Solve synchronization issue

The async mutex in the previous variant would fail when used in a single
threaded mode, because block_in_place() cannot be used there.

Instead, replace the code with a Arc<RwLock> inside of the
UpstreamAddress to let that class take care of its own mutability.

Signed-off-by: Jacob Kiers <code@kiers.eu>
This commit is contained in:
Jacob Kiers 2024-02-23 22:56:28 +01:00
parent 59c7128f93
commit 97b4bf6bbe
2 changed files with 24 additions and 29 deletions

View File

@ -2,14 +2,16 @@ use log::debug;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::io::Result; use std::io::Result;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use time::{Duration, Instant, OffsetDateTime}; use time::{Duration, Instant, OffsetDateTime};
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub(crate) struct UpstreamAddress { pub(crate) struct UpstreamAddress {
address: String, address: String,
resolved_addresses: Vec<SocketAddr>, resolved_addresses: Arc<RwLock<Vec<SocketAddr>>>,
resolved_time: Option<Instant>, resolved_time: Arc<RwLock<Option<Instant>>>,
ttl: Option<Duration>, ttl: Arc<RwLock<Option<Duration>>>,
} }
impl Display for UpstreamAddress { impl Display for UpstreamAddress {
@ -27,8 +29,10 @@ impl UpstreamAddress {
} }
pub fn is_valid(&self) -> bool { pub fn is_valid(&self) -> bool {
if let Some(resolved) = self.resolved_time { let r = { *self.resolved_time.read().unwrap() };
if let Some(ttl) = self.ttl {
if let Some(resolved) = r {
if let Some(ttl) = { *self.ttl.read().unwrap() } {
return resolved.elapsed() < ttl; return resolved.elapsed() < ttl;
} }
} }
@ -37,7 +41,7 @@ impl UpstreamAddress {
} }
fn is_resolved(&self) -> bool { fn is_resolved(&self) -> bool {
!self.resolved_addresses.is_empty() !self.resolved_addresses.read().unwrap().is_empty()
} }
fn time_remaining(&self) -> Duration { fn time_remaining(&self) -> Duration {
@ -45,17 +49,19 @@ impl UpstreamAddress {
return Duration::seconds(0); return Duration::seconds(0);
} }
self.ttl.unwrap() - self.resolved_time.unwrap().elapsed() let rt = { *self.resolved_time.read().unwrap() };
let ttl = { *self.ttl.read().unwrap() };
ttl.unwrap() - rt.unwrap().elapsed()
} }
pub async fn resolve(&mut self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> { pub async fn resolve(&self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> {
if self.is_resolved() && self.is_valid() { if self.is_resolved() && self.is_valid() {
debug!( debug!(
"Already got address {:?}, still valid for {:.3}s", "Already got address {:?}, still valid for {:.3}s",
&self.resolved_addresses, &self.resolved_addresses,
self.time_remaining().as_seconds_f64() self.time_remaining().as_seconds_f64()
); );
return Ok(self.resolved_addresses.clone()); return Ok(self.resolved_addresses.read().unwrap().clone());
} }
debug!( debug!(
@ -70,8 +76,8 @@ impl UpstreamAddress {
Err(e) => { Err(e) => {
debug!("Failed looking up {}: {}", &self.address, &e); debug!("Failed looking up {}: {}", &self.address, &e);
// Protect against DNS flooding. Cache the result for 1 second. // Protect against DNS flooding. Cache the result for 1 second.
self.resolved_time = Some(Instant::now()); *self.resolved_time.write().unwrap() = Some(Instant::now());
self.ttl = Some(Duration::seconds(3)); *self.ttl.write().unwrap() = Some(Duration::seconds(3));
return Err(e); return Err(e);
} }
}; };
@ -103,11 +109,11 @@ impl UpstreamAddress {
.expect("Format") .expect("Format")
); );
self.resolved_addresses = addresses; *self.resolved_addresses.write().unwrap() = addresses.clone();
self.resolved_time = Some(Instant::now()); *self.resolved_time.write().unwrap() = Some(Instant::now());
self.ttl = Some(Duration::minutes(1)); *self.ttl.write().unwrap() = Some(Duration::minutes(1));
Ok(self.resolved_addresses.clone()) Ok(addresses)
} }
} }

View File

@ -7,36 +7,25 @@ use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io; use tokio::io;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex;
#[derive(Debug, Default)]
struct Addr(Mutex<UpstreamAddress>);
impl Clone for Addr {
fn clone(&self) -> Self {
tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone())))
}
}
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct ProxyToUpstream { pub struct ProxyToUpstream {
pub addr: String, pub addr: String,
pub protocol: String, pub protocol: String,
#[serde(skip_deserializing)] #[serde(skip_deserializing)]
addresses: Addr, addresses: UpstreamAddress,
} }
impl ProxyToUpstream { impl ProxyToUpstream {
pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> { pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
let mut addr = self.addresses.0.lock().await; self.addresses.resolve((*self.protocol).into()).await
addr.resolve((*self.protocol).into()).await
} }
pub fn new(address: String, protocol: String) -> Self { pub fn new(address: String, protocol: String) -> Self {
Self { Self {
addr: address.clone(), addr: address.clone(),
protocol, protocol,
addresses: Addr(Mutex::new(UpstreamAddress::new(address))), addresses: UpstreamAddress::new(address),
} }
} }