1 Commits

17 changed files with 325 additions and 1020 deletions

View File

@@ -1,52 +0,0 @@
when:
- event:
- push
- tag
- manual
matrix:
include:
- TARGET: x86_64-unknown-linux-musl
SHORT: amd64-musl
BIN_SUFFIX:
- TARGET: aarch64-unknown-linux-musl
SHORT: arm64-musl
BIN_SUFFIX:
- TARGET: x86_64-pc-windows-gnu
SHORT: windows
BIN_SUFFIX: .exe
steps:
- name: Prepare
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
commands:
- echo Using image img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
- mkdir -p artifacts
- cargo --version
- rustc --version
- name: Build for ${SHORT}
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
commands:
- echo Building ${TARGET} \(${SHORT}\)
- cargo zigbuild --release --target ${TARGET}
- mkdir -p artifacts
- cp target/${TARGET}/release/l4p${BIN_SUFFIX} artifacts/l4p-${TARGET}${BIN_SUFFIX}
- rm -rf target/${TARGET}/release/*
depends_on:
- Prepare
- name: Release
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
when:
- event: tag
commands:
- ls -lah artifacts
- scripts/create_release_artifacts.sh
environment:
GITEA_SERVER_URL: https://code.kiers.eu
GITEA_SERVER_TOKEN:
from_secret: gitea_token
depends_on:
- Build for ${SHORT}

130
Cargo.lock generated
View File

@@ -26,6 +26,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "anyhow"
version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]]
name = "async-trait"
version = "0.1.77"
@@ -49,6 +55,28 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "aws-lc-rs"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
dependencies = [
"aws-lc-sys",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.37.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
dependencies = [
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "backtrace"
version = "0.3.69"
@@ -102,9 +130,15 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]]
name = "cc"
version = "1.0.86"
version = "1.2.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2"
dependencies = [
"find-msvc-tools",
"jobserver",
"libc",
"shlex",
]
[[package]]
name = "cfg-if"
@@ -112,6 +146,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cmake"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d"
dependencies = [
"cc",
]
[[package]]
name = "console"
version = "0.15.8"
@@ -134,6 +177,12 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "encode_unicode"
version = "0.3.6"
@@ -184,6 +233,12 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]]
name = "find-msvc-tools"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582"
[[package]]
name = "fnv"
version = "1.0.7"
@@ -199,6 +254,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.30"
@@ -505,6 +566,15 @@ version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
name = "jobserver"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
dependencies = [
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.69"
@@ -516,8 +586,9 @@ dependencies = [
[[package]]
name = "l4p"
version = "0.1.12"
version = "0.1.11"
dependencies = [
"anyhow",
"async-trait",
"byte_string",
"bytes",
@@ -525,13 +596,15 @@ dependencies = [
"log",
"pico-args",
"pretty_env_logger",
"psl",
"rustls",
"rustls-pemfile",
"self_update",
"serde",
"serde_yaml",
"time",
"tls-parser",
"tokio",
"tokio-rustls",
"url",
]
@@ -857,21 +930,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "psl"
version = "2.1.199"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70b63978a2742d3f662188698ab45854156e7e34658f53fa951e9253a3dfd583"
dependencies = [
"psl-types",
]
[[package]]
name = "psl-types"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac"
[[package]]
name = "quick-xml"
version = "0.37.2"
@@ -1100,10 +1158,12 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.10"
version = "0.23.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402"
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"ring",
"rustls-pki-types",
@@ -1114,26 +1174,29 @@ dependencies = [
[[package]]
name = "rustls-pemfile"
version = "2.1.2"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
dependencies = [
"base64",
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.7.0"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"zeroize",
]
[[package]]
name = "rustls-webpki"
version = "0.102.4"
version = "0.103.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e"
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
dependencies = [
"aws-lc-rs",
"ring",
"rustls-pki-types",
"untrusted",
@@ -1249,6 +1312,12 @@ dependencies = [
"unsafe-libyaml",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.1"
@@ -1464,12 +1533,11 @@ dependencies = [
[[package]]
name = "tokio-rustls"
version = "0.26.0"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"rustls-pki-types",
"tokio",
]

View File

@@ -1,6 +1,6 @@
[package]
name = "l4p"
version = "0.1.12"
version = "0.1.11"
edition = "2021"
authors = ["Jacob Kiers <code@kiers.eu>"]
license = "Apache-2.0"
@@ -20,6 +20,7 @@ name = "l4p"
path = "src/main.rs"
[dependencies]
anyhow = "1.0.102"
async-trait = "0.1.73"
byte_string = "1"
bytes = "1.1"
@@ -27,13 +28,15 @@ futures = "0.3"
log = "0.4"
pico-args = "0.5.0"
pretty_env_logger = "0.5"
rustls = "0.23"
rustls-pemfile = "2.2.0"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.21"
time = { version = "0.3.37", features = ["local-offset", "formatting"] }
tls-parser = "0.12.2"
tokio = { version = "1.0", features = ["full"] }
tokio-rustls = "0.26.4"
url = "2.2.2"
psl = "2.1"
[dependencies.self_update]
version = "0.42.0"

30
PLAN.md Normal file
View File

@@ -0,0 +1,30 @@
## Plan for TLS Termination and Dynamic Port Handling
### Task
Modify the `l4p` (layer 4 proxy) to perform TLS termination and handle dynamic/random ports for backend services. The backend services are on a specific IPv6 address, and the user can dynamically determine the hostname of these services in the format "https://{port}.my-host". The proxy should listen on port 443 and use SNI for routing.
### Completed Actions
- Added `tokio-rustls`, `rustls-pemfile`, `anyhow` dependencies to `Cargo.toml`.
- Modified `src/config/config_v1.rs` to include `TlsTerminationConfig`, `CertificateConfig`, `SniCertificateConfig` structs and updated `ServerConfig`.
- Created `src/tls.rs` (multiple iterations due to compilation issues).
- Integrated `anyhow::Result` into various functions and imported `Context` to `src/servers/mod.rs`.
- Corrected imports in `src/main.rs`, `src/servers/mod.rs`, `src/servers/protocol/tcp.rs`.
- Removed `mod tls;` from `src/servers/protocol/mod.rs`.
- Attempted to fix various compilation errors related to `rustls` API changes, lifetime issues, and `tokio` task handling.
- Changed `handle.await??;` to explicit match for debugging purposes.
### Current State (with persistent errors)
The code currently has compilation errors, primarily related to:
1. **`src/servers/mod.rs`**: Still showing an error for `map_err` not found for unit type `()`. This arises from the complex double `Result` handling (`Result<anyhow::Result<()>, JoinError>`) when awaiting spawned tasks.
2. **`src/tls.rs`**: Facing issues with `rustls::pki_types::PrivateKeyDer` and `CertificateDer` conversions, specifically for ensuring `'static` lifetimes and incorrect method usages like `to_vec()` or `as_ref()`, or `into_owned()` methods not existing for certain types. The `borrowed data escapes outside of function` error indicates deeper lifetime mismatches.
### Next Steps (Requires Manual Intervention)
- **Refactor `src/servers/mod.rs` error handling**: The current `match handle.await` block needs to be carefully reviewed to ensure correct unwraping of the nested `Result` types and proper error propagation from `tokio::task::JoinError` to `anyhow::Error`.
- **Re-evaluate `src/tls.rs` `rustls::pki_types` usage**: A deeper understanding of `rustls-pki-types` crate and its `CertificateDer` and `PrivateKeyDer` lifetimes and conversion methods is needed. The specific error message `no method named to_vec found for struct PrivatePkcs8KeyDer` is a key indicator of incorrect usage.
- **Review `rustls` version and documentation**: It might be helpful to review the `rustls` and `tokio-rustls` documentation for version-specific changes and best practices regarding `pki_types` and asynchronous error handling.
This commit contains the work in progress as of the current session, including these unresolved errors, to allow for external review and debugging.

View File

@@ -31,13 +31,6 @@ $ cargo install l4p
Or you can download binary file form the Release page.
## Features
- Listen on specific port and proxy to local or remote port
- SNI-based rule without terminating TLS connection
- Wildcard SNI matching with DNS-style longest-suffix-match
- DNS-based backend with periodic resolution
## Configuration
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
@@ -62,14 +55,6 @@ There are two upstreams built in:
For detailed configuration, check [this example](./config.yaml.example).
### SNI Matching
The proxy supports both exact and wildcard SNI patterns in the `sni` config. Wildcards use DNS-style longest-suffix-match: more specific patterns take precedence. For example, with `*.example.com` and `*.api.example.com`, request `api.example.com` matches the first, while `v2.api.example.com` matches the second.
Wildcards are validated against the Public Suffix List (PSL). Known suffixes (`.com`, `.org`) require at least one label below the suffix (`*.example.com` OK, `*.com` rejected). Unknown suffixes (`.local`, `.lan`) are allowed without restriction.
Invalid wildcard patterns are rejected at config load time with clear error messages.
## Thanks
- [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork.

View File

@@ -10,9 +10,6 @@ servers:
sni:
api.example.org: example-api
www.example.org: proxy
*.example.org: wildcard-proxy # Matches any subdomain of example.org
*.dev.example.org: dev-proxy # More specific: matches v2.dev.example.org, etc.
*.local: local-upstream # Unknown suffix - allowed (no PSL restriction)
default: ban
second-server:
@@ -22,6 +19,3 @@ servers:
upstream:
proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
wildcard-proxy: "tcp://wildcard.example.org:443"
dev-proxy: "tcp://dev.example.org:443"
local-upstream: "tcp://localhost:8080"

View File

@@ -1,50 +0,0 @@
#!/usr/bin/env bash
# vim: set expandtab shiftwidth=4 softtabstop=4 tabstop=4 :
set -euo pipefail
if [ -z "${CI_COMMIT_TAG:-}" ]; then
echo "No commit tag set"
exit 1
fi
DIR=$(realpath $(dirname "${BASH_SOURCE[0]}") )
echo ${DIR}
${DIR}/install_tea.sh linux-amd64 https://gitea.com/api/v1/repos/gitea/tea/releases/latest
## Log in to Gitea
TEA=$(pwd)/tea
if [ -z "${GITEA_SERVER_URL:-}" ]; then
if [ -z "${CI_FORGE_URL:-}" ]; then
echo "Cannot log in to gitea: GITEA_SERVER_URL or CI_FORGE_URL missing"
exit 1
fi
GITEA_SERVER_URL=${CI_FORGE_URL}
fi
if [ -z "${GITEA_SERVER_TOKEN:-}" ]; then
echo "Cannot log in to gitea: GITEA_SERVER_TOKEN missing"
exit 1
fi
if ! ${TEA} login ls | grep ${GITEA_SERVER_URL} 2>&1 > /dev/null || false; then
${TEA} login add
else
echo "Already logged in to ${GITEA_SERVER_URL}"
fi
## Check and create tag
if ${TEA} release ls -o json | jq -e --arg tag "${CI_COMMIT_TAG}" 'map(.["tag-_name"]) | index($tag) != null' >/dev/null; then
echo "Release ${CI_COMMIT_TAG} exists"
else
echo "Creating release ${CI_COMMIT_TAG}"
${TEA} release create -o json --tag "${CI_COMMIT_TAG}" --title "${CI_COMMIT_TAG}" --draft
fi
find $(dirname ${DIR})/artifacts -type f -exec ${TEA} releases assets create -o json ${CI_COMMIT_TAG} {} +

View File

@@ -1,39 +0,0 @@
#!/usr/bin/env bash
set -eo pipefail
if [ -x ./tea ]; then
echo "tea already installed in current directory"; exit 0
fi
platform="${1:-linux-amd64}"
src="${2:-release.json}"
# obtain JSON: if src looks like a URL fetch it, otherwise treat as filename (or default file)
if [[ "$src" =~ ^https?:// ]]; then
curl -fsSL "$src" -o /tmp/release.json.$$
json="/tmp/release.json.$$"
trap 'rm -f "$json"' EXIT
elif [ -f "$src" ]; then
json="$src"
else
echo "release JSON not found; provide a filename or URL as second arg" >&2
exit 1
fi
# read tag and find binary URL (exclude archives/checksums/sigs)
tag=$(jq -r '.tag_name' "$json")
url=$(jq -r --arg p "$platform" '.assets[]
| select(.name | test($p))
| select(.name | test("\\.(xz|zip|gz|tar|bz2|7z|sha256|sha256sum|sig|asc)$") | not)
| .browser_download_url' "$json" | head -n1)
[ -n "$url" ] || { echo "binary not found for $platform" >&2; exit 1; }
tmp="$(mktemp)"
trap 'rm -f "$tmp"' EXIT
curl -fsSL "$url" -o "$tmp"
mv "$tmp" tea
chmod +x tea
echo "Downloaded tag ${tag}: $url -> ./tea"

View File

@@ -1,4 +1,3 @@
use crate::sni_matcher::SniMatcher;
use crate::upstreams::ProxyToUpstream;
use crate::upstreams::Upstream;
use log::{debug, info, warn};
@@ -13,6 +12,14 @@ pub struct ConfigV1 {
pub base: ParsedConfigV1,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ParsedConfigV1 {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct BaseConfig {
pub version: i32,
@@ -21,14 +28,6 @@ pub struct BaseConfig {
pub upstream: HashMap<String, String>,
}
#[derive(Debug, Default, Clone)]
pub struct ParsedConfigV1 {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ParsedServerConfig>,
pub upstream: HashMap<String, Upstream>,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ServerConfig {
pub listen: Vec<String>,
@@ -36,35 +35,9 @@ pub struct ServerConfig {
pub tls: Option<bool>,
pub sni: Option<HashMap<String, String>>,
pub default: Option<String>,
}
impl ServerConfig {
pub fn into_parsed(self) -> Result<ParsedServerConfig, Vec<String>> {
let sni = match self.sni {
Some(sni_map) => {
let matcher = SniMatcher::new(sni_map)?;
Some(matcher)
}
None => None,
};
Ok(ParsedServerConfig {
listen: self.listen,
protocol: self.protocol,
tls: self.tls,
sni,
default: self.default,
})
}
}
#[derive(Debug, Clone)]
pub struct ParsedServerConfig {
pub listen: Vec<String>,
pub protocol: Option<String>,
pub tls: Option<bool>,
pub sni: Option<SniMatcher>,
pub default: Option<String>,
pub termination_certs: Option<TlsTerminationConfig>,
pub dynamic_backend_pattern: Option<String>,
pub fixed_backend_ipv6: Option<String>,
}
impl TryInto<ProxyToUpstream> for &str {
type Error = ConfigError;
@@ -132,23 +105,12 @@ impl ConfigV1 {
}
}
/// Load and parse configuration from a YAML string.
///
/// This public function takes raw YAML content as a string and returns a parsed,
/// validated configuration. It performs all validation including:
/// - Version checking
/// - SNI pattern validation
/// - Upstream URL parsing
/// - Cross-reference validation
///
/// # Arguments
/// * `yaml_str` - The YAML configuration content as a string
///
/// # Returns
/// * `Ok(ParsedConfigV1)` - Successfully parsed and validated configuration
/// * `Err(ConfigError)` - If YAML parsing fails or validation errors occur
pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigError> {
let base: BaseConfig = serde_yaml::from_str(yaml_str)?;
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
let mut contents = String::new();
let mut file = File::open(path)?;
file.read_to_string(&mut contents)?;
let base: BaseConfig = serde_yaml::from_str(&contents)?;
if base.version != 1 {
return Err(ConfigError::Custom(
@@ -158,12 +120,11 @@ pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigErr
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
if !log_level.eq("disable") {
unsafe {
std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG");
}
std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG");
}
info!("Using config file: {}", &path);
debug!("Set log level to {}", log_level);
debug!("Config version {}", base.version);
@@ -177,50 +138,16 @@ pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigErr
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
}
// Convert ServerConfig to ParsedServerConfig, collecting all SNI validation errors
let mut all_errors: Vec<String> = Vec::new();
let mut parsed_servers: HashMap<String, ParsedServerConfig> = HashMap::new();
for (name, server_config) in base.servers {
match server_config.into_parsed() {
Ok(parsed) => {
parsed_servers.insert(name, parsed);
}
Err(errors) => {
for err in errors {
all_errors.push(format!("Server '{}': {}", name, err));
}
}
}
}
if !all_errors.is_empty() {
return Err(ConfigError::Custom(format!(
"Invalid SNI configuration:\n{}",
all_errors.join("\n")
)));
}
let parsed = ParsedConfigV1 {
version: base.version,
log: base.log,
servers: parsed_servers,
servers: base.servers,
upstream: parsed_upstream,
};
verify_config(parsed)
}
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
let mut contents = String::new();
let mut file = File::open(path)?;
file.read_to_string(&mut contents)?;
info!("Using config file: {}", &path);
load_config_from_yaml(&contents)
}
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
let mut used_upstreams: HashSet<String> = HashSet::new();
let mut upstream_names: HashSet<String> = HashSet::new();
@@ -251,20 +178,14 @@ fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError>
listen_addresses.insert(listen.to_string());
}
if server.tls.unwrap_or_default() {
if let Some(matcher) = &server.sni {
// Collect all upstream names from the SniMatcher
for (_, upstream) in matcher.exact.iter() {
used_upstreams.insert(upstream.clone());
}
for pattern in &matcher.wildcards {
used_upstreams.insert(pattern.upstream.clone());
}
if server.tls.unwrap_or_default() && server.sni.is_some() {
for (_, val) in server.sni.unwrap() {
used_upstreams.insert(val.to_string());
}
}
if let Some(default) = &server.default {
used_upstreams.insert(default.clone());
if server.default.is_some() {
used_upstreams.insert(server.default.unwrap().to_string());
}
for key in &used_upstreams {
@@ -283,6 +204,25 @@ fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError>
Ok(config)
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct TlsTerminationConfig {
pub default_certificate: CertificateConfig,
pub sni_certificates: Option<Vec<SniCertificateConfig>>,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct CertificateConfig {
pub certificate_path: String,
pub private_key_path: String,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct SniCertificateConfig {
pub hostname: String,
pub certificate_path: String,
pub private_key_path: String,
}
impl From<IOError> for ConfigError {
fn from(err: IOError) -> ConfigError {
ConfigError::IO(err)
@@ -307,45 +247,4 @@ mod tests {
assert_eq!(config.base.servers.len(), 3);
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
}
#[test]
fn test_config_hard_failure_on_invalid_sni() {
// Test that invalid SNI wildcard (*.com) causes hard failure
let config_content = r#"version: 1
log: disable
servers:
test_server:
listen:
- "127.0.0.1:8443"
protocol: tcp
tls: true
sni:
"*.com": "upstream1"
default: ban
upstream:
upstream1: tcp://127.0.0.1:9000
"#;
let result = load_config_from_yaml(config_content);
// Should fail with an error
assert!(result.is_err(), "Expected config to fail with invalid SNI");
// Verify error message contains helpful information
match result {
Err(ConfigError::Custom(msg)) => {
assert!(
msg.contains("Invalid SNI"),
"Error message should mention invalid SNI: {}",
msg
);
assert!(
msg.contains("*.com"),
"Error message should mention the invalid pattern: {}",
msg
);
}
_ => panic!("Expected ConfigError::Custom"),
}
}
}

View File

@@ -1,4 +1,3 @@
mod config_v1;
pub(crate) use config_v1::ConfigV1;
pub(crate) use config_v1::ParsedConfigV1;
pub(crate) use config_v1::ParsedServerConfig;
pub mod config_v1;
pub(crate) use config_v1::ConfigV1;
pub(crate) use config_v1::ParsedConfigV1;

View File

@@ -1,8 +1,8 @@
mod config;
mod servers;
mod sni_matcher;
mod update;
mod upstreams;
mod tls; // NEW: Declare the new TLS module
use crate::config::ConfigV1;
use crate::servers::Server;

View File

@@ -1,3 +1,4 @@
use anyhow::{anyhow, Result};
use log::{error, info};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
@@ -8,11 +9,15 @@ use tokio::task::JoinHandle;
mod protocol;
pub(crate) mod upstream_address;
use crate::config::ParsedConfigV1;
use crate::sni_matcher::SniMatcher;
use crate::config::{ParsedConfigV1, config_v1::TlsTerminationConfig};
// use crate::tls;
use crate::upstreams::Upstream;
use protocol::tcp;
// A helper to convert Box<dyn Error> to anyhow::Error
fn unhandled_error_for_box_error(e: Box<dyn std::error::Error>) -> anyhow::Error {
anyhow!("{}", e)
}
#[derive(Debug)]
pub(crate) struct Server {
pub proxies: Vec<Arc<Proxy>>,
@@ -24,9 +29,12 @@ pub(crate) struct Proxy {
pub listen: SocketAddr,
pub protocol: String,
pub tls: bool,
pub sni: Option<SniMatcher>,
pub sni: Option<HashMap<String, String>>,
pub default_action: String,
pub upstream: HashMap<String, Upstream>,
pub termination_certs: Option<TlsTerminationConfig>,
pub dynamic_backend_pattern: Option<String>,
pub fixed_backend_ipv6: Option<String>,
}
impl Server {
@@ -65,6 +73,9 @@ impl Server {
sni: sni.clone(),
default_action: default.clone(),
upstream: upstream.clone(),
termination_certs: proxy.termination_certs.clone(),
dynamic_backend_pattern: proxy.dynamic_backend_pattern.clone(),
fixed_backend_ipv6: proxy.fixed_backend_ipv6.clone(),
};
new_server.proxies.push(Arc::new(proxy));
}
@@ -74,7 +85,7 @@ impl Server {
}
#[tokio::main]
pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
pub async fn run(&mut self) -> Result<()> {
let proxies = self.proxies.clone();
let mut handles: Vec<JoinHandle<()>> = Vec::new();
@@ -84,15 +95,24 @@ impl Server {
config.protocol, config.name, config.listen
);
let handle = tokio::spawn(async move {
match config.protocol.as_ref() {
"tcp" | "tcp4" | "tcp6" => {
let res = tcp::proxy(config.clone()).await;
if res.is_err() {
error!("Failed to start {}: {}", config.name, res.err().unwrap());
}
if config.tls && config.termination_certs.is_some() {
// New TLS termination handling
let res = tcp::tls_proxy(config.clone()).await;
if res.is_err() {
error!("Failed to start TLS server {}: {}", config.name, res.err().unwrap());
}
_ => {
error!("Invalid protocol: {}", config.protocol)
} else {
// Existing plain TCP handling
match config.protocol.as_ref() {
"tcp" | "tcp4" | "tcp6" => {
let res = tcp::proxy(config.clone()).await;
if res.is_err() {
error!("Failed to start {}: {}", config.name, res.err().unwrap());
}
}
_ => {
error!("Invalid protocol: {}", config.protocol)
}
}
}
});
@@ -100,8 +120,7 @@ impl Server {
}
for handle in handles {
handle.await?;
}
handle.await.map_err(anyhow::Error::from)?.map_err(anyhow::Error::from)?; }
Ok(())
}
}
@@ -135,132 +154,6 @@ mod tests {
}
}
/// Mock server for wildcard SNI test that responds with "tls_wildcard_response" on first read
#[tokio::main]
async fn tls_mock_server_wildcard() {
let server_addr: SocketAddr = "127.0.0.1:54598".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
// Read client hello (which will be peeked but not actually read by proxy)
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf).await;
// Send a response to verify connection succeeded
let _ = stream.write(b"tls_wildcard_response").await;
let _ = stream.shutdown().await;
}
}
/// Mock server for SNI test that doesn't match wildcard pattern
#[tokio::main]
async fn tls_mock_server_default() {
let server_addr: SocketAddr = "127.0.0.1:54597".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf).await;
let _ = stream.write(b"tls_default_response").await;
let _ = stream.shutdown().await;
}
}
/// Helper function to build a minimal TLS ClientHello with SNI extension
/// This creates a valid TLS 1.2 ClientHello packet with the specified SNI hostname
fn build_tls_client_hello(sni_hostname: &str) -> Vec<u8> {
// TLS record header (9 bytes)
let mut hello = Vec::new();
// Record type: Handshake (0x16)
hello.push(0x16);
// Version: TLS 1.2 (0x0303)
hello.extend_from_slice(&[0x03, 0x03]);
// We'll set the record length later
let record_length_pos = hello.len();
hello.extend_from_slice(&[0x00, 0x00]); // Placeholder for record length
// Handshake message type: ClientHello (0x01)
hello.push(0x01);
// We'll set the handshake length later
let handshake_length_pos = hello.len();
hello.extend_from_slice(&[0x00, 0x00, 0x00]); // Placeholder for handshake length
// ClientHello fields
// Protocol version: TLS 1.2 (0x0303)
hello.extend_from_slice(&[0x03, 0x03]);
// Random: 32 bytes (we'll use a fixed pattern)
hello.extend_from_slice(&[0x00; 32]);
// Session ID length: 0 (no session)
hello.push(0x00);
// Cipher suites length: 2 bytes + cipher suites
hello.extend_from_slice(&[0x00, 0x02]); // Length of cipher suites list (2 bytes)
// Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002F)
hello.extend_from_slice(&[0x00, 0x2F]);
// Compression methods length: 1 byte
hello.push(0x01);
// Compression method: null (0x00)
hello.push(0x00);
// Extensions
let extensions_start = hello.len();
// SNI Extension (type 0x0000)
let mut sni_extension = Vec::new();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
// Extension length (will be set later)
let ext_length_pos = sni_extension.len();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder
// Server name list
let server_name_list_start = sni_extension.len();
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder for server name list length
// Server name: host_name(0), length, hostname
sni_extension.push(0x00); // Name type: host_name
let hostname_bytes = sni_hostname.as_bytes();
sni_extension.extend_from_slice(&[(hostname_bytes.len() >> 8) as u8, (hostname_bytes.len() & 0xFF) as u8]);
sni_extension.extend_from_slice(hostname_bytes);
// Set server name list length
let server_name_list_len = sni_extension.len() - server_name_list_start - 2;
let pos = server_name_list_start;
sni_extension[pos] = (server_name_list_len >> 8) as u8;
sni_extension[pos + 1] = (server_name_list_len & 0xFF) as u8;
// Set extension length
let ext_len = sni_extension.len() - ext_length_pos - 2;
sni_extension[ext_length_pos] = (ext_len >> 8) as u8;
sni_extension[ext_length_pos + 1] = (ext_len & 0xFF) as u8;
// Add SNI extension to hello
hello.extend_from_slice(&sni_extension);
// Set extensions total length
let extensions_length = hello.len() - extensions_start;
hello.insert(extensions_start, (extensions_length & 0xFF) as u8);
hello.insert(extensions_start, (extensions_length >> 8) as u8);
// Set handshake message length
let handshake_len = hello.len() - handshake_length_pos - 3;
hello[handshake_length_pos] = (handshake_len >> 16) as u8;
hello[handshake_length_pos + 1] = (handshake_len >> 8) as u8;
hello[handshake_length_pos + 2] = (handshake_len & 0xFF) as u8;
// Set record length
let record_len = hello.len() - record_length_pos - 2;
hello[record_length_pos] = (record_len >> 8) as u8;
hello[record_length_pos + 1] = (record_len & 0xFF) as u8;
hello
}
#[tokio::test]
async fn test_proxy() {
use crate::config::ConfigV1;
@@ -297,118 +190,4 @@ mod tests {
}
conn.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_wildcard_sni_routing() {
// Create test configuration with wildcard SNI pattern
use crate::upstreams::Upstream;
use std::collections::HashMap;
// Start mock servers for upstreams
thread::spawn(move || {
tls_mock_server_wildcard();
});
thread::spawn(move || {
tls_mock_server_default();
});
sleep(Duration::from_millis(500)); // wait for mock servers to start
// Create inline configuration
let mut config = crate::config::ParsedConfigV1 {
version: 1,
log: Some("disable".to_string()),
servers: HashMap::new(),
upstream: HashMap::new(),
};
// Add upstreams
config.upstream.insert(
"wildcard_upstream".to_string(),
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
"127.0.0.1:54598".to_string(),
"tcp".to_string(),
)),
);
config.upstream.insert(
"default_upstream".to_string(),
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
"127.0.0.1:54597".to_string(),
"tcp".to_string(),
)),
);
// Add TLS server with wildcard SNI pattern
let mut sni_map = HashMap::new();
sni_map.insert("*.api.example.com".to_string(), "wildcard_upstream".to_string());
let server_config = crate::config::ParsedServerConfig {
listen: vec!["127.0.0.1:54595".to_string()],
protocol: Some("tcp".to_string()),
tls: Some(true),
sni: Some(crate::sni_matcher::SniMatcher::new(sni_map).unwrap()),
default: Some("default_upstream".to_string()),
};
config.servers.insert("wildcard_test_server".to_string(), server_config);
// Start proxy server
let mut server = Server::new_from_v1_config(config);
thread::spawn(move || {
let _ = server.run();
});
sleep(Duration::from_secs(1)); // wait for proxy to start
// Test 1: Send ClientHello with SNI matching wildcard pattern
// Expected: Should route to wildcard_upstream (127.0.0.1:54598)
let client_hello = build_tls_client_hello("app.api.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello).await.unwrap();
let mut response = [0u8; 21];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from wildcard upstream");
assert_eq!(
&response[..n],
b"tls_wildcard_response",
"Should receive expected response from wildcard upstream"
);
conn.shutdown().await.unwrap();
// Test 2: Send ClientHello with SNI not matching any pattern
// Expected: Should route to default_upstream (127.0.0.1:54597)
let client_hello_nomatch = build_tls_client_hello("unrelated.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello_nomatch).await.unwrap();
let mut response = [0u8; 20];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from default upstream");
assert_eq!(
&response[..n],
b"tls_default_response",
"Should receive expected response from default upstream"
);
conn.shutdown().await.unwrap();
// Test 3: Send ClientHello with another SNI matching wildcard pattern
let client_hello_match2 = build_tls_client_hello("v2.api.example.com");
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
.await
.unwrap();
let _ = conn.write(&client_hello_match2).await.unwrap();
let mut response = [0u8; 21];
let n = conn.read(&mut response).await.unwrap();
assert!(n > 0, "Should receive response from wildcard upstream for second match");
assert_eq!(
&response[..n],
b"tls_wildcard_response",
"Should receive expected response from wildcard upstream for second match"
);
conn.shutdown().await.unwrap();
}
}

View File

@@ -1,2 +1 @@
pub mod tcp;
pub mod tls;

View File

@@ -1,11 +1,48 @@
use crate::servers::protocol::tls::determine_upstream_name;
use anyhow::Result;
use crate::servers::Proxy;
use log::{debug, error, info, warn};
use std::error::Error;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use crate::tls;
pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
pub(crate) async fn tls_proxy(config: Arc<Proxy>) -> Result<()> {
let listener = TcpListener::bind(config.listen).await?;
let config = config.clone();
let acceptor = tls::build_tls_acceptor(
config.termination_certs.as_ref().expect("TLS termination config missing"),
)?;
loop {
let _config = config.clone();
let thread_acceptor = acceptor.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept TLS connection: {}", err);
}
Ok((stream, _)) => {
tokio::spawn(async move {
let res = match thread_acceptor.accept(stream).await {
Ok(tls_stream) => {
info!("TLS handshake successful with {:?}", tls_stream.into_inner().0.peer_addr().ok().map(|s| s.to_string()).unwrap_or_else(|| "unknown".to_string()));
Ok(()) // Return Ok(()) for now
}
Err(err) => {
error!("TLS handshake failed: {}", err);
Err(anyhow::anyhow!("{}", err))
}
};
if res.is_err() {
error!("TLS handling error: {}", res.unwrap_err());
}
});
}
}
}
}
pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<()> {
let listener = TcpListener::bind(config.listen).await?;
let config = config.clone();
@@ -14,7 +51,7 @@ pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
return Err(anyhow::Error::new(err)); // Convert to anyhow::Error
}
Ok((stream, _)) => {
tokio::spawn(async move {
@@ -30,13 +67,10 @@ pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
}
}
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<()> {
info!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls {
false => proxy.default_action.clone(),
true => determine_upstream_name(&inbound, &proxy).await?,
};
let upstream_name = proxy.default_action.clone();
debug!("Upstream: {}", upstream_name);
@@ -47,9 +81,9 @@ async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn Err
"No upstream named {:?} on server {:?}",
proxy.default_action, proxy.name
);
proxy.upstream.get(&proxy.default_action).unwrap()
proxy.upstream.get(&proxy.default_action).expect("Default upstream must exist")
}
};
upstream.process(inbound).await
upstream.process(inbound).await.map_err(|e| anyhow::anyhow!("{}", e))
}

View File

@@ -206,26 +206,36 @@ pub(crate) async fn determine_upstream_name(
if snis.is_empty() {
debug!("No SNI found in ClientHello, using default upstream.");
return Ok(default_upstream);
}
if let Some(matcher) = &proxy.sni {
for sni in snis {
if let Some(upstream) = matcher.match_sni(&sni) {
debug!(
"Found matching SNI '{}', routing to upstream: {}",
sni, upstream
);
return Ok(upstream);
} else {
trace!("SNI '{}' not found in matcher.", sni);
} else {
match proxy.sni.clone() {
Some(sni_map) => {
let mut upstream = default_upstream.clone(); // Clone here for default case
let mut found_match = false;
for sni in snis {
// snis is already Vec<String>
if let Some(target_upstream) = sni_map.get(&sni) {
debug!(
"Found matching SNI '{}', routing to upstream: {}",
sni, target_upstream
);
upstream = target_upstream.clone();
found_match = true;
break;
} else {
trace!("SNI '{}' not found in map.", sni);
}
}
if !found_match {
debug!("SNI(s) found but none matched configuration, using default upstream.");
}
Ok(upstream)
}
None => {
debug!("SNI found but no SNI map configured, using default upstream.");
Ok(default_upstream)
}
}
debug!("SNI(s) found but none matched configuration, using default upstream.");
} else {
debug!("SNI found but no SNI matcher configured, using default upstream.");
}
Ok(default_upstream)
}
fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> {

View File

@@ -1,391 +0,0 @@
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SniPattern {
pub pattern: String,
pub upstream: String,
}
/// Validates and matches SNI patterns against incoming SNI values. Supports exact matches and wildcard patterns.
///
/// Rules:
/// - Wildcard patterns must start with "*." and have a valid domain suffix after it (e.g., "*.example.com")
/// - Wildcard patterns cannot have a wildcard in the middle (e.g., "*.example.*" is invalid)
/// - Wildcard patterns cannot have multiple wildcards (e.g., "*.*.example.com" is invalid)
/// - When there are two wildcard patterns that could match the same SNI, the longest suffix wins
/// (e.g., "*.example.com" vs "*.api.example.com" - "v2.api.example.com" matches "*.api.example.com")
/// - Wildcard patterns that overlap are not matched (e.g. "*.example.com" and "*.bar.example.com" - "bar.example.com" matches neither, "v2.bar.example.com" matches "*.bar.example.com")
/// - Wildcard patterns cannot be just "*." or "*" (must have a valid suffix)
/// - For known public suffixes (e.g., "com", "org"), the wildcard must have at least one label below the public suffix (e.g., "*.example.com" is valid, but "*.com", or *.co.uk is invalid)
/// - For unknown suffixes (e.g., "local", "lan"), the wildcard is allowed without restriction (e.g., "*.local" is valid)
#[derive(Debug, Clone)]
pub struct SniMatcher {
pub exact: HashMap<String, String>,
pub wildcards: Vec<SniPattern>,
}
impl SniMatcher {
pub fn new(sni_map: HashMap<String, String>) -> Result<Self, Vec<String>> {
Self::validate(&sni_map)?;
let mut exact = HashMap::new();
let mut wildcards = Vec::new();
for (pattern, upstream) in sni_map {
if pattern.starts_with("*.") {
wildcards.push(SniPattern {
pattern: pattern.clone(),
upstream,
});
} else {
exact.insert(pattern, upstream);
}
}
wildcards.sort_by(|a, b| {
let a_suffix = a.pattern.trim_start_matches("*.");
let b_suffix = b.pattern.trim_start_matches("*.");
b_suffix.len().cmp(&a_suffix.len())
});
Ok(SniMatcher { exact, wildcards })
}
/// Matches the provided SNI against the patterns in the matcher. Returns Some(upstream) if a match is found, or None if no match is found.
pub fn match_sni(&self, sni: &str) -> Option<String> {
if let Some(upstream) = self.exact.get(sni) {
return Some(upstream.clone());
}
// Try each wildcard in order (longest suffix first)
for wildcard in &self.wildcards {
let suffix = wildcard.pattern.trim_start_matches("*.");
let suffix_len = suffix.len();
let check = format!(".{}", suffix);
if !sni.ends_with(&check) {
continue;
}
// Must have at least one label before the suffix to match
let prefix = &sni[..sni.len() - suffix_len - 1];
// Check if a more specific wildcard could also match this SNI
let is_owned = {
let sni_labels = sni.matches('.').count() + 1;
self.wildcards.iter().any(|w| {
// Skip the current wildcard itself
if w.pattern == wildcard.pattern {
return false;
}
let w_suffix = w.pattern.trim_start_matches("*.");
let w_len = w_suffix.len();
// Only care about wildcards with longer suffix (more specific)
if w_len <= suffix_len {
return false;
}
let w_suffix_labels = w_suffix.matches('.').count() + 1;
if sni == w_suffix {
// Exact match to more specific suffix - owned if could potentially match
// (sni has at least as many labels as the suffix needs)
sni_labels >= w_suffix_labels
} else if sni.ends_with(&format!(".{}", w_suffix)) {
// Ends with more specific suffix - owned if SNI has enough labels
sni_labels >= w_suffix_labels + 1
} else {
false
}
})
};
if is_owned {
continue;
}
// Only return if we have a valid prefix (at least one label)
if !prefix.is_empty() {
return Some(wildcard.upstream.clone());
}
}
None
}
fn validate_wildcard_suffix(pattern: &str) -> Result<(), String> {
let suffix = pattern.trim_start_matches("*.");
let domain_str = format!("a.{}", suffix);
if let Some(ps) = psl::suffix(domain_str.as_bytes()) {
let ps_str = std::str::from_utf8(ps.as_bytes()).unwrap_or("");
if ps_str == suffix {
return Err(format!(
"Invalid wildcard pattern: {} - wildcard cannot be at the public suffix level",
pattern
));
}
}
Ok(())
}
fn validate(sni_map: &HashMap<String, String>) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
for (pattern, _upstream) in sni_map {
if pattern == "*" {
errors.push(format!(
"Invalid wildcard pattern: * - just asterisk is not allowed"
));
continue;
}
if pattern == "*." {
errors.push(format!(
"Invalid wildcard pattern: *. - empty suffix after wildcard"
));
continue;
}
if let Some(rest) = pattern.strip_prefix("*.") {
if rest.is_empty() {
errors.push(format!(
"Invalid wildcard pattern: {pattern} - empty suffix after wildcard"
));
continue;
}
if rest.contains('*') {
errors.push(format!("Invalid wildcard pattern: {pattern} - wildcard cannot be in the middle of suffix"));
continue;
}
if let Err(e) = Self::validate_wildcard_suffix(pattern) {
errors.push(e);
}
} else if pattern.contains('*') {
errors.push(format!(
"Invalid wildcard pattern: {pattern} - wildcard must be at the start"
));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("example.com"),
Some("upstream1".to_string())
);
}
#[test]
fn test_wildcard_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("www.example.com"),
Some("upstream2".to_string())
);
assert_eq!(
matcher.match_sni("api.example.com"),
Some("upstream2".to_string())
);
}
#[test]
fn test_longest_suffix_match() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "wildcard1".to_string());
sni_map.insert("*.api.example.com".to_string(), "wildcard2".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("v2.api.example.com"),
Some("wildcard2".to_string())
);
}
#[test]
fn test_no_match() {
let mut sni_map = HashMap::new();
sni_map.insert("example.com".to_string(), "upstream1".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("other.com"), None);
}
#[test]
fn test_unknown_suffix_allowed() {
let mut sni_map = HashMap::new();
sni_map.insert(
"*.private.local".to_string(),
"private_upstream".to_string(),
);
sni_map.insert(
"*.internal.net".to_string(),
"internal_upstream".to_string(),
);
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("server.private.local"),
Some("private_upstream".to_string())
);
assert_eq!(
matcher.match_sni("app.internal.net"),
Some("internal_upstream".to_string())
);
}
#[test]
fn test_invalid_public_suffix() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(!errors.is_empty());
assert!(errors[0].contains("*.com"));
}
#[test]
fn test_multiple_errors_collected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid1".to_string());
sni_map.insert("*.org".to_string(), "invalid2".to_string());
sni_map.insert("*.net".to_string(), "invalid3".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 3);
}
#[test]
fn test_valid_public_suffix() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "valid".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(
matcher.match_sni("www.example.com"),
Some("valid".to_string())
);
}
#[test]
fn test_validate_static() {
let mut sni_map = HashMap::new();
sni_map.insert("*.com".to_string(), "invalid".to_string());
sni_map.insert("*.example.com".to_string(), "valid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("*.com"));
}
#[test]
fn test_wildcard_not_at_start_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("foo*.example.com".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(!errors.is_empty());
assert!(errors[0].contains("*.example.com"));
}
#[test]
fn test_wildcard_in_middle_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.*".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_trailing_dot_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*.".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_just_asterisk_rejected() {
let mut sni_map = HashMap::new();
sni_map.insert("*".to_string(), "invalid".to_string());
let result = SniMatcher::new(sni_map);
assert!(result.is_err());
}
#[test]
fn test_wildcard_requires_subdomain() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.com".to_string(), "upstream".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("example.com"), None);
assert_eq!(
matcher.match_sni("www.example.com"),
Some("upstream".to_string())
);
assert_eq!(
matcher.match_sni("foo.bar.example.com"),
Some("upstream".to_string())
);
}
#[test]
fn test_longest_suffix_wins_not_shortest() {
let mut sni_map = HashMap::new();
sni_map.insert("*.example.org".to_string(), "broad".to_string());
sni_map.insert("*.bar.example.org".to_string(), "narrow".to_string());
let matcher = SniMatcher::new(sni_map).unwrap();
assert_eq!(matcher.match_sni("bar.example.org"), None);
assert_eq!(
matcher.match_sni("v2.bar.example.org"),
Some("narrow".to_string())
);
assert_eq!(
matcher.match_sni("v2.example.org"),
Some("broad".to_string())
);
}
}

37
src/tls.rs Normal file
View File

@@ -0,0 +1,37 @@
use anyhow::{anyhow, Result};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use rustls::ServerConfig;
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
pub fn load_certs(path: &str) -> Result<Vec<CertificateDer>> {
let mut reader = BufReader::new(File::open(path)?);
certs(&mut reader)
.collect::<Result<Vec<CertificateDer>, std::io::Error>>()
.map_err(|e| anyhow!("failed to load certificates from {}: {}", path, e))
}
pub fn load_private_key(path: &str) -> Result<PrivateKeyDer> {
let mut reader = BufReader::new(File::open(path)?);
let mut keys = pkcs8_private_keys(&mut reader)
.collect::<Result<Vec<PrivatePkcs8KeyDer>, std::io::Error>>()
.map_err(|e| anyhow!("failed to load private keys from {}: {}", path, e))?;
keys.pop()
.map(|k| PrivateKeyDer::Pkcs8(k.to_vec().into()))
.ok_or_else(|| anyhow!("no private keys found for {}", path))
}
pub fn build_tls_acceptor(
config: &crate::config::config_v1::TlsTerminationConfig,
) -> Result<tokio_rustls::TlsAcceptor> {
let certs = load_certs(&config.default_certificate.certificate_path)?;
let key = load_private_key(&config.default_certificate.private_key_path)?;
let tls_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(tls_config)))
}