diff --git a/src/servers/upstream_address.rs b/src/servers/upstream_address.rs index 0e24124..6b2808c 100644 --- a/src/servers/upstream_address.rs +++ b/src/servers/upstream_address.rs @@ -2,14 +2,16 @@ use log::debug; use std::fmt::{Display, Formatter}; use std::io::Result; use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::RwLock; use time::{Duration, Instant, OffsetDateTime}; #[derive(Debug, Clone, Default)] pub(crate) struct UpstreamAddress { address: String, - resolved_addresses: Vec, - resolved_time: Option, - ttl: Option, + resolved_addresses: Arc>>, + resolved_time: Arc>>, + ttl: Arc>>, } impl Display for UpstreamAddress { @@ -27,8 +29,10 @@ impl UpstreamAddress { } pub fn is_valid(&self) -> bool { - if let Some(resolved) = self.resolved_time { - if let Some(ttl) = self.ttl { + let r = { *self.resolved_time.read().unwrap() }; + + if let Some(resolved) = r { + if let Some(ttl) = { *self.ttl.read().unwrap() } { return resolved.elapsed() < ttl; } } @@ -37,7 +41,7 @@ impl UpstreamAddress { } fn is_resolved(&self) -> bool { - !self.resolved_addresses.is_empty() + !self.resolved_addresses.read().unwrap().is_empty() } fn time_remaining(&self) -> Duration { @@ -45,17 +49,19 @@ impl UpstreamAddress { return Duration::seconds(0); } - self.ttl.unwrap() - self.resolved_time.unwrap().elapsed() + let rt = { *self.resolved_time.read().unwrap() }; + let ttl = { *self.ttl.read().unwrap() }; + ttl.unwrap() - rt.unwrap().elapsed() } - pub async fn resolve(&mut self, mode: ResolutionMode) -> Result> { + pub async fn resolve(&self, mode: ResolutionMode) -> Result> { if self.is_resolved() && self.is_valid() { debug!( "Already got address {:?}, still valid for {:.3}s", &self.resolved_addresses, self.time_remaining().as_seconds_f64() ); - return Ok(self.resolved_addresses.clone()); + return Ok(self.resolved_addresses.read().unwrap().clone()); } debug!( @@ -70,8 +76,8 @@ impl UpstreamAddress { Err(e) => { debug!("Failed looking up {}: {}", &self.address, &e); // Protect against DNS flooding. Cache the result for 1 second. - self.resolved_time = Some(Instant::now()); - self.ttl = Some(Duration::seconds(3)); + *self.resolved_time.write().unwrap() = Some(Instant::now()); + *self.ttl.write().unwrap() = Some(Duration::seconds(3)); return Err(e); } }; @@ -103,11 +109,11 @@ impl UpstreamAddress { .expect("Format") ); - self.resolved_addresses = addresses; - self.resolved_time = Some(Instant::now()); - self.ttl = Some(Duration::minutes(1)); + *self.resolved_addresses.write().unwrap() = addresses.clone(); + *self.resolved_time.write().unwrap() = Some(Instant::now()); + *self.ttl.write().unwrap() = Some(Duration::minutes(1)); - Ok(self.resolved_addresses.clone()) + Ok(addresses) } } diff --git a/src/upstreams/proxy_to_upstream.rs b/src/upstreams/proxy_to_upstream.rs index 4ca77a5..b60c31c 100644 --- a/src/upstreams/proxy_to_upstream.rs +++ b/src/upstreams/proxy_to_upstream.rs @@ -7,36 +7,25 @@ use serde::Deserialize; use std::net::SocketAddr; use tokio::io; use tokio::net::TcpStream; -use tokio::sync::Mutex; - -#[derive(Debug, Default)] -struct Addr(Mutex); - -impl Clone for Addr { - fn clone(&self) -> Self { - tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone()))) - } -} #[derive(Debug, Clone, Deserialize, Default)] pub struct ProxyToUpstream { pub addr: String, pub protocol: String, #[serde(skip_deserializing)] - addresses: Addr, + addresses: UpstreamAddress, } impl ProxyToUpstream { pub async fn resolve_addresses(&self) -> std::io::Result> { - let mut addr = self.addresses.0.lock().await; - addr.resolve((*self.protocol).into()).await + self.addresses.resolve((*self.protocol).into()).await } pub fn new(address: String, protocol: String) -> Self { Self { addr: address.clone(), protocol, - addresses: Addr(Mutex::new(UpstreamAddress::new(address))), + addresses: UpstreamAddress::new(address), } }