First commit
This commit is contained in:
76
src/config.rs
Normal file
76
src/config.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use log::debug;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{Error as IOError, Read};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub base: BaseConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Clone)]
|
||||
pub struct BaseConfig {
|
||||
pub version: i32,
|
||||
pub log: Option<String>,
|
||||
pub servers: HashMap<String, ServerConfig>,
|
||||
pub upstream: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
pub listen: Vec<String>,
|
||||
pub tls: Option<bool>,
|
||||
pub sni: Option<HashMap<String, String>>,
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
IO(IOError),
|
||||
Yaml(serde_yaml::Error),
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(path: &str) -> Result<Config, ConfigError> {
|
||||
let base = (load_config(path))?;
|
||||
|
||||
Ok(Config { base })
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config(path: &str) -> Result<BaseConfig, ConfigError> {
|
||||
let mut contents = String::new();
|
||||
let mut file = (File::open(path))?;
|
||||
(file.read_to_string(&mut contents))?;
|
||||
|
||||
let parsed: BaseConfig = serde_yaml::from_str(&contents)?;
|
||||
|
||||
if parsed.version != 1 {
|
||||
return Err(ConfigError::Custom(
|
||||
"Unsupported config version".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let log_level = parsed.log.clone().unwrap_or_else(|| "info".to_string());
|
||||
std::env::set_var("FOURTH_LOG", log_level.clone());
|
||||
pretty_env_logger::init_custom_env("FOURTH_LOG");
|
||||
debug!("Set log level to {}", log_level);
|
||||
|
||||
debug!("Config version {}", parsed.version);
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
|
||||
impl From<IOError> for ConfigError {
|
||||
fn from(err: IOError) -> ConfigError {
|
||||
ConfigError::IO(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_yaml::Error> for ConfigError {
|
||||
fn from(err: serde_yaml::Error) -> ConfigError {
|
||||
ConfigError::Yaml(err)
|
||||
}
|
||||
}
|
24
src/main.rs
Normal file
24
src/main.rs
Normal file
@ -0,0 +1,24 @@
|
||||
mod config;
|
||||
mod servers;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::servers::Server;
|
||||
|
||||
use log::{debug, error};
|
||||
|
||||
fn main() {
|
||||
let config = match Config::new("/etc/fourth/config.yaml") {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
println!("Could not load config: {:?}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
debug!("{:?}", config);
|
||||
|
||||
let mut server = Server::new(config.base);
|
||||
debug!("{:?}", server);
|
||||
|
||||
let res = server.run();
|
||||
error!("Server returned an error: {:?}", res);
|
||||
}
|
208
src/servers/mod.rs
Normal file
208
src/servers/mod.rs
Normal file
@ -0,0 +1,208 @@
|
||||
use futures::future::try_join;
|
||||
use log::{debug, error, info, warn};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
mod tls;
|
||||
use self::tls::get_sni;
|
||||
use crate::config::BaseConfig;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Server {
|
||||
pub proxies: Vec<Arc<Proxy>>,
|
||||
pub config: BaseConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Proxy {
|
||||
pub name: String,
|
||||
pub listen: SocketAddr,
|
||||
pub tls: bool,
|
||||
pub sni: Option<HashMap<String, String>>,
|
||||
pub default: String,
|
||||
pub upstream: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub fn new(config: BaseConfig) -> Self {
|
||||
let mut new_server = Server {
|
||||
proxies: Vec::new(),
|
||||
config: config.clone(),
|
||||
};
|
||||
|
||||
for (name, proxy) in config.servers.iter() {
|
||||
let tls = proxy.tls.unwrap_or(false);
|
||||
let sni = proxy.sni.clone();
|
||||
let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string());
|
||||
let upstream = config.upstream.clone();
|
||||
let mut upstream_set: HashSet<String> = HashSet::new();
|
||||
for (key, _) in &upstream {
|
||||
if key.eq("ban") || key.eq("echo") {
|
||||
continue;
|
||||
}
|
||||
upstream_set.insert(key.clone());
|
||||
}
|
||||
for listen in proxy.listen.clone() {
|
||||
println!("{:?}", listen);
|
||||
let listen_addr: SocketAddr = match listen.parse() {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => {
|
||||
error!("Invalid listen address: {}", listen);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let proxy = Proxy {
|
||||
name: name.clone(),
|
||||
listen: listen_addr,
|
||||
tls,
|
||||
sni: sni.clone(),
|
||||
default: default.clone(),
|
||||
upstream: upstream.clone(),
|
||||
};
|
||||
new_server.proxies.push(Arc::new(proxy));
|
||||
}
|
||||
}
|
||||
|
||||
new_server
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let proxies = self.proxies.clone();
|
||||
let mut handles: Vec<JoinHandle<()>> = Vec::new();
|
||||
|
||||
for config in proxies {
|
||||
info!("Starting server {} on {}", config.name, config.listen);
|
||||
let handle = tokio::spawn(async move {
|
||||
let _ = proxy(config).await;
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let listener = TcpListener::bind(config.listen).await?;
|
||||
let config = config.clone();
|
||||
|
||||
loop {
|
||||
let thread_proxy = config.clone();
|
||||
match listener.accept().await {
|
||||
Err(err) => {
|
||||
error!("Failed to accept connection: {}", err);
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
Ok((stream, _)) => {
|
||||
tokio::spawn(async move {
|
||||
match accept(stream, thread_proxy).await {
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
error!("Relay thread returned an error: {}", err);
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
debug!("New connection from {:?}", inbound.peer_addr()?);
|
||||
|
||||
let upstream_name = match proxy.tls {
|
||||
false => proxy.default.clone(),
|
||||
true => {
|
||||
let mut hello_buf = [0u8; 1024];
|
||||
inbound.peek(&mut hello_buf).await?;
|
||||
let snis = get_sni(&hello_buf);
|
||||
if snis.is_empty() {
|
||||
proxy.default.clone()
|
||||
} else {
|
||||
match proxy.sni.clone() {
|
||||
Some(sni_map) => {
|
||||
let mut upstream = proxy.default.clone();
|
||||
for sni in snis {
|
||||
let m = sni_map.get(&sni);
|
||||
if m.is_some() {
|
||||
upstream = m.unwrap().clone();
|
||||
break;
|
||||
}
|
||||
}
|
||||
upstream
|
||||
}
|
||||
None => proxy.default.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Upstream: {}", upstream_name);
|
||||
|
||||
let upstream = match proxy.upstream.get(&upstream_name) {
|
||||
Some(upstream) => upstream,
|
||||
None => {
|
||||
warn!(
|
||||
"No upstream named {:?} on server {:?}",
|
||||
proxy.default, proxy.name
|
||||
);
|
||||
return process(inbound, &proxy.default).await;
|
||||
}
|
||||
};
|
||||
return process(inbound, upstream).await;
|
||||
}
|
||||
|
||||
async fn process(mut inbound: TcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if upstream == "ban" {
|
||||
let _ = inbound.shutdown();
|
||||
return Ok(());
|
||||
} else if upstream == "echo" {
|
||||
loop {
|
||||
let mut buf = [0u8; 1];
|
||||
let b = inbound.read(&mut buf).await?;
|
||||
if b == 0 {
|
||||
break;
|
||||
} else {
|
||||
inbound.write(&buf).await?;
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let outbound = TcpStream::connect(upstream).await?;
|
||||
|
||||
let (mut ri, mut wi) = io::split(inbound);
|
||||
let (mut ro, mut wo) = io::split(outbound);
|
||||
|
||||
let inbound_to_outbound = copy(&mut ri, &mut wo);
|
||||
let outbound_to_inbound = copy(&mut ro, &mut wi);
|
||||
|
||||
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
|
||||
|
||||
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
|
||||
where
|
||||
R: AsyncRead + Unpin + ?Sized,
|
||||
W: AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
match io::copy(reader, writer).await {
|
||||
Ok(u64) => {
|
||||
writer.shutdown().await?;
|
||||
Ok(u64)
|
||||
}
|
||||
Err(_) => Ok(0),
|
||||
}
|
||||
}
|
62
src/servers/tls.rs
Normal file
62
src/servers/tls.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use log::{debug, warn};
|
||||
use tls_parser::{
|
||||
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
|
||||
TlsMessageHandshake,
|
||||
};
|
||||
|
||||
pub fn get_sni(buf: &[u8]) -> Vec<String> {
|
||||
let mut snis: Vec<String> = Vec::new();
|
||||
match parse_tls_raw_record(buf) {
|
||||
Ok((_, ref r)) => {
|
||||
match parse_tls_record_with_header(r.data, &r.hdr) {
|
||||
Ok((_, ref msg_list)) => {
|
||||
for msg in msg_list {
|
||||
match *msg {
|
||||
TlsMessage::Handshake(ref m) => match *m {
|
||||
TlsMessageHandshake::ClientHello(ref content) => {
|
||||
debug!("TLS ClientHello version: {}", content.version);
|
||||
let ext = parse_tls_extensions(content.ext.unwrap_or(b""));
|
||||
match ext {
|
||||
Ok((_, ref extensions)) => {
|
||||
for ext in extensions {
|
||||
match *ext {
|
||||
tls_parser::TlsExtension::SNI(ref v) => {
|
||||
for &(t, sni) in v {
|
||||
match String::from_utf8(sni.to_vec()) {
|
||||
Ok(s) => {
|
||||
debug!("TLS SNI: {} {}", t, s);
|
||||
snis.push(s);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse SNI: {} {}", t, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("TLS extensions error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse TLS: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse TLS: {}", err);
|
||||
}
|
||||
}
|
||||
|
||||
snis
|
||||
}
|
Reference in New Issue
Block a user