Wait for the entire TLS header to become available, even if it takes
multiple packets. Closes: #10
This commit is contained in:
parent
4c2711fc81
commit
aecffa0d14
@ -1,12 +1,15 @@
|
||||
use crate::servers::Proxy;
|
||||
use log::{debug, error, trace, warn};
|
||||
use std::error::Error;
|
||||
use std::io; // Import io for ErrorKind
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration; // For potential delays
|
||||
use tls_parser::{
|
||||
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
|
||||
TlsMessageHandshake,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::timeout; // Use timeout for peek operations
|
||||
|
||||
fn get_sni(buf: &[u8]) -> Vec<String> {
|
||||
let mut snis: Vec<String> = Vec::new();
|
||||
@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec<String> {
|
||||
snis
|
||||
}
|
||||
|
||||
// Timeout duration for waiting for TLS Hello data
|
||||
const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed
|
||||
|
||||
pub(crate) async fn determine_upstream_name(
|
||||
inbound: &TcpStream,
|
||||
proxy: &Arc<Proxy>,
|
||||
@ -64,35 +70,170 @@ pub(crate) async fn determine_upstream_name(
|
||||
let default_upstream = proxy.default_action.clone();
|
||||
|
||||
let mut header = [0u8; 9];
|
||||
inbound.peek(&mut header).await?;
|
||||
|
||||
let required_bytes = client_hello_buffer_size(&header)?;
|
||||
// --- Step 1: Peek the initial header (9 bytes) with timeout ---
|
||||
match timeout(TLS_PEEK_TIMEOUT, async {
|
||||
loop {
|
||||
match inbound.peek(&mut header).await {
|
||||
Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes
|
||||
Ok(0) => {
|
||||
// Connection closed cleanly before sending enough data
|
||||
trace!("Connection closed while peeking for TLS header");
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"Connection closed while peeking for TLS header",
|
||||
)
|
||||
.into()); // Convert to Box<dyn Error>
|
||||
}
|
||||
Ok(_) => {
|
||||
// Not enough bytes yet, yield and loop again
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
// Should not happen with await, but yield defensively
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
Err(e) => {
|
||||
// Other I/O error
|
||||
warn!("Error peeking for TLS header: {}", e);
|
||||
return Err(e.into()); // Convert to Box<dyn Error>
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(_)) => { /* Header peeked successfully */ }
|
||||
Ok(Err(e)) => {
|
||||
// Inner loop returned an error (e.g., EOF, IO error)
|
||||
trace!("Failed to peek header (inner error): {}", e);
|
||||
return Ok(default_upstream); // Fallback on error/EOF
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout occurred
|
||||
error!("Timeout waiting for TLS header");
|
||||
return Ok(default_upstream); // Fallback on timeout
|
||||
}
|
||||
}
|
||||
|
||||
let mut hello_buf = vec![0; required_bytes];
|
||||
let read_bytes = inbound.peek(&mut hello_buf).await?;
|
||||
// --- Step 2: Calculate required size ---
|
||||
let required_bytes = match client_hello_buffer_size(&header) {
|
||||
Ok(size) => size,
|
||||
Err(e) => {
|
||||
// Header was invalid or not a ClientHello
|
||||
trace!("Could not determine required buffer size: {}", e);
|
||||
return Ok(default_upstream);
|
||||
}
|
||||
};
|
||||
|
||||
if read_bytes < required_bytes.into() {
|
||||
error!("Could not read enough bytes to determine SNI");
|
||||
// Basic sanity check on size
|
||||
if required_bytes > 16384 + 9 {
|
||||
// TLS max record size + header approx
|
||||
error!(
|
||||
"Calculated required TLS buffer size is too large: {}",
|
||||
required_bytes
|
||||
);
|
||||
return Ok(default_upstream);
|
||||
}
|
||||
|
||||
// --- Step 3: Peek the full ClientHello with timeout ---
|
||||
let mut hello_buf = vec![0; required_bytes];
|
||||
match timeout(TLS_PEEK_TIMEOUT, async {
|
||||
let mut total_peeked = 0;
|
||||
loop {
|
||||
// Peek into the portion of the buffer that hasn't been filled yet.
|
||||
match inbound.peek(&mut hello_buf[total_peeked..]).await {
|
||||
Ok(0) => {
|
||||
// Connection closed cleanly before sending full ClientHello
|
||||
trace!(
|
||||
"Connection closed while peeking for full ClientHello (peeked {}/{} bytes)",
|
||||
total_peeked,
|
||||
required_bytes
|
||||
);
|
||||
return Err::<usize, io::Error>(
|
||||
io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"Connection closed while peeking for full ClientHello",
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
Ok(n) => {
|
||||
total_peeked += n;
|
||||
if total_peeked >= required_bytes {
|
||||
trace!("Successfully peeked {} bytes for ClientHello", total_peeked);
|
||||
return Ok(total_peeked); // Got enough
|
||||
} else {
|
||||
// Not enough bytes yet, yield and loop again
|
||||
trace!(
|
||||
"Peeked {}/{} bytes for ClientHello, waiting for more...",
|
||||
total_peeked,
|
||||
required_bytes
|
||||
);
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error peeking for full ClientHello: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(_)) => { /* Full hello peeked successfully */ }
|
||||
Ok(Err(e)) => {
|
||||
error!("Could not peek full ClientHello (inner error): {}", e);
|
||||
return Ok(default_upstream); // Fallback on error/EOF
|
||||
}
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Timeout waiting for full ClientHello (needed {} bytes)",
|
||||
required_bytes
|
||||
);
|
||||
return Ok(default_upstream); // Fallback on timeout
|
||||
}
|
||||
}
|
||||
|
||||
// --- Step 4: Parse SNI ---
|
||||
let snis = get_sni(&hello_buf);
|
||||
|
||||
// --- Step 5: Determine upstream based on SNI ---
|
||||
if snis.is_empty() {
|
||||
debug!("No SNI found in ClientHello, using default upstream.");
|
||||
return Ok(default_upstream);
|
||||
} else {
|
||||
match proxy.sni.clone() {
|
||||
Some(sni_map) => {
|
||||
let mut upstream = default_upstream;
|
||||
let mut upstream = default_upstream.clone(); // Clone here for default case
|
||||
let mut found_match = false;
|
||||
for sni in snis {
|
||||
let m = sni_map.get(&sni);
|
||||
if m.is_some() {
|
||||
upstream = m.unwrap().clone();
|
||||
// snis is already Vec<String>
|
||||
if let Some(target_upstream) = sni_map.get(&sni) {
|
||||
debug!(
|
||||
"Found matching SNI '{}', routing to upstream: {}",
|
||||
sni, target_upstream
|
||||
);
|
||||
upstream = target_upstream.clone();
|
||||
found_match = true;
|
||||
break;
|
||||
} else {
|
||||
trace!("SNI '{}' not found in map.", sni);
|
||||
}
|
||||
}
|
||||
if !found_match {
|
||||
debug!("SNI(s) found but none matched configuration, using default upstream.");
|
||||
}
|
||||
Ok(upstream)
|
||||
}
|
||||
None => return Ok(default_upstream),
|
||||
None => {
|
||||
debug!("SNI found but no SNI map configured, using default upstream.");
|
||||
Ok(default_upstream)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user