Add upstream scheme support

Need to implement TCP and UDP upstream support.
This commit is contained in:
KernelErr
2021-10-31 19:21:32 +08:00
parent 5944beb6a2
commit 47be2568ba
9 changed files with 254 additions and 89 deletions

View File

@ -3,10 +3,19 @@ use serde::Deserialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::{Error as IOError, Read};
use url::Url;
#[derive(Debug, Clone)]
pub struct Config {
pub base: BaseConfig,
pub base: ParsedConfig,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ParsedConfig {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)]
@ -26,6 +35,20 @@ pub struct ServerConfig {
pub default: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub enum Upstream {
Ban,
Echo,
Custom(CustomUpstream),
}
#[derive(Debug, Clone, Deserialize)]
pub struct CustomUpstream {
pub name: String,
pub addr: String,
pub protocol: String,
}
#[derive(Debug)]
pub enum ConfigError {
IO(IOError),
@ -41,27 +64,89 @@ impl Config {
}
}
fn load_config(path: &str) -> Result<BaseConfig, ConfigError> {
fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
let mut contents = String::new();
let mut file = (File::open(path))?;
(file.read_to_string(&mut contents))?;
let parsed: BaseConfig = serde_yaml::from_str(&contents)?;
let base: BaseConfig = serde_yaml::from_str(&contents)?;
if parsed.version != 1 {
if base.version != 1 {
return Err(ConfigError::Custom(
"Unsupported config version".to_string(),
));
}
let log_level = parsed.log.clone().unwrap_or_else(|| "info".to_string());
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
if !log_level.eq("disable") {
std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG");
debug!("Set log level to {}", log_level);
}
debug!("Config version {}", parsed.version);
debug!("Config version {}", base.version);
let mut parsed_upstream: HashMap<String, Upstream> = HashMap::new();
for (name, upstream) in base.upstream.iter() {
let upstream_url = match Url::parse(upstream) {
Ok(url) => url,
Err(_) => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"",
upstream
)))
}
};
let upstream_host = match upstream_url.host_str() {
Some(host) => host,
None => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"",
upstream
)))
}
};
let upsteam_port = match upstream_url.port_or_known_default() {
Some(port) => port,
None => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"",
upstream
)))
}
};
parsed_upstream.insert(
name.to_string(),
Upstream::Custom(CustomUpstream {
name: name.to_string(),
addr: format!("{}:{}", upstream_host, upsteam_port),
protocol: upstream_url.scheme().to_string(),
}),
);
}
parsed_upstream.insert(
"ban".to_string(),
Upstream::Ban,
);
parsed_upstream.insert(
"echo".to_string(),
Upstream::Echo,
);
let parsed = ParsedConfig {
version: base.version,
log: base.log,
servers: base.servers,
upstream: parsed_upstream,
};
// ToDo: validate config
Ok(parsed)
}
@ -88,6 +173,6 @@ mod tests {
assert_eq!(config.base.version, 1);
assert_eq!(config.base.log.unwrap(), "disable");
assert_eq!(config.base.servers.len(), 5);
assert_eq!(config.base.upstream.len(), 3);
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
}
}

View File

@ -5,13 +5,13 @@ use std::sync::Arc;
use tokio::task::JoinHandle;
mod protocol;
use crate::config::BaseConfig;
use crate::config::{ParsedConfig, Upstream};
use protocol::{kcp, tcp};
#[derive(Debug)]
pub struct Server {
pub proxies: Vec<Arc<Proxy>>,
pub config: BaseConfig,
pub config: ParsedConfig,
}
#[derive(Debug, Clone)]
@ -22,11 +22,11 @@ pub struct Proxy {
pub tls: bool,
pub sni: Option<HashMap<String, String>>,
pub default: String,
pub upstream: HashMap<String, String>,
pub upstream: HashMap<String, Upstream>,
}
impl Server {
pub fn new(config: BaseConfig) -> Self {
pub fn new(config: ParsedConfig) -> Self {
let mut new_server = Server {
proxies: Vec::new(),
config: config.clone(),
@ -53,6 +53,7 @@ impl Server {
continue;
}
};
let proxy = Proxy {
name: name.clone(),
listen: listen_addr,
@ -103,7 +104,7 @@ impl Server {
}
#[cfg(test)]
mod test {
mod tests {
use crate::plugins::kcp::{KcpConfig, KcpStream};
use std::net::SocketAddr;
use std::thread::{self, sleep};

View File

@ -1,3 +1,4 @@
use crate::config::Upstream;
use crate::plugins::kcp::{KcpConfig, KcpListener, KcpStream};
use crate::servers::Proxy;
use futures::future::try_join;
@ -52,37 +53,41 @@ async fn accept(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, &proxy.default).await;
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: KcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
if upstream == "ban" {
let _ = inbound.shutdown();
return Ok(());
} else if upstream == "echo" {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
return Ok(());
async fn process(mut inbound: KcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
let _ = inbound.shutdown();
Ok(())
}
Upstream::Echo => {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
Ok(())
}
Upstream::Custom(custom) => {
let outbound = TcpStream::connect(custom.addr.clone()).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
}
let outbound = TcpStream::connect(upstream).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>

View File

@ -1,3 +1,4 @@
use crate::config::Upstream;
use crate::servers::protocol::tls::get_sni;
use crate::servers::Proxy;
use futures::future::try_join;
@ -71,37 +72,41 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, &proxy.default).await;
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: TcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
if upstream == "ban" {
let _ = inbound.shutdown();
return Ok(());
} else if upstream == "echo" {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
return Ok(());
async fn process(mut inbound: TcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
let _ = inbound.shutdown();
Ok(())
}
Upstream::Echo => {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
Ok(())
}
Upstream::Custom(custom) => {
let outbound = TcpStream::connect(custom.addr.clone()).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
}
let outbound = TcpStream::connect(upstream).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>