diff --git a/src/config.rs b/src/config.rs index 3fb3252..64c7699 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,4 @@ +use crate::servers::upstream_address::UpstreamAddress; use log::{debug, warn}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; @@ -6,8 +7,6 @@ use std::io::{Error as IOError, Read}; use std::net::SocketAddr; use tokio::sync::Mutex; use url::Url; -use tokio::time::Instant; -use time::OffsetDateTime; #[derive(Debug, Clone)] pub struct Config { @@ -47,7 +46,7 @@ pub enum Upstream { } #[derive(Debug)] -struct Addr(Mutex>); +struct Addr(Mutex); impl Default for Addr { fn default() -> Self { @@ -71,38 +70,9 @@ pub struct CustomUpstream { } impl CustomUpstream { - pub async fn resolve_addresses(&self) -> std::io::Result<()> { - { - let addr = self.addresses.0.lock().await; - if addr.len() > 0 { - debug!("Already have addresses: {:?}", &addr); - return Ok(()); - } - } - - debug!("Resolving addresses for {}", &self.addr); - let addresses = tokio::net::lookup_host(self.addr.clone()).await?; - - let mut addr: Vec = match self.protocol.as_ref() { - "tcp4" => addresses.into_iter().filter(|a| a.is_ipv4()).collect(), - "tcp6" => addresses.into_iter().filter(|a| a.is_ipv6()).collect(), - _ => addresses.collect(), - }; - - debug!("Got addresses for {}: {:?}", &self.addr, &addr); - debug!("Resolved at {}", OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).expect("Format")); - - { - let mut self_addr = self.addresses.0.lock().await; - self_addr.clear(); - self_addr.append(&mut addr); - } - Ok(()) - } - - pub async fn get_addresses(&self) -> Vec { - let a = self.addresses.0.lock().await; - a.clone() + pub async fn resolve_addresses(&self) -> std::io::Result> { + let mut addr = self.addresses.0.lock().await; + addr.resolve((*self.protocol).into()).await } } diff --git a/src/servers/mod.rs b/src/servers/mod.rs index 01ffb18..bda4e79 100644 --- a/src/servers/mod.rs +++ b/src/servers/mod.rs @@ -5,17 +5,19 @@ use std::sync::Arc; use tokio::task::JoinHandle; mod protocol; +pub(crate) mod upstream_address; + use crate::config::{ParsedConfig, Upstream}; use protocol::tcp; #[derive(Debug)] -pub struct Server { +pub(crate) struct Server { pub proxies: Vec>, pub config: ParsedConfig, } #[derive(Debug, Clone)] -pub struct Proxy { +pub(crate) struct Proxy { pub name: String, pub listen: SocketAddr, pub protocol: String, diff --git a/src/servers/protocol/tcp.rs b/src/servers/protocol/tcp.rs index 13533f0..f33a7f8 100644 --- a/src/servers/protocol/tcp.rs +++ b/src/servers/protocol/tcp.rs @@ -8,7 +8,7 @@ use tokio::io; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -pub async fn proxy(config: Arc) -> Result<(), Box> { +pub(crate) async fn proxy(config: Arc) -> Result<(), Box> { let listener = TcpListener::bind(config.listen).await?; let config = config.clone(); @@ -81,11 +81,6 @@ async fn accept(inbound: TcpStream, proxy: Arc) -> Result<(), Box u.resolve_addresses().await?, - _ => {} - } - return process(inbound, upstream.clone()).await; } @@ -104,10 +99,9 @@ async fn process( debug!("Bytes read: {:?}", bytes_tx); } Upstream::Custom(custom) => { - custom.resolve_addresses().await?; let outbound = match custom.protocol.as_ref() { "tcp4" | "tcp6" | "tcp" => { - TcpStream::connect(custom.get_addresses().await.as_slice()).await? + TcpStream::connect(custom.resolve_addresses().await?.as_slice()).await? } _ => { error!("Reached unknown protocol: {:?}", custom.protocol); diff --git a/src/servers/upstream_address.rs b/src/servers/upstream_address.rs new file mode 100644 index 0000000..3220a12 --- /dev/null +++ b/src/servers/upstream_address.rs @@ -0,0 +1,115 @@ +use log::debug; +use std::fmt::{Display, Formatter}; +use std::io::Result; +use std::net::SocketAddr; +use time::{Duration, Instant, OffsetDateTime}; + +#[derive(Debug, Clone, Default)] +pub(crate) struct UpstreamAddress { + address: String, + resolved_addresses: Vec, + resolved_time: Option, + ttl: Option, +} + +impl Display for UpstreamAddress { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.address.fmt(f) + } +} + +impl UpstreamAddress { + pub fn is_valid(&self) -> bool { + if let Some(resolved) = self.resolved_time { + if let Some(ttl) = self.ttl { + return resolved.elapsed() < ttl; + } + } + + false + } + + fn is_resolved(&self) -> bool { + self.resolved_addresses.len() > 0 + } + + fn time_remaining(&self) -> Duration { + if !self.is_valid() { + return Duration::seconds(0); + } + + self.ttl.unwrap() - self.resolved_time.unwrap().elapsed() + } + + pub async fn resolve(&mut self, mode: ResolutionMode) -> Result> { + if self.is_resolved() && self.is_valid() { + debug!( + "Already got address {:?}, still valid for {}", + &self.resolved_addresses, + self.time_remaining() + ); + return Ok(self.resolved_addresses.clone()); + } + + debug!("Resolving addresses for {}", &self.address); + + let lookup_result = tokio::net::lookup_host(&self.address).await; + + let resolved_addresses = match lookup_result { + Ok(resolved_addresses) => resolved_addresses, + Err(e) => { + // Protect against DNS flooding. Cache the result for 1 second. + self.resolved_time = Some(Instant::now()); + self.ttl = Some(Duration::seconds(3)); + return Err(e); + } + }; + + let addresses: Vec = match mode { + ResolutionMode::Ipv4 => resolved_addresses + .into_iter() + .filter(|a| a.is_ipv4()) + .collect(), + + ResolutionMode::Ipv6 => resolved_addresses + .into_iter() + .filter(|a| a.is_ipv6()) + .collect(), + + _ => resolved_addresses.collect(), + }; + + debug!("Got addresses for {}: {:?}", &self.address, &addresses); + debug!( + "Resolved at {}", + OffsetDateTime::now_utc() + .format(&time::format_description::well_known::Rfc3339) + .expect("Format") + ); + + self.resolved_addresses = addresses; + self.resolved_time = Some(Instant::now()); + self.ttl = Some(Duration::minutes(1)); + + Ok(self.resolved_addresses.clone()) + } +} + +#[derive(Debug, Default, Clone)] +pub(crate) enum ResolutionMode { + #[default] + Ipv4AndIpv6, + Ipv4, + Ipv6, +} + +impl From<&str> for ResolutionMode { + fn from(value: &str) -> Self { + match value { + "tcp4" => ResolutionMode::Ipv4, + "tcp6" => ResolutionMode::Ipv6, + "tcp" => ResolutionMode::Ipv4AndIpv6, + _ => panic!("This should never happen. Please check configuration parser."), + } + } +}