Solve synchronization issue
The async mutex in the previous variant would fail when used in a single threaded mode, because block_in_place() cannot be used there. Instead, replace the code with a Arc<RwLock> inside of the UpstreamAddress to let that class take care of its own mutability. Signed-off-by: Jacob Kiers <code@kiers.eu>
This commit is contained in:
parent
59c7128f93
commit
97b4bf6bbe
@ -2,14 +2,16 @@ use log::debug;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::io::Result;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use time::{Duration, Instant, OffsetDateTime};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct UpstreamAddress {
|
||||
address: String,
|
||||
resolved_addresses: Vec<SocketAddr>,
|
||||
resolved_time: Option<Instant>,
|
||||
ttl: Option<Duration>,
|
||||
resolved_addresses: Arc<RwLock<Vec<SocketAddr>>>,
|
||||
resolved_time: Arc<RwLock<Option<Instant>>>,
|
||||
ttl: Arc<RwLock<Option<Duration>>>,
|
||||
}
|
||||
|
||||
impl Display for UpstreamAddress {
|
||||
@ -27,8 +29,10 @@ impl UpstreamAddress {
|
||||
}
|
||||
|
||||
pub fn is_valid(&self) -> bool {
|
||||
if let Some(resolved) = self.resolved_time {
|
||||
if let Some(ttl) = self.ttl {
|
||||
let r = { *self.resolved_time.read().unwrap() };
|
||||
|
||||
if let Some(resolved) = r {
|
||||
if let Some(ttl) = { *self.ttl.read().unwrap() } {
|
||||
return resolved.elapsed() < ttl;
|
||||
}
|
||||
}
|
||||
@ -37,7 +41,7 @@ impl UpstreamAddress {
|
||||
}
|
||||
|
||||
fn is_resolved(&self) -> bool {
|
||||
!self.resolved_addresses.is_empty()
|
||||
!self.resolved_addresses.read().unwrap().is_empty()
|
||||
}
|
||||
|
||||
fn time_remaining(&self) -> Duration {
|
||||
@ -45,17 +49,19 @@ impl UpstreamAddress {
|
||||
return Duration::seconds(0);
|
||||
}
|
||||
|
||||
self.ttl.unwrap() - self.resolved_time.unwrap().elapsed()
|
||||
let rt = { *self.resolved_time.read().unwrap() };
|
||||
let ttl = { *self.ttl.read().unwrap() };
|
||||
ttl.unwrap() - rt.unwrap().elapsed()
|
||||
}
|
||||
|
||||
pub async fn resolve(&mut self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> {
|
||||
pub async fn resolve(&self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> {
|
||||
if self.is_resolved() && self.is_valid() {
|
||||
debug!(
|
||||
"Already got address {:?}, still valid for {:.3}s",
|
||||
&self.resolved_addresses,
|
||||
self.time_remaining().as_seconds_f64()
|
||||
);
|
||||
return Ok(self.resolved_addresses.clone());
|
||||
return Ok(self.resolved_addresses.read().unwrap().clone());
|
||||
}
|
||||
|
||||
debug!(
|
||||
@ -70,8 +76,8 @@ impl UpstreamAddress {
|
||||
Err(e) => {
|
||||
debug!("Failed looking up {}: {}", &self.address, &e);
|
||||
// Protect against DNS flooding. Cache the result for 1 second.
|
||||
self.resolved_time = Some(Instant::now());
|
||||
self.ttl = Some(Duration::seconds(3));
|
||||
*self.resolved_time.write().unwrap() = Some(Instant::now());
|
||||
*self.ttl.write().unwrap() = Some(Duration::seconds(3));
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
@ -103,11 +109,11 @@ impl UpstreamAddress {
|
||||
.expect("Format")
|
||||
);
|
||||
|
||||
self.resolved_addresses = addresses;
|
||||
self.resolved_time = Some(Instant::now());
|
||||
self.ttl = Some(Duration::minutes(1));
|
||||
*self.resolved_addresses.write().unwrap() = addresses.clone();
|
||||
*self.resolved_time.write().unwrap() = Some(Instant::now());
|
||||
*self.ttl.write().unwrap() = Some(Duration::minutes(1));
|
||||
|
||||
Ok(self.resolved_addresses.clone())
|
||||
Ok(addresses)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,36 +7,25 @@ use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct Addr(Mutex<UpstreamAddress>);
|
||||
|
||||
impl Clone for Addr {
|
||||
fn clone(&self) -> Self {
|
||||
tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone())))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ProxyToUpstream {
|
||||
pub addr: String,
|
||||
pub protocol: String,
|
||||
#[serde(skip_deserializing)]
|
||||
addresses: Addr,
|
||||
addresses: UpstreamAddress,
|
||||
}
|
||||
|
||||
impl ProxyToUpstream {
|
||||
pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
|
||||
let mut addr = self.addresses.0.lock().await;
|
||||
addr.resolve((*self.protocol).into()).await
|
||||
self.addresses.resolve((*self.protocol).into()).await
|
||||
}
|
||||
|
||||
pub fn new(address: String, protocol: String) -> Self {
|
||||
Self {
|
||||
addr: address.clone(),
|
||||
protocol,
|
||||
addresses: Addr(Mutex::new(UpstreamAddress::new(address))),
|
||||
addresses: UpstreamAddress::new(address),
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user