use log::{error, info}; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use tokio::task::JoinHandle; mod protocol; pub(crate) mod upstream_address; use crate::config::ParsedConfigV1; use crate::upstreams::Upstream; use protocol::tcp; #[derive(Debug)] pub(crate) struct Server { pub proxies: Vec>, } #[derive(Debug, Clone)] pub(crate) struct Proxy { pub name: String, pub listen: SocketAddr, pub protocol: String, pub tls: bool, pub sni: Option>, pub default_action: String, pub upstream: HashMap, } impl Server { pub fn new_from_v1_config(config: ParsedConfigV1) -> Self { let mut new_server = Server { proxies: Vec::new(), }; for (name, proxy) in config.servers.iter() { let protocol = proxy.protocol.clone().unwrap_or_else(|| "tcp".to_string()); let tls = proxy.tls.unwrap_or(false); let sni = proxy.sni.clone(); let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string()); let upstream = config.upstream.clone(); let mut upstream_set: HashSet = HashSet::new(); for key in upstream.keys() { if key.eq("ban") || key.eq("echo") { continue; } upstream_set.insert(key.clone()); } for listen in proxy.listen.clone() { let listen_addr: SocketAddr = match listen.parse() { Ok(addr) => addr, Err(_) => { error!("Invalid listen address: {}", listen); continue; } }; let proxy = Proxy { name: name.clone(), listen: listen_addr, protocol: protocol.clone(), tls, sni: sni.clone(), default_action: default.clone(), upstream: upstream.clone(), }; new_server.proxies.push(Arc::new(proxy)); } } new_server } #[tokio::main] pub async fn run(&mut self) -> Result<(), Box> { let proxies = self.proxies.clone(); let mut handles: Vec> = Vec::new(); for config in proxies { info!( "Starting {} server {} on {}", config.protocol, config.name, config.listen ); let handle = tokio::spawn(async move { match config.protocol.as_ref() { "tcp" | "tcp4" | "tcp6" => { let res = tcp::proxy(config.clone()).await; if res.is_err() { error!("Failed to start {}: {}", config.name, res.err().unwrap()); } } _ => { error!("Invalid protocol: {}", config.protocol) } } }); handles.push(handle); } for handle in handles { handle.await?; } Ok(()) } } #[cfg(test)] mod tests { use std::thread::{self, sleep}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use super::*; #[tokio::main] async fn tcp_mock_server() { let server_addr: SocketAddr = "127.0.0.1:54599".parse().unwrap(); let listener = TcpListener::bind(server_addr).await.unwrap(); loop { let (mut stream, _) = listener.accept().await.unwrap(); let mut buf = [0u8; 2]; let mut n = stream.read(&mut buf).await.unwrap(); while n > 0 { let _ = stream.write(b"hello").await.unwrap(); if buf.eq(b"by") { stream.shutdown().await.unwrap(); break; } n = stream.read(&mut buf).await.unwrap(); } stream.shutdown().await.unwrap(); } } #[tokio::test] async fn test_proxy() { use crate::config::ConfigV1; let config = ConfigV1::new("tests/config.yaml").unwrap(); let mut server = Server::new_from_v1_config(config.base); thread::spawn(move || { tcp_mock_server(); }); sleep(Duration::from_secs(1)); // wait for server to start thread::spawn(move || { let _ = server.run(); }); sleep(Duration::from_secs(1)); // wait for server to start // test TCP proxy let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54500") .await .unwrap(); let mut buf = [0u8; 5]; let _ = conn.write(b"hi").await.unwrap(); let _ = conn.read(&mut buf).await.unwrap(); assert_eq!(&buf, b"hello"); conn.shutdown().await.unwrap(); // test TCP echo let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54956") .await .unwrap(); let mut buf = [0u8; 1]; for i in 0..=10u8 { let _ = conn.write(&[i]).await.unwrap(); let _ = conn.read(&mut buf).await.unwrap(); assert_eq!(&buf, &[i]); } conn.shutdown().await.unwrap(); } }