Moved upstreams to their own dedicated namespace

Signed-off-by: Jacob Kiers <code@kiers.eu>
This commit is contained in:
Jacob Kiers 2023-10-05 00:23:34 +02:00
parent 2116659a14
commit 3a2367ef28
6 changed files with 143 additions and 117 deletions

View File

@ -1,11 +1,10 @@
use crate::servers::upstream_address::UpstreamAddress; use crate::upstreams::ProxyToUpstream;
use crate::upstreams::Upstream;
use log::{debug, warn}; use log::{debug, warn};
use serde::Deserialize; use serde::Deserialize;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fs::File; use std::fs::File;
use std::io::{Error as IOError, Read}; use std::io::{Error as IOError, Read};
use std::net::SocketAddr;
use tokio::sync::Mutex;
use url::Url; use url::Url;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -37,48 +36,16 @@ pub struct ServerConfig {
pub sni: Option<HashMap<String, String>>, pub sni: Option<HashMap<String, String>>,
pub default: Option<String>, pub default: Option<String>,
} }
impl TryInto<ProxyToUpstream> for &str {
#[derive(Debug, Clone, Deserialize)]
pub enum Upstream {
Ban,
Echo,
Proxy(ProxyToUpstream),
}
#[derive(Debug, Default)]
struct Addr(Mutex<UpstreamAddress>);
impl Clone for Addr {
fn clone(&self) -> Self {
tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone())))
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ProxyToUpstream {
pub addr: String,
pub protocol: String,
#[serde(skip_deserializing)]
addresses: Addr,
}
impl ProxyToUpstream {
pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
let mut addr = self.addresses.0.lock().await;
addr.resolve((*self.protocol).into()).await
}
}
impl TryFrom<&str> for ProxyToUpstream {
type Error = ConfigError; type Error = ConfigError;
fn try_from(upstream: &str) -> Result<Self, Self::Error> { fn try_into(self) -> Result<ProxyToUpstream, Self::Error> {
let upstream_url = match Url::parse(upstream) { let upstream_url = match Url::parse(self) {
Ok(url) => url, Ok(url) => url,
Err(_) => { Err(_) => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream url {}", "Invalid upstream url {}",
upstream self
))) )))
} }
}; };
@ -88,7 +55,7 @@ impl TryFrom<&str> for ProxyToUpstream {
None => { None => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream url {}", "Invalid upstream url {}",
upstream self
))) )))
} }
}; };
@ -98,7 +65,7 @@ impl TryFrom<&str> for ProxyToUpstream {
None => { None => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream url {}", "Invalid upstream url {}",
upstream self
))) )))
} }
}; };
@ -108,17 +75,15 @@ impl TryFrom<&str> for ProxyToUpstream {
_ => { _ => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream scheme {}", "Invalid upstream scheme {}",
upstream self
))) )))
} }
} }
let addr = UpstreamAddress::new(format!("{}:{}", upstream_host, upstream_port)); Ok(ProxyToUpstream::new(
Ok(ProxyToUpstream { format!("{}:{}", upstream_host, upstream_port),
addr: format!("{}:{}", upstream_host, upstream_port), upstream_url.scheme().to_string(),
protocol: upstream_url.scheme().to_string(), ))
addresses: Addr(Mutex::new(addr)),
})
} }
} }
@ -165,7 +130,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
parsed_upstream.insert("echo".to_string(), Upstream::Echo); parsed_upstream.insert("echo".to_string(), Upstream::Echo);
for (name, upstream) in base.upstream.iter() { for (name, upstream) in base.upstream.iter() {
let ups = ProxyToUpstream::try_from(upstream.as_str())?; let ups = upstream.as_str().try_into()?;
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups)); parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
} }

View File

@ -1,6 +1,7 @@
mod config; mod config;
mod plugins; mod plugins;
mod servers; mod servers;
mod upstreams;
use crate::config::Config; use crate::config::Config;
use crate::servers::Server; use crate::servers::Server;

View File

@ -2,14 +2,14 @@ use log::{error, info};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
mod protocol; mod protocol;
pub(crate) mod upstream_address; pub(crate) mod upstream_address;
use crate::config::{ParsedConfig, Upstream}; use crate::config::ParsedConfig;
use crate::upstreams::Upstream;
use protocol::tcp; use protocol::tcp;
#[derive(Debug)] #[derive(Debug)]
@ -212,17 +212,3 @@ mod tests {
// conn.shutdown().await.unwrap(); // conn.shutdown().await.unwrap();
} }
} }
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}

View File

@ -1,14 +1,11 @@
use crate::config::Upstream;
use crate::servers::protocol::tls::get_sni; use crate::servers::protocol::tls::get_sni;
use crate::servers::{copy, Proxy}; use crate::servers::Proxy;
use futures::future::try_join;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use std::error::Error;
use std::sync::Arc; use std::sync::Arc;
use tokio::io;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> { pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(config.listen).await?; let listener = TcpListener::bind(config.listen).await?;
let config = config.clone(); let config = config.clone();
@ -33,7 +30,7 @@ pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::
} }
} }
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> { async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
info!("New connection from {:?}", inbound.peer_addr()?); info!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls { let upstream_name = match proxy.tls {
@ -72,51 +69,9 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
"No upstream named {:?} on server {:?}", "No upstream named {:?} on server {:?}",
proxy.default_action, proxy.name proxy.default_action, proxy.name
); );
return process(inbound, proxy.upstream.get(&proxy.default_action).unwrap()).await; proxy.upstream.get(&proxy.default_action).unwrap()
// ToDo: Remove unwrap and check default option
} }
}; };
process(inbound, upstream).await upstream.process(inbound).await
}
async fn process(
mut inbound: TcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
inbound.shutdown().await?;
}
Upstream::Echo => {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
}
Upstream::Proxy(config) => {
let outbound = match config.protocol.as_ref() {
"tcp4" | "tcp6" | "tcp" => {
TcpStream::connect(config.resolve_addresses().await?.as_slice()).await?
}
_ => {
error!("Reached unknown protocol: {:?}", config.protocol);
return Err("Reached unknown protocol".into());
}
};
debug!("Connected to {:?}", outbound.peer_addr().unwrap());
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
}
};
Ok(())
} }

51
src/upstreams/mod.rs Normal file
View File

@ -0,0 +1,51 @@
mod proxy_to_upstream;
use log::debug;
use serde::Deserialize;
use std::error::Error;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
pub use crate::upstreams::proxy_to_upstream::ProxyToUpstream;
#[derive(Debug, Clone, Deserialize)]
pub enum Upstream {
Ban,
Echo,
Proxy(ProxyToUpstream),
}
impl Upstream {
pub(crate) async fn process(&self, mut inbound: TcpStream) -> Result<(), Box<dyn Error>> {
match self {
Upstream::Ban => {
inbound.shutdown().await?;
}
Upstream::Echo => {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
}
Upstream::Proxy(config) => {
config.proxy(inbound).await?;
}
};
Ok(())
}
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}

View File

@ -0,0 +1,68 @@
use crate::servers::upstream_address::UpstreamAddress;
use crate::upstreams::copy;
use futures::future::try_join;
use log::{debug, error};
use serde::Deserialize;
use std::net::SocketAddr;
use tokio::io;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
#[derive(Debug, Default)]
struct Addr(Mutex<UpstreamAddress>);
impl Clone for Addr {
fn clone(&self) -> Self {
tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone())))
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ProxyToUpstream {
pub addr: String,
pub protocol: String,
#[serde(skip_deserializing)]
addresses: Addr,
}
impl ProxyToUpstream {
pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
let mut addr = self.addresses.0.lock().await;
addr.resolve((*self.protocol).into()).await
}
pub fn new(address: String, protocol: String) -> Self {
Self {
addr: address.clone(),
protocol,
addresses: Addr(Mutex::new(UpstreamAddress::new(address))),
}
}
pub(crate) async fn proxy(&self, inbound: TcpStream) -> Result<(), Box<dyn std::error::Error>> {
let outbound = match self.protocol.as_ref() {
"tcp4" | "tcp6" | "tcp" => {
TcpStream::connect(self.resolve_addresses().await?.as_slice()).await?
}
_ => {
error!("Reached unknown protocol: {:?}", self.protocol);
return Err("Reached unknown protocol".into());
}
};
debug!("Connected to {:?}", outbound.peer_addr().unwrap());
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
}