diff --git a/Cargo.lock b/Cargo.lock index b86446e..ffdcffb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,14 +100,16 @@ dependencies = [ name = "fourth" version = "0.1.2" dependencies = [ + "byte_string", + "bytes 1.1.0", "futures", + "kcp", "log 0.4.14", "pretty_env_logger", "serde", "serde_yaml", "tls-parser", "tokio", - "tokio_kcp", ] [[package]] @@ -764,19 +766,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio_kcp" -version = "0.8.0" -source = "git+https://github.com/Matrix-Zhang/tokio_kcp?rev=d93a2f2#d93a2f2ad2cba731dce8490f0960246d2a655033" -dependencies = [ - "byte_string", - "bytes 1.1.0", - "futures", - "kcp", - "log 0.4.14", - "tokio", -] - [[package]] name = "unicode-xid" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index 9659251..e6a193b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,4 +24,7 @@ futures = "0.3" tls-parser = "0.11" tokio = { version = "1.0", features = ["full"] } -tokio_kcp = { git = "https://github.com/Matrix-Zhang/tokio_kcp", rev="d93a2f2" } \ No newline at end of file + +bytes = "1.1" +kcp = "0.4" +byte_string = "1" \ No newline at end of file diff --git a/README-EN.md b/README-EN.md index 184ff14..c048bc0 100644 --- a/README-EN.md +++ b/README-EN.md @@ -10,7 +10,7 @@ Fourth is a layer 4 proxy implemented by Rust to listen on specific ports and tr - Listen on specific port and proxy to local or remote port - SNI-based rule without terminating TLS connection -- Allow TCP inbound +- Allow KCP inbound(warning: untested) ## Installation @@ -60,6 +60,10 @@ upstream: Built-in two upstreams: ban(terminate connection immediately), echo +## Thanks + +- [tokio_kcp](https://github.com/Matrix-Zhang/tokio_kcp) + ## License Fourth is available under terms of Apache-2.0. \ No newline at end of file diff --git a/README.md b/README.md index 324b973..5c07f87 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Fourth是一个Rust实现的Layer 4代理,用于监听指定端口TCP/KCP流 - 监听指定端口代理到本地或远端指定端口 - 监听指定端口,通过TLS ClientHello消息中的SNI进行分流 -- 支持KCP入站 +- 支持KCP入站(警告:未测试) ## 安装方法 @@ -73,6 +73,10 @@ upstream: 可能以后会为Linux高内核版本的用户提供可选的io_uring加速。 +## 感谢 + +- [tokio_kcp](https://github.com/Matrix-Zhang/tokio_kcp) + ## 协议 Fourth以Apache-2.0协议开源。 diff --git a/src/main.rs b/src/main.rs index 8b0b42d..4a9532f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ mod config; +mod plugins; mod servers; use crate::config::Config; diff --git a/src/plugins/kcp/config.rs b/src/plugins/kcp/config.rs new file mode 100644 index 0000000..7ba881d --- /dev/null +++ b/src/plugins/kcp/config.rs @@ -0,0 +1,110 @@ +use std::{io::Write, time::Duration}; + +use kcp::Kcp; + +/// Kcp Delay Config +#[derive(Debug, Clone, Copy)] +pub struct KcpNoDelayConfig { + /// Enable nodelay + pub nodelay: bool, + /// Internal update interval (ms) + pub interval: i32, + /// ACK number to enable fast resend + pub resend: i32, + /// Disable congetion control + pub nc: bool, +} + +impl Default for KcpNoDelayConfig { + fn default() -> KcpNoDelayConfig { + KcpNoDelayConfig { + nodelay: false, + interval: 100, + resend: 0, + nc: false, + } + } +} + +#[allow(unused)] +impl KcpNoDelayConfig { + /// Get a fastest configuration + /// + /// 1. Enable NoDelay + /// 2. Set ticking interval to be 10ms + /// 3. Set fast resend to be 2 + /// 4. Disable congestion control + pub fn fastest() -> KcpNoDelayConfig { + KcpNoDelayConfig { + nodelay: true, + interval: 10, + resend: 2, + nc: true, + } + } + + /// Get a normal configuration + /// + /// 1. Disable NoDelay + /// 2. Set ticking interval to be 40ms + /// 3. Disable fast resend + /// 4. Enable congestion control + pub fn normal() -> KcpNoDelayConfig { + KcpNoDelayConfig { + nodelay: false, + interval: 40, + resend: 0, + nc: false, + } + } +} + +/// Kcp Config +#[derive(Debug, Clone, Copy)] +pub struct KcpConfig { + /// Max Transmission Unit + pub mtu: usize, + /// nodelay + pub nodelay: KcpNoDelayConfig, + /// Send window size + pub wnd_size: (u16, u16), + /// Session expire duration, default is 90 seconds + pub session_expire: Duration, + /// Flush KCP state immediately after write + pub flush_write: bool, + /// Flush ACKs immediately after input + pub flush_acks_input: bool, + /// Stream mode + pub stream: bool, +} + +impl Default for KcpConfig { + fn default() -> KcpConfig { + KcpConfig { + mtu: 1400, + nodelay: KcpNoDelayConfig::normal(), + wnd_size: (256, 256), + session_expire: Duration::from_secs(90), + flush_write: false, + flush_acks_input: false, + stream: true, + } + } +} + +impl KcpConfig { + /// Applies config onto `Kcp` + #[doc(hidden)] + pub fn apply_config(&self, k: &mut Kcp) { + k.set_mtu(self.mtu).expect("invalid MTU"); + + k.set_nodelay( + self.nodelay.nodelay, + self.nodelay.interval, + self.nodelay.resend, + self.nodelay.nc, + ); + + k.set_wndsize(self.wnd_size.0, self.wnd_size.1); + } +} diff --git a/src/plugins/kcp/listener.rs b/src/plugins/kcp/listener.rs new file mode 100644 index 0000000..89b0ffc --- /dev/null +++ b/src/plugins/kcp/listener.rs @@ -0,0 +1,128 @@ +use std::{ + io::{self, ErrorKind}, + net::SocketAddr, + sync::Arc, + time::Duration, +}; + +use byte_string::ByteStr; +use kcp::{Error as KcpError, KcpResult}; +use log::{debug, error, trace}; +use tokio::{ + net::{ToSocketAddrs, UdpSocket}, + sync::mpsc, + task::JoinHandle, + time, +}; + +use crate::plugins::kcp::{config::KcpConfig, session::KcpSessionManager, stream::KcpStream}; + +#[allow(unused)] +pub struct KcpListener { + udp: Arc, + accept_rx: mpsc::Receiver<(KcpStream, SocketAddr)>, + task_watcher: JoinHandle<()>, +} + +impl Drop for KcpListener { + fn drop(&mut self) { + self.task_watcher.abort(); + } +} + +impl KcpListener { + pub async fn bind(config: KcpConfig, addr: A) -> KcpResult { + let udp = UdpSocket::bind(addr).await?; + let udp = Arc::new(udp); + let server_udp = udp.clone(); + + let (accept_tx, accept_rx) = mpsc::channel(1024 /* backlogs */); + let task_watcher = tokio::spawn(async move { + let (close_tx, mut close_rx) = mpsc::channel(64); + + let mut sessions = KcpSessionManager::new(); + let mut packet_buffer = [0u8; 65536]; + loop { + tokio::select! { + conv = close_rx.recv() => { + let conv = conv.expect("close_tx closed unexpectly"); + sessions.close_conv(conv); + trace!("session conv: {} removed", conv); + } + + recv_res = udp.recv_from(&mut packet_buffer) => { + match recv_res { + Err(err) => { + error!("udp.recv_from failed, error: {}", err); + time::sleep(Duration::from_secs(1)).await; + } + Ok((n, peer_addr)) => { + let packet = &mut packet_buffer[..n]; + + log::trace!("received peer: {}, {:?}", peer_addr, ByteStr::new(packet)); + + let mut conv = kcp::get_conv(packet); + if conv == 0 { + // Allocate a conv for client. + conv = sessions.alloc_conv(); + debug!("allocate {} conv for peer: {}", conv, peer_addr); + + kcp::set_conv(packet, conv); + } + + let session = match sessions.get_or_create(&config, conv, &udp, peer_addr, &close_tx) { + Ok((s, created)) => { + if created { + // Created a new session, constructed a new accepted client + let stream = KcpStream::with_session(s.clone()); + if let Err(..) = accept_tx.try_send((stream, peer_addr)) { + debug!("failed to create accepted stream due to channel failure"); + + // remove it from session + sessions.close_conv(conv); + continue; + } + } + + s + }, + Err(err) => { + error!("failed to create session, error: {}, peer: {}, conv: {}", err, peer_addr, conv); + continue; + } + }; + + // let mut kcp = session.kcp_socket().lock().await; + // if let Err(err) = kcp.input(packet) { + // error!("kcp.input failed, peer: {}, conv: {}, error: {}, packet: {:?}", peer_addr, conv, err, ByteStr::new(packet)); + // } + session.input(packet).await; + } + } + } + } + } + }); + + Ok(KcpListener { + udp: server_udp, + accept_rx, + task_watcher, + }) + } + + pub async fn accept(&mut self) -> KcpResult<(KcpStream, SocketAddr)> { + match self.accept_rx.recv().await { + Some(s) => Ok(s), + None => Err(KcpError::IoError(io::Error::new( + ErrorKind::Other, + "accept channel closed unexpectly", + ))), + } + } + + #[allow(unused)] + pub fn local_addr(&self) -> io::Result { + self.udp.local_addr() + } +} diff --git a/src/plugins/kcp/mod.rs b/src/plugins/kcp/mod.rs new file mode 100644 index 0000000..51f88c7 --- /dev/null +++ b/src/plugins/kcp/mod.rs @@ -0,0 +1,14 @@ +//! Library of KCP on Tokio + +pub use self::{ + config::{KcpConfig, KcpNoDelayConfig}, + listener::KcpListener, + stream::KcpStream, +}; + +mod config; +mod listener; +mod session; +mod skcp; +mod stream; +mod utils; diff --git a/src/plugins/kcp/session.rs b/src/plugins/kcp/session.rs new file mode 100644 index 0000000..25a0bef --- /dev/null +++ b/src/plugins/kcp/session.rs @@ -0,0 +1,256 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use byte_string::ByteStr; +use kcp::KcpResult; +use log::{error, trace}; +use tokio::{ + net::UdpSocket, + sync::{mpsc, Mutex}, + time::{self, Instant}, +}; + +use crate::plugins::kcp::{skcp::KcpSocket, KcpConfig}; + +pub struct KcpSession { + socket: Mutex, + closed: AtomicBool, + session_expire: Duration, + session_close_notifier: Option>, + input_tx: mpsc::Sender>, +} + +impl KcpSession { + fn new( + socket: KcpSocket, + session_expire: Duration, + session_close_notifier: Option>, + input_tx: mpsc::Sender>, + ) -> KcpSession { + KcpSession { + socket: Mutex::new(socket), + closed: AtomicBool::new(false), + session_expire, + session_close_notifier, + input_tx, + } + } + + pub fn new_shared( + socket: KcpSocket, + session_expire: Duration, + session_close_notifier: Option>, + ) -> Arc { + let is_client = session_close_notifier.is_none(); + + let (input_tx, mut input_rx) = mpsc::channel(64); + + let udp_socket = socket.udp_socket().clone(); + + let session = Arc::new(KcpSession::new( + socket, + session_expire, + session_close_notifier, + input_tx, + )); + + { + let session = session.clone(); + tokio::spawn(async move { + let mut input_buffer = [0u8; 65536]; + let update_timer = time::sleep(Duration::from_millis(10)); + tokio::pin!(update_timer); + + loop { + tokio::select! { + // recv() then input() + // Drives the KCP machine forward + recv_result = udp_socket.recv(&mut input_buffer), if is_client => { + match recv_result { + Err(err) => { + error!("[SESSION] UDP recv failed, error: {}", err); + } + Ok(n) => { + let input_buffer = &input_buffer[..n]; + trace!("[SESSION] UDP recv {} bytes, going to input {:?}", n, ByteStr::new(input_buffer)); + + let mut socket = session.socket.lock().await; + + match socket.input(input_buffer) { + Ok(true) => { + trace!("[SESSION] UDP input {} bytes and waked sender/receiver", n); + } + Ok(false) => {} + Err(err) => { + error!("[SESSION] UDP input {} bytes error: {}, input buffer {:?}", n, err, ByteStr::new(input_buffer)); + } + } + } + } + } + + // bytes received from listener socket + input_opt = input_rx.recv() => { + if let Some(input_buffer) = input_opt { + let mut socket = session.socket.lock().await; + match socket.input(&input_buffer) { + Ok(..) => { + trace!("[SESSION] UDP input {} bytes from channel {:?}", input_buffer.len(), ByteStr::new(&input_buffer)); + } + Err(err) => { + error!("[SESSION] UDP input {} bytes from channel failed, error: {}, input buffer {:?}", + input_buffer.len(), err, ByteStr::new(&input_buffer)); + } + } + } + } + + // Call update() in period + _ = &mut update_timer => { + let mut socket = session.socket.lock().await; + + let is_closed = session.closed.load(Ordering::Acquire); + if is_closed && socket.can_close() { + trace!("[SESSION] KCP session closed"); + break; + } + + // server socket expires + if !is_client { + // If this is a server stream, close it automatically after a period of time + let last_update_time = socket.last_update_time(); + let elapsed = last_update_time.elapsed(); + + if elapsed > session.session_expire { + if elapsed > session.session_expire * 2 { + // Force close. Client may have already gone. + trace!( + "[SESSION] force close inactive session, conv: {}, last_update: {}s ago", + socket.conv(), + elapsed.as_secs() + ); + break; + } + + if !is_closed { + trace!( + "[SESSION] closing inactive session, conv: {}, last_update: {}s ago", + socket.conv(), + elapsed.as_secs() + ); + session.closed.store(true, Ordering::Release); + } + } + } + + match socket.update() { + Ok(next_next) => { + update_timer.as_mut().reset(Instant::from_std(next_next)); + } + Err(err) => { + error!("[SESSION] KCP update failed, error: {}", err); + update_timer.as_mut().reset(Instant::now() + Duration::from_millis(10)); + } + } + } + } + } + + { + // Close the socket. + // Wake all pending tasks and let all send/recv return EOF + + let mut socket = session.socket.lock().await; + socket.close(); + } + + if let Some(ref notifier) = session.session_close_notifier { + let socket = session.socket.lock().await; + let _ = notifier.send(socket.conv()).await; + } + }); + } + + session + } + + pub fn kcp_socket(&self) -> &Mutex { + &self.socket + } + + pub fn close(&self) { + self.closed.store(true, Ordering::Release); + } + + pub async fn input(&self, buf: &[u8]) { + self.input_tx + .send(buf.to_owned()) + .await + .expect("input channel closed") + } +} + +pub struct KcpSessionManager { + sessions: HashMap>, + next_free_conv: u32, +} + +impl KcpSessionManager { + pub fn new() -> KcpSessionManager { + KcpSessionManager { + sessions: HashMap::new(), + next_free_conv: 0, + } + } + + pub fn close_conv(&mut self, conv: u32) { + self.sessions.remove(&conv); + } + + pub fn alloc_conv(&mut self) -> u32 { + loop { + let (mut c, _) = self.next_free_conv.overflowing_add(1); + if c == 0 { + let (nc, _) = c.overflowing_add(1); + c = nc; + } + self.next_free_conv = c; + + if self.sessions.get(&self.next_free_conv).is_none() { + let conv = self.next_free_conv; + return conv; + } + } + } + + pub fn get_or_create( + &mut self, + config: &KcpConfig, + conv: u32, + udp: &Arc, + peer_addr: SocketAddr, + session_close_notifier: &mpsc::Sender, + ) -> KcpResult<(Arc, bool)> { + match self.sessions.entry(conv) { + Entry::Occupied(occ) => Ok((occ.get().clone(), false)), + Entry::Vacant(vac) => { + let socket = KcpSocket::new(config, conv, udp.clone(), peer_addr, config.stream)?; + let session = KcpSession::new_shared( + socket, + config.session_expire, + Some(session_close_notifier.clone()), + ); + trace!("created session for conv: {}, peer: {}", conv, peer_addr); + vac.insert(session.clone()); + Ok((session, true)) + } + } + } +} diff --git a/src/plugins/kcp/skcp.rs b/src/plugins/kcp/skcp.rs new file mode 100644 index 0000000..fa064c1 --- /dev/null +++ b/src/plugins/kcp/skcp.rs @@ -0,0 +1,288 @@ +use std::{ + io::{self, ErrorKind, Write}, + net::SocketAddr, + sync::Arc, + task::{Context, Poll, Waker}, + time::{Duration, Instant}, +}; + +use futures::future; +use kcp::{Error as KcpError, Kcp, KcpResult}; +use log::{error, trace}; +use tokio::{net::UdpSocket, sync::mpsc}; + +use crate::plugins::kcp::{utils::now_millis, KcpConfig}; + +/// Writer for sending packets to the underlying UdpSocket +struct UdpOutput { + socket: Arc, + target_addr: SocketAddr, + delay_tx: mpsc::UnboundedSender>, +} + +impl UdpOutput { + /// Create a new Writer for writing packets to UdpSocket + pub fn new(socket: Arc, target_addr: SocketAddr) -> UdpOutput { + let (delay_tx, mut delay_rx) = mpsc::unbounded_channel::>(); + + { + let socket = socket.clone(); + tokio::spawn(async move { + while let Some(buf) = delay_rx.recv().await { + if let Err(err) = socket.send_to(&buf, target_addr).await { + error!("[SEND] UDP delayed send failed, error: {}", err); + } + } + }); + } + + UdpOutput { + socket, + target_addr, + delay_tx, + } + } +} + +impl Write for UdpOutput { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.socket.try_send_to(buf, self.target_addr) { + Ok(n) => Ok(n), + Err(ref err) if err.kind() == ErrorKind::WouldBlock => { + // send return EAGAIN + // ignored as packet was lost in transmission + trace!( + "[SEND] UDP send EAGAIN, packet.size: {} bytes, delayed send", + buf.len() + ); + + self.delay_tx + .send(buf.to_owned()) + .expect("channel closed unexpectly"); + + Ok(buf.len()) + } + Err(err) => Err(err), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub struct KcpSocket { + kcp: Kcp, + last_update: Instant, + socket: Arc, + flush_write: bool, + flush_ack_input: bool, + sent_first: bool, + pending_sender: Option, + pending_receiver: Option, + closed: bool, +} + +impl KcpSocket { + pub fn new( + c: &KcpConfig, + conv: u32, + socket: Arc, + target_addr: SocketAddr, + stream: bool, + ) -> KcpResult { + let output = UdpOutput::new(socket.clone(), target_addr); + let mut kcp = if stream { + Kcp::new_stream(conv, output) + } else { + Kcp::new(conv, output) + }; + c.apply_config(&mut kcp); + + // Ask server to allocate one + if conv == 0 { + kcp.input_conv(); + } + + kcp.update(now_millis())?; + + Ok(KcpSocket { + kcp, + last_update: Instant::now(), + socket, + flush_write: c.flush_write, + flush_ack_input: c.flush_acks_input, + sent_first: false, + pending_sender: None, + pending_receiver: None, + closed: false, + }) + } + + /// Call every time you got data from transmission + pub fn input(&mut self, buf: &[u8]) -> KcpResult { + match self.kcp.input(buf) { + Ok(..) => {} + Err(KcpError::ConvInconsistent(expected, actual)) => { + trace!( + "[INPUT] Conv expected={} actual={} ignored", + expected, + actual + ); + return Ok(false); + } + Err(err) => return Err(err), + } + self.last_update = Instant::now(); + + if self.flush_ack_input { + self.kcp.flush_ack()?; + } + + Ok(self.try_wake_pending_waker()) + } + + /// Call if you want to send some data + pub fn poll_send(&mut self, cx: &mut Context<'_>, mut buf: &[u8]) -> Poll> { + if self.closed { + return Ok(0).into(); + } + + // If: + // 1. Have sent the first packet (asking for conv) + // 2. Too many pending packets + if self.sent_first + && (self.kcp.wait_snd() >= self.kcp.snd_wnd() as usize || self.kcp.waiting_conv()) + { + trace!( + "[SEND] waitsnd={} sndwnd={} excceeded or waiting conv={}", + self.kcp.wait_snd(), + self.kcp.snd_wnd(), + self.kcp.waiting_conv() + ); + self.pending_sender = Some(cx.waker().clone()); + return Poll::Pending; + } + + if !self.sent_first && self.kcp.waiting_conv() && buf.len() > self.kcp.mss() as usize { + buf = &buf[..self.kcp.mss() as usize]; + } + + let n = self.kcp.send(buf)?; + self.sent_first = true; + self.last_update = Instant::now(); + + if self.flush_write { + self.kcp.flush()?; + } + + Ok(n).into() + } + + /// Call if you want to send some data + #[allow(dead_code)] + pub async fn send(&mut self, buf: &[u8]) -> KcpResult { + future::poll_fn(|cx| self.poll_send(cx, buf)).await + } + + #[allow(dead_code)] + pub fn try_recv(&mut self, buf: &mut [u8]) -> KcpResult { + if self.closed { + return Ok(0); + } + self.kcp.recv(buf) + } + + pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + if self.closed { + return Ok(0).into(); + } + + match self.kcp.recv(buf) { + Ok(n) => Ok(n).into(), + Err(KcpError::RecvQueueEmpty) => { + self.pending_receiver = Some(cx.waker().clone()); + Poll::Pending + } + Err(err) => Err(err).into(), + } + } + + #[allow(dead_code)] + pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult { + future::poll_fn(|cx| self.poll_recv(cx, buf)).await + } + + pub fn flush(&mut self) -> KcpResult<()> { + self.kcp.flush()?; + self.last_update = Instant::now(); + Ok(()) + } + + fn try_wake_pending_waker(&mut self) -> bool { + let mut waked = false; + + if self.pending_sender.is_some() + && self.kcp.wait_snd() < self.kcp.snd_wnd() as usize + && !self.kcp.waiting_conv() + { + let waker = self.pending_sender.take().unwrap(); + waker.wake(); + + waked = true; + } + + if self.pending_receiver.is_some() { + if let Ok(peek) = self.kcp.peeksize() { + if peek > 0 { + let waker = self.pending_receiver.take().unwrap(); + waker.wake(); + + waked = true; + } + } + } + + waked + } + + pub fn update(&mut self) -> KcpResult { + let now = now_millis(); + self.kcp.update(now)?; + let next = self.kcp.check(now); + + self.try_wake_pending_waker(); + + Ok(Instant::now() + Duration::from_millis(next as u64)) + } + + pub fn close(&mut self) { + self.closed = true; + if let Some(w) = self.pending_sender.take() { + w.wake(); + } + if let Some(w) = self.pending_receiver.take() { + w.wake(); + } + } + + pub fn udp_socket(&self) -> &Arc { + &self.socket + } + + pub fn can_close(&self) -> bool { + self.kcp.wait_snd() == 0 + } + + pub fn conv(&self) -> u32 { + self.kcp.conv() + } + + pub fn peek_size(&self) -> KcpResult { + self.kcp.peeksize() + } + + pub fn last_update_time(&self) -> Instant { + self.last_update + } +} diff --git a/src/plugins/kcp/stream.rs b/src/plugins/kcp/stream.rs new file mode 100644 index 0000000..9e07f66 --- /dev/null +++ b/src/plugins/kcp/stream.rs @@ -0,0 +1,183 @@ +use std::{ + io::{self, ErrorKind}, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures::{future, ready}; +use kcp::{Error as KcpError, KcpResult}; +use log::trace; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::UdpSocket, +}; + +use crate::plugins::kcp::{config::KcpConfig, session::KcpSession, skcp::KcpSocket}; + +pub struct KcpStream { + session: Arc, + recv_buffer: Vec, + recv_buffer_pos: usize, + recv_buffer_cap: usize, +} + +impl Drop for KcpStream { + fn drop(&mut self) { + self.session.close(); + } +} + +#[allow(unused)] +impl KcpStream { + pub async fn connect(config: &KcpConfig, addr: SocketAddr) -> KcpResult { + let udp = match addr.ip() { + IpAddr::V4(..) => UdpSocket::bind("0.0.0.0:0").await?, + IpAddr::V6(..) => UdpSocket::bind("[::]:0").await?, + }; + + let udp = Arc::new(udp); + let socket = KcpSocket::new(config, 0, udp, addr, config.stream)?; + + let session = KcpSession::new_shared(socket, config.session_expire, None); + + Ok(KcpStream::with_session(session)) + } + + pub(crate) fn with_session(session: Arc) -> KcpStream { + KcpStream { + session, + recv_buffer: Vec::new(), + recv_buffer_pos: 0, + recv_buffer_cap: 0, + } + } + + pub fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + // Mutex doesn't have poll_lock, spinning on it. + let socket = self.session.kcp_socket(); + let mut kcp = match socket.try_lock() { + Ok(guard) => guard, + Err(..) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + + kcp.poll_send(cx, buf) + } + + pub async fn send(&mut self, buf: &[u8]) -> KcpResult { + future::poll_fn(|cx| self.poll_send(cx, buf)).await + } + + pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + loop { + // Consumes all data in buffer + if self.recv_buffer_pos < self.recv_buffer_cap { + let remaining = self.recv_buffer_cap - self.recv_buffer_pos; + let copy_length = remaining.min(buf.len()); + + buf.copy_from_slice( + &self.recv_buffer[self.recv_buffer_pos..self.recv_buffer_pos + copy_length], + ); + self.recv_buffer_pos += copy_length; + return Ok(copy_length).into(); + } + + // Mutex doesn't have poll_lock, spinning on it. + let socket = self.session.kcp_socket(); + let mut kcp = match socket.try_lock() { + Ok(guard) => guard, + Err(..) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + + // Try to read from KCP + // 1. Read directly with user provided `buf` + match ready!(kcp.poll_recv(cx, buf)) { + Ok(n) => { + trace!("[CLIENT] recv directly {} bytes", n); + return Ok(n).into(); + } + Err(KcpError::UserBufTooSmall) => {} + Err(err) => return Err(err).into(), + } + + // 2. User `buf` too small, read to recv_buffer + let required_size = kcp.peek_size()?; + if self.recv_buffer.len() < required_size { + self.recv_buffer.resize(required_size, 0); + } + + match ready!(kcp.poll_recv(cx, &mut self.recv_buffer)) { + Ok(n) => { + trace!("[CLIENT] recv buffered {} bytes", n); + self.recv_buffer_pos = 0; + self.recv_buffer_cap = n; + } + Err(err) => return Err(err).into(), + } + } + } + + pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult { + future::poll_fn(|cx| self.poll_recv(cx, buf)).await + } +} + +impl AsyncRead for KcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match ready!(self.poll_recv(cx, buf.initialize_unfilled())) { + Ok(n) => { + buf.advance(n); + Ok(()).into() + } + Err(KcpError::IoError(err)) => Err(err).into(), + Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(), + } + } +} + +impl AsyncWrite for KcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match ready!(self.poll_send(cx, buf)) { + Ok(n) => Ok(n).into(), + Err(KcpError::IoError(err)) => Err(err).into(), + Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Mutex doesn't have poll_lock, spinning on it. + let socket = self.session.kcp_socket(); + let mut kcp = match socket.try_lock() { + Ok(guard) => guard, + Err(..) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + + match kcp.flush() { + Ok(..) => Ok(()).into(), + Err(KcpError::IoError(err)) => Err(err).into(), + Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } +} diff --git a/src/plugins/kcp/utils.rs b/src/plugins/kcp/utils.rs new file mode 100644 index 0000000..293c69f --- /dev/null +++ b/src/plugins/kcp/utils.rs @@ -0,0 +1,10 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +#[inline] +pub fn now_millis() -> u32 { + let start = SystemTime::now(); + let since_the_epoch = start + .duration_since(UNIX_EPOCH) + .expect("time went afterwards"); + (since_the_epoch.as_secs() * 1000 + since_the_epoch.subsec_millis() as u64 / 1_000_000) as u32 +} diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs new file mode 100644 index 0000000..ee8689c --- /dev/null +++ b/src/plugins/mod.rs @@ -0,0 +1 @@ +pub mod kcp; diff --git a/src/servers/mod.rs b/src/servers/mod.rs index 87189b2..3b72e97 100644 --- a/src/servers/mod.rs +++ b/src/servers/mod.rs @@ -104,12 +104,12 @@ impl Server { #[cfg(test)] mod test { + use crate::plugins::kcp::{KcpConfig, KcpStream}; use std::net::SocketAddr; use std::thread::{self, sleep}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; - use tokio_kcp::{KcpConfig, KcpStream}; use super::*; diff --git a/src/servers/protocol/kcp.rs b/src/servers/protocol/kcp.rs index 47b769f..cd3e0a8 100644 --- a/src/servers/protocol/kcp.rs +++ b/src/servers/protocol/kcp.rs @@ -1,3 +1,4 @@ +use crate::plugins::kcp::{KcpConfig, KcpListener, KcpStream}; use crate::servers::Proxy; use futures::future::try_join; use log::{debug, error, warn}; @@ -6,7 +7,6 @@ use std::sync::Arc; use tokio::io; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_kcp::{KcpConfig, KcpListener, KcpStream}; pub async fn proxy(config: Arc) -> Result<(), Box> { let kcp_config = KcpConfig::default();