Add KCP support

This commit is contained in:
KernelErr
2021-10-26 21:36:12 +08:00
parent 4c9b7a6990
commit 55eef8581c
12 changed files with 404 additions and 143 deletions

View File

@ -20,6 +20,7 @@ pub struct BaseConfig {
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ServerConfig {
pub listen: Vec<String>,
pub protocol: Option<String>,
pub tls: Option<bool>,
pub sni: Option<HashMap<String, String>>,
pub default: Option<String>,
@ -86,7 +87,7 @@ mod tests {
let config = Config::new("tests/config.yaml").unwrap();
assert_eq!(config.base.version, 1);
assert_eq!(config.base.log.unwrap(), "disable");
assert_eq!(config.base.servers.len(), 2);
assert_eq!(config.base.upstream.len(), 2);
assert_eq!(config.base.servers.len(), 4);
assert_eq!(config.base.upstream.len(), 3);
}
}

View File

@ -1,16 +1,12 @@
use futures::future::try_join;
use log::{debug, error, info, warn};
use log::{error, info};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task::JoinHandle;
mod tls;
use self::tls::get_sni;
mod protocol;
use crate::config::BaseConfig;
use protocol::{kcp, tcp};
#[derive(Debug)]
pub struct Server {
@ -22,6 +18,7 @@ pub struct Server {
pub struct Proxy {
pub name: String,
pub listen: SocketAddr,
pub protocol: String,
pub tls: bool,
pub sni: Option<HashMap<String, String>>,
pub default: String,
@ -36,6 +33,7 @@ impl Server {
};
for (name, proxy) in config.servers.iter() {
let protocol = proxy.protocol.clone().unwrap_or_else(|| "tcp".to_string());
let tls = proxy.tls.unwrap_or(false);
let sni = proxy.sni.clone();
let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string());
@ -48,7 +46,6 @@ impl Server {
upstream_set.insert(key.clone());
}
for listen in proxy.listen.clone() {
println!("{:?}", listen);
let listen_addr: SocketAddr = match listen.parse() {
Ok(addr) => addr,
Err(_) => {
@ -59,6 +56,7 @@ impl Server {
let proxy = Proxy {
name: name.clone(),
listen: listen_addr,
protocol: protocol.clone(),
tls,
sni: sni.clone(),
default: default.clone(),
@ -77,9 +75,22 @@ impl Server {
let mut handles: Vec<JoinHandle<()>> = Vec::new();
for config in proxies {
info!("Starting server {} on {}", config.name, config.listen);
info!(
"Starting {} server {} on {}",
config.protocol, config.name, config.listen
);
let handle = tokio::spawn(async move {
let _ = proxy(config).await;
match config.protocol.as_ref() {
"tcp" => {
let _ = tcp::proxy(config).await;
}
"kcp" => {
let _ = kcp::proxy(config).await;
}
_ => {
error!("Invalid protocol: {}", config.protocol)
}
}
});
handles.push(handle);
}
@ -91,131 +102,48 @@ impl Server {
}
}
async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(config.listen).await?;
let config = config.clone();
loop {
let thread_proxy = config.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
}
Ok((stream, _)) => {
tokio::spawn(async move {
match accept(stream, thread_proxy).await {
Ok(_) => {}
Err(err) => {
error!("Relay thread returned an error: {}", err);
}
};
});
}
}
}
}
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls {
false => proxy.default.clone(),
true => {
let mut hello_buf = [0u8; 1024];
inbound.peek(&mut hello_buf).await?;
let snis = get_sni(&hello_buf);
if snis.is_empty() {
proxy.default.clone()
} else {
match proxy.sni.clone() {
Some(sni_map) => {
let mut upstream = proxy.default.clone();
for sni in snis {
let m = sni_map.get(&sni);
if m.is_some() {
upstream = m.unwrap().clone();
break;
}
}
upstream
}
None => proxy.default.clone(),
}
}
}
};
debug!("Upstream: {}", upstream_name);
let upstream = match proxy.upstream.get(&upstream_name) {
Some(upstream) => upstream,
None => {
warn!(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, &proxy.default).await;
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: TcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
if upstream == "ban" {
let _ = inbound.shutdown();
return Ok(());
} else if upstream == "echo" {
loop {
let mut buf = [0u8; 1];
let b = inbound.read(&mut buf).await?;
if b == 0 {
break;
} else {
inbound.write(&buf).await?;
}
}
return Ok(());
}
let outbound = TcpStream::connect(upstream).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}
#[cfg(test)]
mod test {
use std::net::SocketAddr;
use std::thread::{self, sleep};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_kcp::{KcpConfig, KcpStream};
use super::*;
#[tokio::main]
async fn tcp_mock_server() {
let server_addr: SocketAddr = "127.0.0.1:54599".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
let (mut stream, _) = listener.accept().await.unwrap();
stream.write(b"hello").await.unwrap();
stream.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_echo_server() {
async fn test_tcp_proxy() {
use crate::config::Config;
let config = Config::new("tests/config.yaml").unwrap();
let mut server = Server::new(config.base);
thread::spawn(move || {
tcp_mock_server();
});
sleep(Duration::from_secs(1)); // wait for server to start
thread::spawn(move || {
let _ = server.run();
});
sleep(Duration::from_secs(1)); // wait for server to start
let mut conn = TcpStream::connect("127.0.0.1:54500").await.unwrap();
let mut buf = [0u8; 5];
conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
conn.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_tcp_echo_server() {
use crate::config::Config;
let config = Config::new("tests/config.yaml").unwrap();
let mut server = Server::new(config.base);
@ -232,4 +160,25 @@ mod test {
}
conn.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_kcp_echo_server() {
use crate::config::Config;
let config = Config::new("tests/config.yaml").unwrap();
let mut server = Server::new(config.base);
thread::spawn(move || {
let _ = server.run();
});
sleep(Duration::from_secs(1)); // wait for server to start
let kcp_config = KcpConfig::default();
let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap();
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
let mut buf = [0u8; 1];
for i in 0..=10u8 {
conn.write(&[i]).await.unwrap();
conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, &[i]);
}
conn.shutdown().await.unwrap();
}
}

100
src/servers/protocol/kcp.rs Normal file
View File

@ -0,0 +1,100 @@
use crate::servers::Proxy;
use futures::future::try_join;
use log::{debug, error, warn};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_kcp::{KcpConfig, KcpListener, KcpStream};
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let kcp_config = KcpConfig::default();
let mut listener = KcpListener::bind(kcp_config, config.listen).await?;
let config = config.clone();
loop {
let thread_proxy = config.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
}
Ok((stream, peer)) => {
tokio::spawn(async move {
match accept(stream, peer, thread_proxy).await {
Ok(_) => {}
Err(err) => {
error!("Relay thread returned an error: {}", err);
}
};
});
}
}
}
}
async fn accept(
inbound: KcpStream,
peer: SocketAddr,
proxy: Arc<Proxy>,
) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", peer);
let upstream_name = proxy.default.clone();
debug!("Upstream: {}", upstream_name);
let upstream = match proxy.upstream.get(&upstream_name) {
Some(upstream) => upstream,
None => {
warn!(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, &proxy.default).await;
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: KcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
if upstream == "ban" {
let _ = inbound.shutdown();
return Ok(());
} else if upstream == "echo" {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
return Ok(());
}
let outbound = TcpStream::connect(upstream).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}

View File

@ -0,0 +1,3 @@
pub mod kcp;
pub mod tcp;
pub mod tls;

119
src/servers/protocol/tcp.rs Normal file
View File

@ -0,0 +1,119 @@
use crate::servers::protocol::tls::get_sni;
use crate::servers::Proxy;
use futures::future::try_join;
use log::{debug, error, warn};
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(config.listen).await?;
let config = config.clone();
loop {
let thread_proxy = config.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
}
Ok((stream, _)) => {
tokio::spawn(async move {
match accept(stream, thread_proxy).await {
Ok(_) => {}
Err(err) => {
error!("Relay thread returned an error: {}", err);
}
};
});
}
}
}
}
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls {
false => proxy.default.clone(),
true => {
let mut hello_buf = [0u8; 1024];
inbound.peek(&mut hello_buf).await?;
let snis = get_sni(&hello_buf);
if snis.is_empty() {
proxy.default.clone()
} else {
match proxy.sni.clone() {
Some(sni_map) => {
let mut upstream = proxy.default.clone();
for sni in snis {
let m = sni_map.get(&sni);
if m.is_some() {
upstream = m.unwrap().clone();
break;
}
}
upstream
}
None => proxy.default.clone(),
}
}
}
};
debug!("Upstream: {}", upstream_name);
let upstream = match proxy.upstream.get(&upstream_name) {
Some(upstream) => upstream,
None => {
warn!(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, &proxy.default).await;
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: TcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
if upstream == "ban" {
let _ = inbound.shutdown();
return Ok(());
} else if upstream == "echo" {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
return Ok(());
}
let outbound = TcpStream::connect(upstream).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}