Add config validation

This commit is contained in:
KernelErr 2021-11-01 13:45:47 +08:00
parent 47be2568ba
commit 0407f4b40c
3 changed files with 85 additions and 19 deletions

View File

@ -1,6 +1,6 @@
use log::debug; use log::{debug, warn};
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::fs::File; use std::fs::File;
use std::io::{Error as IOError, Read}; use std::io::{Error as IOError, Read};
use url::Url; use url::Url;
@ -93,7 +93,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
Ok(url) => url, Ok(url) => url,
Err(_) => { Err(_) => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"", "Invalid upstream url {}",
upstream upstream
))) )))
} }
@ -103,7 +103,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
Some(host) => host, Some(host) => host,
None => { None => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"", "Invalid upstream url {}",
upstream upstream
))) )))
} }
@ -113,12 +113,19 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
Some(port) => port, Some(port) => port,
None => { None => {
return Err(ConfigError::Custom(format!( return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"", "Invalid upstream url {}",
upstream upstream
))) )))
} }
}; };
if upstream_url.scheme() != "tcp" {
return Err(ConfigError::Custom(format!(
"Invalid upstream scheme {}",
upstream
)));
}
parsed_upstream.insert( parsed_upstream.insert(
name.to_string(), name.to_string(),
Upstream::Custom(CustomUpstream { Upstream::Custom(CustomUpstream {
@ -129,15 +136,9 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
); );
} }
parsed_upstream.insert( parsed_upstream.insert("ban".to_string(), Upstream::Ban);
"ban".to_string(),
Upstream::Ban,
);
parsed_upstream.insert( parsed_upstream.insert("echo".to_string(), Upstream::Echo);
"echo".to_string(),
Upstream::Echo,
);
let parsed = ParsedConfig { let parsed = ParsedConfig {
version: base.version, version: base.version,
@ -146,9 +147,66 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
upstream: parsed_upstream, upstream: parsed_upstream,
}; };
// ToDo: validate config verify_config(parsed)
}
Ok(parsed) fn verify_config(config: ParsedConfig) -> Result<ParsedConfig, ConfigError> {
let mut used_upstreams: HashSet<String> = HashSet::new();
let mut upstream_names: HashSet<String> = HashSet::new();
let mut listen_addresses: HashSet<String> = HashSet::new();
// Check for duplicate upstream names
for (name, _) in config.upstream.iter() {
if upstream_names.contains(name) {
return Err(ConfigError::Custom(format!(
"Duplicate upstream name {}",
name
)));
}
upstream_names.insert(name.to_string());
}
for (_, server) in config.servers.clone() {
// check for duplicate listen addresses
for listen in server.listen {
if listen_addresses.contains(&listen) {
return Err(ConfigError::Custom(format!(
"Duplicate listen address {}",
listen
)));
}
listen_addresses.insert(listen.to_string());
}
if server.tls.unwrap_or_default() && server.sni.is_some() {
for (_, val) in server.sni.unwrap() {
used_upstreams.insert(val.to_string());
}
}
if server.default.is_some() {
used_upstreams.insert(server.default.unwrap().to_string());
}
for key in &used_upstreams {
if !config.upstream.contains_key(key) {
return Err(ConfigError::Custom(format!(
"Upstream {} not found",
key
)));
}
}
}
for key in &upstream_names {
if !used_upstreams.contains(key) {
warn!("Upstream {} not used", key);
}
}
Ok(config)
} }
impl From<IOError> for ConfigError { impl From<IOError> for ConfigError {

View File

@ -53,13 +53,17 @@ async fn accept(
"No upstream named {:?} on server {:?}", "No upstream named {:?} on server {:?}",
proxy.default, proxy.name proxy.default, proxy.name
); );
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await;
// ToDo: Remove unwrap and check default option
} }
}; };
return process(inbound, upstream).await; return process(inbound, upstream).await;
} }
async fn process(mut inbound: KcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> { async fn process(
mut inbound: KcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream { match upstream {
Upstream::Ban => { Upstream::Ban => {
let _ = inbound.shutdown(); let _ = inbound.shutdown();

View File

@ -72,13 +72,17 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
"No upstream named {:?} on server {:?}", "No upstream named {:?} on server {:?}",
proxy.default, proxy.name proxy.default, proxy.name
); );
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await;
// ToDo: Remove unwrap and check default option
} }
}; };
return process(inbound, upstream).await; return process(inbound, upstream).await;
} }
async fn process(mut inbound: TcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> { async fn process(
mut inbound: TcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream { match upstream {
Upstream::Ban => { Upstream::Ban => {
let _ = inbound.shutdown(); let _ = inbound.shutdown();