554 lines
19 KiB
Rust
554 lines
19 KiB
Rust
use std::sync::atomic::{AtomicU16, Ordering};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::TcpListener;
|
|
use tokio::task::JoinHandle;
|
|
|
|
/// Atomic port allocator starting at 19000 to avoid collisions.
|
|
static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000);
|
|
|
|
/// Get the next available port for testing.
|
|
pub fn next_port() -> u16 {
|
|
PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
|
}
|
|
|
|
/// Start a simple TCP echo server that echoes back whatever it receives.
|
|
/// Returns the join handle for the server task.
|
|
pub async fn start_echo_server(port: u16) -> JoinHandle<()> {
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.expect("Failed to bind echo server");
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (mut stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
loop {
|
|
let n = match stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
if stream.write_all(&buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Start a TCP echo server that prefixes responses to identify which backend responded.
|
|
pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> {
|
|
let prefix = prefix.to_string();
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.expect("Failed to bind prefix echo server");
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (mut stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
let pfx = prefix.clone();
|
|
tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
loop {
|
|
let n = match stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
let mut response = pfx.as_bytes().to_vec();
|
|
response.extend_from_slice(&buf[..n]);
|
|
if stream.write_all(&response).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Start a simple HTTP server that responds with a fixed status and body.
|
|
pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> {
|
|
let body = body.to_string();
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.expect("Failed to bind HTTP server");
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (mut stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
let b = body.clone();
|
|
tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 8192];
|
|
// Read the request
|
|
let _n = stream.read(&mut buf).await.unwrap_or(0);
|
|
// Send response
|
|
let response = format!(
|
|
"HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
|
status,
|
|
b.len(),
|
|
b,
|
|
);
|
|
let _ = stream.write_all(response.as_bytes()).await;
|
|
let _ = stream.shutdown().await;
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Start an HTTP backend server that echoes back request details as JSON.
|
|
/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":"<name>"}
|
|
/// Supports keep-alive by reading HTTP requests properly.
|
|
pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> {
|
|
let name = backend_name.to_string();
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port));
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (mut stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
let backend = name.clone();
|
|
tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 16384];
|
|
// Read request data
|
|
let n = match stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => return,
|
|
Ok(n) => n,
|
|
};
|
|
let req_str = String::from_utf8_lossy(&buf[..n]);
|
|
|
|
// Parse first line: METHOD PATH HTTP/x.x
|
|
let first_line = req_str.lines().next().unwrap_or("");
|
|
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
|
let method = parts.first().copied().unwrap_or("UNKNOWN");
|
|
let path = parts.get(1).copied().unwrap_or("/");
|
|
|
|
// Extract Host header
|
|
let host = req_str.lines()
|
|
.find(|l| l.to_lowercase().starts_with("host:"))
|
|
.map(|l| l[5..].trim())
|
|
.unwrap_or("unknown");
|
|
|
|
let body = format!(
|
|
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
|
|
method, path, host, backend
|
|
);
|
|
|
|
let response = format!(
|
|
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
|
body.len(),
|
|
body,
|
|
);
|
|
let _ = stream.write_all(response.as_bytes()).await;
|
|
let _ = stream.shutdown().await;
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Wrap a future with a timeout, preventing tests from hanging.
|
|
pub async fn with_timeout<F, T>(future: F, secs: u64) -> Result<T, &'static str>
|
|
where
|
|
F: std::future::Future<Output = T>,
|
|
{
|
|
match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await {
|
|
Ok(result) => Ok(result),
|
|
Err(_) => Err("Test timed out"),
|
|
}
|
|
}
|
|
|
|
/// Wait briefly for a server to be ready by attempting TCP connections.
|
|
pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool {
|
|
let start = std::time::Instant::now();
|
|
let timeout = std::time::Duration::from_millis(timeout_ms);
|
|
while start.elapsed() < timeout {
|
|
if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.is_ok()
|
|
{
|
|
return true;
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
}
|
|
false
|
|
}
|
|
|
|
/// Start a TLS HTTP echo backend: accepts TLS, then responds with HTTP JSON
|
|
/// containing request details. Combines TLS acceptance with HTTP echo behavior.
|
|
pub async fn start_tls_http_backend(
|
|
port: u16,
|
|
backend_name: &str,
|
|
cert_pem: &str,
|
|
key_pem: &str,
|
|
) -> JoinHandle<()> {
|
|
use std::sync::Arc;
|
|
|
|
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
|
.expect("Failed to build TLS acceptor");
|
|
let acceptor = Arc::new(acceptor);
|
|
let name = backend_name.to_string();
|
|
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.unwrap_or_else(|_| panic!("Failed to bind TLS HTTP backend on port {}", port));
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
let acc = acceptor.clone();
|
|
let backend = name.clone();
|
|
tokio::spawn(async move {
|
|
let mut tls_stream = match acc.accept(stream).await {
|
|
Ok(s) => s,
|
|
Err(_) => return,
|
|
};
|
|
|
|
let mut buf = vec![0u8; 16384];
|
|
let n = match tls_stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => return,
|
|
Ok(n) => n,
|
|
};
|
|
let req_str = String::from_utf8_lossy(&buf[..n]);
|
|
|
|
// Parse first line: METHOD PATH HTTP/x.x
|
|
let first_line = req_str.lines().next().unwrap_or("");
|
|
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
|
let method = parts.first().copied().unwrap_or("UNKNOWN");
|
|
let path = parts.get(1).copied().unwrap_or("/");
|
|
|
|
// Extract Host header
|
|
let host = req_str
|
|
.lines()
|
|
.find(|l| l.to_lowercase().starts_with("host:"))
|
|
.map(|l| l[5..].trim())
|
|
.unwrap_or("unknown");
|
|
|
|
let body = format!(
|
|
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
|
|
method, path, host, backend
|
|
);
|
|
|
|
let response = format!(
|
|
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
|
body.len(),
|
|
body,
|
|
);
|
|
let _ = tls_stream.write_all(response.as_bytes()).await;
|
|
let _ = tls_stream.shutdown().await;
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Helper to create a minimal route config for testing.
|
|
pub fn make_test_route(
|
|
port: u16,
|
|
domain: Option<&str>,
|
|
target_host: &str,
|
|
target_port: u16,
|
|
) -> rustproxy_config::RouteConfig {
|
|
rustproxy_config::RouteConfig {
|
|
id: None,
|
|
route_match: rustproxy_config::RouteMatch {
|
|
ports: rustproxy_config::PortRange::Single(port),
|
|
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
|
|
path: None,
|
|
client_ip: None,
|
|
tls_version: None,
|
|
headers: None,
|
|
protocol: None,
|
|
},
|
|
action: rustproxy_config::RouteAction {
|
|
action_type: rustproxy_config::RouteActionType::Forward,
|
|
targets: Some(vec![rustproxy_config::RouteTarget {
|
|
target_match: None,
|
|
host: rustproxy_config::HostSpec::Single(target_host.to_string()),
|
|
port: rustproxy_config::PortSpec::Fixed(target_port),
|
|
tls: None,
|
|
websocket: None,
|
|
load_balancing: None,
|
|
send_proxy_protocol: None,
|
|
headers: None,
|
|
advanced: None,
|
|
priority: None,
|
|
}]),
|
|
tls: None,
|
|
websocket: None,
|
|
load_balancing: None,
|
|
advanced: None,
|
|
options: None,
|
|
forwarding_engine: None,
|
|
nftables: None,
|
|
send_proxy_protocol: None,
|
|
},
|
|
headers: None,
|
|
security: None,
|
|
name: None,
|
|
description: None,
|
|
priority: None,
|
|
tags: None,
|
|
enabled: None,
|
|
}
|
|
}
|
|
|
|
/// Start a simple WebSocket echo backend.
|
|
///
|
|
/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back,
|
|
/// then echoes all data received on the connection.
|
|
pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port));
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (mut stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
tokio::spawn(async move {
|
|
// Read the HTTP upgrade request
|
|
let mut buf = vec![0u8; 4096];
|
|
let n = match stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => return,
|
|
Ok(n) => n,
|
|
};
|
|
|
|
let req_str = String::from_utf8_lossy(&buf[..n]);
|
|
|
|
// Extract Sec-WebSocket-Key for proper handshake
|
|
let ws_key = req_str.lines()
|
|
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
|
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
|
.unwrap_or_default();
|
|
|
|
// Compute Sec-WebSocket-Accept (simplified - just echo for test purposes)
|
|
// Real implementation would compute SHA-1 + base64
|
|
let accept_response = format!(
|
|
"HTTP/1.1 101 Switching Protocols\r\n\
|
|
Upgrade: websocket\r\n\
|
|
Connection: Upgrade\r\n\
|
|
Sec-WebSocket-Accept: {}\r\n\
|
|
\r\n",
|
|
ws_key
|
|
);
|
|
|
|
if stream.write_all(accept_response.as_bytes()).await.is_err() {
|
|
return;
|
|
}
|
|
|
|
// Echo all data back (raw TCP after upgrade)
|
|
let mut echo_buf = vec![0u8; 65536];
|
|
loop {
|
|
let n = match stream.read(&mut echo_buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
if stream.write_all(&echo_buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Generate a self-signed certificate for testing using rcgen.
|
|
/// Returns (cert_pem, key_pem).
|
|
pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
|
|
use rcgen::{CertificateParams, KeyPair};
|
|
|
|
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
|
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
|
|
|
let key_pair = KeyPair::generate().unwrap();
|
|
let cert = params.self_signed(&key_pair).unwrap();
|
|
|
|
(cert.pem(), key_pair.serialize_pem())
|
|
}
|
|
|
|
/// Start a TLS echo server using the given cert/key.
|
|
/// Returns the join handle.
|
|
pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
|
|
use std::sync::Arc;
|
|
|
|
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
|
.expect("Failed to build TLS acceptor");
|
|
let acceptor = Arc::new(acceptor);
|
|
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.expect("Failed to bind TLS echo server");
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
let acc = acceptor.clone();
|
|
tokio::spawn(async move {
|
|
let mut tls_stream = match acc.accept(stream).await {
|
|
Ok(s) => s,
|
|
Err(_) => return,
|
|
};
|
|
let mut buf = vec![0u8; 65536];
|
|
loop {
|
|
let n = match tls_stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
if tls_stream.write_all(&buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Helper to create a TLS terminate route with static cert for testing.
|
|
pub fn make_tls_terminate_route(
|
|
port: u16,
|
|
domain: &str,
|
|
target_host: &str,
|
|
target_port: u16,
|
|
cert_pem: &str,
|
|
key_pem: &str,
|
|
) -> rustproxy_config::RouteConfig {
|
|
let mut route = make_test_route(port, Some(domain), target_host, target_port);
|
|
route.action.tls = Some(rustproxy_config::RouteTls {
|
|
mode: rustproxy_config::TlsMode::Terminate,
|
|
certificate: Some(rustproxy_config::CertificateSpec::Static(
|
|
rustproxy_config::CertificateConfig {
|
|
cert: cert_pem.to_string(),
|
|
key: key_pem.to_string(),
|
|
ca: None,
|
|
key_file: None,
|
|
cert_file: None,
|
|
},
|
|
)),
|
|
acme: None,
|
|
versions: None,
|
|
ciphers: None,
|
|
honor_cipher_order: None,
|
|
session_timeout: None,
|
|
});
|
|
route
|
|
}
|
|
|
|
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
|
|
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
|
|
pub async fn start_tls_ws_echo_backend(
|
|
port: u16,
|
|
cert_pem: &str,
|
|
key_pem: &str,
|
|
) -> JoinHandle<()> {
|
|
use std::sync::Arc;
|
|
|
|
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
|
.expect("Failed to build TLS acceptor");
|
|
let acceptor = Arc::new(acceptor);
|
|
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
|
.await
|
|
.unwrap_or_else(|_| panic!("Failed to bind TLS WS echo backend on port {}", port));
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
let (stream, _) = match listener.accept().await {
|
|
Ok(conn) => conn,
|
|
Err(_) => break,
|
|
};
|
|
let acc = acceptor.clone();
|
|
tokio::spawn(async move {
|
|
let mut tls_stream = match acc.accept(stream).await {
|
|
Ok(s) => s,
|
|
Err(_) => return,
|
|
};
|
|
|
|
// Read the HTTP upgrade request
|
|
let mut buf = vec![0u8; 4096];
|
|
let n = match tls_stream.read(&mut buf).await {
|
|
Ok(0) | Err(_) => return,
|
|
Ok(n) => n,
|
|
};
|
|
|
|
let req_str = String::from_utf8_lossy(&buf[..n]);
|
|
|
|
// Extract Sec-WebSocket-Key for handshake
|
|
let ws_key = req_str
|
|
.lines()
|
|
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
|
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
|
.unwrap_or_default();
|
|
|
|
// Send 101 Switching Protocols
|
|
let accept_response = format!(
|
|
"HTTP/1.1 101 Switching Protocols\r\n\
|
|
Upgrade: websocket\r\n\
|
|
Connection: Upgrade\r\n\
|
|
Sec-WebSocket-Accept: {}\r\n\
|
|
\r\n",
|
|
ws_key
|
|
);
|
|
|
|
if tls_stream
|
|
.write_all(accept_response.as_bytes())
|
|
.await
|
|
.is_err()
|
|
{
|
|
return;
|
|
}
|
|
|
|
// Echo all data back (raw TCP after upgrade)
|
|
let mut echo_buf = vec![0u8; 65536];
|
|
loop {
|
|
let n = match tls_stream.read(&mut echo_buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
if tls_stream.write_all(&echo_buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Helper to create a TLS passthrough route for testing.
|
|
pub fn make_tls_passthrough_route(
|
|
port: u16,
|
|
domain: Option<&str>,
|
|
target_host: &str,
|
|
target_port: u16,
|
|
) -> rustproxy_config::RouteConfig {
|
|
let mut route = make_test_route(port, domain, target_host, target_port);
|
|
route.action.tls = Some(rustproxy_config::RouteTls {
|
|
mode: rustproxy_config::TlsMode::Passthrough,
|
|
certificate: None,
|
|
acme: None,
|
|
versions: None,
|
|
ciphers: None,
|
|
honor_cipher_order: None,
|
|
session_timeout: None,
|
|
});
|
|
route
|
|
}
|