Wait for the entire TLS header to become available, even if it takes
multiple packets. Closes: #10
This commit is contained in:
		| @@ -1,12 +1,15 @@ | ||||
| use crate::servers::Proxy; | ||||
| use log::{debug, error, trace, warn}; | ||||
| use std::error::Error; | ||||
| use std::io; // Import io for ErrorKind | ||||
| use std::sync::Arc; | ||||
| use std::time::Duration; // For potential delays | ||||
| use tls_parser::{ | ||||
|     parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage, | ||||
|     TlsMessageHandshake, | ||||
| }; | ||||
| use tokio::net::TcpStream; | ||||
| use tokio::time::timeout; // Use timeout for peek operations | ||||
|  | ||||
| fn get_sni(buf: &[u8]) -> Vec<String> { | ||||
|     let mut snis: Vec<String> = Vec::new(); | ||||
| @@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec<String> { | ||||
|     snis | ||||
| } | ||||
|  | ||||
| // Timeout duration for waiting for TLS Hello data | ||||
| const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed | ||||
|  | ||||
| pub(crate) async fn determine_upstream_name( | ||||
|     inbound: &TcpStream, | ||||
|     proxy: &Arc<Proxy>, | ||||
| @@ -64,35 +70,170 @@ pub(crate) async fn determine_upstream_name( | ||||
|     let default_upstream = proxy.default_action.clone(); | ||||
|  | ||||
|     let mut header = [0u8; 9]; | ||||
|     inbound.peek(&mut header).await?; | ||||
|  | ||||
|     let required_bytes = client_hello_buffer_size(&header)?; | ||||
|     // --- Step 1: Peek the initial header (9 bytes) with timeout --- | ||||
|     match timeout(TLS_PEEK_TIMEOUT, async { | ||||
|         loop { | ||||
|             match inbound.peek(&mut header).await { | ||||
|                 Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes | ||||
|                 Ok(0) => { | ||||
|                     // Connection closed cleanly before sending enough data | ||||
|                     trace!("Connection closed while peeking for TLS header"); | ||||
|                     return Err(io::Error::new( | ||||
|                         io::ErrorKind::UnexpectedEof, | ||||
|                         "Connection closed while peeking for TLS header", | ||||
|                     ) | ||||
|                     .into()); // Convert to Box<dyn Error> | ||||
|                 } | ||||
|                 Ok(_) => { | ||||
|                     // Not enough bytes yet, yield and loop again | ||||
|                     tokio::task::yield_now().await; | ||||
|                 } | ||||
|                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { | ||||
|                     // Should not happen with await, but yield defensively | ||||
|                     tokio::task::yield_now().await; | ||||
|                 } | ||||
|                 Err(e) => { | ||||
|                     // Other I/O error | ||||
|                     warn!("Error peeking for TLS header: {}", e); | ||||
|                     return Err(e.into()); // Convert to Box<dyn Error> | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     }) | ||||
|     .await | ||||
|     { | ||||
|         Ok(Ok(_)) => { /* Header peeked successfully */ } | ||||
|         Ok(Err(e)) => { | ||||
|             // Inner loop returned an error (e.g., EOF, IO error) | ||||
|             trace!("Failed to peek header (inner error): {}", e); | ||||
|             return Ok(default_upstream); // Fallback on error/EOF | ||||
|         } | ||||
|         Err(_) => { | ||||
|             // Timeout occurred | ||||
|             error!("Timeout waiting for TLS header"); | ||||
|             return Ok(default_upstream); // Fallback on timeout | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     let mut hello_buf = vec![0; required_bytes]; | ||||
|     let read_bytes = inbound.peek(&mut hello_buf).await?; | ||||
|     // --- Step 2: Calculate required size --- | ||||
|     let required_bytes = match client_hello_buffer_size(&header) { | ||||
|         Ok(size) => size, | ||||
|         Err(e) => { | ||||
|             // Header was invalid or not a ClientHello | ||||
|             trace!("Could not determine required buffer size: {}", e); | ||||
|             return Ok(default_upstream); | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     if read_bytes < required_bytes.into() { | ||||
|         error!("Could not read enough bytes to determine SNI"); | ||||
|     // Basic sanity check on size | ||||
|     if required_bytes > 16384 + 9 { | ||||
|         // TLS max record size + header approx | ||||
|         error!( | ||||
|             "Calculated required TLS buffer size is too large: {}", | ||||
|             required_bytes | ||||
|         ); | ||||
|         return Ok(default_upstream); | ||||
|     } | ||||
|  | ||||
|     // --- Step 3: Peek the full ClientHello with timeout --- | ||||
|     let mut hello_buf = vec![0; required_bytes]; | ||||
|     match timeout(TLS_PEEK_TIMEOUT, async { | ||||
|         let mut total_peeked = 0; | ||||
|         loop { | ||||
|             // Peek into the portion of the buffer that hasn't been filled yet. | ||||
|             match inbound.peek(&mut hello_buf[total_peeked..]).await { | ||||
|                 Ok(0) => { | ||||
|                     // Connection closed cleanly before sending full ClientHello | ||||
|                     trace!( | ||||
|                         "Connection closed while peeking for full ClientHello (peeked {}/{} bytes)", | ||||
|                         total_peeked, | ||||
|                         required_bytes | ||||
|                     ); | ||||
|                     return Err::<usize, io::Error>( | ||||
|                         io::Error::new( | ||||
|                             io::ErrorKind::UnexpectedEof, | ||||
|                             "Connection closed while peeking for full ClientHello", | ||||
|                         ) | ||||
|                         .into(), | ||||
|                     ); | ||||
|                 } | ||||
|                 Ok(n) => { | ||||
|                     total_peeked += n; | ||||
|                     if total_peeked >= required_bytes { | ||||
|                         trace!("Successfully peeked {} bytes for ClientHello", total_peeked); | ||||
|                         return Ok(total_peeked); // Got enough | ||||
|                     } else { | ||||
|                         // Not enough bytes yet, yield and loop again | ||||
|                         trace!( | ||||
|                             "Peeked {}/{} bytes for ClientHello, waiting for more...", | ||||
|                             total_peeked, | ||||
|                             required_bytes | ||||
|                         ); | ||||
|                         tokio::task::yield_now().await; | ||||
|                     } | ||||
|                 } | ||||
|                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { | ||||
|                     tokio::task::yield_now().await; | ||||
|                 } | ||||
|                 Err(e) => { | ||||
|                     warn!("Error peeking for full ClientHello: {}", e); | ||||
|                     return Err(e.into()); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     }) | ||||
|     .await | ||||
|     { | ||||
|         Ok(Ok(_)) => { /* Full hello peeked successfully */ } | ||||
|         Ok(Err(e)) => { | ||||
|             error!("Could not peek full ClientHello (inner error): {}", e); | ||||
|             return Ok(default_upstream); // Fallback on error/EOF | ||||
|         } | ||||
|         Err(_) => { | ||||
|             error!( | ||||
|                 "Timeout waiting for full ClientHello (needed {} bytes)", | ||||
|                 required_bytes | ||||
|             ); | ||||
|             return Ok(default_upstream); // Fallback on timeout | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // --- Step 4: Parse SNI --- | ||||
|     let snis = get_sni(&hello_buf); | ||||
|  | ||||
|     // --- Step 5: Determine upstream based on SNI --- | ||||
|     if snis.is_empty() { | ||||
|         debug!("No SNI found in ClientHello, using default upstream."); | ||||
|         return Ok(default_upstream); | ||||
|     } else { | ||||
|         match proxy.sni.clone() { | ||||
|             Some(sni_map) => { | ||||
|                 let mut upstream = default_upstream; | ||||
|                 let mut upstream = default_upstream.clone(); // Clone here for default case | ||||
|                 let mut found_match = false; | ||||
|                 for sni in snis { | ||||
|                     let m = sni_map.get(&sni); | ||||
|                     if m.is_some() { | ||||
|                         upstream = m.unwrap().clone(); | ||||
|                     // snis is already Vec<String> | ||||
|                     if let Some(target_upstream) = sni_map.get(&sni) { | ||||
|                         debug!( | ||||
|                             "Found matching SNI '{}', routing to upstream: {}", | ||||
|                             sni, target_upstream | ||||
|                         ); | ||||
|                         upstream = target_upstream.clone(); | ||||
|                         found_match = true; | ||||
|                         break; | ||||
|                     } else { | ||||
|                         trace!("SNI '{}' not found in map.", sni); | ||||
|                     } | ||||
|                 } | ||||
|                 if !found_match { | ||||
|                     debug!("SNI(s) found but none matched configuration, using default upstream."); | ||||
|                 } | ||||
|                 Ok(upstream) | ||||
|             } | ||||
|             None => return Ok(default_upstream), | ||||
|             None => { | ||||
|                 debug!("SNI found but no SNI map configured, using default upstream."); | ||||
|                 Ok(default_upstream) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user