Wait for the entire TLS header to become available, even if it takes

multiple packets.

Closes: #10
This commit is contained in:
Jacob Kiers 2025-01-21 21:22:53 +01:00
parent 4c2711fc81
commit aecffa0d14

View File

@ -1,12 +1,15 @@
use crate::servers::Proxy; use crate::servers::Proxy;
use log::{debug, error, trace, warn}; use log::{debug, error, trace, warn};
use std::error::Error; use std::error::Error;
use std::io; // Import io for ErrorKind
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; // For potential delays
use tls_parser::{ use tls_parser::{
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage, parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
TlsMessageHandshake, TlsMessageHandshake,
}; };
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::timeout; // Use timeout for peek operations
fn get_sni(buf: &[u8]) -> Vec<String> { fn get_sni(buf: &[u8]) -> Vec<String> {
let mut snis: Vec<String> = Vec::new(); let mut snis: Vec<String> = Vec::new();
@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec<String> {
snis snis
} }
// Timeout duration for waiting for TLS Hello data
const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed
pub(crate) async fn determine_upstream_name( pub(crate) async fn determine_upstream_name(
inbound: &TcpStream, inbound: &TcpStream,
proxy: &Arc<Proxy>, proxy: &Arc<Proxy>,
@ -64,35 +70,170 @@ pub(crate) async fn determine_upstream_name(
let default_upstream = proxy.default_action.clone(); let default_upstream = proxy.default_action.clone();
let mut header = [0u8; 9]; let mut header = [0u8; 9];
inbound.peek(&mut header).await?;
let required_bytes = client_hello_buffer_size(&header)?; // --- Step 1: Peek the initial header (9 bytes) with timeout ---
match timeout(TLS_PEEK_TIMEOUT, async {
loop {
match inbound.peek(&mut header).await {
Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes
Ok(0) => {
// Connection closed cleanly before sending enough data
trace!("Connection closed while peeking for TLS header");
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed while peeking for TLS header",
)
.into()); // Convert to Box<dyn Error>
}
Ok(_) => {
// Not enough bytes yet, yield and loop again
tokio::task::yield_now().await;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// Should not happen with await, but yield defensively
tokio::task::yield_now().await;
}
Err(e) => {
// Other I/O error
warn!("Error peeking for TLS header: {}", e);
return Err(e.into()); // Convert to Box<dyn Error>
}
}
}
})
.await
{
Ok(Ok(_)) => { /* Header peeked successfully */ }
Ok(Err(e)) => {
// Inner loop returned an error (e.g., EOF, IO error)
trace!("Failed to peek header (inner error): {}", e);
return Ok(default_upstream); // Fallback on error/EOF
}
Err(_) => {
// Timeout occurred
error!("Timeout waiting for TLS header");
return Ok(default_upstream); // Fallback on timeout
}
}
let mut hello_buf = vec![0; required_bytes]; // --- Step 2: Calculate required size ---
let read_bytes = inbound.peek(&mut hello_buf).await?; let required_bytes = match client_hello_buffer_size(&header) {
Ok(size) => size,
Err(e) => {
// Header was invalid or not a ClientHello
trace!("Could not determine required buffer size: {}", e);
return Ok(default_upstream);
}
};
if read_bytes < required_bytes.into() { // Basic sanity check on size
error!("Could not read enough bytes to determine SNI"); if required_bytes > 16384 + 9 {
// TLS max record size + header approx
error!(
"Calculated required TLS buffer size is too large: {}",
required_bytes
);
return Ok(default_upstream); return Ok(default_upstream);
} }
// --- Step 3: Peek the full ClientHello with timeout ---
let mut hello_buf = vec![0; required_bytes];
match timeout(TLS_PEEK_TIMEOUT, async {
let mut total_peeked = 0;
loop {
// Peek into the portion of the buffer that hasn't been filled yet.
match inbound.peek(&mut hello_buf[total_peeked..]).await {
Ok(0) => {
// Connection closed cleanly before sending full ClientHello
trace!(
"Connection closed while peeking for full ClientHello (peeked {}/{} bytes)",
total_peeked,
required_bytes
);
return Err::<usize, io::Error>(
io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed while peeking for full ClientHello",
)
.into(),
);
}
Ok(n) => {
total_peeked += n;
if total_peeked >= required_bytes {
trace!("Successfully peeked {} bytes for ClientHello", total_peeked);
return Ok(total_peeked); // Got enough
} else {
// Not enough bytes yet, yield and loop again
trace!(
"Peeked {}/{} bytes for ClientHello, waiting for more...",
total_peeked,
required_bytes
);
tokio::task::yield_now().await;
}
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
tokio::task::yield_now().await;
}
Err(e) => {
warn!("Error peeking for full ClientHello: {}", e);
return Err(e.into());
}
}
}
})
.await
{
Ok(Ok(_)) => { /* Full hello peeked successfully */ }
Ok(Err(e)) => {
error!("Could not peek full ClientHello (inner error): {}", e);
return Ok(default_upstream); // Fallback on error/EOF
}
Err(_) => {
error!(
"Timeout waiting for full ClientHello (needed {} bytes)",
required_bytes
);
return Ok(default_upstream); // Fallback on timeout
}
}
// --- Step 4: Parse SNI ---
let snis = get_sni(&hello_buf); let snis = get_sni(&hello_buf);
// --- Step 5: Determine upstream based on SNI ---
if snis.is_empty() { if snis.is_empty() {
debug!("No SNI found in ClientHello, using default upstream.");
return Ok(default_upstream); return Ok(default_upstream);
} else { } else {
match proxy.sni.clone() { match proxy.sni.clone() {
Some(sni_map) => { Some(sni_map) => {
let mut upstream = default_upstream; let mut upstream = default_upstream.clone(); // Clone here for default case
let mut found_match = false;
for sni in snis { for sni in snis {
let m = sni_map.get(&sni); // snis is already Vec<String>
if m.is_some() { if let Some(target_upstream) = sni_map.get(&sni) {
upstream = m.unwrap().clone(); debug!(
"Found matching SNI '{}', routing to upstream: {}",
sni, target_upstream
);
upstream = target_upstream.clone();
found_match = true;
break; break;
} else {
trace!("SNI '{}' not found in map.", sni);
} }
} }
if !found_match {
debug!("SNI(s) found but none matched configuration, using default upstream.");
}
Ok(upstream) Ok(upstream)
} }
None => return Ok(default_upstream), None => {
debug!("SNI found but no SNI map configured, using default upstream.");
Ok(default_upstream)
}
} }
} }
} }