Wait for the entire TLS header to become available, even if it takes
multiple packets. Closes: #10
This commit is contained in:
parent
4c2711fc81
commit
aecffa0d14
@ -1,12 +1,15 @@
|
|||||||
use crate::servers::Proxy;
|
use crate::servers::Proxy;
|
||||||
use log::{debug, error, trace, warn};
|
use log::{debug, error, trace, warn};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
use std::io; // Import io for ErrorKind
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration; // For potential delays
|
||||||
use tls_parser::{
|
use tls_parser::{
|
||||||
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
|
parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage,
|
||||||
TlsMessageHandshake,
|
TlsMessageHandshake,
|
||||||
};
|
};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::time::timeout; // Use timeout for peek operations
|
||||||
|
|
||||||
fn get_sni(buf: &[u8]) -> Vec<String> {
|
fn get_sni(buf: &[u8]) -> Vec<String> {
|
||||||
let mut snis: Vec<String> = Vec::new();
|
let mut snis: Vec<String> = Vec::new();
|
||||||
@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec<String> {
|
|||||||
snis
|
snis
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Timeout duration for waiting for TLS Hello data
|
||||||
|
const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed
|
||||||
|
|
||||||
pub(crate) async fn determine_upstream_name(
|
pub(crate) async fn determine_upstream_name(
|
||||||
inbound: &TcpStream,
|
inbound: &TcpStream,
|
||||||
proxy: &Arc<Proxy>,
|
proxy: &Arc<Proxy>,
|
||||||
@ -64,35 +70,170 @@ pub(crate) async fn determine_upstream_name(
|
|||||||
let default_upstream = proxy.default_action.clone();
|
let default_upstream = proxy.default_action.clone();
|
||||||
|
|
||||||
let mut header = [0u8; 9];
|
let mut header = [0u8; 9];
|
||||||
inbound.peek(&mut header).await?;
|
|
||||||
|
|
||||||
let required_bytes = client_hello_buffer_size(&header)?;
|
// --- Step 1: Peek the initial header (9 bytes) with timeout ---
|
||||||
|
match timeout(TLS_PEEK_TIMEOUT, async {
|
||||||
|
loop {
|
||||||
|
match inbound.peek(&mut header).await {
|
||||||
|
Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes
|
||||||
|
Ok(0) => {
|
||||||
|
// Connection closed cleanly before sending enough data
|
||||||
|
trace!("Connection closed while peeking for TLS header");
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::UnexpectedEof,
|
||||||
|
"Connection closed while peeking for TLS header",
|
||||||
|
)
|
||||||
|
.into()); // Convert to Box<dyn Error>
|
||||||
|
}
|
||||||
|
Ok(_) => {
|
||||||
|
// Not enough bytes yet, yield and loop again
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||||
|
// Should not happen with await, but yield defensively
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Other I/O error
|
||||||
|
warn!("Error peeking for TLS header: {}", e);
|
||||||
|
return Err(e.into()); // Convert to Box<dyn Error>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => { /* Header peeked successfully */ }
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
// Inner loop returned an error (e.g., EOF, IO error)
|
||||||
|
trace!("Failed to peek header (inner error): {}", e);
|
||||||
|
return Ok(default_upstream); // Fallback on error/EOF
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Timeout occurred
|
||||||
|
error!("Timeout waiting for TLS header");
|
||||||
|
return Ok(default_upstream); // Fallback on timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut hello_buf = vec![0; required_bytes];
|
// --- Step 2: Calculate required size ---
|
||||||
let read_bytes = inbound.peek(&mut hello_buf).await?;
|
let required_bytes = match client_hello_buffer_size(&header) {
|
||||||
|
Ok(size) => size,
|
||||||
|
Err(e) => {
|
||||||
|
// Header was invalid or not a ClientHello
|
||||||
|
trace!("Could not determine required buffer size: {}", e);
|
||||||
|
return Ok(default_upstream);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if read_bytes < required_bytes.into() {
|
// Basic sanity check on size
|
||||||
error!("Could not read enough bytes to determine SNI");
|
if required_bytes > 16384 + 9 {
|
||||||
|
// TLS max record size + header approx
|
||||||
|
error!(
|
||||||
|
"Calculated required TLS buffer size is too large: {}",
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
return Ok(default_upstream);
|
return Ok(default_upstream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Step 3: Peek the full ClientHello with timeout ---
|
||||||
|
let mut hello_buf = vec![0; required_bytes];
|
||||||
|
match timeout(TLS_PEEK_TIMEOUT, async {
|
||||||
|
let mut total_peeked = 0;
|
||||||
|
loop {
|
||||||
|
// Peek into the portion of the buffer that hasn't been filled yet.
|
||||||
|
match inbound.peek(&mut hello_buf[total_peeked..]).await {
|
||||||
|
Ok(0) => {
|
||||||
|
// Connection closed cleanly before sending full ClientHello
|
||||||
|
trace!(
|
||||||
|
"Connection closed while peeking for full ClientHello (peeked {}/{} bytes)",
|
||||||
|
total_peeked,
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
return Err::<usize, io::Error>(
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::UnexpectedEof,
|
||||||
|
"Connection closed while peeking for full ClientHello",
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(n) => {
|
||||||
|
total_peeked += n;
|
||||||
|
if total_peeked >= required_bytes {
|
||||||
|
trace!("Successfully peeked {} bytes for ClientHello", total_peeked);
|
||||||
|
return Ok(total_peeked); // Got enough
|
||||||
|
} else {
|
||||||
|
// Not enough bytes yet, yield and loop again
|
||||||
|
trace!(
|
||||||
|
"Peeked {}/{} bytes for ClientHello, waiting for more...",
|
||||||
|
total_peeked,
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Error peeking for full ClientHello: {}", e);
|
||||||
|
return Err(e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => { /* Full hello peeked successfully */ }
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
error!("Could not peek full ClientHello (inner error): {}", e);
|
||||||
|
return Ok(default_upstream); // Fallback on error/EOF
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
error!(
|
||||||
|
"Timeout waiting for full ClientHello (needed {} bytes)",
|
||||||
|
required_bytes
|
||||||
|
);
|
||||||
|
return Ok(default_upstream); // Fallback on timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Step 4: Parse SNI ---
|
||||||
let snis = get_sni(&hello_buf);
|
let snis = get_sni(&hello_buf);
|
||||||
|
|
||||||
|
// --- Step 5: Determine upstream based on SNI ---
|
||||||
if snis.is_empty() {
|
if snis.is_empty() {
|
||||||
|
debug!("No SNI found in ClientHello, using default upstream.");
|
||||||
return Ok(default_upstream);
|
return Ok(default_upstream);
|
||||||
} else {
|
} else {
|
||||||
match proxy.sni.clone() {
|
match proxy.sni.clone() {
|
||||||
Some(sni_map) => {
|
Some(sni_map) => {
|
||||||
let mut upstream = default_upstream;
|
let mut upstream = default_upstream.clone(); // Clone here for default case
|
||||||
|
let mut found_match = false;
|
||||||
for sni in snis {
|
for sni in snis {
|
||||||
let m = sni_map.get(&sni);
|
// snis is already Vec<String>
|
||||||
if m.is_some() {
|
if let Some(target_upstream) = sni_map.get(&sni) {
|
||||||
upstream = m.unwrap().clone();
|
debug!(
|
||||||
|
"Found matching SNI '{}', routing to upstream: {}",
|
||||||
|
sni, target_upstream
|
||||||
|
);
|
||||||
|
upstream = target_upstream.clone();
|
||||||
|
found_match = true;
|
||||||
break;
|
break;
|
||||||
|
} else {
|
||||||
|
trace!("SNI '{}' not found in map.", sni);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !found_match {
|
||||||
|
debug!("SNI(s) found but none matched configuration, using default upstream.");
|
||||||
|
}
|
||||||
Ok(upstream)
|
Ok(upstream)
|
||||||
}
|
}
|
||||||
None => return Ok(default_upstream),
|
None => {
|
||||||
|
debug!("SNI found but no SNI map configured, using default upstream.");
|
||||||
|
Ok(default_upstream)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user