mod common; use common::*; use rustproxy::RustProxy; use rustproxy_config::RustProxyOptions; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; /// Build a minimal TLS ClientHello with the given SNI domain. /// This is enough for the proxy's SNI parser to extract the domain. fn build_client_hello(domain: &str) -> Vec { let domain_bytes = domain.as_bytes(); let sni_length = domain_bytes.len() as u16; // Server Name extension (type 0x0000) let mut sni_ext = Vec::new(); sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length sni_ext.push(0x00); // host_name type sni_ext.extend_from_slice(&sni_length.to_be_bytes()); sni_ext.extend_from_slice(domain_bytes); let extensions_length = sni_ext.len() as u16; // ClientHello message let mut client_hello = Vec::new(); client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version client_hello.extend_from_slice(&[0x00; 32]); // random client_hello.push(0x00); // session_id length client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite) client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null) client_hello.extend_from_slice(&extensions_length.to_be_bytes()); client_hello.extend_from_slice(&sni_ext); let hello_len = client_hello.len() as u32; // Handshake wrapper (type 1 = ClientHello) let mut handshake = Vec::new(); handshake.push(0x01); // ClientHello handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length handshake.extend_from_slice(&client_hello); let hs_len = handshake.len() as u16; // TLS record let mut record = Vec::new(); record.push(0x16); // ContentType: Handshake record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version) record.extend_from_slice(&hs_len.to_be_bytes()); record.extend_from_slice(&handshake); record } #[tokio::test] async fn test_tls_passthrough_sni_routing() { let backend1_port = next_port(); let backend2_port = next_port(); let proxy_port = next_port(); let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await; let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await; let options = RustProxyOptions { routes: vec![ make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port), make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port), ], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); // Send a fake ClientHello with SNI "one.example.com" let result = with_timeout(async { let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let hello = build_client_hello("one.example.com"); stream.write_all(&hello).await.unwrap(); let mut buf = vec![0u8; 4096]; let n = stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 5) .await .unwrap(); // Backend1 should have received the ClientHello and prefixed its response assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result); // Now test routing to backend2 let result2 = with_timeout(async { let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let hello = build_client_hello("two.example.com"); stream.write_all(&hello).await.unwrap(); let mut buf = vec![0u8; 4096]; let n = stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 5) .await .unwrap(); assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_passthrough_unknown_sni() { let backend_port = next_port(); let proxy_port = next_port(); let _backend = start_echo_server(backend_port).await; let options = RustProxyOptions { routes: vec![ make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port), ], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); // Send ClientHello with unknown SNI - should get no response (connection dropped) let result = with_timeout(async { let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let hello = build_client_hello("unknown.example.com"); stream.write_all(&hello).await.unwrap(); let mut buf = vec![0u8; 4096]; // Should either get 0 bytes (closed) or an error match stream.read(&mut buf).await { Ok(0) => true, // Connection closed = no route matched Ok(_) => false, // Got data = route shouldn't have matched Err(_) => true, // Error = connection dropped } }, 5) .await .unwrap(); assert!(result, "Unknown SNI should result in dropped connection"); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_passthrough_wildcard_domain() { let backend_port = next_port(); let proxy_port = next_port(); let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await; let options = RustProxyOptions { routes: vec![ make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port), ], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); // Should match any subdomain of example.com let result = with_timeout(async { let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let hello = build_client_hello("anything.example.com"); stream.write_all(&hello).await.unwrap(); let mut buf = vec![0u8; 4096]; let n = stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 5) .await .unwrap(); assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_passthrough_multiple_domains() { let b1_port = next_port(); let b2_port = next_port(); let b3_port = next_port(); let proxy_port = next_port(); let _b1 = start_prefix_echo_server(b1_port, "B1:").await; let _b2 = start_prefix_echo_server(b2_port, "B2:").await; let _b3 = start_prefix_echo_server(b3_port, "B3:").await; let options = RustProxyOptions { routes: vec![ make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port), make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port), make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port), ], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); for (domain, expected_prefix) in [ ("alpha.example.com", "B1:"), ("beta.example.com", "B2:"), ("gamma.example.com", "B3:"), ] { let result = with_timeout(async { let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let hello = build_client_hello(domain); stream.write_all(&hello).await.unwrap(); let mut buf = vec![0u8; 4096]; let n = stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 5) .await .unwrap(); assert!( result.starts_with(expected_prefix), "Domain {} should route to {}, got: {}", domain, expected_prefix, result ); } proxy.stop().await.unwrap(); }