Add config validation
This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| use log::debug; | ||||
| use log::{debug, warn}; | ||||
| use serde::Deserialize; | ||||
| use std::collections::HashMap; | ||||
| use std::collections::{HashMap, HashSet}; | ||||
| use std::fs::File; | ||||
| use std::io::{Error as IOError, Read}; | ||||
| use url::Url; | ||||
| @@ -93,7 +93,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | ||||
|             Ok(url) => url, | ||||
|             Err(_) => { | ||||
|                 return Err(ConfigError::Custom(format!( | ||||
|                     "Invalid upstream url \"{}\"", | ||||
|                     "Invalid upstream url {}", | ||||
|                     upstream | ||||
|                 ))) | ||||
|             } | ||||
| @@ -103,7 +103,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | ||||
|             Some(host) => host, | ||||
|             None => { | ||||
|                 return Err(ConfigError::Custom(format!( | ||||
|                     "Invalid upstream url \"{}\"", | ||||
|                     "Invalid upstream url {}", | ||||
|                     upstream | ||||
|                 ))) | ||||
|             } | ||||
| @@ -113,12 +113,19 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | ||||
|             Some(port) => port, | ||||
|             None => { | ||||
|                 return Err(ConfigError::Custom(format!( | ||||
|                     "Invalid upstream url \"{}\"", | ||||
|                     "Invalid upstream url {}", | ||||
|                     upstream | ||||
|                 ))) | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         if upstream_url.scheme() != "tcp" { | ||||
|             return Err(ConfigError::Custom(format!( | ||||
|                 "Invalid upstream scheme {}", | ||||
|                 upstream | ||||
|             ))); | ||||
|         } | ||||
|  | ||||
|         parsed_upstream.insert( | ||||
|             name.to_string(), | ||||
|             Upstream::Custom(CustomUpstream { | ||||
| @@ -129,15 +136,9 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     parsed_upstream.insert( | ||||
|         "ban".to_string(), | ||||
|         Upstream::Ban, | ||||
|     ); | ||||
|     parsed_upstream.insert("ban".to_string(), Upstream::Ban); | ||||
|  | ||||
|     parsed_upstream.insert( | ||||
|         "echo".to_string(), | ||||
|         Upstream::Echo, | ||||
|     ); | ||||
|     parsed_upstream.insert("echo".to_string(), Upstream::Echo); | ||||
|  | ||||
|     let parsed = ParsedConfig { | ||||
|         version: base.version, | ||||
| @@ -146,9 +147,66 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | ||||
|         upstream: parsed_upstream, | ||||
|     }; | ||||
|  | ||||
|     // ToDo: validate config | ||||
|     verify_config(parsed) | ||||
| } | ||||
|  | ||||
|     Ok(parsed) | ||||
| fn verify_config(config: ParsedConfig) -> Result<ParsedConfig, ConfigError> { | ||||
|     let mut used_upstreams: HashSet<String> = HashSet::new(); | ||||
|     let mut upstream_names: HashSet<String> = HashSet::new(); | ||||
|     let mut listen_addresses: HashSet<String> = HashSet::new(); | ||||
|  | ||||
|     // Check for duplicate upstream names | ||||
|     for (name, _) in config.upstream.iter() { | ||||
|         if upstream_names.contains(name) { | ||||
|             return Err(ConfigError::Custom(format!( | ||||
|                 "Duplicate upstream name {}", | ||||
|                 name | ||||
|             ))); | ||||
|         } | ||||
|  | ||||
|         upstream_names.insert(name.to_string()); | ||||
|     } | ||||
|  | ||||
|     for (_, server) in config.servers.clone() { | ||||
|         // check for duplicate listen addresses | ||||
|         for listen in server.listen { | ||||
|             if listen_addresses.contains(&listen) { | ||||
|                 return Err(ConfigError::Custom(format!( | ||||
|                     "Duplicate listen address {}", | ||||
|                     listen | ||||
|                 ))); | ||||
|             } | ||||
|  | ||||
|             listen_addresses.insert(listen.to_string()); | ||||
|         } | ||||
|  | ||||
|         if server.tls.unwrap_or_default() && server.sni.is_some() { | ||||
|             for (_, val) in server.sni.unwrap() { | ||||
|                 used_upstreams.insert(val.to_string()); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if server.default.is_some() { | ||||
|             used_upstreams.insert(server.default.unwrap().to_string()); | ||||
|         } | ||||
|  | ||||
|         for key in &used_upstreams { | ||||
|             if !config.upstream.contains_key(key) { | ||||
|                 return Err(ConfigError::Custom(format!( | ||||
|                     "Upstream {} not found", | ||||
|                     key | ||||
|                 ))); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for key in &upstream_names { | ||||
|         if !used_upstreams.contains(key) { | ||||
|             warn!("Upstream {} not used", key); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(config) | ||||
| } | ||||
|  | ||||
| impl From<IOError> for ConfigError { | ||||
|   | ||||
| @@ -53,13 +53,17 @@ async fn accept( | ||||
|                 "No upstream named {:?} on server {:?}", | ||||
|                 proxy.default, proxy.name | ||||
|             ); | ||||
|             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option | ||||
|             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; | ||||
|             // ToDo: Remove unwrap and check default option | ||||
|         } | ||||
|     }; | ||||
|     return process(inbound, upstream).await; | ||||
| } | ||||
|  | ||||
| async fn process(mut inbound: KcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> { | ||||
| async fn process( | ||||
|     mut inbound: KcpStream, | ||||
|     upstream: &Upstream, | ||||
| ) -> Result<(), Box<dyn std::error::Error>> { | ||||
|     match upstream { | ||||
|         Upstream::Ban => { | ||||
|             let _ = inbound.shutdown(); | ||||
|   | ||||
| @@ -72,13 +72,17 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std | ||||
|                 "No upstream named {:?} on server {:?}", | ||||
|                 proxy.default, proxy.name | ||||
|             ); | ||||
|             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option | ||||
|             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; | ||||
|             // ToDo: Remove unwrap and check default option | ||||
|         } | ||||
|     }; | ||||
|     return process(inbound, upstream).await; | ||||
| } | ||||
|  | ||||
| async fn process(mut inbound: TcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> { | ||||
| async fn process( | ||||
|     mut inbound: TcpStream, | ||||
|     upstream: &Upstream, | ||||
| ) -> Result<(), Box<dyn std::error::Error>> { | ||||
|     match upstream { | ||||
|         Upstream::Ban => { | ||||
|             let _ = inbound.shutdown(); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 KernelErr
					KernelErr