Add config validation
This commit is contained in:
		| @@ -1,6 +1,6 @@ | |||||||
| use log::debug; | use log::{debug, warn}; | ||||||
| use serde::Deserialize; | use serde::Deserialize; | ||||||
| use std::collections::HashMap; | use std::collections::{HashMap, HashSet}; | ||||||
| use std::fs::File; | use std::fs::File; | ||||||
| use std::io::{Error as IOError, Read}; | use std::io::{Error as IOError, Read}; | ||||||
| use url::Url; | use url::Url; | ||||||
| @@ -93,7 +93,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | |||||||
|             Ok(url) => url, |             Ok(url) => url, | ||||||
|             Err(_) => { |             Err(_) => { | ||||||
|                 return Err(ConfigError::Custom(format!( |                 return Err(ConfigError::Custom(format!( | ||||||
|                     "Invalid upstream url \"{}\"", |                     "Invalid upstream url {}", | ||||||
|                     upstream |                     upstream | ||||||
|                 ))) |                 ))) | ||||||
|             } |             } | ||||||
| @@ -103,7 +103,7 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | |||||||
|             Some(host) => host, |             Some(host) => host, | ||||||
|             None => { |             None => { | ||||||
|                 return Err(ConfigError::Custom(format!( |                 return Err(ConfigError::Custom(format!( | ||||||
|                     "Invalid upstream url \"{}\"", |                     "Invalid upstream url {}", | ||||||
|                     upstream |                     upstream | ||||||
|                 ))) |                 ))) | ||||||
|             } |             } | ||||||
| @@ -113,12 +113,19 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | |||||||
|             Some(port) => port, |             Some(port) => port, | ||||||
|             None => { |             None => { | ||||||
|                 return Err(ConfigError::Custom(format!( |                 return Err(ConfigError::Custom(format!( | ||||||
|                     "Invalid upstream url \"{}\"", |                     "Invalid upstream url {}", | ||||||
|                     upstream |                     upstream | ||||||
|                 ))) |                 ))) | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|  |         if upstream_url.scheme() != "tcp" { | ||||||
|  |             return Err(ConfigError::Custom(format!( | ||||||
|  |                 "Invalid upstream scheme {}", | ||||||
|  |                 upstream | ||||||
|  |             ))); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         parsed_upstream.insert( |         parsed_upstream.insert( | ||||||
|             name.to_string(), |             name.to_string(), | ||||||
|             Upstream::Custom(CustomUpstream { |             Upstream::Custom(CustomUpstream { | ||||||
| @@ -129,15 +136,9 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | |||||||
|         ); |         ); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     parsed_upstream.insert( |     parsed_upstream.insert("ban".to_string(), Upstream::Ban); | ||||||
|         "ban".to_string(), |  | ||||||
|         Upstream::Ban, |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     parsed_upstream.insert( |     parsed_upstream.insert("echo".to_string(), Upstream::Echo); | ||||||
|         "echo".to_string(), |  | ||||||
|         Upstream::Echo, |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     let parsed = ParsedConfig { |     let parsed = ParsedConfig { | ||||||
|         version: base.version, |         version: base.version, | ||||||
| @@ -146,9 +147,66 @@ fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> { | |||||||
|         upstream: parsed_upstream, |         upstream: parsed_upstream, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     // ToDo: validate config |     verify_config(parsed) | ||||||
|  | } | ||||||
|  |  | ||||||
|     Ok(parsed) | fn verify_config(config: ParsedConfig) -> Result<ParsedConfig, ConfigError> { | ||||||
|  |     let mut used_upstreams: HashSet<String> = HashSet::new(); | ||||||
|  |     let mut upstream_names: HashSet<String> = HashSet::new(); | ||||||
|  |     let mut listen_addresses: HashSet<String> = HashSet::new(); | ||||||
|  |  | ||||||
|  |     // Check for duplicate upstream names | ||||||
|  |     for (name, _) in config.upstream.iter() { | ||||||
|  |         if upstream_names.contains(name) { | ||||||
|  |             return Err(ConfigError::Custom(format!( | ||||||
|  |                 "Duplicate upstream name {}", | ||||||
|  |                 name | ||||||
|  |             ))); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         upstream_names.insert(name.to_string()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (_, server) in config.servers.clone() { | ||||||
|  |         // check for duplicate listen addresses | ||||||
|  |         for listen in server.listen { | ||||||
|  |             if listen_addresses.contains(&listen) { | ||||||
|  |                 return Err(ConfigError::Custom(format!( | ||||||
|  |                     "Duplicate listen address {}", | ||||||
|  |                     listen | ||||||
|  |                 ))); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             listen_addresses.insert(listen.to_string()); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if server.tls.unwrap_or_default() && server.sni.is_some() { | ||||||
|  |             for (_, val) in server.sni.unwrap() { | ||||||
|  |                 used_upstreams.insert(val.to_string()); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if server.default.is_some() { | ||||||
|  |             used_upstreams.insert(server.default.unwrap().to_string()); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         for key in &used_upstreams { | ||||||
|  |             if !config.upstream.contains_key(key) { | ||||||
|  |                 return Err(ConfigError::Custom(format!( | ||||||
|  |                     "Upstream {} not found", | ||||||
|  |                     key | ||||||
|  |                 ))); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for key in &upstream_names { | ||||||
|  |         if !used_upstreams.contains(key) { | ||||||
|  |             warn!("Upstream {} not used", key); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     Ok(config) | ||||||
| } | } | ||||||
|  |  | ||||||
| impl From<IOError> for ConfigError { | impl From<IOError> for ConfigError { | ||||||
|   | |||||||
| @@ -53,13 +53,17 @@ async fn accept( | |||||||
|                 "No upstream named {:?} on server {:?}", |                 "No upstream named {:?} on server {:?}", | ||||||
|                 proxy.default, proxy.name |                 proxy.default, proxy.name | ||||||
|             ); |             ); | ||||||
|             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option |             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; | ||||||
|  |             // ToDo: Remove unwrap and check default option | ||||||
|         } |         } | ||||||
|     }; |     }; | ||||||
|     return process(inbound, upstream).await; |     return process(inbound, upstream).await; | ||||||
| } | } | ||||||
|  |  | ||||||
| async fn process(mut inbound: KcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> { | async fn process( | ||||||
|  |     mut inbound: KcpStream, | ||||||
|  |     upstream: &Upstream, | ||||||
|  | ) -> Result<(), Box<dyn std::error::Error>> { | ||||||
|     match upstream { |     match upstream { | ||||||
|         Upstream::Ban => { |         Upstream::Ban => { | ||||||
|             let _ = inbound.shutdown(); |             let _ = inbound.shutdown(); | ||||||
|   | |||||||
| @@ -72,13 +72,17 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std | |||||||
|                 "No upstream named {:?} on server {:?}", |                 "No upstream named {:?} on server {:?}", | ||||||
|                 proxy.default, proxy.name |                 proxy.default, proxy.name | ||||||
|             ); |             ); | ||||||
|             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; // ToDo: Remove unwrap and check default option |             return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await; | ||||||
|  |             // ToDo: Remove unwrap and check default option | ||||||
|         } |         } | ||||||
|     }; |     }; | ||||||
|     return process(inbound, upstream).await; |     return process(inbound, upstream).await; | ||||||
| } | } | ||||||
|  |  | ||||||
| async fn process(mut inbound: TcpStream, upstream: &Upstream) -> Result<(), Box<dyn std::error::Error>> { | async fn process( | ||||||
|  |     mut inbound: TcpStream, | ||||||
|  |     upstream: &Upstream, | ||||||
|  | ) -> Result<(), Box<dyn std::error::Error>> { | ||||||
|     match upstream { |     match upstream { | ||||||
|         Upstream::Ban => { |         Upstream::Ban => { | ||||||
|             let _ = inbound.shutdown(); |             let _ = inbound.shutdown(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 KernelErr
					KernelErr