Compare commits
6 Commits
4c2711fc81
...
v0.1.12
| Author | SHA1 | Date | |
|---|---|---|---|
|
aff96d1a01
|
|||
|
590740f40e
|
|||
|
a674895173
|
|||
|
644ca99004
|
|||
| 913e50ff1c | |||
| aecffa0d14 |
52
.woodpecker/build.yaml
Normal file
52
.woodpecker/build.yaml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
when:
|
||||||
|
- event:
|
||||||
|
- push
|
||||||
|
- tag
|
||||||
|
- manual
|
||||||
|
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- TARGET: x86_64-unknown-linux-musl
|
||||||
|
SHORT: amd64-musl
|
||||||
|
BIN_SUFFIX:
|
||||||
|
- TARGET: aarch64-unknown-linux-musl
|
||||||
|
SHORT: arm64-musl
|
||||||
|
BIN_SUFFIX:
|
||||||
|
- TARGET: x86_64-pc-windows-gnu
|
||||||
|
SHORT: windows
|
||||||
|
BIN_SUFFIX: .exe
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare
|
||||||
|
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
|
||||||
|
commands:
|
||||||
|
- echo Using image img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
|
||||||
|
- mkdir -p artifacts
|
||||||
|
- cargo --version
|
||||||
|
- rustc --version
|
||||||
|
- set
|
||||||
|
|
||||||
|
- name: Build for ${SHORT}
|
||||||
|
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
|
||||||
|
commands:
|
||||||
|
- echo Building ${TARGET} \(${SHORT}\)
|
||||||
|
- cargo zigbuild --release --target ${TARGET}
|
||||||
|
- mkdir -p artifacts
|
||||||
|
- cp target/${TARGET}/release/l4p${BIN_SUFFIX} artifacts/l4p-${TARGET}${BIN_SUFFIX}
|
||||||
|
- rm -rf target/${TARGET}/release/*
|
||||||
|
depends_on:
|
||||||
|
- Prepare
|
||||||
|
|
||||||
|
- name: Release
|
||||||
|
image: img.kie.rs/jjkiers/rust-crossbuild:rust1.79.0-zig0.11.0-zig
|
||||||
|
when:
|
||||||
|
- event: tag
|
||||||
|
commands:
|
||||||
|
- ls -lah artifacts
|
||||||
|
- scripts/create_release_artifacts.sh
|
||||||
|
environment:
|
||||||
|
GITEA_SERVER_TOKEN:
|
||||||
|
from_secret: gitea_token
|
||||||
|
depends_on:
|
||||||
|
- Build for ${SHORT}
|
||||||
20
Cargo.lock
generated
20
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
|||||||
# This file is automatically @generated by Cargo.
|
# This file is automatically @generated by Cargo.
|
||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 3
|
version = 4
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "addr2line"
|
name = "addr2line"
|
||||||
@@ -516,7 +516,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "l4p"
|
name = "l4p"
|
||||||
version = "0.1.9"
|
version = "0.1.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"byte_string",
|
"byte_string",
|
||||||
@@ -525,6 +525,7 @@ dependencies = [
|
|||||||
"log",
|
"log",
|
||||||
"pico-args",
|
"pico-args",
|
||||||
"pretty_env_logger",
|
"pretty_env_logger",
|
||||||
|
"psl",
|
||||||
"self_update",
|
"self_update",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
@@ -856,6 +857,21 @@ dependencies = [
|
|||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "psl"
|
||||||
|
version = "2.1.199"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "70b63978a2742d3f662188698ab45854156e7e34658f53fa951e9253a3dfd583"
|
||||||
|
dependencies = [
|
||||||
|
"psl-types",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "psl-types"
|
||||||
|
version = "2.0.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quick-xml"
|
name = "quick-xml"
|
||||||
version = "0.37.2"
|
version = "0.37.2"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "l4p"
|
name = "l4p"
|
||||||
version = "0.1.10"
|
version = "0.1.12"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Jacob Kiers <code@kiers.eu>"]
|
authors = ["Jacob Kiers <code@kiers.eu>"]
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
@@ -33,6 +33,7 @@ time = { version = "0.3.37", features = ["local-offset", "formatting"] }
|
|||||||
tls-parser = "0.12.2"
|
tls-parser = "0.12.2"
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
url = "2.2.2"
|
url = "2.2.2"
|
||||||
|
psl = "2.1"
|
||||||
|
|
||||||
[dependencies.self_update]
|
[dependencies.self_update]
|
||||||
version = "0.42.0"
|
version = "0.42.0"
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -31,6 +31,13 @@ $ cargo install l4p
|
|||||||
|
|
||||||
Or you can download binary file form the Release page.
|
Or you can download binary file form the Release page.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Listen on specific port and proxy to local or remote port
|
||||||
|
- SNI-based rule without terminating TLS connection
|
||||||
|
- Wildcard SNI matching with DNS-style longest-suffix-match
|
||||||
|
- DNS-based backend with periodic resolution
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
|
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
|
||||||
@@ -55,6 +62,14 @@ There are two upstreams built in:
|
|||||||
|
|
||||||
For detailed configuration, check [this example](./config.yaml.example).
|
For detailed configuration, check [this example](./config.yaml.example).
|
||||||
|
|
||||||
|
### SNI Matching
|
||||||
|
|
||||||
|
The proxy supports both exact and wildcard SNI patterns in the `sni` config. Wildcards use DNS-style longest-suffix-match: more specific patterns take precedence. For example, with `*.example.com` and `*.api.example.com`, request `api.example.com` matches the first, while `v2.api.example.com` matches the second.
|
||||||
|
|
||||||
|
Wildcards are validated against the Public Suffix List (PSL). Known suffixes (`.com`, `.org`) require at least one label below the suffix (`*.example.com` OK, `*.com` rejected). Unknown suffixes (`.local`, `.lan`) are allowed without restriction.
|
||||||
|
|
||||||
|
Invalid wildcard patterns are rejected at config load time with clear error messages.
|
||||||
|
|
||||||
## Thanks
|
## Thanks
|
||||||
|
|
||||||
- [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork.
|
- [`fourth`](https://crates.io/crates/fourth), of which this is a heavily modified fork.
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ servers:
|
|||||||
sni:
|
sni:
|
||||||
api.example.org: example-api
|
api.example.org: example-api
|
||||||
www.example.org: proxy
|
www.example.org: proxy
|
||||||
|
*.example.org: wildcard-proxy # Matches any subdomain of example.org
|
||||||
|
*.dev.example.org: dev-proxy # More specific: matches v2.dev.example.org, etc.
|
||||||
|
*.local: local-upstream # Unknown suffix - allowed (no PSL restriction)
|
||||||
default: ban
|
default: ban
|
||||||
|
|
||||||
second-server:
|
second-server:
|
||||||
@@ -19,3 +22,6 @@ servers:
|
|||||||
upstream:
|
upstream:
|
||||||
proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
|
proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
|
||||||
example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
|
example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
|
||||||
|
wildcard-proxy: "tcp://wildcard.example.org:443"
|
||||||
|
dev-proxy: "tcp://dev.example.org:443"
|
||||||
|
local-upstream: "tcp://localhost:8080"
|
||||||
|
|||||||
50
scripts/create_release_artifacts.sh
Executable file
50
scripts/create_release_artifacts.sh
Executable file
@@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# vim: set expandtab shiftwidth=4 softtabstop=4 tabstop=4 :
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -z "${CI_COMMIT_TAG:-}" ]; then
|
||||||
|
echo "No commit tag set"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
DIR=$(realpath $(dirname "${BASH_SOURCE[0]}") )
|
||||||
|
|
||||||
|
echo ${DIR}
|
||||||
|
|
||||||
|
${DIR}/install_tea.sh linux-amd64 https://gitea.com/api/v1/repos/gitea/tea/releases/latest
|
||||||
|
|
||||||
|
## Log in to Gitea
|
||||||
|
|
||||||
|
TEA=$(pwd)/tea
|
||||||
|
|
||||||
|
if [ -z "${GITEA_SERVER_URL:-}" ]; then
|
||||||
|
|
||||||
|
if [ -z "${CI_FORGE_URL:-}" ]; then
|
||||||
|
echo "Cannot log in to gitea: GITEA_SERVER_URL or CI_FORGE_URL missing"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
GITEA_SERVER_URL=${CI_FORGE_URL}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "${GITEA_SERVER_TOKEN:-}" ]; then
|
||||||
|
echo "Cannot log in to gitea: GITEA_SERVER_TOKEN missing"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! ${TEA} login ls | grep ${GITEA_SERVER_URL} 2>&1 > /dev/null || false; then
|
||||||
|
${TEA} login add
|
||||||
|
else
|
||||||
|
echo "Already logged in to ${GITEA_SERVER_URL}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
## Check and create tag
|
||||||
|
|
||||||
|
if ${TEA} release ls -o json | jq -e --arg tag "${CI_COMMIT_TAG}" 'map(.["tag-_name"]) | index($tag) != null' >/dev/null; then
|
||||||
|
echo "Release ${CI_COMMIT_TAG} exists"
|
||||||
|
else
|
||||||
|
echo "Creating release ${CI_COMMIT_TAG}"
|
||||||
|
${TEA} release create -o json --tag "${CI_COMMIT_TAG}" --title "${CI_COMMIT_TAG}" --draft
|
||||||
|
fi
|
||||||
|
|
||||||
|
find $(dirname ${DIR})/artifacts -type f -exec ${TEA} releases assets create -o json ${CI_COMMIT_TAG} {} +
|
||||||
39
scripts/install_tea.sh
Executable file
39
scripts/install_tea.sh
Executable file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -eo pipefail
|
||||||
|
|
||||||
|
if [ -x ./tea ]; then
|
||||||
|
echo "tea already installed in current directory"; exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
platform="${1:-linux-amd64}"
|
||||||
|
src="${2:-release.json}"
|
||||||
|
|
||||||
|
# obtain JSON: if src looks like a URL fetch it, otherwise treat as filename (or default file)
|
||||||
|
if [[ "$src" =~ ^https?:// ]]; then
|
||||||
|
curl -fsSL "$src" -o /tmp/release.json.$$
|
||||||
|
json="/tmp/release.json.$$"
|
||||||
|
trap 'rm -f "$json"' EXIT
|
||||||
|
elif [ -f "$src" ]; then
|
||||||
|
json="$src"
|
||||||
|
else
|
||||||
|
echo "release JSON not found; provide a filename or URL as second arg" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# read tag and find binary URL (exclude archives/checksums/sigs)
|
||||||
|
tag=$(jq -r '.tag_name' "$json")
|
||||||
|
url=$(jq -r --arg p "$platform" '.assets[]
|
||||||
|
| select(.name | test($p))
|
||||||
|
| select(.name | test("\\.(xz|zip|gz|tar|bz2|7z|sha256|sha256sum|sig|asc)$") | not)
|
||||||
|
| .browser_download_url' "$json" | head -n1)
|
||||||
|
|
||||||
|
[ -n "$url" ] || { echo "binary not found for $platform" >&2; exit 1; }
|
||||||
|
|
||||||
|
tmp="$(mktemp)"
|
||||||
|
trap 'rm -f "$tmp"' EXIT
|
||||||
|
|
||||||
|
curl -fsSL "$url" -o "$tmp"
|
||||||
|
mv "$tmp" tea
|
||||||
|
chmod +x tea
|
||||||
|
echo "Downloaded tag ${tag}: $url -> ./tea"
|
||||||
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::sni_matcher::SniMatcher;
|
||||||
use crate::upstreams::ProxyToUpstream;
|
use crate::upstreams::ProxyToUpstream;
|
||||||
use crate::upstreams::Upstream;
|
use crate::upstreams::Upstream;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
@@ -12,14 +13,6 @@ pub struct ConfigV1 {
|
|||||||
pub base: ParsedConfigV1,
|
pub base: ParsedConfigV1,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Deserialize, Clone)]
|
|
||||||
pub struct ParsedConfigV1 {
|
|
||||||
pub version: i32,
|
|
||||||
pub log: Option<String>,
|
|
||||||
pub servers: HashMap<String, ServerConfig>,
|
|
||||||
pub upstream: HashMap<String, Upstream>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Deserialize, Clone)]
|
#[derive(Debug, Default, Deserialize, Clone)]
|
||||||
pub struct BaseConfig {
|
pub struct BaseConfig {
|
||||||
pub version: i32,
|
pub version: i32,
|
||||||
@@ -28,6 +21,14 @@ pub struct BaseConfig {
|
|||||||
pub upstream: HashMap<String, String>,
|
pub upstream: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct ParsedConfigV1 {
|
||||||
|
pub version: i32,
|
||||||
|
pub log: Option<String>,
|
||||||
|
pub servers: HashMap<String, ParsedServerConfig>,
|
||||||
|
pub upstream: HashMap<String, Upstream>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Deserialize, Clone)]
|
#[derive(Debug, Default, Deserialize, Clone)]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
pub listen: Vec<String>,
|
pub listen: Vec<String>,
|
||||||
@@ -36,6 +37,35 @@ pub struct ServerConfig {
|
|||||||
pub sni: Option<HashMap<String, String>>,
|
pub sni: Option<HashMap<String, String>>,
|
||||||
pub default: Option<String>,
|
pub default: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ServerConfig {
|
||||||
|
pub fn into_parsed(self) -> Result<ParsedServerConfig, Vec<String>> {
|
||||||
|
let sni = match self.sni {
|
||||||
|
Some(sni_map) => {
|
||||||
|
let matcher = SniMatcher::new(sni_map)?;
|
||||||
|
Some(matcher)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ParsedServerConfig {
|
||||||
|
listen: self.listen,
|
||||||
|
protocol: self.protocol,
|
||||||
|
tls: self.tls,
|
||||||
|
sni,
|
||||||
|
default: self.default,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ParsedServerConfig {
|
||||||
|
pub listen: Vec<String>,
|
||||||
|
pub protocol: Option<String>,
|
||||||
|
pub tls: Option<bool>,
|
||||||
|
pub sni: Option<SniMatcher>,
|
||||||
|
pub default: Option<String>,
|
||||||
|
}
|
||||||
impl TryInto<ProxyToUpstream> for &str {
|
impl TryInto<ProxyToUpstream> for &str {
|
||||||
type Error = ConfigError;
|
type Error = ConfigError;
|
||||||
|
|
||||||
@@ -102,12 +132,23 @@ impl ConfigV1 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
/// Load and parse configuration from a YAML string.
|
||||||
let mut contents = String::new();
|
///
|
||||||
let mut file = File::open(path)?;
|
/// This public function takes raw YAML content as a string and returns a parsed,
|
||||||
file.read_to_string(&mut contents)?;
|
/// validated configuration. It performs all validation including:
|
||||||
|
/// - Version checking
|
||||||
let base: BaseConfig = serde_yaml::from_str(&contents)?;
|
/// - SNI pattern validation
|
||||||
|
/// - Upstream URL parsing
|
||||||
|
/// - Cross-reference validation
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `yaml_str` - The YAML configuration content as a string
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Ok(ParsedConfigV1)` - Successfully parsed and validated configuration
|
||||||
|
/// * `Err(ConfigError)` - If YAML parsing fails or validation errors occur
|
||||||
|
pub fn load_config_from_yaml(yaml_str: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||||
|
let base: BaseConfig = serde_yaml::from_str(yaml_str)?;
|
||||||
|
|
||||||
if base.version != 1 {
|
if base.version != 1 {
|
||||||
return Err(ConfigError::Custom(
|
return Err(ConfigError::Custom(
|
||||||
@@ -117,11 +158,12 @@ fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
|||||||
|
|
||||||
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
|
let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
|
||||||
if !log_level.eq("disable") {
|
if !log_level.eq("disable") {
|
||||||
std::env::set_var("FOURTH_LOG", log_level.clone());
|
unsafe {
|
||||||
pretty_env_logger::init_custom_env("FOURTH_LOG");
|
std::env::set_var("FOURTH_LOG", log_level.clone());
|
||||||
|
pretty_env_logger::init_custom_env("FOURTH_LOG");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Using config file: {}", &path);
|
|
||||||
debug!("Set log level to {}", log_level);
|
debug!("Set log level to {}", log_level);
|
||||||
debug!("Config version {}", base.version);
|
debug!("Config version {}", base.version);
|
||||||
|
|
||||||
@@ -135,16 +177,50 @@ fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
|||||||
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
|
parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert ServerConfig to ParsedServerConfig, collecting all SNI validation errors
|
||||||
|
let mut all_errors: Vec<String> = Vec::new();
|
||||||
|
let mut parsed_servers: HashMap<String, ParsedServerConfig> = HashMap::new();
|
||||||
|
|
||||||
|
for (name, server_config) in base.servers {
|
||||||
|
match server_config.into_parsed() {
|
||||||
|
Ok(parsed) => {
|
||||||
|
parsed_servers.insert(name, parsed);
|
||||||
|
}
|
||||||
|
Err(errors) => {
|
||||||
|
for err in errors {
|
||||||
|
all_errors.push(format!("Server '{}': {}", name, err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !all_errors.is_empty() {
|
||||||
|
return Err(ConfigError::Custom(format!(
|
||||||
|
"Invalid SNI configuration:\n{}",
|
||||||
|
all_errors.join("\n")
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
let parsed = ParsedConfigV1 {
|
let parsed = ParsedConfigV1 {
|
||||||
version: base.version,
|
version: base.version,
|
||||||
log: base.log,
|
log: base.log,
|
||||||
servers: base.servers,
|
servers: parsed_servers,
|
||||||
upstream: parsed_upstream,
|
upstream: parsed_upstream,
|
||||||
};
|
};
|
||||||
|
|
||||||
verify_config(parsed)
|
verify_config(parsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
|
||||||
|
let mut contents = String::new();
|
||||||
|
let mut file = File::open(path)?;
|
||||||
|
file.read_to_string(&mut contents)?;
|
||||||
|
|
||||||
|
info!("Using config file: {}", &path);
|
||||||
|
|
||||||
|
load_config_from_yaml(&contents)
|
||||||
|
}
|
||||||
|
|
||||||
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
|
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
|
||||||
let mut used_upstreams: HashSet<String> = HashSet::new();
|
let mut used_upstreams: HashSet<String> = HashSet::new();
|
||||||
let mut upstream_names: HashSet<String> = HashSet::new();
|
let mut upstream_names: HashSet<String> = HashSet::new();
|
||||||
@@ -175,14 +251,20 @@ fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError>
|
|||||||
listen_addresses.insert(listen.to_string());
|
listen_addresses.insert(listen.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.tls.unwrap_or_default() && server.sni.is_some() {
|
if server.tls.unwrap_or_default() {
|
||||||
for (_, val) in server.sni.unwrap() {
|
if let Some(matcher) = &server.sni {
|
||||||
used_upstreams.insert(val.to_string());
|
// Collect all upstream names from the SniMatcher
|
||||||
|
for (_, upstream) in matcher.exact.iter() {
|
||||||
|
used_upstreams.insert(upstream.clone());
|
||||||
|
}
|
||||||
|
for pattern in &matcher.wildcards {
|
||||||
|
used_upstreams.insert(pattern.upstream.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.default.is_some() {
|
if let Some(default) = &server.default {
|
||||||
used_upstreams.insert(server.default.unwrap().to_string());
|
used_upstreams.insert(default.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
for key in &used_upstreams {
|
for key in &used_upstreams {
|
||||||
@@ -225,4 +307,45 @@ mod tests {
|
|||||||
assert_eq!(config.base.servers.len(), 3);
|
assert_eq!(config.base.servers.len(), 3);
|
||||||
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
|
assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_hard_failure_on_invalid_sni() {
|
||||||
|
// Test that invalid SNI wildcard (*.com) causes hard failure
|
||||||
|
let config_content = r#"version: 1
|
||||||
|
log: disable
|
||||||
|
servers:
|
||||||
|
test_server:
|
||||||
|
listen:
|
||||||
|
- "127.0.0.1:8443"
|
||||||
|
protocol: tcp
|
||||||
|
tls: true
|
||||||
|
sni:
|
||||||
|
"*.com": "upstream1"
|
||||||
|
default: ban
|
||||||
|
upstream:
|
||||||
|
upstream1: tcp://127.0.0.1:9000
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let result = load_config_from_yaml(config_content);
|
||||||
|
|
||||||
|
// Should fail with an error
|
||||||
|
assert!(result.is_err(), "Expected config to fail with invalid SNI");
|
||||||
|
|
||||||
|
// Verify error message contains helpful information
|
||||||
|
match result {
|
||||||
|
Err(ConfigError::Custom(msg)) => {
|
||||||
|
assert!(
|
||||||
|
msg.contains("Invalid SNI"),
|
||||||
|
"Error message should mention invalid SNI: {}",
|
||||||
|
msg
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
msg.contains("*.com"),
|
||||||
|
"Error message should mention the invalid pattern: {}",
|
||||||
|
msg
|
||||||
|
);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected ConfigError::Custom"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
mod config_v1;
|
mod config_v1;
|
||||||
pub(crate) use config_v1::ConfigV1;
|
pub(crate) use config_v1::ConfigV1;
|
||||||
pub(crate) use config_v1::ParsedConfigV1;
|
pub(crate) use config_v1::ParsedConfigV1;
|
||||||
|
pub(crate) use config_v1::ParsedServerConfig;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
mod config;
|
mod config;
|
||||||
mod servers;
|
mod servers;
|
||||||
|
mod sni_matcher;
|
||||||
mod update;
|
mod update;
|
||||||
mod upstreams;
|
mod upstreams;
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ mod protocol;
|
|||||||
pub(crate) mod upstream_address;
|
pub(crate) mod upstream_address;
|
||||||
|
|
||||||
use crate::config::ParsedConfigV1;
|
use crate::config::ParsedConfigV1;
|
||||||
|
use crate::sni_matcher::SniMatcher;
|
||||||
use crate::upstreams::Upstream;
|
use crate::upstreams::Upstream;
|
||||||
use protocol::tcp;
|
use protocol::tcp;
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ pub(crate) struct Proxy {
|
|||||||
pub listen: SocketAddr,
|
pub listen: SocketAddr,
|
||||||
pub protocol: String,
|
pub protocol: String,
|
||||||
pub tls: bool,
|
pub tls: bool,
|
||||||
pub sni: Option<HashMap<String, String>>,
|
pub sni: Option<SniMatcher>,
|
||||||
pub default_action: String,
|
pub default_action: String,
|
||||||
pub upstream: HashMap<String, Upstream>,
|
pub upstream: HashMap<String, Upstream>,
|
||||||
}
|
}
|
||||||
@@ -134,6 +135,132 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mock server for wildcard SNI test that responds with "tls_wildcard_response" on first read
|
||||||
|
#[tokio::main]
|
||||||
|
async fn tls_mock_server_wildcard() {
|
||||||
|
let server_addr: SocketAddr = "127.0.0.1:54598".parse().unwrap();
|
||||||
|
let listener = TcpListener::bind(server_addr).await.unwrap();
|
||||||
|
loop {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
// Read client hello (which will be peeked but not actually read by proxy)
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
let _ = stream.read(&mut buf).await;
|
||||||
|
// Send a response to verify connection succeeded
|
||||||
|
let _ = stream.write(b"tls_wildcard_response").await;
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mock server for SNI test that doesn't match wildcard pattern
|
||||||
|
#[tokio::main]
|
||||||
|
async fn tls_mock_server_default() {
|
||||||
|
let server_addr: SocketAddr = "127.0.0.1:54597".parse().unwrap();
|
||||||
|
let listener = TcpListener::bind(server_addr).await.unwrap();
|
||||||
|
loop {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
let _ = stream.read(&mut buf).await;
|
||||||
|
let _ = stream.write(b"tls_default_response").await;
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to build a minimal TLS ClientHello with SNI extension
|
||||||
|
/// This creates a valid TLS 1.2 ClientHello packet with the specified SNI hostname
|
||||||
|
fn build_tls_client_hello(sni_hostname: &str) -> Vec<u8> {
|
||||||
|
// TLS record header (9 bytes)
|
||||||
|
let mut hello = Vec::new();
|
||||||
|
|
||||||
|
// Record type: Handshake (0x16)
|
||||||
|
hello.push(0x16);
|
||||||
|
// Version: TLS 1.2 (0x0303)
|
||||||
|
hello.extend_from_slice(&[0x03, 0x03]);
|
||||||
|
|
||||||
|
// We'll set the record length later
|
||||||
|
let record_length_pos = hello.len();
|
||||||
|
hello.extend_from_slice(&[0x00, 0x00]); // Placeholder for record length
|
||||||
|
|
||||||
|
// Handshake message type: ClientHello (0x01)
|
||||||
|
hello.push(0x01);
|
||||||
|
|
||||||
|
// We'll set the handshake length later
|
||||||
|
let handshake_length_pos = hello.len();
|
||||||
|
hello.extend_from_slice(&[0x00, 0x00, 0x00]); // Placeholder for handshake length
|
||||||
|
|
||||||
|
// ClientHello fields
|
||||||
|
// Protocol version: TLS 1.2 (0x0303)
|
||||||
|
hello.extend_from_slice(&[0x03, 0x03]);
|
||||||
|
|
||||||
|
// Random: 32 bytes (we'll use a fixed pattern)
|
||||||
|
hello.extend_from_slice(&[0x00; 32]);
|
||||||
|
|
||||||
|
// Session ID length: 0 (no session)
|
||||||
|
hello.push(0x00);
|
||||||
|
|
||||||
|
// Cipher suites length: 2 bytes + cipher suites
|
||||||
|
hello.extend_from_slice(&[0x00, 0x02]); // Length of cipher suites list (2 bytes)
|
||||||
|
// Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002F)
|
||||||
|
hello.extend_from_slice(&[0x00, 0x2F]);
|
||||||
|
|
||||||
|
// Compression methods length: 1 byte
|
||||||
|
hello.push(0x01);
|
||||||
|
// Compression method: null (0x00)
|
||||||
|
hello.push(0x00);
|
||||||
|
|
||||||
|
// Extensions
|
||||||
|
let extensions_start = hello.len();
|
||||||
|
|
||||||
|
// SNI Extension (type 0x0000)
|
||||||
|
let mut sni_extension = Vec::new();
|
||||||
|
sni_extension.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
|
||||||
|
|
||||||
|
// Extension length (will be set later)
|
||||||
|
let ext_length_pos = sni_extension.len();
|
||||||
|
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder
|
||||||
|
|
||||||
|
// Server name list
|
||||||
|
let server_name_list_start = sni_extension.len();
|
||||||
|
sni_extension.extend_from_slice(&[0x00, 0x00]); // Placeholder for server name list length
|
||||||
|
|
||||||
|
// Server name: host_name(0), length, hostname
|
||||||
|
sni_extension.push(0x00); // Name type: host_name
|
||||||
|
let hostname_bytes = sni_hostname.as_bytes();
|
||||||
|
sni_extension.extend_from_slice(&[(hostname_bytes.len() >> 8) as u8, (hostname_bytes.len() & 0xFF) as u8]);
|
||||||
|
sni_extension.extend_from_slice(hostname_bytes);
|
||||||
|
|
||||||
|
// Set server name list length
|
||||||
|
let server_name_list_len = sni_extension.len() - server_name_list_start - 2;
|
||||||
|
let pos = server_name_list_start;
|
||||||
|
sni_extension[pos] = (server_name_list_len >> 8) as u8;
|
||||||
|
sni_extension[pos + 1] = (server_name_list_len & 0xFF) as u8;
|
||||||
|
|
||||||
|
// Set extension length
|
||||||
|
let ext_len = sni_extension.len() - ext_length_pos - 2;
|
||||||
|
sni_extension[ext_length_pos] = (ext_len >> 8) as u8;
|
||||||
|
sni_extension[ext_length_pos + 1] = (ext_len & 0xFF) as u8;
|
||||||
|
|
||||||
|
// Add SNI extension to hello
|
||||||
|
hello.extend_from_slice(&sni_extension);
|
||||||
|
|
||||||
|
// Set extensions total length
|
||||||
|
let extensions_length = hello.len() - extensions_start;
|
||||||
|
hello.insert(extensions_start, (extensions_length & 0xFF) as u8);
|
||||||
|
hello.insert(extensions_start, (extensions_length >> 8) as u8);
|
||||||
|
|
||||||
|
// Set handshake message length
|
||||||
|
let handshake_len = hello.len() - handshake_length_pos - 3;
|
||||||
|
hello[handshake_length_pos] = (handshake_len >> 16) as u8;
|
||||||
|
hello[handshake_length_pos + 1] = (handshake_len >> 8) as u8;
|
||||||
|
hello[handshake_length_pos + 2] = (handshake_len & 0xFF) as u8;
|
||||||
|
|
||||||
|
// Set record length
|
||||||
|
let record_len = hello.len() - record_length_pos - 2;
|
||||||
|
hello[record_length_pos] = (record_len >> 8) as u8;
|
||||||
|
hello[record_length_pos + 1] = (record_len & 0xFF) as u8;
|
||||||
|
|
||||||
|
hello
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_proxy() {
|
async fn test_proxy() {
|
||||||
use crate::config::ConfigV1;
|
use crate::config::ConfigV1;
|
||||||
@@ -170,4 +297,118 @@ mod tests {
|
|||||||
}
|
}
|
||||||
conn.shutdown().await.unwrap();
|
conn.shutdown().await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_wildcard_sni_routing() {
|
||||||
|
// Create test configuration with wildcard SNI pattern
|
||||||
|
use crate::upstreams::Upstream;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
// Start mock servers for upstreams
|
||||||
|
thread::spawn(move || {
|
||||||
|
tls_mock_server_wildcard();
|
||||||
|
});
|
||||||
|
thread::spawn(move || {
|
||||||
|
tls_mock_server_default();
|
||||||
|
});
|
||||||
|
sleep(Duration::from_millis(500)); // wait for mock servers to start
|
||||||
|
|
||||||
|
// Create inline configuration
|
||||||
|
let mut config = crate::config::ParsedConfigV1 {
|
||||||
|
version: 1,
|
||||||
|
log: Some("disable".to_string()),
|
||||||
|
servers: HashMap::new(),
|
||||||
|
upstream: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add upstreams
|
||||||
|
config.upstream.insert(
|
||||||
|
"wildcard_upstream".to_string(),
|
||||||
|
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
|
||||||
|
"127.0.0.1:54598".to_string(),
|
||||||
|
"tcp".to_string(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
config.upstream.insert(
|
||||||
|
"default_upstream".to_string(),
|
||||||
|
Upstream::Proxy(crate::upstreams::ProxyToUpstream::new(
|
||||||
|
"127.0.0.1:54597".to_string(),
|
||||||
|
"tcp".to_string(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add TLS server with wildcard SNI pattern
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.api.example.com".to_string(), "wildcard_upstream".to_string());
|
||||||
|
|
||||||
|
let server_config = crate::config::ParsedServerConfig {
|
||||||
|
listen: vec!["127.0.0.1:54595".to_string()],
|
||||||
|
protocol: Some("tcp".to_string()),
|
||||||
|
tls: Some(true),
|
||||||
|
sni: Some(crate::sni_matcher::SniMatcher::new(sni_map).unwrap()),
|
||||||
|
default: Some("default_upstream".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
config.servers.insert("wildcard_test_server".to_string(), server_config);
|
||||||
|
|
||||||
|
// Start proxy server
|
||||||
|
let mut server = Server::new_from_v1_config(config);
|
||||||
|
thread::spawn(move || {
|
||||||
|
let _ = server.run();
|
||||||
|
});
|
||||||
|
sleep(Duration::from_secs(1)); // wait for proxy to start
|
||||||
|
|
||||||
|
// Test 1: Send ClientHello with SNI matching wildcard pattern
|
||||||
|
// Expected: Should route to wildcard_upstream (127.0.0.1:54598)
|
||||||
|
let client_hello = build_tls_client_hello("app.api.example.com");
|
||||||
|
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let _ = conn.write(&client_hello).await.unwrap();
|
||||||
|
let mut response = [0u8; 21];
|
||||||
|
let n = conn.read(&mut response).await.unwrap();
|
||||||
|
assert!(n > 0, "Should receive response from wildcard upstream");
|
||||||
|
assert_eq!(
|
||||||
|
&response[..n],
|
||||||
|
b"tls_wildcard_response",
|
||||||
|
"Should receive expected response from wildcard upstream"
|
||||||
|
);
|
||||||
|
conn.shutdown().await.unwrap();
|
||||||
|
|
||||||
|
// Test 2: Send ClientHello with SNI not matching any pattern
|
||||||
|
// Expected: Should route to default_upstream (127.0.0.1:54597)
|
||||||
|
let client_hello_nomatch = build_tls_client_hello("unrelated.example.com");
|
||||||
|
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let _ = conn.write(&client_hello_nomatch).await.unwrap();
|
||||||
|
let mut response = [0u8; 20];
|
||||||
|
let n = conn.read(&mut response).await.unwrap();
|
||||||
|
assert!(n > 0, "Should receive response from default upstream");
|
||||||
|
assert_eq!(
|
||||||
|
&response[..n],
|
||||||
|
b"tls_default_response",
|
||||||
|
"Should receive expected response from default upstream"
|
||||||
|
);
|
||||||
|
conn.shutdown().await.unwrap();
|
||||||
|
|
||||||
|
// Test 3: Send ClientHello with another SNI matching wildcard pattern
|
||||||
|
let client_hello_match2 = build_tls_client_hello("v2.api.example.com");
|
||||||
|
let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54595")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let _ = conn.write(&client_hello_match2).await.unwrap();
|
||||||
|
let mut response = [0u8; 21];
|
||||||
|
let n = conn.read(&mut response).await.unwrap();
|
||||||
|
assert!(n > 0, "Should receive response from wildcard upstream for second match");
|
||||||
|
assert_eq!(
|
||||||
|
&response[..n],
|
||||||
|
b"tls_wildcard_response",
|
||||||
|
"Should receive expected response from wildcard upstream for second match"
|
||||||
|
);
|
||||||
|
conn.shutdown().await.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
use crate::servers::Proxy;
|
use crate::servers::Proxy;
|
||||||
use log::{debug, error, trace, warn};
|
use log::{debug, error, trace, warn};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
use std::io; // Import io for ErrorKind
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration; // For potential delays
|
||||||
use tls_parser::{
|
use tls_parser::{
|
||||||
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
|
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
|
||||||
TlsMessageHandshake,
|
TlsMessageHandshake,
|
||||||
};
|
};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::time::timeout; // Use timeout for peek operations
|
||||||
|
|
||||||
fn get_sni(buf: &[u8]) -> Vec<String> {
|
fn get_sni(buf: &[u8]) -> Vec<String> {
|
||||||
let mut snis: Vec<String> = Vec::new();
|
let mut snis: Vec<String> = Vec::new();
|
||||||
@@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec<String> {
|
|||||||
snis
|
snis
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Timeout duration for waiting for TLS Hello data
|
||||||
|
const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed
|
||||||
|
|
||||||
pub(crate) async fn determine_upstream_name(
|
pub(crate) async fn determine_upstream_name(
|
||||||
inbound: &TcpStream,
|
inbound: &TcpStream,
|
||||||
proxy: &Arc<Proxy>,
|
proxy: &Arc<Proxy>,
|
||||||
@@ -64,37 +70,162 @@ pub(crate) async fn determine_upstream_name(
|
|||||||
let default_upstream = proxy.default_action.clone();
|
let default_upstream = proxy.default_action.clone();
|
||||||
|
|
||||||
let mut header = [0u8; 9];
|
let mut header = [0u8; 9];
|
||||||
inbound.peek(&mut header).await?;
|
|
||||||
|
|
||||||
let required_bytes = client_hello_buffer_size(&header)?;
|
// --- Step 1: Peek the initial header (9 bytes) with timeout ---
|
||||||
|
match timeout(TLS_PEEK_TIMEOUT, async {
|
||||||
let mut hello_buf = vec![0; required_bytes];
|
loop {
|
||||||
let read_bytes = inbound.peek(&mut hello_buf).await?;
|
match inbound.peek(&mut header).await {
|
||||||
|
Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes
|
||||||
if read_bytes < required_bytes.into() {
|
Ok(0) => {
|
||||||
error!("Could not read enough bytes to determine SNI");
|
// Connection closed cleanly before sending enough data
|
||||||
return Ok(default_upstream);
|
trace!("Connection closed while peeking for TLS header");
|
||||||
}
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::UnexpectedEof,
|
||||||
let snis = get_sni(&hello_buf);
|
"Connection closed while peeking for TLS header",
|
||||||
if snis.is_empty() {
|
)
|
||||||
return Ok(default_upstream);
|
.into()); // Convert to Box<dyn Error>
|
||||||
} else {
|
}
|
||||||
match proxy.sni.clone() {
|
Ok(_) => {
|
||||||
Some(sni_map) => {
|
// Not enough bytes yet, yield and loop again
|
||||||
let mut upstream = default_upstream;
|
tokio::task::yield_now().await;
|
||||||
for sni in snis {
|
}
|
||||||
let m = sni_map.get(&sni);
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||||
if m.is_some() {
|
// Should not happen with await, but yield defensively
|
||||||
upstream = m.unwrap().clone();
|
tokio::task::yield_now().await;
|
||||||
break;
|
}
|
||||||
}
|
Err(e) => {
|
||||||
|
// Other I/O error
|
||||||
|
warn!("Error peeking for TLS header: {}", e);
|
||||||
|
return Err(e.into()); // Convert to Box<dyn Error>
|
||||||
}
|
}
|
||||||
Ok(upstream)
|
|
||||||
}
|
}
|
||||||
None => return Ok(default_upstream),
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => { /* Header peeked successfully */ }
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
// Inner loop returned an error (e.g., EOF, IO error)
|
||||||
|
trace!("Failed to peek header (inner error): {}", e);
|
||||||
|
return Ok(default_upstream); // Fallback on error/EOF
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Timeout occurred
|
||||||
|
error!("Timeout waiting for TLS header");
|
||||||
|
return Ok(default_upstream); // Fallback on timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Step 2: Calculate required size ---
|
||||||
|
let required_bytes = match client_hello_buffer_size(&header) {
|
||||||
|
Ok(size) => size,
|
||||||
|
Err(e) => {
|
||||||
|
// Header was invalid or not a ClientHello
|
||||||
|
trace!("Could not determine required buffer size: {}", e);
|
||||||
|
return Ok(default_upstream);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Basic sanity check on size
|
||||||
|
if required_bytes > 16384 + 9 {
|
||||||
|
// TLS max record size + header approx
|
||||||
|
error!(
|
||||||
|
"Calculated required TLS buffer size is too large: {}",
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
return Ok(default_upstream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Step 3: Peek the full ClientHello with timeout ---
|
||||||
|
let mut hello_buf = vec![0; required_bytes];
|
||||||
|
match timeout(TLS_PEEK_TIMEOUT, async {
|
||||||
|
let mut total_peeked = 0;
|
||||||
|
loop {
|
||||||
|
// Peek into the portion of the buffer that hasn't been filled yet.
|
||||||
|
match inbound.peek(&mut hello_buf[total_peeked..]).await {
|
||||||
|
Ok(0) => {
|
||||||
|
// Connection closed cleanly before sending full ClientHello
|
||||||
|
trace!(
|
||||||
|
"Connection closed while peeking for full ClientHello (peeked {}/{} bytes)",
|
||||||
|
total_peeked,
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
return Err::<usize, io::Error>(
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::UnexpectedEof,
|
||||||
|
"Connection closed while peeking for full ClientHello",
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(n) => {
|
||||||
|
total_peeked += n;
|
||||||
|
if total_peeked >= required_bytes {
|
||||||
|
trace!("Successfully peeked {} bytes for ClientHello", total_peeked);
|
||||||
|
return Ok(total_peeked); // Got enough
|
||||||
|
} else {
|
||||||
|
// Not enough bytes yet, yield and loop again
|
||||||
|
trace!(
|
||||||
|
"Peeked {}/{} bytes for ClientHello, waiting for more...",
|
||||||
|
total_peeked,
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Error peeking for full ClientHello: {}", e);
|
||||||
|
return Err(e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => { /* Full hello peeked successfully */ }
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
error!("Could not peek full ClientHello (inner error): {}", e);
|
||||||
|
return Ok(default_upstream); // Fallback on error/EOF
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
error!(
|
||||||
|
"Timeout waiting for full ClientHello (needed {} bytes)",
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
return Ok(default_upstream); // Fallback on timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Step 4: Parse SNI ---
|
||||||
|
let snis = get_sni(&hello_buf);
|
||||||
|
|
||||||
|
// --- Step 5: Determine upstream based on SNI ---
|
||||||
|
if snis.is_empty() {
|
||||||
|
debug!("No SNI found in ClientHello, using default upstream.");
|
||||||
|
return Ok(default_upstream);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(matcher) = &proxy.sni {
|
||||||
|
for sni in snis {
|
||||||
|
if let Some(upstream) = matcher.match_sni(&sni) {
|
||||||
|
debug!(
|
||||||
|
"Found matching SNI '{}', routing to upstream: {}",
|
||||||
|
sni, upstream
|
||||||
|
);
|
||||||
|
return Ok(upstream);
|
||||||
|
} else {
|
||||||
|
trace!("SNI '{}' not found in matcher.", sni);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug!("SNI(s) found but none matched configuration, using default upstream.");
|
||||||
|
} else {
|
||||||
|
debug!("SNI found but no SNI matcher configured, using default upstream.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(default_upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> {
|
fn client_hello_buffer_size(data: &[u8]) -> Result<usize, String> {
|
||||||
|
|||||||
391
src/sni_matcher.rs
Normal file
391
src/sni_matcher.rs
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SniPattern {
|
||||||
|
pub pattern: String,
|
||||||
|
pub upstream: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates and matches SNI patterns against incoming SNI values. Supports exact matches and wildcard patterns.
|
||||||
|
///
|
||||||
|
/// Rules:
|
||||||
|
/// - Wildcard patterns must start with "*." and have a valid domain suffix after it (e.g., "*.example.com")
|
||||||
|
/// - Wildcard patterns cannot have a wildcard in the middle (e.g., "*.example.*" is invalid)
|
||||||
|
/// - Wildcard patterns cannot have multiple wildcards (e.g., "*.*.example.com" is invalid)
|
||||||
|
/// - When there are two wildcard patterns that could match the same SNI, the longest suffix wins
|
||||||
|
/// (e.g., "*.example.com" vs "*.api.example.com" - "v2.api.example.com" matches "*.api.example.com")
|
||||||
|
/// - Wildcard patterns that overlap are not matched (e.g. "*.example.com" and "*.bar.example.com" - "bar.example.com" matches neither, "v2.bar.example.com" matches "*.bar.example.com")
|
||||||
|
/// - Wildcard patterns cannot be just "*." or "*" (must have a valid suffix)
|
||||||
|
/// - For known public suffixes (e.g., "com", "org"), the wildcard must have at least one label below the public suffix (e.g., "*.example.com" is valid, but "*.com", or *.co.uk is invalid)
|
||||||
|
/// - For unknown suffixes (e.g., "local", "lan"), the wildcard is allowed without restriction (e.g., "*.local" is valid)
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SniMatcher {
|
||||||
|
pub exact: HashMap<String, String>,
|
||||||
|
pub wildcards: Vec<SniPattern>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SniMatcher {
|
||||||
|
pub fn new(sni_map: HashMap<String, String>) -> Result<Self, Vec<String>> {
|
||||||
|
Self::validate(&sni_map)?;
|
||||||
|
|
||||||
|
let mut exact = HashMap::new();
|
||||||
|
let mut wildcards = Vec::new();
|
||||||
|
|
||||||
|
for (pattern, upstream) in sni_map {
|
||||||
|
if pattern.starts_with("*.") {
|
||||||
|
wildcards.push(SniPattern {
|
||||||
|
pattern: pattern.clone(),
|
||||||
|
upstream,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
exact.insert(pattern, upstream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wildcards.sort_by(|a, b| {
|
||||||
|
let a_suffix = a.pattern.trim_start_matches("*.");
|
||||||
|
let b_suffix = b.pattern.trim_start_matches("*.");
|
||||||
|
b_suffix.len().cmp(&a_suffix.len())
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(SniMatcher { exact, wildcards })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Matches the provided SNI against the patterns in the matcher. Returns Some(upstream) if a match is found, or None if no match is found.
|
||||||
|
pub fn match_sni(&self, sni: &str) -> Option<String> {
|
||||||
|
if let Some(upstream) = self.exact.get(sni) {
|
||||||
|
return Some(upstream.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each wildcard in order (longest suffix first)
|
||||||
|
for wildcard in &self.wildcards {
|
||||||
|
let suffix = wildcard.pattern.trim_start_matches("*.");
|
||||||
|
let suffix_len = suffix.len();
|
||||||
|
let check = format!(".{}", suffix);
|
||||||
|
|
||||||
|
if !sni.ends_with(&check) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must have at least one label before the suffix to match
|
||||||
|
let prefix = &sni[..sni.len() - suffix_len - 1];
|
||||||
|
|
||||||
|
// Check if a more specific wildcard could also match this SNI
|
||||||
|
let is_owned = {
|
||||||
|
let sni_labels = sni.matches('.').count() + 1;
|
||||||
|
|
||||||
|
self.wildcards.iter().any(|w| {
|
||||||
|
// Skip the current wildcard itself
|
||||||
|
if w.pattern == wildcard.pattern {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let w_suffix = w.pattern.trim_start_matches("*.");
|
||||||
|
let w_len = w_suffix.len();
|
||||||
|
|
||||||
|
// Only care about wildcards with longer suffix (more specific)
|
||||||
|
if w_len <= suffix_len {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let w_suffix_labels = w_suffix.matches('.').count() + 1;
|
||||||
|
|
||||||
|
if sni == w_suffix {
|
||||||
|
// Exact match to more specific suffix - owned if could potentially match
|
||||||
|
// (sni has at least as many labels as the suffix needs)
|
||||||
|
sni_labels >= w_suffix_labels
|
||||||
|
} else if sni.ends_with(&format!(".{}", w_suffix)) {
|
||||||
|
// Ends with more specific suffix - owned if SNI has enough labels
|
||||||
|
sni_labels >= w_suffix_labels + 1
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_owned {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only return if we have a valid prefix (at least one label)
|
||||||
|
if !prefix.is_empty() {
|
||||||
|
return Some(wildcard.upstream.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_wildcard_suffix(pattern: &str) -> Result<(), String> {
|
||||||
|
let suffix = pattern.trim_start_matches("*.");
|
||||||
|
let domain_str = format!("a.{}", suffix);
|
||||||
|
|
||||||
|
if let Some(ps) = psl::suffix(domain_str.as_bytes()) {
|
||||||
|
let ps_str = std::str::from_utf8(ps.as_bytes()).unwrap_or("");
|
||||||
|
if ps_str == suffix {
|
||||||
|
return Err(format!(
|
||||||
|
"Invalid wildcard pattern: {} - wildcard cannot be at the public suffix level",
|
||||||
|
pattern
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(sni_map: &HashMap<String, String>) -> Result<(), Vec<String>> {
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
|
||||||
|
for (pattern, _upstream) in sni_map {
|
||||||
|
if pattern == "*" {
|
||||||
|
errors.push(format!(
|
||||||
|
"Invalid wildcard pattern: * - just asterisk is not allowed"
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if pattern == "*." {
|
||||||
|
errors.push(format!(
|
||||||
|
"Invalid wildcard pattern: *. - empty suffix after wildcard"
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rest) = pattern.strip_prefix("*.") {
|
||||||
|
if rest.is_empty() {
|
||||||
|
errors.push(format!(
|
||||||
|
"Invalid wildcard pattern: {pattern} - empty suffix after wildcard"
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if rest.contains('*') {
|
||||||
|
errors.push(format!("Invalid wildcard pattern: {pattern} - wildcard cannot be in the middle of suffix"));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = Self::validate_wildcard_suffix(pattern) {
|
||||||
|
errors.push(e);
|
||||||
|
}
|
||||||
|
} else if pattern.contains('*') {
|
||||||
|
errors.push(format!(
|
||||||
|
"Invalid wildcard pattern: {pattern} - wildcard must be at the start"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.is_empty() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_exact_match() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("example.com".to_string(), "upstream1".to_string());
|
||||||
|
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
|
||||||
|
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("example.com"),
|
||||||
|
Some("upstream1".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_match() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("example.com".to_string(), "upstream1".to_string());
|
||||||
|
sni_map.insert("*.example.com".to_string(), "upstream2".to_string());
|
||||||
|
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("www.example.com"),
|
||||||
|
Some("upstream2".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("api.example.com"),
|
||||||
|
Some("upstream2".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_longest_suffix_match() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.example.com".to_string(), "wildcard1".to_string());
|
||||||
|
sni_map.insert("*.api.example.com".to_string(), "wildcard2".to_string());
|
||||||
|
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("v2.api.example.com"),
|
||||||
|
Some("wildcard2".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_match() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("example.com".to_string(), "upstream1".to_string());
|
||||||
|
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matcher.match_sni("other.com"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unknown_suffix_allowed() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert(
|
||||||
|
"*.private.local".to_string(),
|
||||||
|
"private_upstream".to_string(),
|
||||||
|
);
|
||||||
|
sni_map.insert(
|
||||||
|
"*.internal.net".to_string(),
|
||||||
|
"internal_upstream".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("server.private.local"),
|
||||||
|
Some("private_upstream".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("app.internal.net"),
|
||||||
|
Some("internal_upstream".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_public_suffix() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.com".to_string(), "invalid".to_string());
|
||||||
|
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let errors = result.unwrap_err();
|
||||||
|
assert!(!errors.is_empty());
|
||||||
|
assert!(errors[0].contains("*.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_errors_collected() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.com".to_string(), "invalid1".to_string());
|
||||||
|
sni_map.insert("*.org".to_string(), "invalid2".to_string());
|
||||||
|
sni_map.insert("*.net".to_string(), "invalid3".to_string());
|
||||||
|
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let errors = result.unwrap_err();
|
||||||
|
assert_eq!(errors.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_valid_public_suffix() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.example.com".to_string(), "valid".to_string());
|
||||||
|
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("www.example.com"),
|
||||||
|
Some("valid".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_static() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.com".to_string(), "invalid".to_string());
|
||||||
|
sni_map.insert("*.example.com".to_string(), "valid".to_string());
|
||||||
|
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let errors = result.unwrap_err();
|
||||||
|
assert_eq!(errors.len(), 1);
|
||||||
|
assert!(errors[0].contains("*.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_not_at_start_rejected() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("foo*.example.com".to_string(), "invalid".to_string());
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
let errors = result.unwrap_err();
|
||||||
|
assert!(!errors.is_empty());
|
||||||
|
assert!(errors[0].contains("*.example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_in_middle_rejected() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.example.*".to_string(), "invalid".to_string());
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trailing_dot_rejected() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.".to_string(), "invalid".to_string());
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_just_asterisk_rejected() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*".to_string(), "invalid".to_string());
|
||||||
|
let result = SniMatcher::new(sni_map);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_requires_subdomain() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.example.com".to_string(), "upstream".to_string());
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matcher.match_sni("example.com"), None);
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("www.example.com"),
|
||||||
|
Some("upstream".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("foo.bar.example.com"),
|
||||||
|
Some("upstream".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_longest_suffix_wins_not_shortest() {
|
||||||
|
let mut sni_map = HashMap::new();
|
||||||
|
sni_map.insert("*.example.org".to_string(), "broad".to_string());
|
||||||
|
sni_map.insert("*.bar.example.org".to_string(), "narrow".to_string());
|
||||||
|
let matcher = SniMatcher::new(sni_map).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matcher.match_sni("bar.example.org"), None);
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("v2.bar.example.org"),
|
||||||
|
Some("narrow".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matcher.match_sni("v2.example.org"),
|
||||||
|
Some("broad".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user