191 lines
6.1 KiB
Rust
191 lines
6.1 KiB
Rust
use log::{error, info};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use tokio::task::JoinHandle;
|
|
|
|
mod protocol;
|
|
use crate::config::BaseConfig;
|
|
use protocol::{kcp, tcp};
|
|
|
|
#[derive(Debug)]
|
|
pub struct Server {
|
|
pub proxies: Vec<Arc<Proxy>>,
|
|
pub config: BaseConfig,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Proxy {
|
|
pub name: String,
|
|
pub listen: SocketAddr,
|
|
pub protocol: String,
|
|
pub tls: bool,
|
|
pub sni: Option<HashMap<String, String>>,
|
|
pub default: String,
|
|
pub upstream: HashMap<String, String>,
|
|
}
|
|
|
|
impl Server {
|
|
pub fn new(config: BaseConfig) -> Self {
|
|
let mut new_server = Server {
|
|
proxies: Vec::new(),
|
|
config: config.clone(),
|
|
};
|
|
|
|
for (name, proxy) in config.servers.iter() {
|
|
let protocol = proxy.protocol.clone().unwrap_or_else(|| "tcp".to_string());
|
|
let tls = proxy.tls.unwrap_or(false);
|
|
let sni = proxy.sni.clone();
|
|
let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string());
|
|
let upstream = config.upstream.clone();
|
|
let mut upstream_set: HashSet<String> = HashSet::new();
|
|
for key in upstream.keys() {
|
|
if key.eq("ban") || key.eq("echo") {
|
|
continue;
|
|
}
|
|
upstream_set.insert(key.clone());
|
|
}
|
|
for listen in proxy.listen.clone() {
|
|
let listen_addr: SocketAddr = match listen.parse() {
|
|
Ok(addr) => addr,
|
|
Err(_) => {
|
|
error!("Invalid listen address: {}", listen);
|
|
continue;
|
|
}
|
|
};
|
|
let proxy = Proxy {
|
|
name: name.clone(),
|
|
listen: listen_addr,
|
|
protocol: protocol.clone(),
|
|
tls,
|
|
sni: sni.clone(),
|
|
default: default.clone(),
|
|
upstream: upstream.clone(),
|
|
};
|
|
new_server.proxies.push(Arc::new(proxy));
|
|
}
|
|
}
|
|
|
|
new_server
|
|
}
|
|
|
|
#[tokio::main]
|
|
pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
|
let proxies = self.proxies.clone();
|
|
let mut handles: Vec<JoinHandle<()>> = Vec::new();
|
|
|
|
for config in proxies {
|
|
info!(
|
|
"Starting {} server {} on {}",
|
|
config.protocol, config.name, config.listen
|
|
);
|
|
let handle = tokio::spawn(async move {
|
|
match config.protocol.as_ref() {
|
|
"tcp" => {
|
|
let _ = tcp::proxy(config).await;
|
|
}
|
|
"kcp" => {
|
|
let _ = kcp::proxy(config).await;
|
|
}
|
|
_ => {
|
|
error!("Invalid protocol: {}", config.protocol)
|
|
}
|
|
}
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
|
|
for handle in handles {
|
|
handle.await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use crate::plugins::kcp::{KcpConfig, KcpStream};
|
|
use std::net::SocketAddr;
|
|
use std::thread::{self, sleep};
|
|
use std::time::Duration;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
use super::*;
|
|
|
|
#[tokio::main]
|
|
async fn tcp_mock_server() {
|
|
let server_addr: SocketAddr = "127.0.0.1:54599".parse().unwrap();
|
|
let listener = TcpListener::bind(server_addr).await.unwrap();
|
|
loop {
|
|
let (mut stream, _) = listener.accept().await.unwrap();
|
|
let mut buf = [0u8; 2];
|
|
let mut n = stream.read(&mut buf).await.unwrap();
|
|
while n > 0 {
|
|
stream.write(b"hello").await.unwrap();
|
|
if buf.eq(b"by") {
|
|
stream.shutdown().await.unwrap();
|
|
break;
|
|
}
|
|
n = stream.read(&mut buf).await.unwrap();
|
|
}
|
|
stream.shutdown().await.unwrap();
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_proxy() {
|
|
use crate::config::Config;
|
|
let config = Config::new("tests/config.yaml").unwrap();
|
|
let mut server = Server::new(config.base);
|
|
thread::spawn(move || {
|
|
tcp_mock_server();
|
|
});
|
|
sleep(Duration::from_secs(1)); // wait for server to start
|
|
thread::spawn(move || {
|
|
let _ = server.run();
|
|
});
|
|
sleep(Duration::from_secs(1)); // wait for server to start
|
|
|
|
// test TCP proxy
|
|
let mut conn = TcpStream::connect("127.0.0.1:54500").await.unwrap();
|
|
let mut buf = [0u8; 5];
|
|
conn.write(b"hi").await.unwrap();
|
|
conn.read(&mut buf).await.unwrap();
|
|
assert_eq!(&buf, b"hello");
|
|
conn.shutdown().await.unwrap();
|
|
|
|
// test TCP echo
|
|
let mut conn = TcpStream::connect("127.0.0.1:54956").await.unwrap();
|
|
let mut buf = [0u8; 1];
|
|
for i in 0..=10u8 {
|
|
conn.write(&[i]).await.unwrap();
|
|
conn.read(&mut buf).await.unwrap();
|
|
assert_eq!(&buf, &[i]);
|
|
}
|
|
conn.shutdown().await.unwrap();
|
|
|
|
// test KCP echo
|
|
let kcp_config = KcpConfig::default();
|
|
let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap();
|
|
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
|
|
let mut buf = [0u8; 1];
|
|
for i in 0..=10u8 {
|
|
conn.write(&[i]).await.unwrap();
|
|
conn.read(&mut buf).await.unwrap();
|
|
assert_eq!(&buf, &[i]);
|
|
}
|
|
conn.shutdown().await.unwrap();
|
|
|
|
// test KCP proxy and close mock server
|
|
let kcp_config = KcpConfig::default();
|
|
let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap();
|
|
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
|
|
let mut buf = [0u8; 5];
|
|
conn.write(b"by").await.unwrap();
|
|
conn.read(&mut buf).await.unwrap();
|
|
assert_eq!(&buf, b"hello");
|
|
conn.shutdown().await.unwrap();
|
|
}
|
|
}
|