feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers

This commit is contained in:
2026-04-26 15:11:10 +00:00
parent 8fa3a51b03
commit af4908b63f
53 changed files with 2350 additions and 1196 deletions
+12 -8
View File
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error};
use tracing::{debug, error, info};
/// ACME HTTP-01 challenge server.
pub struct ChallengeServer {
@@ -47,7 +47,10 @@ impl ChallengeServer {
}
/// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn start(
&mut self,
port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port);
@@ -101,10 +104,7 @@ impl ChallengeServer {
pub async fn stop(&mut self) {
self.cancel.cancel();
if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
}
self.challenges.clear();
self.cancel = CancellationToken::new();
@@ -154,10 +154,14 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
.await
.unwrap();
let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; });
tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new()))
+297 -140
View File
@@ -57,24 +57,27 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result;
use tracing::{info, warn, debug, error};
use arc_swap::ArcSwap;
use tracing::{debug, error, info, warn};
// Re-export key types
pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http;
pub use rustproxy_metrics;
pub use rustproxy_passthrough;
pub use rustproxy_routing;
pub use rustproxy_security;
pub use rustproxy_tls;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec};
use rustproxy_config::{CertificateSpec, RouteConfig, RustProxyOptions, TlsMode};
use rustproxy_metrics::{Metrics, MetricsCollector, Statistics};
use rustproxy_passthrough::{
ConnectionConfig, TcpListenerManager, TlsCertConfig, UdpListenerManager,
};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_security::IpBlockList;
use rustproxy_tls::{CertBundle, CertManager, CertMetadata, CertSource, CertStore};
use tokio_util::sync::CancellationToken;
/// Certificate status.
@@ -106,6 +109,8 @@ pub struct RustProxy {
loaded_certs: HashMap<String, TlsCertConfig>,
/// Cancellation token for cooperative shutdown of background tasks.
cancel_token: CancellationToken,
/// Shared global ingress blocklist, hot-reloadable across TCP/UDP listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
}
impl RustProxy {
@@ -127,13 +132,19 @@ impl RustProxy {
let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options)
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let cert_manager =
Self::build_cert_manager(&options).map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let retention = options.metrics.as_ref()
let retention = options
.metrics
.as_ref()
.and_then(|m| m.retention_seconds)
.unwrap_or(3600) as usize;
let security_policy = Arc::new(ArcSwap::from(Arc::new(Self::build_ip_block_list(
options.security_policy.as_ref(),
))));
Ok(Self {
options,
route_table: ArcSwap::from(Arc::new(route_manager)),
@@ -149,6 +160,7 @@ impl RustProxy {
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
loaded_certs: HashMap::new(),
cancel_token: CancellationToken::new(),
security_policy,
})
}
@@ -163,24 +175,25 @@ impl RustProxy {
// Apply default target if route has no targets
if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}",
default_target.host, default_target.port,
route.name.as_deref().unwrap_or("unnamed"));
route.action.targets = Some(vec![
rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
port: rustproxy_config::PortSpec::Fixed(default_target.port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}
]);
debug!(
"Applying default target {}:{} to route {:?}",
default_target.host,
default_target.port,
route.name.as_deref().unwrap_or("unnamed")
);
route.action.targets = Some(vec![rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
port: rustproxy_config::PortSpec::Fixed(default_target.port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]);
}
}
@@ -199,7 +212,10 @@ impl RustProxy {
if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some(
allow_list.iter().map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone())).collect()
allow_list
.iter()
.map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone()))
.collect(),
);
}
if let Some(ref block_list) = default_security.ip_block_list {
@@ -208,8 +224,10 @@ impl RustProxy {
// Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed"));
debug!(
"Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed")
);
route.security = Some(security);
}
}
@@ -224,13 +242,17 @@ impl RustProxy {
return None;
}
let email = acme.email.clone()
.or_else(|| acme.account_email.clone());
let email = acme.email.clone().or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new();
Some(CertManager::new(store, email, use_production, renew_before_days))
Some(CertManager::new(
store,
email,
use_production,
renew_before_days,
))
}
/// Build ConnectionConfig from RustProxyOptions.
@@ -248,7 +270,10 @@ impl RustProxy {
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
proxy_ips: options.proxy_ips.as_deref().unwrap_or(&[])
proxy_ips: options
.proxy_ips
.as_deref()
.unwrap_or(&[])
.iter()
.filter_map(|s| s.parse::<std::net::IpAddr>().ok())
.collect(),
@@ -258,6 +283,22 @@ impl RustProxy {
}
}
fn build_ip_block_list(policy: Option<&rustproxy_config::SecurityPolicy>) -> IpBlockList {
let Some(policy) = policy else {
return IpBlockList::empty();
};
let mut entries = Vec::new();
if let Some(blocked_ips) = &policy.blocked_ips {
entries.extend(blocked_ips.iter().cloned());
}
if let Some(blocked_cidrs) = &policy.blocked_cidrs {
entries.extend(blocked_cidrs.iter().cloned());
}
IpBlockList::new(&entries)
}
/// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> {
if self.started {
@@ -272,7 +313,11 @@ impl RustProxy {
let route_manager = self.route_table.load();
let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
info!(
"Configured {} routes on {} ports",
route_manager.route_count(),
ports.len()
);
// Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics(
@@ -282,7 +327,8 @@ impl RustProxy {
// Apply connection config from options
let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
debug!(
"Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms,
@@ -291,6 +337,7 @@ impl RustProxy {
// Clone proxy_ips before conn_config is moved into the TCP listener
let udp_proxy_ips = conn_config.proxy_ips.clone();
listener.set_connection_config(conn_config);
listener.set_security_policy(Arc::clone(&self.security_policy));
// Share the socket-handler relay path with the listener
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
@@ -303,10 +350,13 @@ impl RustProxy {
let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
tls_configs.insert(
domain.clone(),
TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
}
}
}
@@ -330,7 +380,9 @@ impl RustProxy {
let mut tcp_ports = std::collections::HashSet::new();
let mut udp_ports = std::collections::HashSet::new();
for route in &self.options.routes {
if !route.is_enabled() { continue; }
if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref();
let route_ports = route.route_match.ports.to_ports();
for port in route_ports {
@@ -371,6 +423,7 @@ impl RustProxy {
connection_registry,
);
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Share HttpProxyService with H3 — same route matching, connection
// pool, and ALPN protocol detection as the TCP/HTTP path.
@@ -379,10 +432,15 @@ impl RustProxy {
udp_mgr.set_h3_service(Arc::new(h3_svc));
for port in &udp_ports {
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
udp_mgr
.add_port_with_tls(*port, quic_tls_config.clone())
.await?;
}
info!("UDP listeners started on {} ports: {:?}",
udp_ports.len(), udp_mgr.listening_ports());
info!(
"UDP listeners started on {} ports: {:?}",
udp_ports.len(),
udp_mgr.listening_ports()
);
self.udp_listener_manager = Some(udp_mgr);
}
@@ -391,16 +449,22 @@ impl RustProxy {
// Start the throughput sampling task with cooperative cancellation
let metrics = Arc::clone(&self.metrics);
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
let conn_tracker = self
.listener_manager
.as_ref()
.unwrap()
.conn_tracker()
.clone();
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
let interval_ms = self.options.metrics.as_ref()
let interval_ms = self
.options
.metrics
.as_ref()
.and_then(|m| m.sample_interval_ms)
.unwrap_or(1000);
let sampling_cancel = self.cancel_token.clone();
self.sampling_handle = Some(tokio::spawn(async move {
let mut interval = tokio::time::interval(
std::time::Duration::from_millis(interval_ms)
);
let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
loop {
tokio::select! {
_ = sampling_cancel.cancelled() => break,
@@ -442,7 +506,10 @@ impl RustProxy {
continue;
}
let cert_spec = route.action.tls.as_ref()
let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec {
@@ -466,16 +533,25 @@ impl RustProxy {
return;
}
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
info!(
"Auto-provisioning certificates for {} domains",
domains_to_provision.len()
);
// Start challenge server
let acme_port = self.options.acme.as_ref()
let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
error!(
"Failed to start ACME challenge server on port {}: {}",
acme_port, e
);
return;
}
@@ -488,13 +564,15 @@ impl RustProxy {
if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| {
challenge_server_ref.set_challenge(
pending.token.clone(),
pending.key_authorization.clone(),
);
async move { Ok(()) }
}).await;
let result = acme_client
.provision(domain, |pending| {
challenge_server_ref.set_challenge(
pending.token.clone(),
pending.key_authorization.clone(),
);
async move { Ok(()) }
})
.await;
match result {
Ok((cert_pem, key_pem)) => {
@@ -539,7 +617,10 @@ impl RustProxy {
None => return,
};
let auto_renew = self.options.acme.as_ref()
let auto_renew = self
.options
.acme
.as_ref()
.and_then(|a| a.auto_renew)
.unwrap_or(true);
@@ -547,11 +628,17 @@ impl RustProxy {
return;
}
let check_interval_hours = self.options.acme.as_ref()
let check_interval_hours = self
.options
.acme
.as_ref()
.and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24);
let acme_port = self.options.acme.as_ref()
let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
@@ -664,17 +751,19 @@ impl RustProxy {
/// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes
rustproxy_config::validate_routes(&routes)
.map_err(|errors| {
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
})?;
rustproxy_config::validate_routes(&routes).map_err(|errors| {
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
})?;
let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports",
new_manager.route_count(), new_ports.len());
info!(
"Updating routes: {} routes on {} ports",
new_manager.route_count(),
new_ports.len()
);
// Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
@@ -684,28 +773,35 @@ impl RustProxy {
};
// Prune per-route metrics for route IDs that no longer exist
let active_route_ids: HashSet<String> = routes.iter()
.filter_map(|r| r.id.clone())
.collect();
let active_route_ids: HashSet<String> =
routes.iter().filter_map(|r| r.id.clone()).collect();
self.metrics.retain_routes(&active_route_ids);
// Prune per-backend metrics for backends no longer in any route target.
// For PortSpec::Preserve routes, expand across all listening ports since
// the actual runtime port depends on the incoming connection.
let listening_ports = self.get_listening_ports();
let active_backends: HashSet<String> = routes.iter()
let active_backends: HashSet<String> = routes
.iter()
.filter_map(|r| r.action.targets.as_ref())
.flat_map(|targets| targets.iter())
.flat_map(|target| {
let hosts: Vec<String> = target.host.to_vec().into_iter().map(|s| s.to_string()).collect();
let hosts: Vec<String> = target
.host
.to_vec()
.into_iter()
.map(|s| s.to_string())
.collect();
match &target.port {
rustproxy_config::PortSpec::Fixed(p) => {
hosts.into_iter().map(|h| format!("{}:{}", h, p)).collect::<Vec<_>>()
}
rustproxy_config::PortSpec::Fixed(p) => hosts
.into_iter()
.map(|h| format!("{}:{}", h, p))
.collect::<Vec<_>>(),
_ => {
// Preserve/special: expand across all listening ports
let lp = &listening_ports;
hosts.into_iter()
hosts
.into_iter()
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
.collect::<Vec<_>>()
}
@@ -733,10 +829,13 @@ impl RustProxy {
let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
tls_configs.insert(
domain.clone(),
TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
}
}
}
@@ -753,7 +852,9 @@ impl RustProxy {
// Cancel connections on routes that were removed or disabled
listener.invalidate_removed_routes(&active_route_ids);
// Clean up registry entries for removed routes
listener.connection_registry().cleanup_removed_routes(&active_route_ids);
listener
.connection_registry()
.cleanup_removed_routes(&active_route_ids);
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
listener.prune_http_proxy_caches(&active_route_ids);
@@ -766,9 +867,10 @@ impl RustProxy {
None => continue,
};
// Find corresponding old route
let old_route = old_manager.routes().iter().find(|r| {
r.id.as_deref() == Some(new_id)
});
let old_route = old_manager
.routes()
.iter()
.find(|r| r.id.as_deref() == Some(new_id));
let old_route = match old_route {
Some(r) => r,
None => continue, // new route, no existing connections to recycle
@@ -812,11 +914,13 @@ impl RustProxy {
{
let mut new_udp_ports = HashSet::new();
for route in &routes {
if !route.is_enabled() { continue; }
if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref();
match transport {
Some(rustproxy_config::TransportProtocol::Udp) |
Some(rustproxy_config::TransportProtocol::All) => {
Some(rustproxy_config::TransportProtocol::Udp)
| Some(rustproxy_config::TransportProtocol::All) => {
for port in route.route_match.ports.to_ports() {
new_udp_ports.insert(port);
}
@@ -825,7 +929,8 @@ impl RustProxy {
}
}
let old_udp_ports: HashSet<u16> = self.udp_listener_manager
let old_udp_ports: HashSet<u16> = self
.udp_listener_manager
.as_ref()
.map(|u| u.listening_ports().into_iter().collect())
.unwrap_or_default();
@@ -847,6 +952,7 @@ impl RustProxy {
connection_registry,
);
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
let http_proxy = listener.http_proxy().clone();
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
@@ -898,56 +1004,77 @@ impl RustProxy {
/// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref()
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
let cm_arc = self.cert_manager.as_ref().ok_or_else(|| {
anyhow::anyhow!("No certificate manager configured (ACME not enabled)")
})?;
// Find the route by name
let route = self.options.routes.iter()
let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref()
let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
info!(
"Provisioning certificate for route '{}' (domain: {})",
route_name, domain
);
// Start challenge server
let acme_port = self.options.acme.as_ref()
let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await
cs.start(acme_port)
.await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
let result = cm
.renew_domain(&domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
})
.await;
drop(cm);
cs.stop().await;
let bundle = result
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
let bundle = result.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
tls_configs.insert(
domain.clone(),
TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
{
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
tls_configs.insert(
d.clone(),
TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
},
);
}
}
}
@@ -966,7 +1093,10 @@ impl RustProxy {
}
}
info!("Certificate provisioned and loaded for route '{}'", route_name);
info!(
"Certificate provisioned and loaded for route '{}'",
route_name
);
Ok(())
}
@@ -978,10 +1108,16 @@ impl RustProxy {
/// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter()
let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref()
let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager {
@@ -1010,8 +1146,9 @@ impl RustProxy {
let mut metrics = self.metrics.snapshot();
if let Some(ref lm) = self.listener_manager {
let entries = lm.http_proxy().protocol_cache_snapshot();
metrics.detected_protocols = entries.into_iter().map(|e| {
rustproxy_metrics::ProtocolCacheEntryMetric {
metrics.detected_protocols = entries
.into_iter()
.map(|e| rustproxy_metrics::ProtocolCacheEntryMetric {
host: e.host,
port: e.port,
domain: e.domain,
@@ -1026,8 +1163,8 @@ impl RustProxy {
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
h2_consecutive_failures: e.h2_consecutive_failures,
h3_consecutive_failures: e.h3_consecutive_failures,
}
}).collect();
})
.collect();
}
metrics
}
@@ -1058,9 +1195,7 @@ impl RustProxy {
/// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
let uptime = self.started_at.map(|t| t.elapsed().as_secs()).unwrap_or(0);
Statistics {
active_connections: self.metrics.active_connections(),
@@ -1071,6 +1206,13 @@ impl RustProxy {
}
}
/// Update the global ingress security policy.
pub fn set_security_policy(&mut self, policy: rustproxy_config::SecurityPolicy) {
self.security_policy
.store(Arc::new(Self::build_ip_block_list(Some(&policy))));
self.options.security_policy = Some(policy);
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
/// take effect immediately for all new connections.
@@ -1130,10 +1272,13 @@ impl RustProxy {
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !configs.contains_key(d) {
configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
configs.insert(
d.clone(),
TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
},
);
}
}
}
@@ -1166,7 +1311,8 @@ impl RustProxy {
info!("Loading certificate for domain: {}", domain);
// Check if the cert actually changed (for selective connection recycling)
let cert_changed = self.loaded_certs
let cert_changed = self
.loaded_certs
.get(domain)
.map(|existing| existing.cert_pem != cert_pem)
.unwrap_or(false); // new domain = no existing connections to recycle
@@ -1196,10 +1342,13 @@ impl RustProxy {
}
// Persist in loaded_certs so future rebuild calls include this cert
self.loaded_certs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
});
self.loaded_certs.insert(
domain.to_string(),
TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
},
);
// Hot-swap TLS config on TCP and QUIC listeners
let tls_configs = self.current_tls_configs().await;
@@ -1222,7 +1371,9 @@ impl RustProxy {
// Recycle existing connections if cert actually changed
if cert_changed {
if let Some(ref listener) = self.listener_manager {
listener.connection_registry().recycle_for_cert_change(domain);
listener
.connection_registry()
.recycle_for_cert_change(domain);
}
}
@@ -1244,16 +1395,22 @@ impl RustProxy {
continue;
}
let cert_spec = route.action.tls.as_ref()
let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
});
configs.insert(
domain.to_string(),
TlsCertConfig {
cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
},
);
}
}
}
+4 -9
View File
@@ -1,12 +1,12 @@
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use anyhow::Result;
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
)
.init();
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!(
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
// Validate-only mode
if cli.validate {
+157 -65
View File
@@ -1,7 +1,7 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use tracing::{error, info};
use crate::RustProxy;
use rustproxy_config::RustProxyOptions;
@@ -141,14 +141,19 @@ async fn handle_request(
"start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
"getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
"setSocketHandlerRelay" => {
handle_set_socket_handler_relay(&id, &request.params, proxy).await
}
"setDatagramHandlerRelay" => {
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
}
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
@@ -167,7 +172,12 @@ async fn handle_start(
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'config' parameter".to_string(),
)
}
};
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
@@ -176,38 +186,31 @@ async fn handle_start(
};
match RustProxy::new(options) {
Ok(mut p) => {
match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
Ok(mut p) => match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
},
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
}
}
async fn handle_stop(
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
match proxy.as_mut() {
Some(p) => {
match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
Some(p) => match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
},
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
@@ -224,7 +227,12 @@ async fn handle_update_routes(
let routes = match params.get("routes") {
Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routes' parameter".to_string(),
)
}
};
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
@@ -234,36 +242,72 @@ async fn handle_update_routes(
match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
Err(e) => {
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
}
}
}
fn handle_get_metrics(
fn handle_set_security_policy(
id: &str,
proxy: &Option<RustProxy>,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let policy = match params.get("policy") {
Some(policy) => policy,
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'policy' parameter".to_string(),
)
}
};
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
Ok(policy) => policy,
Err(e) => {
return ManagementResponse::err(
id.to_string(),
format!("Invalid security policy: {}", e),
)
}
};
p.set_security_policy(policy);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let metrics = p.get_metrics();
match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize metrics: {}", e),
),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
fn handle_get_statistics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let stats = p.get_statistics();
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize statistics: {}", e),
),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to provision certificate: {}", e),
),
}
}
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to renew certificate: {}", e),
),
}
}
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
})),
Some(status) => ManagementResponse::ok(
id.to_string(),
serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
}),
),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
}
}
fn handle_get_listening_ports(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let ports = p.get_listening_ports();
@@ -361,7 +426,8 @@ async fn handle_set_socket_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
@@ -381,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
};
match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to add port {}: {}", port, e),
),
}
}
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
};
match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to remove port {}: {}", port, e),
),
}
}
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'domain' parameter".to_string(),
)
}
};
let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
}
};
let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
}
};
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
let ca = params
.get("ca")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to load certificate for {}: {}", domain, e),
),
}
}
+8 -8
View File
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str.lines()
let host = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
@@ -336,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines()
let ws_key = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
@@ -378,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
@@ -458,11 +462,7 @@ pub fn make_tls_terminate_route(
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
pub async fn start_tls_ws_echo_backend(
port: u16,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
@@ -1,9 +1,9 @@
mod common;
use bytes::Buf;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
use bytes::Buf;
use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
use std::sync::Arc;
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
@@ -14,7 +14,14 @@ fn make_h3_route(
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
let mut route = make_tls_terminate_route(
port,
"localhost",
target_host,
target_port,
cert_pem,
key_pem,
);
route.route_match.transport = Some(TransportProtocol::All);
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
route.action.udp = Some(RouteUdp {
@@ -89,11 +96,9 @@ async fn test_h3_response_stream_finishes() {
.await
.expect("QUIC handshake failed");
let (mut driver, mut send_request) = h3::client::new(
h3_quinn::Connection::new(connection),
)
.await
.expect("H3 connection setup failed");
let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
.await
.expect("H3 connection setup failed");
// Drive the H3 connection in background
tokio::spawn(async move {
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
.body(())
.unwrap();
let mut stream = send_request.send_request(req).await
let mut stream = send_request
.send_request(req)
.await
.expect("Failed to send H3 request");
stream.finish().await
stream
.finish()
.await
.expect("Failed to finish sending H3 request body");
// 6. Read response headers
let resp = stream.recv_response().await
let resp = stream
.recv_response()
.await
.expect("Failed to receive H3 response");
assert_eq!(resp.status(), http::StatusCode::OK,
"Expected 200 OK, got {}", resp.status());
assert_eq!(
resp.status(),
http::StatusCode::OK,
"Expected 200 OK, got {}",
resp.status()
);
// 7. Read body and verify stream ends (FIN received)
// This is the critical assertion: recv_data() must return None (stream ended)
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
let result = with_timeout(async {
let mut total = 0usize;
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
total += chunk.remaining();
}
// recv_data() returned None => stream ended (FIN received)
total
}, 10)
let result = with_timeout(
async {
let mut total = 0usize;
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
total += chunk.remaining();
}
// recv_data() returned None => stream ended (FIN received)
total
},
10,
)
.await;
let bytes_received = result.expect(
"TIMEOUT: H3 stream never ended (FIN not received by client). \
The proxy sent all response data but failed to send the QUIC stream FIN."
The proxy sent all response data but failed to send the QUIC stream FIN.",
);
assert_eq!(
bytes_received,
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
}, 10)
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
},
10,
)
.await
.unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
assert!(
result.contains(r#""method":"GET"#),
"Expected GET method, got: {}",
result
);
assert!(
result.contains(r#""path":"/hello"#),
"Expected /hello path, got: {}",
result
);
assert!(
result.contains(r#""backend":"main"#),
"Expected main backend, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
let options = RustProxyOptions {
routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
make_test_route(
proxy_port,
Some("alpha.example.com"),
"127.0.0.1",
backend1_port,
),
make_test_route(
proxy_port,
Some("beta.example.com"),
"127.0.0.1",
backend2_port,
),
],
..Default::default()
};
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let alpha_result = with_timeout(async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
let alpha_result = with_timeout(
async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
},
10,
)
.await
.unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
assert!(
alpha_result.contains(r#""backend":"alpha"#),
"Expected alpha backend, got: {}",
alpha_result
);
// Test beta domain
let beta_result = with_timeout(async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
let beta_result = with_timeout(
async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
},
10,
)
.await
.unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
assert!(
beta_result.contains(r#""backend":"beta"#),
"Expected beta backend, got: {}",
beta_result
);
proxy.stop().await.unwrap();
}
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test API path
let api_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
}, 10)
let api_result = with_timeout(
async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
},
10,
)
.await
.unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
assert!(
api_result.contains(r#""backend":"api"#),
"Expected api backend, got: {}",
api_result
);
// Test web path (no /api prefix)
let web_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
}, 10)
let web_result = with_timeout(
async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
},
10,
)
.await
.unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
assert!(
web_result.contains(r#""backend":"web"#),
"Expected web backend, got: {}",
web_result
);
proxy.stop().await.unwrap();
}
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
.unwrap();
// Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
assert!(result.to_lowercase().contains("access-control-allow-origin"),
"Expected CORS header, got: {}", result);
assert!(
result.contains("204"),
"Expected 204 status, got: {}",
result
);
assert!(
result
.to_lowercase()
.contains("access-control-allow-origin"),
"Expected CORS header, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
}, 10)
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
},
10,
)
.await
.unwrap();
// Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
assert!(
result.contains("500"),
"Expected 500 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
// Create a route only for a specific domain
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
routes: vec![make_test_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
9999,
)],
..Default::default()
};
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
}, 10)
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
},
10,
)
.await
.unwrap();
// Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
}, 10)
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
},
10,
)
.await
.unwrap();
// Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -295,38 +388,53 @@ async fn test_https_terminate_http_forward() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send HTTP request through TLS
let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain
);
tls_stream.write_all(request.as_bytes()).await.unwrap();
// Send HTTP request through TLS
let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain
);
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
},
10,
)
.await
.unwrap();
let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
assert!(
body.contains(r#""method":"GET"#),
"Expected GET, got: {}",
body
);
assert!(
body.contains(r#""path":"/api/data"#),
"Expected /api/data, got: {}",
body
);
assert!(
body.contains(r#""backend":"tls-backend"#),
"Expected tls-backend, got: {}",
body
);
proxy.stop().await.unwrap();
}
@@ -347,59 +455,68 @@ async fn test_websocket_through_proxy() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let result = with_timeout(
async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send WebSocket upgrade request
let request = format!(
"GET /ws HTTP/1.1\r\n\
// Send WebSocket upgrade request
let request = format!(
"GET /ws HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(request.as_bytes()).await.unwrap();
);
stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; }
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
// Read the 101 response
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 {
break;
}
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len - 4..] == *b"\r\n\r\n" {
break;
}
}
}
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(
response_str.contains("101"),
"Expected 101 Switching Protocols, got: {}",
response_str
);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap();
// After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap();
// Read echoed data
let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n];
// Read echoed data
let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data");
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
}, 10)
"ok".to_string()
},
10,
)
.await
.unwrap();
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
// Create terminate-and-reencrypt routes
let mut route1 = make_tls_terminate_route(
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
);
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let mut route2 = make_tls_terminate_route(
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
);
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -450,27 +577,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
let alpha_result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let alpha_result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let request =
"GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
},
10,
)
.await
.unwrap();
@@ -498,27 +630,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
);
// Test beta domain - different host goes to different backend
let beta_result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let beta_result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let request =
"GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
},
10,
)
.await
.unwrap();
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector =
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send WebSocket upgrade request through TLS
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
assert!(wait_for_port(proxy_port, 2000).await);
// HTTP request should match the route and get proxied
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
extract_body(&response).to_string()
}, 10)
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
extract_body(&response).to_string()
},
10,
)
.await
.unwrap();
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
assert!(
wait_for_port(port, 2000).await,
"Port should be listening after start"
);
proxy.stop().await.unwrap();
// Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
assert!(
!wait_for_port(port, 200).await,
"Port should not be listening after stop"
);
}
#[tokio::test]
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
routes: vec![make_test_route(
port,
Some("old.example.com"),
"127.0.0.1",
8080,
)],
..Default::default()
};
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
proxy.start().await.unwrap();
// Update routes atomically
let new_routes = vec![
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
];
let new_routes = vec![make_test_route(
port,
Some("new.example.com"),
"127.0.0.1",
9090,
)];
let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok());
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
// Add a new port
proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
assert!(
wait_for_port(port2, 2000).await,
"New port should be listening"
);
// Remove the port
proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
assert!(
!wait_for_port(port2, 200).await,
"Removed port should not be listening"
);
// Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
assert!(
wait_for_port(port1, 200).await,
"Original port should still be listening"
);
proxy.stop().await.unwrap();
}
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
assert!(
stats.total_connections > 0,
"Expected total_connections > 0, got {}",
stats.total_connections
);
proxy.stop().await.unwrap();
}
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0,
"Expected some connections tracked, got {}", stats.total_connections);
assert!(
stats.total_connections > 0,
"Expected some connections tracked, got {}",
stats.total_connections
);
proxy.stop().await.unwrap();
}
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
assert!(
!wait_for_port(port2, 200).await,
"port2 should not be listening yet"
);
// Update routes to use port2 instead
let new_routes = vec![
make_test_route(port2, None, "127.0.0.1", backend_port),
];
let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
assert!(
wait_for_port(port2, 2000).await,
"port2 should be listening after reload"
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
assert!(
!wait_for_port(port1, 200).await,
"port1 should be closed after reload"
);
// Verify port2 works
let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
assert!(
ports.contains(&port2),
"Expected port2 in listening ports: {:?}",
ports
);
assert!(
!ports.contains(&port1),
"port1 should not be in listening ports: {:?}",
ports
);
proxy.stop().await.unwrap();
}
@@ -24,19 +24,25 @@ async fn test_tcp_forward_echo() {
proxy.start().await.unwrap();
// Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
assert!(
wait_for_port(proxy_port, 2000).await,
"Proxy port not ready"
);
// Connect and send data
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello world").await.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello world").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
5,
)
.await
.unwrap();
@@ -61,21 +67,24 @@ async fn test_tcp_forward_large_payload() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send 1MB of data
let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
// Send 1MB of data
let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
// Read all back
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 10)
// Read all back
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
},
10,
)
.await
.unwrap();
@@ -100,29 +109,32 @@ async fn test_tcp_forward_multiple_connections() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
handles.push(tokio::spawn(async move {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let result = with_timeout(
async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
handles.push(tokio::spawn(async move {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 10)
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
},
10,
)
.await
.unwrap();
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
}, 5)
let result = with_timeout(
async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
},
5,
)
.await
.unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down");
assert!(
result,
"Should be able to connect to proxy even if backend is down"
);
proxy.stop().await.unwrap();
}
@@ -178,16 +196,19 @@ async fn test_tcp_forward_bidirectional() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"test data").await.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"test data").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
5,
)
.await
.unwrap();
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
make_tls_passthrough_route(
proxy_port,
Some("one.example.com"),
"127.0.0.1",
backend1_port,
),
make_tls_passthrough_route(
proxy_port,
Some("two.example.com"),
"127.0.0.1",
backend2_port,
),
],
..Default::default()
};
@@ -76,39 +86,53 @@ async fn test_tls_passthrough_sni_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
5,
)
.await
.unwrap();
// Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
assert!(
result.starts_with("BACKEND1:"),
"Expected BACKEND1 prefix, got: {}",
result
);
// Now test routing to backend2
let result2 = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let result2 = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
5,
)
.await
.unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
assert!(
result2.starts_with("BACKEND2:"),
"Expected BACKEND2 prefix, got: {}",
result2
);
proxy.stop().await.unwrap();
}
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
],
routes: vec![make_tls_passthrough_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default()
};
@@ -132,21 +159,24 @@ async fn test_tls_passthrough_unknown_sni() {
assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
}, 5)
let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
},
5,
)
.await
.unwrap();
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
],
routes: vec![make_tls_passthrough_route(
proxy_port,
Some("*.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default()
};
@@ -174,21 +207,28 @@ async fn test_tls_passthrough_wildcard_domain() {
assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
5,
)
.await
.unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
assert!(
result.starts_with("WILDCARD:"),
"Expected WILDCARD prefix, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -222,24 +262,29 @@ async fn test_tls_passthrough_multiple_domains() {
("beta.example.com", "B2:"),
("gamma.example.com", "B3:"),
] {
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
5,
)
.await
.unwrap();
assert!(
result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}",
domain, expected_prefix, result
domain,
expected_prefix,
result
);
}
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -84,23 +89,26 @@ async fn test_tls_terminate_basic() {
assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
10,
)
.await
.unwrap();
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&proxy_cert,
&proxy_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -138,23 +151,26 @@ async fn test_tls_terminate_and_reencrypt() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
10,
)
.await
.unwrap();
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
let options = RustProxyOptions {
routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
make_tls_terminate_route(
proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
),
make_tls_terminate_route(
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
),
],
..Default::default()
};
@@ -188,27 +218,35 @@ async fn test_tls_terminate_sni_cert_selection() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
},
10,
)
.await
.unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
assert!(
result.starts_with("ALPHA:"),
"Expected ALPHA prefix, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -233,26 +276,29 @@ async fn test_tls_terminate_large_payload() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send 1MB of data
let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap();
// Send 1MB of data
let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap();
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 15)
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
},
15,
)
.await
.unwrap();
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -281,37 +332,40 @@ async fn test_tls_terminate_concurrent() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
let dom = domain.to_string();
handles.push(tokio::spawn(async move {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let result = with_timeout(
async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
let dom = domain.to_string();
handles.push(tokio::spawn(async move {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap();
let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 15)
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
},
15,
)
.await
.unwrap();