Move tokio_kcp to local files

This commit is contained in:
KernelErr 2021-10-26 23:02:05 +08:00
parent bfce455a7e
commit a88a263d20
15 changed files with 1010 additions and 19 deletions

17
Cargo.lock generated
View File

@ -100,14 +100,16 @@ dependencies = [
name = "fourth"
version = "0.1.2"
dependencies = [
"byte_string",
"bytes 1.1.0",
"futures",
"kcp",
"log 0.4.14",
"pretty_env_logger",
"serde",
"serde_yaml",
"tls-parser",
"tokio",
"tokio_kcp",
]
[[package]]
@ -764,19 +766,6 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio_kcp"
version = "0.8.0"
source = "git+https://github.com/Matrix-Zhang/tokio_kcp?rev=d93a2f2#d93a2f2ad2cba731dce8490f0960246d2a655033"
dependencies = [
"byte_string",
"bytes 1.1.0",
"futures",
"kcp",
"log 0.4.14",
"tokio",
]
[[package]]
name = "unicode-xid"
version = "0.2.2"

View File

@ -24,4 +24,7 @@ futures = "0.3"
tls-parser = "0.11"
tokio = { version = "1.0", features = ["full"] }
tokio_kcp = { git = "https://github.com/Matrix-Zhang/tokio_kcp", rev="d93a2f2" }
bytes = "1.1"
kcp = "0.4"
byte_string = "1"

View File

@ -10,7 +10,7 @@ Fourth is a layer 4 proxy implemented by Rust to listen on specific ports and tr
- Listen on specific port and proxy to local or remote port
- SNI-based rule without terminating TLS connection
- Allow TCP inbound
- Allow KCP inbound(warning: untested)
## Installation
@ -60,6 +60,10 @@ upstream:
Built-in two upstreams: ban(terminate connection immediately), echo
## Thanks
- [tokio_kcp](https://github.com/Matrix-Zhang/tokio_kcp)
## License
Fourth is available under terms of Apache-2.0.

View File

@ -12,7 +12,7 @@ Fourth是一个Rust实现的Layer 4代理用于监听指定端口TCP/KCP流
- 监听指定端口代理到本地或远端指定端口
- 监听指定端口通过TLS ClientHello消息中的SNI进行分流
- 支持KCP入站
- 支持KCP入站(警告:未测试)
## 安装方法
@ -73,6 +73,10 @@ upstream:
可能以后会为Linux高内核版本的用户提供可选的io_uring加速。
## 感谢
- [tokio_kcp](https://github.com/Matrix-Zhang/tokio_kcp)
## 协议
Fourth以Apache-2.0协议开源。

View File

@ -1,4 +1,5 @@
mod config;
mod plugins;
mod servers;
use crate::config::Config;

110
src/plugins/kcp/config.rs Normal file
View File

@ -0,0 +1,110 @@
use std::{io::Write, time::Duration};
use kcp::Kcp;
/// Kcp Delay Config
#[derive(Debug, Clone, Copy)]
pub struct KcpNoDelayConfig {
/// Enable nodelay
pub nodelay: bool,
/// Internal update interval (ms)
pub interval: i32,
/// ACK number to enable fast resend
pub resend: i32,
/// Disable congetion control
pub nc: bool,
}
impl Default for KcpNoDelayConfig {
fn default() -> KcpNoDelayConfig {
KcpNoDelayConfig {
nodelay: false,
interval: 100,
resend: 0,
nc: false,
}
}
}
#[allow(unused)]
impl KcpNoDelayConfig {
/// Get a fastest configuration
///
/// 1. Enable NoDelay
/// 2. Set ticking interval to be 10ms
/// 3. Set fast resend to be 2
/// 4. Disable congestion control
pub fn fastest() -> KcpNoDelayConfig {
KcpNoDelayConfig {
nodelay: true,
interval: 10,
resend: 2,
nc: true,
}
}
/// Get a normal configuration
///
/// 1. Disable NoDelay
/// 2. Set ticking interval to be 40ms
/// 3. Disable fast resend
/// 4. Enable congestion control
pub fn normal() -> KcpNoDelayConfig {
KcpNoDelayConfig {
nodelay: false,
interval: 40,
resend: 0,
nc: false,
}
}
}
/// Kcp Config
#[derive(Debug, Clone, Copy)]
pub struct KcpConfig {
/// Max Transmission Unit
pub mtu: usize,
/// nodelay
pub nodelay: KcpNoDelayConfig,
/// Send window size
pub wnd_size: (u16, u16),
/// Session expire duration, default is 90 seconds
pub session_expire: Duration,
/// Flush KCP state immediately after write
pub flush_write: bool,
/// Flush ACKs immediately after input
pub flush_acks_input: bool,
/// Stream mode
pub stream: bool,
}
impl Default for KcpConfig {
fn default() -> KcpConfig {
KcpConfig {
mtu: 1400,
nodelay: KcpNoDelayConfig::normal(),
wnd_size: (256, 256),
session_expire: Duration::from_secs(90),
flush_write: false,
flush_acks_input: false,
stream: true,
}
}
}
impl KcpConfig {
/// Applies config onto `Kcp`
#[doc(hidden)]
pub fn apply_config<W: Write>(&self, k: &mut Kcp<W>) {
k.set_mtu(self.mtu).expect("invalid MTU");
k.set_nodelay(
self.nodelay.nodelay,
self.nodelay.interval,
self.nodelay.resend,
self.nodelay.nc,
);
k.set_wndsize(self.wnd_size.0, self.wnd_size.1);
}
}

128
src/plugins/kcp/listener.rs Normal file
View File

@ -0,0 +1,128 @@
use std::{
io::{self, ErrorKind},
net::SocketAddr,
sync::Arc,
time::Duration,
};
use byte_string::ByteStr;
use kcp::{Error as KcpError, KcpResult};
use log::{debug, error, trace};
use tokio::{
net::{ToSocketAddrs, UdpSocket},
sync::mpsc,
task::JoinHandle,
time,
};
use crate::plugins::kcp::{config::KcpConfig, session::KcpSessionManager, stream::KcpStream};
#[allow(unused)]
pub struct KcpListener {
udp: Arc<UdpSocket>,
accept_rx: mpsc::Receiver<(KcpStream, SocketAddr)>,
task_watcher: JoinHandle<()>,
}
impl Drop for KcpListener {
fn drop(&mut self) {
self.task_watcher.abort();
}
}
impl KcpListener {
pub async fn bind<A: ToSocketAddrs>(config: KcpConfig, addr: A) -> KcpResult<KcpListener> {
let udp = UdpSocket::bind(addr).await?;
let udp = Arc::new(udp);
let server_udp = udp.clone();
let (accept_tx, accept_rx) = mpsc::channel(1024 /* backlogs */);
let task_watcher = tokio::spawn(async move {
let (close_tx, mut close_rx) = mpsc::channel(64);
let mut sessions = KcpSessionManager::new();
let mut packet_buffer = [0u8; 65536];
loop {
tokio::select! {
conv = close_rx.recv() => {
let conv = conv.expect("close_tx closed unexpectly");
sessions.close_conv(conv);
trace!("session conv: {} removed", conv);
}
recv_res = udp.recv_from(&mut packet_buffer) => {
match recv_res {
Err(err) => {
error!("udp.recv_from failed, error: {}", err);
time::sleep(Duration::from_secs(1)).await;
}
Ok((n, peer_addr)) => {
let packet = &mut packet_buffer[..n];
log::trace!("received peer: {}, {:?}", peer_addr, ByteStr::new(packet));
let mut conv = kcp::get_conv(packet);
if conv == 0 {
// Allocate a conv for client.
conv = sessions.alloc_conv();
debug!("allocate {} conv for peer: {}", conv, peer_addr);
kcp::set_conv(packet, conv);
}
let session = match sessions.get_or_create(&config, conv, &udp, peer_addr, &close_tx) {
Ok((s, created)) => {
if created {
// Created a new session, constructed a new accepted client
let stream = KcpStream::with_session(s.clone());
if let Err(..) = accept_tx.try_send((stream, peer_addr)) {
debug!("failed to create accepted stream due to channel failure");
// remove it from session
sessions.close_conv(conv);
continue;
}
}
s
},
Err(err) => {
error!("failed to create session, error: {}, peer: {}, conv: {}", err, peer_addr, conv);
continue;
}
};
// let mut kcp = session.kcp_socket().lock().await;
// if let Err(err) = kcp.input(packet) {
// error!("kcp.input failed, peer: {}, conv: {}, error: {}, packet: {:?}", peer_addr, conv, err, ByteStr::new(packet));
// }
session.input(packet).await;
}
}
}
}
}
});
Ok(KcpListener {
udp: server_udp,
accept_rx,
task_watcher,
})
}
pub async fn accept(&mut self) -> KcpResult<(KcpStream, SocketAddr)> {
match self.accept_rx.recv().await {
Some(s) => Ok(s),
None => Err(KcpError::IoError(io::Error::new(
ErrorKind::Other,
"accept channel closed unexpectly",
))),
}
}
#[allow(unused)]
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.udp.local_addr()
}
}

14
src/plugins/kcp/mod.rs Normal file
View File

@ -0,0 +1,14 @@
//! Library of KCP on Tokio
pub use self::{
config::{KcpConfig, KcpNoDelayConfig},
listener::KcpListener,
stream::KcpStream,
};
mod config;
mod listener;
mod session;
mod skcp;
mod stream;
mod utils;

256
src/plugins/kcp/session.rs Normal file
View File

@ -0,0 +1,256 @@
use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use byte_string::ByteStr;
use kcp::KcpResult;
use log::{error, trace};
use tokio::{
net::UdpSocket,
sync::{mpsc, Mutex},
time::{self, Instant},
};
use crate::plugins::kcp::{skcp::KcpSocket, KcpConfig};
pub struct KcpSession {
socket: Mutex<KcpSocket>,
closed: AtomicBool,
session_expire: Duration,
session_close_notifier: Option<mpsc::Sender<u32>>,
input_tx: mpsc::Sender<Vec<u8>>,
}
impl KcpSession {
fn new(
socket: KcpSocket,
session_expire: Duration,
session_close_notifier: Option<mpsc::Sender<u32>>,
input_tx: mpsc::Sender<Vec<u8>>,
) -> KcpSession {
KcpSession {
socket: Mutex::new(socket),
closed: AtomicBool::new(false),
session_expire,
session_close_notifier,
input_tx,
}
}
pub fn new_shared(
socket: KcpSocket,
session_expire: Duration,
session_close_notifier: Option<mpsc::Sender<u32>>,
) -> Arc<KcpSession> {
let is_client = session_close_notifier.is_none();
let (input_tx, mut input_rx) = mpsc::channel(64);
let udp_socket = socket.udp_socket().clone();
let session = Arc::new(KcpSession::new(
socket,
session_expire,
session_close_notifier,
input_tx,
));
{
let session = session.clone();
tokio::spawn(async move {
let mut input_buffer = [0u8; 65536];
let update_timer = time::sleep(Duration::from_millis(10));
tokio::pin!(update_timer);
loop {
tokio::select! {
// recv() then input()
// Drives the KCP machine forward
recv_result = udp_socket.recv(&mut input_buffer), if is_client => {
match recv_result {
Err(err) => {
error!("[SESSION] UDP recv failed, error: {}", err);
}
Ok(n) => {
let input_buffer = &input_buffer[..n];
trace!("[SESSION] UDP recv {} bytes, going to input {:?}", n, ByteStr::new(input_buffer));
let mut socket = session.socket.lock().await;
match socket.input(input_buffer) {
Ok(true) => {
trace!("[SESSION] UDP input {} bytes and waked sender/receiver", n);
}
Ok(false) => {}
Err(err) => {
error!("[SESSION] UDP input {} bytes error: {}, input buffer {:?}", n, err, ByteStr::new(input_buffer));
}
}
}
}
}
// bytes received from listener socket
input_opt = input_rx.recv() => {
if let Some(input_buffer) = input_opt {
let mut socket = session.socket.lock().await;
match socket.input(&input_buffer) {
Ok(..) => {
trace!("[SESSION] UDP input {} bytes from channel {:?}", input_buffer.len(), ByteStr::new(&input_buffer));
}
Err(err) => {
error!("[SESSION] UDP input {} bytes from channel failed, error: {}, input buffer {:?}",
input_buffer.len(), err, ByteStr::new(&input_buffer));
}
}
}
}
// Call update() in period
_ = &mut update_timer => {
let mut socket = session.socket.lock().await;
let is_closed = session.closed.load(Ordering::Acquire);
if is_closed && socket.can_close() {
trace!("[SESSION] KCP session closed");
break;
}
// server socket expires
if !is_client {
// If this is a server stream, close it automatically after a period of time
let last_update_time = socket.last_update_time();
let elapsed = last_update_time.elapsed();
if elapsed > session.session_expire {
if elapsed > session.session_expire * 2 {
// Force close. Client may have already gone.
trace!(
"[SESSION] force close inactive session, conv: {}, last_update: {}s ago",
socket.conv(),
elapsed.as_secs()
);
break;
}
if !is_closed {
trace!(
"[SESSION] closing inactive session, conv: {}, last_update: {}s ago",
socket.conv(),
elapsed.as_secs()
);
session.closed.store(true, Ordering::Release);
}
}
}
match socket.update() {
Ok(next_next) => {
update_timer.as_mut().reset(Instant::from_std(next_next));
}
Err(err) => {
error!("[SESSION] KCP update failed, error: {}", err);
update_timer.as_mut().reset(Instant::now() + Duration::from_millis(10));
}
}
}
}
}
{
// Close the socket.
// Wake all pending tasks and let all send/recv return EOF
let mut socket = session.socket.lock().await;
socket.close();
}
if let Some(ref notifier) = session.session_close_notifier {
let socket = session.socket.lock().await;
let _ = notifier.send(socket.conv()).await;
}
});
}
session
}
pub fn kcp_socket(&self) -> &Mutex<KcpSocket> {
&self.socket
}
pub fn close(&self) {
self.closed.store(true, Ordering::Release);
}
pub async fn input(&self, buf: &[u8]) {
self.input_tx
.send(buf.to_owned())
.await
.expect("input channel closed")
}
}
pub struct KcpSessionManager {
sessions: HashMap<u32, Arc<KcpSession>>,
next_free_conv: u32,
}
impl KcpSessionManager {
pub fn new() -> KcpSessionManager {
KcpSessionManager {
sessions: HashMap::new(),
next_free_conv: 0,
}
}
pub fn close_conv(&mut self, conv: u32) {
self.sessions.remove(&conv);
}
pub fn alloc_conv(&mut self) -> u32 {
loop {
let (mut c, _) = self.next_free_conv.overflowing_add(1);
if c == 0 {
let (nc, _) = c.overflowing_add(1);
c = nc;
}
self.next_free_conv = c;
if self.sessions.get(&self.next_free_conv).is_none() {
let conv = self.next_free_conv;
return conv;
}
}
}
pub fn get_or_create(
&mut self,
config: &KcpConfig,
conv: u32,
udp: &Arc<UdpSocket>,
peer_addr: SocketAddr,
session_close_notifier: &mpsc::Sender<u32>,
) -> KcpResult<(Arc<KcpSession>, bool)> {
match self.sessions.entry(conv) {
Entry::Occupied(occ) => Ok((occ.get().clone(), false)),
Entry::Vacant(vac) => {
let socket = KcpSocket::new(config, conv, udp.clone(), peer_addr, config.stream)?;
let session = KcpSession::new_shared(
socket,
config.session_expire,
Some(session_close_notifier.clone()),
);
trace!("created session for conv: {}, peer: {}", conv, peer_addr);
vac.insert(session.clone());
Ok((session, true))
}
}
}
}

288
src/plugins/kcp/skcp.rs Normal file
View File

@ -0,0 +1,288 @@
use std::{
io::{self, ErrorKind, Write},
net::SocketAddr,
sync::Arc,
task::{Context, Poll, Waker},
time::{Duration, Instant},
};
use futures::future;
use kcp::{Error as KcpError, Kcp, KcpResult};
use log::{error, trace};
use tokio::{net::UdpSocket, sync::mpsc};
use crate::plugins::kcp::{utils::now_millis, KcpConfig};
/// Writer for sending packets to the underlying UdpSocket
struct UdpOutput {
socket: Arc<UdpSocket>,
target_addr: SocketAddr,
delay_tx: mpsc::UnboundedSender<Vec<u8>>,
}
impl UdpOutput {
/// Create a new Writer for writing packets to UdpSocket
pub fn new(socket: Arc<UdpSocket>, target_addr: SocketAddr) -> UdpOutput {
let (delay_tx, mut delay_rx) = mpsc::unbounded_channel::<Vec<u8>>();
{
let socket = socket.clone();
tokio::spawn(async move {
while let Some(buf) = delay_rx.recv().await {
if let Err(err) = socket.send_to(&buf, target_addr).await {
error!("[SEND] UDP delayed send failed, error: {}", err);
}
}
});
}
UdpOutput {
socket,
target_addr,
delay_tx,
}
}
}
impl Write for UdpOutput {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.socket.try_send_to(buf, self.target_addr) {
Ok(n) => Ok(n),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
// send return EAGAIN
// ignored as packet was lost in transmission
trace!(
"[SEND] UDP send EAGAIN, packet.size: {} bytes, delayed send",
buf.len()
);
self.delay_tx
.send(buf.to_owned())
.expect("channel closed unexpectly");
Ok(buf.len())
}
Err(err) => Err(err),
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct KcpSocket {
kcp: Kcp<UdpOutput>,
last_update: Instant,
socket: Arc<UdpSocket>,
flush_write: bool,
flush_ack_input: bool,
sent_first: bool,
pending_sender: Option<Waker>,
pending_receiver: Option<Waker>,
closed: bool,
}
impl KcpSocket {
pub fn new(
c: &KcpConfig,
conv: u32,
socket: Arc<UdpSocket>,
target_addr: SocketAddr,
stream: bool,
) -> KcpResult<KcpSocket> {
let output = UdpOutput::new(socket.clone(), target_addr);
let mut kcp = if stream {
Kcp::new_stream(conv, output)
} else {
Kcp::new(conv, output)
};
c.apply_config(&mut kcp);
// Ask server to allocate one
if conv == 0 {
kcp.input_conv();
}
kcp.update(now_millis())?;
Ok(KcpSocket {
kcp,
last_update: Instant::now(),
socket,
flush_write: c.flush_write,
flush_ack_input: c.flush_acks_input,
sent_first: false,
pending_sender: None,
pending_receiver: None,
closed: false,
})
}
/// Call every time you got data from transmission
pub fn input(&mut self, buf: &[u8]) -> KcpResult<bool> {
match self.kcp.input(buf) {
Ok(..) => {}
Err(KcpError::ConvInconsistent(expected, actual)) => {
trace!(
"[INPUT] Conv expected={} actual={} ignored",
expected,
actual
);
return Ok(false);
}
Err(err) => return Err(err),
}
self.last_update = Instant::now();
if self.flush_ack_input {
self.kcp.flush_ack()?;
}
Ok(self.try_wake_pending_waker())
}
/// Call if you want to send some data
pub fn poll_send(&mut self, cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<KcpResult<usize>> {
if self.closed {
return Ok(0).into();
}
// If:
// 1. Have sent the first packet (asking for conv)
// 2. Too many pending packets
if self.sent_first
&& (self.kcp.wait_snd() >= self.kcp.snd_wnd() as usize || self.kcp.waiting_conv())
{
trace!(
"[SEND] waitsnd={} sndwnd={} excceeded or waiting conv={}",
self.kcp.wait_snd(),
self.kcp.snd_wnd(),
self.kcp.waiting_conv()
);
self.pending_sender = Some(cx.waker().clone());
return Poll::Pending;
}
if !self.sent_first && self.kcp.waiting_conv() && buf.len() > self.kcp.mss() as usize {
buf = &buf[..self.kcp.mss() as usize];
}
let n = self.kcp.send(buf)?;
self.sent_first = true;
self.last_update = Instant::now();
if self.flush_write {
self.kcp.flush()?;
}
Ok(n).into()
}
/// Call if you want to send some data
#[allow(dead_code)]
pub async fn send(&mut self, buf: &[u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_send(cx, buf)).await
}
#[allow(dead_code)]
pub fn try_recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
if self.closed {
return Ok(0);
}
self.kcp.recv(buf)
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<KcpResult<usize>> {
if self.closed {
return Ok(0).into();
}
match self.kcp.recv(buf) {
Ok(n) => Ok(n).into(),
Err(KcpError::RecvQueueEmpty) => {
self.pending_receiver = Some(cx.waker().clone());
Poll::Pending
}
Err(err) => Err(err).into(),
}
}
#[allow(dead_code)]
pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_recv(cx, buf)).await
}
pub fn flush(&mut self) -> KcpResult<()> {
self.kcp.flush()?;
self.last_update = Instant::now();
Ok(())
}
fn try_wake_pending_waker(&mut self) -> bool {
let mut waked = false;
if self.pending_sender.is_some()
&& self.kcp.wait_snd() < self.kcp.snd_wnd() as usize
&& !self.kcp.waiting_conv()
{
let waker = self.pending_sender.take().unwrap();
waker.wake();
waked = true;
}
if self.pending_receiver.is_some() {
if let Ok(peek) = self.kcp.peeksize() {
if peek > 0 {
let waker = self.pending_receiver.take().unwrap();
waker.wake();
waked = true;
}
}
}
waked
}
pub fn update(&mut self) -> KcpResult<Instant> {
let now = now_millis();
self.kcp.update(now)?;
let next = self.kcp.check(now);
self.try_wake_pending_waker();
Ok(Instant::now() + Duration::from_millis(next as u64))
}
pub fn close(&mut self) {
self.closed = true;
if let Some(w) = self.pending_sender.take() {
w.wake();
}
if let Some(w) = self.pending_receiver.take() {
w.wake();
}
}
pub fn udp_socket(&self) -> &Arc<UdpSocket> {
&self.socket
}
pub fn can_close(&self) -> bool {
self.kcp.wait_snd() == 0
}
pub fn conv(&self) -> u32 {
self.kcp.conv()
}
pub fn peek_size(&self) -> KcpResult<usize> {
self.kcp.peeksize()
}
pub fn last_update_time(&self) -> Instant {
self.last_update
}
}

183
src/plugins/kcp/stream.rs Normal file
View File

@ -0,0 +1,183 @@
use std::{
io::{self, ErrorKind},
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{future, ready};
use kcp::{Error as KcpError, KcpResult};
use log::trace;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::UdpSocket,
};
use crate::plugins::kcp::{config::KcpConfig, session::KcpSession, skcp::KcpSocket};
pub struct KcpStream {
session: Arc<KcpSession>,
recv_buffer: Vec<u8>,
recv_buffer_pos: usize,
recv_buffer_cap: usize,
}
impl Drop for KcpStream {
fn drop(&mut self) {
self.session.close();
}
}
#[allow(unused)]
impl KcpStream {
pub async fn connect(config: &KcpConfig, addr: SocketAddr) -> KcpResult<KcpStream> {
let udp = match addr.ip() {
IpAddr::V4(..) => UdpSocket::bind("0.0.0.0:0").await?,
IpAddr::V6(..) => UdpSocket::bind("[::]:0").await?,
};
let udp = Arc::new(udp);
let socket = KcpSocket::new(config, 0, udp, addr, config.stream)?;
let session = KcpSession::new_shared(socket, config.session_expire, None);
Ok(KcpStream::with_session(session))
}
pub(crate) fn with_session(session: Arc<KcpSession>) -> KcpStream {
KcpStream {
session,
recv_buffer: Vec::new(),
recv_buffer_pos: 0,
recv_buffer_cap: 0,
}
}
pub fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<KcpResult<usize>> {
// Mutex doesn't have poll_lock, spinning on it.
let socket = self.session.kcp_socket();
let mut kcp = match socket.try_lock() {
Ok(guard) => guard,
Err(..) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
kcp.poll_send(cx, buf)
}
pub async fn send(&mut self, buf: &[u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_send(cx, buf)).await
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<KcpResult<usize>> {
loop {
// Consumes all data in buffer
if self.recv_buffer_pos < self.recv_buffer_cap {
let remaining = self.recv_buffer_cap - self.recv_buffer_pos;
let copy_length = remaining.min(buf.len());
buf.copy_from_slice(
&self.recv_buffer[self.recv_buffer_pos..self.recv_buffer_pos + copy_length],
);
self.recv_buffer_pos += copy_length;
return Ok(copy_length).into();
}
// Mutex doesn't have poll_lock, spinning on it.
let socket = self.session.kcp_socket();
let mut kcp = match socket.try_lock() {
Ok(guard) => guard,
Err(..) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
// Try to read from KCP
// 1. Read directly with user provided `buf`
match ready!(kcp.poll_recv(cx, buf)) {
Ok(n) => {
trace!("[CLIENT] recv directly {} bytes", n);
return Ok(n).into();
}
Err(KcpError::UserBufTooSmall) => {}
Err(err) => return Err(err).into(),
}
// 2. User `buf` too small, read to recv_buffer
let required_size = kcp.peek_size()?;
if self.recv_buffer.len() < required_size {
self.recv_buffer.resize(required_size, 0);
}
match ready!(kcp.poll_recv(cx, &mut self.recv_buffer)) {
Ok(n) => {
trace!("[CLIENT] recv buffered {} bytes", n);
self.recv_buffer_pos = 0;
self.recv_buffer_cap = n;
}
Err(err) => return Err(err).into(),
}
}
}
pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_recv(cx, buf)).await
}
}
impl AsyncRead for KcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match ready!(self.poll_recv(cx, buf.initialize_unfilled())) {
Ok(n) => {
buf.advance(n);
Ok(()).into()
}
Err(KcpError::IoError(err)) => Err(err).into(),
Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
}
}
}
impl AsyncWrite for KcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match ready!(self.poll_send(cx, buf)) {
Ok(n) => Ok(n).into(),
Err(KcpError::IoError(err)) => Err(err).into(),
Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// Mutex doesn't have poll_lock, spinning on it.
let socket = self.session.kcp_socket();
let mut kcp = match socket.try_lock() {
Ok(guard) => guard,
Err(..) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
match kcp.flush() {
Ok(..) => Ok(()).into(),
Err(KcpError::IoError(err)) => Err(err).into(),
Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
}

10
src/plugins/kcp/utils.rs Normal file
View File

@ -0,0 +1,10 @@
use std::time::{SystemTime, UNIX_EPOCH};
#[inline]
pub fn now_millis() -> u32 {
let start = SystemTime::now();
let since_the_epoch = start
.duration_since(UNIX_EPOCH)
.expect("time went afterwards");
(since_the_epoch.as_secs() * 1000 + since_the_epoch.subsec_millis() as u64 / 1_000_000) as u32
}

1
src/plugins/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod kcp;

View File

@ -104,12 +104,12 @@ impl Server {
#[cfg(test)]
mod test {
use crate::plugins::kcp::{KcpConfig, KcpStream};
use std::net::SocketAddr;
use std::thread::{self, sleep};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_kcp::{KcpConfig, KcpStream};
use super::*;

View File

@ -1,3 +1,4 @@
use crate::plugins::kcp::{KcpConfig, KcpListener, KcpStream};
use crate::servers::Proxy;
use futures::future::try_join;
use log::{debug, error, warn};
@ -6,7 +7,6 @@ use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_kcp::{KcpConfig, KcpListener, KcpStream};
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let kcp_config = KcpConfig::default();