From 590740f40e9b4e8eeed43ede68e9e39862c9e316 Mon Sep 17 00:00:00 2001 From: Jacob Kiers Date: Fri, 3 Apr 2026 00:31:05 +0200 Subject: [PATCH] Add wildcard SNI matching --- Cargo.lock | 18 +- Cargo.toml | 1 + README.md | 15 ++ config.yaml.example | 6 + src/config/config_v1.rs | 169 +++++++++++++--- src/config/mod.rs | 1 + src/main.rs | 1 + src/servers/mod.rs | 243 +++++++++++++++++++++- src/servers/protocol/tls.rs | 44 ++-- src/sni_matcher.rs | 391 ++++++++++++++++++++++++++++++++++++ 10 files changed, 837 insertions(+), 52 deletions(-) create mode 100644 src/sni_matcher.rs diff --git a/Cargo.lock b/Cargo.lock index c7301c8..0d563d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -516,7 +516,7 @@ dependencies = [ [[package]] name = "l4p" -version = "0.1.10" +version = "0.1.11" dependencies = [ "async-trait", "byte_string", @@ -525,6 +525,7 @@ dependencies = [ "log", "pico-args", "pretty_env_logger", + "psl", "self_update", "serde", "serde_yaml", @@ -856,6 +857,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psl" +version = "2.1.199" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70b63978a2742d3f662188698ab45854156e7e34658f53fa951e9253a3dfd583" +dependencies = [ + "psl-types", +] + +[[package]] +name = "psl-types" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" + [[package]] name = "quick-xml" version = "0.37.2" diff --git a/Cargo.toml b/Cargo.toml index 0946f33..7bdf768 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ time = { version = "0.3.37", features = ["local-offset", "formatting"] } tls-parser = "0.12.2" tokio = { version = "1.0", features = ["full"] } url = "2.2.2" +psl = "2.1" [dependencies.self_update] version = "0.42.0" diff --git a/README.md b/README.md index 19f114f..e21c791 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,13 @@ $ cargo install l4p Or you can download binary file form the Release page. +## Features + +- Listen on specific port and proxy to local or remote port +- SNI-based rule without terminating TLS connection +- Wildcard SNI matching with DNS-style longest-suffix-match +- DNS-based backend with periodic resolution + ## Configuration `l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example: @@ -55,6 +62,14 @@ There are two upstreams built in: For detailed configuration, check [this example](./config.yaml.example). +### SNI Matching + +The proxy supports both exact and wildcard SNI patterns in the `sni` config. Wildcards use DNS-style longest-suffix-match: more specific patterns take precedence. For example, with `*.example.com` and `*.api.example.com`, request `api.example.com` matches the first, while `v2.api.example.com` matches the second. + +Wildcards are validated against the Public Suffix List (PSL). Known suffixes (`.com`, `.org`) require at least one label below the suffix (`*.example.com` OK, `*.com` rejected). Unknown suffixes (`.local`, `.lan`) are allowed without restriction. + +Invalid wildcard patterns are rejected at config load time with clear error messages. + ## Thanks - [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork. diff --git a/config.yaml.example b/config.yaml.example index cf64858..a9d6fdc 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -10,6 +10,9 @@ servers: sni: api.example.org: example-api www.example.org: proxy + *.example.org: wildcard-proxy # Matches any subdomain of example.org + *.dev.example.org: dev-proxy # More specific: matches v2.dev.example.org, etc. + *.local: local-upstream # Unknown suffix - allowed (no PSL restriction) default: ban second-server: @@ -19,3 +22,6 @@ servers: upstream: proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443 example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443 + wildcard-proxy: "tcp://wildcard.example.org:443" + dev-proxy: "tcp://dev.example.org:443" + local-upstream: "tcp://localhost:8080" diff --git a/src/config/config_v1.rs b/src/config/config_v1.rs index 123c4e2..eddce6a 100644 --- a/src/config/config_v1.rs +++ b/src/config/config_v1.rs @@ -1,3 +1,4 @@ +use crate::sni_matcher::SniMatcher; use crate::upstreams::ProxyToUpstream; use crate::upstreams::Upstream; use log::{debug, info, warn}; @@ -12,14 +13,6 @@ pub struct ConfigV1 { pub base: ParsedConfigV1, } -#[derive(Debug, Default, Deserialize, Clone)] -pub struct ParsedConfigV1 { - pub version: i32, - pub log: Option, - pub servers: HashMap, - pub upstream: HashMap, -} - #[derive(Debug, Default, Deserialize, Clone)] pub struct BaseConfig { pub version: i32, @@ -28,6 +21,14 @@ pub struct BaseConfig { pub upstream: HashMap, } +#[derive(Debug, Default, Clone)] +pub struct ParsedConfigV1 { + pub version: i32, + pub log: Option, + pub servers: HashMap, + pub upstream: HashMap, +} + #[derive(Debug, Default, Deserialize, Clone)] pub struct ServerConfig { pub listen: Vec, @@ -36,6 +37,35 @@ pub struct ServerConfig { pub sni: Option>, pub default: Option, } + +impl ServerConfig { + pub fn into_parsed(self) -> Result> { + let sni = match self.sni { + Some(sni_map) => { + let matcher = SniMatcher::new(sni_map)?; + Some(matcher) + } + None => None, + }; + + Ok(ParsedServerConfig { + listen: self.listen, + protocol: self.protocol, + tls: self.tls, + sni, + default: self.default, + }) + } +} + +#[derive(Debug, Clone)] +pub struct ParsedServerConfig { + pub listen: Vec, + pub protocol: Option, + pub tls: Option, + pub sni: Option, + pub default: Option, +} impl TryInto for &str { type Error = ConfigError; @@ -102,12 +132,23 @@ impl ConfigV1 { } } -fn load_config(path: &str) -> Result { - let mut contents = String::new(); - let mut file = File::open(path)?; - file.read_to_string(&mut contents)?; - - let base: BaseConfig = serde_yaml::from_str(&contents)?; +/// Load and parse configuration from a YAML string. +/// +/// This public function takes raw YAML content as a string and returns a parsed, +/// validated configuration. It performs all validation including: +/// - Version checking +/// - SNI pattern validation +/// - Upstream URL parsing +/// - Cross-reference validation +/// +/// # Arguments +/// * `yaml_str` - The YAML configuration content as a string +/// +/// # Returns +/// * `Ok(ParsedConfigV1)` - Successfully parsed and validated configuration +/// * `Err(ConfigError)` - If YAML parsing fails or validation errors occur +pub fn load_config_from_yaml(yaml_str: &str) -> Result { + let base: BaseConfig = serde_yaml::from_str(yaml_str)?; if base.version != 1 { return Err(ConfigError::Custom( @@ -117,11 +158,12 @@ fn load_config(path: &str) -> Result { let log_level = base.log.clone().unwrap_or_else(|| "info".to_string()); if !log_level.eq("disable") { - std::env::set_var("FOURTH_LOG", log_level.clone()); - pretty_env_logger::init_custom_env("FOURTH_LOG"); + unsafe { + std::env::set_var("FOURTH_LOG", log_level.clone()); + pretty_env_logger::init_custom_env("FOURTH_LOG"); + } } - info!("Using config file: {}", &path); debug!("Set log level to {}", log_level); debug!("Config version {}", base.version); @@ -135,16 +177,50 @@ fn load_config(path: &str) -> Result { parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups)); } + // Convert ServerConfig to ParsedServerConfig, collecting all SNI validation errors + let mut all_errors: Vec = Vec::new(); + let mut parsed_servers: HashMap = HashMap::new(); + + for (name, server_config) in base.servers { + match server_config.into_parsed() { + Ok(parsed) => { + parsed_servers.insert(name, parsed); + } + Err(errors) => { + for err in errors { + all_errors.push(format!("Server '{}': {}", name, err)); + } + } + } + } + + if !all_errors.is_empty() { + return Err(ConfigError::Custom(format!( + "Invalid SNI configuration:\n{}", + all_errors.join("\n") + ))); + } + let parsed = ParsedConfigV1 { version: base.version, log: base.log, - servers: base.servers, + servers: parsed_servers, upstream: parsed_upstream, }; verify_config(parsed) } +fn load_config(path: &str) -> Result { + let mut contents = String::new(); + let mut file = File::open(path)?; + file.read_to_string(&mut contents)?; + + info!("Using config file: {}", &path); + + load_config_from_yaml(&contents) +} + fn verify_config(config: ParsedConfigV1) -> Result { let mut used_upstreams: HashSet = HashSet::new(); let mut upstream_names: HashSet = HashSet::new(); @@ -175,14 +251,20 @@ fn verify_config(config: ParsedConfigV1) -> Result listen_addresses.insert(listen.to_string()); } - if server.tls.unwrap_or_default() && server.sni.is_some() { - for (_, val) in server.sni.unwrap() { - used_upstreams.insert(val.to_string()); + if server.tls.unwrap_or_default() { + if let Some(matcher) = &server.sni { + // Collect all upstream names from the SniMatcher + for (_, upstream) in matcher.exact.iter() { + used_upstreams.insert(upstream.clone()); + } + for pattern in &matcher.wildcards { + used_upstreams.insert(pattern.upstream.clone()); + } } } - if server.default.is_some() { - used_upstreams.insert(server.default.unwrap().to_string()); + if let Some(default) = &server.default { + used_upstreams.insert(default.clone()); } for key in &used_upstreams { @@ -225,4 +307,45 @@ mod tests { assert_eq!(config.base.servers.len(), 3); assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams } + + #[test] + fn test_config_hard_failure_on_invalid_sni() { + // Test that invalid SNI wildcard (*.com) causes hard failure + let config_content = r#"version: 1 +log: disable +servers: + test_server: + listen: + - "127.0.0.1:8443" + protocol: tcp + tls: true + sni: + "*.com": "upstream1" + default: ban +upstream: + upstream1: tcp://127.0.0.1:9000 +"#; + + let result = load_config_from_yaml(config_content); + + // Should fail with an error + assert!(result.is_err(), "Expected config to fail with invalid SNI"); + + // Verify error message contains helpful information + match result { + Err(ConfigError::Custom(msg)) => { + assert!( + msg.contains("Invalid SNI"), + "Error message should mention invalid SNI: {}", + msg + ); + assert!( + msg.contains("*.com"), + "Error message should mention the invalid pattern: {}", + msg + ); + } + _ => panic!("Expected ConfigError::Custom"), + } + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 9cae978..27ffd6c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,3 +1,4 @@ mod config_v1; pub(crate) use config_v1::ConfigV1; pub(crate) use config_v1::ParsedConfigV1; +pub(crate) use config_v1::ParsedServerConfig; diff --git a/src/main.rs b/src/main.rs index 296eccd..6fc315a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod config; mod servers; +mod sni_matcher; mod update; mod upstreams; diff --git a/src/servers/mod.rs b/src/servers/mod.rs index cd68daf..7b3da59 100644 --- a/src/servers/mod.rs +++ b/src/servers/mod.rs @@ -9,6 +9,7 @@ mod protocol; pub(crate) mod upstream_address; use crate::config::ParsedConfigV1; +use crate::sni_matcher::SniMatcher; use crate::upstreams::Upstream; use protocol::tcp; @@ -23,7 +24,7 @@ pub(crate) struct Proxy { pub listen: SocketAddr, pub protocol: String, pub tls: bool, - pub sni: Option>, + pub sni: Option, pub default_action: String, pub upstream: HashMap, } @@ -134,6 +135,132 @@ mod tests { } } + /// Mock server for wildcard SNI test that responds with "tls_wildcard_response" on first read + #[tokio::main] + async fn tls_mock_server_wildcard() { + let server_addr: SocketAddr = "127.0.0.1:54598".parse().unwrap(); + let listener = TcpListener::bind(server_addr).await.unwrap(); + loop { + let (mut stream, _) = listener.accept().await.unwrap(); + // Read client hello (which will be peeked but not actually read by proxy) + let mut buf = [0u8; 1024]; + let _ = stream.read(&mut buf).await; + // Send a response to verify connection succeeded + let _ = stream.write(b"tls_wildcard_response").await; + let _ = stream.shutdown().await; + } + } + + /// Mock server for SNI test that doesn't match wildcard pattern + #[tokio::main] + async fn tls_mock_server_default() { + let server_addr: SocketAddr = "127.0.0.1:54597".parse().unwrap(); + let listener = TcpListener::bind(server_addr).await.unwrap(); + loop { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 1024]; + let _ = stream.read(&mut buf).await; + let _ = stream.write(b"tls_default_response").await; + let _ = stream.shutdown().await; + } + } + + /// Helper function to build a minimal TLS ClientHello with SNI extension + /// This creates a valid TLS 1.2 ClientHello packet with the specified SNI hostname + fn build_tls_client_hello(sni_hostname: &str) -> Vec { + // TLS record header (9 bytes) + let mut hello = Vec::new(); + + // Record type: Handshake (0x16) + hello.push(0x16); + // Version: TLS 1.2 (0x0303) + hello.extend_from_slice(&[0x03, 0x03]); + + // We'll set the record length later + let record_length_pos = hello.len(); + hello.extend_from_slice(&[0x00, 0x00]); // Placeholder for record length + + // Handshake message type: ClientHello (0x01) + hello.push(0x01); + + // We'll set the handshake length later + let handshake_length_pos = hello.len(); + hello.extend_from_slice(&[0x00, 0x00, 0x00]); // Placeholder for handshake length + + // ClientHello fields + // Protocol version: TLS 1.2 (0x0303) + hello.extend_from_slice(&[0x03, 0x03]); + + // Random: 32 bytes (we'll use a fixed pattern) + hello.extend_from_slice(&[0x00; 32]); + + // Session ID length: 0 (no session) + hello.push(0x00); + + // Cipher suites length: 2 bytes + cipher suites + hello.extend_from_slice(&[0x00, 0x02]); // Length of cipher suites list (2 bytes) + // Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002F) + hello.extend_from_slice(&[0x00, 0x2F]); + + // Compression methods length: 1 byte + hello.push(0x01); + // Compression method: null (0x00) + hello.push(0x00); + + // Extensions + let extensions_start = hello.len(); + + // SNI Extension (type 0x0000) + let mut sni_extension = Vec::new(); + sni_extension.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name + + // Extension length (will be set later) + let ext_length_pos = sni_extension.len(); + sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder + + // Server name list + let server_name_list_start = sni_extension.len(); + sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder for server name list length + + // Server name: host_name(0), length, hostname + sni_extension.push(0x00); // Name type: host_name + let hostname_bytes = sni_hostname.as_bytes(); + sni_extension.extend_from_slice(&[(hostname_bytes.len() >> 8) as u8, (hostname_bytes.len() & 0xFF) as u8]); + sni_extension.extend_from_slice(hostname_bytes); + + // Set server name list length + let server_name_list_len = sni_extension.len() - server_name_list_start - 2; + let pos = server_name_list_start; + sni_extension[pos] = (server_name_list_len >> 8) as u8; + sni_extension[pos + 1] = (server_name_list_len & 0xFF) as u8; + + // Set extension length + let ext_len = sni_extension.len() - ext_length_pos - 2; + sni_extension[ext_length_pos] = (ext_len >> 8) as u8; + sni_extension[ext_length_pos + 1] = (ext_len & 0xFF) as u8; + + // Add SNI extension to hello + hello.extend_from_slice(&sni_extension); + + // Set extensions total length + let extensions_length = hello.len() - extensions_start; + hello.insert(extensions_start, (extensions_length & 0xFF) as u8); + hello.insert(extensions_start, (extensions_length >> 8) as u8); + + // Set handshake message length + let handshake_len = hello.len() - handshake_length_pos - 3; + hello[handshake_length_pos] = (handshake_len >> 16) as u8; + hello[handshake_length_pos + 1] = (handshake_len >> 8) as u8; + hello[handshake_length_pos + 2] = (handshake_len & 0xFF) as u8; + + // Set record length + let record_len = hello.len() - record_length_pos - 2; + hello[record_length_pos] = (record_len >> 8) as u8; + hello[record_length_pos + 1] = (record_len & 0xFF) as u8; + + hello + } + #[tokio::test] async fn test_proxy() { use crate::config::ConfigV1; @@ -170,4 +297,118 @@ mod tests { } conn.shutdown().await.unwrap(); } + + #[tokio::test] + async fn test_wildcard_sni_routing() { + // Create test configuration with wildcard SNI pattern + use crate::upstreams::Upstream; + use std::collections::HashMap; + + // Start mock servers for upstreams + thread::spawn(move || { + tls_mock_server_wildcard(); + }); + thread::spawn(move || { + tls_mock_server_default(); + }); + sleep(Duration::from_millis(500)); // wait for mock servers to start + + // Create inline configuration + let mut config = crate::config::ParsedConfigV1 { + version: 1, + log: Some("disable".to_string()), + servers: HashMap::new(), + upstream: HashMap::new(), + }; + + // Add upstreams + config.upstream.insert( + "wildcard_upstream".to_string(), + Upstream::Proxy(crate::upstreams::ProxyToUpstream::new( + "127.0.0.1:54598".to_string(), + "tcp".to_string(), + )), + ); + config.upstream.insert( + "default_upstream".to_string(), + Upstream::Proxy(crate::upstreams::ProxyToUpstream::new( + "127.0.0.1:54597".to_string(), + "tcp".to_string(), + )), + ); + + // Add TLS server with wildcard SNI pattern + let mut sni_map = HashMap::new(); + sni_map.insert("*.api.example.com".to_string(), "wildcard_upstream".to_string()); + + let server_config = crate::config::ParsedServerConfig { + listen: vec!["127.0.0.1:54595".to_string()], + protocol: Some("tcp".to_string()), + tls: Some(true), + sni: Some(crate::sni_matcher::SniMatcher::new(sni_map).unwrap()), + default: Some("default_upstream".to_string()), + }; + + config.servers.insert("wildcard_test_server".to_string(), server_config); + + // Start proxy server + let mut server = Server::new_from_v1_config(config); + thread::spawn(move || { + let _ = server.run(); + }); + sleep(Duration::from_secs(1)); // wait for proxy to start + + // Test 1: Send ClientHello with SNI matching wildcard pattern + // Expected: Should route to wildcard_upstream (127.0.0.1:54598) + let client_hello = build_tls_client_hello("app.api.example.com"); + let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595") + .await + .unwrap(); + + let _ = conn.write(&client_hello).await.unwrap(); + let mut response = [0u8; 21]; + let n = conn.read(&mut response).await.unwrap(); + assert!(n > 0, "Should receive response from wildcard upstream"); + assert_eq!( + &response[..n], + b"tls_wildcard_response", + "Should receive expected response from wildcard upstream" + ); + conn.shutdown().await.unwrap(); + + // Test 2: Send ClientHello with SNI not matching any pattern + // Expected: Should route to default_upstream (127.0.0.1:54597) + let client_hello_nomatch = build_tls_client_hello("unrelated.example.com"); + let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595") + .await + .unwrap(); + + let _ = conn.write(&client_hello_nomatch).await.unwrap(); + let mut response = [0u8; 20]; + let n = conn.read(&mut response).await.unwrap(); + assert!(n > 0, "Should receive response from default upstream"); + assert_eq!( + &response[..n], + b"tls_default_response", + "Should receive expected response from default upstream" + ); + conn.shutdown().await.unwrap(); + + // Test 3: Send ClientHello with another SNI matching wildcard pattern + let client_hello_match2 = build_tls_client_hello("v2.api.example.com"); + let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595") + .await + .unwrap(); + + let _ = conn.write(&client_hello_match2).await.unwrap(); + let mut response = [0u8; 21]; + let n = conn.read(&mut response).await.unwrap(); + assert!(n > 0, "Should receive response from wildcard upstream for second match"); + assert_eq!( + &response[..n], + b"tls_wildcard_response", + "Should receive expected response from wildcard upstream for second match" + ); + conn.shutdown().await.unwrap(); + } } diff --git a/src/servers/protocol/tls.rs b/src/servers/protocol/tls.rs index 20beef2..0797128 100644 --- a/src/servers/protocol/tls.rs +++ b/src/servers/protocol/tls.rs @@ -206,36 +206,26 @@ pub(crate) async fn determine_upstream_name( if snis.is_empty() { debug!("No SNI found in ClientHello, using default upstream."); return Ok(default_upstream); - } else { - match proxy.sni.clone() { - Some(sni_map) => { - let mut upstream = default_upstream.clone(); // Clone here for default case - let mut found_match = false; - for sni in snis { - // snis is already Vec - if let Some(target_upstream) = sni_map.get(&sni) { - debug!( - "Found matching SNI '{}', routing to upstream: {}", - sni, target_upstream - ); - upstream = target_upstream.clone(); - found_match = true; - break; - } else { - trace!("SNI '{}' not found in map.", sni); - } - } - if !found_match { - debug!("SNI(s) found but none matched configuration, using default upstream."); - } - Ok(upstream) - } - None => { - debug!("SNI found but no SNI map configured, using default upstream."); - Ok(default_upstream) + } + + if let Some(matcher) = &proxy.sni { + for sni in snis { + if let Some(upstream) = matcher.match_sni(&sni) { + debug!( + "Found matching SNI '{}', routing to upstream: {}", + sni, upstream + ); + return Ok(upstream); + } else { + trace!("SNI '{}' not found in matcher.", sni); } } + debug!("SNI(s) found but none matched configuration, using default upstream."); + } else { + debug!("SNI found but no SNI matcher configured, using default upstream."); } + + Ok(default_upstream) } fn client_hello_buffer_size(data: &[u8]) -> Result { diff --git a/src/sni_matcher.rs b/src/sni_matcher.rs new file mode 100644 index 0000000..aaa22e8 --- /dev/null +++ b/src/sni_matcher.rs @@ -0,0 +1,391 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct SniPattern { + pub pattern: String, + pub upstream: String, +} + +/// Validates and matches SNI patterns against incoming SNI values. Supports exact matches and wildcard patterns. +/// +/// Rules: +/// - Wildcard patterns must start with "*." and have a valid domain suffix after it (e.g., "*.example.com") +/// - Wildcard patterns cannot have a wildcard in the middle (e.g., "*.example.*" is invalid) +/// - Wildcard patterns cannot have multiple wildcards (e.g., "*.*.example.com" is invalid) +/// - When there are two wildcard patterns that could match the same SNI, the longest suffix wins +/// (e.g., "*.example.com" vs "*.api.example.com" - "v2.api.example.com" matches "*.api.example.com") +/// - Wildcard patterns that overlap are not matched (e.g. "*.example.com" and "*.bar.example.com" - "bar.example.com" matches neither, "v2.bar.example.com" matches "*.bar.example.com") +/// - Wildcard patterns cannot be just "*." or "*" (must have a valid suffix) +/// - For known public suffixes (e.g., "com", "org"), the wildcard must have at least one label below the public suffix (e.g., "*.example.com" is valid, but "*.com", or *.co.uk is invalid) +/// - For unknown suffixes (e.g., "local", "lan"), the wildcard is allowed without restriction (e.g., "*.local" is valid) +#[derive(Debug, Clone)] +pub struct SniMatcher { + pub exact: HashMap, + pub wildcards: Vec, +} + +impl SniMatcher { + pub fn new(sni_map: HashMap) -> Result> { + Self::validate(&sni_map)?; + + let mut exact = HashMap::new(); + let mut wildcards = Vec::new(); + + for (pattern, upstream) in sni_map { + if pattern.starts_with("*.") { + wildcards.push(SniPattern { + pattern: pattern.clone(), + upstream, + }); + } else { + exact.insert(pattern, upstream); + } + } + + wildcards.sort_by(|a, b| { + let a_suffix = a.pattern.trim_start_matches("*."); + let b_suffix = b.pattern.trim_start_matches("*."); + b_suffix.len().cmp(&a_suffix.len()) + }); + + Ok(SniMatcher { exact, wildcards }) + } + + /// Matches the provided SNI against the patterns in the matcher. Returns Some(upstream) if a match is found, or None if no match is found. + pub fn match_sni(&self, sni: &str) -> Option { + if let Some(upstream) = self.exact.get(sni) { + return Some(upstream.clone()); + } + + // Try each wildcard in order (longest suffix first) + for wildcard in &self.wildcards { + let suffix = wildcard.pattern.trim_start_matches("*."); + let suffix_len = suffix.len(); + let check = format!(".{}", suffix); + + if !sni.ends_with(&check) { + continue; + } + + // Must have at least one label before the suffix to match + let prefix = &sni[..sni.len() - suffix_len - 1]; + + // Check if a more specific wildcard could also match this SNI + let is_owned = { + let sni_labels = sni.matches('.').count() + 1; + + self.wildcards.iter().any(|w| { + // Skip the current wildcard itself + if w.pattern == wildcard.pattern { + return false; + } + + let w_suffix = w.pattern.trim_start_matches("*."); + let w_len = w_suffix.len(); + + // Only care about wildcards with longer suffix (more specific) + if w_len <= suffix_len { + return false; + } + + let w_suffix_labels = w_suffix.matches('.').count() + 1; + + if sni == w_suffix { + // Exact match to more specific suffix - owned if could potentially match + // (sni has at least as many labels as the suffix needs) + sni_labels >= w_suffix_labels + } else if sni.ends_with(&format!(".{}", w_suffix)) { + // Ends with more specific suffix - owned if SNI has enough labels + sni_labels >= w_suffix_labels + 1 + } else { + false + } + }) + }; + + if is_owned { + continue; + } + + // Only return if we have a valid prefix (at least one label) + if !prefix.is_empty() { + return Some(wildcard.upstream.clone()); + } + } + + None + } + + fn validate_wildcard_suffix(pattern: &str) -> Result<(), String> { + let suffix = pattern.trim_start_matches("*."); + let domain_str = format!("a.{}", suffix); + + if let Some(ps) = psl::suffix(domain_str.as_bytes()) { + let ps_str = std::str::from_utf8(ps.as_bytes()).unwrap_or(""); + if ps_str == suffix { + return Err(format!( + "Invalid wildcard pattern: {} - wildcard cannot be at the public suffix level", + pattern + )); + } + } + + Ok(()) + } + + fn validate(sni_map: &HashMap) -> Result<(), Vec> { + let mut errors = Vec::new(); + + for (pattern, _upstream) in sni_map { + if pattern == "*" { + errors.push(format!( + "Invalid wildcard pattern: * - just asterisk is not allowed" + )); + continue; + } + + if pattern == "*." { + errors.push(format!( + "Invalid wildcard pattern: *. - empty suffix after wildcard" + )); + continue; + } + + if let Some(rest) = pattern.strip_prefix("*.") { + if rest.is_empty() { + errors.push(format!( + "Invalid wildcard pattern: {pattern} - empty suffix after wildcard" + )); + continue; + } + + if rest.contains('*') { + errors.push(format!("Invalid wildcard pattern: {pattern} - wildcard cannot be in the middle of suffix")); + continue; + } + + if let Err(e) = Self::validate_wildcard_suffix(pattern) { + errors.push(e); + } + } else if pattern.contains('*') { + errors.push(format!( + "Invalid wildcard pattern: {pattern} - wildcard must be at the start" + )); + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exact_match() { + let mut sni_map = HashMap::new(); + sni_map.insert("example.com".to_string(), "upstream1".to_string()); + sni_map.insert("*.example.com".to_string(), "upstream2".to_string()); + + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!( + matcher.match_sni("example.com"), + Some("upstream1".to_string()) + ); + } + + #[test] + fn test_wildcard_match() { + let mut sni_map = HashMap::new(); + sni_map.insert("example.com".to_string(), "upstream1".to_string()); + sni_map.insert("*.example.com".to_string(), "upstream2".to_string()); + + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!( + matcher.match_sni("www.example.com"), + Some("upstream2".to_string()) + ); + assert_eq!( + matcher.match_sni("api.example.com"), + Some("upstream2".to_string()) + ); + } + + #[test] + fn test_longest_suffix_match() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.example.com".to_string(), "wildcard1".to_string()); + sni_map.insert("*.api.example.com".to_string(), "wildcard2".to_string()); + + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!( + matcher.match_sni("v2.api.example.com"), + Some("wildcard2".to_string()) + ); + } + + #[test] + fn test_no_match() { + let mut sni_map = HashMap::new(); + sni_map.insert("example.com".to_string(), "upstream1".to_string()); + + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!(matcher.match_sni("other.com"), None); + } + + #[test] + fn test_unknown_suffix_allowed() { + let mut sni_map = HashMap::new(); + sni_map.insert( + "*.private.local".to_string(), + "private_upstream".to_string(), + ); + sni_map.insert( + "*.internal.net".to_string(), + "internal_upstream".to_string(), + ); + + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!( + matcher.match_sni("server.private.local"), + Some("private_upstream".to_string()) + ); + assert_eq!( + matcher.match_sni("app.internal.net"), + Some("internal_upstream".to_string()) + ); + } + + #[test] + fn test_invalid_public_suffix() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.com".to_string(), "invalid".to_string()); + + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + + let errors = result.unwrap_err(); + assert!(!errors.is_empty()); + assert!(errors[0].contains("*.com")); + } + + #[test] + fn test_multiple_errors_collected() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.com".to_string(), "invalid1".to_string()); + sni_map.insert("*.org".to_string(), "invalid2".to_string()); + sni_map.insert("*.net".to_string(), "invalid3".to_string()); + + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 3); + } + + #[test] + fn test_valid_public_suffix() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.example.com".to_string(), "valid".to_string()); + + let matcher = SniMatcher::new(sni_map).unwrap(); + assert_eq!( + matcher.match_sni("www.example.com"), + Some("valid".to_string()) + ); + } + + #[test] + fn test_validate_static() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.com".to_string(), "invalid".to_string()); + sni_map.insert("*.example.com".to_string(), "valid".to_string()); + + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("*.com")); + } + + #[test] + fn test_wildcard_not_at_start_rejected() { + let mut sni_map = HashMap::new(); + sni_map.insert("foo*.example.com".to_string(), "invalid".to_string()); + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert!(!errors.is_empty()); + assert!(errors[0].contains("*.example.com")); + } + + #[test] + fn test_wildcard_in_middle_rejected() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.example.*".to_string(), "invalid".to_string()); + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + } + + #[test] + fn test_trailing_dot_rejected() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.".to_string(), "invalid".to_string()); + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + } + + #[test] + fn test_just_asterisk_rejected() { + let mut sni_map = HashMap::new(); + sni_map.insert("*".to_string(), "invalid".to_string()); + let result = SniMatcher::new(sni_map); + assert!(result.is_err()); + } + + #[test] + fn test_wildcard_requires_subdomain() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.example.com".to_string(), "upstream".to_string()); + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!(matcher.match_sni("example.com"), None); + assert_eq!( + matcher.match_sni("www.example.com"), + Some("upstream".to_string()) + ); + assert_eq!( + matcher.match_sni("foo.bar.example.com"), + Some("upstream".to_string()) + ); + } + + #[test] + fn test_longest_suffix_wins_not_shortest() { + let mut sni_map = HashMap::new(); + sni_map.insert("*.example.org".to_string(), "broad".to_string()); + sni_map.insert("*.bar.example.org".to_string(), "narrow".to_string()); + let matcher = SniMatcher::new(sni_map).unwrap(); + + assert_eq!(matcher.match_sni("bar.example.org"), None); + assert_eq!( + matcher.match_sni("v2.bar.example.org"), + Some("narrow".to_string()) + ); + assert_eq!( + matcher.match_sni("v2.example.org"), + Some("broad".to_string()) + ); + } +} -- 2.43.0