mod common; use common::*; use rustproxy::RustProxy; use rustproxy_config::RustProxyOptions; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; /// Create a rustls client config that trusts self-signed certs. fn make_insecure_tls_client_config() -> Arc { let _ = rustls::crypto::ring::default_provider().install_default(); let config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(InsecureVerifier)) .with_no_client_auth(); Arc::new(config) } #[derive(Debug)] struct InsecureVerifier; impl rustls::client::danger::ServerCertVerifier for InsecureVerifier { fn verify_server_cert( &self, _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::RSA_PSS_SHA256, ] } } #[tokio::test] async fn test_tls_terminate_basic() { let backend_port = next_port(); let proxy_port = next_port(); let domain = "test.example.com"; // Generate self-signed cert let (cert_pem, key_pem) = generate_self_signed_cert(domain); // Start plain TCP echo backend (proxy terminates TLS, sends plain to backend) let _backend = start_echo_server(backend_port).await; let options = RustProxyOptions { routes: vec![make_tls_terminate_route( proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, )], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); // Connect with TLS client let result = with_timeout(async { let tls_config = make_insecure_tls_client_config(); let connector = tokio_rustls::TlsConnector::from(tls_config); let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); tls_stream.write_all(b"hello TLS").await.unwrap(); let mut buf = vec![0u8; 1024]; let n = tls_stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 10) .await .unwrap(); assert_eq!(result, "hello TLS"); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_terminate_and_reencrypt() { let backend_port = next_port(); let proxy_port = next_port(); let domain = "reencrypt.example.com"; let backend_domain = "backend.internal"; // Generate certs let (proxy_cert, proxy_key) = generate_self_signed_cert(domain); let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain); // Start TLS echo backend let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await; // Create terminate-and-reencrypt route let mut route = make_tls_terminate_route( proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key, ); route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; let options = RustProxyOptions { routes: vec![route], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); let result = with_timeout(async { let tls_config = make_insecure_tls_client_config(); let connector = tokio_rustls::TlsConnector::from(tls_config); let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); tls_stream.write_all(b"hello reencrypt").await.unwrap(); let mut buf = vec![0u8; 1024]; let n = tls_stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 10) .await .unwrap(); assert_eq!(result, "hello reencrypt"); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_terminate_sni_cert_selection() { let backend1_port = next_port(); let backend2_port = next_port(); let proxy_port = next_port(); let (cert1, key1) = generate_self_signed_cert("alpha.example.com"); let (cert2, key2) = generate_self_signed_cert("beta.example.com"); let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await; let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await; let options = RustProxyOptions { routes: vec![ make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1), make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2), ], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); // Test alpha domain let result = with_timeout(async { let tls_config = make_insecure_tls_client_config(); let connector = tokio_rustls::TlsConnector::from(tls_config); let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); tls_stream.write_all(b"test").await.unwrap(); let mut buf = vec![0u8; 1024]; let n = tls_stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() }, 10) .await .unwrap(); assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_terminate_large_payload() { let backend_port = next_port(); let proxy_port = next_port(); let domain = "large.example.com"; let (cert_pem, key_pem) = generate_self_signed_cert(domain); let _backend = start_echo_server(backend_port).await; let options = RustProxyOptions { routes: vec![make_tls_terminate_route( proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, )], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); let result = with_timeout(async { let tls_config = make_insecure_tls_client_config(); let connector = tokio_rustls::TlsConnector::from(tls_config); let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) .await .unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); // Send 1MB of data let data = vec![b'X'; 1_000_000]; tls_stream.write_all(&data).await.unwrap(); tls_stream.shutdown().await.unwrap(); let mut received = Vec::new(); tls_stream.read_to_end(&mut received).await.unwrap(); received.len() }, 15) .await .unwrap(); assert_eq!(result, 1_000_000); proxy.stop().await.unwrap(); } #[tokio::test] async fn test_tls_terminate_concurrent() { let backend_port = next_port(); let proxy_port = next_port(); let domain = "concurrent.example.com"; let (cert_pem, key_pem) = generate_self_signed_cert(domain); let _backend = start_echo_server(backend_port).await; let options = RustProxyOptions { routes: vec![make_tls_terminate_route( proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, )], ..Default::default() }; let mut proxy = RustProxy::new(options).unwrap(); proxy.start().await.unwrap(); assert!(wait_for_port(proxy_port, 2000).await); let result = with_timeout(async { let mut handles = Vec::new(); for i in 0..10 { let port = proxy_port; let dom = domain.to_string(); handles.push(tokio::spawn(async move { let tls_config = make_insecure_tls_client_config(); let connector = tokio_rustls::TlsConnector::from(tls_config); let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) .await .unwrap(); let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let msg = format!("conn-{}", i); tls_stream.write_all(msg.as_bytes()).await.unwrap(); let mut buf = vec![0u8; 1024]; let n = tls_stream.read(&mut buf).await.unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() })); } let mut results = Vec::new(); for handle in handles { results.push(handle.await.unwrap()); } results }, 15) .await .unwrap(); assert_eq!(result.len(), 10); for (i, r) in result.iter().enumerate() { assert_eq!(r, &format!("conn-{}", i)); } proxy.stop().await.unwrap(); }