use std::sync::atomic::{AtomicU16, Ordering}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::task::JoinHandle; /// Atomic port allocator starting at 19000 to avoid collisions. static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000); /// Get the next available port for testing. pub fn next_port() -> u16 { PORT_COUNTER.fetch_add(1, Ordering::SeqCst) } /// Start a simple TCP echo server that echoes back whatever it receives. /// Returns the join handle for the server task. pub async fn start_echo_server(port: u16) -> JoinHandle<()> { let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) .await .expect("Failed to bind echo server"); tokio::spawn(async move { loop { let (mut stream, _) = match listener.accept().await { Ok(conn) => conn, Err(_) => break, }; tokio::spawn(async move { let mut buf = vec![0u8; 65536]; loop { let n = match stream.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if stream.write_all(&buf[..n]).await.is_err() { break; } } }); } }) } /// Start a TCP echo server that prefixes responses to identify which backend responded. pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> { let prefix = prefix.to_string(); let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) .await .expect("Failed to bind prefix echo server"); tokio::spawn(async move { loop { let (mut stream, _) = match listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let pfx = prefix.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 65536]; loop { let n = match stream.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; let mut response = pfx.as_bytes().to_vec(); response.extend_from_slice(&buf[..n]); if stream.write_all(&response).await.is_err() { break; } } }); } }) } /// Start a simple HTTP server that responds with a fixed status and body. pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> { let body = body.to_string(); let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) .await .expect("Failed to bind HTTP server"); tokio::spawn(async move { loop { let (mut stream, _) = match listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let b = body.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 8192]; // Read the request let _n = stream.read(&mut buf).await.unwrap_or(0); // Send response let response = format!( "HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", status, b.len(), b, ); let _ = stream.write_all(response.as_bytes()).await; let _ = stream.shutdown().await; }); } }) } /// Start an HTTP backend server that echoes back request details as JSON. /// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":""} /// Supports keep-alive by reading HTTP requests properly. pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> { let name = backend_name.to_string(); let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port)); tokio::spawn(async move { loop { let (mut stream, _) = match listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let backend = name.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 16384]; // Read request data let n = match stream.read(&mut buf).await { Ok(0) | Err(_) => return, Ok(n) => n, }; let req_str = String::from_utf8_lossy(&buf[..n]); // Parse first line: METHOD PATH HTTP/x.x let first_line = req_str.lines().next().unwrap_or(""); let parts: Vec<&str> = first_line.split_whitespace().collect(); let method = parts.first().copied().unwrap_or("UNKNOWN"); let path = parts.get(1).copied().unwrap_or("/"); // Extract Host header let host = req_str.lines() .find(|l| l.to_lowercase().starts_with("host:")) .map(|l| l[5..].trim()) .unwrap_or("unknown"); let body = format!( r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#, method, path, host, backend ); let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", body.len(), body, ); let _ = stream.write_all(response.as_bytes()).await; let _ = stream.shutdown().await; }); } }) } /// Wrap a future with a timeout, preventing tests from hanging. pub async fn with_timeout(future: F, secs: u64) -> Result where F: std::future::Future, { match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await { Ok(result) => Ok(result), Err(_) => Err("Test timed out"), } } /// Wait briefly for a server to be ready by attempting TCP connections. pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool { let start = std::time::Instant::now(); let timeout = std::time::Duration::from_millis(timeout_ms); while start.elapsed() < timeout { if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) .await .is_ok() { return true; } tokio::time::sleep(std::time::Duration::from_millis(10)).await; } false } /// Helper to create a minimal route config for testing. pub fn make_test_route( port: u16, domain: Option<&str>, target_host: &str, target_port: u16, ) -> rustproxy_config::RouteConfig { rustproxy_config::RouteConfig { id: None, route_match: rustproxy_config::RouteMatch { ports: rustproxy_config::PortRange::Single(port), domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())), path: None, client_ip: None, tls_version: None, headers: None, protocol: None, }, action: rustproxy_config::RouteAction { action_type: rustproxy_config::RouteActionType::Forward, targets: Some(vec![rustproxy_config::RouteTarget { target_match: None, host: rustproxy_config::HostSpec::Single(target_host.to_string()), port: rustproxy_config::PortSpec::Fixed(target_port), tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None, headers: None, advanced: None, priority: None, }]), tls: None, websocket: None, load_balancing: None, advanced: None, options: None, forwarding_engine: None, nftables: None, send_proxy_protocol: None, }, headers: None, security: None, name: None, description: None, priority: None, tags: None, enabled: None, } } /// Start a simple WebSocket echo backend. /// /// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back, /// then echoes all data received on the connection. pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> { let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port)); tokio::spawn(async move { loop { let (mut stream, _) = match listener.accept().await { Ok(conn) => conn, Err(_) => break, }; tokio::spawn(async move { // Read the HTTP upgrade request let mut buf = vec![0u8; 4096]; let n = match stream.read(&mut buf).await { Ok(0) | Err(_) => return, Ok(n) => n, }; let req_str = String::from_utf8_lossy(&buf[..n]); // Extract Sec-WebSocket-Key for proper handshake let ws_key = req_str.lines() .find(|l| l.to_lowercase().starts_with("sec-websocket-key:")) .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) .unwrap_or_default(); // Compute Sec-WebSocket-Accept (simplified - just echo for test purposes) // Real implementation would compute SHA-1 + base64 let accept_response = format!( "HTTP/1.1 101 Switching Protocols\r\n\ Upgrade: websocket\r\n\ Connection: Upgrade\r\n\ Sec-WebSocket-Accept: {}\r\n\ \r\n", ws_key ); if stream.write_all(accept_response.as_bytes()).await.is_err() { return; } // Echo all data back (raw TCP after upgrade) let mut echo_buf = vec![0u8; 65536]; loop { let n = match stream.read(&mut echo_buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if stream.write_all(&echo_buf[..n]).await.is_err() { break; } } }); } }) } /// Generate a self-signed certificate for testing using rcgen. /// Returns (cert_pem, key_pem). pub fn generate_self_signed_cert(domain: &str) -> (String, String) { use rcgen::{CertificateParams, KeyPair}; let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap(); params.distinguished_name.push(rcgen::DnType::CommonName, domain); let key_pair = KeyPair::generate().unwrap(); let cert = params.self_signed(&key_pair).unwrap(); (cert.pem(), key_pair.serialize_pem()) } /// Start a TLS echo server using the given cert/key. /// Returns the join handle. pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> { use std::sync::Arc; let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem) .expect("Failed to build TLS acceptor"); let acceptor = Arc::new(acceptor); let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) .await .expect("Failed to bind TLS echo server"); tokio::spawn(async move { loop { let (stream, _) = match listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let acc = acceptor.clone(); tokio::spawn(async move { let mut tls_stream = match acc.accept(stream).await { Ok(s) => s, Err(_) => return, }; let mut buf = vec![0u8; 65536]; loop { let n = match tls_stream.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if tls_stream.write_all(&buf[..n]).await.is_err() { break; } } }); } }) } /// Helper to create a TLS terminate route with static cert for testing. pub fn make_tls_terminate_route( port: u16, domain: &str, target_host: &str, target_port: u16, cert_pem: &str, key_pem: &str, ) -> rustproxy_config::RouteConfig { let mut route = make_test_route(port, Some(domain), target_host, target_port); route.action.tls = Some(rustproxy_config::RouteTls { mode: rustproxy_config::TlsMode::Terminate, certificate: Some(rustproxy_config::CertificateSpec::Static( rustproxy_config::CertificateConfig { cert: cert_pem.to_string(), key: key_pem.to_string(), ca: None, key_file: None, cert_file: None, }, )), acme: None, versions: None, ciphers: None, honor_cipher_order: None, session_timeout: None, }); route } /// Helper to create a TLS passthrough route for testing. pub fn make_tls_passthrough_route( port: u16, domain: Option<&str>, target_host: &str, target_port: u16, ) -> rustproxy_config::RouteConfig { let mut route = make_test_route(port, domain, target_host, target_port); route.action.tls = Some(rustproxy_config::RouteTls { mode: rustproxy_config::TlsMode::Passthrough, certificate: None, acme: None, versions: None, ciphers: None, honor_cipher_order: None, session_timeout: None, }); route }