Extract DNS address resolution

Signed-off-by: Jacob Kiers <code@kiers.eu>
This commit is contained in:
Jacob Kiers 2023-08-16 09:32:05 +02:00
parent 0c5153bbd6
commit 915e39b684
4 changed files with 126 additions and 45 deletions

View File

@ -1,3 +1,4 @@
use crate::servers::upstream_address::UpstreamAddress;
use log::{debug, warn}; use log::{debug, warn};
use serde::Deserialize; use serde::Deserialize;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
@ -6,8 +7,6 @@ use std::io::{Error as IOError, Read};
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use url::Url; use url::Url;
use tokio::time::Instant;
use time::OffsetDateTime;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
@ -47,7 +46,7 @@ pub enum Upstream {
} }
#[derive(Debug)] #[derive(Debug)]
struct Addr(Mutex<Vec<SocketAddr>>); struct Addr(Mutex<UpstreamAddress>);
impl Default for Addr { impl Default for Addr {
fn default() -> Self { fn default() -> Self {
@ -71,38 +70,9 @@ pub struct CustomUpstream {
} }
impl CustomUpstream { impl CustomUpstream {
pub async fn resolve_addresses(&self) -> std::io::Result<()> { pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
{ let mut addr = self.addresses.0.lock().await;
let addr = self.addresses.0.lock().await; addr.resolve((*self.protocol).into()).await
if addr.len() > 0 {
debug!("Already have addresses: {:?}", &addr);
return Ok(());
}
}
debug!("Resolving addresses for {}", &self.addr);
let addresses = tokio::net::lookup_host(self.addr.clone()).await?;
let mut addr: Vec<SocketAddr> = match self.protocol.as_ref() {
"tcp4" => addresses.into_iter().filter(|a| a.is_ipv4()).collect(),
"tcp6" => addresses.into_iter().filter(|a| a.is_ipv6()).collect(),
_ => addresses.collect(),
};
debug!("Got addresses for {}: {:?}", &self.addr, &addr);
debug!("Resolved at {}", OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).expect("Format"));
{
let mut self_addr = self.addresses.0.lock().await;
self_addr.clear();
self_addr.append(&mut addr);
}
Ok(())
}
pub async fn get_addresses(&self) -> Vec<SocketAddr> {
let a = self.addresses.0.lock().await;
a.clone()
} }
} }

View File

@ -5,17 +5,19 @@ use std::sync::Arc;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
mod protocol; mod protocol;
pub(crate) mod upstream_address;
use crate::config::{ParsedConfig, Upstream}; use crate::config::{ParsedConfig, Upstream};
use protocol::tcp; use protocol::tcp;
#[derive(Debug)] #[derive(Debug)]
pub struct Server { pub(crate) struct Server {
pub proxies: Vec<Arc<Proxy>>, pub proxies: Vec<Arc<Proxy>>,
pub config: ParsedConfig, pub config: ParsedConfig,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Proxy { pub(crate) struct Proxy {
pub name: String, pub name: String,
pub listen: SocketAddr, pub listen: SocketAddr,
pub protocol: String, pub protocol: String,

View File

@ -8,7 +8,7 @@ use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> { pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(config.listen).await?; let listener = TcpListener::bind(config.listen).await?;
let config = config.clone(); let config = config.clone();
@ -81,11 +81,6 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
} }
}; };
match upstream {
Upstream::Custom(u) => u.resolve_addresses().await?,
_ => {}
}
return process(inbound, upstream.clone()).await; return process(inbound, upstream.clone()).await;
} }
@ -104,10 +99,9 @@ async fn process(
debug!("Bytes read: {:?}", bytes_tx); debug!("Bytes read: {:?}", bytes_tx);
} }
Upstream::Custom(custom) => { Upstream::Custom(custom) => {
custom.resolve_addresses().await?;
let outbound = match custom.protocol.as_ref() { let outbound = match custom.protocol.as_ref() {
"tcp4" | "tcp6" | "tcp" => { "tcp4" | "tcp6" | "tcp" => {
TcpStream::connect(custom.get_addresses().await.as_slice()).await? TcpStream::connect(custom.resolve_addresses().await?.as_slice()).await?
} }
_ => { _ => {
error!("Reached unknown protocol: {:?}", custom.protocol); error!("Reached unknown protocol: {:?}", custom.protocol);

View File

@ -0,0 +1,115 @@
use log::debug;
use std::fmt::{Display, Formatter};
use std::io::Result;
use std::net::SocketAddr;
use time::{Duration, Instant, OffsetDateTime};
#[derive(Debug, Clone, Default)]
pub(crate) struct UpstreamAddress {
address: String,
resolved_addresses: Vec<SocketAddr>,
resolved_time: Option<Instant>,
ttl: Option<Duration>,
}
impl Display for UpstreamAddress {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.address.fmt(f)
}
}
impl UpstreamAddress {
pub fn is_valid(&self) -> bool {
if let Some(resolved) = self.resolved_time {
if let Some(ttl) = self.ttl {
return resolved.elapsed() < ttl;
}
}
false
}
fn is_resolved(&self) -> bool {
self.resolved_addresses.len() > 0
}
fn time_remaining(&self) -> Duration {
if !self.is_valid() {
return Duration::seconds(0);
}
self.ttl.unwrap() - self.resolved_time.unwrap().elapsed()
}
pub async fn resolve(&mut self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> {
if self.is_resolved() && self.is_valid() {
debug!(
"Already got address {:?}, still valid for {}",
&self.resolved_addresses,
self.time_remaining()
);
return Ok(self.resolved_addresses.clone());
}
debug!("Resolving addresses for {}", &self.address);
let lookup_result = tokio::net::lookup_host(&self.address).await;
let resolved_addresses = match lookup_result {
Ok(resolved_addresses) => resolved_addresses,
Err(e) => {
// Protect against DNS flooding. Cache the result for 1 second.
self.resolved_time = Some(Instant::now());
self.ttl = Some(Duration::seconds(3));
return Err(e);
}
};
let addresses: Vec<SocketAddr> = match mode {
ResolutionMode::Ipv4 => resolved_addresses
.into_iter()
.filter(|a| a.is_ipv4())
.collect(),
ResolutionMode::Ipv6 => resolved_addresses
.into_iter()
.filter(|a| a.is_ipv6())
.collect(),
_ => resolved_addresses.collect(),
};
debug!("Got addresses for {}: {:?}", &self.address, &addresses);
debug!(
"Resolved at {}",
OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.expect("Format")
);
self.resolved_addresses = addresses;
self.resolved_time = Some(Instant::now());
self.ttl = Some(Duration::minutes(1));
Ok(self.resolved_addresses.clone())
}
}
#[derive(Debug, Default, Clone)]
pub(crate) enum ResolutionMode {
#[default]
Ipv4AndIpv6,
Ipv4,
Ipv6,
}
impl From<&str> for ResolutionMode {
fn from(value: &str) -> Self {
match value {
"tcp4" => ResolutionMode::Ipv4,
"tcp6" => ResolutionMode::Ipv6,
"tcp" => ResolutionMode::Ipv4AndIpv6,
_ => panic!("This should never happen. Please check configuration parser."),
}
}
}