From aecffa0d14831376362526babf8817dd537b5961 Mon Sep 17 00:00:00 2001 From: Jacob Kiers Date: Tue, 21 Jan 2025 21:22:53 +0100 Subject: [PATCH] Wait for the entire TLS header to become available, even if it takes multiple packets. Closes: #10 --- src/servers/protocol/tls.rs | 163 +++++++++++++++++++++++++++++++++--- 1 file changed, 152 insertions(+), 11 deletions(-) diff --git a/src/servers/protocol/tls.rs b/src/servers/protocol/tls.rs index a061115..20beef2 100644 --- a/src/servers/protocol/tls.rs +++ b/src/servers/protocol/tls.rs @@ -1,12 +1,15 @@ use crate::servers::Proxy; use log::{debug, error, trace, warn}; use std::error::Error; +use std::io; // Import io for ErrorKind use std::sync::Arc; +use std::time::Duration; // For potential delays use tls_parser::{ parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage, TlsMessageHandshake, }; use tokio::net::TcpStream; +use tokio::time::timeout; // Use timeout for peek operations fn get_sni(buf: &[u8]) -> Vec { let mut snis: Vec = Vec::new(); @@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec { snis } +// Timeout duration for waiting for TLS Hello data +const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed + pub(crate) async fn determine_upstream_name( inbound: &TcpStream, proxy: &Arc, @@ -64,35 +70,170 @@ pub(crate) async fn determine_upstream_name( let default_upstream = proxy.default_action.clone(); let mut header = [0u8; 9]; - inbound.peek(&mut header).await?; - let required_bytes = client_hello_buffer_size(&header)?; + // --- Step 1: Peek the initial header (9 bytes) with timeout --- + match timeout(TLS_PEEK_TIMEOUT, async { + loop { + match inbound.peek(&mut header).await { + Ok(n) if n >= header.len() => return Ok::(n), // Got enough bytes + Ok(0) => { + // Connection closed cleanly before sending enough data + trace!("Connection closed while peeking for TLS header"); + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Connection closed while peeking for TLS header", + ) + .into()); // Convert to Box + } + Ok(_) => { + // Not enough bytes yet, yield and loop again + tokio::task::yield_now().await; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // Should not happen with await, but yield defensively + tokio::task::yield_now().await; + } + Err(e) => { + // Other I/O error + warn!("Error peeking for TLS header: {}", e); + return Err(e.into()); // Convert to Box + } + } + } + }) + .await + { + Ok(Ok(_)) => { /* Header peeked successfully */ } + Ok(Err(e)) => { + // Inner loop returned an error (e.g., EOF, IO error) + trace!("Failed to peek header (inner error): {}", e); + return Ok(default_upstream); // Fallback on error/EOF + } + Err(_) => { + // Timeout occurred + error!("Timeout waiting for TLS header"); + return Ok(default_upstream); // Fallback on timeout + } + } - let mut hello_buf = vec![0; required_bytes]; - let read_bytes = inbound.peek(&mut hello_buf).await?; + // --- Step 2: Calculate required size --- + let required_bytes = match client_hello_buffer_size(&header) { + Ok(size) => size, + Err(e) => { + // Header was invalid or not a ClientHello + trace!("Could not determine required buffer size: {}", e); + return Ok(default_upstream); + } + }; - if read_bytes < required_bytes.into() { - error!("Could not read enough bytes to determine SNI"); + // Basic sanity check on size + if required_bytes > 16384 + 9 { + // TLS max record size + header approx + error!( + "Calculated required TLS buffer size is too large: {}", + required_bytes + ); return Ok(default_upstream); } + // --- Step 3: Peek the full ClientHello with timeout --- + let mut hello_buf = vec![0; required_bytes]; + match timeout(TLS_PEEK_TIMEOUT, async { + let mut total_peeked = 0; + loop { + // Peek into the portion of the buffer that hasn't been filled yet. + match inbound.peek(&mut hello_buf[total_peeked..]).await { + Ok(0) => { + // Connection closed cleanly before sending full ClientHello + trace!( + "Connection closed while peeking for full ClientHello (peeked {}/{} bytes)", + total_peeked, + required_bytes + ); + return Err::( + io::Error::new( + io::ErrorKind::UnexpectedEof, + "Connection closed while peeking for full ClientHello", + ) + .into(), + ); + } + Ok(n) => { + total_peeked += n; + if total_peeked >= required_bytes { + trace!("Successfully peeked {} bytes for ClientHello", total_peeked); + return Ok(total_peeked); // Got enough + } else { + // Not enough bytes yet, yield and loop again + trace!( + "Peeked {}/{} bytes for ClientHello, waiting for more...", + total_peeked, + required_bytes + ); + tokio::task::yield_now().await; + } + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + tokio::task::yield_now().await; + } + Err(e) => { + warn!("Error peeking for full ClientHello: {}", e); + return Err(e.into()); + } + } + } + }) + .await + { + Ok(Ok(_)) => { /* Full hello peeked successfully */ } + Ok(Err(e)) => { + error!("Could not peek full ClientHello (inner error): {}", e); + return Ok(default_upstream); // Fallback on error/EOF + } + Err(_) => { + error!( + "Timeout waiting for full ClientHello (needed {} bytes)", + required_bytes + ); + return Ok(default_upstream); // Fallback on timeout + } + } + + // --- Step 4: Parse SNI --- let snis = get_sni(&hello_buf); + + // --- Step 5: Determine upstream based on SNI --- if snis.is_empty() { + debug!("No SNI found in ClientHello, using default upstream."); return Ok(default_upstream); } else { match proxy.sni.clone() { Some(sni_map) => { - let mut upstream = default_upstream; + let mut upstream = default_upstream.clone(); // Clone here for default case + let mut found_match = false; for sni in snis { - let m = sni_map.get(&sni); - if m.is_some() { - upstream = m.unwrap().clone(); + // snis is already Vec + if let Some(target_upstream) = sni_map.get(&sni) { + debug!( + "Found matching SNI '{}', routing to upstream: {}", + sni, target_upstream + ); + upstream = target_upstream.clone(); + found_match = true; break; + } else { + trace!("SNI '{}' not found in map.", sni); } } + if !found_match { + debug!("SNI(s) found but none matched configuration, using default upstream."); + } Ok(upstream) } - None => return Ok(default_upstream), + None => { + debug!("SNI found but no SNI map configured, using default upstream."); + Ok(default_upstream) + } } } }