Updates #1
@ -1,3 +1,4 @@
|
|||||||
|
use crate::servers::upstream_address::UpstreamAddress;
|
||||||
use log::{debug, warn};
|
use log::{debug, warn};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
@ -6,8 +7,6 @@ use std::io::{Error as IOError, Read};
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use tokio::time::Instant;
|
|
||||||
use time::OffsetDateTime;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
@ -47,7 +46,7 @@ pub enum Upstream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Addr(Mutex<Vec<SocketAddr>>);
|
struct Addr(Mutex<UpstreamAddress>);
|
||||||
|
|
||||||
impl Default for Addr {
|
impl Default for Addr {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@ -71,38 +70,9 @@ pub struct CustomUpstream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl CustomUpstream {
|
impl CustomUpstream {
|
||||||
pub async fn resolve_addresses(&self) -> std::io::Result<()> {
|
pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
|
||||||
{
|
let mut addr = self.addresses.0.lock().await;
|
||||||
let addr = self.addresses.0.lock().await;
|
addr.resolve((*self.protocol).into()).await
|
||||||
if addr.len() > 0 {
|
|
||||||
debug!("Already have addresses: {:?}", &addr);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!("Resolving addresses for {}", &self.addr);
|
|
||||||
let addresses = tokio::net::lookup_host(self.addr.clone()).await?;
|
|
||||||
|
|
||||||
let mut addr: Vec<SocketAddr> = match self.protocol.as_ref() {
|
|
||||||
"tcp4" => addresses.into_iter().filter(|a| a.is_ipv4()).collect(),
|
|
||||||
"tcp6" => addresses.into_iter().filter(|a| a.is_ipv6()).collect(),
|
|
||||||
_ => addresses.collect(),
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!("Got addresses for {}: {:?}", &self.addr, &addr);
|
|
||||||
debug!("Resolved at {}", OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).expect("Format"));
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut self_addr = self.addresses.0.lock().await;
|
|
||||||
self_addr.clear();
|
|
||||||
self_addr.append(&mut addr);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_addresses(&self) -> Vec<SocketAddr> {
|
|
||||||
let a = self.addresses.0.lock().await;
|
|
||||||
a.clone()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,17 +5,19 @@ use std::sync::Arc;
|
|||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
mod protocol;
|
mod protocol;
|
||||||
|
pub(crate) mod upstream_address;
|
||||||
|
|
||||||
use crate::config::{ParsedConfig, Upstream};
|
use crate::config::{ParsedConfig, Upstream};
|
||||||
use protocol::tcp;
|
use protocol::tcp;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Server {
|
pub(crate) struct Server {
|
||||||
pub proxies: Vec<Arc<Proxy>>,
|
pub proxies: Vec<Arc<Proxy>>,
|
||||||
pub config: ParsedConfig,
|
pub config: ParsedConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Proxy {
|
pub(crate) struct Proxy {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub listen: SocketAddr,
|
pub listen: SocketAddr,
|
||||||
pub protocol: String,
|
pub protocol: String,
|
||||||
|
@ -8,7 +8,7 @@ use tokio::io;
|
|||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
|
pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let listener = TcpListener::bind(config.listen).await?;
|
let listener = TcpListener::bind(config.listen).await?;
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
|
|
||||||
@ -81,11 +81,6 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match upstream {
|
|
||||||
Upstream::Custom(u) => u.resolve_addresses().await?,
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
return process(inbound, upstream.clone()).await;
|
return process(inbound, upstream.clone()).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,10 +99,9 @@ async fn process(
|
|||||||
debug!("Bytes read: {:?}", bytes_tx);
|
debug!("Bytes read: {:?}", bytes_tx);
|
||||||
}
|
}
|
||||||
Upstream::Custom(custom) => {
|
Upstream::Custom(custom) => {
|
||||||
custom.resolve_addresses().await?;
|
|
||||||
let outbound = match custom.protocol.as_ref() {
|
let outbound = match custom.protocol.as_ref() {
|
||||||
"tcp4" | "tcp6" | "tcp" => {
|
"tcp4" | "tcp6" | "tcp" => {
|
||||||
TcpStream::connect(custom.get_addresses().await.as_slice()).await?
|
TcpStream::connect(custom.resolve_addresses().await?.as_slice()).await?
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
error!("Reached unknown protocol: {:?}", custom.protocol);
|
error!("Reached unknown protocol: {:?}", custom.protocol);
|
||||||
|
115
src/servers/upstream_address.rs
Normal file
115
src/servers/upstream_address.rs
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
use log::debug;
|
||||||
|
use std::fmt::{Display, Formatter};
|
||||||
|
use std::io::Result;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use time::{Duration, Instant, OffsetDateTime};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub(crate) struct UpstreamAddress {
|
||||||
|
address: String,
|
||||||
|
resolved_addresses: Vec<SocketAddr>,
|
||||||
|
resolved_time: Option<Instant>,
|
||||||
|
ttl: Option<Duration>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for UpstreamAddress {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.address.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamAddress {
|
||||||
|
pub fn is_valid(&self) -> bool {
|
||||||
|
if let Some(resolved) = self.resolved_time {
|
||||||
|
if let Some(ttl) = self.ttl {
|
||||||
|
return resolved.elapsed() < ttl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_resolved(&self) -> bool {
|
||||||
|
self.resolved_addresses.len() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn time_remaining(&self) -> Duration {
|
||||||
|
if !self.is_valid() {
|
||||||
|
return Duration::seconds(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ttl.unwrap() - self.resolved_time.unwrap().elapsed()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn resolve(&mut self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> {
|
||||||
|
if self.is_resolved() && self.is_valid() {
|
||||||
|
debug!(
|
||||||
|
"Already got address {:?}, still valid for {}",
|
||||||
|
&self.resolved_addresses,
|
||||||
|
self.time_remaining()
|
||||||
|
);
|
||||||
|
return Ok(self.resolved_addresses.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("Resolving addresses for {}", &self.address);
|
||||||
|
|
||||||
|
let lookup_result = tokio::net::lookup_host(&self.address).await;
|
||||||
|
|
||||||
|
let resolved_addresses = match lookup_result {
|
||||||
|
Ok(resolved_addresses) => resolved_addresses,
|
||||||
|
Err(e) => {
|
||||||
|
// Protect against DNS flooding. Cache the result for 1 second.
|
||||||
|
self.resolved_time = Some(Instant::now());
|
||||||
|
self.ttl = Some(Duration::seconds(3));
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let addresses: Vec<SocketAddr> = match mode {
|
||||||
|
ResolutionMode::Ipv4 => resolved_addresses
|
||||||
|
.into_iter()
|
||||||
|
.filter(|a| a.is_ipv4())
|
||||||
|
.collect(),
|
||||||
|
|
||||||
|
ResolutionMode::Ipv6 => resolved_addresses
|
||||||
|
.into_iter()
|
||||||
|
.filter(|a| a.is_ipv6())
|
||||||
|
.collect(),
|
||||||
|
|
||||||
|
_ => resolved_addresses.collect(),
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Got addresses for {}: {:?}", &self.address, &addresses);
|
||||||
|
debug!(
|
||||||
|
"Resolved at {}",
|
||||||
|
OffsetDateTime::now_utc()
|
||||||
|
.format(&time::format_description::well_known::Rfc3339)
|
||||||
|
.expect("Format")
|
||||||
|
);
|
||||||
|
|
||||||
|
self.resolved_addresses = addresses;
|
||||||
|
self.resolved_time = Some(Instant::now());
|
||||||
|
self.ttl = Some(Duration::minutes(1));
|
||||||
|
|
||||||
|
Ok(self.resolved_addresses.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub(crate) enum ResolutionMode {
|
||||||
|
#[default]
|
||||||
|
Ipv4AndIpv6,
|
||||||
|
Ipv4,
|
||||||
|
Ipv6,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ResolutionMode {
|
||||||
|
fn from(value: &str) -> Self {
|
||||||
|
match value {
|
||||||
|
"tcp4" => ResolutionMode::Ipv4,
|
||||||
|
"tcp6" => ResolutionMode::Ipv6,
|
||||||
|
"tcp" => ResolutionMode::Ipv4AndIpv6,
|
||||||
|
_ => panic!("This should never happen. Please check configuration parser."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user