From 0407f4b40cb45a22a4d535189de9cb474659be5a Mon Sep 17 00:00:00 2001 From: KernelErr <45716019+KernelErr@users.noreply.github.com> Date: Mon, 1 Nov 2021 13:45:47 +0800 Subject: [PATCH] Add config validation --- src/config.rs | 88 ++++++++++++++++++++++++++++++------- src/servers/protocol/kcp.rs | 8 +++- src/servers/protocol/tcp.rs | 8 +++- 3 files changed, 85 insertions(+), 19 deletions(-) diff --git a/src/config.rs b/src/config.rs index 5e21a29..ec5622a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ -use log::debug; +use log::{debug, warn}; use serde::Deserialize; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{Error as IOError, Read}; use url::Url; @@ -93,7 +93,7 @@ fn load_config(path: &str) -> Result { Ok(url) => url, Err(_) => { return Err(ConfigError::Custom(format!( - "Invalid upstream url \"{}\"", + "Invalid upstream url {}", upstream ))) } @@ -103,7 +103,7 @@ fn load_config(path: &str) -> Result { Some(host) => host, None => { return Err(ConfigError::Custom(format!( - "Invalid upstream url \"{}\"", + "Invalid upstream url {}", upstream ))) } @@ -113,12 +113,19 @@ fn load_config(path: &str) -> Result { Some(port) => port, None => { return Err(ConfigError::Custom(format!( - "Invalid upstream url \"{}\"", + "Invalid upstream url {}", upstream ))) } }; + if upstream_url.scheme() != "tcp" { + return Err(ConfigError::Custom(format!( + "Invalid upstream scheme {}", + upstream + ))); + } + parsed_upstream.insert( name.to_string(), Upstream::Custom(CustomUpstream { @@ -129,15 +136,9 @@ fn load_config(path: &str) -> Result { ); } - parsed_upstream.insert( - "ban".to_string(), - Upstream::Ban, - ); + parsed_upstream.insert("ban".to_string(), Upstream::Ban); - parsed_upstream.insert( - "echo".to_string(), - Upstream::Echo, - ); + parsed_upstream.insert("echo".to_string(), Upstream::Echo); let parsed = ParsedConfig { version: base.version, @@ -146,9 +147,66 @@ fn load_config(path: &str) -> Result { upstream: parsed_upstream, }; - // ToDo: validate config + verify_config(parsed) +} - Ok(parsed) +fn verify_config(config: ParsedConfig) -> Result { + let mut used_upstreams: HashSet = HashSet::new(); + let mut upstream_names: HashSet = HashSet::new(); + let mut listen_addresses: HashSet = HashSet::new(); + + // Check for duplicate upstream names + for (name, _) in config.upstream.iter() { + if upstream_names.contains(name) { + return Err(ConfigError::Custom(format!( + "Duplicate upstream name {}", + name + ))); + } + + upstream_names.insert(name.to_string()); + } + + for (_, server) in config.servers.clone() { + // check for duplicate listen addresses + for listen in server.listen { + if listen_addresses.contains(&listen) { + return Err(ConfigError::Custom(format!( + "Duplicate listen address {}", + listen + ))); + } + + listen_addresses.insert(listen.to_string()); + } + + if server.tls.unwrap_or_default() && server.sni.is_some() { + for (_, val) in server.sni.unwrap() { + used_upstreams.insert(val.to_string()); + } + } + + if server.default.is_some() { + used_upstreams.insert(server.default.unwrap().to_string()); + } + + for key in &used_upstreams { + if !config.upstream.contains_key(key) { + return Err(ConfigError::Custom(format!( + "Upstream {} not found", + key + ))); + } + } + } + + for key in &upstream_names { + if !used_upstreams.contains(key) { + warn!("Upstream {} not used", key); + } + } + + Ok(config) } impl From for ConfigError { diff --git a/src/servers/protocol/kcp.rs b/src/servers/protocol/kcp.rs index 195001f..f587431 100644 --- a/src/servers/protocol/kcp.rs +++ b/src/servers/protocol/kcp.rs @@ -53,13 +53,17 @@ async fn accept( "No upstream named {:?} on server {:?}", proxy.default, proxy.name ); - return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option + return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; + // ToDo: Remove unwrap and check default option } }; return process(inbound, upstream).await; } -async fn process(mut inbound: KcpStream, upstream: &Upstream) -> Result<(), Box> { +async fn process( + mut inbound: KcpStream, + upstream: &Upstream, +) -> Result<(), Box> { match upstream { Upstream::Ban => { let _ = inbound.shutdown(); diff --git a/src/servers/protocol/tcp.rs b/src/servers/protocol/tcp.rs index 4988b56..2340881 100644 --- a/src/servers/protocol/tcp.rs +++ b/src/servers/protocol/tcp.rs @@ -72,13 +72,17 @@ async fn accept(inbound: TcpStream, proxy: Arc) -> Result<(), Box Result<(), Box> { +async fn process( + mut inbound: TcpStream, + upstream: &Upstream, +) -> Result<(), Box> { match upstream { Upstream::Ban => { let _ = inbound.shutdown();