Add config validation

This commit is contained in:
KernelErr 2021-11-01 13:45:47 +08:00
parent 47be2568ba
commit 0407f4b40c
3 changed files with 85 additions and 19 deletions

View File

@ -1,6 +1,6 @@
use log::debug;
use log::{debug, warn};
use serde::Deserialize;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{Error as IOError, Read};
use url::Url;
@ -93,7 +93,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
Ok(url) => url,
Err(_) => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"",
"Invalid upstream url {}",
upstream
)))
}
@ -103,7 +103,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
Some(host) => host,
None => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"",
"Invalid upstream url {}",
upstream
)))
}
@ -113,12 +113,19 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
Some(port) => port,
None => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url \"{}\"",
"Invalid upstream url {}",
upstream
)))
}
};
if upstream_url.scheme() != "tcp" {
return Err(ConfigError::Custom(format!(
"Invalid upstream scheme {}",
upstream
)));
}
parsed_upstream.insert(
name.to_string(),
Upstream::Custom(CustomUpstream {
@ -129,15 +136,9 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
);
}
parsed_upstream.insert(
"ban".to_string(),
Upstream::Ban,
);
parsed_upstream.insert("ban".to_string(), Upstream::Ban);
parsed_upstream.insert(
"echo".to_string(),
Upstream::Echo,
);
parsed_upstream.insert("echo".to_string(), Upstream::Echo);
let parsed = ParsedConfig {
version: base.version,
@ -146,9 +147,66 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
upstream: parsed_upstream,
};
// ToDo: validate config
verify_config(parsed)
}
Ok(parsed)
fn verify_config(config: ParsedConfig) -> Result<ParsedConfig, ConfigError> {
let mut used_upstreams: HashSet<String> = HashSet::new();
let mut upstream_names: HashSet<String> = HashSet::new();
let mut listen_addresses: HashSet<String> = HashSet::new();
// Check for duplicate upstream names
for (name, _) in config.upstream.iter() {
if upstream_names.contains(name) {
return Err(ConfigError::Custom(format!(
"Duplicate upstream name {}",
name
)));
}
upstream_names.insert(name.to_string());
}
for (_, server) in config.servers.clone() {
// check for duplicate listen addresses
for listen in server.listen {
if listen_addresses.contains(&listen) {
return Err(ConfigError::Custom(format!(
"Duplicate listen address {}",
listen
)));
}
listen_addresses.insert(listen.to_string());
}
if server.tls.unwrap_or_default() && server.sni.is_some() {
for (_, val) in server.sni.unwrap() {
used_upstreams.insert(val.to_string());
}
}
if server.default.is_some() {
used_upstreams.insert(server.default.unwrap().to_string());
}
for key in &used_upstreams {
if !config.upstream.contains_key(key) {
return Err(ConfigError::Custom(format!(
"Upstream {} not found",
key
)));
}
}
}
for key in &upstream_names {
if !used_upstreams.contains(key) {
warn!("Upstream {} not used", key);
}
}
Ok(config)
}
impl From<IOError> for ConfigError {

View File

@ -53,13 +53,17 @@ async fn accept(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await;
// ToDo: Remove unwrap and check default option
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: KcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> {
async fn process(
mut inbound: KcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
let _ = inbound.shutdown();

View File

@ -72,13 +72,17 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await;
// ToDo: Remove unwrap and check default option
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: TcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> {
async fn process(
mut inbound: TcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
let _ = inbound.shutdown();