2 Commits

Author SHA1 Message Date
2c6ad1d7b8 Update based on lints
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/pr Build is passing
Signed-off-by: Jacob Kiers <code@kiers.eu>
2025-01-09 20:38:05 +01:00
8fe1f7f57b Fix SNI header parsing
When a listener is configured to deal with TLS upstreams, we use the SNI
field of the TLS ClientHello message to decide where to send the traffic.

Therefore, a buffer of 1024 bytes was used to temporarily store this
message. However, a TLS ClientHello message can be larger than that, up
to 16K bytes.

So now the first few bytes are read and manually parsed to find out how
long the message is. And then the entire ClientHello message is
retrieved.

So hopefully that will fix the issue causing the ClientHello
determination to fail.

Closes #10

Signed-off-by: Jacob Kiers <code@kiers.eu>
2025-01-09 20:23:02 +01:00
14 changed files with 48 additions and 1124 deletions

View File

@@ -1,52 +0,0 @@
when:
- event:
- push
- tag
- manual
matrix:
include:
- TARGET: x86_64-unknown-linux-musl
SHORT: amd64-musl
BIN_SUFFIX:
- TARGET: aarch64-unknown-linux-musl
SHORT: arm64-musl
BIN_SUFFIX:
- TARGET: x86_64-pc-windows-gnu
SHORT: windows
BIN_SUFFIX: .exe
steps:
- name: Prepare
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
commands:
- echo Using image img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
- mkdir -p artifacts
- cargo --version
- rustc --version
- set
- name: Build for ${SHORT}
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
commands:
- echo Building ${TARGET} \(${SHORT}\)
- cargo zigbuild --release --target ${TARGET}
- mkdir -p artifacts
- cp target/${TARGET}/release/l4p${BIN_SUFFIX} artifacts/l4p-${TARGET}${BIN_SUFFIX}
- rm -rf target/${TARGET}/release/*
depends_on:
- Prepare
- name: Release
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
when:
- event: tag
commands:
- ls -lah artifacts
- scripts/create_release_artifacts.sh
environment:
GITEA_SERVER_TOKEN:
from_secret: gitea_token
depends_on:
- Build for ${SHORT}

View File

@@ -7,15 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.1.10] - 2025-01-09
### Fixed
* The ClientHello TLS header is now read in full before it is parsed, solving
an error where there was not enough data to fully read it. In those cases
it was not possible to determine the upstream address and therefore the proxy
would go the the default action instead.
### Changed
* Updated some dependencies to prevent the build from breaking.
@@ -41,10 +32,10 @@ The ability to run `l4p` without arguments is now deprecated. Please use
## Previous versions
[unreleased]: https://code.kiers.eu/jjkiers/layer4-proxy/compare/v0.1.9...HEAD
[0.1.10]: https://code.kiers.eu/jjkiers/layer4-proxy/compare/v0.1.9...v0.1.10
[0.1.9]: https://code.kiers.eu/jjkiers/layer4-proxy/compare/v0.1.8...v0.1.9
Types of changes:
* `Added` for new features.

20
Cargo.lock generated
View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
version = 3
[[package]]
name = "addr2line"
@@ -516,7 +516,7 @@ dependencies = [
[[package]]
name = "l4p"
version = "0.1.12"
version = "0.1.9"
dependencies = [
"async-trait",
"byte_string",
@@ -525,7 +525,6 @@ dependencies = [
"log",
"pico-args",
"pretty_env_logger",
"psl",
"self_update",
"serde",
"serde_yaml",
@@ -857,21 +856,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "psl"
version = "2.1.199"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70b63978a2742d3f662188698ab45854156e7e34658f53fa951e9253a3dfd583"
dependencies = [
"psl-types",
]
[[package]]
name = "psl-types"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac"
[[package]]
name = "quick-xml"
version = "0.37.2"

View File

@@ -1,6 +1,6 @@
[package]
name = "l4p"
version = "0.1.12"
version = "0.1.9"
edition = "2021"
authors = ["Jacob Kiers <code@kiers.eu>"]
license = "Apache-2.0"
@@ -33,7 +33,6 @@ time = { version = "0.3.37", features = ["local-offset", "formatting"] }
tls-parser = "0.12.2"
tokio = { version = "1.0", features = ["full"] }
url = "2.2.2"
psl = "2.1"
[dependencies.self_update]
version = "0.42.0"

View File

@@ -31,13 +31,6 @@ $ cargo install l4p
Or you can download binary file form the Release page.
## Features
- Listen on specific port and proxy to local or remote port
- SNI-based rule without terminating TLS connection
- Wildcard SNI matching with DNS-style longest-suffix-match
- DNS-based backend with periodic resolution
## Configuration
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
@@ -62,14 +55,6 @@ There are two upstreams built in:
For detailed configuration, check [this example](./config.yaml.example).
### SNI Matching
The proxy supports both exact and wildcard SNI patterns in the `sni` config. Wildcards use DNS-style longest-suffix-match: more specific patterns take precedence. For example, with `*.example.com` and `*.api.example.com`, request `api.example.com` matches the first, while `v2.api.example.com` matches the second.
Wildcards are validated against the Public Suffix List (PSL). Known suffixes (`.com`, `.org`) require at least one label below the suffix (`*.example.com` OK, `*.com` rejected). Unknown suffixes (`.local`, `.lan`) are allowed without restriction.
Invalid wildcard patterns are rejected at config load time with clear error messages.
## Thanks
- [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork.

View File

@@ -10,9 +10,6 @@ servers:
sni:
api.example.org: example-api
www.example.org: proxy
*.example.org: wildcard-proxy # Matches any subdomain of example.org
*.dev.example.org: dev-proxy # More specific: matches v2.dev.example.org, etc.
*.local: local-upstream # Unknown suffix - allowed (no PSL restriction)
default: ban
second-server:
@@ -22,6 +19,3 @@ servers:
upstream:
proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
wildcard-proxy: "tcp://wildcard.example.org:443"
dev-proxy: "tcp://dev.example.org:443"
local-upstream: "tcp://localhost:8080"

View File

@@ -1,50 +0,0 @@
#!/usr/bin/env bash
# vim: set expandtab shiftwidth=4 softtabstop=4 tabstop=4 :
set -euo pipefail
if [ -z "${CI_COMMIT_TAG:-}" ]; then
echo "No commit tag set"
exit 1
fi
DIR=$(realpath $(dirname "${BASH_SOURCE[0]}") )
echo ${DIR}
${DIR}/install_tea.sh linux-amd64 https://gitea.com/api/v1/repos/gitea/tea/releases/latest
## Log in to Gitea
TEA=$(pwd)/tea
if [ -z "${GITEA_SERVER_URL:-}" ]; then
if [ -z "${CI_FORGE_URL:-}" ]; then
echo "Cannot log in to gitea: GITEA_SERVER_URL or CI_FORGE_URL missing"
exit 1
fi
GITEA_SERVER_URL=${CI_FORGE_URL}
fi
if [ -z "${GITEA_SERVER_TOKEN:-}" ]; then
echo "Cannot log in to gitea: GITEA_SERVER_TOKEN missing"
exit 1
fi
if ! ${TEA} login ls | grep ${GITEA_SERVER_URL} 2>&1 > /dev/null || false; then
${TEA} login add
else
echo "Already logged in to ${GITEA_SERVER_URL}"
fi
## Check and create tag
if ${TEA} release ls -o json | jq -e --arg tag "${CI_COMMIT_TAG}" 'map(.["tag-_name"]) | index($tag) != null' >/dev/null; then
echo "Release ${CI_COMMIT_TAG} exists"
else
echo "Creating release ${CI_COMMIT_TAG}"
${TEA} release create -o json --tag "${CI_COMMIT_TAG}" --title "${CI_COMMIT_TAG}" --draft
fi
find $(dirname ${DIR})/artifacts -type f -exec ${TEA} releases assets create -o json ${CI_COMMIT_TAG} {} +

View File

@@ -1,39 +0,0 @@
#!/usr/bin/env bash
set -eo pipefail
if [ -x ./tea ]; then
echo "tea already installed in current directory"; exit 0
fi
platform="${1:-linux-amd64}"
src="${2:-release.json}"
# obtain JSON: if src looks like a URL fetch it, otherwise treat as filename (or default file)
if [[ "$src" =~ ^https?:// ]]; then
curl -fsSL "$src" -o /tmp/release.json.$$
json="/tmp/release.json.$$"
trap 'rm -f "$json"' EXIT
elif [ -f "$src" ]; then
json="$src"
else
echo "release JSON not found; provide a filename or URL as second arg" >&2
exit 1
fi
# read tag and find binary URL (exclude archives/checksums/sigs)
tag=$(jq -r '.tag_name' "$json")
url=$(jq -r --arg p "$platform" '.assets[]
| select(.name | test($p))
| select(.name | test("\\.(xz|zip|gz|tar|bz2|7z|sha256|sha256sum|sig|asc)$") | not)
| .browser_download_url' "$json" | head -n1)
[ -n "$url" ] || { echo "binary not found for $platform" >&2; exit 1; }
tmp="$(mktemp)"
trap 'rm -f "$tmp"' EXIT
curl -fsSL "$url" -o "$tmp"
mv "$tmp" tea
chmod +x tea
echo "Downloaded tag ${tag}: $url -> ./tea"

View File

@@ -1,4 +1,3 @@
use crate::sni_matcher::SniMatcher;
use crate::upstreams::ProxyToUpstream;
use crate::upstreams::Upstream;
use log::{debug, info, warn};
@@ -13,6 +12,14 @@ pub struct ConfigV1 {
pub base: ParsedConfigV1,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ParsedConfigV1 {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct BaseConfig {
pub version: i32,
@@ -21,14 +28,6 @@ pub struct BaseConfig {
pub upstream: HashMap<String, String>,
}
#[derive(Debug, Default, Clone)]
pub struct ParsedConfigV1 {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ParsedServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ServerConfig {
pub listen: Vec<String>,
@@ -37,35 +36,6 @@ pub struct ServerConfig {
pub sni: Option<HashMap<String, String>>,
pub default: Option<String>,
}
impl ServerConfig {
pub fn into_parsed(self) -> Result<ParsedServerConfig, Vec<String>> {
let sni = match self.sni {
Some(sni_map) => {
let matcher = SniMatcher::new(sni_map)?;
Some(matcher)
}
None => None,
};
Ok(ParsedServerConfig {
listen: self.listen,
protocol: self.protocol,
tls: self.tls,
sni,
default: self.default,
})
}
}
#[derive(Debug, Clone)]
pub struct ParsedServerConfig {
pub listen: Vec<String>,
pub protocol: Option<String>,
pub tls: Option<bool>,
pub sni: Option<SniMatcher>,
pub default: Option<String>,
}
impl TryInto<ProxyToUpstream> for &str {
type Error = ConfigError;
@@ -132,23 +102,12 @@ impl ConfigV1 {
}
}
/// Load and parse configuration from a YAML string.
///
/// This public function takes raw YAML content as a string and returns a parsed,
/// validated configuration. It performs all validation including:
/// - Version checking
/// - SNI pattern validation
/// - Upstream URL parsing
/// - Cross-reference validation
///
/// # Arguments
/// * `yaml_str` - The YAML configuration content as a string
///
/// # Returns
/// * `Ok(ParsedConfigV1)` - Successfully parsed and validated configuration
/// * `Err(ConfigError)` - If YAML parsing fails or validation errors occur
pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigError> {
let base: BaseConfig = serde_yaml::from_str(yaml_str)?;
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
let mut contents = String::new();
let mut file = File::open(path)?;
file.read_to_string(&mut contents)?;
let base: BaseConfig = serde_yaml::from_str(&contents)?;
if base.version != 1 {
return Err(ConfigError::Custom(
@@ -158,12 +117,11 @@ pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigErr
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
if !log_level.eq("disable") {
unsafe {
std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG");
}
std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG");
}
info!("Using config file: {}", &path);
debug!("Set log level to {}", log_level);
debug!("Config version {}", base.version);
@@ -177,50 +135,16 @@ pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigErr
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
}
// Convert ServerConfig to ParsedServerConfig, collecting all SNI validation errors
let mut all_errors: Vec<String> = Vec::new();
let mut parsed_servers: HashMap<String, ParsedServerConfig> = HashMap::new();
for (name, server_config) in base.servers {
match server_config.into_parsed() {
Ok(parsed) => {
parsed_servers.insert(name, parsed);
}
Err(errors) => {
for err in errors {
all_errors.push(format!("Server '{}': {}", name, err));
}
}
}
}
if !all_errors.is_empty() {
return Err(ConfigError::Custom(format!(
"Invalid SNI configuration:\n{}",
all_errors.join("\n")
)));
}
let parsed = ParsedConfigV1 {
version: base.version,
log: base.log,
servers: parsed_servers,
servers: base.servers,
upstream: parsed_upstream,
};
verify_config(parsed)
}
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
let mut contents = String::new();
let mut file = File::open(path)?;
file.read_to_string(&mut contents)?;
info!("Using config file: {}", &path);
load_config_from_yaml(&contents)
}
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
let mut used_upstreams: HashSet<String> = HashSet::new();
let mut upstream_names: HashSet<String> = HashSet::new();
@@ -251,20 +175,14 @@ fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError>
listen_addresses.insert(listen.to_string());
}
if server.tls.unwrap_or_default() {
if let Some(matcher) = &server.sni {
// Collect all upstream names from the SniMatcher
for (_, upstream) in matcher.exact.iter() {
used_upstreams.insert(upstream.clone());
}
for pattern in &matcher.wildcards {
used_upstreams.insert(pattern.upstream.clone());
}
if server.tls.unwrap_or_default() && server.sni.is_some() {
for (_, val) in server.sni.unwrap() {
used_upstreams.insert(val.to_string());
}
}
if let Some(default) = &server.default {
used_upstreams.insert(default.clone());
if server.default.is_some() {
used_upstreams.insert(server.default.unwrap().to_string());
}
for key in &used_upstreams {
@@ -307,45 +225,4 @@ mod tests {
assert_eq!(config.base.servers.len(), 3);
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
}
#[test]
fn test_config_hard_failure_on_invalid_sni() {
// Test that invalid SNI wildcard (*.com) causes hard failure
let config_content = r#"version: 1
log: disable
servers:
test_server:
listen:
- "127.0.0.1:8443"
protocol: tcp
tls: true
sni:
"*.com": "upstream1"
default: ban
upstream:
upstream1: tcp://127.0.0.1:9000
"#;
let result = load_config_from_yaml(config_content);
// Should fail with an error
assert!(result.is_err(), "Expected config to fail with invalid SNI");
// Verify error message contains helpful information
match result {
Err(ConfigError::Custom(msg)) => {
assert!(
msg.contains("Invalid SNI"),
"Error message should mention invalid SNI: {}",
msg
);
assert!(
msg.contains("*.com"),
"Error message should mention the invalid pattern: {}",
msg
);
}
_ => panic!("Expected ConfigError::Custom"),
}
}
}

View File

@@ -1,4 +1,3 @@
mod config_v1;
pub(crate) use config_v1::ConfigV1;
pub(crate) use config_v1::ParsedConfigV1;
pub(crate) use config_v1::ParsedServerConfig;

View File

@@ -1,6 +1,5 @@
mod config;
mod servers;
mod sni_matcher;
mod update;
mod upstreams;

View File

@@ -9,7 +9,6 @@ mod protocol;
pub(crate) mod upstream_address;
use crate::config::ParsedConfigV1;
use crate::sni_matcher::SniMatcher;
use crate::upstreams::Upstream;
use protocol::tcp;
@@ -24,7 +23,7 @@ pub(crate) struct Proxy {
pub listen: SocketAddr,
pub protocol: String,
pub tls: bool,
pub sni: Option<SniMatcher>,
pub sni: Option<HashMap<String, String>>,
pub default_action: String,
pub upstream: HashMap<String, Upstream>,
}
@@ -135,132 +134,6 @@ mod tests {
}
}
/// Mock server for wildcard SNI test that responds with "tls_wildcard_response" on first read
#[tokio::main]
async fn tls_mock_server_wildcard() {
let server_addr: SocketAddr = "127.0.0.1:54598".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
// Read client hello (which will be peeked but not actually read by proxy)
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf).await;
// Send a response to verify connection succeeded
let _ = stream.write(b"tls_wildcard_response").await;
let _ = stream.shutdown().await;
}
}
/// Mock server for SNI test that doesn't match wildcard pattern
#[tokio::main]
async fn tls_mock_server_default() {
let server_addr: SocketAddr = "127.0.0.1:54597".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf).await;
let _ = stream.write(b"tls_default_response").await;
let _ = stream.shutdown().await;
}
}
/// Helper function to build a minimal TLS ClientHello with SNI extension
/// This creates a valid TLS 1.2 ClientHello packet with the specified SNI hostname
fn build_tls_client_hello(sni_hostname: &str) -> Vec<u8> {
// TLS record header (9 bytes)
let mut hello = Vec::new();
// Record type: Handshake (0x16)
hello.push(0x16);
// Version: TLS 1.2 (0x0303)
hello.extend_from_slice(&[0x03, 0x03]);
// We'll set the record length later
let record_length_pos = hello.len();
hello.extend_from_slice(&[0x00, 0x00]); // Placeholder for record length
// Handshake message type: ClientHello (0x01)
hello.push(0x01);
// We'll set the handshake length later
let handshake_length_pos = hello.len();
hello.extend_from_slice(&[0x00, 0x00, 0x00]); // Placeholder for handshake length
// ClientHello fields
// Protocol version: TLS 1.2 (0x0303)
hello.extend_from_slice(&[0x03, 0x03]);
// Random: 32 bytes (we'll use a fixed pattern)
hello.extend_from_slice(&[0x00; 32]);
// Session ID length: 0 (no session)
hello.push(0x00);
// Cipher suites length: 2 bytes + cipher suites
hello.extend_from_slice(&[0x00, 0x02]); // Length of cipher suites list (2 bytes)
// Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002F)
hello.extend_from_slice(&[0x00, 0x2F]);
// Compression methods length: 1 byte
hello.push(0x01);
// Compression method: null (0x00)
hello.push(0x00);
// Extensions
let extensions_start = hello.len();
// SNI Extension (type 0x0000)
let mut sni_extension = Vec::new();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
// Extension length (will be set later)
let ext_length_pos = sni_extension.len();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder
// Server name list
let server_name_list_start = sni_extension.len();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder for server name list length
// Server name: host_name(0), length, hostname
sni_extension.push(0x00); // Name type: host_name
let hostname_bytes = sni_hostname.as_bytes();
sni_extension.extend_from_slice(&[(hostname_bytes.len() >> 8) as u8, (hostname_bytes.len() & 0xFF) as u8]);
sni_extension.extend_from_slice(hostname_bytes);
// Set server name list length
let server_name_list_len = sni_extension.len() - server_name_list_start - 2;
let pos = server_name_list_start;
sni_extension[pos] = (server_name_list_len >> 8) as u8;
sni_extension[pos + 1] = (server_name_list_len & 0xFF) as u8;
// Set extension length
let ext_len = sni_extension.len() - ext_length_pos - 2;
sni_extension[ext_length_pos] = (ext_len >> 8) as u8;
sni_extension[ext_length_pos + 1] = (ext_len & 0xFF) as u8;
// Add SNI extension to hello
hello.extend_from_slice(&sni_extension);
// Set extensions total length
let extensions_length = hello.len() - extensions_start;
hello.insert(extensions_start, (extensions_length & 0xFF) as u8);
hello.insert(extensions_start, (extensions_length >> 8) as u8);
// Set handshake message length
let handshake_len = hello.len() - handshake_length_pos - 3;
hello[handshake_length_pos] = (handshake_len >> 16) as u8;
hello[handshake_length_pos + 1] = (handshake_len >> 8) as u8;
hello[handshake_length_pos + 2] = (handshake_len & 0xFF) as u8;
// Set record length
let record_len = hello.len() - record_length_pos - 2;
hello[record_length_pos] = (record_len >> 8) as u8;
hello[record_length_pos + 1] = (record_len & 0xFF) as u8;
hello
}
#[tokio::test]
async fn test_proxy() {
use crate::config::ConfigV1;
@@ -297,118 +170,4 @@ mod tests {
}
conn.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_wildcard_sni_routing() {
// Create test configuration with wildcard SNI pattern
use crate::upstreams::Upstream;
use std::collections::HashMap;
// Start mock servers for upstreams
thread::spawn(move || {
tls_mock_server_wildcard();
});
thread::spawn(move || {
tls_mock_server_default();
});
sleep(Duration::from_millis(500)); // wait for mock servers to start
// Create inline configuration
let mut config = crate::config::ParsedConfigV1 {
version: 1,
log: Some("disable".to_string()),
servers: HashMap::new(),
upstream: HashMap::new(),
};
// Add upstreams
config.upstream.insert(
"wildcard_upstream".to_string(),
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
"127.0.0.1:54598".to_string(),
"tcp".to_string(),
)),
);
config.upstream.insert(
"default_upstream".to_string(),
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
"127.0.0.1:54597".to_string(),
"tcp".to_string(),
)),
);
// Add TLS server with wildcard SNI pattern
let mut sni_map = HashMap::new();
sni_map.insert("*.api.example.com".to_string(), "wildcard_upstream".to_string());
let server_config = crate::config::ParsedServerConfig {
listen: vec!["127.0.0.1:54595".to_string()],
protocol: Some("tcp".to_string()),
tls: Some(true),
sni: Some(crate::sni_matcher::SniMatcher::new(sni_map).unwrap()),
default: Some("default_upstream".to_string()),
};
config.servers.insert("wildcard_test_server".to_string(), server_config);
// Start proxy server
let mut server = Server::new_from_v1_config(config);
thread::spawn(move || {
let _ = server.run();
});
sleep(Duration::from_secs(1)); // wait for proxy to start
// Test 1: Send ClientHello with SNI matching wildcard pattern
// Expected: Should route to wildcard_upstream (127.0.0.1:54598)
let client_hello = build_tls_client_hello("app.api.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello).await.unwrap();
let mut response = [0u8; 21];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from wildcard upstream");
assert_eq!(
&response[..n],
b"tls_wildcard_response",
"Should receive expected response from wildcard upstream"
);
conn.shutdown().await.unwrap();
// Test 2: Send ClientHello with SNI not matching any pattern
// Expected: Should route to default_upstream (127.0.0.1:54597)
let client_hello_nomatch = build_tls_client_hello("unrelated.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello_nomatch).await.unwrap();
let mut response = [0u8; 20];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from default upstream");
assert_eq!(
&response[..n],
b"tls_default_response",
"Should receive expected response from default upstream"
);
conn.shutdown().await.unwrap();
// Test 3: Send ClientHello with another SNI matching wildcard pattern
let client_hello_match2 = build_tls_client_hello("v2.api.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello_match2).await.unwrap();
let mut response = [0u8; 21];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from wildcard upstream for second match");
assert_eq!(
&response[..n],
b"tls_wildcard_response",
"Should receive expected response from wildcard upstream for second match"
);
conn.shutdown().await.unwrap();
}
}

View File

@@ -1,15 +1,12 @@
use crate::servers::Proxy;
use log::{debug, error, trace, warn};
use std::error::Error;
use std::io; // Import io for ErrorKind
use std::sync::Arc;
use std::time::Duration; // For potential delays
use tls_parser::{
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
TlsMessageHandshake,
};
use tokio::net::TcpStream;
use tokio::time::timeout; // Use timeout for peek operations
fn get_sni(buf: &[u8]) -> Vec<String> {
let mut snis: Vec<String> = Vec::new();
@@ -60,9 +57,6 @@ fn get_sni(buf: &[u8]) -> Vec<String> {
snis
}
// Timeout duration for waiting for TLS Hello data
const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed
pub(crate) async fn determine_upstream_name(
inbound: &TcpStream,
proxy: &Arc<Proxy>,
@@ -70,162 +64,37 @@ pub(crate) async fn determine_upstream_name(
let default_upstream = proxy.default_action.clone();
let mut header = [0u8; 9];
inbound.peek(&mut header).await?;
// --- Step 1: Peek the initial header (9 bytes) with timeout ---
match timeout(TLS_PEEK_TIMEOUT, async {
loop {
match inbound.peek(&mut header).await {
Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes
Ok(0) => {
// Connection closed cleanly before sending enough data
trace!("Connection closed while peeking for TLS header");
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed while peeking for TLS header",
)
.into()); // Convert to Box<dyn Error>
}
Ok(_) => {
// Not enough bytes yet, yield and loop again
tokio::task::yield_now().await;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// Should not happen with await, but yield defensively
tokio::task::yield_now().await;
}
Err(e) => {
// Other I/O error
warn!("Error peeking for TLS header: {}", e);
return Err(e.into()); // Convert to Box<dyn Error>
}
}
}
})
.await
{
Ok(Ok(_)) => { /* Header peeked successfully */ }
Ok(Err(e)) => {
// Inner loop returned an error (e.g., EOF, IO error)
trace!("Failed to peek header (inner error): {}", e);
return Ok(default_upstream); // Fallback on error/EOF
}
Err(_) => {
// Timeout occurred
error!("Timeout waiting for TLS header");
return Ok(default_upstream); // Fallback on timeout
}
}
let required_bytes = client_hello_buffer_size(&header)?;
// --- Step 2: Calculate required size ---
let required_bytes = match client_hello_buffer_size(&header) {
Ok(size) => size,
Err(e) => {
// Header was invalid or not a ClientHello
trace!("Could not determine required buffer size: {}", e);
return Ok(default_upstream);
}
};
let mut hello_buf = vec![0; required_bytes];
let read_bytes = inbound.peek(&mut hello_buf).await?;
// Basic sanity check on size
if required_bytes > 16384 + 9 {
// TLS max record size + header approx
error!(
"Calculated required TLS buffer size is too large: {}",
required_bytes
);
if read_bytes < required_bytes.into() {
error!("Could not read enough bytes to determine SNI");
return Ok(default_upstream);
}
// --- Step 3: Peek the full ClientHello with timeout ---
let mut hello_buf = vec![0; required_bytes];
match timeout(TLS_PEEK_TIMEOUT, async {
let mut total_peeked = 0;
loop {
// Peek into the portion of the buffer that hasn't been filled yet.
match inbound.peek(&mut hello_buf[total_peeked..]).await {
Ok(0) => {
// Connection closed cleanly before sending full ClientHello
trace!(
"Connection closed while peeking for full ClientHello (peeked {}/{} bytes)",
total_peeked,
required_bytes
);
return Err::<usize, io::Error>(
io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed while peeking for full ClientHello",
)
.into(),
);
}
Ok(n) => {
total_peeked += n;
if total_peeked >= required_bytes {
trace!("Successfully peeked {} bytes for ClientHello", total_peeked);
return Ok(total_peeked); // Got enough
} else {
// Not enough bytes yet, yield and loop again
trace!(
"Peeked {}/{} bytes for ClientHello, waiting for more...",
total_peeked,
required_bytes
);
tokio::task::yield_now().await;
let snis = get_sni(&hello_buf);
if snis.is_empty() {
return Ok(default_upstream);
} else {
match proxy.sni.clone() {
Some(sni_map) => {
let mut upstream = default_upstream;
for sni in snis {
let m = sni_map.get(&sni);
if m.is_some() {
upstream = m.unwrap().clone();
break;
}
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
tokio::task::yield_now().await;
}
Err(e) => {
warn!("Error peeking for full ClientHello: {}", e);
return Err(e.into());
}
Ok(upstream)
}
}
})
.await
{
Ok(Ok(_)) => { /* Full hello peeked successfully */ }
Ok(Err(e)) => {
error!("Could not peek full ClientHello (inner error): {}", e);
return Ok(default_upstream); // Fallback on error/EOF
}
Err(_) => {
error!(
"Timeout waiting for full ClientHello (needed {} bytes)",
required_bytes
);
return Ok(default_upstream); // Fallback on timeout
None => return Ok(default_upstream),
}
}
// --- Step 4: Parse SNI ---
let snis = get_sni(&hello_buf);
// --- Step 5: Determine upstream based on SNI ---
if snis.is_empty() {
debug!("No SNI found in ClientHello, using default upstream.");
return Ok(default_upstream);
}
if let Some(matcher) = &proxy.sni {
for sni in snis {
if let Some(upstream) = matcher.match_sni(&sni) {
debug!(
"Found matching SNI '{}', routing to upstream: {}",
sni, upstream
);
return Ok(upstream);
} else {
trace!("SNI '{}' not found in matcher.", sni);
}
}
debug!("SNI(s) found but none matched configuration, using default upstream.");
} else {
debug!("SNI found but no SNI matcher configured, using default upstream.");
}
Ok(default_upstream)
}
fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> {

View File

@@ -1,391 +0,0 @@
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SniPattern {
pub pattern: String,
pub upstream: String,
}
/// Validates and matches SNI patterns against incoming SNI values. Supports exact matches and wildcard patterns.
///
/// Rules:
/// - Wildcard patterns must start with "*." and have a valid domain suffix after it (e.g., "*.example.com")
/// - Wildcard patterns cannot have a wildcard in the middle (e.g., "*.example.*" is invalid)
/// - Wildcard patterns cannot have multiple wildcards (e.g., "*.*.example.com" is invalid)
/// - When there are two wildcard patterns that could match the same SNI, the longest suffix wins
/// (e.g., "*.example.com" vs "*.api.example.com" - "v2.api.example.com" matches "*.api.example.com")
/// - Wildcard patterns that overlap are not matched (e.g. "*.example.com" and "*.bar.example.com" - "bar.example.com" matches neither, "v2.bar.example.com" matches "*.bar.example.com")
/// - Wildcard patterns cannot be just "*." or "*" (must have a valid suffix)
/// - For known public suffixes (e.g., "com", "org"), the wildcard must have at least one label below the public suffix (e.g., "*.example.com" is valid, but "*.com", or *.co.uk is invalid)
/// - For unknown suffixes (e.g., "local", "lan"), the wildcard is allowed without restriction (e.g., "*.local" is valid)
#[derive(Debug, Clone)]
pub struct SniMatcher {
pub exact: HashMap<String, String>,
pub wildcards: Vec<SniPattern>,
}
impl SniMatcher {
pub fn new(sni_map: HashMap<String, String>) -> Result<Self, Vec<String>> {
Self::validate(&sni_map)?;
let mut exact = HashMap::new();
let mut wildcards = Vec::new();
for (pattern, upstream) in sni_map {
if pattern.starts_with("*.") {
wildcards.push(SniPattern {
pattern: pattern.clone(),
upstream,
});
} else {
exact.insert(pattern, upstream);
}
}
wildcards.sort_by(|a, b| {
let a_suffix = a.pattern.trim_start_matches("*.");
let b_suffix = b.pattern.trim_start_matches("*.");
b_suffix.len().cmp(&a_suffix.len())
});
Ok(SniMatcher { exact, wildcards })
}
/// Matches the provided SNI against the patterns in the matcher. Returns Some(upstream) if a match is found, or None if no match is found.
pub fn match_sni(&self, sni: &str) -> Option<String> {
if let Some(upstream) = self.exact.get(sni) {
return Some(upstream.clone());
}
// Try each wildcard in order (longest suffix first)
for wildcard in &self.wildcards {
let suffix = wildcard.pattern.trim_start_matches("*.");
let suffix_len = suffix.len();
let check = format!(".{}", suffix);
if !sni.ends_with(&check) {
continue;
}
// Must have at least one label before the suffix to match
let prefix = &sni[..sni.len() - suffix_len - 1];
// Check if a more specific wildcard could also match this SNI
let is_owned = {
let sni_labels = sni.matches('.').count() + 1;
self.wildcards.iter().any(|w| {
// Skip the current wildcard itself
if w.pattern == wildcard.pattern {
return false;
}
let w_suffix = w.pattern.trim_start_matches("*.");
let w_len = w_suffix.len();
// Only care about wildcards with longer suffix (more specific)
if w_len <= suffix_len {
return false;
}
let w_suffix_labels = w_suffix.matches('.').count() + 1;
if sni == w_suffix {
// Exact match to more specific suffix - owned if could potentially match
// (sni has at least as many labels as the suffix needs)
sni_labels >= w_suffix_labels
} else if sni.ends_with(&format!(".{}", w_suffix)) {
// Ends with more specific suffix - owned if SNI has enough labels
sni_labels >= w_suffix_labels + 1
} else {
false
}
})
};
if is_owned {
continue;
}
// Only return if we have a valid prefix (at least one label)
if !prefix.is_empty() {
return Some(wildcard.upstream.clone());
}
}
None
}
fn validate_wildcard_suffix(pattern: &str) -> Result<(), String> {
let suffix = pattern.trim_start_matches("*.");
let domain_str = format!("a.{}", suffix);
if let Some(ps) = psl::suffix(domain_str.as_bytes()) {
let ps_str = std::str::from_utf8(ps.as_bytes()).unwrap_or("");
if ps_str == suffix {
return Err(format!(
"Invalid wildcard pattern: {} - wildcard cannot be at the public suffix level",
pattern
));
}
}
Ok(())
}
fn validate(sni_map: &HashMap<String, String>) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
for (pattern, _upstream) in sni_map {
if pattern == "*" {
errors.push(format!(
"Invalid wildcard pattern: * - just asterisk is not allowed"
));
continue;
}
if pattern == "*." {
errors.push(format!(
"Invalid wildcard pattern: *. - empty suffix after wildcard"
));
continue;
}
if let Some(rest) = pattern.strip_prefix("*.") {
if rest.is_empty() {
errors.push(format!(
"Invalid wildcard pattern: {pattern} - empty suffix after wildcard"
));
continue;
}
if rest.contains('*') {
errors.push(format!("Invalid wildcard pattern: {pattern} - wildcard cannot be in the middle of suffix"));
continue;
}
if let Err(e) = Self::validate_wildcard_suffix(pattern) {
errors.push(e);
}
} else if pattern.contains('*') {
errors.push(format!(
"Invalid wildcard pattern: {pattern} - wildcard must be at the start"
));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("example.com"),
Some("upstream1".to_string())
);
}
#[test]
fn test_wildcard_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("www.example.com"),
Some("upstream2".to_string())
);
assert_eq!(
matcher.match_sni("api.example.com"),
Some("upstream2".to_string())
);
}
#[test]
fn test_longest_suffix_match() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "wildcard1".to_string());
sni_map.insert("*.api.example.com".to_string(), "wildcard2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("v2.api.example.com"),
Some("wildcard2".to_string())
);
}
#[test]
fn test_no_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("other.com"), None);
}
#[test]
fn test_unknown_suffix_allowed() {
let mut sni_map = HashMap::new();
sni_map.insert(
"*.private.local".to_string(),
"private_upstream".to_string(),
);
sni_map.insert(
"*.internal.net".to_string(),
"internal_upstream".to_string(),
);
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("server.private.local"),
Some("private_upstream".to_string())
);
assert_eq!(
matcher.match_sni("app.internal.net"),
Some("internal_upstream".to_string())
);
}
#[test]
fn test_invalid_public_suffix() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(!errors.is_empty());
assert!(errors[0].contains("*.com"));
}
#[test]
fn test_multiple_errors_collected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid1".to_string());
sni_map.insert("*.org".to_string(), "invalid2".to_string());
sni_map.insert("*.net".to_string(), "invalid3".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 3);
}
#[test]
fn test_valid_public_suffix() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "valid".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("www.example.com"),
Some("valid".to_string())
);
}
#[test]
fn test_validate_static() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid".to_string());
sni_map.insert("*.example.com".to_string(), "valid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("*.com"));
}
#[test]
fn test_wildcard_not_at_start_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("foo*.example.com".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(!errors.is_empty());
assert!(errors[0].contains("*.example.com"));
}
#[test]
fn test_wildcard_in_middle_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.*".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_trailing_dot_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_just_asterisk_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_wildcard_requires_subdomain() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "upstream".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("example.com"), None);
assert_eq!(
matcher.match_sni("www.example.com"),
Some("upstream".to_string())
);
assert_eq!(
matcher.match_sni("foo.bar.example.com"),
Some("upstream".to_string())
);
}
#[test]
fn test_longest_suffix_wins_not_shortest() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.org".to_string(), "broad".to_string());
sni_map.insert("*.bar.example.org".to_string(), "narrow".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("bar.example.org"), None);
assert_eq!(
matcher.match_sni("v2.bar.example.org"),
Some("narrow".to_string())
);
assert_eq!(
matcher.match_sni("v2.example.org"),
Some("broad".to_string())
);
}
}