diff --git a/src/config.rs b/src/config.rs index 7e162d9..5d243bf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,11 +1,10 @@ -use crate::servers::upstream_address::UpstreamAddress; +use crate::upstreams::ProxyToUpstream; +use crate::upstreams::Upstream; use log::{debug, warn}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{Error as IOError, Read}; -use std::net::SocketAddr; -use tokio::sync::Mutex; use url::Url; #[derive(Debug, Clone)] @@ -37,48 +36,16 @@ pub struct ServerConfig { pub sni: Option>, pub default: Option, } - -#[derive(Debug, Clone, Deserialize)] -pub enum Upstream { - Ban, - Echo, - Proxy(ProxyToUpstream), -} - -#[derive(Debug, Default)] -struct Addr(Mutex); - -impl Clone for Addr { - fn clone(&self) -> Self { - tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone()))) - } -} - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct ProxyToUpstream { - pub addr: String, - pub protocol: String, - #[serde(skip_deserializing)] - addresses: Addr, -} - -impl ProxyToUpstream { - pub async fn resolve_addresses(&self) -> std::io::Result> { - let mut addr = self.addresses.0.lock().await; - addr.resolve((*self.protocol).into()).await - } -} - -impl TryFrom<&str> for ProxyToUpstream { +impl TryInto for &str { type Error = ConfigError; - fn try_from(upstream: &str) -> Result { - let upstream_url = match Url::parse(upstream) { + fn try_into(self) -> Result { + let upstream_url = match Url::parse(self) { Ok(url) => url, Err(_) => { return Err(ConfigError::Custom(format!( "Invalid upstream url {}", - upstream + self ))) } }; @@ -88,7 +55,7 @@ impl TryFrom<&str> for ProxyToUpstream { None => { return Err(ConfigError::Custom(format!( "Invalid upstream url {}", - upstream + self ))) } }; @@ -98,7 +65,7 @@ impl TryFrom<&str> for ProxyToUpstream { None => { return Err(ConfigError::Custom(format!( "Invalid upstream url {}", - upstream + self ))) } }; @@ -108,17 +75,15 @@ impl TryFrom<&str> for ProxyToUpstream { _ => { return Err(ConfigError::Custom(format!( "Invalid upstream scheme {}", - upstream + self ))) } } - let addr = UpstreamAddress::new(format!("{}:{}", upstream_host, upstream_port)); - Ok(ProxyToUpstream { - addr: format!("{}:{}", upstream_host, upstream_port), - protocol: upstream_url.scheme().to_string(), - addresses: Addr(Mutex::new(addr)), - }) + Ok(ProxyToUpstream::new( + format!("{}:{}", upstream_host, upstream_port), + upstream_url.scheme().to_string(), + )) } } @@ -165,7 +130,7 @@ fn load_config(path: &str) -> Result { parsed_upstream.insert("echo".to_string(), Upstream::Echo); for (name, upstream) in base.upstream.iter() { - let ups = ProxyToUpstream::try_from(upstream.as_str())?; + let ups = upstream.as_str().try_into()?; parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups)); } diff --git a/src/main.rs b/src/main.rs index c223887..0b92f6c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod config; mod plugins; mod servers; +mod upstreams; use crate::config::Config; use crate::servers::Server; diff --git a/src/servers/mod.rs b/src/servers/mod.rs index 27a47c5..c90391e 100644 --- a/src/servers/mod.rs +++ b/src/servers/mod.rs @@ -2,14 +2,14 @@ use log::{error, info}; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; + use tokio::task::JoinHandle; mod protocol; pub(crate) mod upstream_address; -use crate::config::{ParsedConfig, Upstream}; +use crate::config::ParsedConfig; +use crate::upstreams::Upstream; use protocol::tcp; #[derive(Debug)] @@ -212,17 +212,3 @@ mod tests { // conn.shutdown().await.unwrap(); } } - -async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result -where - R: AsyncRead + Unpin + ?Sized, - W: AsyncWrite + Unpin + ?Sized, -{ - match io::copy(reader, writer).await { - Ok(u64) => { - let _ = writer.shutdown().await; - Ok(u64) - } - Err(_) => Ok(0), - } -} diff --git a/src/servers/protocol/tcp.rs b/src/servers/protocol/tcp.rs index 43e2987..d8917f4 100644 --- a/src/servers/protocol/tcp.rs +++ b/src/servers/protocol/tcp.rs @@ -1,14 +1,11 @@ -use crate::config::Upstream; use crate::servers::protocol::tls::get_sni; -use crate::servers::{copy, Proxy}; -use futures::future::try_join; +use crate::servers::Proxy; use log::{debug, error, info, warn}; +use std::error::Error; use std::sync::Arc; -use tokio::io; -use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; -pub(crate) async fn proxy(config: Arc) -> Result<(), Box> { +pub(crate) async fn proxy(config: Arc) -> Result<(), Box> { let listener = TcpListener::bind(config.listen).await?; let config = config.clone(); @@ -33,7 +30,7 @@ pub(crate) async fn proxy(config: Arc) -> Result<(), Box) -> Result<(), Box> { +async fn accept(inbound: TcpStream, proxy: Arc) -> Result<(), Box> { info!("New connection from {:?}", inbound.peer_addr()?); let upstream_name = match proxy.tls { @@ -72,51 +69,9 @@ async fn accept(inbound: TcpStream, proxy: Arc) -> Result<(), Box Result<(), Box> { - match upstream { - Upstream::Ban => { - inbound.shutdown().await?; - } - Upstream::Echo => { - let (mut ri, mut wi) = io::split(inbound); - let inbound_to_inbound = copy(&mut ri, &mut wi); - let bytes_tx = inbound_to_inbound.await; - debug!("Bytes read: {:?}", bytes_tx); - } - Upstream::Proxy(config) => { - let outbound = match config.protocol.as_ref() { - "tcp4" | "tcp6" | "tcp" => { - TcpStream::connect(config.resolve_addresses().await?.as_slice()).await? - } - _ => { - error!("Reached unknown protocol: {:?}", config.protocol); - return Err("Reached unknown protocol".into()); - } - }; - - debug!("Connected to {:?}", outbound.peer_addr().unwrap()); - - let (mut ri, mut wi) = io::split(inbound); - let (mut ro, mut wo) = io::split(outbound); - - let inbound_to_outbound = copy(&mut ri, &mut wo); - let outbound_to_inbound = copy(&mut ro, &mut wi); - - let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?; - - debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx); - } - }; - Ok(()) + upstream.process(inbound).await } diff --git a/src/upstreams/mod.rs b/src/upstreams/mod.rs new file mode 100644 index 0000000..2d0b086 --- /dev/null +++ b/src/upstreams/mod.rs @@ -0,0 +1,51 @@ +mod proxy_to_upstream; + +use log::debug; +use serde::Deserialize; +use std::error::Error; +use tokio::io; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; + +pub use crate::upstreams::proxy_to_upstream::ProxyToUpstream; + +#[derive(Debug, Clone, Deserialize)] +pub enum Upstream { + Ban, + Echo, + Proxy(ProxyToUpstream), +} + +impl Upstream { + pub(crate) async fn process(&self, mut inbound: TcpStream) -> Result<(), Box> { + match self { + Upstream::Ban => { + inbound.shutdown().await?; + } + Upstream::Echo => { + let (mut ri, mut wi) = io::split(inbound); + let inbound_to_inbound = copy(&mut ri, &mut wi); + let bytes_tx = inbound_to_inbound.await; + debug!("Bytes read: {:?}", bytes_tx); + } + Upstream::Proxy(config) => { + config.proxy(inbound).await?; + } + }; + Ok(()) + } +} + +async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result +where + R: AsyncRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + match io::copy(reader, writer).await { + Ok(u64) => { + let _ = writer.shutdown().await; + Ok(u64) + } + Err(_) => Ok(0), + } +} diff --git a/src/upstreams/proxy_to_upstream.rs b/src/upstreams/proxy_to_upstream.rs new file mode 100644 index 0000000..4ca77a5 --- /dev/null +++ b/src/upstreams/proxy_to_upstream.rs @@ -0,0 +1,68 @@ +use crate::servers::upstream_address::UpstreamAddress; + +use crate::upstreams::copy; +use futures::future::try_join; +use log::{debug, error}; +use serde::Deserialize; +use std::net::SocketAddr; +use tokio::io; +use tokio::net::TcpStream; +use tokio::sync::Mutex; + +#[derive(Debug, Default)] +struct Addr(Mutex); + +impl Clone for Addr { + fn clone(&self) -> Self { + tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone()))) + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct ProxyToUpstream { + pub addr: String, + pub protocol: String, + #[serde(skip_deserializing)] + addresses: Addr, +} + +impl ProxyToUpstream { + pub async fn resolve_addresses(&self) -> std::io::Result> { + let mut addr = self.addresses.0.lock().await; + addr.resolve((*self.protocol).into()).await + } + + pub fn new(address: String, protocol: String) -> Self { + Self { + addr: address.clone(), + protocol, + addresses: Addr(Mutex::new(UpstreamAddress::new(address))), + } + } + + pub(crate) async fn proxy(&self, inbound: TcpStream) -> Result<(), Box> { + let outbound = match self.protocol.as_ref() { + "tcp4" | "tcp6" | "tcp" => { + TcpStream::connect(self.resolve_addresses().await?.as_slice()).await? + } + _ => { + error!("Reached unknown protocol: {:?}", self.protocol); + return Err("Reached unknown protocol".into()); + } + }; + + debug!("Connected to {:?}", outbound.peer_addr().unwrap()); + + let (mut ri, mut wi) = io::split(inbound); + let (mut ro, mut wo) = io::split(outbound); + + let inbound_to_outbound = copy(&mut ri, &mut wo); + let outbound_to_inbound = copy(&mut ro, &mut wi); + + let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?; + + debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx); + + Ok(()) + } +}