Wait for the entire TLS header to become available, even if it takes
multiple packets. Closes: #10
This commit is contained in:
		| @@ -1,12 +1,15 @@ | |||||||
| use crate::servers::Proxy; | use crate::servers::Proxy; | ||||||
| use log::{debug, error, trace, warn}; | use log::{debug, error, trace, warn}; | ||||||
| use std::error::Error; | use std::error::Error; | ||||||
|  | use std::io; // Import io for ErrorKind | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
|  | use std::time::Duration; // For potential delays | ||||||
| use tls_parser::{ | use tls_parser::{ | ||||||
|     parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage, |     parse_tls_extensions, parse_tls_raw_record, parse_tls_record_with_header, TlsMessage, | ||||||
|     TlsMessageHandshake, |     TlsMessageHandshake, | ||||||
| }; | }; | ||||||
| use tokio::net::TcpStream; | use tokio::net::TcpStream; | ||||||
|  | use tokio::time::timeout; // Use timeout for peek operations | ||||||
|  |  | ||||||
| fn get_sni(buf: &[u8]) -> Vec<String> { | fn get_sni(buf: &[u8]) -> Vec<String> { | ||||||
|     let mut snis: Vec<String> = Vec::new(); |     let mut snis: Vec<String> = Vec::new(); | ||||||
| @@ -57,6 +60,9 @@ fn get_sni(buf: &[u8]) -> Vec<String> { | |||||||
|     snis |     snis | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Timeout duration for waiting for TLS Hello data | ||||||
|  | const TLS_PEEK_TIMEOUT: Duration = Duration::from_secs(5); // Adjust as needed | ||||||
|  |  | ||||||
| pub(crate) async fn determine_upstream_name( | pub(crate) async fn determine_upstream_name( | ||||||
|     inbound: &TcpStream, |     inbound: &TcpStream, | ||||||
|     proxy: &Arc<Proxy>, |     proxy: &Arc<Proxy>, | ||||||
| @@ -64,35 +70,170 @@ pub(crate) async fn determine_upstream_name( | |||||||
|     let default_upstream = proxy.default_action.clone(); |     let default_upstream = proxy.default_action.clone(); | ||||||
|  |  | ||||||
|     let mut header = [0u8; 9]; |     let mut header = [0u8; 9]; | ||||||
|     inbound.peek(&mut header).await?; |  | ||||||
|  |  | ||||||
|     let required_bytes = client_hello_buffer_size(&header)?; |     // --- Step 1: Peek the initial header (9 bytes) with timeout --- | ||||||
|  |     match timeout(TLS_PEEK_TIMEOUT, async { | ||||||
|  |         loop { | ||||||
|  |             match inbound.peek(&mut header).await { | ||||||
|  |                 Ok(n) if n >= header.len() => return Ok::<usize, io::Error>(n), // Got enough bytes | ||||||
|  |                 Ok(0) => { | ||||||
|  |                     // Connection closed cleanly before sending enough data | ||||||
|  |                     trace!("Connection closed while peeking for TLS header"); | ||||||
|  |                     return Err(io::Error::new( | ||||||
|  |                         io::ErrorKind::UnexpectedEof, | ||||||
|  |                         "Connection closed while peeking for TLS header", | ||||||
|  |                     ) | ||||||
|  |                     .into()); // Convert to Box<dyn Error> | ||||||
|  |                 } | ||||||
|  |                 Ok(_) => { | ||||||
|  |                     // Not enough bytes yet, yield and loop again | ||||||
|  |                     tokio::task::yield_now().await; | ||||||
|  |                 } | ||||||
|  |                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { | ||||||
|  |                     // Should not happen with await, but yield defensively | ||||||
|  |                     tokio::task::yield_now().await; | ||||||
|  |                 } | ||||||
|  |                 Err(e) => { | ||||||
|  |                     // Other I/O error | ||||||
|  |                     warn!("Error peeking for TLS header: {}", e); | ||||||
|  |                     return Err(e.into()); // Convert to Box<dyn Error> | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     }) | ||||||
|  |     .await | ||||||
|  |     { | ||||||
|  |         Ok(Ok(_)) => { /* Header peeked successfully */ } | ||||||
|  |         Ok(Err(e)) => { | ||||||
|  |             // Inner loop returned an error (e.g., EOF, IO error) | ||||||
|  |             trace!("Failed to peek header (inner error): {}", e); | ||||||
|  |             return Ok(default_upstream); // Fallback on error/EOF | ||||||
|  |         } | ||||||
|  |         Err(_) => { | ||||||
|  |             // Timeout occurred | ||||||
|  |             error!("Timeout waiting for TLS header"); | ||||||
|  |             return Ok(default_upstream); // Fallback on timeout | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     let mut hello_buf = vec![0; required_bytes]; |     // --- Step 2: Calculate required size --- | ||||||
|     let read_bytes = inbound.peek(&mut hello_buf).await?; |     let required_bytes = match client_hello_buffer_size(&header) { | ||||||
|  |         Ok(size) => size, | ||||||
|  |         Err(e) => { | ||||||
|  |             // Header was invalid or not a ClientHello | ||||||
|  |             trace!("Could not determine required buffer size: {}", e); | ||||||
|  |             return Ok(default_upstream); | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  |  | ||||||
|     if read_bytes < required_bytes.into() { |     // Basic sanity check on size | ||||||
|         error!("Could not read enough bytes to determine SNI"); |     if required_bytes > 16384 + 9 { | ||||||
|  |         // TLS max record size + header approx | ||||||
|  |         error!( | ||||||
|  |             "Calculated required TLS buffer size is too large: {}", | ||||||
|  |             required_bytes | ||||||
|  |         ); | ||||||
|         return Ok(default_upstream); |         return Ok(default_upstream); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // --- Step 3: Peek the full ClientHello with timeout --- | ||||||
|  |     let mut hello_buf = vec![0; required_bytes]; | ||||||
|  |     match timeout(TLS_PEEK_TIMEOUT, async { | ||||||
|  |         let mut total_peeked = 0; | ||||||
|  |         loop { | ||||||
|  |             // Peek into the portion of the buffer that hasn't been filled yet. | ||||||
|  |             match inbound.peek(&mut hello_buf[total_peeked..]).await { | ||||||
|  |                 Ok(0) => { | ||||||
|  |                     // Connection closed cleanly before sending full ClientHello | ||||||
|  |                     trace!( | ||||||
|  |                         "Connection closed while peeking for full ClientHello (peeked {}/{} bytes)", | ||||||
|  |                         total_peeked, | ||||||
|  |                         required_bytes | ||||||
|  |                     ); | ||||||
|  |                     return Err::<usize, io::Error>( | ||||||
|  |                         io::Error::new( | ||||||
|  |                             io::ErrorKind::UnexpectedEof, | ||||||
|  |                             "Connection closed while peeking for full ClientHello", | ||||||
|  |                         ) | ||||||
|  |                         .into(), | ||||||
|  |                     ); | ||||||
|  |                 } | ||||||
|  |                 Ok(n) => { | ||||||
|  |                     total_peeked += n; | ||||||
|  |                     if total_peeked >= required_bytes { | ||||||
|  |                         trace!("Successfully peeked {} bytes for ClientHello", total_peeked); | ||||||
|  |                         return Ok(total_peeked); // Got enough | ||||||
|  |                     } else { | ||||||
|  |                         // Not enough bytes yet, yield and loop again | ||||||
|  |                         trace!( | ||||||
|  |                             "Peeked {}/{} bytes for ClientHello, waiting for more...", | ||||||
|  |                             total_peeked, | ||||||
|  |                             required_bytes | ||||||
|  |                         ); | ||||||
|  |                         tokio::task::yield_now().await; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { | ||||||
|  |                     tokio::task::yield_now().await; | ||||||
|  |                 } | ||||||
|  |                 Err(e) => { | ||||||
|  |                     warn!("Error peeking for full ClientHello: {}", e); | ||||||
|  |                     return Err(e.into()); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     }) | ||||||
|  |     .await | ||||||
|  |     { | ||||||
|  |         Ok(Ok(_)) => { /* Full hello peeked successfully */ } | ||||||
|  |         Ok(Err(e)) => { | ||||||
|  |             error!("Could not peek full ClientHello (inner error): {}", e); | ||||||
|  |             return Ok(default_upstream); // Fallback on error/EOF | ||||||
|  |         } | ||||||
|  |         Err(_) => { | ||||||
|  |             error!( | ||||||
|  |                 "Timeout waiting for full ClientHello (needed {} bytes)", | ||||||
|  |                 required_bytes | ||||||
|  |             ); | ||||||
|  |             return Ok(default_upstream); // Fallback on timeout | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // --- Step 4: Parse SNI --- | ||||||
|     let snis = get_sni(&hello_buf); |     let snis = get_sni(&hello_buf); | ||||||
|  |  | ||||||
|  |     // --- Step 5: Determine upstream based on SNI --- | ||||||
|     if snis.is_empty() { |     if snis.is_empty() { | ||||||
|  |         debug!("No SNI found in ClientHello, using default upstream."); | ||||||
|         return Ok(default_upstream); |         return Ok(default_upstream); | ||||||
|     } else { |     } else { | ||||||
|         match proxy.sni.clone() { |         match proxy.sni.clone() { | ||||||
|             Some(sni_map) => { |             Some(sni_map) => { | ||||||
|                 let mut upstream = default_upstream; |                 let mut upstream = default_upstream.clone(); // Clone here for default case | ||||||
|  |                 let mut found_match = false; | ||||||
|                 for sni in snis { |                 for sni in snis { | ||||||
|                     let m = sni_map.get(&sni); |                     // snis is already Vec<String> | ||||||
|                     if m.is_some() { |                     if let Some(target_upstream) = sni_map.get(&sni) { | ||||||
|                         upstream = m.unwrap().clone(); |                         debug!( | ||||||
|  |                             "Found matching SNI '{}', routing to upstream: {}", | ||||||
|  |                             sni, target_upstream | ||||||
|  |                         ); | ||||||
|  |                         upstream = target_upstream.clone(); | ||||||
|  |                         found_match = true; | ||||||
|                         break; |                         break; | ||||||
|  |                     } else { | ||||||
|  |                         trace!("SNI '{}' not found in map.", sni); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |                 if !found_match { | ||||||
|  |                     debug!("SNI(s) found but none matched configuration, using default upstream."); | ||||||
|  |                 } | ||||||
|                 Ok(upstream) |                 Ok(upstream) | ||||||
|             } |             } | ||||||
|             None => return Ok(default_upstream), |             None => { | ||||||
|  |                 debug!("SNI found but no SNI map configured, using default upstream."); | ||||||
|  |                 Ok(default_upstream) | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user