Enable explicit ipv4 / ipv6 proxying

Signed-off-by: Jacob Kiers <code@kiers.eu>
This commit is contained in:
Jacob Kiers 2023-06-02 17:35:29 +02:00
parent f010f8c76b
commit f4bc441ca8
6 changed files with 154 additions and 55 deletions

View File

@ -3,7 +3,11 @@ use serde::Deserialize;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fs::File; use std::fs::File;
use std::io::{Error as IOError, Read}; use std::io::{Error as IOError, Read};
use std::net::SocketAddr;
use tokio::sync::Mutex;
use url::Url; use url::Url;
use tokio::time::Instant;
use time::OffsetDateTime;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
@ -42,11 +46,75 @@ pub enum Upstream {
Custom(CustomUpstream), Custom(CustomUpstream),
} }
#[derive(Debug)]
struct Addr(Mutex<Vec<SocketAddr>>);
impl Default for Addr {
fn default() -> Self {
Self(Default::default())
}
}
impl Clone for Addr {
fn clone(&self) -> Self {
tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone())))
}
}
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct CustomUpstream { pub struct CustomUpstream {
pub name: String, pub name: String,
pub addr: String, pub addr: String,
pub protocol: String, pub protocol: String,
#[serde(skip_deserializing)]
addresses: Addr,
}
impl CustomUpstream {
pub async fn resolve_addresses(&self) -> std::io::Result<()> {
{
let addr = self.addresses.0.lock().await;
if addr.len() > 0 {
debug!("Already have addresses: {:?}", &addr);
return Ok(());
}
}
debug!("Resolving addresses for {}", &self.addr);
let addresses = tokio::net::lookup_host(self.addr.clone()).await?;
let mut addr: Vec<SocketAddr> = match self.protocol.as_ref() {
"tcp4" => addresses.into_iter().filter(|a| a.is_ipv4()).collect(),
"tcp6" => addresses.into_iter().filter(|a| a.is_ipv6()).collect(),
_ => addresses.collect(),
};
debug!("Got addresses for {}: {:?}", &self.addr, &addr);
debug!("Resolved at {}", OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).expect("Format"));
{
let mut self_addr = self.addresses.0.lock().await;
self_addr.clear();
self_addr.append(&mut addr);
}
Ok(())
}
pub async fn get_addresses(&self) -> Vec<SocketAddr> {
let a = self.addresses.0.lock().await;
a.clone()
}
}
impl Default for CustomUpstream {
fn default() -> Self {
Self {
name: Default::default(),
addr: Default::default(),
protocol: Default::default(),
addresses: Default::default(),
}
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -119,11 +187,14 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
} }
}; };
if upstream_url.scheme() != "tcp" { match upstream_url.scheme() {
return Err(ConfigError::Custom(format!( "tcp" | "tcp4" | "tcp6" => {}
"Invalid upstream scheme {}", _ => {
upstream return Err(ConfigError::Custom(format!(
))); "Invalid upstream scheme {}",
upstream
)))
}
} }
parsed_upstream.insert( parsed_upstream.insert(
@ -132,6 +203,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
name: name.to_string(), name: name.to_string(),
addr: format!("{}:{}", upstream_host, upsteam_port), addr: format!("{}:{}", upstream_host, upsteam_port),
protocol: upstream_url.scheme().to_string(), protocol: upstream_url.scheme().to_string(),
..Default::default()
}), }),
); );
} }

View File

@ -1 +1 @@
pub mod kcp; //pub mod kcp;

View File

@ -6,7 +6,7 @@ use tokio::task::JoinHandle;
mod protocol; mod protocol;
use crate::config::{ParsedConfig, Upstream}; use crate::config::{ParsedConfig, Upstream};
use protocol::{kcp, tcp}; use protocol::tcp;
#[derive(Debug)] #[derive(Debug)]
pub struct Server { pub struct Server {
@ -88,12 +88,24 @@ impl Server {
error!("Failed to start {}: {}", config.name, res.err().unwrap()); error!("Failed to start {}: {}", config.name, res.err().unwrap());
} }
} }
"kcp" => { "tcp4" => {
let res = kcp::proxy(config.clone()).await; let res = tcp::proxy(config.clone()).await;
if res.is_err() { if res.is_err() {
error!("Failed to start {}: {}", config.name, res.err().unwrap()); error!("Failed to start {}: {}", config.name, res.err().unwrap());
} }
} }
"tcp6" => {
let res = tcp::proxy(config.clone()).await;
if res.is_err() {
error!("Failed to start {}: {}", config.name, res.err().unwrap());
}
}
// "kcp" => {
// let res = kcp::proxy(config.clone()).await;
// if res.is_err() {
// error!("Failed to start {}: {}", config.name, res.err().unwrap());
// }
// }
_ => { _ => {
error!("Invalid protocol: {}", config.protocol) error!("Invalid protocol: {}", config.protocol)
} }
@ -111,12 +123,11 @@ impl Server {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::plugins::kcp::{KcpConfig, KcpStream}; //use crate::plugins::kcp::{KcpConfig, KcpStream};
use std::net::SocketAddr;
use std::thread::{self, sleep}; use std::thread::{self, sleep};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::TcpListener;
use super::*; use super::*;
@ -155,7 +166,9 @@ mod tests {
sleep(Duration::from_secs(1)); // wait for server to start sleep(Duration::from_secs(1)); // wait for server to start
// test TCP proxy // test TCP proxy
let mut conn = TcpStream::connect("127.0.0.1:54500").await.unwrap(); let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54500")
.await
.unwrap();
let mut buf = [0u8; 5]; let mut buf = [0u8; 5];
conn.write(b"hi").await.unwrap(); conn.write(b"hi").await.unwrap();
conn.read(&mut buf).await.unwrap(); conn.read(&mut buf).await.unwrap();
@ -163,7 +176,9 @@ mod tests {
conn.shutdown().await.unwrap(); conn.shutdown().await.unwrap();
// test TCP echo // test TCP echo
let mut conn = TcpStream::connect("127.0.0.1:54956").await.unwrap(); let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54956")
.await
.unwrap();
let mut buf = [0u8; 1]; let mut buf = [0u8; 1];
for i in 0..=10u8 { for i in 0..=10u8 {
conn.write(&[i]).await.unwrap(); conn.write(&[i]).await.unwrap();
@ -173,25 +188,25 @@ mod tests {
conn.shutdown().await.unwrap(); conn.shutdown().await.unwrap();
// test KCP echo // test KCP echo
let kcp_config = KcpConfig::default(); // let kcp_config = KcpConfig::default();
let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap(); // let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap();
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap(); // let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
let mut buf = [0u8; 1]; // let mut buf = [0u8; 1];
for i in 0..=10u8 { // for i in 0..=10u8 {
conn.write(&[i]).await.unwrap(); // conn.write(&[i]).await.unwrap();
conn.read(&mut buf).await.unwrap(); // conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, &[i]); // assert_eq!(&buf, &[i]);
} // }
conn.shutdown().await.unwrap(); // conn.shutdown().await.unwrap();
//
// test KCP proxy and close mock server // // test KCP proxy and close mock server
let kcp_config = KcpConfig::default(); // let kcp_config = KcpConfig::default();
let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap(); // let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap();
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap(); // let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
let mut buf = [0u8; 5]; // let mut buf = [0u8; 5];
conn.write(b"by").await.unwrap(); // conn.write(b"by").await.unwrap();
conn.read(&mut buf).await.unwrap(); // conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello"); // assert_eq!(&buf, b"hello");
conn.shutdown().await.unwrap(); // conn.shutdown().await.unwrap();
} }
} }

View File

@ -1,3 +1,3 @@
pub mod kcp; //pub mod kcp;
pub mod tcp; pub mod tcp;
pub mod tls; pub mod tls;

View File

@ -2,7 +2,7 @@ use crate::config::Upstream;
use crate::servers::protocol::tls::get_sni; use crate::servers::protocol::tls::get_sni;
use crate::servers::Proxy; use crate::servers::Proxy;
use futures::future::try_join; use futures::future::try_join;
use log::{debug, error, warn}; use log::{debug, error, info, warn};
use std::sync::Arc; use std::sync::Arc;
use tokio::io; use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@ -34,7 +34,7 @@ pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>>
} }
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> { async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", inbound.peer_addr()?); info!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls { let upstream_name = match proxy.tls {
false => proxy.default.clone(), false => proxy.default.clone(),
@ -72,16 +72,22 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
"No upstream named {:?} on server {:?}", "No upstream named {:?} on server {:?}",
proxy.default, proxy.name proxy.default, proxy.name
); );
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; return process(inbound, proxy.upstream.get(&proxy.default).unwrap().clone()).await;
// ToDo: Remove unwrap and check default option // ToDo: Remove unwrap and check default option
} }
}; };
return process(inbound, upstream).await;
match upstream {
Upstream::Custom(u) => u.resolve_addresses().await?,
_ => {}
}
return process(inbound, upstream.clone()).await;
} }
async fn process( async fn process(
mut inbound: TcpStream, mut inbound: TcpStream,
upstream: &Upstream, upstream: Upstream,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
match upstream { match upstream {
Upstream::Ban => { Upstream::Ban => {
@ -93,25 +99,30 @@ async fn process(
let bytes_tx = inbound_to_inbound.await; let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx); debug!("Bytes read: {:?}", bytes_tx);
} }
Upstream::Custom(custom) => match custom.protocol.as_ref() { Upstream::Custom(custom) => {
"tcp" => { custom.resolve_addresses().await?;
let outbound = TcpStream::connect(custom.addr.clone()).await?; let outbound = match custom.protocol.as_ref() {
"tcp4" | "tcp6" | "tcp" => {
TcpStream::connect(custom.get_addresses().await.as_slice()).await?
}
_ => {
error!("Reached unknown protocol: {:?}", custom.protocol);
return Err("Reached unknown protocol".into());
}
};
let (mut ri, mut wi) = io::split(inbound); debug!("Connected to {:?}", outbound.peer_addr().unwrap());
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo); let (mut ri, mut wi) = io::split(inbound);
let outbound_to_inbound = copy(&mut ro, &mut wi); let (mut ro, mut wo) = io::split(outbound);
let (bytes_tx, bytes_rx) = let inbound_to_outbound = copy(&mut ri, &mut wo);
try_join(inbound_to_outbound, outbound_to_inbound).await?; let outbound_to_inbound = copy(&mut ro, &mut wi);
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx); let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
}
_ => { debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
error!("Reached unknown protocol: {:?}", custom.protocol); }
}
},
}; };
Ok(()) Ok(())
} }

View File

@ -49,6 +49,7 @@ pub fn get_sni(buf: &[u8]) -> Vec<String> {
} }
} }
debug!("Found SNIs: {:?}", &snis);
snis snis
} }