Add wildcard SNI matching
Some checks failed
ci/woodpecker/push/build/1 Pipeline was canceled
ci/woodpecker/push/build/3 Pipeline was canceled
ci/woodpecker/push/build/2 Pipeline was canceled
ci/woodpecker/tag/build/2 Pipeline is pending
ci/woodpecker/tag/build/3 Pipeline is pending
ci/woodpecker/tag/build/1 Pipeline was canceled

This commit was merged in pull request #13.
This commit is contained in:
2026-04-03 00:31:05 +02:00
parent a674895173
commit 590740f40e
10 changed files with 837 additions and 52 deletions

18
Cargo.lock generated
View File

@@ -516,7 +516,7 @@ dependencies = [
[[package]] [[package]]
name = "l4p" name = "l4p"
version = "0.1.10" version = "0.1.11"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"byte_string", "byte_string",
@@ -525,6 +525,7 @@ dependencies = [
"log", "log",
"pico-args", "pico-args",
"pretty_env_logger", "pretty_env_logger",
"psl",
"self_update", "self_update",
"serde", "serde",
"serde_yaml", "serde_yaml",
@@ -856,6 +857,21 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "psl"
version = "2.1.199"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70b63978a2742d3f662188698ab45854156e7e34658f53fa951e9253a3dfd583"
dependencies = [
"psl-types",
]
[[package]]
name = "psl-types"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac"
[[package]] [[package]]
name = "quick-xml" name = "quick-xml"
version = "0.37.2" version = "0.37.2"

View File

@@ -33,6 +33,7 @@ time = { version = "0.3.37", features = ["local-offset", "formatting"] }
tls-parser = "0.12.2" tls-parser = "0.12.2"
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }
url = "2.2.2" url = "2.2.2"
psl = "2.1"
[dependencies.self_update] [dependencies.self_update]
version = "0.42.0" version = "0.42.0"

View File

@@ -31,6 +31,13 @@ $ cargo install l4p
Or you can download binary file form the Release page. Or you can download binary file form the Release page.
## Features
- Listen on specific port and proxy to local or remote port
- SNI-based rule without terminating TLS connection
- Wildcard SNI matching with DNS-style longest-suffix-match
- DNS-based backend with periodic resolution
## Configuration ## Configuration
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example: `l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
@@ -55,6 +62,14 @@ There are two upstreams built in:
For detailed configuration, check [this example](./config.yaml.example). For detailed configuration, check [this example](./config.yaml.example).
### SNI Matching
The proxy supports both exact and wildcard SNI patterns in the `sni` config. Wildcards use DNS-style longest-suffix-match: more specific patterns take precedence. For example, with `*.example.com` and `*.api.example.com`, request `api.example.com` matches the first, while `v2.api.example.com` matches the second.
Wildcards are validated against the Public Suffix List (PSL). Known suffixes (`.com`, `.org`) require at least one label below the suffix (`*.example.com` OK, `*.com` rejected). Unknown suffixes (`.local`, `.lan`) are allowed without restriction.
Invalid wildcard patterns are rejected at config load time with clear error messages.
## Thanks ## Thanks
- [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork. - [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork.

View File

@@ -10,6 +10,9 @@ servers:
sni: sni:
api.example.org: example-api api.example.org: example-api
www.example.org: proxy www.example.org: proxy
*.example.org: wildcard-proxy # Matches any subdomain of example.org
*.dev.example.org: dev-proxy # More specific: matches v2.dev.example.org, etc.
*.local: local-upstream # Unknown suffix - allowed (no PSL restriction)
default: ban default: ban
second-server: second-server:
@@ -19,3 +22,6 @@ servers:
upstream: upstream:
proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443 proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443 example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
wildcard-proxy: "tcp://wildcard.example.org:443"
dev-proxy: "tcp://dev.example.org:443"
local-upstream: "tcp://localhost:8080"

View File

@@ -1,3 +1,4 @@
use crate::sni_matcher::SniMatcher;
use crate::upstreams::ProxyToUpstream; use crate::upstreams::ProxyToUpstream;
use crate::upstreams::Upstream; use crate::upstreams::Upstream;
use log::{debug, info, warn}; use log::{debug, info, warn};
@@ -12,14 +13,6 @@ pub struct ConfigV1 {
pub base: ParsedConfigV1, pub base: ParsedConfigV1,
} }
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ParsedConfigV1 {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)] #[derive(Debug, Default, Deserialize, Clone)]
pub struct BaseConfig { pub struct BaseConfig {
pub version: i32, pub version: i32,
@@ -28,6 +21,14 @@ pub struct BaseConfig {
pub upstream: HashMap<String, String>, pub upstream: HashMap<String, String>,
} }
#[derive(Debug, Default, Clone)]
pub struct ParsedConfigV1 {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ParsedServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)] #[derive(Debug, Default, Deserialize, Clone)]
pub struct ServerConfig { pub struct ServerConfig {
pub listen: Vec<String>, pub listen: Vec<String>,
@@ -36,6 +37,35 @@ pub struct ServerConfig {
pub sni: Option<HashMap<String, String>>, pub sni: Option<HashMap<String, String>>,
pub default: Option<String>, pub default: Option<String>,
} }
impl ServerConfig {
pub fn into_parsed(self) -> Result<ParsedServerConfig, Vec<String>> {
let sni = match self.sni {
Some(sni_map) => {
let matcher = SniMatcher::new(sni_map)?;
Some(matcher)
}
None => None,
};
Ok(ParsedServerConfig {
listen: self.listen,
protocol: self.protocol,
tls: self.tls,
sni,
default: self.default,
})
}
}
#[derive(Debug, Clone)]
pub struct ParsedServerConfig {
pub listen: Vec<String>,
pub protocol: Option<String>,
pub tls: Option<bool>,
pub sni: Option<SniMatcher>,
pub default: Option<String>,
}
impl TryInto<ProxyToUpstream> for &str { impl TryInto<ProxyToUpstream> for &str {
type Error = ConfigError; type Error = ConfigError;
@@ -102,12 +132,23 @@ impl ConfigV1 {
} }
} }
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> { /// Load and parse configuration from a YAML string.
let mut contents = String::new(); ///
let mut file = File::open(path)?; /// This public function takes raw YAML content as a string and returns a parsed,
file.read_to_string(&mut contents)?; /// validated configuration. It performs all validation including:
/// - Version checking
let base: BaseConfig = serde_yaml::from_str(&contents)?; /// - SNI pattern validation
/// - Upstream URL parsing
/// - Cross-reference validation
///
/// # Arguments
/// * `yaml_str` - The YAML configuration content as a string
///
/// # Returns
/// * `Ok(ParsedConfigV1)` - Successfully parsed and validated configuration
/// * `Err(ConfigError)` - If YAML parsing fails or validation errors occur
pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigError> {
let base: BaseConfig = serde_yaml::from_str(yaml_str)?;
if base.version != 1 { if base.version != 1 {
return Err(ConfigError::Custom( return Err(ConfigError::Custom(
@@ -117,11 +158,12 @@ fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string()); let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
if !log_level.eq("disable") { if !log_level.eq("disable") {
unsafe {
std::env::set_var("FOURTH_LOG", log_level.clone()); std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG"); pretty_env_logger::init_custom_env("FOURTH_LOG");
} }
}
info!("Using config file: {}", &path);
debug!("Set log level to {}", log_level); debug!("Set log level to {}", log_level);
debug!("Config version {}", base.version); debug!("Config version {}", base.version);
@@ -135,16 +177,50 @@ fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups)); parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
} }
// Convert ServerConfig to ParsedServerConfig, collecting all SNI validation errors
let mut all_errors: Vec<String> = Vec::new();
let mut parsed_servers: HashMap<String, ParsedServerConfig> = HashMap::new();
for (name, server_config) in base.servers {
match server_config.into_parsed() {
Ok(parsed) => {
parsed_servers.insert(name, parsed);
}
Err(errors) => {
for err in errors {
all_errors.push(format!("Server '{}': {}", name, err));
}
}
}
}
if !all_errors.is_empty() {
return Err(ConfigError::Custom(format!(
"Invalid SNI configuration:\n{}",
all_errors.join("\n")
)));
}
let parsed = ParsedConfigV1 { let parsed = ParsedConfigV1 {
version: base.version, version: base.version,
log: base.log, log: base.log,
servers: base.servers, servers: parsed_servers,
upstream: parsed_upstream, upstream: parsed_upstream,
}; };
verify_config(parsed) verify_config(parsed)
} }
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
let mut contents = String::new();
let mut file = File::open(path)?;
file.read_to_string(&mut contents)?;
info!("Using config file: {}", &path);
load_config_from_yaml(&contents)
}
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> { fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
let mut used_upstreams: HashSet<String> = HashSet::new(); let mut used_upstreams: HashSet<String> = HashSet::new();
let mut upstream_names: HashSet<String> = HashSet::new(); let mut upstream_names: HashSet<String> = HashSet::new();
@@ -175,14 +251,20 @@ fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError>
listen_addresses.insert(listen.to_string()); listen_addresses.insert(listen.to_string());
} }
if server.tls.unwrap_or_default() && server.sni.is_some() { if server.tls.unwrap_or_default() {
for (_, val) in server.sni.unwrap() { if let Some(matcher) = &server.sni {
used_upstreams.insert(val.to_string()); // Collect all upstream names from the SniMatcher
for (_, upstream) in matcher.exact.iter() {
used_upstreams.insert(upstream.clone());
}
for pattern in &matcher.wildcards {
used_upstreams.insert(pattern.upstream.clone());
}
} }
} }
if server.default.is_some() { if let Some(default) = &server.default {
used_upstreams.insert(server.default.unwrap().to_string()); used_upstreams.insert(default.clone());
} }
for key in &used_upstreams { for key in &used_upstreams {
@@ -225,4 +307,45 @@ mod tests {
assert_eq!(config.base.servers.len(), 3); assert_eq!(config.base.servers.len(), 3);
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
} }
#[test]
fn test_config_hard_failure_on_invalid_sni() {
// Test that invalid SNI wildcard (*.com) causes hard failure
let config_content = r#"version: 1
log: disable
servers:
test_server:
listen:
- "127.0.0.1:8443"
protocol: tcp
tls: true
sni:
"*.com": "upstream1"
default: ban
upstream:
upstream1: tcp://127.0.0.1:9000
"#;
let result = load_config_from_yaml(config_content);
// Should fail with an error
assert!(result.is_err(), "Expected config to fail with invalid SNI");
// Verify error message contains helpful information
match result {
Err(ConfigError::Custom(msg)) => {
assert!(
msg.contains("Invalid SNI"),
"Error message should mention invalid SNI: {}",
msg
);
assert!(
msg.contains("*.com"),
"Error message should mention the invalid pattern: {}",
msg
);
}
_ => panic!("Expected ConfigError::Custom"),
}
}
} }

View File

@@ -1,3 +1,4 @@
mod config_v1; mod config_v1;
pub(crate) use config_v1::ConfigV1; pub(crate) use config_v1::ConfigV1;
pub(crate) use config_v1::ParsedConfigV1; pub(crate) use config_v1::ParsedConfigV1;
pub(crate) use config_v1::ParsedServerConfig;

View File

@@ -1,5 +1,6 @@
mod config; mod config;
mod servers; mod servers;
mod sni_matcher;
mod update; mod update;
mod upstreams; mod upstreams;

View File

@@ -9,6 +9,7 @@ mod protocol;
pub(crate) mod upstream_address; pub(crate) mod upstream_address;
use crate::config::ParsedConfigV1; use crate::config::ParsedConfigV1;
use crate::sni_matcher::SniMatcher;
use crate::upstreams::Upstream; use crate::upstreams::Upstream;
use protocol::tcp; use protocol::tcp;
@@ -23,7 +24,7 @@ pub(crate) struct Proxy {
pub listen: SocketAddr, pub listen: SocketAddr,
pub protocol: String, pub protocol: String,
pub tls: bool, pub tls: bool,
pub sni: Option<HashMap<String, String>>, pub sni: Option<SniMatcher>,
pub default_action: String, pub default_action: String,
pub upstream: HashMap<String, Upstream>, pub upstream: HashMap<String, Upstream>,
} }
@@ -134,6 +135,132 @@ mod tests {
} }
} }
/// Mock server for wildcard SNI test that responds with "tls_wildcard_response" on first read
#[tokio::main]
async fn tls_mock_server_wildcard() {
let server_addr: SocketAddr = "127.0.0.1:54598".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
// Read client hello (which will be peeked but not actually read by proxy)
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf).await;
// Send a response to verify connection succeeded
let _ = stream.write(b"tls_wildcard_response").await;
let _ = stream.shutdown().await;
}
}
/// Mock server for SNI test that doesn't match wildcard pattern
#[tokio::main]
async fn tls_mock_server_default() {
let server_addr: SocketAddr = "127.0.0.1:54597".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf).await;
let _ = stream.write(b"tls_default_response").await;
let _ = stream.shutdown().await;
}
}
/// Helper function to build a minimal TLS ClientHello with SNI extension
/// This creates a valid TLS 1.2 ClientHello packet with the specified SNI hostname
fn build_tls_client_hello(sni_hostname: &str) -> Vec<u8> {
// TLS record header (9 bytes)
let mut hello = Vec::new();
// Record type: Handshake (0x16)
hello.push(0x16);
// Version: TLS 1.2 (0x0303)
hello.extend_from_slice(&[0x03, 0x03]);
// We'll set the record length later
let record_length_pos = hello.len();
hello.extend_from_slice(&[0x00, 0x00]); // Placeholder for record length
// Handshake message type: ClientHello (0x01)
hello.push(0x01);
// We'll set the handshake length later
let handshake_length_pos = hello.len();
hello.extend_from_slice(&[0x00, 0x00, 0x00]); // Placeholder for handshake length
// ClientHello fields
// Protocol version: TLS 1.2 (0x0303)
hello.extend_from_slice(&[0x03, 0x03]);
// Random: 32 bytes (we'll use a fixed pattern)
hello.extend_from_slice(&[0x00; 32]);
// Session ID length: 0 (no session)
hello.push(0x00);
// Cipher suites length: 2 bytes + cipher suites
hello.extend_from_slice(&[0x00, 0x02]); // Length of cipher suites list (2 bytes)
// Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002F)
hello.extend_from_slice(&[0x00, 0x2F]);
// Compression methods length: 1 byte
hello.push(0x01);
// Compression method: null (0x00)
hello.push(0x00);
// Extensions
let extensions_start = hello.len();
// SNI Extension (type 0x0000)
let mut sni_extension = Vec::new();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
// Extension length (will be set later)
let ext_length_pos = sni_extension.len();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder
// Server name list
let server_name_list_start = sni_extension.len();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder for server name list length
// Server name: host_name(0), length, hostname
sni_extension.push(0x00); // Name type: host_name
let hostname_bytes = sni_hostname.as_bytes();
sni_extension.extend_from_slice(&[(hostname_bytes.len() >> 8) as u8, (hostname_bytes.len() & 0xFF) as u8]);
sni_extension.extend_from_slice(hostname_bytes);
// Set server name list length
let server_name_list_len = sni_extension.len() - server_name_list_start - 2;
let pos = server_name_list_start;
sni_extension[pos] = (server_name_list_len >> 8) as u8;
sni_extension[pos + 1] = (server_name_list_len & 0xFF) as u8;
// Set extension length
let ext_len = sni_extension.len() - ext_length_pos - 2;
sni_extension[ext_length_pos] = (ext_len >> 8) as u8;
sni_extension[ext_length_pos + 1] = (ext_len & 0xFF) as u8;
// Add SNI extension to hello
hello.extend_from_slice(&sni_extension);
// Set extensions total length
let extensions_length = hello.len() - extensions_start;
hello.insert(extensions_start, (extensions_length & 0xFF) as u8);
hello.insert(extensions_start, (extensions_length >> 8) as u8);
// Set handshake message length
let handshake_len = hello.len() - handshake_length_pos - 3;
hello[handshake_length_pos] = (handshake_len >> 16) as u8;
hello[handshake_length_pos + 1] = (handshake_len >> 8) as u8;
hello[handshake_length_pos + 2] = (handshake_len & 0xFF) as u8;
// Set record length
let record_len = hello.len() - record_length_pos - 2;
hello[record_length_pos] = (record_len >> 8) as u8;
hello[record_length_pos + 1] = (record_len & 0xFF) as u8;
hello
}
#[tokio::test] #[tokio::test]
async fn test_proxy() { async fn test_proxy() {
use crate::config::ConfigV1; use crate::config::ConfigV1;
@@ -170,4 +297,118 @@ mod tests {
} }
conn.shutdown().await.unwrap(); conn.shutdown().await.unwrap();
} }
#[tokio::test]
async fn test_wildcard_sni_routing() {
// Create test configuration with wildcard SNI pattern
use crate::upstreams::Upstream;
use std::collections::HashMap;
// Start mock servers for upstreams
thread::spawn(move || {
tls_mock_server_wildcard();
});
thread::spawn(move || {
tls_mock_server_default();
});
sleep(Duration::from_millis(500)); // wait for mock servers to start
// Create inline configuration
let mut config = crate::config::ParsedConfigV1 {
version: 1,
log: Some("disable".to_string()),
servers: HashMap::new(),
upstream: HashMap::new(),
};
// Add upstreams
config.upstream.insert(
"wildcard_upstream".to_string(),
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
"127.0.0.1:54598".to_string(),
"tcp".to_string(),
)),
);
config.upstream.insert(
"default_upstream".to_string(),
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
"127.0.0.1:54597".to_string(),
"tcp".to_string(),
)),
);
// Add TLS server with wildcard SNI pattern
let mut sni_map = HashMap::new();
sni_map.insert("*.api.example.com".to_string(), "wildcard_upstream".to_string());
let server_config = crate::config::ParsedServerConfig {
listen: vec!["127.0.0.1:54595".to_string()],
protocol: Some("tcp".to_string()),
tls: Some(true),
sni: Some(crate::sni_matcher::SniMatcher::new(sni_map).unwrap()),
default: Some("default_upstream".to_string()),
};
config.servers.insert("wildcard_test_server".to_string(), server_config);
// Start proxy server
let mut server = Server::new_from_v1_config(config);
thread::spawn(move || {
let _ = server.run();
});
sleep(Duration::from_secs(1)); // wait for proxy to start
// Test 1: Send ClientHello with SNI matching wildcard pattern
// Expected: Should route to wildcard_upstream (127.0.0.1:54598)
let client_hello = build_tls_client_hello("app.api.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello).await.unwrap();
let mut response = [0u8; 21];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from wildcard upstream");
assert_eq!(
&response[..n],
b"tls_wildcard_response",
"Should receive expected response from wildcard upstream"
);
conn.shutdown().await.unwrap();
// Test 2: Send ClientHello with SNI not matching any pattern
// Expected: Should route to default_upstream (127.0.0.1:54597)
let client_hello_nomatch = build_tls_client_hello("unrelated.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello_nomatch).await.unwrap();
let mut response = [0u8; 20];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from default upstream");
assert_eq!(
&response[..n],
b"tls_default_response",
"Should receive expected response from default upstream"
);
conn.shutdown().await.unwrap();
// Test 3: Send ClientHello with another SNI matching wildcard pattern
let client_hello_match2 = build_tls_client_hello("v2.api.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello_match2).await.unwrap();
let mut response = [0u8; 21];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from wildcard upstream for second match");
assert_eq!(
&response[..n],
b"tls_wildcard_response",
"Should receive expected response from wildcard upstream for second match"
);
conn.shutdown().await.unwrap();
}
} }

View File

@@ -206,36 +206,26 @@ pub(crate) async fn determine_upstream_name(
if snis.is_empty() { if snis.is_empty() {
debug!("No SNI found in ClientHello, using default upstream."); debug!("No SNI found in ClientHello, using default upstream.");
return Ok(default_upstream); return Ok(default_upstream);
} else { }
match proxy.sni.clone() {
Some(sni_map) => { if let Some(matcher) = &proxy.sni {
let mut upstream = default_upstream.clone(); // Clone here for default case
let mut found_match = false;
for sni in snis { for sni in snis {
// snis is already Vec<String> if let Some(upstream) = matcher.match_sni(&sni) {
if let Some(target_upstream) = sni_map.get(&sni) {
debug!( debug!(
"Found matching SNI '{}', routing to upstream: {}", "Found matching SNI '{}', routing to upstream: {}",
sni, target_upstream sni, upstream
); );
upstream = target_upstream.clone(); return Ok(upstream);
found_match = true;
break;
} else { } else {
trace!("SNI '{}' not found in map.", sni); trace!("SNI '{}' not found in matcher.", sni);
} }
} }
if !found_match {
debug!("SNI(s) found but none matched configuration, using default upstream."); debug!("SNI(s) found but none matched configuration, using default upstream.");
} else {
debug!("SNI found but no SNI matcher configured, using default upstream.");
} }
Ok(upstream)
}
None => {
debug!("SNI found but no SNI map configured, using default upstream.");
Ok(default_upstream) Ok(default_upstream)
}
}
}
} }
fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> { fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> {

391
src/sni_matcher.rs Normal file
View File

@@ -0,0 +1,391 @@
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SniPattern {
pub pattern: String,
pub upstream: String,
}
/// Validates and matches SNI patterns against incoming SNI values. Supports exact matches and wildcard patterns.
///
/// Rules:
/// - Wildcard patterns must start with "*." and have a valid domain suffix after it (e.g., "*.example.com")
/// - Wildcard patterns cannot have a wildcard in the middle (e.g., "*.example.*" is invalid)
/// - Wildcard patterns cannot have multiple wildcards (e.g., "*.*.example.com" is invalid)
/// - When there are two wildcard patterns that could match the same SNI, the longest suffix wins
/// (e.g., "*.example.com" vs "*.api.example.com" - "v2.api.example.com" matches "*.api.example.com")
/// - Wildcard patterns that overlap are not matched (e.g. "*.example.com" and "*.bar.example.com" - "bar.example.com" matches neither, "v2.bar.example.com" matches "*.bar.example.com")
/// - Wildcard patterns cannot be just "*." or "*" (must have a valid suffix)
/// - For known public suffixes (e.g., "com", "org"), the wildcard must have at least one label below the public suffix (e.g., "*.example.com" is valid, but "*.com", or *.co.uk is invalid)
/// - For unknown suffixes (e.g., "local", "lan"), the wildcard is allowed without restriction (e.g., "*.local" is valid)
#[derive(Debug, Clone)]
pub struct SniMatcher {
pub exact: HashMap<String, String>,
pub wildcards: Vec<SniPattern>,
}
impl SniMatcher {
pub fn new(sni_map: HashMap<String, String>) -> Result<Self, Vec<String>> {
Self::validate(&sni_map)?;
let mut exact = HashMap::new();
let mut wildcards = Vec::new();
for (pattern, upstream) in sni_map {
if pattern.starts_with("*.") {
wildcards.push(SniPattern {
pattern: pattern.clone(),
upstream,
});
} else {
exact.insert(pattern, upstream);
}
}
wildcards.sort_by(|a, b| {
let a_suffix = a.pattern.trim_start_matches("*.");
let b_suffix = b.pattern.trim_start_matches("*.");
b_suffix.len().cmp(&a_suffix.len())
});
Ok(SniMatcher { exact, wildcards })
}
/// Matches the provided SNI against the patterns in the matcher. Returns Some(upstream) if a match is found, or None if no match is found.
pub fn match_sni(&self, sni: &str) -> Option<String> {
if let Some(upstream) = self.exact.get(sni) {
return Some(upstream.clone());
}
// Try each wildcard in order (longest suffix first)
for wildcard in &self.wildcards {
let suffix = wildcard.pattern.trim_start_matches("*.");
let suffix_len = suffix.len();
let check = format!(".{}", suffix);
if !sni.ends_with(&check) {
continue;
}
// Must have at least one label before the suffix to match
let prefix = &sni[..sni.len() - suffix_len - 1];
// Check if a more specific wildcard could also match this SNI
let is_owned = {
let sni_labels = sni.matches('.').count() + 1;
self.wildcards.iter().any(|w| {
// Skip the current wildcard itself
if w.pattern == wildcard.pattern {
return false;
}
let w_suffix = w.pattern.trim_start_matches("*.");
let w_len = w_suffix.len();
// Only care about wildcards with longer suffix (more specific)
if w_len <= suffix_len {
return false;
}
let w_suffix_labels = w_suffix.matches('.').count() + 1;
if sni == w_suffix {
// Exact match to more specific suffix - owned if could potentially match
// (sni has at least as many labels as the suffix needs)
sni_labels >= w_suffix_labels
} else if sni.ends_with(&format!(".{}", w_suffix)) {
// Ends with more specific suffix - owned if SNI has enough labels
sni_labels >= w_suffix_labels + 1
} else {
false
}
})
};
if is_owned {
continue;
}
// Only return if we have a valid prefix (at least one label)
if !prefix.is_empty() {
return Some(wildcard.upstream.clone());
}
}
None
}
fn validate_wildcard_suffix(pattern: &str) -> Result<(), String> {
let suffix = pattern.trim_start_matches("*.");
let domain_str = format!("a.{}", suffix);
if let Some(ps) = psl::suffix(domain_str.as_bytes()) {
let ps_str = std::str::from_utf8(ps.as_bytes()).unwrap_or("");
if ps_str == suffix {
return Err(format!(
"Invalid wildcard pattern: {} - wildcard cannot be at the public suffix level",
pattern
));
}
}
Ok(())
}
fn validate(sni_map: &HashMap<String, String>) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
for (pattern, _upstream) in sni_map {
if pattern == "*" {
errors.push(format!(
"Invalid wildcard pattern: * - just asterisk is not allowed"
));
continue;
}
if pattern == "*." {
errors.push(format!(
"Invalid wildcard pattern: *. - empty suffix after wildcard"
));
continue;
}
if let Some(rest) = pattern.strip_prefix("*.") {
if rest.is_empty() {
errors.push(format!(
"Invalid wildcard pattern: {pattern} - empty suffix after wildcard"
));
continue;
}
if rest.contains('*') {
errors.push(format!("Invalid wildcard pattern: {pattern} - wildcard cannot be in the middle of suffix"));
continue;
}
if let Err(e) = Self::validate_wildcard_suffix(pattern) {
errors.push(e);
}
} else if pattern.contains('*') {
errors.push(format!(
"Invalid wildcard pattern: {pattern} - wildcard must be at the start"
));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("example.com"),
Some("upstream1".to_string())
);
}
#[test]
fn test_wildcard_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("www.example.com"),
Some("upstream2".to_string())
);
assert_eq!(
matcher.match_sni("api.example.com"),
Some("upstream2".to_string())
);
}
#[test]
fn test_longest_suffix_match() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "wildcard1".to_string());
sni_map.insert("*.api.example.com".to_string(), "wildcard2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("v2.api.example.com"),
Some("wildcard2".to_string())
);
}
#[test]
fn test_no_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("other.com"), None);
}
#[test]
fn test_unknown_suffix_allowed() {
let mut sni_map = HashMap::new();
sni_map.insert(
"*.private.local".to_string(),
"private_upstream".to_string(),
);
sni_map.insert(
"*.internal.net".to_string(),
"internal_upstream".to_string(),
);
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("server.private.local"),
Some("private_upstream".to_string())
);
assert_eq!(
matcher.match_sni("app.internal.net"),
Some("internal_upstream".to_string())
);
}
#[test]
fn test_invalid_public_suffix() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(!errors.is_empty());
assert!(errors[0].contains("*.com"));
}
#[test]
fn test_multiple_errors_collected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid1".to_string());
sni_map.insert("*.org".to_string(), "invalid2".to_string());
sni_map.insert("*.net".to_string(), "invalid3".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 3);
}
#[test]
fn test_valid_public_suffix() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "valid".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("www.example.com"),
Some("valid".to_string())
);
}
#[test]
fn test_validate_static() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid".to_string());
sni_map.insert("*.example.com".to_string(), "valid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("*.com"));
}
#[test]
fn test_wildcard_not_at_start_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("foo*.example.com".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(!errors.is_empty());
assert!(errors[0].contains("*.example.com"));
}
#[test]
fn test_wildcard_in_middle_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.*".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_trailing_dot_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_just_asterisk_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_wildcard_requires_subdomain() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "upstream".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("example.com"), None);
assert_eq!(
matcher.match_sni("www.example.com"),
Some("upstream".to_string())
);
assert_eq!(
matcher.match_sni("foo.bar.example.com"),
Some("upstream".to_string())
);
}
#[test]
fn test_longest_suffix_wins_not_shortest() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.org".to_string(), "broad".to_string());
sni_map.insert("*.bar.example.org".to_string(), "narrow".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("bar.example.org"), None);
assert_eq!(
matcher.match_sni("v2.bar.example.org"),
Some("narrow".to_string())
);
assert_eq!(
matcher.match_sni("v2.example.org"),
Some("broad".to_string())
);
}
}