feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates

This commit is contained in:
2026-02-09 10:55:46 +00:00
parent a31fee41df
commit 1df3b7af4a
151 changed files with 16927 additions and 19432 deletions

View File

@@ -0,0 +1,44 @@
[package]
name = "rustproxy"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "High-performance multi-protocol proxy built on Pingora, compatible with SmartProxy configuration"
[[bin]]
name = "rustproxy"
path = "src/main.rs"
[lib]
name = "rustproxy"
path = "src/lib.rs"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-tls = { workspace = true }
rustproxy-passthrough = { workspace = true }
rustproxy-http = { workspace = true }
rustproxy-nftables = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
clap = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
tokio-util = { workspace = true }
dashmap = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
[dev-dependencies]
rcgen = { workspace = true }

View File

@@ -0,0 +1,177 @@
//! HTTP-01 ACME challenge server.
//!
//! A lightweight HTTP server that serves ACME challenge responses at
//! `/.well-known/acme-challenge/<token>`.
use std::sync::Arc;
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error};
/// ACME HTTP-01 challenge server.
pub struct ChallengeServer {
/// Token -> key authorization mapping
challenges: Arc<DashMap<String, String>>,
/// Cancellation token to stop the server
cancel: CancellationToken,
/// Server task handle
handle: Option<tokio::task::JoinHandle<()>>,
}
impl ChallengeServer {
/// Create a new challenge server (not yet started).
pub fn new() -> Self {
Self {
challenges: Arc::new(DashMap::new()),
cancel: CancellationToken::new(),
handle: None,
}
}
/// Register a challenge token -> key_authorization mapping.
pub fn set_challenge(&self, token: String, key_authorization: String) {
debug!("Registered ACME challenge: token={}", token);
self.challenges.insert(token, key_authorization);
}
/// Remove a challenge token.
pub fn remove_challenge(&self, token: &str) {
self.challenges.remove(token);
}
/// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port);
let challenges = Arc::clone(&self.challenges);
let cancel = self.cancel.clone();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("ACME challenge server stopping");
break;
}
result = listener.accept() => {
match result {
Ok((stream, _)) => {
let challenges = Arc::clone(&challenges);
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let challenges = Arc::clone(&challenges);
async move {
Self::handle_request(req, &challenges)
}
});
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service);
if let Err(e) = conn.await {
debug!("Challenge server connection error: {}", e);
}
});
}
Err(e) => {
error!("Challenge server accept error: {}", e);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
}
}
});
self.handle = Some(handle);
Ok(())
}
/// Stop the challenge server.
pub async fn stop(&mut self) {
self.cancel.cancel();
if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
}
self.challenges.clear();
self.cancel = CancellationToken::new();
info!("ACME challenge server stopped");
}
/// Handle an HTTP request for ACME challenges.
fn handle_request(
req: Request<Incoming>,
challenges: &DashMap<String, String>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let path = req.uri().path();
if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
if let Some(key_auth) = challenges.get(token) {
debug!("Serving ACME challenge for token: {}", token);
return Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/plain")
.body(Full::new(Bytes::from(key_auth.value().clone())))
.unwrap());
}
}
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found")))
.unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_challenge_server_lifecycle() {
let mut server = ChallengeServer::new();
// Set a challenge before starting
server.set_challenge("test-token".to_string(), "test-key-auth".to_string());
// Start on a random port
server.start(19900).await.unwrap();
// Give server a moment to start
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; });
let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = sender.send_request(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// Test 404 for unknown token
let req = Request::get("/.well-known/acme-challenge/unknown")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = sender.send_request(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
server.stop().await;
}
}

View File

@@ -0,0 +1,931 @@
//! # RustProxy
//!
//! High-performance multi-protocol proxy built on Rust,
//! compatible with SmartProxy configuration.
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use rustproxy::RustProxy;
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let options = RustProxyOptions {
//! routes: vec![
//! create_https_passthrough_route("example.com", "backend", 443),
//! ],
//! ..Default::default()
//! };
//!
//! let mut proxy = RustProxy::new(options)?;
//! proxy.start().await?;
//! Ok(())
//! }
//! ```
pub mod challenge_server;
pub mod management;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result;
use tracing::{info, warn, debug, error};
// Re-export key types
pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http;
pub use rustproxy_nftables;
pub use rustproxy_metrics;
pub use rustproxy_security;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_nftables::{NftManager, rule_builder};
/// Certificate status.
#[derive(Debug, Clone)]
pub struct CertStatus {
pub domain: String,
pub source: String,
pub expires_at: u64,
pub is_valid: bool,
}
/// The main RustProxy struct.
/// This is the primary public API matching SmartProxy's interface.
pub struct RustProxy {
options: RustProxyOptions,
route_table: ArcSwap<RouteManager>,
listener_manager: Option<TcpListenerManager>,
metrics: Arc<MetricsCollector>,
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
challenge_server: Option<challenge_server::ChallengeServer>,
renewal_handle: Option<tokio::task::JoinHandle<()>>,
nft_manager: Option<NftManager>,
started: bool,
started_at: Option<Instant>,
/// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
socket_handler_relay_path: Option<String>,
}
impl RustProxy {
/// Create a new RustProxy instance with the given configuration.
pub fn new(mut options: RustProxyOptions) -> Result<Self> {
// Apply defaults to routes before validation
Self::apply_defaults(&mut options);
// Validate routes
if let Err(errors) = rustproxy_config::validate_routes(&options.routes) {
for err in &errors {
warn!("Route validation error: {}", err);
}
if !errors.is_empty() {
anyhow::bail!("Route validation failed with {} errors", errors.len());
}
}
let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options)
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
Ok(Self {
options,
route_table: ArcSwap::from(Arc::new(route_manager)),
listener_manager: None,
metrics: Arc::new(MetricsCollector::new()),
cert_manager,
challenge_server: None,
renewal_handle: None,
nft_manager: None,
started: false,
started_at: None,
socket_handler_relay_path: None,
})
}
/// Apply default configuration to routes that lack targets or security.
fn apply_defaults(options: &mut RustProxyOptions) {
let defaults = match &options.defaults {
Some(d) => d.clone(),
None => return,
};
for route in &mut options.routes {
// Apply default target if route has no targets
if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}",
default_target.host, default_target.port,
route.name.as_deref().unwrap_or("unnamed"));
route.action.targets = Some(vec![
rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
port: rustproxy_config::PortSpec::Fixed(default_target.port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}
]);
}
}
// Apply default security if route has no security
if route.security.is_none() {
if let Some(ref default_security) = defaults.security {
let mut security = rustproxy_config::RouteSecurity {
ip_allow_list: None,
ip_block_list: None,
max_connections: default_security.max_connections,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
};
if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some(allow_list.clone());
}
if let Some(ref block_list) = default_security.ip_block_list {
security.ip_block_list = Some(block_list.clone());
}
// Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed"));
route.security = Some(security);
}
}
}
}
}
/// Build a CertManager from options.
fn build_cert_manager(options: &RustProxyOptions) -> Option<CertManager> {
let acme = options.acme.as_ref()?;
if !acme.enabled.unwrap_or(false) {
return None;
}
let store_path = acme.certificate_store
.as_deref()
.unwrap_or("./certs");
let email = acme.email.clone()
.or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new(store_path);
Some(CertManager::new(store, email, use_production, renew_before_days))
}
/// Build ConnectionConfig from RustProxyOptions.
fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig {
ConnectionConfig {
connection_timeout_ms: options.effective_connection_timeout(),
initial_data_timeout_ms: options.effective_initial_data_timeout(),
socket_timeout_ms: options.effective_socket_timeout(),
max_connection_lifetime_ms: options.effective_max_connection_lifetime(),
graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000),
max_connections_per_ip: options.max_connections_per_ip,
connection_rate_limit_per_minute: options.connection_rate_limit_per_minute,
keep_alive_treatment: options.keep_alive_treatment.clone(),
keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier,
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
}
}
/// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> {
if self.started {
anyhow::bail!("Proxy is already started");
}
info!("Starting RustProxy...");
// Load persisted certificates
if let Some(ref cm) = self.cert_manager {
let mut cm = cm.lock().await;
match cm.load_all() {
Ok(count) => {
if count > 0 {
info!("Loaded {} persisted certificates", count);
}
}
Err(e) => warn!("Failed to load persisted certificates: {}", e),
}
}
// Auto-provision certificates for routes with certificate: 'auto'
self.auto_provision_certificates().await;
let route_manager = self.route_table.load();
let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
// Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics(
Arc::clone(&*route_manager),
Arc::clone(&self.metrics),
);
// Apply connection config from options
let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms,
conn_config.max_connection_lifetime_ms,
);
listener.set_connection_config(conn_config);
// Extract TLS configurations from routes and cert manager
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
// Also load certs from cert manager into TLS config
if let Some(ref cm) = self.cert_manager {
let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
}
}
}
if !tls_configs.is_empty() {
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
listener.set_tls_configs(tls_configs);
}
// Bind all ports
for port in &ports {
listener.add_port(*port).await?;
}
self.listener_manager = Some(listener);
self.started = true;
self.started_at = Some(Instant::now());
// Apply NFTables rules for routes using nftables forwarding engine
self.apply_nftables_rules(&self.options.routes.clone()).await;
// Start renewal timer if ACME is enabled
self.start_renewal_timer();
info!("RustProxy started successfully on ports: {:?}", ports);
Ok(())
}
/// Auto-provision certificates for routes that use certificate: 'auto'.
async fn auto_provision_certificates(&mut self) {
let cm_arc = match self.cert_manager {
Some(ref cm) => Arc::clone(cm),
None => return,
};
let mut domains_to_provision = Vec::new();
for route in &self.options.routes {
let tls_mode = route.tls_mode();
let needs_cert = matches!(
tls_mode,
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
);
if !needs_cert {
continue;
}
let cert_spec = route.action.tls.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
let domain = domain.to_string();
// Skip if we already have a valid cert
let cm = cm_arc.lock().await;
if cm.store().has(&domain) {
debug!("Already have cert for {}, skipping auto-provision", domain);
continue;
}
drop(cm);
domains_to_provision.push(domain);
}
}
}
}
if domains_to_provision.is_empty() {
return;
}
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
// Start challenge server
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
return;
}
for domain in &domains_to_provision {
info!("Provisioning certificate for {}", domain);
let cm = cm_arc.lock().await;
let acme_client = cm.acme_client();
drop(cm);
if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| {
challenge_server_ref.set_challenge(
pending.token.clone(),
pending.key_authorization.clone(),
);
async move { Ok(()) }
}).await;
match result {
Ok((cert_pem, key_pem)) => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem,
key_pem,
ca_pem: None,
metadata: CertMetadata {
domain: domain.clone(),
source: CertSource::Acme,
issued_at: now,
expires_at: now + 90 * 86400, // 90 days
renewed_at: None,
},
};
let mut cm = cm_arc.lock().await;
if let Err(e) = cm.load_static(domain.clone(), bundle) {
error!("Failed to store certificate for {}: {}", domain, e);
}
info!("Certificate provisioned for {}", domain);
}
Err(e) => {
error!("Failed to provision certificate for {}: {}", domain, e);
}
}
}
}
challenge_server.stop().await;
}
/// Start the renewal timer background task.
/// The background task checks for expiring certificates and renews them.
fn start_renewal_timer(&mut self) {
let cm_arc = match self.cert_manager {
Some(ref cm) => Arc::clone(cm),
None => return,
};
let auto_renew = self.options.acme.as_ref()
.and_then(|a| a.auto_renew)
.unwrap_or(true);
if !auto_renew {
return;
}
let check_interval_hours = self.options.acme.as_ref()
.and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24);
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600);
let handle = tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours);
// Check which domains need renewal
let domains = {
let cm = cm_arc.lock().await;
cm.check_renewals()
};
if domains.is_empty() {
debug!("No certificates need renewal");
continue;
}
info!("Renewing {} certificate(s)", domains.len());
// Start challenge server for renewals
let mut cs = challenge_server::ChallengeServer::new();
if let Err(e) = cs.start(acme_port).await {
error!("Failed to start challenge server for renewal: {}", e);
continue;
}
for domain in &domains {
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
match result {
Ok(_bundle) => {
info!("Successfully renewed certificate for {}", domain);
}
Err(e) => {
error!("Failed to renew certificate for {}: {}", domain, e);
}
}
}
cs.stop().await;
}
});
self.renewal_handle = Some(handle);
}
/// Stop the proxy gracefully.
pub async fn stop(&mut self) -> Result<()> {
if !self.started {
return Ok(());
}
info!("Stopping RustProxy...");
// Stop renewal timer
if let Some(handle) = self.renewal_handle.take() {
handle.abort();
}
// Stop challenge server if running
if let Some(ref mut cs) = self.challenge_server {
cs.stop().await;
}
self.challenge_server = None;
// Clean up NFTables rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup failed: {}", e);
}
}
self.nft_manager = None;
if let Some(ref mut listener) = self.listener_manager {
listener.graceful_stop().await;
}
self.listener_manager = None;
self.started = false;
info!("RustProxy stopped");
Ok(())
}
/// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes
rustproxy_config::validate_routes(&routes)
.map_err(|errors| {
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
})?;
let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports",
new_manager.route_count(), new_ports.len());
// Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
listener.listening_ports()
} else {
vec![]
};
// Atomically swap the route table
let new_manager = Arc::new(new_manager);
self.route_table.store(Arc::clone(&new_manager));
// Update listener manager
if let Some(ref mut listener) = self.listener_manager {
listener.update_route_manager(Arc::clone(&new_manager));
// Update TLS configs
let mut tls_configs = Self::extract_tls_configs(&routes);
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
}
}
}
listener.set_tls_configs(tls_configs);
// Add new ports
for port in &new_ports {
if !old_ports.contains(port) {
listener.add_port(*port).await?;
}
}
// Remove old ports no longer needed
for port in &old_ports {
if !new_ports.contains(port) {
listener.remove_port(*port);
}
}
}
// Update NFTables rules: remove old, apply new
self.update_nftables_rules(&routes).await;
self.options.routes = routes;
Ok(())
}
/// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref()
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
// Find the route by name
let route = self.options.routes.iter()
.find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
// Start challenge server
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
drop(cm);
cs.stop().await;
let bundle = result
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
}
}
listener.set_tls_configs(tls_configs);
}
info!("Certificate provisioned and loaded for route '{}'", route_name);
Ok(())
}
/// Renew a certificate for a named route.
pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> {
// Renewal is just re-provisioning
self.provision_certificate(route_name).await
}
/// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter()
.find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
if let Some(bundle) = cm.get_cert(&domain) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
return Some(CertStatus {
domain,
source: format!("{:?}", bundle.metadata.source),
expires_at: bundle.metadata.expires_at,
is_valid: bundle.metadata.expires_at > now,
});
}
}
None
}
/// Get current metrics snapshot.
pub fn get_metrics(&self) -> Metrics {
self.metrics.snapshot()
}
/// Add a listening port at runtime.
pub async fn add_listening_port(&mut self, port: u16) -> Result<()> {
if let Some(ref mut listener) = self.listener_manager {
listener.add_port(port).await?;
}
Ok(())
}
/// Remove a listening port at runtime.
pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> {
if let Some(ref mut listener) = self.listener_manager {
listener.remove_port(port);
}
Ok(())
}
/// Get all currently listening ports.
pub fn get_listening_ports(&self) -> Vec<u16> {
self.listener_manager
.as_ref()
.map(|l| l.listening_ports())
.unwrap_or_default()
}
/// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
Statistics {
active_connections: self.metrics.active_connections(),
total_connections: self.metrics.total_connections(),
routes_count: self.route_table.load().route_count() as u64,
listening_ports: self.get_listening_ports(),
uptime_seconds: uptime,
}
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
info!("Socket handler relay path set to: {:?}", path);
self.socket_handler_relay_path = path;
}
/// Get the current socket handler relay path.
pub fn get_socket_handler_relay_path(&self) -> Option<&str> {
self.socket_handler_relay_path.as_deref()
}
/// Load a certificate for a domain and hot-swap the TLS configuration.
pub async fn load_certificate(
&mut self,
domain: &str,
cert_pem: String,
key_pem: String,
ca_pem: Option<String>,
) -> Result<()> {
info!("Loading certificate for domain: {}", domain);
// Store in cert manager if available
if let Some(ref cm_arc) = self.cert_manager {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
ca_pem: ca_pem.clone(),
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: now,
expires_at: now + 90 * 86400, // assume 90 days
renewed_at: None,
},
};
let mut cm = cm_arc.lock().await;
cm.load_static(domain.to_string(), bundle)
.map_err(|e| anyhow::anyhow!("Failed to store certificate: {}", e))?;
}
// Hot-swap TLS config on the listener
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
// Add the new cert
tls_configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
});
// Also include all existing certs from cert manager
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
}
}
}
listener.set_tls_configs(tls_configs);
}
info!("Certificate loaded and TLS config updated for {}", domain);
Ok(())
}
/// Get NFTables status.
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
match &self.nft_manager {
Some(nft) => Ok(nft.status()),
None => Ok(HashMap::new()),
}
}
/// Apply NFTables rules for routes using the nftables forwarding engine.
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
let nft_routes: Vec<&RouteConfig> = routes.iter()
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
.collect();
if nft_routes.is_empty() {
return;
}
info!("Applying NFTables rules for {} routes", nft_routes.len());
let table_name = nft_routes.iter()
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
.unwrap_or_else(|| "rustproxy".to_string());
let mut nft = NftManager::new(Some(table_name));
for route in &nft_routes {
let route_id = route.id.as_deref()
.or(route.name.as_deref())
.unwrap_or("unnamed");
let nft_options = match &route.action.nftables {
Some(opts) => opts.clone(),
None => rustproxy_config::NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
},
};
let targets = match &route.action.targets {
Some(targets) => targets,
None => {
warn!("NFTables route '{}' has no targets, skipping", route_id);
continue;
}
};
let source_ports = route.route_match.ports.to_ports();
for target in targets {
let target_host = target.host.first().to_string();
let target_port_spec = &target.port;
for &source_port in &source_ports {
let resolved_port = target_port_spec.resolve(source_port);
let rules = rule_builder::build_dnat_rule(
nft.table_name(),
"prerouting",
source_port,
&target_host,
resolved_port,
&nft_options,
);
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
}
}
}
}
self.nft_manager = Some(nft);
}
/// Update NFTables rules when routes change.
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
// Clean up old rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup during update failed: {}", e);
}
}
self.nft_manager = None;
// Apply new rules
self.apply_nftables_rules(new_routes).await;
}
/// Extract TLS configurations from route configs.
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
let mut configs = HashMap::new();
for route in routes {
let tls_mode = route.tls_mode();
let needs_cert = matches!(
tls_mode,
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
);
if !needs_cert {
continue;
}
let cert_spec = route.action.tls.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
});
}
}
}
}
configs
}
}

View File

@@ -0,0 +1,90 @@
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management;
use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy
#[derive(Parser, Debug)]
#[command(name = "rustproxy", version, about)]
struct Cli {
/// Path to JSON configuration file
#[arg(short, long, default_value = "config.json")]
config: String,
/// Log level (trace, debug, info, warn, error)
#[arg(short, long, default_value = "info")]
log_level: String,
/// Validate configuration without starting
#[arg(long)]
validate: bool,
/// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper)
#[arg(long)]
management: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize tracing - write to stderr so stdout is reserved for management IPC
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
)
.init();
// Management mode: JSON IPC over stdin/stdout
if cli.management {
tracing::info!("RustProxy starting in management mode...");
return management::management_loop().await;
}
tracing::info!("RustProxy starting...");
// Load configuration
let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!(
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
// Validate-only mode
if cli.validate {
match rustproxy_config::validate_routes(&options.routes) {
Ok(()) => {
tracing::info!("Configuration is valid");
return Ok(());
}
Err(errors) => {
for err in &errors {
tracing::error!("Validation error: {}", err);
}
anyhow::bail!("{} validation errors found", errors.len());
}
}
}
// Create and start proxy
let mut proxy = RustProxy::new(options)?;
proxy.start().await?;
// Wait for shutdown signal
tracing::info!("RustProxy is running. Press Ctrl+C to stop.");
tokio::signal::ctrl_c().await?;
tracing::info!("Shutdown signal received");
proxy.stop().await?;
tracing::info!("RustProxy shutdown complete");
Ok(())
}

View File

@@ -0,0 +1,470 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use crate::RustProxy;
use rustproxy_config::RustProxyOptions;
/// A management request from the TypeScript wrapper.
#[derive(Debug, Deserialize)]
pub struct ManagementRequest {
pub id: String,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// A management response back to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementResponse {
pub id: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// An unsolicited event from the proxy to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementEvent {
pub event: String,
pub data: serde_json::Value,
}
impl ManagementResponse {
fn ok(id: String, result: serde_json::Value) -> Self {
Self {
id,
success: true,
result: Some(result),
error: None,
}
}
fn err(id: String, message: String) -> Self {
Self {
id,
success: false,
result: None,
error: Some(message),
}
}
}
fn send_line(line: &str) {
// Use blocking stdout write - we're writing short JSON lines
use std::io::Write;
let stdout = std::io::stdout();
let mut handle = stdout.lock();
let _ = handle.write_all(line.as_bytes());
let _ = handle.write_all(b"\n");
let _ = handle.flush();
}
fn send_response(response: &ManagementResponse) {
match serde_json::to_string(response) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management response: {}", e),
}
}
fn send_event(event: &str, data: serde_json::Value) {
let evt = ManagementEvent {
event: event.to_string(),
data,
};
match serde_json::to_string(&evt) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management event: {}", e),
}
}
/// Run the management loop, reading JSON commands from stdin and writing responses to stdout.
pub async fn management_loop() -> Result<()> {
let stdin = BufReader::new(tokio::io::stdin());
let mut lines = stdin.lines();
let mut proxy: Option<RustProxy> = None;
send_event("ready", serde_json::json!({}));
loop {
let line = match lines.next_line().await {
Ok(Some(line)) => line,
Ok(None) => {
// stdin closed - parent process exited
info!("Management stdin closed, shutting down");
if let Some(ref mut p) = proxy {
let _ = p.stop().await;
}
break;
}
Err(e) => {
error!("Error reading management stdin: {}", e);
break;
}
};
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let request: ManagementRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
error!("Failed to parse management request: {}", e);
// Send error response without an ID
send_response(&ManagementResponse::err(
"unknown".to_string(),
format!("Failed to parse request: {}", e),
));
continue;
}
};
let response = handle_request(&request, &mut proxy).await;
send_response(&response);
}
Ok(())
}
async fn handle_request(
request: &ManagementRequest,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let id = request.id.clone();
match request.method.as_str() {
"start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
_ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)),
}
}
async fn handle_start(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
if proxy.is_some() {
return ManagementResponse::err(id.to_string(), "Proxy is already running".to_string());
}
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
};
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
Ok(o) => o,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
};
match RustProxy::new(options) {
Ok(mut p) => {
match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
}
}
async fn handle_stop(
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_mut() {
Some(p) => {
match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_update_routes(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let routes = match params.get("routes") {
Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
};
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
Ok(r) => r,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid routes: {}", e)),
};
match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
}
}
fn handle_get_metrics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let metrics = p.get_metrics();
match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
fn handle_get_statistics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let stats = p.get_statistics();
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
async fn handle_provision_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
}
}
async fn handle_renew_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
}
}
async fn handle_get_certificate_status(
id: &str,
params: &serde_json::Value,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_ref() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
})),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
}
}
fn handle_get_listening_ports(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let ports = p.get_listening_ports();
ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": ports }))
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": [] })),
}
}
async fn handle_get_nftables_status(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
match p.get_nftables_status().await {
Ok(status) => {
match serde_json::to_value(&status) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_set_socket_handler_relay(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("setSocketHandlerRelay: socket_path={:?}", socket_path);
p.set_socket_handler_relay_path(socket_path);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_add_listening_port(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
};
match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
}
}
async fn handle_remove_listening_port(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
};
match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
}
}
async fn handle_load_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
};
let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
};
let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
};
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
}
}

View File

@@ -0,0 +1,402 @@
use std::sync::atomic::{AtomicU16, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
/// Atomic port allocator starting at 19000 to avoid collisions.
static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000);
/// Get the next available port for testing.
pub fn next_port() -> u16 {
PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
}
/// Start a simple TCP echo server that echoes back whatever it receives.
/// Returns the join handle for the server task.
pub async fn start_echo_server(port: u16) -> JoinHandle<()> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind echo server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Start a TCP echo server that prefixes responses to identify which backend responded.
pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> {
let prefix = prefix.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind prefix echo server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let pfx = prefix.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let mut response = pfx.as_bytes().to_vec();
response.extend_from_slice(&buf[..n]);
if stream.write_all(&response).await.is_err() {
break;
}
}
});
}
})
}
/// Start a simple HTTP server that responds with a fixed status and body.
pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> {
let body = body.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind HTTP server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let b = body.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
// Read the request
let _n = stream.read(&mut buf).await.unwrap_or(0);
// Send response
let response = format!(
"HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
b.len(),
b,
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
})
}
/// Start an HTTP backend server that echoes back request details as JSON.
/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":"<name>"}
/// Supports keep-alive by reading HTTP requests properly.
pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> {
let name = backend_name.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let backend = name.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 16384];
// Read request data
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Parse first line: METHOD PATH HTTP/x.x
let first_line = req_str.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.split_whitespace().collect();
let method = parts.first().copied().unwrap_or("UNKNOWN");
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
let body = format!(
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
method, path, host, backend
);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body,
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
})
}
/// Wrap a future with a timeout, preventing tests from hanging.
pub async fn with_timeout<F, T>(future: F, secs: u64) -> Result<T, &'static str>
where
F: std::future::Future<Output = T>,
{
match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await {
Ok(result) => Ok(result),
Err(_) => Err("Test timed out"),
}
}
/// Wait briefly for a server to be ready by attempting TCP connections.
pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while start.elapsed() < timeout {
if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.is_ok()
{
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
false
}
/// Helper to create a minimal route config for testing.
pub fn make_test_route(
port: u16,
domain: Option<&str>,
target_host: &str,
target_port: u16,
) -> rustproxy_config::RouteConfig {
rustproxy_config::RouteConfig {
id: None,
route_match: rustproxy_config::RouteMatch {
ports: rustproxy_config::PortRange::Single(port),
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
tls_version: None,
headers: None,
},
action: rustproxy_config::RouteAction {
action_type: rustproxy_config::RouteActionType::Forward,
targets: Some(vec![rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(target_host.to_string()),
port: rustproxy_config::PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Start a simple WebSocket echo backend.
///
/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back,
/// then echoes all data received on the connection.
pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
tokio::spawn(async move {
// Read the HTTP upgrade request
let mut buf = vec![0u8; 4096];
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
// Compute Sec-WebSocket-Accept (simplified - just echo for test purposes)
// Real implementation would compute SHA-1 + base64
let accept_response = format!(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n",
ws_key
);
if stream.write_all(accept_response.as_bytes()).await.is_err() {
return;
}
// Echo all data back (raw TCP after upgrade)
let mut echo_buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut echo_buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if stream.write_all(&echo_buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Generate a self-signed certificate for testing using rcgen.
/// Returns (cert_pem, key_pem).
pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
(cert.pem(), key_pair.serialize_pem())
}
/// Start a TLS echo server using the given cert/key.
/// Returns the join handle.
pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
.expect("Failed to build TLS acceptor");
let acceptor = Arc::new(acceptor);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind TLS echo server");
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acc = acceptor.clone();
tokio::spawn(async move {
let mut tls_stream = match acc.accept(stream).await {
Ok(s) => s,
Err(_) => return,
};
let mut buf = vec![0u8; 65536];
loop {
let n = match tls_stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if tls_stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Helper to create a TLS terminate route with static cert for testing.
pub fn make_tls_terminate_route(
port: u16,
domain: &str,
target_host: &str,
target_port: u16,
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_test_route(port, Some(domain), target_host, target_port);
route.action.tls = Some(rustproxy_config::RouteTls {
mode: rustproxy_config::TlsMode::Terminate,
certificate: Some(rustproxy_config::CertificateSpec::Static(
rustproxy_config::CertificateConfig {
cert: cert_pem.to_string(),
key: key_pem.to_string(),
ca: None,
key_file: None,
cert_file: None,
},
)),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Helper to create a TLS passthrough route for testing.
pub fn make_tls_passthrough_route(
port: u16,
domain: Option<&str>,
target_host: &str,
target_port: u16,
) -> rustproxy_config::RouteConfig {
let mut route = make_test_route(port, domain, target_host, target_port);
route.action.tls = Some(rustproxy_config::RouteTls {
mode: rustproxy_config::TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}

View File

@@ -0,0 +1,453 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// Send a raw HTTP request and return the full response as a string.
async fn send_http_request(port: u16, host: &str, method: &str, path: &str) -> String {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let request = format!(
"{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
method, path, host,
);
stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}
/// Extract the body from a raw HTTP response string (after the \r\n\r\n).
fn extract_body(response: &str) -> &str {
response.split("\r\n\r\n").nth(1).unwrap_or("")
}
#[tokio::test]
async fn test_http_forward_basic() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
}, 10)
.await
.unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_host_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_http_echo_backend(backend1_port, "alpha").await;
let _b2 = start_http_echo_backend(backend2_port, "beta").await;
let options = RustProxyOptions {
routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let alpha_result = with_timeout(async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
// Test beta domain
let beta_result = with_timeout(async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_path_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_http_echo_backend(backend1_port, "api").await;
let _b2 = start_http_echo_backend(backend2_port, "web").await;
let mut api_route = make_test_route(proxy_port, None, "127.0.0.1", backend1_port);
api_route.route_match.path = Some("/api/**".to_string());
api_route.priority = Some(10);
let web_route = make_test_route(proxy_port, None, "127.0.0.1", backend2_port);
let options = RustProxyOptions {
routes: vec![api_route, web_route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test API path
let api_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
// Test web path (no /api prefix)
let web_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_cors_preflight() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send CORS preflight request
let request = format!(
"OPTIONS /api/data HTTP/1.1\r\nHost: example.com\r\nOrigin: http://localhost:3000\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
);
stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
// Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
assert!(result.to_lowercase().contains("access-control-allow-origin"),
"Expected CORS header, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_backend_error() {
let backend_port = next_port();
let proxy_port = next_port();
// Start an HTTP server that returns 500
let _backend = start_http_server(backend_port, 500, "Internal Error").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
}, 10)
.await
.unwrap();
// Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_no_route_matched() {
let proxy_port = next_port();
// Create a route only for a specific domain
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
}, 10)
.await
.unwrap();
// Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_backend_unavailable() {
let proxy_port = next_port();
let dead_port = next_port(); // No server running here
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
}, 10)
.await
.unwrap();
// Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_https_terminate_http_forward() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "httpproxy.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_http_echo_backend(backend_port, "tls-backend").await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send HTTP request through TLS
let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain
);
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_websocket_through_proxy() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_ws_echo_backend(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send WebSocket upgrade request
let request = format!(
"GET /ws HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; }
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break;
}
}
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap();
// Read echoed data
let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "ok");
proxy.stop().await.unwrap();
}
/// InsecureVerifier for test TLS client connections.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}

View File

@@ -0,0 +1,250 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn test_start_and_stop() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
// Not listening before start
assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
proxy.stop().await.unwrap();
// Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
}
#[tokio::test]
async fn test_double_start_fails() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Second start should fail
let result = proxy.start().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already started"));
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_update_routes_hot_reload() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Update routes atomically
let new_routes = vec![
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
];
let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok());
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_add_remove_listening_port() {
let port1 = next_port();
let port2 = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port1, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
// Add a new port
proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
// Remove the port
proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
// Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_get_statistics() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
let stats = proxy.get_statistics();
assert_eq!(stats.routes_count, 1);
assert!(stats.listening_ports.contains(&port));
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_invalid_routes_rejected() {
let options = RustProxyOptions {
routes: vec![{
let mut route = make_test_route(80, None, "127.0.0.1", 8080);
route.action.targets = None; // Invalid: forward without targets
route
}],
..Default::default()
};
let result = RustProxy::new(options);
assert!(result.is_err());
}
#[tokio::test]
async fn test_metrics_track_connections() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// No connections yet
let stats = proxy.get_statistics();
assert_eq!(stats.total_connections, 0);
// Make a connection and send data
{
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello").await.unwrap();
let mut buf = vec![0u8; 16];
let _ = stream.read(&mut buf).await;
}
// Small delay for metrics to update
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_metrics_track_bytes() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "metrics-test").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send HTTP request through proxy
{
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let request = b"GET /test HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n";
stream.write_all(request).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
assert!(!response.is_empty(), "Expected non-empty response");
}
// Small delay for metrics to update
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0,
"Expected some connections tracked, got {}", stats.total_connections);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_hot_reload_port_changes() {
let port1 = next_port();
let port2 = next_port();
let backend_port = next_port();
let _backend = start_echo_server(backend_port).await;
// Start with port1
let options = RustProxyOptions {
routes: vec![make_test_route(port1, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
// Update routes to use port2 instead
let new_routes = vec![
make_test_route(port2, None, "127.0.0.1", backend_port),
];
proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
// Verify port2 works
let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,197 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[tokio::test]
async fn test_tcp_forward_echo() {
let backend_port = next_port();
let proxy_port = next_port();
// Start echo backend
let _backend = start_echo_server(backend_port).await;
// Configure proxy
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
// Connect and send data
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello world").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert_eq!(result, "hello world");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_large_payload() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send 1MB of data
let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
// Read all back
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 10)
.await
.unwrap();
assert_eq!(result, 1_000_000);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_multiple_connections() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
handles.push(tokio::spawn(async move {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 10)
.await
.unwrap();
assert_eq!(result.len(), 10);
for (i, r) in result.iter().enumerate() {
assert_eq!(r, &format!("connection-{}", i));
}
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_backend_unreachable() {
let proxy_port = next_port();
let dead_port = next_port(); // No server on this port
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
}, 5)
.await
.unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_bidirectional() {
let backend_port = next_port();
let proxy_port = next_port();
// Start a prefix echo server to verify data flows in both directions
let _backend = start_prefix_echo_server(backend_port, "REPLY:").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"test data").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert_eq!(result, "REPLY:test data");
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,247 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
/// Build a minimal TLS ClientHello with the given SNI domain.
/// This is enough for the proxy's SNI parser to extract the domain.
fn build_client_hello(domain: &str) -> Vec<u8> {
let domain_bytes = domain.as_bytes();
let sni_length = domain_bytes.len() as u16;
// Server Name extension (type 0x0000)
let mut sni_ext = Vec::new();
sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name
let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name
sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length
sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length
sni_ext.push(0x00); // host_name type
sni_ext.extend_from_slice(&sni_length.to_be_bytes());
sni_ext.extend_from_slice(domain_bytes);
let extensions_length = sni_ext.len() as u16;
// ClientHello message
let mut client_hello = Vec::new();
client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version
client_hello.extend_from_slice(&[0x00; 32]); // random
client_hello.push(0x00); // session_id length
client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite)
client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null)
client_hello.extend_from_slice(&extensions_length.to_be_bytes());
client_hello.extend_from_slice(&sni_ext);
let hello_len = client_hello.len() as u32;
// Handshake wrapper (type 1 = ClientHello)
let mut handshake = Vec::new();
handshake.push(0x01); // ClientHello
handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length
handshake.extend_from_slice(&client_hello);
let hs_len = handshake.len() as u16;
// TLS record
let mut record = Vec::new();
record.push(0x16); // ContentType: Handshake
record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version)
record.extend_from_slice(&hs_len.to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[tokio::test]
async fn test_tls_passthrough_sni_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await;
let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
// Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
// Now test routing to backend2
let result2 = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_unknown_sni() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
}, 5)
.await
.unwrap();
assert!(result, "Unknown SNI should result in dropped connection");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_wildcard_domain() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_multiple_domains() {
let b1_port = next_port();
let b2_port = next_port();
let b3_port = next_port();
let proxy_port = next_port();
let _b1 = start_prefix_echo_server(b1_port, "B1:").await;
let _b2 = start_prefix_echo_server(b2_port, "B2:").await;
let _b3 = start_prefix_echo_server(b3_port, "B3:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port),
make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port),
make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
for (domain, expected_prefix) in [
("alpha.example.com", "B1:"),
("beta.example.com", "B2:"),
("gamma.example.com", "B3:"),
] {
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(
result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}",
domain, expected_prefix, result
);
}
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,324 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// Create a rustls client config that trusts self-signed certs.
fn make_insecure_tls_client_config() -> Arc<rustls::ClientConfig> {
let _ = rustls::crypto::ring::default_provider().install_default();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
Arc::new(config)
}
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}
#[tokio::test]
async fn test_tls_terminate_basic() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "test.example.com";
// Generate self-signed cert
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
// Start plain TCP echo backend (proxy terminates TLS, sends plain to backend)
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "hello TLS");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_and_reencrypt() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "reencrypt.example.com";
let backend_domain = "backend.internal";
// Generate certs
let (proxy_cert, proxy_key) = generate_self_signed_cert(domain);
let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain);
// Start TLS echo backend
let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await;
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "hello reencrypt");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_sni_cert_selection() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let (cert1, key1) = generate_self_signed_cert("alpha.example.com");
let (cert2, key2) = generate_self_signed_cert("beta.example.com");
let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await;
let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_large_payload() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "large.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send 1MB of data
let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap();
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 15)
.await
.unwrap();
assert_eq!(result, 1_000_000);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_concurrent() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "concurrent.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
let dom = domain.to_string();
handles.push(tokio::spawn(async move {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 15)
.await
.unwrap();
assert_eq!(result.len(), 10);
for (i, r) in result.iter().enumerate() {
assert_eq!(r, &format!("conn-{}", i));
}
proxy.stop().await.unwrap();
}