Compare commits
2 Commits
a674895173
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
aff96d1a01
|
|||
|
590740f40e
|
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -516,7 +516,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "l4p"
|
||||
version = "0.1.10"
|
||||
version = "0.1.12"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byte_string",
|
||||
@@ -525,6 +525,7 @@ dependencies = [
|
||||
"log",
|
||||
"pico-args",
|
||||
"pretty_env_logger",
|
||||
"psl",
|
||||
"self_update",
|
||||
"serde",
|
||||
"serde_yaml",
|
||||
@@ -856,6 +857,21 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psl"
|
||||
version = "2.1.199"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70b63978a2742d3f662188698ab45854156e7e34658f53fa951e9253a3dfd583"
|
||||
dependencies = [
|
||||
"psl-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psl-types"
|
||||
version = "2.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac"
|
||||
|
||||
[[package]]
|
||||
name = "quick-xml"
|
||||
version = "0.37.2"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "l4p"
|
||||
version = "0.1.11"
|
||||
version = "0.1.12"
|
||||
edition = "2021"
|
||||
authors = ["Jacob Kiers <code@kiers.eu>"]
|
||||
license = "Apache-2.0"
|
||||
@@ -33,6 +33,7 @@ time = { version = "0.3.37", features = ["local-offset", "formatting"] }
|
||||
tls-parser = "0.12.2"
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
url = "2.2.2"
|
||||
psl = "2.1"
|
||||
|
||||
[dependencies.self_update]
|
||||
version = "0.42.0"
|
||||
|
||||
15
README.md
15
README.md
@@ -31,6 +31,13 @@ $ cargo install l4p
|
||||
|
||||
Or you can download binary file form the Release page.
|
||||
|
||||
## Features
|
||||
|
||||
- Listen on specific port and proxy to local or remote port
|
||||
- SNI-based rule without terminating TLS connection
|
||||
- Wildcard SNI matching with DNS-style longest-suffix-match
|
||||
- DNS-based backend with periodic resolution
|
||||
|
||||
## Configuration
|
||||
|
||||
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
|
||||
@@ -55,6 +62,14 @@ There are two upstreams built in:
|
||||
|
||||
For detailed configuration, check [this example](./config.yaml.example).
|
||||
|
||||
### SNI Matching
|
||||
|
||||
The proxy supports both exact and wildcard SNI patterns in the `sni` config. Wildcards use DNS-style longest-suffix-match: more specific patterns take precedence. For example, with `*.example.com` and `*.api.example.com`, request `api.example.com` matches the first, while `v2.api.example.com` matches the second.
|
||||
|
||||
Wildcards are validated against the Public Suffix List (PSL). Known suffixes (`.com`, `.org`) require at least one label below the suffix (`*.example.com` OK, `*.com` rejected). Unknown suffixes (`.local`, `.lan`) are allowed without restriction.
|
||||
|
||||
Invalid wildcard patterns are rejected at config load time with clear error messages.
|
||||
|
||||
## Thanks
|
||||
|
||||
- [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork.
|
||||
|
||||
@@ -10,6 +10,9 @@ servers:
|
||||
sni:
|
||||
api.example.org: example-api
|
||||
www.example.org: proxy
|
||||
*.example.org: wildcard-proxy # Matches any subdomain of example.org
|
||||
*.dev.example.org: dev-proxy # More specific: matches v2.dev.example.org, etc.
|
||||
*.local: local-upstream # Unknown suffix - allowed (no PSL restriction)
|
||||
default: ban
|
||||
|
||||
second-server:
|
||||
@@ -19,3 +22,6 @@ servers:
|
||||
upstream:
|
||||
proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
|
||||
example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
|
||||
wildcard-proxy: "tcp://wildcard.example.org:443"
|
||||
dev-proxy: "tcp://dev.example.org:443"
|
||||
local-upstream: "tcp://localhost:8080"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::sni_matcher::SniMatcher;
|
||||
use crate::upstreams::ProxyToUpstream;
|
||||
use crate::upstreams::Upstream;
|
||||
use log::{debug, info, warn};
|
||||
@@ -12,14 +13,6 @@ pub struct ConfigV1 {
|
||||
pub base: ParsedConfigV1,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Clone)]
|
||||
pub struct ParsedConfigV1 {
|
||||
pub version: i32,
|
||||
pub log: Option<String>,
|
||||
pub servers: HashMap<String, ServerConfig>,
|
||||
pub upstream: HashMap<String, Upstream>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Clone)]
|
||||
pub struct BaseConfig {
|
||||
pub version: i32,
|
||||
@@ -28,6 +21,14 @@ pub struct BaseConfig {
|
||||
pub upstream: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ParsedConfigV1 {
|
||||
pub version: i32,
|
||||
pub log: Option<String>,
|
||||
pub servers: HashMap<String, ParsedServerConfig>,
|
||||
pub upstream: HashMap<String, Upstream>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
pub listen: Vec<String>,
|
||||
@@ -36,6 +37,35 @@ pub struct ServerConfig {
|
||||
pub sni: Option<HashMap<String, String>>,
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
impl ServerConfig {
|
||||
pub fn into_parsed(self) -> Result<ParsedServerConfig, Vec<String>> {
|
||||
let sni = match self.sni {
|
||||
Some(sni_map) => {
|
||||
let matcher = SniMatcher::new(sni_map)?;
|
||||
Some(matcher)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(ParsedServerConfig {
|
||||
listen: self.listen,
|
||||
protocol: self.protocol,
|
||||
tls: self.tls,
|
||||
sni,
|
||||
default: self.default,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParsedServerConfig {
|
||||
pub listen: Vec<String>,
|
||||
pub protocol: Option<String>,
|
||||
pub tls: Option<bool>,
|
||||
pub sni: Option<SniMatcher>,
|
||||
pub default: Option<String>,
|
||||
}
|
||||
impl TryInto<ProxyToUpstream> for &str {
|
||||
type Error = ConfigError;
|
||||
|
||||
@@ -102,12 +132,23 @@ impl ConfigV1 {
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||
let mut contents = String::new();
|
||||
let mut file = File::open(path)?;
|
||||
file.read_to_string(&mut contents)?;
|
||||
|
||||
let base: BaseConfig = serde_yaml::from_str(&contents)?;
|
||||
/// Load and parse configuration from a YAML string.
|
||||
///
|
||||
/// This public function takes raw YAML content as a string and returns a parsed,
|
||||
/// validated configuration. It performs all validation including:
|
||||
/// - Version checking
|
||||
/// - SNI pattern validation
|
||||
/// - Upstream URL parsing
|
||||
/// - Cross-reference validation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `yaml_str` - The YAML configuration content as a string
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(ParsedConfigV1)` - Successfully parsed and validated configuration
|
||||
/// * `Err(ConfigError)` - If YAML parsing fails or validation errors occur
|
||||
pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||
let base: BaseConfig = serde_yaml::from_str(yaml_str)?;
|
||||
|
||||
if base.version != 1 {
|
||||
return Err(ConfigError::Custom(
|
||||
@@ -117,11 +158,12 @@ fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||
|
||||
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
|
||||
if !log_level.eq("disable") {
|
||||
std::env::set_var("FOURTH_LOG", log_level.clone());
|
||||
pretty_env_logger::init_custom_env("FOURTH_LOG");
|
||||
unsafe {
|
||||
std::env::set_var("FOURTH_LOG", log_level.clone());
|
||||
pretty_env_logger::init_custom_env("FOURTH_LOG");
|
||||
}
|
||||
}
|
||||
|
||||
info!("Using config file: {}", &path);
|
||||
debug!("Set log level to {}", log_level);
|
||||
debug!("Config version {}", base.version);
|
||||
|
||||
@@ -135,16 +177,50 @@ fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
|
||||
}
|
||||
|
||||
// Convert ServerConfig to ParsedServerConfig, collecting all SNI validation errors
|
||||
let mut all_errors: Vec<String> = Vec::new();
|
||||
let mut parsed_servers: HashMap<String, ParsedServerConfig> = HashMap::new();
|
||||
|
||||
for (name, server_config) in base.servers {
|
||||
match server_config.into_parsed() {
|
||||
Ok(parsed) => {
|
||||
parsed_servers.insert(name, parsed);
|
||||
}
|
||||
Err(errors) => {
|
||||
for err in errors {
|
||||
all_errors.push(format!("Server '{}': {}", name, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !all_errors.is_empty() {
|
||||
return Err(ConfigError::Custom(format!(
|
||||
"Invalid SNI configuration:\n{}",
|
||||
all_errors.join("\n")
|
||||
)));
|
||||
}
|
||||
|
||||
let parsed = ParsedConfigV1 {
|
||||
version: base.version,
|
||||
log: base.log,
|
||||
servers: base.servers,
|
||||
servers: parsed_servers,
|
||||
upstream: parsed_upstream,
|
||||
};
|
||||
|
||||
verify_config(parsed)
|
||||
}
|
||||
|
||||
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||
let mut contents = String::new();
|
||||
let mut file = File::open(path)?;
|
||||
file.read_to_string(&mut contents)?;
|
||||
|
||||
info!("Using config file: {}", &path);
|
||||
|
||||
load_config_from_yaml(&contents)
|
||||
}
|
||||
|
||||
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
|
||||
let mut used_upstreams: HashSet<String> = HashSet::new();
|
||||
let mut upstream_names: HashSet<String> = HashSet::new();
|
||||
@@ -175,14 +251,20 @@ fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError>
|
||||
listen_addresses.insert(listen.to_string());
|
||||
}
|
||||
|
||||
if server.tls.unwrap_or_default() && server.sni.is_some() {
|
||||
for (_, val) in server.sni.unwrap() {
|
||||
used_upstreams.insert(val.to_string());
|
||||
if server.tls.unwrap_or_default() {
|
||||
if let Some(matcher) = &server.sni {
|
||||
// Collect all upstream names from the SniMatcher
|
||||
for (_, upstream) in matcher.exact.iter() {
|
||||
used_upstreams.insert(upstream.clone());
|
||||
}
|
||||
for pattern in &matcher.wildcards {
|
||||
used_upstreams.insert(pattern.upstream.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if server.default.is_some() {
|
||||
used_upstreams.insert(server.default.unwrap().to_string());
|
||||
if let Some(default) = &server.default {
|
||||
used_upstreams.insert(default.clone());
|
||||
}
|
||||
|
||||
for key in &used_upstreams {
|
||||
@@ -225,4 +307,45 @@ mod tests {
|
||||
assert_eq!(config.base.servers.len(), 3);
|
||||
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_hard_failure_on_invalid_sni() {
|
||||
// Test that invalid SNI wildcard (*.com) causes hard failure
|
||||
let config_content = r#"version: 1
|
||||
log: disable
|
||||
servers:
|
||||
test_server:
|
||||
listen:
|
||||
- "127.0.0.1:8443"
|
||||
protocol: tcp
|
||||
tls: true
|
||||
sni:
|
||||
"*.com": "upstream1"
|
||||
default: ban
|
||||
upstream:
|
||||
upstream1: tcp://127.0.0.1:9000
|
||||
"#;
|
||||
|
||||
let result = load_config_from_yaml(config_content);
|
||||
|
||||
// Should fail with an error
|
||||
assert!(result.is_err(), "Expected config to fail with invalid SNI");
|
||||
|
||||
// Verify error message contains helpful information
|
||||
match result {
|
||||
Err(ConfigError::Custom(msg)) => {
|
||||
assert!(
|
||||
msg.contains("Invalid SNI"),
|
||||
"Error message should mention invalid SNI: {}",
|
||||
msg
|
||||
);
|
||||
assert!(
|
||||
msg.contains("*.com"),
|
||||
"Error message should mention the invalid pattern: {}",
|
||||
msg
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected ConfigError::Custom"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod config_v1;
|
||||
pub(crate) use config_v1::ConfigV1;
|
||||
pub(crate) use config_v1::ParsedConfigV1;
|
||||
pub(crate) use config_v1::ParsedServerConfig;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod config;
|
||||
mod servers;
|
||||
mod sni_matcher;
|
||||
mod update;
|
||||
mod upstreams;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ mod protocol;
|
||||
pub(crate) mod upstream_address;
|
||||
|
||||
use crate::config::ParsedConfigV1;
|
||||
use crate::sni_matcher::SniMatcher;
|
||||
use crate::upstreams::Upstream;
|
||||
use protocol::tcp;
|
||||
|
||||
@@ -23,7 +24,7 @@ pub(crate) struct Proxy {
|
||||
pub listen: SocketAddr,
|
||||
pub protocol: String,
|
||||
pub tls: bool,
|
||||
pub sni: Option<HashMap<String, String>>,
|
||||
pub sni: Option<SniMatcher>,
|
||||
pub default_action: String,
|
||||
pub upstream: HashMap<String, Upstream>,
|
||||
}
|
||||
@@ -134,6 +135,132 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock server for wildcard SNI test that responds with "tls_wildcard_response" on first read
|
||||
#[tokio::main]
|
||||
async fn tls_mock_server_wildcard() {
|
||||
let server_addr: SocketAddr = "127.0.0.1:54598".parse().unwrap();
|
||||
let listener = TcpListener::bind(server_addr).await.unwrap();
|
||||
loop {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
// Read client hello (which will be peeked but not actually read by proxy)
|
||||
let mut buf = [0u8; 1024];
|
||||
let _ = stream.read(&mut buf).await;
|
||||
// Send a response to verify connection succeeded
|
||||
let _ = stream.write(b"tls_wildcard_response").await;
|
||||
let _ = stream.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock server for SNI test that doesn't match wildcard pattern
|
||||
#[tokio::main]
|
||||
async fn tls_mock_server_default() {
|
||||
let server_addr: SocketAddr = "127.0.0.1:54597".parse().unwrap();
|
||||
let listener = TcpListener::bind(server_addr).await.unwrap();
|
||||
loop {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 1024];
|
||||
let _ = stream.read(&mut buf).await;
|
||||
let _ = stream.write(b"tls_default_response").await;
|
||||
let _ = stream.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to build a minimal TLS ClientHello with SNI extension
|
||||
/// This creates a valid TLS 1.2 ClientHello packet with the specified SNI hostname
|
||||
fn build_tls_client_hello(sni_hostname: &str) -> Vec<u8> {
|
||||
// TLS record header (9 bytes)
|
||||
let mut hello = Vec::new();
|
||||
|
||||
// Record type: Handshake (0x16)
|
||||
hello.push(0x16);
|
||||
// Version: TLS 1.2 (0x0303)
|
||||
hello.extend_from_slice(&[0x03, 0x03]);
|
||||
|
||||
// We'll set the record length later
|
||||
let record_length_pos = hello.len();
|
||||
hello.extend_from_slice(&[0x00, 0x00]); // Placeholder for record length
|
||||
|
||||
// Handshake message type: ClientHello (0x01)
|
||||
hello.push(0x01);
|
||||
|
||||
// We'll set the handshake length later
|
||||
let handshake_length_pos = hello.len();
|
||||
hello.extend_from_slice(&[0x00, 0x00, 0x00]); // Placeholder for handshake length
|
||||
|
||||
// ClientHello fields
|
||||
// Protocol version: TLS 1.2 (0x0303)
|
||||
hello.extend_from_slice(&[0x03, 0x03]);
|
||||
|
||||
// Random: 32 bytes (we'll use a fixed pattern)
|
||||
hello.extend_from_slice(&[0x00; 32]);
|
||||
|
||||
// Session ID length: 0 (no session)
|
||||
hello.push(0x00);
|
||||
|
||||
// Cipher suites length: 2 bytes + cipher suites
|
||||
hello.extend_from_slice(&[0x00, 0x02]); // Length of cipher suites list (2 bytes)
|
||||
// Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002F)
|
||||
hello.extend_from_slice(&[0x00, 0x2F]);
|
||||
|
||||
// Compression methods length: 1 byte
|
||||
hello.push(0x01);
|
||||
// Compression method: null (0x00)
|
||||
hello.push(0x00);
|
||||
|
||||
// Extensions
|
||||
let extensions_start = hello.len();
|
||||
|
||||
// SNI Extension (type 0x0000)
|
||||
let mut sni_extension = Vec::new();
|
||||
sni_extension.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
|
||||
|
||||
// Extension length (will be set later)
|
||||
let ext_length_pos = sni_extension.len();
|
||||
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder
|
||||
|
||||
// Server name list
|
||||
let server_name_list_start = sni_extension.len();
|
||||
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder for server name list length
|
||||
|
||||
// Server name: host_name(0), length, hostname
|
||||
sni_extension.push(0x00); // Name type: host_name
|
||||
let hostname_bytes = sni_hostname.as_bytes();
|
||||
sni_extension.extend_from_slice(&[(hostname_bytes.len() >> 8) as u8, (hostname_bytes.len() & 0xFF) as u8]);
|
||||
sni_extension.extend_from_slice(hostname_bytes);
|
||||
|
||||
// Set server name list length
|
||||
let server_name_list_len = sni_extension.len() - server_name_list_start - 2;
|
||||
let pos = server_name_list_start;
|
||||
sni_extension[pos] = (server_name_list_len >> 8) as u8;
|
||||
sni_extension[pos + 1] = (server_name_list_len & 0xFF) as u8;
|
||||
|
||||
// Set extension length
|
||||
let ext_len = sni_extension.len() - ext_length_pos - 2;
|
||||
sni_extension[ext_length_pos] = (ext_len >> 8) as u8;
|
||||
sni_extension[ext_length_pos + 1] = (ext_len & 0xFF) as u8;
|
||||
|
||||
// Add SNI extension to hello
|
||||
hello.extend_from_slice(&sni_extension);
|
||||
|
||||
// Set extensions total length
|
||||
let extensions_length = hello.len() - extensions_start;
|
||||
hello.insert(extensions_start, (extensions_length & 0xFF) as u8);
|
||||
hello.insert(extensions_start, (extensions_length >> 8) as u8);
|
||||
|
||||
// Set handshake message length
|
||||
let handshake_len = hello.len() - handshake_length_pos - 3;
|
||||
hello[handshake_length_pos] = (handshake_len >> 16) as u8;
|
||||
hello[handshake_length_pos + 1] = (handshake_len >> 8) as u8;
|
||||
hello[handshake_length_pos + 2] = (handshake_len & 0xFF) as u8;
|
||||
|
||||
// Set record length
|
||||
let record_len = hello.len() - record_length_pos - 2;
|
||||
hello[record_length_pos] = (record_len >> 8) as u8;
|
||||
hello[record_length_pos + 1] = (record_len & 0xFF) as u8;
|
||||
|
||||
hello
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy() {
|
||||
use crate::config::ConfigV1;
|
||||
@@ -170,4 +297,118 @@ mod tests {
|
||||
}
|
||||
conn.shutdown().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wildcard_sni_routing() {
|
||||
// Create test configuration with wildcard SNI pattern
|
||||
use crate::upstreams::Upstream;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Start mock servers for upstreams
|
||||
thread::spawn(move || {
|
||||
tls_mock_server_wildcard();
|
||||
});
|
||||
thread::spawn(move || {
|
||||
tls_mock_server_default();
|
||||
});
|
||||
sleep(Duration::from_millis(500)); // wait for mock servers to start
|
||||
|
||||
// Create inline configuration
|
||||
let mut config = crate::config::ParsedConfigV1 {
|
||||
version: 1,
|
||||
log: Some("disable".to_string()),
|
||||
servers: HashMap::new(),
|
||||
upstream: HashMap::new(),
|
||||
};
|
||||
|
||||
// Add upstreams
|
||||
config.upstream.insert(
|
||||
"wildcard_upstream".to_string(),
|
||||
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
|
||||
"127.0.0.1:54598".to_string(),
|
||||
"tcp".to_string(),
|
||||
)),
|
||||
);
|
||||
config.upstream.insert(
|
||||
"default_upstream".to_string(),
|
||||
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
|
||||
"127.0.0.1:54597".to_string(),
|
||||
"tcp".to_string(),
|
||||
)),
|
||||
);
|
||||
|
||||
// Add TLS server with wildcard SNI pattern
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.api.example.com".to_string(), "wildcard_upstream".to_string());
|
||||
|
||||
let server_config = crate::config::ParsedServerConfig {
|
||||
listen: vec!["127.0.0.1:54595".to_string()],
|
||||
protocol: Some("tcp".to_string()),
|
||||
tls: Some(true),
|
||||
sni: Some(crate::sni_matcher::SniMatcher::new(sni_map).unwrap()),
|
||||
default: Some("default_upstream".to_string()),
|
||||
};
|
||||
|
||||
config.servers.insert("wildcard_test_server".to_string(), server_config);
|
||||
|
||||
// Start proxy server
|
||||
let mut server = Server::new_from_v1_config(config);
|
||||
thread::spawn(move || {
|
||||
let _ = server.run();
|
||||
});
|
||||
sleep(Duration::from_secs(1)); // wait for proxy to start
|
||||
|
||||
// Test 1: Send ClientHello with SNI matching wildcard pattern
|
||||
// Expected: Should route to wildcard_upstream (127.0.0.1:54598)
|
||||
let client_hello = build_tls_client_hello("app.api.example.com");
|
||||
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _ = conn.write(&client_hello).await.unwrap();
|
||||
let mut response = [0u8; 21];
|
||||
let n = conn.read(&mut response).await.unwrap();
|
||||
assert!(n > 0, "Should receive response from wildcard upstream");
|
||||
assert_eq!(
|
||||
&response[..n],
|
||||
b"tls_wildcard_response",
|
||||
"Should receive expected response from wildcard upstream"
|
||||
);
|
||||
conn.shutdown().await.unwrap();
|
||||
|
||||
// Test 2: Send ClientHello with SNI not matching any pattern
|
||||
// Expected: Should route to default_upstream (127.0.0.1:54597)
|
||||
let client_hello_nomatch = build_tls_client_hello("unrelated.example.com");
|
||||
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _ = conn.write(&client_hello_nomatch).await.unwrap();
|
||||
let mut response = [0u8; 20];
|
||||
let n = conn.read(&mut response).await.unwrap();
|
||||
assert!(n > 0, "Should receive response from default upstream");
|
||||
assert_eq!(
|
||||
&response[..n],
|
||||
b"tls_default_response",
|
||||
"Should receive expected response from default upstream"
|
||||
);
|
||||
conn.shutdown().await.unwrap();
|
||||
|
||||
// Test 3: Send ClientHello with another SNI matching wildcard pattern
|
||||
let client_hello_match2 = build_tls_client_hello("v2.api.example.com");
|
||||
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _ = conn.write(&client_hello_match2).await.unwrap();
|
||||
let mut response = [0u8; 21];
|
||||
let n = conn.read(&mut response).await.unwrap();
|
||||
assert!(n > 0, "Should receive response from wildcard upstream for second match");
|
||||
assert_eq!(
|
||||
&response[..n],
|
||||
b"tls_wildcard_response",
|
||||
"Should receive expected response from wildcard upstream for second match"
|
||||
);
|
||||
conn.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,36 +206,26 @@ pub(crate) async fn determine_upstream_name(
|
||||
if snis.is_empty() {
|
||||
debug!("No SNI found in ClientHello, using default upstream.");
|
||||
return Ok(default_upstream);
|
||||
} else {
|
||||
match proxy.sni.clone() {
|
||||
Some(sni_map) => {
|
||||
let mut upstream = default_upstream.clone(); // Clone here for default case
|
||||
let mut found_match = false;
|
||||
for sni in snis {
|
||||
// snis is already Vec<String>
|
||||
if let Some(target_upstream) = sni_map.get(&sni) {
|
||||
debug!(
|
||||
"Found matching SNI '{}', routing to upstream: {}",
|
||||
sni, target_upstream
|
||||
);
|
||||
upstream = target_upstream.clone();
|
||||
found_match = true;
|
||||
break;
|
||||
} else {
|
||||
trace!("SNI '{}' not found in map.", sni);
|
||||
}
|
||||
}
|
||||
if !found_match {
|
||||
debug!("SNI(s) found but none matched configuration, using default upstream.");
|
||||
}
|
||||
Ok(upstream)
|
||||
}
|
||||
None => {
|
||||
debug!("SNI found but no SNI map configured, using default upstream.");
|
||||
Ok(default_upstream)
|
||||
}
|
||||
|
||||
if let Some(matcher) = &proxy.sni {
|
||||
for sni in snis {
|
||||
if let Some(upstream) = matcher.match_sni(&sni) {
|
||||
debug!(
|
||||
"Found matching SNI '{}', routing to upstream: {}",
|
||||
sni, upstream
|
||||
);
|
||||
return Ok(upstream);
|
||||
} else {
|
||||
trace!("SNI '{}' not found in matcher.", sni);
|
||||
}
|
||||
}
|
||||
debug!("SNI(s) found but none matched configuration, using default upstream.");
|
||||
} else {
|
||||
debug!("SNI found but no SNI matcher configured, using default upstream.");
|
||||
}
|
||||
|
||||
Ok(default_upstream)
|
||||
}
|
||||
|
||||
fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> {
|
||||
|
||||
391
src/sni_matcher.rs
Normal file
391
src/sni_matcher.rs
Normal file
@@ -0,0 +1,391 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SniPattern {
|
||||
pub pattern: String,
|
||||
pub upstream: String,
|
||||
}
|
||||
|
||||
/// Validates and matches SNI patterns against incoming SNI values. Supports exact matches and wildcard patterns.
|
||||
///
|
||||
/// Rules:
|
||||
/// - Wildcard patterns must start with "*." and have a valid domain suffix after it (e.g., "*.example.com")
|
||||
/// - Wildcard patterns cannot have a wildcard in the middle (e.g., "*.example.*" is invalid)
|
||||
/// - Wildcard patterns cannot have multiple wildcards (e.g., "*.*.example.com" is invalid)
|
||||
/// - When there are two wildcard patterns that could match the same SNI, the longest suffix wins
|
||||
/// (e.g., "*.example.com" vs "*.api.example.com" - "v2.api.example.com" matches "*.api.example.com")
|
||||
/// - Wildcard patterns that overlap are not matched (e.g. "*.example.com" and "*.bar.example.com" - "bar.example.com" matches neither, "v2.bar.example.com" matches "*.bar.example.com")
|
||||
/// - Wildcard patterns cannot be just "*." or "*" (must have a valid suffix)
|
||||
/// - For known public suffixes (e.g., "com", "org"), the wildcard must have at least one label below the public suffix (e.g., "*.example.com" is valid, but "*.com", or *.co.uk is invalid)
|
||||
/// - For unknown suffixes (e.g., "local", "lan"), the wildcard is allowed without restriction (e.g., "*.local" is valid)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SniMatcher {
|
||||
pub exact: HashMap<String, String>,
|
||||
pub wildcards: Vec<SniPattern>,
|
||||
}
|
||||
|
||||
impl SniMatcher {
|
||||
pub fn new(sni_map: HashMap<String, String>) -> Result<Self, Vec<String>> {
|
||||
Self::validate(&sni_map)?;
|
||||
|
||||
let mut exact = HashMap::new();
|
||||
let mut wildcards = Vec::new();
|
||||
|
||||
for (pattern, upstream) in sni_map {
|
||||
if pattern.starts_with("*.") {
|
||||
wildcards.push(SniPattern {
|
||||
pattern: pattern.clone(),
|
||||
upstream,
|
||||
});
|
||||
} else {
|
||||
exact.insert(pattern, upstream);
|
||||
}
|
||||
}
|
||||
|
||||
wildcards.sort_by(|a, b| {
|
||||
let a_suffix = a.pattern.trim_start_matches("*.");
|
||||
let b_suffix = b.pattern.trim_start_matches("*.");
|
||||
b_suffix.len().cmp(&a_suffix.len())
|
||||
});
|
||||
|
||||
Ok(SniMatcher { exact, wildcards })
|
||||
}
|
||||
|
||||
/// Matches the provided SNI against the patterns in the matcher. Returns Some(upstream) if a match is found, or None if no match is found.
|
||||
pub fn match_sni(&self, sni: &str) -> Option<String> {
|
||||
if let Some(upstream) = self.exact.get(sni) {
|
||||
return Some(upstream.clone());
|
||||
}
|
||||
|
||||
// Try each wildcard in order (longest suffix first)
|
||||
for wildcard in &self.wildcards {
|
||||
let suffix = wildcard.pattern.trim_start_matches("*.");
|
||||
let suffix_len = suffix.len();
|
||||
let check = format!(".{}", suffix);
|
||||
|
||||
if !sni.ends_with(&check) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Must have at least one label before the suffix to match
|
||||
let prefix = &sni[..sni.len() - suffix_len - 1];
|
||||
|
||||
// Check if a more specific wildcard could also match this SNI
|
||||
let is_owned = {
|
||||
let sni_labels = sni.matches('.').count() + 1;
|
||||
|
||||
self.wildcards.iter().any(|w| {
|
||||
// Skip the current wildcard itself
|
||||
if w.pattern == wildcard.pattern {
|
||||
return false;
|
||||
}
|
||||
|
||||
let w_suffix = w.pattern.trim_start_matches("*.");
|
||||
let w_len = w_suffix.len();
|
||||
|
||||
// Only care about wildcards with longer suffix (more specific)
|
||||
if w_len <= suffix_len {
|
||||
return false;
|
||||
}
|
||||
|
||||
let w_suffix_labels = w_suffix.matches('.').count() + 1;
|
||||
|
||||
if sni == w_suffix {
|
||||
// Exact match to more specific suffix - owned if could potentially match
|
||||
// (sni has at least as many labels as the suffix needs)
|
||||
sni_labels >= w_suffix_labels
|
||||
} else if sni.ends_with(&format!(".{}", w_suffix)) {
|
||||
// Ends with more specific suffix - owned if SNI has enough labels
|
||||
sni_labels >= w_suffix_labels + 1
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
if is_owned {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only return if we have a valid prefix (at least one label)
|
||||
if !prefix.is_empty() {
|
||||
return Some(wildcard.upstream.clone());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn validate_wildcard_suffix(pattern: &str) -> Result<(), String> {
|
||||
let suffix = pattern.trim_start_matches("*.");
|
||||
let domain_str = format!("a.{}", suffix);
|
||||
|
||||
if let Some(ps) = psl::suffix(domain_str.as_bytes()) {
|
||||
let ps_str = std::str::from_utf8(ps.as_bytes()).unwrap_or("");
|
||||
if ps_str == suffix {
|
||||
return Err(format!(
|
||||
"Invalid wildcard pattern: {} - wildcard cannot be at the public suffix level",
|
||||
pattern
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate(sni_map: &HashMap<String, String>) -> Result<(), Vec<String>> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
for (pattern, _upstream) in sni_map {
|
||||
if pattern == "*" {
|
||||
errors.push(format!(
|
||||
"Invalid wildcard pattern: * - just asterisk is not allowed"
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if pattern == "*." {
|
||||
errors.push(format!(
|
||||
"Invalid wildcard pattern: *. - empty suffix after wildcard"
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = pattern.strip_prefix("*.") {
|
||||
if rest.is_empty() {
|
||||
errors.push(format!(
|
||||
"Invalid wildcard pattern: {pattern} - empty suffix after wildcard"
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if rest.contains('*') {
|
||||
errors.push(format!("Invalid wildcard pattern: {pattern} - wildcard cannot be in the middle of suffix"));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = Self::validate_wildcard_suffix(pattern) {
|
||||
errors.push(e);
|
||||
}
|
||||
} else if pattern.contains('*') {
|
||||
errors.push(format!(
|
||||
"Invalid wildcard pattern: {pattern} - wildcard must be at the start"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_match() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("example.com".to_string(), "upstream1".to_string());
|
||||
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
|
||||
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
matcher.match_sni("example.com"),
|
||||
Some("upstream1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_match() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("example.com".to_string(), "upstream1".to_string());
|
||||
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
|
||||
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
matcher.match_sni("www.example.com"),
|
||||
Some("upstream2".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
matcher.match_sni("api.example.com"),
|
||||
Some("upstream2".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_longest_suffix_match() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.example.com".to_string(), "wildcard1".to_string());
|
||||
sni_map.insert("*.api.example.com".to_string(), "wildcard2".to_string());
|
||||
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
matcher.match_sni("v2.api.example.com"),
|
||||
Some("wildcard2".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("example.com".to_string(), "upstream1".to_string());
|
||||
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(matcher.match_sni("other.com"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_suffix_allowed() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert(
|
||||
"*.private.local".to_string(),
|
||||
"private_upstream".to_string(),
|
||||
);
|
||||
sni_map.insert(
|
||||
"*.internal.net".to_string(),
|
||||
"internal_upstream".to_string(),
|
||||
);
|
||||
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
matcher.match_sni("server.private.local"),
|
||||
Some("private_upstream".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
matcher.match_sni("app.internal.net"),
|
||||
Some("internal_upstream".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_public_suffix() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.com".to_string(), "invalid".to_string());
|
||||
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
|
||||
let errors = result.unwrap_err();
|
||||
assert!(!errors.is_empty());
|
||||
assert!(errors[0].contains("*.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_errors_collected() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.com".to_string(), "invalid1".to_string());
|
||||
sni_map.insert("*.org".to_string(), "invalid2".to_string());
|
||||
sni_map.insert("*.net".to_string(), "invalid3".to_string());
|
||||
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
|
||||
let errors = result.unwrap_err();
|
||||
assert_eq!(errors.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_public_suffix() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.example.com".to_string(), "valid".to_string());
|
||||
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
assert_eq!(
|
||||
matcher.match_sni("www.example.com"),
|
||||
Some("valid".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_static() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.com".to_string(), "invalid".to_string());
|
||||
sni_map.insert("*.example.com".to_string(), "valid".to_string());
|
||||
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
|
||||
let errors = result.unwrap_err();
|
||||
assert_eq!(errors.len(), 1);
|
||||
assert!(errors[0].contains("*.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_not_at_start_rejected() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("foo*.example.com".to_string(), "invalid".to_string());
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
let errors = result.unwrap_err();
|
||||
assert!(!errors.is_empty());
|
||||
assert!(errors[0].contains("*.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_in_middle_rejected() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.example.*".to_string(), "invalid".to_string());
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trailing_dot_rejected() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.".to_string(), "invalid".to_string());
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_just_asterisk_rejected() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*".to_string(), "invalid".to_string());
|
||||
let result = SniMatcher::new(sni_map);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_requires_subdomain() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.example.com".to_string(), "upstream".to_string());
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(matcher.match_sni("example.com"), None);
|
||||
assert_eq!(
|
||||
matcher.match_sni("www.example.com"),
|
||||
Some("upstream".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
matcher.match_sni("foo.bar.example.com"),
|
||||
Some("upstream".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_longest_suffix_wins_not_shortest() {
|
||||
let mut sni_map = HashMap::new();
|
||||
sni_map.insert("*.example.org".to_string(), "broad".to_string());
|
||||
sni_map.insert("*.bar.example.org".to_string(), "narrow".to_string());
|
||||
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||
|
||||
assert_eq!(matcher.match_sni("bar.example.org"), None);
|
||||
assert_eq!(
|
||||
matcher.match_sni("v2.bar.example.org"),
|
||||
Some("narrow".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
matcher.match_sni("v2.example.org"),
|
||||
Some("broad".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user