use crate::upstreams::ProxyToUpstream; use crate::upstreams::Upstream; use log::{debug, info, warn}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{Error as IOError, Read}; use url::Url; #[derive(Debug, Clone)] pub struct ConfigV1 { pub base: ParsedConfigV1, } #[derive(Debug, Default, Deserialize, Clone)] pub struct ParsedConfigV1 { pub version: i32, pub log: Option, pub servers: HashMap, pub upstream: HashMap, } #[derive(Debug, Default, Deserialize, Clone)] pub struct BaseConfig { pub version: i32, pub log: Option, pub servers: HashMap, pub upstream: HashMap, } #[derive(Debug, Default, Deserialize, Clone)] pub struct ServerConfig { pub listen: Vec, pub protocol: Option, pub tls: Option, pub sni: Option>, pub default: Option, } impl TryInto for &str { type Error = ConfigError; fn try_into(self) -> Result { let upstream_url = match Url::parse(self) { Ok(url) => url, Err(_) => { return Err(ConfigError::Custom(format!( "Invalid upstream url {}", self ))) } }; let upstream_host = match upstream_url.host_str() { Some(host) => host, None => { return Err(ConfigError::Custom(format!( "Invalid upstream url {}", self ))) } }; let upstream_port = match upstream_url.port_or_known_default() { Some(port) => port, None => { return Err(ConfigError::Custom(format!( "Invalid upstream url {}", self ))) } }; match upstream_url.scheme() { "tcp" | "tcp4" | "tcp6" => {} _ => { return Err(ConfigError::Custom(format!( "Invalid upstream scheme {}", self ))) } } Ok(ProxyToUpstream::new( format!("{}:{}", upstream_host, upstream_port), upstream_url.scheme().to_string(), )) } } #[derive(Debug)] pub enum ConfigError { IO(IOError), Yaml(serde_yaml::Error), Custom(String), } impl ConfigV1 { pub fn new(path: &str) -> Result { let base = load_config(path)?; Ok(ConfigV1 { base }) } } fn load_config(path: &str) -> Result { let mut contents = String::new(); let mut file = File::open(path)?; file.read_to_string(&mut contents)?; let base: BaseConfig = serde_yaml::from_str(&contents)?; if base.version != 1 { return Err(ConfigError::Custom( "Unsupported config version".to_string(), )); } let log_level = base.log.clone().unwrap_or_else(|| "info".to_string()); if !log_level.eq("disable") { std::env::set_var("FOURTH_LOG", log_level.clone()); pretty_env_logger::init_custom_env("FOURTH_LOG"); } info!("Using config file: {}", &path); debug!("Set log level to {}", log_level); debug!("Config version {}", base.version); let mut parsed_upstream: HashMap = HashMap::new(); parsed_upstream.insert("ban".to_string(), Upstream::Ban); parsed_upstream.insert("echo".to_string(), Upstream::Echo); for (name, upstream) in base.upstream.iter() { let ups = upstream.as_str().try_into()?; parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups)); } let parsed = ParsedConfigV1 { version: base.version, log: base.log, servers: base.servers, upstream: parsed_upstream, }; verify_config(parsed) } fn verify_config(config: ParsedConfigV1) -> Result { let mut used_upstreams: HashSet = HashSet::new(); let mut upstream_names: HashSet = HashSet::new(); let mut listen_addresses: HashSet = HashSet::new(); // Check for duplicate upstream names for (name, _) in config.upstream.iter() { if upstream_names.contains(name) { return Err(ConfigError::Custom(format!( "Duplicate upstream name {}", name ))); } upstream_names.insert(name.to_string()); } for (_, server) in config.servers.clone() { // check for duplicate listen addresses for listen in server.listen { if listen_addresses.contains(&listen) { return Err(ConfigError::Custom(format!( "Duplicate listen address {}", listen ))); } listen_addresses.insert(listen.to_string()); } if server.tls.unwrap_or_default() && server.sni.is_some() { for (_, val) in server.sni.unwrap() { used_upstreams.insert(val.to_string()); } } if server.default.is_some() { used_upstreams.insert(server.default.unwrap().to_string()); } for key in &used_upstreams { if !config.upstream.contains_key(key) { return Err(ConfigError::Custom(format!("Upstream {} not found", key))); } } } for key in &upstream_names { if !used_upstreams.contains(key) && !key.eq("echo") && !key.eq("ban") { warn!("Upstream {} not used", key); } } Ok(config) } impl From for ConfigError { fn from(err: IOError) -> ConfigError { ConfigError::IO(err) } } impl From for ConfigError { fn from(err: serde_yaml::Error) -> ConfigError { ConfigError::Yaml(err) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_load_config() { let config = ConfigV1::new("tests/config.yaml").unwrap(); assert_eq!(config.base.version, 1); assert_eq!(config.base.log.unwrap(), "disable"); assert_eq!(config.base.servers.len(), 3); assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams } }