diff --git a/src/config.rs b/src/config.rs index d0a82c2..3fb3252 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,11 @@ use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{Error as IOError, Read}; +use std::net::SocketAddr; +use tokio::sync::Mutex; use url::Url; +use tokio::time::Instant; +use time::OffsetDateTime; #[derive(Debug, Clone)] pub struct Config { @@ -42,11 +46,75 @@ pub enum Upstream { Custom(CustomUpstream), } +#[derive(Debug)] +struct Addr(Mutex>); + +impl Default for Addr { + fn default() -> Self { + Self(Default::default()) + } +} + +impl Clone for Addr { + fn clone(&self) -> Self { + tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone()))) + } +} + #[derive(Debug, Clone, Deserialize)] pub struct CustomUpstream { pub name: String, pub addr: String, pub protocol: String, + #[serde(skip_deserializing)] + addresses: Addr, +} + +impl CustomUpstream { + pub async fn resolve_addresses(&self) -> std::io::Result<()> { + { + let addr = self.addresses.0.lock().await; + if addr.len() > 0 { + debug!("Already have addresses: {:?}", &addr); + return Ok(()); + } + } + + debug!("Resolving addresses for {}", &self.addr); + let addresses = tokio::net::lookup_host(self.addr.clone()).await?; + + let mut addr: Vec = match self.protocol.as_ref() { + "tcp4" => addresses.into_iter().filter(|a| a.is_ipv4()).collect(), + "tcp6" => addresses.into_iter().filter(|a| a.is_ipv6()).collect(), + _ => addresses.collect(), + }; + + debug!("Got addresses for {}: {:?}", &self.addr, &addr); + debug!("Resolved at {}", OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).expect("Format")); + + { + let mut self_addr = self.addresses.0.lock().await; + self_addr.clear(); + self_addr.append(&mut addr); + } + Ok(()) + } + + pub async fn get_addresses(&self) -> Vec { + let a = self.addresses.0.lock().await; + a.clone() + } +} + +impl Default for CustomUpstream { + fn default() -> Self { + Self { + name: Default::default(), + addr: Default::default(), + protocol: Default::default(), + addresses: Default::default(), + } + } } #[derive(Debug)] @@ -119,11 +187,14 @@ fn load_config(path: &str) -> Result { } }; - if upstream_url.scheme() != "tcp" { - return Err(ConfigError::Custom(format!( - "Invalid upstream scheme {}", - upstream - ))); + match upstream_url.scheme() { + "tcp" | "tcp4" | "tcp6" => {} + _ => { + return Err(ConfigError::Custom(format!( + "Invalid upstream scheme {}", + upstream + ))) + } } parsed_upstream.insert( @@ -132,6 +203,7 @@ fn load_config(path: &str) -> Result { name: name.to_string(), addr: format!("{}:{}", upstream_host, upsteam_port), protocol: upstream_url.scheme().to_string(), + ..Default::default() }), ); } diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index ee8689c..a219535 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -1 +1 @@ -pub mod kcp; +//pub mod kcp; diff --git a/src/servers/mod.rs b/src/servers/mod.rs index eafe736..ea54ef8 100644 --- a/src/servers/mod.rs +++ b/src/servers/mod.rs @@ -6,7 +6,7 @@ use tokio::task::JoinHandle; mod protocol; use crate::config::{ParsedConfig, Upstream}; -use protocol::{kcp, tcp}; +use protocol::tcp; #[derive(Debug)] pub struct Server { @@ -88,12 +88,24 @@ impl Server { error!("Failed to start {}: {}", config.name, res.err().unwrap()); } } - "kcp" => { - let res = kcp::proxy(config.clone()).await; + "tcp4" => { + let res = tcp::proxy(config.clone()).await; if res.is_err() { error!("Failed to start {}: {}", config.name, res.err().unwrap()); } } + "tcp6" => { + let res = tcp::proxy(config.clone()).await; + if res.is_err() { + error!("Failed to start {}: {}", config.name, res.err().unwrap()); + } + } + // "kcp" => { + // let res = kcp::proxy(config.clone()).await; + // if res.is_err() { + // error!("Failed to start {}: {}", config.name, res.err().unwrap()); + // } + // } _ => { error!("Invalid protocol: {}", config.protocol) } @@ -111,12 +123,11 @@ impl Server { #[cfg(test)] mod tests { - use crate::plugins::kcp::{KcpConfig, KcpStream}; - use std::net::SocketAddr; + //use crate::plugins::kcp::{KcpConfig, KcpStream}; use std::thread::{self, sleep}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use tokio::net::{TcpListener, TcpStream}; + use tokio::net::TcpListener; use super::*; @@ -155,7 +166,9 @@ mod tests { sleep(Duration::from_secs(1)); // wait for server to start // test TCP proxy - let mut conn = TcpStream::connect("127.0.0.1:54500").await.unwrap(); + let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54500") + .await + .unwrap(); let mut buf = [0u8; 5]; conn.write(b"hi").await.unwrap(); conn.read(&mut buf).await.unwrap(); @@ -163,7 +176,9 @@ mod tests { conn.shutdown().await.unwrap(); // test TCP echo - let mut conn = TcpStream::connect("127.0.0.1:54956").await.unwrap(); + let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54956") + .await + .unwrap(); let mut buf = [0u8; 1]; for i in 0..=10u8 { conn.write(&[i]).await.unwrap(); @@ -173,25 +188,25 @@ mod tests { conn.shutdown().await.unwrap(); // test KCP echo - let kcp_config = KcpConfig::default(); - let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap(); - let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap(); - let mut buf = [0u8; 1]; - for i in 0..=10u8 { - conn.write(&[i]).await.unwrap(); - conn.read(&mut buf).await.unwrap(); - assert_eq!(&buf, &[i]); - } - conn.shutdown().await.unwrap(); - - // test KCP proxy and close mock server - let kcp_config = KcpConfig::default(); - let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap(); - let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap(); - let mut buf = [0u8; 5]; - conn.write(b"by").await.unwrap(); - conn.read(&mut buf).await.unwrap(); - assert_eq!(&buf, b"hello"); - conn.shutdown().await.unwrap(); + // let kcp_config = KcpConfig::default(); + // let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap(); + // let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap(); + // let mut buf = [0u8; 1]; + // for i in 0..=10u8 { + // conn.write(&[i]).await.unwrap(); + // conn.read(&mut buf).await.unwrap(); + // assert_eq!(&buf, &[i]); + // } + // conn.shutdown().await.unwrap(); + // + // // test KCP proxy and close mock server + // let kcp_config = KcpConfig::default(); + // let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap(); + // let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap(); + // let mut buf = [0u8; 5]; + // conn.write(b"by").await.unwrap(); + // conn.read(&mut buf).await.unwrap(); + // assert_eq!(&buf, b"hello"); + // conn.shutdown().await.unwrap(); } } diff --git a/src/servers/protocol/mod.rs b/src/servers/protocol/mod.rs index 014642a..7969da2 100644 --- a/src/servers/protocol/mod.rs +++ b/src/servers/protocol/mod.rs @@ -1,3 +1,3 @@ -pub mod kcp; +//pub mod kcp; pub mod tcp; pub mod tls; diff --git a/src/servers/protocol/tcp.rs b/src/servers/protocol/tcp.rs index a93363a..c415d29 100644 --- a/src/servers/protocol/tcp.rs +++ b/src/servers/protocol/tcp.rs @@ -2,7 +2,7 @@ use crate::config::Upstream; use crate::servers::protocol::tls::get_sni; use crate::servers::Proxy; use futures::future::try_join; -use log::{debug, error, warn}; +use log::{debug, error, info, warn}; use std::sync::Arc; use tokio::io; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -34,7 +34,7 @@ pub async fn proxy(config: Arc) -> Result<(), Box> } async fn accept(inbound: TcpStream, proxy: Arc) -> Result<(), Box> { - debug!("New connection from {:?}", inbound.peer_addr()?); + info!("New connection from {:?}", inbound.peer_addr()?); let upstream_name = match proxy.tls { false => proxy.default.clone(), @@ -72,16 +72,22 @@ async fn accept(inbound: TcpStream, proxy: Arc) -> Result<(), Box u.resolve_addresses().await?, + _ => {} + } + + return process(inbound, upstream.clone()).await; } async fn process( mut inbound: TcpStream, - upstream: &Upstream, + upstream: Upstream, ) -> Result<(), Box> { match upstream { Upstream::Ban => { @@ -93,25 +99,30 @@ async fn process( let bytes_tx = inbound_to_inbound.await; debug!("Bytes read: {:?}", bytes_tx); } - Upstream::Custom(custom) => match custom.protocol.as_ref() { - "tcp" => { - let outbound = TcpStream::connect(custom.addr.clone()).await?; + Upstream::Custom(custom) => { + custom.resolve_addresses().await?; + let outbound = match custom.protocol.as_ref() { + "tcp4" | "tcp6" | "tcp" => { + TcpStream::connect(custom.get_addresses().await.as_slice()).await? + } + _ => { + error!("Reached unknown protocol: {:?}", custom.protocol); + return Err("Reached unknown protocol".into()); + } + }; - let (mut ri, mut wi) = io::split(inbound); - let (mut ro, mut wo) = io::split(outbound); + debug!("Connected to {:?}", outbound.peer_addr().unwrap()); - let inbound_to_outbound = copy(&mut ri, &mut wo); - let outbound_to_inbound = copy(&mut ro, &mut wi); + let (mut ri, mut wi) = io::split(inbound); + let (mut ro, mut wo) = io::split(outbound); - let (bytes_tx, bytes_rx) = - try_join(inbound_to_outbound, outbound_to_inbound).await?; + let inbound_to_outbound = copy(&mut ri, &mut wo); + let outbound_to_inbound = copy(&mut ro, &mut wi); - debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx); - } - _ => { - error!("Reached unknown protocol: {:?}", custom.protocol); - } - }, + let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?; + + debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx); + } }; Ok(()) } diff --git a/src/servers/protocol/tls.rs b/src/servers/protocol/tls.rs index 398c36a..1227c15 100644 --- a/src/servers/protocol/tls.rs +++ b/src/servers/protocol/tls.rs @@ -49,6 +49,7 @@ pub fn get_sni(buf: &[u8]) -> Vec { } } + debug!("Found SNIs: {:?}", &snis); snis }