feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates

This commit is contained in:
2026-02-09 10:55:46 +00:00
parent a31fee41df
commit 1df3b7af4a
151 changed files with 16927 additions and 19432 deletions

1760
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

98
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,98 @@
[workspace]
resolver = "2"
members = [
"crates/rustproxy",
"crates/rustproxy-config",
"crates/rustproxy-routing",
"crates/rustproxy-tls",
"crates/rustproxy-passthrough",
"crates/rustproxy-http",
"crates/rustproxy-nftables",
"crates/rustproxy-metrics",
"crates/rustproxy-security",
]
[workspace.package]
version = "0.1.0"
edition = "2021"
license = "MIT"
authors = ["Lossless GmbH <hello@lossless.com>"]
[workspace.dependencies]
# Async runtime
tokio = { version = "1", features = ["full"] }
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# HTTP proxy engine (hyper-based)
hyper = { version = "1", features = ["http1", "http2", "server", "client"] }
hyper-util = { version = "0.1", features = ["tokio", "http1", "http2", "client-legacy", "server-auto"] }
http-body-util = "0.1"
bytes = "1"
# ACME / Let's Encrypt
instant-acme = { version = "0.7", features = ["hyper-rustls"] }
# TLS for passthrough SNI
rustls = { version = "0.23", features = ["ring"] }
tokio-rustls = "0.26"
rustls-pemfile = "2"
# Self-signed cert generation for tests
rcgen = "0.13"
# Temp directories for tests
tempfile = "3"
# Lock-free atomics
arc-swap = "1"
# Concurrent maps
dashmap = "6"
# Domain wildcard matching
glob-match = "0.2"
# IP/CIDR parsing
ipnet = "2"
# JWT authentication
jsonwebtoken = "9"
# Structured logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Error handling
thiserror = "2"
anyhow = "1"
# CLI
clap = { version = "4", features = ["derive"] }
# Regex for URL rewriting
regex = "1"
# Base64 for basic auth
base64 = "0.22"
# Cancellation / utility
tokio-util = "0.7"
# Async traits
async-trait = "0.1"
# libc for uid checks
libc = "0.2"
# Internal crates
rustproxy-config = { path = "crates/rustproxy-config" }
rustproxy-routing = { path = "crates/rustproxy-routing" }
rustproxy-tls = { path = "crates/rustproxy-tls" }
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
rustproxy-http = { path = "crates/rustproxy-http" }
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
rustproxy-security = { path = "crates/rustproxy-security" }

145
rust/config/example.json Normal file
View File

@@ -0,0 +1,145 @@
{
"routes": [
{
"id": "https-passthrough",
"name": "HTTPS Passthrough to Backend",
"match": {
"ports": 443,
"domains": "backend.example.com"
},
"action": {
"type": "forward",
"targets": [
{
"host": "10.0.0.1",
"port": 443
}
],
"tls": {
"mode": "passthrough"
}
},
"priority": 10,
"enabled": true
},
{
"id": "https-terminate",
"name": "HTTPS Terminate for API",
"match": {
"ports": 443,
"domains": "api.example.com"
},
"action": {
"type": "forward",
"targets": [
{
"host": "localhost",
"port": 8080
}
],
"tls": {
"mode": "terminate",
"certificate": "auto"
}
},
"priority": 20,
"enabled": true
},
{
"id": "http-redirect",
"name": "HTTP to HTTPS Redirect",
"match": {
"ports": 80,
"domains": ["api.example.com", "www.example.com"]
},
"action": {
"type": "forward",
"targets": [
{
"host": "localhost",
"port": 8080
}
]
},
"priority": 0
},
{
"id": "load-balanced",
"name": "Load Balanced Backend",
"match": {
"ports": 443,
"domains": "*.example.com"
},
"action": {
"type": "forward",
"targets": [
{
"host": "backend1.internal",
"port": 8080
},
{
"host": "backend2.internal",
"port": 8080
},
{
"host": "backend3.internal",
"port": 8080
}
],
"tls": {
"mode": "terminate",
"certificate": "auto"
},
"loadBalancing": {
"algorithm": "round-robin",
"healthCheck": {
"path": "/health",
"interval": 30,
"timeout": 5,
"unhealthyThreshold": 3,
"healthyThreshold": 2
}
}
},
"security": {
"ipAllowList": ["10.0.0.0/8", "192.168.0.0/16"],
"maxConnections": 1000,
"rateLimit": {
"enabled": true,
"maxRequests": 100,
"window": 60,
"keyBy": "ip"
}
},
"headers": {
"request": {
"X-Forwarded-For": "{clientIp}",
"X-Real-IP": "{clientIp}"
},
"response": {
"X-Powered-By": "RustProxy"
},
"cors": {
"enabled": true,
"allowOrigin": "*",
"allowMethods": "GET,POST,PUT,DELETE,OPTIONS",
"allowHeaders": "Content-Type,Authorization",
"allowCredentials": false,
"maxAge": 86400
}
},
"priority": 5
}
],
"acme": {
"email": "admin@example.com",
"useProduction": false,
"port": 80
},
"connectionTimeout": 30000,
"socketTimeout": 3600000,
"maxConnectionsPerIp": 100,
"connectionRateLimitPerMinute": 300,
"keepAliveTreatment": "extended",
"enableDetailedLogging": false
}

View File

@@ -0,0 +1,13 @@
[package]
name = "rustproxy-config"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Configuration types for RustProxy, compatible with SmartProxy JSON schema"
[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
ipnet = { workspace = true }

View File

@@ -0,0 +1,334 @@
use crate::route_types::*;
use crate::tls_types::*;
/// Create a simple HTTP forwarding route.
/// Equivalent to SmartProxy's `createHttpRoute()`.
pub fn create_http_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains.into()),
path: None,
client_ip: None,
tls_version: None,
headers: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single(target_host.into()),
port: PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create an HTTPS termination route.
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
pub fn create_https_terminate_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Terminate,
certificate: Some(CertificateSpec::Auto("auto".to_string())),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create a TLS passthrough route.
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
pub fn create_https_passthrough_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create an HTTP-to-HTTPS redirect route.
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
pub fn create_http_to_https_redirect(
domains: impl Into<DomainSpec>,
) -> RouteConfig {
let domains = domains.into();
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains),
path: None,
client_ip: None,
tls_version: None,
headers: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: Some(RouteAdvanced {
timeout: None,
headers: None,
keep_alive: None,
static_files: None,
test_response: Some(RouteTestResponse {
status: 301,
headers: {
let mut h = std::collections::HashMap::new();
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
h
},
body: String::new(),
}),
url_rewrite: None,
}),
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: Some("HTTP to HTTPS Redirect".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create a complete HTTPS server with HTTP redirect.
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
pub fn create_complete_https_server(
domain: impl Into<String>,
target_host: impl Into<String>,
target_port: u16,
) -> Vec<RouteConfig> {
let domain = domain.into();
let target_host = target_host.into();
vec![
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
create_https_terminate_route(
DomainSpec::Single(domain),
target_host,
target_port,
),
]
}
/// Create a load balancer route.
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
pub fn create_load_balancer_route(
domains: impl Into<DomainSpec>,
targets: Vec<(String, u16)>,
tls: Option<RouteTls>,
) -> RouteConfig {
let route_targets: Vec<RouteTarget> = targets
.into_iter()
.map(|(host, port)| RouteTarget {
target_match: None,
host: HostSpec::Single(host),
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
})
.collect();
let port = if tls.is_some() { 443 } else { 80 };
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: Some(domains.into()),
path: None,
client_ip: None,
tls_version: None,
headers: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(route_targets),
tls,
websocket: None,
load_balancing: Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
}),
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: Some("Load Balancer".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
// Convenience conversions for DomainSpec
impl From<&str> for DomainSpec {
fn from(s: &str) -> Self {
DomainSpec::Single(s.to_string())
}
}
impl From<String> for DomainSpec {
fn from(s: String) -> Self {
DomainSpec::Single(s)
}
}
impl From<Vec<String>> for DomainSpec {
fn from(v: Vec<String>) -> Self {
DomainSpec::List(v)
}
}
impl From<Vec<&str>> for DomainSpec {
fn from(v: Vec<&str>) -> Self {
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls_types::TlsMode;
#[test]
fn test_create_http_route() {
let route = create_http_route("example.com", "localhost", 8080);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
assert_eq!(domains, vec!["example.com"]);
let target = &route.action.targets.as_ref().unwrap()[0];
assert_eq!(target.host.first(), "localhost");
assert_eq!(target.port.resolve(80), 8080);
assert!(route.action.tls.is_none());
}
#[test]
fn test_create_https_terminate_route() {
let route = create_https_terminate_route("api.example.com", "backend", 3000);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Terminate);
assert!(tls.certificate.as_ref().unwrap().is_auto());
}
#[test]
fn test_create_https_passthrough_route() {
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Passthrough);
assert!(tls.certificate.is_none());
}
#[test]
fn test_create_http_to_https_redirect() {
let route = create_http_to_https_redirect("example.com");
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert!(route.action.targets.is_none());
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
assert_eq!(test_response.status, 301);
assert!(test_response.headers.contains_key("Location"));
}
#[test]
fn test_create_complete_https_server() {
let routes = create_complete_https_server("example.com", "backend", 8080);
assert_eq!(routes.len(), 2);
// First route is HTTP redirect
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
// Second route is HTTPS terminate
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
}
#[test]
fn test_create_load_balancer_route() {
let targets = vec![
("backend1".to_string(), 8080),
("backend2".to_string(), 8080),
("backend3".to_string(), 8080),
];
let route = create_load_balancer_route("*.example.com", targets, None);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
let lb = route.action.load_balancing.as_ref().unwrap();
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
}
#[test]
fn test_domain_spec_from_str() {
let spec: DomainSpec = "example.com".into();
assert_eq!(spec.to_vec(), vec!["example.com"]);
}
#[test]
fn test_domain_spec_from_vec() {
let spec: DomainSpec = vec!["a.com", "b.com"].into();
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
}
}

View File

@@ -0,0 +1,19 @@
//! # rustproxy-config
//!
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
pub mod route_types;
pub mod proxy_options;
pub mod tls_types;
pub mod security_types;
pub mod validation;
pub mod helpers;
// Re-export all primary types
pub use route_types::*;
pub use proxy_options::*;
pub use tls_types::*;
pub use security_types::*;
pub use validation::*;
pub use helpers::*;

View File

@@ -0,0 +1,439 @@
use serde::{Deserialize, Serialize};
use crate::route_types::RouteConfig;
/// Global ACME configuration options.
/// Matches TypeScript: `IAcmeOptions`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AcmeOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
/// Required when any route uses certificate: 'auto'
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub environment: Option<AcmeEnvironment>,
/// Alias for email
#[serde(skip_serializing_if = "Option::is_none")]
pub account_email: Option<String>,
/// Port for HTTP-01 challenges (default: 80)
#[serde(skip_serializing_if = "Option::is_none")]
pub port: Option<u16>,
/// Use Let's Encrypt production (default: false)
#[serde(skip_serializing_if = "Option::is_none")]
pub use_production: Option<bool>,
/// Days before expiry to renew (default: 30)
#[serde(skip_serializing_if = "Option::is_none")]
pub renew_threshold_days: Option<u32>,
/// Enable automatic renewal (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_renew: Option<bool>,
/// Directory to store certificates (default: './certs')
#[serde(skip_serializing_if = "Option::is_none")]
pub certificate_store: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_configured_certs: Option<bool>,
/// How often to check for renewals (default: 24)
#[serde(skip_serializing_if = "Option::is_none")]
pub renew_check_interval_hours: Option<u32>,
}
/// ACME environment.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AcmeEnvironment {
Production,
Staging,
}
/// Default target configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultTarget {
pub host: String,
pub port: u16,
}
/// Default security configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultSecurity {
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_allow_list: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_block_list: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections: Option<u64>,
}
/// Default configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<DefaultTarget>,
#[serde(skip_serializing_if = "Option::is_none")]
pub security: Option<DefaultSecurity>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
}
/// Keep-alive treatment.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KeepAliveTreatment {
Standard,
Extended,
Immortal,
}
/// Metrics configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MetricsConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sample_interval_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retention_seconds: Option<u64>,
}
/// RustProxy configuration options.
/// Matches TypeScript: `ISmartProxyOptions`
///
/// This is the top-level configuration that can be loaded from a JSON file
/// or constructed programmatically.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RustProxyOptions {
/// The unified configuration array (required)
pub routes: Vec<RouteConfig>,
/// Preserve client IP when forwarding
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
/// List of trusted proxy IPs that can send PROXY protocol
#[serde(skip_serializing_if = "Option::is_none")]
pub proxy_ips: Option<Vec<String>>,
/// Global option to accept PROXY protocol
#[serde(skip_serializing_if = "Option::is_none")]
pub accept_proxy_protocol: Option<bool>,
/// Global option to send PROXY protocol to all targets
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
/// Global/default settings
#[serde(skip_serializing_if = "Option::is_none")]
pub defaults: Option<DefaultConfig>,
// ─── Timeout Settings ────────────────────────────────────────────
/// Timeout for establishing connection to backend (ms), default: 30000
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_timeout: Option<u64>,
/// Timeout for initial data/SNI (ms), default: 60000
#[serde(skip_serializing_if = "Option::is_none")]
pub initial_data_timeout: Option<u64>,
/// Socket inactivity timeout (ms), default: 3600000
#[serde(skip_serializing_if = "Option::is_none")]
pub socket_timeout: Option<u64>,
/// How often to check for inactive connections (ms), default: 60000
#[serde(skip_serializing_if = "Option::is_none")]
pub inactivity_check_interval: Option<u64>,
/// Default max connection lifetime (ms), default: 86400000
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connection_lifetime: Option<u64>,
/// Inactivity timeout (ms), default: 14400000
#[serde(skip_serializing_if = "Option::is_none")]
pub inactivity_timeout: Option<u64>,
/// Maximum time to wait for connections to close during shutdown (ms)
#[serde(skip_serializing_if = "Option::is_none")]
pub graceful_shutdown_timeout: Option<u64>,
// ─── Socket Optimization ─────────────────────────────────────────
/// Disable Nagle's algorithm (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub no_delay: Option<bool>,
/// Enable TCP keepalive (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<bool>,
/// Initial delay before sending keepalive probes (ms)
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_initial_delay: Option<u64>,
/// Maximum bytes to buffer during connection setup
#[serde(skip_serializing_if = "Option::is_none")]
pub max_pending_data_size: Option<u64>,
// ─── Enhanced Features ───────────────────────────────────────────
/// Disable inactivity checking entirely
#[serde(skip_serializing_if = "Option::is_none")]
pub disable_inactivity_check: Option<bool>,
/// Enable TCP keep-alive probes
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_keep_alive_probes: Option<bool>,
/// Enable detailed connection logging
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_detailed_logging: Option<bool>,
/// Enable TLS handshake debug logging
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_tls_debug_logging: Option<bool>,
/// Randomize timeouts to prevent thundering herd
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_randomized_timeouts: Option<bool>,
// ─── Rate Limiting ───────────────────────────────────────────────
/// Maximum simultaneous connections from a single IP
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections_per_ip: Option<u64>,
/// Max new connections per minute from a single IP
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_rate_limit_per_minute: Option<u64>,
// ─── Keep-Alive Settings ─────────────────────────────────────────
/// How to treat keep-alive connections
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_treatment: Option<KeepAliveTreatment>,
/// Multiplier for inactivity timeout for keep-alive connections
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_inactivity_multiplier: Option<f64>,
/// Extended lifetime for keep-alive connections (ms)
#[serde(skip_serializing_if = "Option::is_none")]
pub extended_keep_alive_lifetime: Option<u64>,
// ─── HttpProxy Integration ───────────────────────────────────────
/// Array of ports to forward to HttpProxy
#[serde(skip_serializing_if = "Option::is_none")]
pub use_http_proxy: Option<Vec<u16>>,
/// Port where HttpProxy is listening (default: 8443)
#[serde(skip_serializing_if = "Option::is_none")]
pub http_proxy_port: Option<u16>,
// ─── Metrics ─────────────────────────────────────────────────────
/// Metrics configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<MetricsConfig>,
// ─── ACME ────────────────────────────────────────────────────────
/// Global ACME configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub acme: Option<AcmeOptions>,
}
impl Default for RustProxyOptions {
fn default() -> Self {
Self {
routes: Vec::new(),
preserve_source_ip: None,
proxy_ips: None,
accept_proxy_protocol: None,
send_proxy_protocol: None,
defaults: None,
connection_timeout: None,
initial_data_timeout: None,
socket_timeout: None,
inactivity_check_interval: None,
max_connection_lifetime: None,
inactivity_timeout: None,
graceful_shutdown_timeout: None,
no_delay: None,
keep_alive: None,
keep_alive_initial_delay: None,
max_pending_data_size: None,
disable_inactivity_check: None,
enable_keep_alive_probes: None,
enable_detailed_logging: None,
enable_tls_debug_logging: None,
enable_randomized_timeouts: None,
max_connections_per_ip: None,
connection_rate_limit_per_minute: None,
keep_alive_treatment: None,
keep_alive_inactivity_multiplier: None,
extended_keep_alive_lifetime: None,
use_http_proxy: None,
http_proxy_port: None,
metrics: None,
acme: None,
}
}
}
impl RustProxyOptions {
/// Load configuration from a JSON file.
pub fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let options: Self = serde_json::from_str(&content)?;
Ok(options)
}
/// Get the effective connection timeout in milliseconds.
pub fn effective_connection_timeout(&self) -> u64 {
self.connection_timeout.unwrap_or(30_000)
}
/// Get the effective initial data timeout in milliseconds.
pub fn effective_initial_data_timeout(&self) -> u64 {
self.initial_data_timeout.unwrap_or(60_000)
}
/// Get the effective socket timeout in milliseconds.
pub fn effective_socket_timeout(&self) -> u64 {
self.socket_timeout.unwrap_or(3_600_000)
}
/// Get the effective max connection lifetime in milliseconds.
pub fn effective_max_connection_lifetime(&self) -> u64 {
self.max_connection_lifetime.unwrap_or(86_400_000)
}
/// Get all unique ports that routes listen on.
pub fn all_listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.routes
.iter()
.flat_map(|r| r.listening_ports())
.collect();
ports.sort();
ports.dedup();
ports
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::helpers::*;
#[test]
fn test_serde_roundtrip_minimal() {
let options = RustProxyOptions {
routes: vec![create_http_route("example.com", "localhost", 8080)],
..Default::default()
};
let json = serde_json::to_string(&options).unwrap();
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.routes.len(), 1);
}
#[test]
fn test_serde_roundtrip_full() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend1", 8080),
create_https_passthrough_route("b.com", "backend2", 443),
],
connection_timeout: Some(5000),
socket_timeout: Some(60000),
max_connections_per_ip: Some(100),
acme: Some(AcmeOptions {
enabled: Some(true),
email: Some("admin@example.com".to_string()),
environment: Some(AcmeEnvironment::Staging),
account_email: None,
port: None,
use_production: None,
renew_threshold_days: None,
auto_renew: None,
certificate_store: None,
skip_configured_certs: None,
renew_check_interval_hours: None,
}),
..Default::default()
};
let json = serde_json::to_string_pretty(&options).unwrap();
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.routes.len(), 2);
assert_eq!(parsed.connection_timeout, Some(5000));
}
#[test]
fn test_default_timeouts() {
let options = RustProxyOptions::default();
assert_eq!(options.effective_connection_timeout(), 30_000);
assert_eq!(options.effective_initial_data_timeout(), 60_000);
assert_eq!(options.effective_socket_timeout(), 3_600_000);
assert_eq!(options.effective_max_connection_lifetime(), 86_400_000);
}
#[test]
fn test_custom_timeouts() {
let options = RustProxyOptions {
connection_timeout: Some(5000),
initial_data_timeout: Some(10000),
socket_timeout: Some(30000),
max_connection_lifetime: Some(60000),
..Default::default()
};
assert_eq!(options.effective_connection_timeout(), 5000);
assert_eq!(options.effective_initial_data_timeout(), 10000);
assert_eq!(options.effective_socket_timeout(), 30000);
assert_eq!(options.effective_max_connection_lifetime(), 60000);
}
#[test]
fn test_all_listening_ports() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend", 8080), // port 80
create_https_passthrough_route("b.com", "backend", 443), // port 443
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
],
..Default::default()
};
let ports = options.all_listening_ports();
assert_eq!(ports, vec![80, 443]);
}
#[test]
fn test_camel_case_field_names() {
let options = RustProxyOptions {
connection_timeout: Some(5000),
max_connections_per_ip: Some(100),
keep_alive_treatment: Some(KeepAliveTreatment::Extended),
..Default::default()
};
let json = serde_json::to_string(&options).unwrap();
assert!(json.contains("connectionTimeout"));
assert!(json.contains("maxConnectionsPerIp"));
assert!(json.contains("keepAliveTreatment"));
}
#[test]
fn test_deserialize_example_json() {
let content = std::fs::read_to_string(
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
).unwrap();
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
assert_eq!(options.routes.len(), 4);
let ports = options.all_listening_ports();
assert!(ports.contains(&80));
assert!(ports.contains(&443));
}
}

View File

@@ -0,0 +1,603 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::tls_types::RouteTls;
use crate::security_types::RouteSecurity;
// ─── Port Range ──────────────────────────────────────────────────────
/// Port range specification format.
/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortRange {
/// Single port number
Single(u16),
/// Array of port numbers
List(Vec<u16>),
/// Array of port ranges
Ranges(Vec<PortRangeSpec>),
}
impl PortRange {
/// Expand the port range into a flat list of ports.
pub fn to_ports(&self) -> Vec<u16> {
match self {
PortRange::Single(p) => vec![*p],
PortRange::List(ports) => ports.clone(),
PortRange::Ranges(ranges) => {
ranges.iter().flat_map(|r| r.from..=r.to).collect()
}
}
}
}
/// A from-to port range.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PortRangeSpec {
pub from: u16,
pub to: u16,
}
// ─── Route Action Type ───────────────────────────────────────────────
/// Supported action types for route configurations.
/// Matches TypeScript: `type TRouteActionType = 'forward' | 'socket-handler'`
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum RouteActionType {
Forward,
SocketHandler,
}
// ─── Forwarding Engine ───────────────────────────────────────────────
/// Forwarding engine specification.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ForwardingEngine {
Node,
Nftables,
}
// ─── Route Match ─────────────────────────────────────────────────────
/// Domain specification: single string or array.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum DomainSpec {
Single(String),
List(Vec<String>),
}
impl DomainSpec {
pub fn to_vec(&self) -> Vec<&str> {
match self {
DomainSpec::Single(s) => vec![s.as_str()],
DomainSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
}
}
}
/// Header match value: either exact string or regex pattern.
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HeaderMatchValue {
Exact(String),
}
/// Route match criteria for incoming requests.
/// Matches TypeScript: `IRouteMatch`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteMatch {
/// Listen on these ports (required)
pub ports: PortRange,
/// Optional domain patterns to match (default: all domains)
#[serde(skip_serializing_if = "Option::is_none")]
pub domains: Option<DomainSpec>,
/// Match specific paths
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
/// Match specific client IPs
#[serde(skip_serializing_if = "Option::is_none")]
pub client_ip: Option<Vec<String>>,
/// Match specific TLS versions
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_version: Option<Vec<String>>,
/// Match specific HTTP headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
// ─── Target Match ────────────────────────────────────────────────────
/// Target-specific match criteria for sub-routing within a route.
/// Matches TypeScript: `ITargetMatch`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TargetMatch {
/// Match specific ports from the route
#[serde(skip_serializing_if = "Option::is_none")]
pub ports: Option<Vec<u16>>,
/// Match specific paths (supports wildcards like /api/*)
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
/// Match specific HTTP headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
/// Match specific HTTP methods
#[serde(skip_serializing_if = "Option::is_none")]
pub method: Option<Vec<String>>,
}
// ─── WebSocket Config ────────────────────────────────────────────────
/// WebSocket configuration.
/// Matches TypeScript: `IRouteWebSocket`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteWebSocket {
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub ping_interval: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ping_timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_payload_size: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub subprotocols: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rewrite_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_origins: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub authenticate_request: Option<bool>,
}
// ─── Load Balancing ──────────────────────────────────────────────────
/// Load balancing algorithm.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum LoadBalancingAlgorithm {
RoundRobin,
LeastConnections,
IpHash,
}
/// Health check configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HealthCheck {
pub path: String,
pub interval: u64,
pub timeout: u64,
pub unhealthy_threshold: u32,
pub healthy_threshold: u32,
}
/// Load balancing configuration.
/// Matches TypeScript: `IRouteLoadBalancing`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteLoadBalancing {
pub algorithm: LoadBalancingAlgorithm,
#[serde(skip_serializing_if = "Option::is_none")]
pub health_check: Option<HealthCheck>,
}
// ─── CORS ────────────────────────────────────────────────────────────
/// Allowed origin specification.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AllowOrigin {
Single(String),
List(Vec<String>),
}
/// CORS configuration for a route.
/// Matches TypeScript: `IRouteCors`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteCors {
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_origin: Option<AllowOrigin>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_methods: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_headers: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_credentials: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expose_headers: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_age: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preflight: Option<bool>,
}
// ─── Headers ─────────────────────────────────────────────────────────
/// Headers configuration.
/// Matches TypeScript: `IRouteHeaders`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteHeaders {
/// Headers to add/modify for requests to backend
#[serde(skip_serializing_if = "Option::is_none")]
pub request: Option<HashMap<String, String>>,
/// Headers to add/modify for responses to client
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<HashMap<String, String>>,
/// CORS configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub cors: Option<RouteCors>,
}
// ─── Static Files ────────────────────────────────────────────────────
/// Static file server configuration.
/// Matches TypeScript: `IRouteStaticFiles`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteStaticFiles {
pub root: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub directory: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub index_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub follow_symlinks: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub disable_directory_listing: Option<bool>,
}
// ─── Test Response ───────────────────────────────────────────────────
/// Test route response configuration.
/// Matches TypeScript: `IRouteTestResponse`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteTestResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: String,
}
// ─── URL Rewriting ───────────────────────────────────────────────────
/// URL rewriting configuration.
/// Matches TypeScript: `IRouteUrlRewrite`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteUrlRewrite {
/// RegExp pattern to match in URL
pub pattern: String,
/// Replacement pattern
pub target: String,
/// RegExp flags
#[serde(skip_serializing_if = "Option::is_none")]
pub flags: Option<String>,
/// Only apply to path, not query string
#[serde(skip_serializing_if = "Option::is_none")]
pub only_rewrite_path: Option<bool>,
}
// ─── Advanced Options ────────────────────────────────────────────────
/// Advanced options for route actions.
/// Matches TypeScript: `IRouteAdvanced`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAdvanced {
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub static_files: Option<RouteStaticFiles>,
#[serde(skip_serializing_if = "Option::is_none")]
pub test_response: Option<RouteTestResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url_rewrite: Option<RouteUrlRewrite>,
}
// ─── NFTables Options ────────────────────────────────────────────────
/// NFTables protocol type.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NfTablesProtocol {
Tcp,
Udp,
All,
}
/// NFTables-specific configuration options.
/// Matches TypeScript: `INfTablesOptions`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NfTablesOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<NfTablesProtocol>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_rate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_ip_sets: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_advanced_nat: Option<bool>,
}
// ─── Backend Protocol ────────────────────────────────────────────────
/// Backend protocol.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum BackendProtocol {
Http1,
Http2,
}
/// Action options.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ActionOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub backend_protocol: Option<BackendProtocol>,
/// Catch-all for additional options
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
// ─── Route Target ────────────────────────────────────────────────────
/// Host specification: single string or array of strings.
/// Note: Dynamic host functions are only available via programmatic API, not JSON.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HostSpec {
Single(String),
List(Vec<String>),
}
impl HostSpec {
pub fn to_vec(&self) -> Vec<&str> {
match self {
HostSpec::Single(s) => vec![s.as_str()],
HostSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
}
}
pub fn first(&self) -> &str {
match self {
HostSpec::Single(s) => s.as_str(),
HostSpec::List(v) => v.first().map(|s| s.as_str()).unwrap_or(""),
}
}
}
/// Port specification: number or "preserve".
/// Note: Dynamic port functions are only available via programmatic API, not JSON.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortSpec {
/// Fixed port number
Fixed(u16),
/// Special string value like "preserve"
Special(String),
}
impl PortSpec {
/// Resolve the port, using incoming_port when "preserve" is specified.
pub fn resolve(&self, incoming_port: u16) -> u16 {
match self {
PortSpec::Fixed(p) => *p,
PortSpec::Special(s) if s == "preserve" => incoming_port,
PortSpec::Special(_) => incoming_port, // fallback
}
}
}
/// Target configuration for forwarding with sub-matching and overrides.
/// Matches TypeScript: `IRouteTarget`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteTarget {
/// Optional sub-matching criteria within the route
#[serde(rename = "match")]
#[serde(skip_serializing_if = "Option::is_none")]
pub target_match: Option<TargetMatch>,
/// Target host(s)
pub host: HostSpec,
/// Target port
pub port: PortSpec,
/// Override route-level TLS settings
#[serde(skip_serializing_if = "Option::is_none")]
pub tls: Option<RouteTls>,
/// Override route-level WebSocket settings
#[serde(skip_serializing_if = "Option::is_none")]
pub websocket: Option<RouteWebSocket>,
/// Override route-level load balancing
#[serde(skip_serializing_if = "Option::is_none")]
pub load_balancing: Option<RouteLoadBalancing>,
/// Override route-level proxy protocol setting
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
/// Override route-level headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<RouteHeaders>,
/// Override route-level advanced settings
#[serde(skip_serializing_if = "Option::is_none")]
pub advanced: Option<RouteAdvanced>,
/// Priority for matching (higher values checked first, default: 0)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
}
// ─── Route Action ────────────────────────────────────────────────────
/// Action configuration for route handling.
/// Matches TypeScript: `IRouteAction`
///
/// Note: `socketHandler` is not serializable in JSON. Use the programmatic API
/// for socket handler routes.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAction {
/// Basic routing type
#[serde(rename = "type")]
pub action_type: RouteActionType,
/// Targets for forwarding (array supports multiple targets with sub-matching)
#[serde(skip_serializing_if = "Option::is_none")]
pub targets: Option<Vec<RouteTarget>>,
/// TLS handling (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub tls: Option<RouteTls>,
/// WebSocket support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub websocket: Option<RouteWebSocket>,
/// Load balancing options (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub load_balancing: Option<RouteLoadBalancing>,
/// Advanced options (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub advanced: Option<RouteAdvanced>,
/// Additional options
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<ActionOptions>,
/// Forwarding engine specification
#[serde(skip_serializing_if = "Option::is_none")]
pub forwarding_engine: Option<ForwardingEngine>,
/// NFTables-specific options
#[serde(skip_serializing_if = "Option::is_none")]
pub nftables: Option<NfTablesOptions>,
/// PROXY protocol support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
}
// ─── Route Config ────────────────────────────────────────────────────
/// The core unified configuration interface.
/// Matches TypeScript: `IRouteConfig`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteConfig {
/// Unique identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// What to match
#[serde(rename = "match")]
pub route_match: RouteMatch,
/// What to do with matched traffic
pub action: RouteAction,
/// Custom headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<RouteHeaders>,
/// Security features
#[serde(skip_serializing_if = "Option::is_none")]
pub security: Option<RouteSecurity>,
/// Human-readable name for this route
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
/// Description of the route's purpose
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// Controls matching order (higher = matched first)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
/// Arbitrary tags for categorization
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
/// Whether the route is active (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
}
impl RouteConfig {
/// Check if this route is enabled (defaults to true).
pub fn is_enabled(&self) -> bool {
self.enabled.unwrap_or(true)
}
/// Get the effective priority (defaults to 0).
pub fn effective_priority(&self) -> i32 {
self.priority.unwrap_or(0)
}
/// Get all ports this route listens on.
pub fn listening_ports(&self) -> Vec<u16> {
self.route_match.ports.to_ports()
}
/// Get the TLS mode for this route (from action-level or first target).
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
// Check action-level TLS first
if let Some(tls) = &self.action.tls {
return Some(&tls.mode);
}
// Check first target's TLS
if let Some(targets) = &self.action.targets {
if let Some(first) = targets.first() {
if let Some(tls) = &first.tls {
return Some(&tls.mode);
}
}
}
None
}
}

View File

@@ -0,0 +1,132 @@
use serde::{Deserialize, Serialize};
/// Rate limiting configuration.
/// Matches TypeScript: `IRouteRateLimit`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteRateLimit {
pub enabled: bool,
pub max_requests: u64,
/// Time window in seconds
pub window: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub key_by: Option<RateLimitKeyBy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub header_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_message: Option<String>,
}
/// Rate limit key selection.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RateLimitKeyBy {
Ip,
Path,
Header,
}
/// Authentication type.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AuthenticationType {
Basic,
Digest,
Oauth,
Jwt,
}
/// Authentication credentials.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthCredentials {
pub username: String,
pub password: String,
}
/// Authentication options.
/// Matches TypeScript: `IRouteAuthentication`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAuthentication {
#[serde(rename = "type")]
pub auth_type: AuthenticationType,
#[serde(skip_serializing_if = "Option::is_none")]
pub credentials: Option<Vec<AuthCredentials>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub realm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt_issuer: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_redirect_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<serde_json::Value>,
}
/// Basic auth configuration.
/// Matches TypeScript: `IRouteSecurity.basicAuth`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BasicAuthConfig {
pub enabled: bool,
pub users: Vec<AuthCredentials>,
#[serde(skip_serializing_if = "Option::is_none")]
pub realm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_paths: Option<Vec<String>>,
}
/// JWT auth configuration.
/// Matches TypeScript: `IRouteSecurity.jwtAuth`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JwtAuthConfig {
pub enabled: bool,
pub secret: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub algorithm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audience: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_paths: Option<Vec<String>>,
}
/// Security options for routes.
/// Matches TypeScript: `IRouteSecurity`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteSecurity {
/// IP addresses that are allowed to connect
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_allow_list: Option<Vec<String>>,
/// IP addresses that are blocked from connecting
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_block_list: Option<Vec<String>>,
/// Maximum concurrent connections
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections: Option<u64>,
/// Authentication configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub authentication: Option<RouteAuthentication>,
/// Rate limiting
#[serde(skip_serializing_if = "Option::is_none")]
pub rate_limit: Option<RouteRateLimit>,
/// Basic auth
#[serde(skip_serializing_if = "Option::is_none")]
pub basic_auth: Option<BasicAuthConfig>,
/// JWT auth
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt_auth: Option<JwtAuthConfig>,
}

View File

@@ -0,0 +1,93 @@
use serde::{Deserialize, Serialize};
/// TLS handling modes for route configurations.
/// Matches TypeScript: `type TTlsMode = 'passthrough' | 'terminate' | 'terminate-and-reencrypt'`
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum TlsMode {
Passthrough,
Terminate,
TerminateAndReencrypt,
}
/// Static certificate configuration (PEM-encoded).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CertificateConfig {
/// PEM-encoded private key
pub key: String,
/// PEM-encoded certificate
pub cert: String,
/// PEM-encoded CA chain
#[serde(skip_serializing_if = "Option::is_none")]
pub ca: Option<String>,
/// Path to key file (overrides key)
#[serde(skip_serializing_if = "Option::is_none")]
pub key_file: Option<String>,
/// Path to cert file (overrides cert)
#[serde(skip_serializing_if = "Option::is_none")]
pub cert_file: Option<String>,
}
/// Certificate specification: either automatic (ACME) or static.
/// Matches TypeScript: `certificate?: 'auto' | { key, cert, ca?, keyFile?, certFile? }`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CertificateSpec {
/// Use ACME (Let's Encrypt) for automatic provisioning
Auto(String), // "auto"
/// Static certificate configuration
Static(CertificateConfig),
}
impl CertificateSpec {
/// Check if this is an auto (ACME) certificate
pub fn is_auto(&self) -> bool {
matches!(self, CertificateSpec::Auto(s) if s == "auto")
}
}
/// ACME configuration for automatic certificate provisioning.
/// Matches TypeScript: `IRouteAcme`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAcme {
/// Contact email for ACME account
pub email: String,
/// Use production ACME servers (default: false)
#[serde(skip_serializing_if = "Option::is_none")]
pub use_production: Option<bool>,
/// Port for HTTP-01 challenges (default: 80)
#[serde(skip_serializing_if = "Option::is_none")]
pub challenge_port: Option<u16>,
/// Days before expiry to renew (default: 30)
#[serde(skip_serializing_if = "Option::is_none")]
pub renew_before_days: Option<u32>,
}
/// TLS configuration for route actions.
/// Matches TypeScript: `IRouteTls`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteTls {
/// TLS mode (passthrough, terminate, terminate-and-reencrypt)
pub mode: TlsMode,
/// Certificate configuration (auto or static)
#[serde(skip_serializing_if = "Option::is_none")]
pub certificate: Option<CertificateSpec>,
/// ACME options when certificate is 'auto'
#[serde(skip_serializing_if = "Option::is_none")]
pub acme: Option<RouteAcme>,
/// Allowed TLS versions
#[serde(skip_serializing_if = "Option::is_none")]
pub versions: Option<Vec<String>>,
/// OpenSSL cipher string
#[serde(skip_serializing_if = "Option::is_none")]
pub ciphers: Option<String>,
/// Use server's cipher preferences
#[serde(skip_serializing_if = "Option::is_none")]
pub honor_cipher_order: Option<bool>,
/// TLS session timeout in seconds
#[serde(skip_serializing_if = "Option::is_none")]
pub session_timeout: Option<u64>,
}

View File

@@ -0,0 +1,158 @@
use thiserror::Error;
use crate::route_types::{RouteConfig, RouteActionType};
/// Validation errors for route configurations.
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("Route '{name}' has no targets but action type is 'forward'")]
MissingTargets { name: String },
#[error("Route '{name}' has empty targets list")]
EmptyTargets { name: String },
#[error("Route '{name}' has no ports specified")]
NoPorts { name: String },
#[error("Route '{name}' port {port} is invalid (must be 1-65535)")]
InvalidPort { name: String, port: u16 },
#[error("Route '{name}': socket-handler action type is not supported in JSON config")]
SocketHandlerInJson { name: String },
#[error("Route '{name}': duplicate route ID '{id}'")]
DuplicateId { name: String, id: String },
#[error("Route '{name}': {message}")]
Custom { name: String, message: String },
}
/// Validate a single route configuration.
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
let name = route.name.clone().unwrap_or_else(|| {
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
});
// Check ports
let ports = route.listening_ports();
if ports.is_empty() {
errors.push(ValidationError::NoPorts { name: name.clone() });
}
for &port in &ports {
if port == 0 {
errors.push(ValidationError::InvalidPort {
name: name.clone(),
port,
});
}
}
// Check forward action has targets
if route.action.action_type == RouteActionType::Forward {
match &route.action.targets {
None => {
errors.push(ValidationError::MissingTargets { name: name.clone() });
}
Some(targets) if targets.is_empty() => {
errors.push(ValidationError::EmptyTargets { name: name.clone() });
}
_ => {}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
/// Validate an entire list of routes.
pub fn validate_routes(routes: &[RouteConfig]) -> Result<(), Vec<ValidationError>> {
let mut all_errors = Vec::new();
let mut seen_ids = std::collections::HashSet::new();
for route in routes {
// Check for duplicate IDs
if let Some(id) = &route.id {
if !seen_ids.insert(id.clone()) {
let name = route.name.clone().unwrap_or_else(|| id.clone());
all_errors.push(ValidationError::DuplicateId {
name,
id: id.clone(),
});
}
}
// Validate individual route
if let Err(errors) = validate_route(route) {
all_errors.extend(errors);
}
}
if all_errors.is_empty() {
Ok(())
} else {
Err(all_errors)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::route_types::*;
fn make_valid_route() -> RouteConfig {
crate::helpers::create_http_route("example.com", "localhost", 8080)
}
#[test]
fn test_valid_route_passes() {
let route = make_valid_route();
assert!(validate_route(&route).is_ok());
}
#[test]
fn test_missing_targets() {
let mut route = make_valid_route();
route.action.targets = None;
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
}
#[test]
fn test_empty_targets() {
let mut route = make_valid_route();
route.action.targets = Some(vec![]);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
}
#[test]
fn test_invalid_port_zero() {
let mut route = make_valid_route();
route.route_match.ports = PortRange::Single(0);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
}
#[test]
fn test_duplicate_ids() {
let mut r1 = make_valid_route();
r1.id = Some("route-1".to_string());
let mut r2 = make_valid_route();
r2.id = Some("route-1".to_string());
let errors = validate_routes(&[r1, r2]).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
}
#[test]
fn test_multiple_errors_collected() {
let mut r1 = make_valid_route();
r1.action.targets = None; // MissingTargets
r1.route_match.ports = PortRange::Single(0); // InvalidPort
let errors = validate_route(&r1).unwrap_err();
assert!(errors.len() >= 2);
}
}

View File

@@ -0,0 +1,24 @@
[package]
name = "rustproxy-http"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Hyper-based HTTP proxy service for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-security = { workspace = true }
rustproxy-metrics = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
regex = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
dashmap = { workspace = true }

View File

@@ -0,0 +1,14 @@
//! # rustproxy-http
//!
//! Hyper-based HTTP proxy service for RustProxy.
//! Handles HTTP request parsing, route-based forwarding, and response filtering.
pub mod proxy_service;
pub mod request_filter;
pub mod response_filter;
pub mod template;
pub mod upstream_selector;
pub use proxy_service::*;
pub use template::*;
pub use upstream_selector::*;

View File

@@ -0,0 +1,827 @@
//! Hyper-based HTTP proxy service.
//!
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
//! parses HTTP requests, matches routes, and forwards to upstream backends.
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
use std::collections::HashMap;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use regex::Regex;
use tokio::net::TcpStream;
use tracing::{debug, error, info, warn};
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use crate::request_filter::RequestFilter;
use crate::response_filter::ResponseFilter;
use crate::upstream_selector::UpstreamSelector;
/// HTTP proxy service that processes HTTP traffic.
pub struct HttpProxyService {
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
upstream_selector: UpstreamSelector,
}
impl HttpProxyService {
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
Self {
route_manager,
metrics,
upstream_selector: UpstreamSelector::new(),
}
}
/// Handle an incoming HTTP connection on a plain TCP stream.
pub async fn handle_connection(
self: Arc<Self>,
stream: TcpStream,
peer_addr: std::net::SocketAddr,
port: u16,
) {
self.handle_io(stream, peer_addr, port).await;
}
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
///
/// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2,
/// use `handle_io_auto` instead.
pub async fn handle_io<I>(
self: Arc<Self>,
stream: I,
peer_addr: std::net::SocketAddr,
port: u16,
)
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let io = TokioIo::new(stream);
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let svc = Arc::clone(&self);
let peer = peer_addr;
async move {
svc.handle_request(req, peer, port).await
}
});
// Use http1::Builder with upgrades for WebSocket support
let conn = hyper::server::conn::http1::Builder::new()
.keep_alive(true)
.serve_connection(io, service)
.with_upgrades();
if let Err(e) = conn.await {
debug!("HTTP connection error from {}: {}", peer_addr, e);
}
}
/// Handle a single HTTP request.
async fn handle_request(
&self,
req: Request<Incoming>,
peer_addr: std::net::SocketAddr,
port: u16,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let host = req.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.map(|h| {
// Strip port from host header
h.split(':').next().unwrap_or(h).to_string()
});
let path = req.uri().path().to_string();
let method = req.method().clone();
// Extract headers for matching
let headers: HashMap<String, String> = req.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
// Check for CORS preflight
if method == hyper::Method::OPTIONS {
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
return Ok(response);
}
}
// Match route
let ctx = rustproxy_routing::MatchContext {
port,
domain: host.as_deref(),
path: Some(&path),
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: Some(&headers),
is_tls: false,
};
let route_match = match self.route_manager.find_route(&ctx) {
Some(rm) => rm,
None => {
debug!("No route matched for HTTP request to {:?}{}", host, path);
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
}
};
let route_id = route_match.route.id.as_deref();
self.metrics.connection_opened(route_id);
// Apply request filters (IP check, rate limiting, auth)
if let Some(ref security) = route_match.route.security {
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
self.metrics.connection_closed(route_id);
return Ok(response);
}
}
// Check for test response (returns immediately, no upstream needed)
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref test_response) = advanced.test_response {
self.metrics.connection_closed(route_id);
return Ok(Self::build_test_response(test_response));
}
}
// Check for static file serving
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref static_files) = advanced.static_files {
self.metrics.connection_closed(route_id);
return Ok(Self::serve_static_file(&path, static_files));
}
}
// Select upstream
let target = match route_match.target {
Some(t) => t,
None => {
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
}
};
let upstream = self.upstream_selector.select(target, &peer_addr, port);
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_started(&upstream_key);
// Check for WebSocket upgrade
let is_websocket = req.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
if is_websocket {
let result = self.handle_websocket_upgrade(
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key,
).await;
// Note: for WebSocket, connection_ended is called inside
// the spawned tunnel task when the connection closes.
return result;
}
// Determine backend protocol
let use_h2 = route_match.route.action.options.as_ref()
.and_then(|o| o.backend_protocol.as_ref())
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
.unwrap_or(false);
// Build the upstream path (path + query), applying URL rewriting if configured
let upstream_path = {
let raw_path = match req.uri().query() {
Some(q) => format!("{}?{}", path, q),
None => path.clone(),
};
Self::apply_url_rewrite(&raw_path, &route_match.route)
};
// Build upstream request - stream body instead of buffering
let (parts, body) = req.into_parts();
// Apply request headers from route config
let mut upstream_headers = parts.headers.clone();
if let Some(ref route_headers) = route_match.route.headers {
if let Some(ref request_headers) = route_headers.request {
for (key, value) in request_headers {
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
upstream_headers.insert(name, val);
}
}
}
}
}
// Connect to upstream
let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await {
Ok(s) => s,
Err(e) => {
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
};
upstream_stream.set_nodelay(true).ok();
let io = TokioIo::new(upstream_stream);
let result = if use_h2 {
// HTTP/2 backend
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
} else {
// HTTP/1.1 backend (default)
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
};
self.upstream_selector.connection_ended(&upstream_key);
result
}
/// Forward request to backend via HTTP/1.1 with body streaming.
async fn forward_h1(
&self,
io: TokioIo<TcpStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
Ok(h) => h,
Err(e) => {
error!("Upstream handshake failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
}
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("Upstream connection error: {}", e);
}
});
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_path)
.version(parts.version);
if let Some(headers) = upstream_req.headers_mut() {
*headers = upstream_headers;
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
&format!("{}:{}", upstream.host, upstream.port)
) {
headers.insert(hyper::header::HOST, host_val);
}
}
// Stream the request body through to upstream
let upstream_req = upstream_req.body(body).unwrap();
let upstream_response = match sender.send_request(upstream_req).await {
Ok(resp) => resp,
Err(e) => {
error!("Upstream request failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
}
};
self.build_streaming_response(upstream_response, route, route_id).await
}
/// Forward request to backend via HTTP/2 with body streaming.
async fn forward_h2(
&self,
io: TokioIo<TcpStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let exec = hyper_util::rt::TokioExecutor::new();
let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await {
Ok(h) => h,
Err(e) => {
error!("HTTP/2 upstream handshake failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
}
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("HTTP/2 upstream connection error: {}", e);
}
});
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_path);
if let Some(headers) = upstream_req.headers_mut() {
*headers = upstream_headers;
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
&format!("{}:{}", upstream.host, upstream.port)
) {
headers.insert(hyper::header::HOST, host_val);
}
}
// Stream the request body through to upstream
let upstream_req = upstream_req.body(body).unwrap();
let upstream_response = match sender.send_request(upstream_req).await {
Ok(resp) => resp,
Err(e) => {
error!("HTTP/2 upstream request failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
}
};
self.build_streaming_response(upstream_response, route, route_id).await
}
/// Build the client-facing response from an upstream response, streaming the body.
async fn build_streaming_response(
&self,
upstream_response: Response<Incoming>,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let (resp_parts, resp_body) = upstream_response.into_parts();
let mut response = Response::builder()
.status(resp_parts.status);
if let Some(headers) = response.headers_mut() {
*headers = resp_parts.headers;
ResponseFilter::apply_headers(route, headers, None);
}
self.metrics.connection_closed(route_id);
// Stream the response body directly from upstream to client
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(resp_body);
Ok(response.body(body).unwrap())
}
/// Handle a WebSocket upgrade request.
async fn handle_websocket_upgrade(
&self,
req: Request<Incoming>,
peer_addr: std::net::SocketAddr,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
upstream_key: &str,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Get WebSocket config from route
let ws_config = route.action.websocket.as_ref();
// Check allowed origins if configured
if let Some(ws) = ws_config {
if let Some(ref allowed_origins) = ws.allowed_origins {
let origin = req.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
}
}
}
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
let mut upstream_stream = match TcpStream::connect(
format!("{}:{}", upstream.host, upstream.port)
).await {
Ok(s) => s,
Err(e) => {
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
};
upstream_stream.set_nodelay(true).ok();
let path = req.uri().path().to_string();
let upstream_path = {
let raw = match req.uri().query() {
Some(q) => format!("{}?{}", path, q),
None => path,
};
// Apply rewrite_path if configured
if let Some(ws) = ws_config {
if let Some(ref rewrite_path) = ws.rewrite_path {
rewrite_path.clone()
} else {
raw
}
} else {
raw
}
};
let (parts, _body) = req.into_parts();
let mut raw_request = format!(
"{} {} HTTP/1.1\r\n",
parts.method, upstream_path
);
let upstream_host = format!("{}:{}", upstream.host, upstream.port);
for (name, value) in parts.headers.iter() {
if name == hyper::header::HOST {
raw_request.push_str(&format!("host: {}\r\n", upstream_host));
} else {
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
}
}
if let Some(ref route_headers) = route.headers {
if let Some(ref request_headers) = route_headers.request {
for (key, value) in request_headers {
raw_request.push_str(&format!("{}: {}\r\n", key, value));
}
}
}
// Apply WebSocket custom headers
if let Some(ws) = ws_config {
if let Some(ref custom_headers) = ws.custom_headers {
for (key, value) in custom_headers {
raw_request.push_str(&format!("{}: {}\r\n", key, value));
}
}
}
raw_request.push_str("\r\n");
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
}
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
match upstream_stream.read(&mut temp).await {
Ok(0) => {
error!("WebSocket: upstream closed before completing handshake");
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
}
Ok(_) => {
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break;
}
}
if response_buf.len() > 8192 {
error!("WebSocket: upstream response headers too large");
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
}
}
Err(e) => {
error!("WebSocket: failed to read upstream response: {}", e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
}
}
}
let response_str = String::from_utf8_lossy(&response_buf);
let status_line = response_str.lines().next().unwrap_or("");
let status_code = status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(0);
if status_code != 101 {
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
"WebSocket upgrade rejected by backend",
));
}
let mut client_resp = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS);
if let Some(resp_headers) = client_resp.headers_mut() {
for line in response_str.lines().skip(1) {
let line = line.trim();
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim();
let value = value.trim();
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
resp_headers.insert(header_name, header_value);
}
}
}
}
}
let on_client_upgrade = hyper::upgrade::on(
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
);
let metrics = Arc::clone(&self.metrics);
let route_id_owned = route_id.map(|s| s.to_string());
let upstream_selector = self.upstream_selector.clone();
let upstream_key_owned = upstream_key.to_string();
tokio::spawn(async move {
let client_upgraded = match on_client_upgrade.await {
Ok(upgraded) => upgraded,
Err(e) => {
debug!("WebSocket: client upgrade failed: {}", e);
upstream_selector.connection_ended(&upstream_key_owned);
if let Some(ref rid) = route_id_owned {
metrics.connection_closed(Some(rid.as_str()));
}
return;
}
};
let client_io = TokioIo::new(client_upgraded);
let (mut cr, mut cw) = tokio::io::split(client_io);
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
let c2u = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match cr.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if uw.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
}
let _ = uw.shutdown().await;
total
});
let u2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match ur.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if cw.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
}
let _ = cw.shutdown().await;
total
});
let bytes_in = c2u.await.unwrap_or(0);
let bytes_out = u2c.await.unwrap_or(0);
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
upstream_selector.connection_ended(&upstream_key_owned);
if let Some(ref rid) = route_id_owned {
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()));
metrics.connection_closed(Some(rid.as_str()));
}
});
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
);
Ok(client_resp.body(body).unwrap())
}
/// Build a test response from config (no upstream connection needed).
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::builder()
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
if let Some(headers) = response.headers_mut() {
for (key, value) in &config.headers {
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
headers.insert(name, val);
}
}
}
}
let body = Full::new(Bytes::from(config.body.clone()))
.map_err(|never| match never {});
response.body(BoxBody::new(body)).unwrap()
}
/// Apply URL rewriting rules from route config.
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
let rewrite = match route.action.advanced.as_ref()
.and_then(|a| a.url_rewrite.as_ref())
{
Some(r) => r,
None => return path.to_string(),
};
// Determine what to rewrite
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
// Only rewrite the path portion (before ?)
match path.split_once('?') {
Some((p, q)) => (p.to_string(), format!("?{}", q)),
None => (path.to_string(), String::new()),
}
} else {
(path.to_string(), String::new())
};
match Regex::new(&rewrite.pattern) {
Ok(re) => {
let result = re.replace_all(&subject, rewrite.target.as_str());
format!("{}{}", result, suffix)
}
Err(e) => {
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
path.to_string()
}
}
}
/// Serve a static file from the configured directory.
fn serve_static_file(
path: &str,
config: &rustproxy_config::RouteStaticFiles,
) -> Response<BoxBody<Bytes, hyper::Error>> {
use std::path::Path;
let root = Path::new(&config.root);
// Sanitize path to prevent directory traversal
let clean_path = path.trim_start_matches('/');
let clean_path = clean_path.replace("..", "");
let mut file_path = root.join(&clean_path);
// If path points to a directory, try index files
if file_path.is_dir() || clean_path.is_empty() {
let index_files = config.index_files.as_deref()
.or(config.index.as_deref())
.unwrap_or(&[]);
let default_index = vec!["index.html".to_string()];
let index_files = if index_files.is_empty() { &default_index } else { index_files };
let mut found = false;
for index in index_files {
let candidate = if clean_path.is_empty() {
root.join(index)
} else {
file_path.join(index)
};
if candidate.is_file() {
file_path = candidate;
found = true;
break;
}
}
if !found {
return error_response(StatusCode::NOT_FOUND, "Not found");
}
}
// Ensure the resolved path is within the root (prevent traversal)
let canonical_root = match root.canonicalize() {
Ok(p) => p,
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
};
let canonical_file = match file_path.canonicalize() {
Ok(p) => p,
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
};
if !canonical_file.starts_with(&canonical_root) {
return error_response(StatusCode::FORBIDDEN, "Forbidden");
}
// Check if symlinks are allowed
if config.follow_symlinks == Some(false) && canonical_file != file_path {
return error_response(StatusCode::FORBIDDEN, "Forbidden");
}
// Read the file
match std::fs::read(&file_path) {
Ok(content) => {
let content_type = guess_content_type(&file_path);
let mut response = Response::builder()
.status(StatusCode::OK)
.header("Content-Type", content_type);
// Apply cache-control if configured
if let Some(ref cache_control) = config.cache_control {
response = response.header("Cache-Control", cache_control.as_str());
}
// Apply custom headers
if let Some(ref headers) = config.headers {
for (key, value) in headers {
response = response.header(key.as_str(), value.as_str());
}
}
let body = Full::new(Bytes::from(content))
.map_err(|never| match never {});
response.body(BoxBody::new(body)).unwrap()
}
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
}
}
}
/// Guess MIME content type from file extension.
fn guess_content_type(path: &std::path::Path) -> &'static str {
match path.extension().and_then(|e| e.to_str()) {
Some("html") | Some("htm") => "text/html; charset=utf-8",
Some("css") => "text/css; charset=utf-8",
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
Some("json") => "application/json; charset=utf-8",
Some("xml") => "application/xml; charset=utf-8",
Some("txt") => "text/plain; charset=utf-8",
Some("png") => "image/png",
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("gif") => "image/gif",
Some("svg") => "image/svg+xml",
Some("ico") => "image/x-icon",
Some("woff") => "font/woff",
Some("woff2") => "font/woff2",
Some("ttf") => "font/ttf",
Some("pdf") => "application/pdf",
Some("wasm") => "application/wasm",
_ => "application/octet-stream",
}
}
impl Default for HttpProxyService {
fn default() -> Self {
Self {
route_manager: Arc::new(RouteManager::new(vec![])),
metrics: Arc::new(MetricsCollector::new()),
upstream_selector: UpstreamSelector::new(),
}
}
}
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
let body = Full::new(Bytes::from(message.to_string()))
.map_err(|never| match never {});
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(BoxBody::new(body))
.unwrap()
}

View File

@@ -0,0 +1,263 @@
//! Request filtering: security checks, auth, CORS preflight.
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
pub struct RequestFilter;
impl RequestFilter {
/// Apply security filters. Returns Some(response) if the request should be blocked.
pub fn apply(
security: &RouteSecurity,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
Self::apply_with_rate_limiter(security, req, peer_addr, None)
}
/// Apply security filters with an optional shared rate limiter.
/// Returns Some(response) if the request should be blocked.
pub fn apply_with_rate_limiter(
security: &RouteSecurity,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
rate_limiter: Option<&Arc<RateLimiter>>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
let client_ip = peer_addr.ip();
let request_path = req.uri().path();
// IP filter
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip);
if !filter.is_allowed(&normalized) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
}
}
// Rate limiting
if let Some(ref rate_limit_config) = security.rate_limit {
if rate_limit_config.enabled {
// Use shared rate limiter if provided, otherwise create ephemeral one
let should_block = if let Some(limiter) = rate_limiter {
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
} else {
// Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new(
rate_limit_config.max_requests,
rate_limit_config.window,
);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
};
if should_block {
let message = rate_limit_config.error_message
.as_deref()
.unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
}
}
}
// Check exclude paths before auth
let should_skip_auth = Self::path_matches_exclude_list(request_path, security);
if !should_skip_auth {
// Basic auth
if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled {
// Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter()
.map(|c| (c.username.clone(), c.password.clone()))
.collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header {
Some(header) => {
if validator.validate(header).is_none() {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Invalid credentials"))
.unwrap());
}
}
None => {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap());
}
}
}
}
}
// JWT auth
if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled {
// Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_jwt {
let validator = JwtValidator::new(
&jwt_auth.secret,
jwt_auth.algorithm.as_deref(),
jwt_auth.issuer.as_deref(),
jwt_auth.audience.as_deref(),
);
let auth_header = req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => {
if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
}
}
None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
}
}
}
}
}
}
None
}
/// Check if a request path matches any pattern in the exclude list.
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
// No global exclude paths on RouteSecurity currently,
// but we check per-auth exclude paths above.
// This can be extended if a global exclude_paths is added.
false
}
/// Check if a path matches any pattern in the list.
/// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check`
fn path_matches_any(path: &str, patterns: &[String]) -> bool {
for pattern in patterns {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
if path.starts_with(prefix) {
return true;
}
} else if path == pattern {
return true;
}
}
false
}
/// Determine the rate limit key based on configuration.
fn rate_limit_key(
config: &rustproxy_config::RouteRateLimit,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
) -> String {
use rustproxy_config::RateLimitKeyBy;
match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) {
RateLimitKeyBy::Ip => peer_addr.ip().to_string(),
RateLimitKeyBy::Path => req.uri().path().to_string(),
RateLimitKeyBy::Header => {
if let Some(ref header_name) = config.header_name {
req.headers()
.get(header_name.as_str())
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string()
} else {
peer_addr.ip().to_string()
}
}
}
}
/// Check IP-based security (for use in passthrough / TCP-level connections).
/// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(client_ip);
filter.is_allowed(&normalized)
} else {
true
}
}
/// Handle CORS preflight (OPTIONS) requests.
/// Returns Some(response) if this is a CORS preflight that should be handled.
pub fn handle_cors_preflight(
req: &Request<Incoming>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
if req.method() != hyper::Method::OPTIONS {
return None;
}
// Check for CORS preflight indicators
let has_origin = req.headers().contains_key("origin");
let has_request_method = req.headers().contains_key("access-control-request-method");
if !has_origin || !has_request_method {
return None;
}
let origin = req.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");
Some(Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap())
}
}
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(boxed_body(message))
.unwrap()
}
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
}

View File

@@ -0,0 +1,92 @@
//! Response filtering: CORS headers, custom headers, security headers.
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template};
pub struct ResponseFilter;
impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
// Apply custom response headers from route config
if let Some(ref route_headers) = route.headers {
if let Some(ref response_headers) = route_headers.response {
for (key, value) in response_headers {
if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) {
let expanded = match req_ctx {
Some(ctx) => expand_template(value, ctx),
None => value.clone(),
};
if let Ok(val) = HeaderValue::from_str(&expanded) {
headers.insert(name, val);
}
}
}
}
// Apply CORS headers if configured
if let Some(ref cors) = route_headers.cors {
if cors.enabled {
Self::apply_cors_headers(cors, headers);
}
}
}
}
fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) {
// Allow-Origin
if let Some(ref origin) = cors.allow_origin {
let origin_str = match origin {
rustproxy_config::AllowOrigin::Single(s) => s.clone(),
rustproxy_config::AllowOrigin::List(list) => list.join(", "),
};
if let Ok(val) = HeaderValue::from_str(&origin_str) {
headers.insert("access-control-allow-origin", val);
}
} else {
headers.insert(
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
}
// Allow-Methods
if let Some(ref methods) = cors.allow_methods {
if let Ok(val) = HeaderValue::from_str(methods) {
headers.insert("access-control-allow-methods", val);
}
}
// Allow-Headers
if let Some(ref allow_headers) = cors.allow_headers {
if let Ok(val) = HeaderValue::from_str(allow_headers) {
headers.insert("access-control-allow-headers", val);
}
}
// Allow-Credentials
if cors.allow_credentials == Some(true) {
headers.insert(
"access-control-allow-credentials",
HeaderValue::from_static("true"),
);
}
// Expose-Headers
if let Some(ref expose) = cors.expose_headers {
if let Ok(val) = HeaderValue::from_str(expose) {
headers.insert("access-control-expose-headers", val);
}
}
// Max-Age
if let Some(max_age) = cors.max_age {
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
headers.insert("access-control-max-age", val);
}
}
}
}

View File

@@ -0,0 +1,162 @@
//! Header template variable expansion.
//!
//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc.
//! in header values before they are applied to requests or responses.
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
/// Context for template variable expansion.
pub struct RequestContext {
pub client_ip: String,
pub domain: String,
pub port: u16,
pub path: String,
pub route_name: String,
pub connection_id: u64,
}
/// Expand template variables in a header value.
/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp}
pub fn expand_template(template: &str, ctx: &RequestContext) -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
template
.replace("{clientIp}", &ctx.client_ip)
.replace("{domain}", &ctx.domain)
.replace("{port}", &ctx.port.to_string())
.replace("{path}", &ctx.path)
.replace("{routeName}", &ctx.route_name)
.replace("{connectionId}", &ctx.connection_id.to_string())
.replace("{timestamp}", &timestamp.to_string())
}
/// Expand templates in a map of header key-value pairs.
pub fn expand_headers(
headers: &HashMap<String, String>,
ctx: &RequestContext,
) -> HashMap<String, String> {
headers.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> RequestContext {
RequestContext {
client_ip: "192.168.1.100".to_string(),
domain: "example.com".to_string(),
port: 443,
path: "/api/v1/users".to_string(),
route_name: "api-route".to_string(),
connection_id: 42,
}
}
#[test]
fn test_expand_client_ip() {
let ctx = test_context();
assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100");
}
#[test]
fn test_expand_domain() {
let ctx = test_context();
assert_eq!(expand_template("{domain}", &ctx), "example.com");
}
#[test]
fn test_expand_port() {
let ctx = test_context();
assert_eq!(expand_template("{port}", &ctx), "443");
}
#[test]
fn test_expand_path() {
let ctx = test_context();
assert_eq!(expand_template("{path}", &ctx), "/api/v1/users");
}
#[test]
fn test_expand_route_name() {
let ctx = test_context();
assert_eq!(expand_template("{routeName}", &ctx), "api-route");
}
#[test]
fn test_expand_connection_id() {
let ctx = test_context();
assert_eq!(expand_template("{connectionId}", &ctx), "42");
}
#[test]
fn test_expand_timestamp() {
let ctx = test_context();
let result = expand_template("{timestamp}", &ctx);
// Timestamp should be a valid number
let ts: u64 = result.parse().expect("timestamp should be a number");
// Should be a reasonable Unix timestamp (after 2020)
assert!(ts > 1_577_836_800);
}
#[test]
fn test_expand_mixed_template() {
let ctx = test_context();
let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx);
assert_eq!(result, "client=192.168.1.100, host=example.com:443");
}
#[test]
fn test_expand_no_variables() {
let ctx = test_context();
assert_eq!(expand_template("plain-value", &ctx), "plain-value");
}
#[test]
fn test_expand_empty_string() {
let ctx = test_context();
assert_eq!(expand_template("", &ctx), "");
}
#[test]
fn test_expand_multiple_same_variable() {
let ctx = test_context();
let result = expand_template("{clientIp}-{clientIp}", &ctx);
assert_eq!(result, "192.168.1.100-192.168.1.100");
}
#[test]
fn test_expand_headers_map() {
let ctx = test_context();
let mut headers = HashMap::new();
headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string());
headers.insert("X-Route".to_string(), "{routeName}".to_string());
headers.insert("X-Static".to_string(), "no-template".to_string());
let result = expand_headers(&headers, &ctx);
assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100");
assert_eq!(result.get("X-Route").unwrap(), "api-route");
assert_eq!(result.get("X-Static").unwrap(), "no-template");
}
#[test]
fn test_expand_all_variables_in_one() {
let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
}
#[test]
fn test_expand_unknown_variable_left_as_is() {
let ctx = test_context();
let result = expand_template("{unknownVar}", &ctx);
assert_eq!(result, "{unknownVar}");
}
}

View File

@@ -0,0 +1,222 @@
//! Route-aware upstream selection with load balancing.
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
/// Upstream selection result.
pub struct UpstreamSelection {
pub host: String,
pub port: u16,
pub use_tls: bool,
}
/// Selects upstream backends with load balancing support.
pub struct UpstreamSelector {
/// Round-robin counters per route (keyed by first target host:port)
round_robin: Mutex<HashMap<String, AtomicUsize>>,
/// Active connection counts per host (keyed by "host:port")
active_connections: Arc<DashMap<String, AtomicU64>>,
}
impl UpstreamSelector {
pub fn new() -> Self {
Self {
round_robin: Mutex::new(HashMap::new()),
active_connections: Arc::new(DashMap::new()),
}
}
/// Select an upstream target based on the route target config and load balancing.
pub fn select(
&self,
target: &RouteTarget,
client_addr: &SocketAddr,
incoming_port: u16,
) -> UpstreamSelection {
let hosts = target.host.to_vec();
let port = target.port.resolve(incoming_port);
if hosts.len() <= 1 {
return UpstreamSelection {
host: hosts.first().map(|s| s.to_string()).unwrap_or_default(),
port,
use_tls: target.tls.is_some(),
};
}
// Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref()
.map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr);
hash % hosts.len()
}
LoadBalancingAlgorithm::LeastConnections => {
self.least_connections_select(&hosts, port)
}
};
UpstreamSelection {
host: hosts[idx].to_string(),
port,
use_tls: target.tls.is_some(),
}
}
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap();
let counter = counters
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len()
}
fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize {
let mut min_conns = u64::MAX;
let mut min_idx = 0;
for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port);
let conns = self.active_connections
.get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0);
if conns < min_conns {
min_conns = conns;
min_idx = i;
}
}
min_idx
}
/// Record that a connection to the given host has started.
pub fn connection_started(&self, host: &str) {
self.active_connections
.entry(host.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection to the given host has ended.
pub fn connection_ended(&self, host: &str) {
if let Some(counter) = self.active_connections.get(host) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Guard against underflow (shouldn't happen, but be safe)
if prev == 0 {
counter.value().store(0, Ordering::Relaxed);
}
}
}
fn ip_hash(addr: &SocketAddr) -> usize {
let ip_str = addr.ip().to_string();
let mut hash: usize = 5381;
for byte in ip_str.bytes() {
hash = hash.wrapping_mul(33).wrapping_add(byte as usize);
}
hash
}
}
impl Default for UpstreamSelector {
fn default() -> Self {
Self::new()
}
}
impl Clone for UpstreamSelector {
fn clone(&self) -> Self {
Self {
round_robin: Mutex::new(HashMap::new()),
active_connections: Arc::clone(&self.active_connections),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustproxy_config::*;
fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget {
RouteTarget {
target_match: None,
host: if hosts.len() == 1 {
HostSpec::Single(hosts[0].to_string())
} else {
HostSpec::List(hosts.iter().map(|s| s.to_string()).collect())
},
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}
}
#[test]
fn test_single_host() {
let selector = UpstreamSelector::new();
let target = make_target(vec!["backend"], 8080);
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let result = selector.select(&target, &addr, 80);
assert_eq!(result.host, "backend");
assert_eq!(result.port, 8080);
}
#[test]
fn test_round_robin() {
let selector = UpstreamSelector::new();
let mut target = make_target(vec!["a", "b", "c"], 8080);
target.load_balancing = Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
});
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let r1 = selector.select(&target, &addr, 80);
let r2 = selector.select(&target, &addr, 80);
let r3 = selector.select(&target, &addr, 80);
let r4 = selector.select(&target, &addr, 80);
// Should cycle through a, b, c, a
assert_eq!(r1.host, "a");
assert_eq!(r2.host, "b");
assert_eq!(r3.host, "c");
assert_eq!(r4.host, "a");
}
#[test]
fn test_ip_hash_consistent() {
let selector = UpstreamSelector::new();
let mut target = make_target(vec!["a", "b", "c"], 8080);
target.load_balancing = Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::IpHash,
health_check: None,
});
let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap();
let r1 = selector.select(&target, &addr, 80);
let r2 = selector.select(&target, &addr, 80);
// Same IP should always get same backend
assert_eq!(r1.host, r2.host);
}
}

View File

@@ -0,0 +1,15 @@
[package]
name = "rustproxy-metrics"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Metrics and throughput tracking for RustProxy"
[dependencies]
dashmap = { workspace = true }
tracing = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }

View File

@@ -0,0 +1,251 @@
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
/// Aggregated metrics snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Metrics {
pub active_connections: u64,
pub total_connections: u64,
pub bytes_in: u64,
pub bytes_out: u64,
pub throughput_in_bytes_per_sec: u64,
pub throughput_out_bytes_per_sec: u64,
pub routes: std::collections::HashMap<String, RouteMetrics>,
}
/// Per-route metrics.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteMetrics {
pub active_connections: u64,
pub total_connections: u64,
pub bytes_in: u64,
pub bytes_out: u64,
pub throughput_in_bytes_per_sec: u64,
pub throughput_out_bytes_per_sec: u64,
}
/// Statistics snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Statistics {
pub active_connections: u64,
pub total_connections: u64,
pub routes_count: u64,
pub listening_ports: Vec<u16>,
pub uptime_seconds: u64,
}
/// Metrics collector tracking connections and throughput.
pub struct MetricsCollector {
active_connections: AtomicU64,
total_connections: AtomicU64,
total_bytes_in: AtomicU64,
total_bytes_out: AtomicU64,
/// Per-route active connection counts
route_connections: DashMap<String, AtomicU64>,
/// Per-route total connection counts
route_total_connections: DashMap<String, AtomicU64>,
/// Per-route byte counters
route_bytes_in: DashMap<String, AtomicU64>,
route_bytes_out: DashMap<String, AtomicU64>,
}
impl MetricsCollector {
pub fn new() -> Self {
Self {
active_connections: AtomicU64::new(0),
total_connections: AtomicU64::new(0),
total_bytes_in: AtomicU64::new(0),
total_bytes_out: AtomicU64::new(0),
route_connections: DashMap::new(),
route_total_connections: DashMap::new(),
route_bytes_in: DashMap::new(),
route_bytes_out: DashMap::new(),
}
}
/// Record a new connection.
pub fn connection_opened(&self, route_id: Option<&str>) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
self.total_connections.fetch_add(1, Ordering::Relaxed);
if let Some(route_id) = route_id {
self.route_connections
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
self.route_total_connections
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
}
/// Record a connection closing.
pub fn connection_closed(&self, route_id: Option<&str>) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
if let Some(route_id) = route_id {
if let Some(counter) = self.route_connections.get(route_id) {
let val = counter.load(Ordering::Relaxed);
if val > 0 {
counter.fetch_sub(1, Ordering::Relaxed);
}
}
}
}
/// Record bytes transferred.
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>) {
self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
if let Some(route_id) = route_id {
self.route_bytes_in
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_in, Ordering::Relaxed);
self.route_bytes_out
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_out, Ordering::Relaxed);
}
}
/// Get current active connection count.
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
}
/// Get total connection count.
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
/// Get total bytes received.
pub fn total_bytes_in(&self) -> u64 {
self.total_bytes_in.load(Ordering::Relaxed)
}
/// Get total bytes sent.
pub fn total_bytes_out(&self) -> u64 {
self.total_bytes_out.load(Ordering::Relaxed)
}
/// Get a full metrics snapshot including per-route data.
pub fn snapshot(&self) -> Metrics {
let mut routes = std::collections::HashMap::new();
// Collect per-route metrics
for entry in self.route_total_connections.iter() {
let route_id = entry.key().clone();
let total = entry.value().load(Ordering::Relaxed);
let active = self.route_connections
.get(&route_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let bytes_in = self.route_bytes_in
.get(&route_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let bytes_out = self.route_bytes_out
.get(&route_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
routes.insert(route_id, RouteMetrics {
active_connections: active,
total_connections: total,
bytes_in,
bytes_out,
throughput_in_bytes_per_sec: 0,
throughput_out_bytes_per_sec: 0,
});
}
Metrics {
active_connections: self.active_connections(),
total_connections: self.total_connections(),
bytes_in: self.total_bytes_in(),
bytes_out: self.total_bytes_out(),
throughput_in_bytes_per_sec: 0,
throughput_out_bytes_per_sec: 0,
routes,
}
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state_zeros() {
let collector = MetricsCollector::new();
assert_eq!(collector.active_connections(), 0);
assert_eq!(collector.total_connections(), 0);
}
#[test]
fn test_connection_opened_increments() {
let collector = MetricsCollector::new();
collector.connection_opened(None);
assert_eq!(collector.active_connections(), 1);
assert_eq!(collector.total_connections(), 1);
collector.connection_opened(None);
assert_eq!(collector.active_connections(), 2);
assert_eq!(collector.total_connections(), 2);
}
#[test]
fn test_connection_closed_decrements() {
let collector = MetricsCollector::new();
collector.connection_opened(None);
collector.connection_opened(None);
assert_eq!(collector.active_connections(), 2);
collector.connection_closed(None);
assert_eq!(collector.active_connections(), 1);
// total_connections should stay at 2
assert_eq!(collector.total_connections(), 2);
}
#[test]
fn test_route_specific_tracking() {
let collector = MetricsCollector::new();
collector.connection_opened(Some("route-a"));
collector.connection_opened(Some("route-a"));
collector.connection_opened(Some("route-b"));
assert_eq!(collector.active_connections(), 3);
assert_eq!(collector.total_connections(), 3);
collector.connection_closed(Some("route-a"));
assert_eq!(collector.active_connections(), 2);
}
#[test]
fn test_record_bytes() {
let collector = MetricsCollector::new();
collector.record_bytes(100, 200, Some("route-a"));
collector.record_bytes(50, 75, Some("route-a"));
collector.record_bytes(25, 30, None);
let total_in = collector.total_bytes_in.load(Ordering::Relaxed);
let total_out = collector.total_bytes_out.load(Ordering::Relaxed);
assert_eq!(total_in, 175);
assert_eq!(total_out, 305);
// Route-specific bytes
let route_in = collector.route_bytes_in.get("route-a").unwrap();
assert_eq!(route_in.load(Ordering::Relaxed), 150);
}
}

View File

@@ -0,0 +1,11 @@
//! # rustproxy-metrics
//!
//! Metrics and throughput tracking for RustProxy.
pub mod throughput;
pub mod collector;
pub mod log_dedup;
pub use throughput::*;
pub use collector::*;
pub use log_dedup::*;

View File

@@ -0,0 +1,219 @@
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tracing::info;
/// An aggregated event during the deduplication window.
struct AggregatedEvent {
category: String,
first_message: String,
count: AtomicU64,
first_seen: Instant,
#[allow(dead_code)]
last_seen: Instant,
}
/// Log deduplicator that batches similar events over a time window.
///
/// Events are grouped by a composite key of `category:key`. Within each
/// deduplication window (`flush_interval`) identical events are counted
/// instead of being emitted individually. When the window expires (or the
/// batch reaches `max_batch_size`) a single summary line is written via
/// `tracing::info!`.
pub struct LogDeduplicator {
events: DashMap<String, AggregatedEvent>,
flush_interval: Duration,
max_batch_size: u64,
#[allow(dead_code)]
rapid_threshold: u64, // events/sec that triggers immediate flush
}
impl LogDeduplicator {
pub fn new() -> Self {
Self {
events: DashMap::new(),
flush_interval: Duration::from_secs(5),
max_batch_size: 100,
rapid_threshold: 50,
}
}
/// Log an event, deduplicating by `category` + `key`.
///
/// If the batch for this composite key reaches `max_batch_size` the
/// accumulated events are flushed immediately.
pub fn log(&self, category: &str, key: &str, message: &str) {
let map_key = format!("{}:{}", category, key);
let now = Instant::now();
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
category: category.to_string(),
first_message: message.to_string(),
count: AtomicU64::new(0),
first_seen: now,
last_seen: now,
});
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if we should flush (batch size exceeded)
if count >= self.max_batch_size {
drop(entry);
self.flush();
}
}
/// Flush all accumulated events, emitting summary log lines.
pub fn flush(&self) {
// Collect and remove all events
self.events.retain(|_key, event| {
let count = event.count.load(Ordering::Relaxed);
if count > 0 {
let elapsed = event.first_seen.elapsed();
if count == 1 {
info!("[{}] {}", event.category, event.first_message);
} else {
info!(
"[SUMMARY] {} {} events in {:.1}s: {}",
count,
event.category,
elapsed.as_secs_f64(),
event.first_message
);
}
}
false // remove all entries after flushing
});
}
/// Start a background flush task that periodically drains accumulated
/// events. The task runs until the supplied `CancellationToken` is
/// cancelled, at which point it performs one final flush before exiting.
pub fn start_flush_task(self: &Arc<Self>, cancel: tokio_util::sync::CancellationToken) {
let dedup = Arc::clone(self);
let interval = self.flush_interval;
tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => {
dedup.flush();
break;
}
_ = tokio::time::sleep(interval) => {
dedup.flush();
}
}
}
});
}
}
impl Default for LogDeduplicator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_event_emitted_as_is() {
let dedup = LogDeduplicator::new();
dedup.log("conn", "open", "connection opened from 1.2.3.4");
// One event should exist
assert_eq!(dedup.events.len(), 1);
let entry = dedup.events.get("conn:open").unwrap();
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
assert_eq!(entry.first_message, "connection opened from 1.2.3.4");
drop(entry);
dedup.flush();
// After flush, map should be empty
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_duplicate_events_aggregated() {
let dedup = LogDeduplicator::new();
for _ in 0..10 {
dedup.log("conn", "timeout", "connection timed out");
}
assert_eq!(dedup.events.len(), 1);
let entry = dedup.events.get("conn:timeout").unwrap();
assert_eq!(entry.count.load(Ordering::Relaxed), 10);
drop(entry);
dedup.flush();
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_different_keys_separate() {
let dedup = LogDeduplicator::new();
dedup.log("conn", "open", "opened");
dedup.log("conn", "close", "closed");
dedup.log("tls", "handshake", "TLS handshake");
assert_eq!(dedup.events.len(), 3);
dedup.flush();
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_flush_clears_events() {
let dedup = LogDeduplicator::new();
dedup.log("a", "b", "msg1");
dedup.log("a", "b", "msg2");
dedup.flush();
assert_eq!(dedup.events.len(), 0);
// Logging after flush creates a new entry
dedup.log("a", "b", "msg3");
assert_eq!(dedup.events.len(), 1);
let entry = dedup.events.get("a:b").unwrap();
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
assert_eq!(entry.first_message, "msg3");
}
#[test]
fn test_max_batch_triggers_flush() {
let dedup = LogDeduplicator::new();
// max_batch_size defaults to 100
for i in 0..100 {
dedup.log("flood", "key", &format!("event {}", i));
}
// After hitting max_batch_size the events map should have been flushed
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_default_trait() {
let dedup = LogDeduplicator::default();
assert_eq!(dedup.flush_interval, Duration::from_secs(5));
assert_eq!(dedup.max_batch_size, 100);
}
#[tokio::test]
async fn test_background_flush_task() {
let dedup = Arc::new(LogDeduplicator {
events: DashMap::new(),
flush_interval: Duration::from_millis(50),
max_batch_size: 100,
rapid_threshold: 50,
});
let cancel = tokio_util::sync::CancellationToken::new();
dedup.start_flush_task(cancel.clone());
// Log some events
dedup.log("bg", "test", "background flush test");
assert_eq!(dedup.events.len(), 1);
// Wait for the background task to flush
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(dedup.events.len(), 0);
// Cancel the task
cancel.cancel();
tokio::time::sleep(Duration::from_millis(20)).await;
}
}

View File

@@ -0,0 +1,173 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
/// A single throughput sample.
#[derive(Debug, Clone, Copy)]
pub struct ThroughputSample {
pub timestamp_ms: u64,
pub bytes_in: u64,
pub bytes_out: u64,
}
/// Circular buffer for 1Hz throughput sampling.
/// Matches smartproxy's ThroughputTracker.
pub struct ThroughputTracker {
/// Circular buffer of samples
samples: Vec<ThroughputSample>,
/// Current write index
write_index: usize,
/// Number of valid samples
count: usize,
/// Maximum number of samples to retain
capacity: usize,
/// Accumulated bytes since last sample
pending_bytes_in: AtomicU64,
pending_bytes_out: AtomicU64,
/// When the tracker was created
created_at: Instant,
}
impl ThroughputTracker {
/// Create a new tracker with the given capacity (seconds of retention).
pub fn new(retention_seconds: usize) -> Self {
Self {
samples: Vec::with_capacity(retention_seconds),
write_index: 0,
count: 0,
capacity: retention_seconds,
pending_bytes_in: AtomicU64::new(0),
pending_bytes_out: AtomicU64::new(0),
created_at: Instant::now(),
}
}
/// Record bytes (called from data flow callbacks).
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
}
/// Take a sample (called at 1Hz).
pub fn sample(&mut self) {
let bytes_in = self.pending_bytes_in.swap(0, Ordering::Relaxed);
let bytes_out = self.pending_bytes_out.swap(0, Ordering::Relaxed);
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let sample = ThroughputSample {
timestamp_ms,
bytes_in,
bytes_out,
};
if self.samples.len() < self.capacity {
self.samples.push(sample);
} else {
self.samples[self.write_index] = sample;
}
self.write_index = (self.write_index + 1) % self.capacity;
self.count = (self.count + 1).min(self.capacity);
}
/// Get throughput over the last N seconds.
pub fn throughput(&self, window_seconds: usize) -> (u64, u64) {
let window = window_seconds.min(self.count);
if window == 0 {
return (0, 0);
}
let mut total_in = 0u64;
let mut total_out = 0u64;
for i in 0..window {
let idx = if self.write_index >= i + 1 {
self.write_index - i - 1
} else {
self.capacity - (i + 1 - self.write_index)
};
if idx < self.samples.len() {
total_in += self.samples[idx].bytes_in;
total_out += self.samples[idx].bytes_out;
}
}
(total_in / window as u64, total_out / window as u64)
}
/// Get instant throughput (last 1 second).
pub fn instant(&self) -> (u64, u64) {
self.throughput(1)
}
/// Get recent throughput (last 10 seconds).
pub fn recent(&self) -> (u64, u64) {
self.throughput(10)
}
/// How long this tracker has been alive.
pub fn uptime(&self) -> std::time::Duration {
self.created_at.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_throughput() {
let tracker = ThroughputTracker::new(60);
let (bytes_in, bytes_out) = tracker.throughput(10);
assert_eq!(bytes_in, 0);
assert_eq!(bytes_out, 0);
}
#[test]
fn test_single_sample() {
let mut tracker = ThroughputTracker::new(60);
tracker.record_bytes(1000, 2000);
tracker.sample();
let (bytes_in, bytes_out) = tracker.instant();
assert_eq!(bytes_in, 1000);
assert_eq!(bytes_out, 2000);
}
#[test]
fn test_circular_buffer_wrap() {
let mut tracker = ThroughputTracker::new(3); // Small capacity
for i in 0..5 {
tracker.record_bytes(i * 100, i * 200);
tracker.sample();
}
// Should still work after wrapping
let (bytes_in, bytes_out) = tracker.throughput(3);
assert!(bytes_in > 0);
assert!(bytes_out > 0);
}
#[test]
fn test_window_averaging() {
let mut tracker = ThroughputTracker::new(60);
// Record 3 samples of different sizes
tracker.record_bytes(100, 200);
tracker.sample();
tracker.record_bytes(200, 400);
tracker.sample();
tracker.record_bytes(300, 600);
tracker.sample();
// Average over 3 samples: (100+200+300)/3 = 200, (200+400+600)/3 = 400
let (avg_in, avg_out) = tracker.throughput(3);
assert_eq!(avg_in, 200);
assert_eq!(avg_out, 400);
}
#[test]
fn test_uptime_positive() {
let tracker = ThroughputTracker::new(60);
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(tracker.uptime().as_millis() >= 10);
}
}

View File

@@ -0,0 +1,17 @@
[package]
name = "rustproxy-nftables"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "NFTables kernel-level forwarding for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
libc = { workspace = true }

View File

@@ -0,0 +1,10 @@
//! # rustproxy-nftables
//!
//! NFTables kernel-level forwarding for RustProxy.
//! Generates and manages nft CLI rules for DNAT/SNAT.
pub mod nft_manager;
pub mod rule_builder;
pub use nft_manager::*;
pub use rule_builder::*;

View File

@@ -0,0 +1,238 @@
use thiserror::Error;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Error)]
pub enum NftError {
#[error("nft command failed: {0}")]
CommandFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Not running as root")]
NotRoot,
}
/// Manager for nftables rules.
///
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
/// Requires root privileges; operations are skipped gracefully if not root.
pub struct NftManager {
table_name: String,
/// Active rules indexed by route ID
active_rules: HashMap<String, Vec<String>>,
/// Whether the table has been initialized
table_initialized: bool,
}
impl NftManager {
pub fn new(table_name: Option<String>) -> Self {
Self {
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
active_rules: HashMap::new(),
table_initialized: false,
}
}
/// Check if we are running as root.
fn is_root() -> bool {
unsafe { libc::geteuid() == 0 }
}
/// Execute a single nft command via the CLI.
async fn exec_nft(command: &str) -> Result<String, NftError> {
// The command starts with "nft ", strip it to get the args
let args = if command.starts_with("nft ") {
&command[4..]
} else {
command
};
let output = tokio::process::Command::new("nft")
.args(args.split_whitespace())
.output()
.await
.map_err(NftError::Io)?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(NftError::CommandFailed(format!(
"Command '{}' failed: {}",
command, stderr
)))
}
}
/// Ensure the nftables table and chains are set up.
async fn ensure_table(&mut self) -> Result<(), NftError> {
if self.table_initialized {
return Ok(());
}
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
for cmd in &setup_commands {
Self::exec_nft(cmd).await?;
}
self.table_initialized = true;
info!("NFTables table '{}' initialized", self.table_name);
Ok(())
}
/// Apply rules for a route.
///
/// Executes the nft commands via the CLI. If not running as root,
/// the rules are stored locally but not applied to the kernel.
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, nftables rules will not be applied to kernel");
self.active_rules.insert(route_id.to_string(), rules);
return Ok(());
}
self.ensure_table().await?;
for cmd in &rules {
Self::exec_nft(cmd).await?;
debug!("Applied nft rule: {}", cmd);
}
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
self.active_rules.insert(route_id.to_string(), rules);
Ok(())
}
/// Remove rules for a route.
///
/// Currently removes the route from tracking. To fully remove specific
/// rules would require handle-based tracking; for now, cleanup() removes
/// the entire table.
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
if let Some(rules) = self.active_rules.remove(route_id) {
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
}
Ok(())
}
/// Clean up all managed rules by deleting the entire nftables table.
pub async fn cleanup(&mut self) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, skipping nftables cleanup");
self.active_rules.clear();
self.table_initialized = false;
return Ok(());
}
if self.table_initialized {
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
for cmd in &cleanup_commands {
match Self::exec_nft(cmd).await {
Ok(_) => debug!("Cleanup: {}", cmd),
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
}
}
info!("NFTables table '{}' cleaned up", self.table_name);
}
self.active_rules.clear();
self.table_initialized = false;
Ok(())
}
/// Get the table name.
pub fn table_name(&self) -> &str {
&self.table_name
}
/// Whether the table has been initialized in the kernel.
pub fn is_initialized(&self) -> bool {
self.table_initialized
}
/// Get the number of active route rule sets.
pub fn active_route_count(&self) -> usize {
self.active_rules.len()
}
/// Get the status of all active rules.
pub fn status(&self) -> HashMap<String, serde_json::Value> {
let mut status = HashMap::new();
for (route_id, rules) in &self.active_rules {
status.insert(
route_id.clone(),
serde_json::json!({
"ruleCount": rules.len(),
"rules": rules,
}),
);
}
status
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_default_table_name() {
let mgr = NftManager::new(None);
assert_eq!(mgr.table_name(), "rustproxy");
assert!(!mgr.is_initialized());
}
#[test]
fn test_new_custom_table_name() {
let mgr = NftManager::new(Some("custom".to_string()));
assert_eq!(mgr.table_name(), "custom");
}
#[tokio::test]
async fn test_apply_rules_non_root() {
let mut mgr = NftManager::new(None);
// When not root, rules are stored but not applied to kernel
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
let status = mgr.status();
assert!(status.contains_key("route-1"));
assert_eq!(status["route-1"]["ruleCount"], 1);
}
#[tokio::test]
async fn test_remove_rules() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
mgr.remove_rules("route-1").await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
}
#[tokio::test]
async fn test_cleanup_non_root() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
mgr.cleanup().await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
assert!(!mgr.is_initialized());
}
#[tokio::test]
async fn test_status_multiple_routes() {
let mut mgr = NftManager::new(None);
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
let status = mgr.status();
assert_eq!(status.len(), 2);
assert_eq!(status["web"]["ruleCount"], 2);
assert_eq!(status["api"]["ruleCount"], 1);
}
}

View File

@@ -0,0 +1,123 @@
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
/// Build nftables DNAT rule for port forwarding.
pub fn build_dnat_rule(
table_name: &str,
chain_name: &str,
source_port: u16,
target_host: &str,
target_port: u16,
options: &NfTablesOptions,
) -> Vec<String> {
let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
NfTablesProtocol::Tcp => "tcp",
NfTablesProtocol::Udp => "udp",
NfTablesProtocol::All => "tcp", // TODO: handle "all"
};
let mut rules = Vec::new();
// DNAT rule
rules.push(format!(
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
table_name, chain_name, protocol, source_port, target_host, target_port,
));
// SNAT rule if preserving source IP is not enabled
if !options.preserve_source_ip.unwrap_or(false) {
rules.push(format!(
"nft add rule ip {} postrouting {} dport {} masquerade",
table_name, protocol, target_port,
));
}
// Rate limiting
if let Some(max_rate) = &options.max_rate {
rules.push(format!(
"nft add rule ip {} {} {} dport {} limit rate {} accept",
table_name, chain_name, protocol, source_port, max_rate,
));
}
rules
}
/// Build the initial table and chain setup commands.
pub fn build_table_setup(table_name: &str) -> Vec<String> {
vec![
format!("nft add table ip {}", table_name),
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
]
}
/// Build cleanup commands to remove the table.
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
vec![format!("nft delete table ip {}", table_name)]
}
#[cfg(test)]
mod tests {
use super::*;
fn make_options() -> NfTablesOptions {
NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
}
}
#[test]
fn test_basic_dnat_rule() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.len() >= 1);
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
assert!(rules[0].contains("dport 443"));
}
#[test]
fn test_preserve_source_ip() {
let mut options = make_options();
options.preserve_source_ip = Some(true);
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
// When preserving source IP, no masquerade rule
assert!(rules.iter().all(|r| !r.contains("masquerade")));
}
#[test]
fn test_without_preserve_source_ip() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.iter().any(|r| r.contains("masquerade")));
}
#[test]
fn test_rate_limited_rule() {
let mut options = make_options();
options.max_rate = Some("100/second".to_string());
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
}
#[test]
fn test_table_setup_commands() {
let commands = build_table_setup("rustproxy");
assert_eq!(commands.len(), 3);
assert!(commands[0].contains("add table ip rustproxy"));
assert!(commands[1].contains("prerouting"));
assert!(commands[2].contains("postrouting"));
}
#[test]
fn test_table_cleanup() {
let commands = build_table_cleanup("rustproxy");
assert_eq!(commands.len(), 1);
assert!(commands[0].contains("delete table ip rustproxy"));
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "rustproxy-passthrough"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Raw TCP/SNI passthrough engine for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
dashmap = { workspace = true }
arc-swap = { workspace = true }
rustproxy-http = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tokio-util = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@@ -0,0 +1,155 @@
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
/// Per-connection tracking record with atomics for lock-free updates.
///
/// Each field uses atomics so that the forwarding tasks can update
/// bytes_received / bytes_sent / last_activity without holding any lock,
/// while the zombie scanner reads them concurrently.
pub struct ConnectionRecord {
/// Unique connection ID assigned by the ConnectionTracker.
pub id: u64,
/// Wall-clock instant when this connection was created.
pub created_at: Instant,
/// Milliseconds since `created_at` when the last activity occurred.
/// Updated atomically by the forwarding loops.
pub last_activity: AtomicU64,
/// Total bytes received from the client (inbound).
pub bytes_received: AtomicU64,
/// Total bytes sent to the client (outbound / from backend).
pub bytes_sent: AtomicU64,
/// True once the client side of the connection has closed.
pub client_closed: AtomicBool,
/// True once the backend side of the connection has closed.
pub backend_closed: AtomicBool,
/// Whether this connection uses TLS (affects zombie thresholds).
pub is_tls: AtomicBool,
/// Whether this connection has keep-alive semantics.
pub has_keep_alive: AtomicBool,
}
impl ConnectionRecord {
/// Create a new connection record with the given ID.
/// All counters start at zero, all flags start as false.
pub fn new(id: u64) -> Self {
Self {
id,
created_at: Instant::now(),
last_activity: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
client_closed: AtomicBool::new(false),
backend_closed: AtomicBool::new(false),
is_tls: AtomicBool::new(false),
has_keep_alive: AtomicBool::new(false),
}
}
/// Update `last_activity` to reflect the current elapsed time.
pub fn touch(&self) {
let elapsed_ms = self.created_at.elapsed().as_millis() as u64;
self.last_activity.store(elapsed_ms, Ordering::Relaxed);
}
/// Record `n` bytes received from the client (inbound).
pub fn record_bytes_in(&self, n: u64) {
self.bytes_received.fetch_add(n, Ordering::Relaxed);
self.touch();
}
/// Record `n` bytes sent to the client (outbound / from backend).
pub fn record_bytes_out(&self, n: u64) {
self.bytes_sent.fetch_add(n, Ordering::Relaxed);
self.touch();
}
/// How long since the last activity on this connection.
pub fn idle_duration(&self) -> Duration {
let last_ms = self.last_activity.load(Ordering::Relaxed);
let age_ms = self.created_at.elapsed().as_millis() as u64;
Duration::from_millis(age_ms.saturating_sub(last_ms))
}
/// Total age of this connection (time since creation).
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_new_record() {
let record = ConnectionRecord::new(42);
assert_eq!(record.id, 42);
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 0);
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 0);
assert!(!record.client_closed.load(Ordering::Relaxed));
assert!(!record.backend_closed.load(Ordering::Relaxed));
assert!(!record.is_tls.load(Ordering::Relaxed));
assert!(!record.has_keep_alive.load(Ordering::Relaxed));
}
#[test]
fn test_record_bytes() {
let record = ConnectionRecord::new(1);
record.record_bytes_in(100);
record.record_bytes_in(200);
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 300);
record.record_bytes_out(50);
record.record_bytes_out(75);
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 125);
}
#[test]
fn test_touch_updates_activity() {
let record = ConnectionRecord::new(1);
assert_eq!(record.last_activity.load(Ordering::Relaxed), 0);
// Sleep briefly so elapsed time is nonzero
thread::sleep(Duration::from_millis(10));
record.touch();
let activity = record.last_activity.load(Ordering::Relaxed);
assert!(activity >= 10, "last_activity should be at least 10ms, got {}", activity);
}
#[test]
fn test_idle_duration() {
let record = ConnectionRecord::new(1);
// Initially idle_duration ~ age since last_activity is 0
thread::sleep(Duration::from_millis(20));
let idle = record.idle_duration();
assert!(idle >= Duration::from_millis(20));
// After touch, idle should be near zero
record.touch();
let idle = record.idle_duration();
assert!(idle < Duration::from_millis(10));
}
#[test]
fn test_age() {
let record = ConnectionRecord::new(1);
thread::sleep(Duration::from_millis(20));
let age = record.age();
assert!(age >= Duration::from_millis(20));
}
#[test]
fn test_flags() {
let record = ConnectionRecord::new(1);
record.client_closed.store(true, Ordering::Relaxed);
record.is_tls.store(true, Ordering::Relaxed);
record.has_keep_alive.store(true, Ordering::Relaxed);
assert!(record.client_closed.load(Ordering::Relaxed));
assert!(!record.backend_closed.load(Ordering::Relaxed));
assert!(record.is_tls.load(Ordering::Relaxed));
assert!(record.has_keep_alive.load(Ordering::Relaxed));
}
}

View File

@@ -0,0 +1,402 @@
use dashmap::DashMap;
use std::collections::VecDeque;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use super::connection_record::ConnectionRecord;
/// Thresholds for zombie detection (non-TLS connections).
const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30);
/// Thresholds for zombie detection (TLS connections).
const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300);
/// Stuck connection timeout (non-TLS): received data but never sent any.
const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60);
/// Stuck connection timeout (TLS): received data but never sent any.
const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300);
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
/// Also maintains per-connection records for zombie detection.
pub struct ConnectionTracker {
/// Active connection counts per IP
active: DashMap<IpAddr, AtomicU64>,
/// Connection timestamps per IP for rate limiting
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
/// Maximum concurrent connections per IP (None = unlimited)
max_per_ip: Option<u64>,
/// Maximum new connections per minute per IP (None = unlimited)
rate_limit_per_minute: Option<u64>,
/// Per-connection tracking records for zombie detection
connections: DashMap<u64, Arc<ConnectionRecord>>,
/// Monotonically increasing connection ID counter
next_id: AtomicU64,
}
impl ConnectionTracker {
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
Self {
active: DashMap::new(),
timestamps: DashMap::new(),
max_per_ip,
rate_limit_per_minute,
connections: DashMap::new(),
next_id: AtomicU64::new(1),
}
}
/// Try to accept a new connection from the given IP.
/// Returns true if allowed, false if over limit.
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
if count >= max {
return false;
}
}
// Check rate limit
if let Some(rate_limit) = self.rate_limit_per_minute {
let now = Instant::now();
let one_minute = std::time::Duration::from_secs(60);
let mut entry = self.timestamps.entry(*ip).or_default();
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
timestamps.pop_front();
}
if timestamps.len() as u64 >= rate_limit {
return false;
}
timestamps.push_back(now);
}
true
}
/// Record that a connection was opened from the given IP.
pub fn connection_opened(&self, ip: &IpAddr) {
self.active
.entry(*ip)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection was closed from the given IP.
pub fn connection_closed(&self, ip: &IpAddr) {
if let Some(counter) = self.active.get(ip) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Clean up zero entries
if prev <= 1 {
drop(counter);
self.active.remove(ip);
}
}
}
/// Get the current number of active connections for an IP.
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
}
/// Get the total number of tracked IPs.
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
/// Register a new connection and return its tracking record.
///
/// The returned `Arc<ConnectionRecord>` should be passed to the forwarding
/// loop so it can update bytes / activity atomics in real time.
pub fn register_connection(&self, is_tls: bool) -> Arc<ConnectionRecord> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let record = Arc::new(ConnectionRecord::new(id));
record.is_tls.store(is_tls, Ordering::Relaxed);
self.connections.insert(id, Arc::clone(&record));
record
}
/// Remove a connection record when the connection is fully closed.
pub fn unregister_connection(&self, id: u64) {
self.connections.remove(&id);
}
/// Scan all tracked connections and return IDs of zombie connections.
///
/// A connection is considered a zombie in any of these cases:
/// - **Full zombie**: both `client_closed` and `backend_closed` are true.
/// - **Half zombie**: one side closed for longer than the threshold
/// (5 min for TLS, 30s for non-TLS).
/// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer
/// than the stuck threshold (5 min for TLS, 60s for non-TLS).
pub fn scan_zombies(&self) -> Vec<u64> {
let mut zombies = Vec::new();
for entry in self.connections.iter() {
let record = entry.value();
let id = *entry.key();
let is_tls = record.is_tls.load(Ordering::Relaxed);
let client_closed = record.client_closed.load(Ordering::Relaxed);
let backend_closed = record.backend_closed.load(Ordering::Relaxed);
let idle = record.idle_duration();
let bytes_in = record.bytes_received.load(Ordering::Relaxed);
let bytes_out = record.bytes_sent.load(Ordering::Relaxed);
// Full zombie: both sides closed
if client_closed && backend_closed {
zombies.push(id);
continue;
}
// Half zombie: one side closed for too long
let half_timeout = if is_tls {
HALF_ZOMBIE_TIMEOUT_TLS
} else {
HALF_ZOMBIE_TIMEOUT_PLAIN
};
if (client_closed || backend_closed) && idle >= half_timeout {
zombies.push(id);
continue;
}
// Stuck: received data but never sent anything for too long
let stuck_timeout = if is_tls {
STUCK_TIMEOUT_TLS
} else {
STUCK_TIMEOUT_PLAIN
};
if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout {
zombies.push(id);
}
}
zombies
}
/// Start a background task that periodically scans for zombie connections.
///
/// The scanner runs every 10 seconds and logs any zombies it finds.
/// It stops when the provided `CancellationToken` is cancelled.
pub fn start_zombie_scanner(self: &Arc<Self>, cancel: CancellationToken) {
let tracker = Arc::clone(self);
tokio::spawn(async move {
let interval = Duration::from_secs(10);
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Zombie scanner shutting down");
break;
}
_ = tokio::time::sleep(interval) => {
let zombies = tracker.scan_zombies();
if !zombies.is_empty() {
warn!(
"Detected {} zombie connection(s): {:?}",
zombies.len(),
zombies
);
}
}
}
}
});
}
/// Get the total number of tracked connections (with records).
pub fn total_connections(&self) -> usize {
self.connections.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_tracking() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 2);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 0);
}
#[test]
fn test_per_ip_limit() {
let tracker = ConnectionTracker::new(Some(2), None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
// Third connection should be rejected
assert!(!tracker.try_accept(&ip));
// Different IP should still be allowed
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
assert!(tracker.try_accept(&ip2));
}
#[test]
fn test_rate_limit() {
let tracker = ConnectionTracker::new(None, Some(3));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
// 4th attempt within the minute should be rejected
assert!(!tracker.try_accept(&ip));
}
#[test]
fn test_no_limits() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..1000 {
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
}
assert_eq!(tracker.active_connections(&ip), 1000);
}
#[test]
fn test_tracked_ips() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.tracked_ips(), 0);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
tracker.connection_opened(&ip1);
tracker.connection_opened(&ip2);
assert_eq!(tracker.tracked_ips(), 2);
tracker.connection_closed(&ip1);
assert_eq!(tracker.tracked_ips(), 1);
}
#[test]
fn test_register_unregister_connection() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.total_connections(), 0);
let record1 = tracker.register_connection(false);
assert_eq!(tracker.total_connections(), 1);
assert!(!record1.is_tls.load(Ordering::Relaxed));
let record2 = tracker.register_connection(true);
assert_eq!(tracker.total_connections(), 2);
assert!(record2.is_tls.load(Ordering::Relaxed));
// IDs should be unique
assert_ne!(record1.id, record2.id);
tracker.unregister_connection(record1.id);
assert_eq!(tracker.total_connections(), 1);
tracker.unregister_connection(record2.id);
assert_eq!(tracker.total_connections(), 0);
}
#[test]
fn test_full_zombie_detection() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
// Not a zombie initially
assert!(tracker.scan_zombies().is_empty());
// Set both sides closed -> full zombie
record.client_closed.store(true, Ordering::Relaxed);
record.backend_closed.store(true, Ordering::Relaxed);
let zombies = tracker.scan_zombies();
assert_eq!(zombies.len(), 1);
assert_eq!(zombies[0], record.id);
}
#[test]
fn test_half_zombie_not_triggered_immediately() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
record.touch(); // mark activity now
// Only one side closed, but just now -> not a zombie yet
record.client_closed.store(true, Ordering::Relaxed);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_stuck_connection_not_triggered_immediately() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
record.touch(); // mark activity now
// Has received data but sent nothing -> but just started, not stuck yet
record.bytes_received.store(1000, Ordering::Relaxed);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_unregister_removes_from_zombie_scan() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
let id = record.id;
// Make it a full zombie
record.client_closed.store(true, Ordering::Relaxed);
record.backend_closed.store(true, Ordering::Relaxed);
assert_eq!(tracker.scan_zombies().len(), 1);
// Unregister should remove it
tracker.unregister_connection(id);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_total_connections() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.total_connections(), 0);
let r1 = tracker.register_connection(false);
let r2 = tracker.register_connection(true);
let r3 = tracker.register_connection(false);
assert_eq!(tracker.total_connections(), 3);
tracker.unregister_connection(r2.id);
assert_eq!(tracker.total_connections(), 2);
tracker.unregister_connection(r1.id);
tracker.unregister_connection(r3.id);
assert_eq!(tracker.total_connections(), 0);
}
}

View File

@@ -0,0 +1,325 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
use super::connection_record::ConnectionRecord;
/// Statistics for a forwarded connection.
#[derive(Debug, Default)]
pub struct ForwardStats {
pub bytes_in: AtomicU64,
pub bytes_out: AtomicU64,
}
/// Perform bidirectional TCP forwarding between client and backend.
///
/// This is the core data path for passthrough connections.
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
pub async fn forward_bidirectional(
mut client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
}
let (mut client_read, mut client_write) = client.split();
let (mut backend_read, mut backend_write) = backend.split();
let client_to_backend = async {
let mut buf = vec![0u8; 65536];
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
loop {
let n = client_read.read(&mut buf).await?;
if n == 0 {
break;
}
backend_write.write_all(&buf[..n]).await?;
total += n as u64;
}
backend_write.shutdown().await?;
Ok::<u64, std::io::Error>(total)
};
let backend_to_client = async {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = backend_read.read(&mut buf).await?;
if n == 0 {
break;
}
client_write.write_all(&buf[..n]).await?;
total += n as u64;
}
client_write.shutdown().await?;
Ok::<u64, std::io::Error>(total)
};
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
}
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
///
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
pub async fn forward_bidirectional_with_timeouts(
client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
cancel: CancellationToken,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
}
let (mut client_read, mut client_write) = client.into_split();
let (mut backend_read, mut backend_write) = backend.into_split();
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = initial_len;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = backend_write.shutdown().await;
total
});
let la2 = Arc::clone(&last_activity);
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = client_write.shutdown().await;
total
});
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Connection cancelled by shutdown");
c2b_handle.abort();
b2c_handle.abort();
break;
}
_ = tokio::time::sleep(check_interval) => {
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
}
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}
/// Forward bidirectional with a callback for byte counting.
pub async fn forward_bidirectional_with_stats(
client: TcpStream,
backend: TcpStream,
initial_data: Option<&[u8]>,
stats: Arc<ForwardStats>,
) -> std::io::Result<()> {
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
Ok(())
}
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
/// updating a `ConnectionRecord` with byte counts and activity timestamps
/// in real time for zombie detection.
///
/// When `record` is `None`, this behaves identically to
/// `forward_bidirectional_with_timeouts`.
///
/// The record's `client_closed` / `backend_closed` flags are set when the
/// respective copy loop terminates, giving the zombie scanner visibility
/// into half-open connections.
pub async fn forward_bidirectional_with_record(
client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
cancel: CancellationToken,
record: Option<Arc<ConnectionRecord>>,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
if let Some(ref r) = record {
r.record_bytes_in(data.len() as u64);
}
}
let (mut client_read, mut client_write) = client.into_split();
let (mut backend_read, mut backend_write) = backend.into_split();
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
let rec1 = record.clone();
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = initial_len;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
let now_ms = start.elapsed().as_millis() as u64;
la1.store(now_ms, Ordering::Relaxed);
if let Some(ref r) = rec1 {
r.record_bytes_in(n as u64);
}
}
let _ = backend_write.shutdown().await;
// Mark client side as closed
if let Some(ref r) = rec1 {
r.client_closed.store(true, Ordering::Relaxed);
}
total
});
let la2 = Arc::clone(&last_activity);
let rec2 = record.clone();
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
let now_ms = start.elapsed().as_millis() as u64;
la2.store(now_ms, Ordering::Relaxed);
if let Some(ref r) = rec2 {
r.record_bytes_out(n as u64);
}
}
let _ = client_write.shutdown().await;
// Mark backend side as closed
if let Some(ref r) = rec2 {
r.backend_closed.store(true, Ordering::Relaxed);
}
total
});
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Connection cancelled by shutdown");
c2b_handle.abort();
b2c_handle.abort();
break;
}
_ = tokio::time::sleep(check_interval) => {
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
}
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}

View File

@@ -0,0 +1,22 @@
//! # rustproxy-passthrough
//!
//! Raw TCP/SNI passthrough engine for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
pub mod tcp_listener;
pub mod sni_parser;
pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_record;
pub mod connection_tracker;
pub mod socket_relay;
pub use tcp_listener::*;
pub use sni_parser::*;
pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_record::*;
pub use connection_tracker::*;
pub use socket_relay::*;

View File

@@ -0,0 +1,129 @@
use std::net::SocketAddr;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProxyProtocolError {
#[error("Invalid PROXY protocol header")]
InvalidHeader,
#[error("Unsupported PROXY protocol version")]
UnsupportedVersion,
#[error("Parse error: {0}")]
Parse(String),
}
/// Parsed PROXY protocol v1 header.
#[derive(Debug, Clone)]
pub struct ProxyProtocolHeader {
pub source_addr: SocketAddr,
pub dest_addr: SocketAddr,
pub protocol: ProxyProtocol,
}
/// Protocol in PROXY header.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProxyProtocol {
Tcp4,
Tcp6,
Unknown,
}
/// Parse a PROXY protocol v1 header from data.
///
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
// Find the end of the header line
let line_end = data
.windows(2)
.position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end])
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader);
}
let parts: Vec<&str> = line.split(' ').collect();
if parts.len() != 6 {
return Err(ProxyProtocolError::InvalidHeader);
}
let protocol = match parts[1] {
"TCP4" => ProxyProtocol::Tcp4,
"TCP6" => ProxyProtocol::Tcp6,
"UNKNOWN" => ProxyProtocol::Unknown,
_ => return Err(ProxyProtocolError::UnsupportedVersion),
};
let src_ip: std::net::IpAddr = parts[2]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
let dst_ip: std::net::IpAddr = parts[3]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
let src_port: u16 = parts[4]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
let dst_port: u16 = parts[5]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
let header = ProxyProtocolHeader {
source_addr: SocketAddr::new(src_ip, src_port),
dest_addr: SocketAddr::new(dst_ip, dst_port),
protocol,
};
// Consumed bytes = line + \r\n
Ok((header, line_end + 2))
}
/// Generate a PROXY protocol v1 header string.
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
format!(
"PROXY {} {} {} {} {}\r\n",
proto,
source.ip(),
dest.ip(),
source.port(),
dest.port()
)
}
/// Check if data starts with a PROXY protocol v1 header.
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
data.starts_with(b"PROXY ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_v1_tcp4() {
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
let (parsed, consumed) = parse_v1(header).unwrap();
assert_eq!(consumed, header.len());
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
assert_eq!(parsed.source_addr.port(), 12345);
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
assert_eq!(parsed.dest_addr.port(), 443);
}
#[test]
fn test_generate_v1() {
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
let header = generate_v1(&source, &dest);
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
}
#[test]
fn test_is_proxy_protocol() {
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
}
}

View File

@@ -0,0 +1,287 @@
//! ClientHello SNI extraction via manual byte parsing.
//! No TLS stack needed - we just parse enough of the ClientHello to extract the SNI.
/// Result of SNI extraction.
#[derive(Debug)]
pub enum SniResult {
/// Successfully extracted SNI hostname.
Found(String),
/// TLS ClientHello detected but no SNI extension present.
NoSni,
/// Not a TLS ClientHello (plain HTTP or other protocol).
NotTls,
/// Need more data to determine.
NeedMoreData,
}
/// Extract the SNI hostname from a TLS ClientHello message.
///
/// This parses just enough of the TLS record to find the SNI extension,
/// without performing any actual TLS operations.
pub fn extract_sni(data: &[u8]) -> SniResult {
// Minimum TLS record header is 5 bytes
if data.len() < 5 {
return SniResult::NeedMoreData;
}
// Check for TLS record: content_type=22 (Handshake)
if data[0] != 0x16 {
return SniResult::NotTls;
}
// TLS version (major.minor) - accept any
// data[1..2] = version
// Record length
let record_len = ((data[3] as usize) << 8) | (data[4] as usize);
let _total_len = 5 + record_len;
// We need at least the handshake header (5 TLS + 4 handshake = 9)
if data.len() < 9 {
return SniResult::NeedMoreData;
}
// Handshake type = 1 (ClientHello)
if data[5] != 0x01 {
return SniResult::NotTls;
}
// Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16)
| ((data[7] as usize) << 8)
| (data[8] as usize);
let hello = &data[9..];
// ClientHello structure:
// 2 bytes: client version
// 32 bytes: random
// 1 byte: session_id length + session_id
let mut pos = 2 + 32; // skip version + random
if pos >= hello.len() {
return SniResult::NeedMoreData;
}
// Session ID
let session_id_len = hello[pos] as usize;
pos += 1 + session_id_len;
if pos + 2 > hello.len() {
return SniResult::NeedMoreData;
}
// Cipher suites
let cipher_suites_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
pos += 2 + cipher_suites_len;
if pos + 1 > hello.len() {
return SniResult::NeedMoreData;
}
// Compression methods
let compression_len = hello[pos] as usize;
pos += 1 + compression_len;
if pos + 2 > hello.len() {
// No extensions
return SniResult::NoSni;
}
// Extensions length
let extensions_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
pos += 2;
let extensions_end = pos + extensions_len;
if extensions_end > hello.len() {
// Partial extensions, try to parse what we have
}
// Parse extensions looking for SNI (type 0x0000)
while pos + 4 <= hello.len() && pos < extensions_end {
let ext_type = ((hello[pos] as u16) << 8) | (hello[pos + 1] as u16);
let ext_len = ((hello[pos + 2] as usize) << 8) | (hello[pos + 3] as usize);
pos += 4;
if ext_type == 0x0000 {
// SNI extension
return parse_sni_extension(&hello[pos..(pos + ext_len).min(hello.len())], ext_len);
}
pos += ext_len;
}
SniResult::NoSni
}
/// Parse the SNI extension data.
fn parse_sni_extension(data: &[u8], _ext_len: usize) -> SniResult {
if data.len() < 5 {
return SniResult::NeedMoreData;
}
// Server name list length
let _list_len = ((data[0] as usize) << 8) | (data[1] as usize);
// Server name type (0 = hostname)
if data[2] != 0x00 {
return SniResult::NoSni;
}
// Hostname length
let name_len = ((data[3] as usize) << 8) | (data[4] as usize);
if data.len() < 5 + name_len {
return SniResult::NeedMoreData;
}
match std::str::from_utf8(&data[5..5 + name_len]) {
Ok(hostname) => SniResult::Found(hostname.to_lowercase()),
Err(_) => SniResult::NoSni,
}
}
/// Check if the initial bytes look like a TLS ClientHello.
pub fn is_tls(data: &[u8]) -> bool {
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
}
/// Check if the initial bytes look like HTTP.
pub fn is_http(data: &[u8]) -> bool {
if data.len() < 4 {
return false;
}
// Check for common HTTP methods
let starts = [
b"GET " as &[u8],
b"POST",
b"PUT ",
b"HEAD",
b"DELE",
b"PATC",
b"OPTI",
b"CONN",
];
starts.iter().any(|s| data.starts_with(s))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_not_tls() {
let http_data = b"GET / HTTP/1.1\r\n";
assert!(matches!(extract_sni(http_data), SniResult::NotTls));
}
#[test]
fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
}
#[test]
fn test_is_tls() {
assert!(is_tls(&[0x16, 0x03, 0x01]));
assert!(!is_tls(&[0x47, 0x45, 0x54])); // "GET"
}
#[test]
fn test_is_http() {
assert!(is_http(b"GET /"));
assert!(is_http(b"POST /api"));
assert!(!is_http(&[0x16, 0x03, 0x01]));
}
#[test]
fn test_real_client_hello() {
// A minimal TLS 1.2 ClientHello with SNI "example.com"
let client_hello: Vec<u8> = build_test_client_hello("example.com");
match extract_sni(&client_hello) {
SniResult::Found(sni) => assert_eq!(sni, "example.com"),
other => panic!("Expected Found, got {:?}", other),
}
}
/// Build a minimal TLS ClientHello for testing.
fn build_test_client_hello(hostname: &str) -> Vec<u8> {
let hostname_bytes = hostname.as_bytes();
// SNI extension
let sni_ext_data = {
let mut d = Vec::new();
// Server name list length
let name_entry_len = 3 + hostname_bytes.len(); // type(1) + len(2) + name
d.push(((name_entry_len >> 8) & 0xFF) as u8);
d.push((name_entry_len & 0xFF) as u8);
// Host name type = 0
d.push(0x00);
// Host name length
d.push(((hostname_bytes.len() >> 8) & 0xFF) as u8);
d.push((hostname_bytes.len() & 0xFF) as u8);
// Host name
d.extend_from_slice(hostname_bytes);
d
};
// Extension: type=0x0000 (SNI), length, data
let sni_extension = {
let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data);
e
};
// Extensions block
let extensions = {
let mut ext = Vec::new();
ext.push(((sni_extension.len() >> 8) & 0xFF) as u8);
ext.push((sni_extension.len() & 0xFF) as u8);
ext.extend_from_slice(&sni_extension);
ext
};
// ClientHello body
let hello_body = {
let mut h = Vec::new();
// Client version TLS 1.2
h.push(0x03); h.push(0x03);
// Random (32 bytes)
h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0
h.push(0x00);
// Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01); h.push(0x00);
// Extensions
h.extend_from_slice(&extensions);
h
};
// Handshake: type=1 (ClientHello), length
let handshake = {
let mut hs = Vec::new();
hs.push(0x01); // ClientHello
// 3-byte length
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
hs.push((hello_body.len() & 0xFF) as u8);
hs.extend_from_slice(&hello_body);
hs
};
// TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new();
record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake);
record
}
}

View File

@@ -0,0 +1,126 @@
//! Socket handler relay for connecting client connections to a TypeScript handler
//! via a Unix domain socket.
//!
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
use tokio::net::UnixStream;
use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::net::TcpStream;
use serde::Serialize;
use tracing::debug;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct RelayMetadata {
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data_base64: Option<String>,
}
/// Relay a client connection to a TypeScript handler via Unix domain socket.
///
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
pub async fn relay_to_handler(
client: TcpStream,
relay_socket_path: &str,
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data: Option<&[u8]>,
) -> std::io::Result<()> {
debug!(
"Relaying connection {} to handler socket {}",
connection_id, relay_socket_path
);
// Connect to TypeScript handler Unix socket
let mut handler = UnixStream::connect(relay_socket_path).await?;
// Build and send metadata header
let initial_data_base64 = initial_data.map(base64_encode);
let metadata = RelayMetadata {
connection_id,
remote_ip,
remote_port,
local_port,
sni,
route_name,
initial_data_base64,
};
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
handler.write_all(metadata_json.as_bytes()).await?;
handler.write_all(b"\n").await?;
// Bidirectional relay between client and handler
let (mut client_read, mut client_write) = client.into_split();
let (mut handler_read, mut handler_write) = handler.into_split();
let c2h = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if handler_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = handler_write.shutdown().await;
});
let h2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match handler_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = client_write.shutdown().await;
});
let _ = tokio::join!(c2h, h2c);
debug!("Relay connection {} completed", connection_id);
Ok(())
}
/// Simple base64 encoding without external dependency.
fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in data.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(n & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}

View File

@@ -0,0 +1,874 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{info, error, debug, warn};
use thiserror::Error;
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use rustproxy_http::HttpProxyService;
use crate::sni_parser;
use crate::forwarder;
use crate::tls_handler;
use crate::connection_tracker::ConnectionTracker;
#[derive(Debug, Error)]
pub enum ListenerError {
#[error("Failed to bind port {port}: {source}")]
BindFailed { port: u16, source: std::io::Error },
#[error("Port {0} already bound")]
AlreadyBound(u16),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
/// TLS configuration for a specific domain.
#[derive(Clone)]
pub struct TlsCertConfig {
pub cert_pem: String,
pub key_pem: String,
}
/// Timeout and connection management configuration.
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
/// Timeout for establishing connection to backend (ms)
pub connection_timeout_ms: u64,
/// Timeout for initial data/SNI peek (ms)
pub initial_data_timeout_ms: u64,
/// Socket inactivity timeout (ms)
pub socket_timeout_ms: u64,
/// Maximum connection lifetime (ms)
pub max_connection_lifetime_ms: u64,
/// Graceful shutdown timeout (ms)
pub graceful_shutdown_timeout_ms: u64,
/// Maximum connections per IP (None = unlimited)
pub max_connections_per_ip: Option<u64>,
/// Connection rate limit per minute per IP (None = unlimited)
pub connection_rate_limit_per_minute: Option<u64>,
/// Keep-alive treatment
pub keep_alive_treatment: Option<rustproxy_config::KeepAliveTreatment>,
/// Inactivity multiplier for keep-alive connections
pub keep_alive_inactivity_multiplier: Option<f64>,
/// Extended keep-alive lifetime (ms) for Extended treatment mode
pub extended_keep_alive_lifetime_ms: Option<u64>,
/// Whether to accept PROXY protocol
pub accept_proxy_protocol: bool,
/// Whether to send PROXY protocol
pub send_proxy_protocol: bool,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
connection_timeout_ms: 30_000,
initial_data_timeout_ms: 60_000,
socket_timeout_ms: 3_600_000,
max_connection_lifetime_ms: 86_400_000,
graceful_shutdown_timeout_ms: 30_000,
max_connections_per_ip: None,
connection_rate_limit_per_minute: None,
keep_alive_treatment: None,
keep_alive_inactivity_multiplier: None,
extended_keep_alive_lifetime_ms: None,
accept_proxy_protocol: false,
send_proxy_protocol: false,
}
}
}
/// Manages TCP listeners for all configured ports.
pub struct TcpListenerManager {
/// Active listeners indexed by port
listeners: HashMap<u16, tokio::task::JoinHandle<()>>,
/// Shared route manager
route_manager: Arc<RouteManager>,
/// Shared metrics collector
metrics: Arc<MetricsCollector>,
/// TLS acceptors indexed by domain
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
/// HTTP proxy service for HTTP-level forwarding
http_proxy: Arc<HttpProxyService>,
/// Connection configuration
conn_config: Arc<ConnectionConfig>,
/// Connection tracker for per-IP limits
conn_tracker: Arc<ConnectionTracker>,
/// Cancellation token for graceful shutdown
cancel_token: CancellationToken,
}
impl TcpListenerManager {
pub fn new(route_manager: Arc<RouteManager>) -> Self {
let metrics = Arc::new(MetricsCollector::new());
let http_proxy = Arc::new(HttpProxyService::new(
Arc::clone(&route_manager),
Arc::clone(&metrics),
));
let conn_config = ConnectionConfig::default();
let conn_tracker = Arc::new(ConnectionTracker::new(
conn_config.max_connections_per_ip,
conn_config.connection_rate_limit_per_minute,
));
Self {
listeners: HashMap::new(),
route_manager,
metrics,
tls_configs: Arc::new(HashMap::new()),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
cancel_token: CancellationToken::new(),
}
}
/// Create with a metrics collector.
pub fn with_metrics(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
let http_proxy = Arc::new(HttpProxyService::new(
Arc::clone(&route_manager),
Arc::clone(&metrics),
));
let conn_config = ConnectionConfig::default();
let conn_tracker = Arc::new(ConnectionTracker::new(
conn_config.max_connections_per_ip,
conn_config.connection_rate_limit_per_minute,
));
Self {
listeners: HashMap::new(),
route_manager,
metrics,
tls_configs: Arc::new(HashMap::new()),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
cancel_token: CancellationToken::new(),
}
}
/// Set connection configuration.
pub fn set_connection_config(&mut self, config: ConnectionConfig) {
self.conn_tracker = Arc::new(ConnectionTracker::new(
config.max_connections_per_ip,
config.connection_rate_limit_per_minute,
));
self.conn_config = Arc::new(config);
}
/// Set TLS certificate configurations.
pub fn set_tls_configs(&mut self, configs: HashMap<String, TlsCertConfig>) {
self.tls_configs = Arc::new(configs);
}
/// Start listening on a port.
pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> {
if self.listeners.contains_key(&port) {
return Err(ListenerError::AlreadyBound(port));
}
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await.map_err(|e| {
ListenerError::BindFailed { port, source: e }
})?;
info!("Listening on port {}", port);
let route_manager = Arc::clone(&self.route_manager);
let metrics = Arc::clone(&self.metrics);
let tls_configs = Arc::clone(&self.tls_configs);
let http_proxy = Arc::clone(&self.http_proxy);
let conn_config = Arc::clone(&self.conn_config);
let conn_tracker = Arc::clone(&self.conn_tracker);
let cancel = self.cancel_token.clone();
let handle = tokio::spawn(async move {
Self::accept_loop(
listener, port, route_manager, metrics, tls_configs,
http_proxy, conn_config, conn_tracker, cancel,
).await;
});
self.listeners.insert(port, handle);
Ok(())
}
/// Stop listening on a port.
pub fn remove_port(&mut self, port: u16) -> bool {
if let Some(handle) = self.listeners.remove(&port) {
handle.abort();
info!("Stopped listening on port {}", port);
true
} else {
false
}
}
/// Get all currently listening ports.
pub fn listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.listeners.keys().copied().collect();
ports.sort();
ports
}
/// Stop all listeners gracefully.
///
/// Signals cancellation and waits up to `graceful_shutdown_timeout_ms` for
/// connections to drain, then aborts remaining tasks.
pub async fn graceful_stop(&mut self) {
let timeout_ms = self.conn_config.graceful_shutdown_timeout_ms;
info!("Initiating graceful shutdown (timeout: {}ms)", timeout_ms);
// Signal all accept loops to stop accepting new connections
self.cancel_token.cancel();
// Wait for existing connections to drain
let timeout = std::time::Duration::from_millis(timeout_ms);
let deadline = tokio::time::Instant::now() + timeout;
for (port, handle) in self.listeners.drain() {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
handle.abort();
warn!("Force-stopped listener on port {} (timeout exceeded)", port);
} else {
match tokio::time::timeout(remaining, handle).await {
Ok(_) => info!("Listener on port {} stopped gracefully", port),
Err(_) => {
warn!("Listener on port {} did not stop in time, aborting", port);
}
}
}
}
// Reset cancellation token for potential restart
self.cancel_token = CancellationToken::new();
info!("Graceful shutdown complete");
}
/// Stop all listeners immediately (backward compatibility).
pub fn stop_all(&mut self) {
self.cancel_token.cancel();
for (port, handle) in self.listeners.drain() {
handle.abort();
info!("Stopped listening on port {}", port);
}
self.cancel_token = CancellationToken::new();
}
/// Update the route manager (for hot-reload).
pub fn update_route_manager(&mut self, route_manager: Arc<RouteManager>) {
self.route_manager = route_manager;
}
/// Get a reference to the metrics collector.
pub fn metrics(&self) -> &Arc<MetricsCollector> {
&self.metrics
}
/// Accept loop for a single port.
async fn accept_loop(
listener: TcpListener,
port: u16,
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
conn_tracker: Arc<ConnectionTracker>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("Accept loop on port {} shutting down", port);
break;
}
result = listener.accept() => {
match result {
Ok((stream, peer_addr)) => {
let ip = peer_addr.ip();
// Check per-IP limits and rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("Rejected connection from {} (per-IP limit or rate limit)", peer_addr);
drop(stream);
continue;
}
conn_tracker.connection_opened(&ip);
let rm = Arc::clone(&route_manager);
let m = Arc::clone(&metrics);
let tc = Arc::clone(&tls_configs);
let hp = Arc::clone(&http_proxy);
let cc = Arc::clone(&conn_config);
let ct = Arc::clone(&conn_tracker);
let cn = cancel.clone();
debug!("Accepted connection from {} on port {}", peer_addr, port);
tokio::spawn(async move {
let result = Self::handle_connection(
stream, port, peer_addr, rm, m, tc, hp, cc, cn,
).await;
if let Err(e) = result {
debug!("Connection error from {}: {}", peer_addr, e);
}
ct.connection_closed(&ip);
});
}
Err(e) => {
error!("Accept error on port {}: {}", port, e);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
}
}
}
/// Handle a single incoming connection.
async fn handle_connection(
mut stream: tokio::net::TcpStream,
port: u16,
peer_addr: std::net::SocketAddr,
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
cancel: CancellationToken,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use tokio::io::AsyncReadExt;
stream.set_nodelay(true)?;
// Handle PROXY protocol if configured
let mut effective_peer_addr = peer_addr;
if conn_config.accept_proxy_protocol {
let mut proxy_peek = vec![0u8; 256];
let pn = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
stream.peek(&mut proxy_peek),
).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Initial data timeout (proxy protocol peek)".into()),
};
if pn > 0 && crate::proxy_protocol::is_proxy_protocol_v1(&proxy_peek[..pn]) {
match crate::proxy_protocol::parse_v1(&proxy_peek[..pn]) {
Ok((header, consumed)) => {
debug!("PROXY protocol: real client {} -> {}", header.source_addr, header.dest_addr);
effective_peer_addr = header.source_addr;
// Consume the proxy protocol header bytes
let mut discard = vec![0u8; consumed];
stream.read_exact(&mut discard).await?;
}
Err(e) => {
debug!("Failed to parse PROXY protocol header: {}", e);
// Not a PROXY protocol header, continue normally
}
}
}
}
let peer_addr = effective_peer_addr;
// Peek at initial bytes with timeout
let mut peek_buf = vec![0u8; 4096];
let n = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
stream.peek(&mut peek_buf),
).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Initial data timeout".into()),
};
let initial_data = &peek_buf[..n];
// Determine connection type and extract SNI if TLS
let is_tls = sni_parser::is_tls(initial_data);
let is_http = sni_parser::is_http(initial_data);
let domain = if is_tls {
match sni_parser::extract_sni(initial_data) {
sni_parser::SniResult::Found(sni) => Some(sni),
sni_parser::SniResult::NoSni => None,
sni_parser::SniResult::NeedMoreData => {
let mut bigger_buf = vec![0u8; 16384];
let n = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
stream.peek(&mut bigger_buf),
).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("SNI data timeout".into()),
};
match sni_parser::extract_sni(&bigger_buf[..n]) {
sni_parser::SniResult::Found(sni) => Some(sni),
_ => None,
}
}
sni_parser::SniResult::NotTls => None,
}
} else {
None
};
// Match route
let ctx = rustproxy_routing::MatchContext {
port,
domain: domain.as_deref(),
path: None,
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: None,
is_tls,
};
let route_match = route_manager.find_route(&ctx);
let route_match = match route_match {
Some(rm) => rm,
None => {
debug!("No route matched for port {} domain {:?}", port, domain);
return Ok(());
}
};
let route_id = route_match.route.id.as_deref();
// Check route-level IP security for passthrough connections
if let Some(ref security) = route_match.route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security,
&peer_addr.ip(),
) {
debug!("Connection from {} blocked by route security", peer_addr);
return Ok(());
}
}
// Track connection in metrics
metrics.connection_opened(route_id);
let target = match route_match.target {
Some(t) => t,
None => {
debug!("Route matched but no target available");
metrics.connection_closed(route_id);
return Ok(());
}
};
let target_host = target.host.first().to_string();
let target_port = target.port.resolve(port);
let tls_mode = route_match.route.tls_mode();
// Connection timeout for backend connections
let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms);
let base_inactivity_ms = conn_config.socket_timeout_ms;
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
(
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
std::time::Duration::from_millis(extended_lifetime),
)
}
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_secs(u64::MAX / 2),
)
}
_ => {
// Standard
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
)
}
};
// Determine if we should send PROXY protocol to backend
let should_send_proxy = conn_config.send_proxy_protocol
|| route_match.route.action.send_proxy_protocol.unwrap_or(false)
|| target.send_proxy_protocol.unwrap_or(false);
// Generate PROXY protocol header if needed
let proxy_header = if should_send_proxy {
let dest = std::net::SocketAddr::new(
target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)),
target_port,
);
Some(crate::proxy_protocol::generate_v1(&peer_addr, &dest))
} else {
None
};
let result = match tls_mode {
Some(rustproxy_config::TlsMode::Passthrough) => {
// Raw TCP passthrough - connect to backend and forward
let mut backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Backend connection timeout".into()),
};
backend.set_nodelay(true)?;
// Send PROXY protocol header if configured
if let Some(ref header) = proxy_header {
use tokio::io::AsyncWriteExt;
backend.write_all(header.as_bytes()).await?;
}
debug!(
"Passthrough: {} -> {}:{} (SNI: {:?})",
peer_addr, target_host, target_port, domain
);
let mut actual_buf = vec![0u8; n];
stream.read_exact(&mut actual_buf).await?;
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend, Some(&actual_buf),
inactivity_timeout, max_lifetime, cancel,
).await?;
metrics.record_bytes(bytes_in, bytes_out, route_id);
Ok(())
}
Some(rustproxy_config::TlsMode::Terminate) => {
let tls_config = Self::find_tls_config(&domain, &tls_configs)?;
// TLS accept with timeout, applying route-level TLS settings
let route_tls = route_match.route.action.tls.as_ref();
let acceptor = tls_handler::build_tls_acceptor_with_config(
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
)?;
let tls_stream = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
tls_handler::accept_tls(stream, &acceptor),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(_) => return Err("TLS handshake timeout".into()),
};
// Peek at decrypted data to determine if HTTP
let mut buf_stream = tokio::io::BufReader::new(tls_stream);
let peeked = {
use tokio::io::AsyncBufReadExt;
match buf_stream.fill_buf().await {
Ok(data) => sni_parser::is_http(data),
Err(_) => false,
}
};
if peeked {
debug!(
"TLS Terminate + HTTP: {} -> {}:{} (domain: {:?})",
peer_addr, target_host, target_port, domain
);
http_proxy.handle_io(buf_stream, peer_addr, port).await;
} else {
debug!(
"TLS Terminate + TCP: {} -> {}:{} (domain: {:?})",
peer_addr, target_host, target_port, domain
);
// Raw TCP forwarding of decrypted stream
let backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Backend connection timeout".into()),
};
backend.set_nodelay(true)?;
let (tls_read, tls_write) = tokio::io::split(buf_stream);
let (backend_read, backend_write) = tokio::io::split(backend);
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
tls_read, tls_write, backend_read, backend_write,
inactivity_timeout, max_lifetime,
).await;
metrics.record_bytes(bytes_in, bytes_out, route_id);
}
Ok(())
}
Some(rustproxy_config::TlsMode::TerminateAndReencrypt) => {
let route_tls = route_match.route.action.tls.as_ref();
Self::handle_tls_terminate_reencrypt(
stream, n, &domain, &target_host, target_port,
peer_addr, &tls_configs, &metrics, route_id, &conn_config, route_tls,
).await
}
None => {
if is_http {
// Plain HTTP - use HTTP proxy for request-level routing
debug!("HTTP proxy: {} on port {}", peer_addr, port);
http_proxy.handle_connection(stream, peer_addr, port).await;
Ok(())
} else {
// Plain TCP forwarding (non-HTTP)
let mut backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Backend connection timeout".into()),
};
backend.set_nodelay(true)?;
// Send PROXY protocol header if configured
if let Some(ref header) = proxy_header {
use tokio::io::AsyncWriteExt;
backend.write_all(header.as_bytes()).await?;
}
debug!(
"Forward: {} -> {}:{}",
peer_addr, target_host, target_port
);
let mut actual_buf = vec![0u8; n];
stream.read_exact(&mut actual_buf).await?;
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend, Some(&actual_buf),
inactivity_timeout, max_lifetime, cancel,
).await?;
metrics.record_bytes(bytes_in, bytes_out, route_id);
Ok(())
}
}
};
metrics.connection_closed(route_id);
result
}
/// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend.
async fn handle_tls_terminate_reencrypt(
stream: tokio::net::TcpStream,
_peek_len: usize,
domain: &Option<String>,
target_host: &str,
target_port: u16,
peer_addr: std::net::SocketAddr,
tls_configs: &HashMap<String, TlsCertConfig>,
metrics: &MetricsCollector,
route_id: Option<&str>,
conn_config: &ConnectionConfig,
route_tls: Option<&rustproxy_config::RouteTls>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let tls_config = Self::find_tls_config(domain, tls_configs)?;
let acceptor = tls_handler::build_tls_acceptor_with_config(
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
)?;
// Accept TLS from client with timeout
let client_tls = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
tls_handler::accept_tls(stream, &acceptor),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(_) => return Err("TLS handshake timeout".into()),
};
debug!(
"TLS Terminate+Reencrypt: {} -> {}:{} (domain: {:?})",
peer_addr, target_host, target_port, domain
);
// Connect to backend over TLS with timeout
let backend_tls = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
tls_handler::connect_tls(target_host, target_port),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(_) => return Err("Backend TLS connection timeout".into()),
};
// Forward between two TLS streams
let (client_read, client_write) = tokio::io::split(client_tls);
let (backend_read, backend_write) = tokio::io::split(backend_tls);
let base_inactivity_ms = conn_config.socket_timeout_ms;
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
(
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
std::time::Duration::from_millis(extended_lifetime),
)
}
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_secs(u64::MAX / 2),
)
}
_ => {
// Standard
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
)
}
};
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
client_read, client_write, backend_read, backend_write,
inactivity_timeout, max_lifetime,
).await;
metrics.record_bytes(bytes_in, bytes_out, route_id);
Ok(())
}
/// Find the TLS config for a given domain.
fn find_tls_config<'a>(
domain: &Option<String>,
tls_configs: &'a HashMap<String, TlsCertConfig>,
) -> Result<&'a TlsCertConfig, Box<dyn std::error::Error + Send + Sync>> {
if let Some(domain) = domain {
// Try exact match
if let Some(config) = tls_configs.get(domain) {
return Ok(config);
}
// Try wildcard
if let Some(dot_pos) = domain.find('.') {
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
if let Some(config) = tls_configs.get(&wildcard) {
return Ok(config);
}
}
}
// Try default/fallback cert
if let Some(config) = tls_configs.get("*") {
return Ok(config);
}
// Try first available cert
if let Some((_key, config)) = tls_configs.iter().next() {
return Ok(config);
}
Err("No TLS certificate available for this domain".into())
}
/// Forward bidirectional between two split streams with inactivity and lifetime timeouts.
async fn forward_bidirectional_split_with_timeouts<R1, W1, R2, W2>(
mut client_read: R1,
mut client_write: W1,
mut backend_read: R2,
mut backend_write: W2,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
) -> (u64, u64)
where
R1: tokio::io::AsyncRead + Unpin + Send + 'static,
W1: tokio::io::AsyncWrite + Unpin + Send + 'static,
R2: tokio::io::AsyncRead + Unpin + Send + 'static,
W2: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(
start.elapsed().as_millis() as u64,
Ordering::Relaxed,
);
}
let _ = backend_write.shutdown().await;
total
});
let la2 = Arc::clone(&last_activity);
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(
start.elapsed().as_millis() as u64,
Ordering::Relaxed,
);
}
let _ = client_write.shutdown().await;
total
});
// Watchdog task: check for inactivity and max lifetime
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::time::sleep(check_interval).await;
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
// No activity since last check
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
(bytes_in, bytes_out)
}
}

View File

@@ -0,0 +1,190 @@
use std::io::BufReader;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tracing::debug;
/// Ensure the default crypto provider is installed.
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
pub fn build_tls_acceptor_with_config(
cert_pem: &str,
key_pem: &str,
tls_config: Option<&rustproxy_config::RouteTls>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
let mut config = if let Some(route_tls) = tls_config {
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?
};
// Apply session timeout if configured
if let Some(route_tls) = tls_config {
if let Some(timeout_secs) = route_tls.session_timeout {
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
256, // max sessions
);
debug!("TLS session timeout configured: {}s", timeout_secs);
}
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
};
let mut result = Vec::new();
for v in versions {
match v.as_str() {
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
if !result.contains(&&rustls::version::TLS12) {
result.push(&rustls::version::TLS12);
}
}
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
if !result.contains(&&rustls::version::TLS13) {
result.push(&rustls::version::TLS13);
}
}
other => {
debug!("Unknown TLS version '{}', ignoring", other);
}
}
}
if result.is_empty() {
// Fallback to both if no valid versions specified
vec![&rustls::version::TLS12, &rustls::version::TLS13]
} else {
result
}
}
/// Accept a TLS connection from a client stream.
pub async fn accept_tls(
stream: TcpStream,
acceptor: &TlsAcceptor,
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
let tls_stream = acceptor.accept(stream).await?;
debug!("TLS handshake completed");
Ok(tls_stream)
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
pub async fn connect_tls(
host: &str,
port: u16,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
let tls_stream = connector.connect(server_name, stream).await?;
debug!("Backend TLS connection established to {}:{}", host, port);
Ok(tls_stream)
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
Ok(certs)
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
Ok(key)
}
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
/// In internal networks, backends may use self-signed certs.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}

View File

@@ -0,0 +1,16 @@
[package]
name = "rustproxy-routing"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Route matching engine for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
glob-match = { workspace = true }
ipnet = { workspace = true }
regex = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
arc-swap = { workspace = true }

View File

@@ -0,0 +1,9 @@
//! # rustproxy-routing
//!
//! Route matching engine for RustProxy.
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
pub mod route_manager;
pub mod matchers;
pub use route_manager::*;

View File

@@ -0,0 +1,86 @@
/// Match a domain against a pattern supporting wildcards.
///
/// Supported patterns:
/// - `*` matches any domain
/// - `*.example.com` matches any subdomain of example.com
/// - `example.com` exact match
/// - `**.example.com` matches any depth of subdomain
pub fn domain_matches(pattern: &str, domain: &str) -> bool {
let pattern = pattern.trim().to_lowercase();
let domain = domain.trim().to_lowercase();
if pattern == "*" {
return true;
}
if pattern == domain {
return true;
}
// Wildcard patterns
if pattern.starts_with("*.") {
let suffix = &pattern[2..]; // e.g., "example.com"
// Match exact parent or any single-level subdomain
if domain == suffix {
return true;
}
if domain.ends_with(&format!(".{}", suffix)) {
// Check it's a single level subdomain for `*.`
let prefix = &domain[..domain.len() - suffix.len() - 1];
return !prefix.contains('.');
}
return false;
}
if pattern.starts_with("**.") {
let suffix = &pattern[3..];
// Match exact parent or any depth of subdomain
return domain == suffix || domain.ends_with(&format!(".{}", suffix));
}
// Use glob-match for more complex patterns
glob_match::glob_match(&pattern, &domain)
}
/// Check if a domain matches any of the given patterns.
pub fn domain_matches_any(patterns: &[&str], domain: &str) -> bool {
patterns.iter().any(|p| domain_matches(p, domain))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
assert!(domain_matches("example.com", "example.com"));
assert!(!domain_matches("example.com", "other.com"));
}
#[test]
fn test_wildcard_all() {
assert!(domain_matches("*", "anything.com"));
assert!(domain_matches("*", "sub.domain.example.com"));
}
#[test]
fn test_wildcard_subdomain() {
assert!(domain_matches("*.example.com", "www.example.com"));
assert!(domain_matches("*.example.com", "api.example.com"));
assert!(domain_matches("*.example.com", "example.com"));
assert!(!domain_matches("*.example.com", "deep.sub.example.com"));
}
#[test]
fn test_double_wildcard() {
assert!(domain_matches("**.example.com", "www.example.com"));
assert!(domain_matches("**.example.com", "deep.sub.example.com"));
assert!(domain_matches("**.example.com", "example.com"));
}
#[test]
fn test_case_insensitive() {
assert!(domain_matches("Example.COM", "example.com"));
assert!(domain_matches("*.EXAMPLE.com", "WWW.example.COM"));
}
}

View File

@@ -0,0 +1,98 @@
use std::collections::HashMap;
use regex::Regex;
/// Match HTTP headers against a set of patterns.
///
/// Pattern values can be:
/// - Exact string: `"application/json"`
/// - Regex (surrounded by /): `"/^text\/.*/"`
pub fn headers_match(
patterns: &HashMap<String, String>,
headers: &HashMap<String, String>,
) -> bool {
for (key, pattern) in patterns {
let key_lower = key.to_lowercase();
// Find the header (case-insensitive)
let header_value = headers
.iter()
.find(|(k, _)| k.to_lowercase() == key_lower)
.map(|(_, v)| v.as_str());
let header_value = match header_value {
Some(v) => v,
None => return false, // Required header not present
};
// Check if pattern is a regex (surrounded by /)
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
let regex_str = &pattern[1..pattern.len() - 1];
match Regex::new(regex_str) {
Ok(re) => {
if !re.is_match(header_value) {
return false;
}
}
Err(_) => {
// Invalid regex, fall back to exact match
if header_value != pattern {
return false;
}
}
}
} else {
// Exact match
if header_value != pattern {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_header_match() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("Content-Type".to_string(), "application/json".to_string());
m
};
let headers: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("content-type".to_string(), "application/json".to_string());
m
};
assert!(headers_match(&patterns, &headers));
}
#[test]
fn test_regex_header_match() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("Content-Type".to_string(), "/^text\\/.*/".to_string());
m
};
let headers: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("content-type".to_string(), "text/html".to_string());
m
};
assert!(headers_match(&patterns, &headers));
}
#[test]
fn test_missing_header() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("X-Custom".to_string(), "value".to_string());
m
};
let headers: HashMap<String, String> = HashMap::new();
assert!(!headers_match(&patterns, &headers));
}
}

View File

@@ -0,0 +1,126 @@
use std::net::IpAddr;
use std::str::FromStr;
use ipnet::IpNet;
/// Match an IP address against a pattern.
///
/// Supported patterns:
/// - `*` matches any IP
/// - `192.168.1.0/24` CIDR range
/// - `192.168.1.100` exact match
/// - `192.168.1.*` wildcard (converted to CIDR)
/// - `::ffff:192.168.1.100` IPv6-mapped IPv4
pub fn ip_matches(pattern: &str, ip: &str) -> bool {
let pattern = pattern.trim();
if pattern == "*" {
return true;
}
// Normalize IPv4-mapped IPv6
let normalized_ip = normalize_ip_str(ip);
// Try CIDR match
if pattern.contains('/') {
if let Ok(net) = IpNet::from_str(pattern) {
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
return net.contains(&addr);
}
}
return false;
}
// Handle wildcard patterns like 192.168.1.*
if pattern.contains('*') {
let pattern_cidr = wildcard_to_cidr(pattern);
if let Some(cidr) = pattern_cidr {
if let Ok(net) = IpNet::from_str(&cidr) {
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
return net.contains(&addr);
}
}
}
return false;
}
// Exact match
let normalized_pattern = normalize_ip_str(pattern);
normalized_ip == normalized_pattern
}
/// Check if an IP matches any of the given patterns.
pub fn ip_matches_any(patterns: &[String], ip: &str) -> bool {
patterns.iter().any(|p| ip_matches(p, ip))
}
/// Normalize IPv4-mapped IPv6 addresses.
fn normalize_ip_str(ip: &str) -> String {
let ip = ip.trim();
if ip.starts_with("::ffff:") {
return ip[7..].to_string();
}
ip.to_string()
}
/// Convert a wildcard IP pattern to CIDR notation.
/// e.g., "192.168.1.*" -> "192.168.1.0/24"
fn wildcard_to_cidr(pattern: &str) -> Option<String> {
let parts: Vec<&str> = pattern.split('.').collect();
if parts.len() != 4 {
return None;
}
let mut octets = [0u8; 4];
let mut prefix_len = 0;
for (i, part) in parts.iter().enumerate() {
if *part == "*" {
break;
}
if let Ok(n) = part.parse::<u8>() {
octets[i] = n;
prefix_len += 8;
} else {
return None;
}
}
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wildcard_all() {
assert!(ip_matches("*", "192.168.1.100"));
assert!(ip_matches("*", "::1"));
}
#[test]
fn test_exact_match() {
assert!(ip_matches("192.168.1.100", "192.168.1.100"));
assert!(!ip_matches("192.168.1.100", "192.168.1.101"));
}
#[test]
fn test_cidr() {
assert!(ip_matches("192.168.1.0/24", "192.168.1.100"));
assert!(ip_matches("192.168.1.0/24", "192.168.1.1"));
assert!(!ip_matches("192.168.1.0/24", "192.168.2.1"));
}
#[test]
fn test_wildcard_pattern() {
assert!(ip_matches("192.168.1.*", "192.168.1.100"));
assert!(ip_matches("192.168.1.*", "192.168.1.1"));
assert!(!ip_matches("192.168.1.*", "192.168.2.1"));
}
#[test]
fn test_ipv6_mapped() {
assert!(ip_matches("192.168.1.100", "::ffff:192.168.1.100"));
assert!(ip_matches("192.168.1.0/24", "::ffff:192.168.1.50"));
}
}

View File

@@ -0,0 +1,9 @@
pub mod domain;
pub mod path;
pub mod ip;
pub mod header;
pub use domain::*;
pub use path::*;
pub use ip::*;
pub use header::*;

View File

@@ -0,0 +1,65 @@
/// Match a URL path against a pattern supporting wildcards.
///
/// Supported patterns:
/// - `/api/*` matches `/api/anything` (single level)
/// - `/api/**` matches `/api/any/depth/here`
/// - `/exact/path` exact match
/// - `/prefix*` prefix match
pub fn path_matches(pattern: &str, path: &str) -> bool {
// Exact match
if pattern == path {
return true;
}
// Double-star: match any depth
if pattern.ends_with("/**") {
let prefix = &pattern[..pattern.len() - 3];
return path == prefix || path.starts_with(&format!("{}/", prefix));
}
// Single-star at end: match single path segment
if pattern.ends_with("/*") {
let prefix = &pattern[..pattern.len() - 2];
if path == prefix {
return true;
}
if path.starts_with(&format!("{}/", prefix)) {
let rest = &path[prefix.len() + 1..];
// Single level means no more slashes
return !rest.contains('/');
}
return false;
}
// Star anywhere: use glob matching
if pattern.contains('*') {
return glob_match::glob_match(pattern, path);
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_path() {
assert!(path_matches("/api/users", "/api/users"));
assert!(!path_matches("/api/users", "/api/posts"));
}
#[test]
fn test_single_wildcard() {
assert!(path_matches("/api/*", "/api/users"));
assert!(path_matches("/api/*", "/api/posts"));
assert!(!path_matches("/api/*", "/api/users/123"));
}
#[test]
fn test_double_wildcard() {
assert!(path_matches("/api/**", "/api/users"));
assert!(path_matches("/api/**", "/api/users/123"));
assert!(path_matches("/api/**", "/api/users/123/posts"));
}
}

View File

@@ -0,0 +1,545 @@
use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
use crate::matchers;
/// Context for route matching (subset of connection info).
pub struct MatchContext<'a> {
pub port: u16,
pub domain: Option<&'a str>,
pub path: Option<&'a str>,
pub client_ip: Option<&'a str>,
pub tls_version: Option<&'a str>,
pub headers: Option<&'a HashMap<String, String>>,
pub is_tls: bool,
}
/// Result of a route match.
pub struct RouteMatchResult<'a> {
pub route: &'a RouteConfig,
pub target: Option<&'a RouteTarget>,
}
/// Port-indexed route lookup with priority-based matching.
/// This is the core routing engine.
pub struct RouteManager {
/// Routes indexed by port for O(1) port lookup.
port_index: HashMap<u16, Vec<usize>>,
/// All routes, sorted by priority (highest first).
routes: Vec<RouteConfig>,
}
impl RouteManager {
/// Create a new RouteManager from a list of routes.
pub fn new(routes: Vec<RouteConfig>) -> Self {
let mut manager = Self {
port_index: HashMap::new(),
routes: Vec::new(),
};
// Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes
.into_iter()
.filter(|r| r.is_enabled())
.collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index
for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() {
manager.port_index
.entry(port)
.or_default()
.push(idx);
}
}
manager.routes = enabled_routes;
manager
}
/// Find the best matching route for the given context.
pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option<RouteMatchResult<'a>> {
// Get routes for this port
let route_indices = self.port_index.get(&ctx.port)?;
for &idx in route_indices {
let route = &self.routes[idx];
if self.matches_route(route, ctx) {
// Find the best matching target within the route
let target = self.find_target(route, ctx);
return Some(RouteMatchResult { route, target });
}
}
None
}
/// Check if a route matches the given context.
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
let rm = &route.route_match;
// Domain matching
if let Some(ref domains) = rm.domains {
if let Some(domain) = ctx.domain {
let patterns = domains.to_vec();
if !matchers::domain_matches_any(&patterns, domain) {
return false;
}
}
// If no domain provided but route requires domain, it depends on context
// For TLS passthrough, we need SNI; for other cases we may still match
}
// Path matching
if let Some(ref pattern) = rm.path {
if let Some(path) = ctx.path {
if !matchers::path_matches(pattern, path) {
return false;
}
} else {
// Route requires path but none provided
return false;
}
}
// Client IP matching
if let Some(ref client_ips) = rm.client_ip {
if let Some(ip) = ctx.client_ip {
if !matchers::ip_matches_any(client_ips, ip) {
return false;
}
} else {
return false;
}
}
// TLS version matching
if let Some(ref tls_versions) = rm.tls_version {
if let Some(version) = ctx.tls_version {
if !tls_versions.iter().any(|v| v == version) {
return false;
}
} else {
return false;
}
}
// Header matching
if let Some(ref patterns) = rm.headers {
if let Some(headers) = ctx.headers {
if !matchers::headers_match(patterns, headers) {
return false;
}
} else {
return false;
}
}
true
}
/// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() {
return Some(&targets[0]);
}
// Sort candidates by priority (already in order from config)
let mut best: Option<&RouteTarget> = None;
let mut best_priority = i32::MIN;
for target in targets {
let priority = target.priority.unwrap_or(0);
if let Some(ref tm) = target.target_match {
if !self.matches_target(tm, ctx) {
continue;
}
}
if priority > best_priority || best.is_none() {
best = Some(target);
best_priority = priority;
}
}
// Fall back to first target without match criteria
best.or_else(|| {
targets.iter().find(|t| t.target_match.is_none())
})
}
/// Check if a target match criteria matches the context.
fn matches_target(
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
// Port matching
if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) {
return false;
}
}
// Path matching
if let Some(ref pattern) = tm.path {
if let Some(path) = ctx.path {
if !matchers::path_matches(pattern, path) {
return false;
}
} else {
return false;
}
}
// Header matching
if let Some(ref patterns) = tm.headers {
if let Some(headers) = ctx.headers {
if !matchers::headers_match(patterns, headers) {
return false;
}
} else {
return false;
}
}
true
}
/// Get all unique listening ports.
pub fn listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.port_index.keys().copied().collect();
ports.sort();
ports
}
/// Get all routes for a specific port.
pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> {
self.port_index
.get(&port)
.map(|indices| indices.iter().map(|&i| &self.routes[i]).collect())
.unwrap_or_default()
}
/// Get the total number of enabled routes.
pub fn route_count(&self) -> usize {
self.routes.len()
}
/// Check if any route on the given port requires SNI.
pub fn port_requires_sni(&self, port: u16) -> bool {
let routes = self.routes_for_port(port);
// If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes
.iter()
.filter(|r| {
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.collect();
if passthrough_routes.len() > 1 {
return true;
}
// Single passthrough route with specific domain restriction needs SNI
if let Some(route) = passthrough_routes.first() {
if let Some(ref domains) = route.route_match.domains {
let domain_list = domains.to_vec();
// If it's not just a wildcard, SNI is needed
if !domain_list.iter().all(|d| *d == "*") {
return true;
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustproxy_config::*;
fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
tls_version: None,
headers: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single("localhost".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: Some(priority),
tags: None,
enabled: None,
}
}
#[test]
fn test_basic_routing() {
let routes = vec![
make_route(80, Some("example.com"), 0),
make_route(80, Some("other.com"), 0),
];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
let result = manager.find_route(&ctx);
assert!(result.is_some());
}
#[test]
fn test_priority_ordering() {
let routes = vec![
make_route(80, Some("*.example.com"), 0),
make_route(80, Some("api.example.com"), 10), // Higher priority
];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("api.example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref()
.map(|d| d.to_vec())
.unwrap()
.contains(&"api.example.com"));
}
#[test]
fn test_no_match() {
let routes = vec![make_route(80, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443, // Different port
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_disabled_routes_excluded() {
let mut route = make_route(80, Some("example.com"), 0);
route.enabled = Some(false);
let manager = RouteManager::new(vec![route]);
assert_eq!(manager.route_count(), 0);
}
#[test]
fn test_listening_ports() {
let routes = vec![
make_route(80, Some("a.com"), 0),
make_route(443, Some("b.com"), 0),
make_route(80, Some("c.com"), 0), // duplicate port
];
let manager = RouteManager::new(routes);
let ports = manager.listening_ports();
assert_eq!(ports, vec![80, 443]);
}
#[test]
fn test_port_requires_sni_single_passthrough() {
let mut route = make_route(443, Some("example.com"), 0);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
let manager = RouteManager::new(vec![route]);
// Single passthrough route with specific domain needs SNI
assert!(manager.port_requires_sni(443));
}
#[test]
fn test_port_requires_sni_wildcard_only() {
let mut route = make_route(443, Some("*"), 0);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
let manager = RouteManager::new(vec![route]);
// Single passthrough route with wildcard doesn't need SNI
assert!(!manager.port_requires_sni(443));
}
#[test]
fn test_routes_for_port() {
let routes = vec![
make_route(80, Some("a.com"), 0),
make_route(80, Some("b.com"), 0),
make_route(443, Some("c.com"), 0),
];
let manager = RouteManager::new(routes);
assert_eq!(manager.routes_for_port(80).len(), 2);
assert_eq!(manager.routes_for_port(443).len(), 1);
assert_eq!(manager.routes_for_port(8080).len(), 0);
}
#[test]
fn test_wildcard_domain_matches_any() {
let routes = vec![make_route(80, Some("*"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("anything.example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_no_domain_route_matches_any_domain() {
let routes = vec![make_route(80, None, 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_target_sub_matching() {
let mut route = make_route(80, Some("example.com"), 0);
route.action.targets = Some(vec![
RouteTarget {
target_match: Some(rustproxy_config::TargetMatch {
ports: None,
path: Some("/api/*".to_string()),
headers: None,
method: None,
}),
host: HostSpec::Single("api-backend".to_string()),
port: PortSpec::Fixed(3000),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: Some(10),
},
RouteTarget {
target_match: None,
host: HostSpec::Single("default-backend".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
},
]);
let manager = RouteManager::new(vec![route]);
// Should match the API target
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: Some("/api/users"),
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "api-backend");
// Should fall back to default target
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: Some("/home"),
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "default-backend");
}
}

View File

@@ -0,0 +1,17 @@
[package]
name = "rustproxy-security"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "IP filtering, rate limiting, and authentication for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
dashmap = { workspace = true }
ipnet = { workspace = true }
jsonwebtoken = { workspace = true }
base64 = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }

View File

@@ -0,0 +1,111 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
/// Basic auth validator.
pub struct BasicAuthValidator {
users: Vec<(String, String)>,
realm: String,
}
impl BasicAuthValidator {
pub fn new(users: Vec<(String, String)>, realm: Option<String>) -> Self {
Self {
users,
realm: realm.unwrap_or_else(|| "Restricted".to_string()),
}
}
/// Validate an Authorization header value.
/// Returns the username if valid.
pub fn validate(&self, auth_header: &str) -> Option<String> {
let auth_header = auth_header.trim();
if !auth_header.starts_with("Basic ") {
return None;
}
let encoded = &auth_header[6..];
let decoded = BASE64.decode(encoded).ok()?;
let credentials = String::from_utf8(decoded).ok()?;
let mut parts = credentials.splitn(2, ':');
let username = parts.next()?;
let password = parts.next()?;
for (u, p) in &self.users {
if u == username && p == password {
return Some(username.to_string());
}
}
None
}
/// Get the realm for WWW-Authenticate header.
pub fn realm(&self) -> &str {
&self.realm
}
/// Generate the WWW-Authenticate header value.
pub fn www_authenticate(&self) -> String {
format!("Basic realm=\"{}\"", self.realm)
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
fn make_validator() -> BasicAuthValidator {
BasicAuthValidator::new(
vec![
("admin".to_string(), "secret".to_string()),
("user".to_string(), "pass".to_string()),
],
Some("TestRealm".to_string()),
)
}
fn encode_basic(user: &str, pass: &str) -> String {
let encoded = BASE64.encode(format!("{}:{}", user, pass));
format!("Basic {}", encoded)
}
#[test]
fn test_valid_credentials() {
let validator = make_validator();
let header = encode_basic("admin", "secret");
assert_eq!(validator.validate(&header), Some("admin".to_string()));
}
#[test]
fn test_invalid_password() {
let validator = make_validator();
let header = encode_basic("admin", "wrong");
assert_eq!(validator.validate(&header), None);
}
#[test]
fn test_not_basic_scheme() {
let validator = make_validator();
assert_eq!(validator.validate("Bearer sometoken"), None);
}
#[test]
fn test_malformed_base64() {
let validator = make_validator();
assert_eq!(validator.validate("Basic !!!not-base64!!!"), None);
}
#[test]
fn test_www_authenticate_format() {
let validator = make_validator();
assert_eq!(validator.www_authenticate(), "Basic realm=\"TestRealm\"");
}
#[test]
fn test_default_realm() {
let validator = BasicAuthValidator::new(vec![], None);
assert_eq!(validator.www_authenticate(), "Basic realm=\"Restricted\"");
}
}

View File

@@ -0,0 +1,189 @@
use ipnet::IpNet;
use std::net::IpAddr;
use std::str::FromStr;
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
pub struct IpFilter {
allow_list: Vec<IpPattern>,
block_list: Vec<IpPattern>,
}
/// Represents an IP pattern for matching.
#[derive(Debug)]
enum IpPattern {
/// Exact IP match
Exact(IpAddr),
/// CIDR range match
Cidr(IpNet),
/// Wildcard (matches everything)
Wildcard,
}
impl IpPattern {
fn parse(s: &str) -> Self {
let s = s.trim();
if s == "*" {
return IpPattern::Wildcard;
}
if let Ok(net) = IpNet::from_str(s) {
return IpPattern::Cidr(net);
}
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Try as CIDR by appending default prefix
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Fallback: treat as exact, will never match an invalid string
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
}
fn matches(&self, ip: &IpAddr) -> bool {
match self {
IpPattern::Wildcard => true,
IpPattern::Exact(addr) => addr == ip,
IpPattern::Cidr(net) => net.contains(ip),
}
}
}
impl IpFilter {
/// Create a new IP filter from allow and block lists.
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
Self {
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
/// Check if an IP is allowed.
/// If allow_list is non-empty, IP must match at least one entry.
/// If block_list is non-empty, IP must NOT match any entry.
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
// Check block list first
if !self.block_list.is_empty() {
for pattern in &self.block_list {
if pattern.matches(ip) {
return false;
}
}
}
// If allow list is non-empty, must match at least one
if !self.allow_list.is_empty() {
return self.allow_list.iter().any(|p| p.matches(ip));
}
true
}
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpAddr::V4(v4)
} else {
*ip
}
}
_ => *ip,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_lists_allow_all() {
let filter = IpFilter::new(&[], &[]);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&ip));
}
#[test]
fn test_allow_list_exact() {
let filter = IpFilter::new(
&["10.0.0.1".to_string()],
&[],
);
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
let denied: IpAddr = "10.0.0.2".parse().unwrap();
assert!(filter.is_allowed(&allowed));
assert!(!filter.is_allowed(&denied));
}
#[test]
fn test_allow_list_cidr() {
let filter = IpFilter::new(
&["10.0.0.0/8".to_string()],
&[],
);
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
let denied: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&allowed));
assert!(!filter.is_allowed(&denied));
}
#[test]
fn test_block_list() {
let filter = IpFilter::new(
&[],
&["192.168.1.100".to_string()],
);
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
assert!(filter.is_allowed(&allowed));
}
#[test]
fn test_block_trumps_allow() {
let filter = IpFilter::new(
&["10.0.0.0/8".to_string()],
&["10.0.0.5".to_string()],
);
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
assert!(filter.is_allowed(&allowed));
}
#[test]
fn test_wildcard_allow() {
let filter = IpFilter::new(
&["*".to_string()],
&[],
);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(filter.is_allowed(&ip));
}
#[test]
fn test_wildcard_block() {
let filter = IpFilter::new(
&[],
&["*".to_string()],
);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(!filter.is_allowed(&ip));
}
#[test]
fn test_normalize_ipv4_mapped_ipv6() {
let mapped: IpAddr = "::ffff:192.168.1.1".parse().unwrap();
let normalized = IpFilter::normalize_ip(&mapped);
let expected: IpAddr = "192.168.1.1".parse().unwrap();
assert_eq!(normalized, expected);
}
#[test]
fn test_normalize_pure_ipv4() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let normalized = IpFilter::normalize_ip(&ip);
assert_eq!(normalized, ip);
}
}

View File

@@ -0,0 +1,174 @@
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure).
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: Option<String>,
pub exp: Option<u64>,
pub iss: Option<String>,
pub aud: Option<String>,
}
/// JWT auth validator.
pub struct JwtValidator {
decoding_key: DecodingKey,
validation: Validation,
}
impl JwtValidator {
pub fn new(
secret: &str,
algorithm: Option<&str>,
issuer: Option<&str>,
audience: Option<&str>,
) -> Self {
let algo = match algorithm {
Some("HS384") => Algorithm::HS384,
Some("HS512") => Algorithm::HS512,
Some("RS256") => Algorithm::RS256,
_ => Algorithm::HS256,
};
let mut validation = Validation::new(algo);
if let Some(iss) = issuer {
validation.set_issuer(&[iss]);
}
if let Some(aud) = audience {
validation.set_audience(&[aud]);
}
Self {
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
validation,
}
}
/// Validate a JWT token string (without "Bearer " prefix).
/// Returns the claims if valid.
pub fn validate(&self, token: &str) -> Result<Claims, String> {
decode::<Claims>(token, &self.decoding_key, &self.validation)
.map(|data| data.claims)
.map_err(|e| e.to_string())
}
/// Extract token from Authorization header.
pub fn extract_token(auth_header: &str) -> Option<&str> {
let header = auth_header.trim();
if header.starts_with("Bearer ") {
Some(&header[7..])
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{encode, EncodingKey, Header};
fn make_token(secret: &str, claims: &Claims) -> String {
encode(
&Header::default(),
claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
fn future_exp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600
}
fn past_exp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 3600
}
#[test]
fn test_valid_token() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: None,
aud: None,
};
let token = make_token(secret, &claims);
let validator = JwtValidator::new(secret, None, None, None);
let result = validator.validate(&token);
assert!(result.is_ok());
assert_eq!(result.unwrap().sub, Some("user123".to_string()));
}
#[test]
fn test_expired_token() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(past_exp()),
iss: None,
aud: None,
};
let token = make_token(secret, &claims);
let validator = JwtValidator::new(secret, None, None, None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_wrong_secret() {
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: None,
aud: None,
};
let token = make_token("correct-secret", &claims);
let validator = JwtValidator::new("wrong-secret", None, None, None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_issuer_validation() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: Some("my-issuer".to_string()),
aud: None,
};
let token = make_token(secret, &claims);
// Correct issuer
let validator = JwtValidator::new(secret, None, Some("my-issuer"), None);
assert!(validator.validate(&token).is_ok());
// Wrong issuer
let validator = JwtValidator::new(secret, None, Some("other-issuer"), None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_extract_token_bearer() {
assert_eq!(
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
}
#[test]
fn test_extract_token_non_bearer() {
assert_eq!(JwtValidator::extract_token("Basic abc123"), None);
assert_eq!(JwtValidator::extract_token("abc123"), None);
}
}

View File

@@ -0,0 +1,13 @@
//! # rustproxy-security
//!
//! IP filtering, rate limiting, and authentication for RustProxy.
pub mod ip_filter;
pub mod rate_limiter;
pub mod basic_auth;
pub mod jwt_auth;
pub use ip_filter::*;
pub use rate_limiter::*;
pub use basic_auth::*;
pub use jwt_auth::*;

View File

@@ -0,0 +1,97 @@
use dashmap::DashMap;
use std::time::Instant;
/// Sliding window rate limiter.
pub struct RateLimiter {
/// Map of key -> list of request timestamps
windows: DashMap<String, Vec<Instant>>,
/// Maximum requests per window
max_requests: u64,
/// Window duration in seconds
window_seconds: u64,
}
impl RateLimiter {
pub fn new(max_requests: u64, window_seconds: u64) -> Self {
Self {
windows: DashMap::new(),
max_requests,
window_seconds,
}
}
/// Check if a request is allowed for the given key.
/// Returns true if allowed, false if rate limited.
pub fn check(&self, key: &str) -> bool {
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_seconds);
let mut entry = self.windows.entry(key.to_string()).or_default();
let timestamps = entry.value_mut();
// Remove expired entries
timestamps.retain(|t| now.duration_since(*t) < window);
if timestamps.len() as u64 >= self.max_requests {
false
} else {
timestamps.push(now);
true
}
}
/// Clean up expired entries (call periodically).
pub fn cleanup(&self) {
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_seconds);
self.windows.retain(|_, timestamps| {
timestamps.retain(|t| now.duration_since(*t) < window);
!timestamps.is_empty()
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allow_under_limit() {
let limiter = RateLimiter::new(5, 60);
for _ in 0..5 {
assert!(limiter.check("client-1"));
}
}
#[test]
fn test_block_over_limit() {
let limiter = RateLimiter::new(3, 60);
assert!(limiter.check("client-1"));
assert!(limiter.check("client-1"));
assert!(limiter.check("client-1"));
assert!(!limiter.check("client-1")); // 4th request blocked
}
#[test]
fn test_different_keys_independent() {
let limiter = RateLimiter::new(2, 60);
assert!(limiter.check("client-a"));
assert!(limiter.check("client-a"));
assert!(!limiter.check("client-a")); // blocked
// Different key should still be allowed
assert!(limiter.check("client-b"));
assert!(limiter.check("client-b"));
}
#[test]
fn test_cleanup_removes_expired() {
let limiter = RateLimiter::new(100, 0); // 0 second window = immediately expired
limiter.check("client-1");
// Sleep briefly to let entries expire
std::thread::sleep(std::time::Duration::from_millis(10));
limiter.cleanup();
// After cleanup, the key should be allowed again (entries expired)
assert!(limiter.check("client-1"));
}
}

View File

@@ -0,0 +1,22 @@
[package]
name = "rustproxy-tls"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "TLS certificate management for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
tokio = { workspace = true }
rustls = { workspace = true }
instant-acme = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rcgen = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }

View File

@@ -0,0 +1,360 @@
//! ACME (Let's Encrypt) integration using instant-acme.
//!
//! This module handles HTTP-01 challenge creation and certificate provisioning.
//! Supports persisting ACME account credentials to disk for reuse across restarts.
use std::path::{Path, PathBuf};
use instant_acme::{
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
AccountCredentials,
};
use rcgen::{CertificateParams, KeyPair};
use thiserror::Error;
use tracing::{debug, info, warn};
#[derive(Debug, Error)]
pub enum AcmeError {
#[error("ACME account creation failed: {0}")]
AccountCreation(String),
#[error("ACME order failed: {0}")]
OrderFailed(String),
#[error("Challenge failed: {0}")]
ChallengeFailed(String),
#[error("Certificate finalization failed: {0}")]
FinalizationFailed(String),
#[error("No HTTP-01 challenge found")]
NoHttp01Challenge,
#[error("Timeout waiting for order: {0}")]
Timeout(String),
#[error("Account persistence error: {0}")]
Persistence(String),
}
/// Pending HTTP-01 challenge that needs to be served.
pub struct PendingChallenge {
pub token: String,
pub key_authorization: String,
pub domain: String,
}
/// ACME client wrapper around instant-acme.
pub struct AcmeClient {
use_production: bool,
email: String,
/// Optional directory where account.json is persisted.
account_dir: Option<PathBuf>,
}
impl AcmeClient {
pub fn new(email: String, use_production: bool) -> Self {
Self {
use_production,
email,
account_dir: None,
}
}
/// Create a new client with account persistence at the given directory.
pub fn with_persistence(email: String, use_production: bool, account_dir: impl AsRef<Path>) -> Self {
Self {
use_production,
email,
account_dir: Some(account_dir.as_ref().to_path_buf()),
}
}
/// Get or create an ACME account, persisting credentials if account_dir is set.
async fn get_or_create_account(&self) -> Result<Account, AcmeError> {
let directory_url = self.directory_url();
// Try to restore from persisted credentials
if let Some(ref dir) = self.account_dir {
let account_file = dir.join("account.json");
if account_file.exists() {
match std::fs::read_to_string(&account_file) {
Ok(json) => {
match serde_json::from_str::<AccountCredentials>(&json) {
Ok(credentials) => {
match Account::from_credentials(credentials).await {
Ok(account) => {
debug!("Restored ACME account from {}", account_file.display());
return Ok(account);
}
Err(e) => {
warn!("Failed to restore ACME account, creating new: {}", e);
}
}
}
Err(e) => {
warn!("Invalid account.json, creating new account: {}", e);
}
}
}
Err(e) => {
warn!("Could not read account.json: {}", e);
}
}
}
}
// Create a new account
let contact = format!("mailto:{}", self.email);
let (account, credentials) = Account::create(
&NewAccount {
contact: &[&contact],
terms_of_service_agreed: true,
only_return_existing: false,
},
directory_url,
None,
)
.await
.map_err(|e| AcmeError::AccountCreation(e.to_string()))?;
debug!("ACME account created");
// Persist credentials if we have a directory
if let Some(ref dir) = self.account_dir {
if let Err(e) = std::fs::create_dir_all(dir) {
warn!("Failed to create account directory {}: {}", dir.display(), e);
} else {
let account_file = dir.join("account.json");
match serde_json::to_string_pretty(&credentials) {
Ok(json) => {
if let Err(e) = std::fs::write(&account_file, &json) {
warn!("Failed to persist ACME account to {}: {}", account_file.display(), e);
} else {
info!("ACME account credentials persisted to {}", account_file.display());
}
}
Err(e) => {
warn!("Failed to serialize account credentials: {}", e);
}
}
}
}
Ok(account)
}
/// Request a certificate for a domain using the HTTP-01 challenge.
///
/// Returns (cert_chain_pem, private_key_pem) on success.
///
/// The caller must serve the HTTP-01 challenge at:
/// `http://<domain>/.well-known/acme-challenge/<token>`
///
/// The `challenge_handler` closure is called with a `PendingChallenge`
/// and must arrange for the challenge response to be served. It should
/// return once the challenge is ready to be validated.
pub async fn provision<F, Fut>(
&self,
domain: &str,
challenge_handler: F,
) -> Result<(String, String), AcmeError>
where
F: FnOnce(PendingChallenge) -> Fut,
Fut: std::future::Future<Output = Result<(), AcmeError>>,
{
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
// 1. Get or create ACME account (with persistence)
let account = self.get_or_create_account().await?;
// 2. Create order
let identifier = Identifier::Dns(domain.to_string());
let mut order = account
.new_order(&NewOrder {
identifiers: &[identifier],
})
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
debug!("ACME order created");
// 3. Get authorizations and find HTTP-01 challenge
let authorizations = order
.authorizations()
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
// Find the HTTP-01 challenge
let (challenge_token, challenge_url) = authorizations
.iter()
.flat_map(|auth| auth.challenges.iter())
.find(|c| c.r#type == ChallengeType::Http01)
.map(|c| {
let key_auth = order.key_authorization(c);
(
PendingChallenge {
token: c.token.clone(),
key_authorization: key_auth.as_str().to_string(),
domain: domain.to_string(),
},
c.url.clone(),
)
})
.ok_or(AcmeError::NoHttp01Challenge)?;
// Call the handler to set up challenge serving
challenge_handler(challenge_token).await?;
// 4. Notify ACME server that challenge is ready
order
.set_challenge_ready(&challenge_url)
.await
.map_err(|e| AcmeError::ChallengeFailed(e.to_string()))?;
debug!("Challenge marked as ready, waiting for validation...");
// 5. Poll for order to become ready
let mut attempts = 0;
let state = loop {
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
let state = order
.refresh()
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
match state.status {
OrderStatus::Ready | OrderStatus::Valid => break state.status,
OrderStatus::Invalid => {
return Err(AcmeError::ChallengeFailed(
"Order became invalid (challenge failed)".to_string(),
));
}
_ => {
attempts += 1;
if attempts > 30 {
return Err(AcmeError::Timeout(
"Order did not become ready within 60 seconds".to_string(),
));
}
}
}
};
debug!("Order ready, finalizing...");
// 6. Generate CSR and finalize
let key_pair = KeyPair::generate().map_err(|e| {
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
})?;
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
})?;
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let csr = params.serialize_request(&key_pair).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
})?;
if state == OrderStatus::Ready {
order
.finalize(csr.der())
.await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?;
}
// 7. Wait for certificate to be issued
let mut attempts = 0;
loop {
let state = order
.refresh()
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
if state.status == OrderStatus::Valid {
break;
}
if state.status == OrderStatus::Invalid {
return Err(AcmeError::FinalizationFailed(
"Order became invalid during finalization".to_string(),
));
}
attempts += 1;
if attempts > 15 {
return Err(AcmeError::Timeout(
"Certificate not issued within 30 seconds".to_string(),
));
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
}
// 8. Download certificate
let cert_chain_pem = order
.certificate()
.await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
.ok_or_else(|| {
AcmeError::FinalizationFailed("No certificate returned".to_string())
})?;
let private_key_pem = key_pair.serialize_pem();
info!("Certificate provisioned successfully for {}", domain);
Ok((cert_chain_pem, private_key_pem))
}
/// Restore an ACME account from stored credentials.
pub async fn restore_account(
&self,
credentials: AccountCredentials,
) -> Result<Account, AcmeError> {
Account::from_credentials(credentials)
.await
.map_err(|e| AcmeError::AccountCreation(e.to_string()))
}
/// Get the ACME directory URL based on production/staging.
pub fn directory_url(&self) -> &str {
if self.use_production {
"https://acme-v02.api.letsencrypt.org/directory"
} else {
"https://acme-staging-v02.api.letsencrypt.org/directory"
}
}
/// Whether this client is configured for production.
pub fn is_production(&self) -> bool {
self.use_production
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_directory_url_staging() {
let client = AcmeClient::new("test@example.com".to_string(), false);
assert!(client.directory_url().contains("staging"));
assert!(!client.is_production());
}
#[test]
fn test_directory_url_production() {
let client = AcmeClient::new("test@example.com".to_string(), true);
assert!(!client.directory_url().contains("staging"));
assert!(client.is_production());
}
#[test]
fn test_with_persistence_sets_account_dir() {
let tmp = tempfile::tempdir().unwrap();
let client = AcmeClient::with_persistence(
"test@example.com".to_string(),
false,
tmp.path(),
);
assert!(client.account_dir.is_some());
assert_eq!(client.account_dir.unwrap(), tmp.path());
}
#[test]
fn test_without_persistence_no_account_dir() {
let client = AcmeClient::new("test@example.com".to_string(), false);
assert!(client.account_dir.is_none());
}
}

View File

@@ -0,0 +1,183 @@
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use tracing::info;
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
use crate::acme::AcmeClient;
#[derive(Debug, Error)]
pub enum CertManagerError {
#[error("ACME provisioning failed for {domain}: {message}")]
AcmeFailure { domain: String, message: String },
#[error("Certificate store error: {0}")]
Store(#[from] crate::cert_store::CertStoreError),
#[error("No ACME email configured")]
NoEmail,
}
/// Certificate lifecycle manager.
/// Handles ACME provisioning, static cert loading, and renewal.
pub struct CertManager {
store: CertStore,
acme_email: Option<String>,
use_production: bool,
renew_before_days: u32,
}
impl CertManager {
pub fn new(
store: CertStore,
acme_email: Option<String>,
use_production: bool,
renew_before_days: u32,
) -> Self {
Self {
store,
acme_email,
use_production,
renew_before_days,
}
}
/// Get a certificate for a domain (from cache).
pub fn get_cert(&self, domain: &str) -> Option<&CertBundle> {
self.store.get(domain)
}
/// Create an ACME client using this manager's configuration.
/// Returns None if no ACME email is configured.
/// Account credentials are persisted in the cert store base directory.
pub fn acme_client(&self) -> Option<AcmeClient> {
self.acme_email.as_ref().map(|email| {
AcmeClient::with_persistence(
email.clone(),
self.use_production,
self.store.base_dir(),
)
})
}
/// Load a static certificate into the store.
pub fn load_static(
&mut self,
domain: String,
bundle: CertBundle,
) -> Result<(), CertManagerError> {
self.store.store(domain, bundle)?;
Ok(())
}
/// Check and return domains that need certificate renewal.
///
/// A certificate needs renewal if it expires within `renew_before_days`.
/// Returns a list of domain names needing renewal.
pub fn check_renewals(&self) -> Vec<String> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let renewal_threshold = self.renew_before_days as u64 * 86400;
let mut needs_renewal = Vec::new();
for (domain, bundle) in self.store.iter() {
// Only auto-renew ACME certs
if bundle.metadata.source != CertSource::Acme {
continue;
}
let time_until_expiry = bundle.metadata.expires_at.saturating_sub(now);
if time_until_expiry < renewal_threshold {
info!(
"Certificate for {} needs renewal (expires in {} days)",
domain,
time_until_expiry / 86400
);
needs_renewal.push(domain.clone());
}
}
needs_renewal
}
/// Renew a certificate for a domain.
///
/// Performs the full ACME provision+store flow. The `challenge_setup` closure
/// is called to arrange for the HTTP-01 challenge to be served. It receives
/// (token, key_authorization) and must make the challenge response available.
///
/// Returns the new CertBundle on success.
pub async fn renew_domain<F, Fut>(
&mut self,
domain: &str,
challenge_setup: F,
) -> Result<CertBundle, CertManagerError>
where
F: FnOnce(String, String) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let acme_client = self.acme_client()
.ok_or(CertManagerError::NoEmail)?;
info!("Renewing certificate for {}", domain);
let domain_owned = domain.to_string();
let result = acme_client.provision(&domain_owned, |pending| {
let token = pending.token.clone();
let key_auth = pending.key_authorization.clone();
async move {
challenge_setup(token, key_auth).await;
Ok(())
}
}).await.map_err(|e| CertManagerError::AcmeFailure {
domain: domain.to_string(),
message: e.to_string(),
})?;
let (cert_pem, key_pem) = result;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem,
key_pem,
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Acme,
issued_at: now,
expires_at: now + 90 * 86400,
renewed_at: Some(now),
},
};
self.store.store(domain.to_string(), bundle.clone())?;
info!("Certificate renewed and stored for {}", domain);
Ok(bundle)
}
/// Load all certificates from disk.
pub fn load_all(&mut self) -> Result<usize, CertManagerError> {
let loaded = self.store.load_all()?;
info!("Loaded {} certificates from store", loaded);
Ok(loaded)
}
/// Whether this manager has an ACME email configured.
pub fn has_acme(&self) -> bool {
self.acme_email.is_some()
}
/// Get reference to the underlying store.
pub fn store(&self) -> &CertStore {
&self.store
}
/// Get mutable reference to the underlying store.
pub fn store_mut(&mut self) -> &mut CertStore {
&mut self.store
}
}

View File

@@ -0,0 +1,314 @@
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CertStoreError {
#[error("Certificate not found for domain: {0}")]
NotFound(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Invalid certificate: {0}")]
Invalid(String),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
}
/// Certificate metadata stored alongside certs on disk.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CertMetadata {
pub domain: String,
pub source: CertSource,
pub issued_at: u64,
pub expires_at: u64,
pub renewed_at: Option<u64>,
}
/// How a certificate was obtained.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CertSource {
Acme,
Static,
Custom,
SelfSigned,
}
/// An in-memory certificate bundle.
#[derive(Debug, Clone)]
pub struct CertBundle {
pub key_pem: String,
pub cert_pem: String,
pub ca_pem: Option<String>,
pub metadata: CertMetadata,
}
/// Filesystem-backed certificate store.
///
/// File layout per domain:
/// ```text
/// {base_dir}/{domain}/
/// key.pem
/// cert.pem
/// ca.pem (optional)
/// metadata.json
/// ```
pub struct CertStore {
base_dir: PathBuf,
/// In-memory cache of loaded certs
cache: HashMap<String, CertBundle>,
}
impl CertStore {
/// Create a new cert store at the given directory.
pub fn new(base_dir: impl AsRef<Path>) -> Self {
Self {
base_dir: base_dir.as_ref().to_path_buf(),
cache: HashMap::new(),
}
}
/// Get a certificate by domain.
pub fn get(&self, domain: &str) -> Option<&CertBundle> {
self.cache.get(domain)
}
/// Store a certificate to both cache and filesystem.
pub fn store(&mut self, domain: String, bundle: CertBundle) -> Result<(), CertStoreError> {
// Sanitize domain for directory name (replace wildcards)
let dir_name = domain.replace('*', "_wildcard_");
let cert_dir = self.base_dir.join(&dir_name);
// Create directory
std::fs::create_dir_all(&cert_dir)?;
// Write key
std::fs::write(cert_dir.join("key.pem"), &bundle.key_pem)?;
// Write cert
std::fs::write(cert_dir.join("cert.pem"), &bundle.cert_pem)?;
// Write CA cert if present
if let Some(ref ca) = bundle.ca_pem {
std::fs::write(cert_dir.join("ca.pem"), ca)?;
}
// Write metadata
let metadata_json = serde_json::to_string_pretty(&bundle.metadata)?;
std::fs::write(cert_dir.join("metadata.json"), metadata_json)?;
// Update cache
self.cache.insert(domain, bundle);
Ok(())
}
/// Check if a certificate exists for a domain.
pub fn has(&self, domain: &str) -> bool {
self.cache.contains_key(domain)
}
/// Load all certificates from the base directory.
pub fn load_all(&mut self) -> Result<usize, CertStoreError> {
if !self.base_dir.exists() {
return Ok(0);
}
let entries = std::fs::read_dir(&self.base_dir)?;
let mut loaded = 0;
for entry in entries {
let entry = entry?;
let path = entry.path();
if !path.is_dir() {
continue;
}
let metadata_path = path.join("metadata.json");
let key_path = path.join("key.pem");
let cert_path = path.join("cert.pem");
// All three files must exist
if !metadata_path.exists() || !key_path.exists() || !cert_path.exists() {
continue;
}
// Load metadata
let metadata_str = std::fs::read_to_string(&metadata_path)?;
let metadata: CertMetadata = serde_json::from_str(&metadata_str)?;
// Load key and cert
let key_pem = std::fs::read_to_string(&key_path)?;
let cert_pem = std::fs::read_to_string(&cert_path)?;
// Load CA cert if present
let ca_path = path.join("ca.pem");
let ca_pem = if ca_path.exists() {
Some(std::fs::read_to_string(&ca_path)?)
} else {
None
};
let domain = metadata.domain.clone();
let bundle = CertBundle {
key_pem,
cert_pem,
ca_pem,
metadata,
};
self.cache.insert(domain, bundle);
loaded += 1;
}
Ok(loaded)
}
/// Get the base directory.
pub fn base_dir(&self) -> &Path {
&self.base_dir
}
/// Get the number of cached certificates.
pub fn count(&self) -> usize {
self.cache.len()
}
/// Iterate over all cached certificates.
pub fn iter(&self) -> impl Iterator<Item = (&String, &CertBundle)> {
self.cache.iter()
}
/// Remove a certificate from cache and filesystem.
pub fn remove(&mut self, domain: &str) -> Result<bool, CertStoreError> {
let removed = self.cache.remove(domain).is_some();
if removed {
let dir_name = domain.replace('*', "_wildcard_");
let cert_dir = self.base_dir.join(&dir_name);
if cert_dir.exists() {
std::fs::remove_dir_all(&cert_dir)?;
}
}
Ok(removed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_bundle(domain: &str) -> CertBundle {
CertBundle {
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: 1700000000,
expires_at: 1700000000 + 90 * 86400,
renewed_at: None,
},
}
}
#[test]
fn test_store_and_load_roundtrip() {
let tmp = tempfile::tempdir().unwrap();
let mut store = CertStore::new(tmp.path());
let bundle = make_test_bundle("example.com");
store.store("example.com".to_string(), bundle.clone()).unwrap();
// Verify files exist
let cert_dir = tmp.path().join("example.com");
assert!(cert_dir.join("key.pem").exists());
assert!(cert_dir.join("cert.pem").exists());
assert!(cert_dir.join("metadata.json").exists());
assert!(!cert_dir.join("ca.pem").exists()); // No CA cert
// Load into a fresh store
let mut store2 = CertStore::new(tmp.path());
let loaded = store2.load_all().unwrap();
assert_eq!(loaded, 1);
let loaded_bundle = store2.get("example.com").unwrap();
assert_eq!(loaded_bundle.key_pem, bundle.key_pem);
assert_eq!(loaded_bundle.cert_pem, bundle.cert_pem);
assert_eq!(loaded_bundle.metadata.domain, "example.com");
assert_eq!(loaded_bundle.metadata.source, CertSource::Static);
}
#[test]
fn test_store_with_ca_cert() {
let tmp = tempfile::tempdir().unwrap();
let mut store = CertStore::new(tmp.path());
let mut bundle = make_test_bundle("secure.com");
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
store.store("secure.com".to_string(), bundle).unwrap();
let cert_dir = tmp.path().join("secure.com");
assert!(cert_dir.join("ca.pem").exists());
let mut store2 = CertStore::new(tmp.path());
store2.load_all().unwrap();
let loaded = store2.get("secure.com").unwrap();
assert!(loaded.ca_pem.is_some());
}
#[test]
fn test_load_all_multiple_certs() {
let tmp = tempfile::tempdir().unwrap();
let mut store = CertStore::new(tmp.path());
store.store("a.com".to_string(), make_test_bundle("a.com")).unwrap();
store.store("b.com".to_string(), make_test_bundle("b.com")).unwrap();
store.store("c.com".to_string(), make_test_bundle("c.com")).unwrap();
let mut store2 = CertStore::new(tmp.path());
let loaded = store2.load_all().unwrap();
assert_eq!(loaded, 3);
assert!(store2.has("a.com"));
assert!(store2.has("b.com"));
assert!(store2.has("c.com"));
}
#[test]
fn test_load_all_missing_directory() {
let mut store = CertStore::new("/nonexistent/path/to/certs");
let loaded = store.load_all().unwrap();
assert_eq!(loaded, 0);
}
#[test]
fn test_remove_cert() {
let tmp = tempfile::tempdir().unwrap();
let mut store = CertStore::new(tmp.path());
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com")).unwrap();
assert!(store.has("remove-me.com"));
let removed = store.remove("remove-me.com").unwrap();
assert!(removed);
assert!(!store.has("remove-me.com"));
assert!(!tmp.path().join("remove-me.com").exists());
}
#[test]
fn test_wildcard_domain_storage() {
let tmp = tempfile::tempdir().unwrap();
let mut store = CertStore::new(tmp.path());
store.store("*.example.com".to_string(), make_test_bundle("*.example.com")).unwrap();
// Directory should use sanitized name
assert!(tmp.path().join("_wildcard_.example.com").exists());
let mut store2 = CertStore::new(tmp.path());
store2.load_all().unwrap();
assert!(store2.has("*.example.com"));
}
}

View File

@@ -0,0 +1,13 @@
//! # rustproxy-tls
//!
//! TLS certificate management for RustProxy.
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
pub mod cert_store;
pub mod cert_manager;
pub mod acme;
pub mod sni_resolver;
pub use cert_store::*;
pub use cert_manager::*;
pub use sni_resolver::*;

View File

@@ -0,0 +1,139 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::cert_store::CertBundle;
/// Dynamic SNI-based certificate resolver.
/// Used by the TLS stack to select the right certificate based on client SNI.
pub struct SniResolver {
/// Domain -> certificate bundle mapping
certs: RwLock<HashMap<String, Arc<CertBundle>>>,
/// Fallback certificate (used when no SNI or no match)
fallback: RwLock<Option<Arc<CertBundle>>>,
}
impl SniResolver {
pub fn new() -> Self {
Self {
certs: RwLock::new(HashMap::new()),
fallback: RwLock::new(None),
}
}
/// Register a certificate for a domain.
pub fn add_cert(&self, domain: String, bundle: CertBundle) {
let mut certs = self.certs.write().unwrap();
certs.insert(domain, Arc::new(bundle));
}
/// Set the fallback certificate.
pub fn set_fallback(&self, bundle: CertBundle) {
let mut fallback = self.fallback.write().unwrap();
*fallback = Some(Arc::new(bundle));
}
/// Resolve a certificate for the given SNI domain.
pub fn resolve(&self, domain: &str) -> Option<Arc<CertBundle>> {
let certs = self.certs.read().unwrap();
// Try exact match
if let Some(bundle) = certs.get(domain) {
return Some(Arc::clone(bundle));
}
// Try wildcard match (e.g., *.example.com)
if let Some(dot_pos) = domain.find('.') {
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
if let Some(bundle) = certs.get(&wildcard) {
return Some(Arc::clone(bundle));
}
}
// Fallback
let fallback = self.fallback.read().unwrap();
fallback.clone()
}
/// Remove a certificate for a domain.
pub fn remove_cert(&self, domain: &str) {
let mut certs = self.certs.write().unwrap();
certs.remove(domain);
}
/// Get the number of registered certificates.
pub fn cert_count(&self) -> usize {
self.certs.read().unwrap().len()
}
}
impl Default for SniResolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cert_store::{CertBundle, CertMetadata, CertSource};
fn make_bundle(domain: &str) -> CertBundle {
CertBundle {
key_pem: format!("KEY-{}", domain),
cert_pem: format!("CERT-{}", domain),
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: 0,
expires_at: 0,
renewed_at: None,
},
}
}
#[test]
fn test_exact_domain_resolve() {
let resolver = SniResolver::new();
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
let result = resolver.resolve("example.com");
assert!(result.is_some());
assert_eq!(result.unwrap().cert_pem, "CERT-example.com");
}
#[test]
fn test_wildcard_resolve() {
let resolver = SniResolver::new();
resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com"));
let result = resolver.resolve("sub.example.com");
assert!(result.is_some());
assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com");
}
#[test]
fn test_fallback() {
let resolver = SniResolver::new();
resolver.set_fallback(make_bundle("fallback"));
let result = resolver.resolve("unknown.com");
assert!(result.is_some());
assert_eq!(result.unwrap().cert_pem, "CERT-fallback");
}
#[test]
fn test_no_match_no_fallback() {
let resolver = SniResolver::new();
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
let result = resolver.resolve("other.com");
assert!(result.is_none());
}
#[test]
fn test_remove_cert() {
let resolver = SniResolver::new();
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
assert_eq!(resolver.cert_count(), 1);
resolver.remove_cert("example.com");
assert_eq!(resolver.cert_count(), 0);
assert!(resolver.resolve("example.com").is_none());
}
}

View File

@@ -0,0 +1,44 @@
[package]
name = "rustproxy"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "High-performance multi-protocol proxy built on Pingora, compatible with SmartProxy configuration"
[[bin]]
name = "rustproxy"
path = "src/main.rs"
[lib]
name = "rustproxy"
path = "src/lib.rs"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-tls = { workspace = true }
rustproxy-passthrough = { workspace = true }
rustproxy-http = { workspace = true }
rustproxy-nftables = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
clap = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
tokio-util = { workspace = true }
dashmap = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
[dev-dependencies]
rcgen = { workspace = true }

View File

@@ -0,0 +1,177 @@
//! HTTP-01 ACME challenge server.
//!
//! A lightweight HTTP server that serves ACME challenge responses at
//! `/.well-known/acme-challenge/<token>`.
use std::sync::Arc;
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error};
/// ACME HTTP-01 challenge server.
pub struct ChallengeServer {
/// Token -> key authorization mapping
challenges: Arc<DashMap<String, String>>,
/// Cancellation token to stop the server
cancel: CancellationToken,
/// Server task handle
handle: Option<tokio::task::JoinHandle<()>>,
}
impl ChallengeServer {
/// Create a new challenge server (not yet started).
pub fn new() -> Self {
Self {
challenges: Arc::new(DashMap::new()),
cancel: CancellationToken::new(),
handle: None,
}
}
/// Register a challenge token -> key_authorization mapping.
pub fn set_challenge(&self, token: String, key_authorization: String) {
debug!("Registered ACME challenge: token={}", token);
self.challenges.insert(token, key_authorization);
}
/// Remove a challenge token.
pub fn remove_challenge(&self, token: &str) {
self.challenges.remove(token);
}
/// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port);
let challenges = Arc::clone(&self.challenges);
let cancel = self.cancel.clone();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("ACME challenge server stopping");
break;
}
result = listener.accept() => {
match result {
Ok((stream, _)) => {
let challenges = Arc::clone(&challenges);
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let challenges = Arc::clone(&challenges);
async move {
Self::handle_request(req, &challenges)
}
});
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service);
if let Err(e) = conn.await {
debug!("Challenge server connection error: {}", e);
}
});
}
Err(e) => {
error!("Challenge server accept error: {}", e);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
}
}
});
self.handle = Some(handle);
Ok(())
}
/// Stop the challenge server.
pub async fn stop(&mut self) {
self.cancel.cancel();
if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
}
self.challenges.clear();
self.cancel = CancellationToken::new();
info!("ACME challenge server stopped");
}
/// Handle an HTTP request for ACME challenges.
fn handle_request(
req: Request<Incoming>,
challenges: &DashMap<String, String>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let path = req.uri().path();
if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
if let Some(key_auth) = challenges.get(token) {
debug!("Serving ACME challenge for token: {}", token);
return Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/plain")
.body(Full::new(Bytes::from(key_auth.value().clone())))
.unwrap());
}
}
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found")))
.unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_challenge_server_lifecycle() {
let mut server = ChallengeServer::new();
// Set a challenge before starting
server.set_challenge("test-token".to_string(), "test-key-auth".to_string());
// Start on a random port
server.start(19900).await.unwrap();
// Give server a moment to start
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; });
let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = sender.send_request(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// Test 404 for unknown token
let req = Request::get("/.well-known/acme-challenge/unknown")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = sender.send_request(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
server.stop().await;
}
}

View File

@@ -0,0 +1,931 @@
//! # RustProxy
//!
//! High-performance multi-protocol proxy built on Rust,
//! compatible with SmartProxy configuration.
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use rustproxy::RustProxy;
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let options = RustProxyOptions {
//! routes: vec![
//! create_https_passthrough_route("example.com", "backend", 443),
//! ],
//! ..Default::default()
//! };
//!
//! let mut proxy = RustProxy::new(options)?;
//! proxy.start().await?;
//! Ok(())
//! }
//! ```
pub mod challenge_server;
pub mod management;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result;
use tracing::{info, warn, debug, error};
// Re-export key types
pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http;
pub use rustproxy_nftables;
pub use rustproxy_metrics;
pub use rustproxy_security;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_nftables::{NftManager, rule_builder};
/// Certificate status.
#[derive(Debug, Clone)]
pub struct CertStatus {
pub domain: String,
pub source: String,
pub expires_at: u64,
pub is_valid: bool,
}
/// The main RustProxy struct.
/// This is the primary public API matching SmartProxy's interface.
pub struct RustProxy {
options: RustProxyOptions,
route_table: ArcSwap<RouteManager>,
listener_manager: Option<TcpListenerManager>,
metrics: Arc<MetricsCollector>,
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
challenge_server: Option<challenge_server::ChallengeServer>,
renewal_handle: Option<tokio::task::JoinHandle<()>>,
nft_manager: Option<NftManager>,
started: bool,
started_at: Option<Instant>,
/// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
socket_handler_relay_path: Option<String>,
}
impl RustProxy {
/// Create a new RustProxy instance with the given configuration.
pub fn new(mut options: RustProxyOptions) -> Result<Self> {
// Apply defaults to routes before validation
Self::apply_defaults(&mut options);
// Validate routes
if let Err(errors) = rustproxy_config::validate_routes(&options.routes) {
for err in &errors {
warn!("Route validation error: {}", err);
}
if !errors.is_empty() {
anyhow::bail!("Route validation failed with {} errors", errors.len());
}
}
let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options)
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
Ok(Self {
options,
route_table: ArcSwap::from(Arc::new(route_manager)),
listener_manager: None,
metrics: Arc::new(MetricsCollector::new()),
cert_manager,
challenge_server: None,
renewal_handle: None,
nft_manager: None,
started: false,
started_at: None,
socket_handler_relay_path: None,
})
}
/// Apply default configuration to routes that lack targets or security.
fn apply_defaults(options: &mut RustProxyOptions) {
let defaults = match &options.defaults {
Some(d) => d.clone(),
None => return,
};
for route in &mut options.routes {
// Apply default target if route has no targets
if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}",
default_target.host, default_target.port,
route.name.as_deref().unwrap_or("unnamed"));
route.action.targets = Some(vec![
rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
port: rustproxy_config::PortSpec::Fixed(default_target.port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}
]);
}
}
// Apply default security if route has no security
if route.security.is_none() {
if let Some(ref default_security) = defaults.security {
let mut security = rustproxy_config::RouteSecurity {
ip_allow_list: None,
ip_block_list: None,
max_connections: default_security.max_connections,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
};
if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some(allow_list.clone());
}
if let Some(ref block_list) = default_security.ip_block_list {
security.ip_block_list = Some(block_list.clone());
}
// Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed"));
route.security = Some(security);
}
}
}
}
}
/// Build a CertManager from options.
fn build_cert_manager(options: &RustProxyOptions) -> Option<CertManager> {
let acme = options.acme.as_ref()?;
if !acme.enabled.unwrap_or(false) {
return None;
}
let store_path = acme.certificate_store
.as_deref()
.unwrap_or("./certs");
let email = acme.email.clone()
.or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new(store_path);
Some(CertManager::new(store, email, use_production, renew_before_days))
}
/// Build ConnectionConfig from RustProxyOptions.
fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig {
ConnectionConfig {
connection_timeout_ms: options.effective_connection_timeout(),
initial_data_timeout_ms: options.effective_initial_data_timeout(),
socket_timeout_ms: options.effective_socket_timeout(),
max_connection_lifetime_ms: options.effective_max_connection_lifetime(),
graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000),
max_connections_per_ip: options.max_connections_per_ip,
connection_rate_limit_per_minute: options.connection_rate_limit_per_minute,
keep_alive_treatment: options.keep_alive_treatment.clone(),
keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier,
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
}
}
/// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> {
if self.started {
anyhow::bail!("Proxy is already started");
}
info!("Starting RustProxy...");
// Load persisted certificates
if let Some(ref cm) = self.cert_manager {
let mut cm = cm.lock().await;
match cm.load_all() {
Ok(count) => {
if count > 0 {
info!("Loaded {} persisted certificates", count);
}
}
Err(e) => warn!("Failed to load persisted certificates: {}", e),
}
}
// Auto-provision certificates for routes with certificate: 'auto'
self.auto_provision_certificates().await;
let route_manager = self.route_table.load();
let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
// Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics(
Arc::clone(&*route_manager),
Arc::clone(&self.metrics),
);
// Apply connection config from options
let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms,
conn_config.max_connection_lifetime_ms,
);
listener.set_connection_config(conn_config);
// Extract TLS configurations from routes and cert manager
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
// Also load certs from cert manager into TLS config
if let Some(ref cm) = self.cert_manager {
let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
}
}
}
if !tls_configs.is_empty() {
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
listener.set_tls_configs(tls_configs);
}
// Bind all ports
for port in &ports {
listener.add_port(*port).await?;
}
self.listener_manager = Some(listener);
self.started = true;
self.started_at = Some(Instant::now());
// Apply NFTables rules for routes using nftables forwarding engine
self.apply_nftables_rules(&self.options.routes.clone()).await;
// Start renewal timer if ACME is enabled
self.start_renewal_timer();
info!("RustProxy started successfully on ports: {:?}", ports);
Ok(())
}
/// Auto-provision certificates for routes that use certificate: 'auto'.
async fn auto_provision_certificates(&mut self) {
let cm_arc = match self.cert_manager {
Some(ref cm) => Arc::clone(cm),
None => return,
};
let mut domains_to_provision = Vec::new();
for route in &self.options.routes {
let tls_mode = route.tls_mode();
let needs_cert = matches!(
tls_mode,
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
);
if !needs_cert {
continue;
}
let cert_spec = route.action.tls.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
let domain = domain.to_string();
// Skip if we already have a valid cert
let cm = cm_arc.lock().await;
if cm.store().has(&domain) {
debug!("Already have cert for {}, skipping auto-provision", domain);
continue;
}
drop(cm);
domains_to_provision.push(domain);
}
}
}
}
if domains_to_provision.is_empty() {
return;
}
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
// Start challenge server
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
return;
}
for domain in &domains_to_provision {
info!("Provisioning certificate for {}", domain);
let cm = cm_arc.lock().await;
let acme_client = cm.acme_client();
drop(cm);
if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| {
challenge_server_ref.set_challenge(
pending.token.clone(),
pending.key_authorization.clone(),
);
async move { Ok(()) }
}).await;
match result {
Ok((cert_pem, key_pem)) => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem,
key_pem,
ca_pem: None,
metadata: CertMetadata {
domain: domain.clone(),
source: CertSource::Acme,
issued_at: now,
expires_at: now + 90 * 86400, // 90 days
renewed_at: None,
},
};
let mut cm = cm_arc.lock().await;
if let Err(e) = cm.load_static(domain.clone(), bundle) {
error!("Failed to store certificate for {}: {}", domain, e);
}
info!("Certificate provisioned for {}", domain);
}
Err(e) => {
error!("Failed to provision certificate for {}: {}", domain, e);
}
}
}
}
challenge_server.stop().await;
}
/// Start the renewal timer background task.
/// The background task checks for expiring certificates and renews them.
fn start_renewal_timer(&mut self) {
let cm_arc = match self.cert_manager {
Some(ref cm) => Arc::clone(cm),
None => return,
};
let auto_renew = self.options.acme.as_ref()
.and_then(|a| a.auto_renew)
.unwrap_or(true);
if !auto_renew {
return;
}
let check_interval_hours = self.options.acme.as_ref()
.and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24);
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600);
let handle = tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours);
// Check which domains need renewal
let domains = {
let cm = cm_arc.lock().await;
cm.check_renewals()
};
if domains.is_empty() {
debug!("No certificates need renewal");
continue;
}
info!("Renewing {} certificate(s)", domains.len());
// Start challenge server for renewals
let mut cs = challenge_server::ChallengeServer::new();
if let Err(e) = cs.start(acme_port).await {
error!("Failed to start challenge server for renewal: {}", e);
continue;
}
for domain in &domains {
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
match result {
Ok(_bundle) => {
info!("Successfully renewed certificate for {}", domain);
}
Err(e) => {
error!("Failed to renew certificate for {}: {}", domain, e);
}
}
}
cs.stop().await;
}
});
self.renewal_handle = Some(handle);
}
/// Stop the proxy gracefully.
pub async fn stop(&mut self) -> Result<()> {
if !self.started {
return Ok(());
}
info!("Stopping RustProxy...");
// Stop renewal timer
if let Some(handle) = self.renewal_handle.take() {
handle.abort();
}
// Stop challenge server if running
if let Some(ref mut cs) = self.challenge_server {
cs.stop().await;
}
self.challenge_server = None;
// Clean up NFTables rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup failed: {}", e);
}
}
self.nft_manager = None;
if let Some(ref mut listener) = self.listener_manager {
listener.graceful_stop().await;
}
self.listener_manager = None;
self.started = false;
info!("RustProxy stopped");
Ok(())
}
/// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes
rustproxy_config::validate_routes(&routes)
.map_err(|errors| {
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
})?;
let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports",
new_manager.route_count(), new_ports.len());
// Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
listener.listening_ports()
} else {
vec![]
};
// Atomically swap the route table
let new_manager = Arc::new(new_manager);
self.route_table.store(Arc::clone(&new_manager));
// Update listener manager
if let Some(ref mut listener) = self.listener_manager {
listener.update_route_manager(Arc::clone(&new_manager));
// Update TLS configs
let mut tls_configs = Self::extract_tls_configs(&routes);
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
}
}
}
listener.set_tls_configs(tls_configs);
// Add new ports
for port in &new_ports {
if !old_ports.contains(port) {
listener.add_port(*port).await?;
}
}
// Remove old ports no longer needed
for port in &old_ports {
if !new_ports.contains(port) {
listener.remove_port(*port);
}
}
}
// Update NFTables rules: remove old, apply new
self.update_nftables_rules(&routes).await;
self.options.routes = routes;
Ok(())
}
/// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref()
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
// Find the route by name
let route = self.options.routes.iter()
.find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
// Start challenge server
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
drop(cm);
cs.stop().await;
let bundle = result
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
}
}
listener.set_tls_configs(tls_configs);
}
info!("Certificate provisioned and loaded for route '{}'", route_name);
Ok(())
}
/// Renew a certificate for a named route.
pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> {
// Renewal is just re-provisioning
self.provision_certificate(route_name).await
}
/// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter()
.find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
if let Some(bundle) = cm.get_cert(&domain) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
return Some(CertStatus {
domain,
source: format!("{:?}", bundle.metadata.source),
expires_at: bundle.metadata.expires_at,
is_valid: bundle.metadata.expires_at > now,
});
}
}
None
}
/// Get current metrics snapshot.
pub fn get_metrics(&self) -> Metrics {
self.metrics.snapshot()
}
/// Add a listening port at runtime.
pub async fn add_listening_port(&mut self, port: u16) -> Result<()> {
if let Some(ref mut listener) = self.listener_manager {
listener.add_port(port).await?;
}
Ok(())
}
/// Remove a listening port at runtime.
pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> {
if let Some(ref mut listener) = self.listener_manager {
listener.remove_port(port);
}
Ok(())
}
/// Get all currently listening ports.
pub fn get_listening_ports(&self) -> Vec<u16> {
self.listener_manager
.as_ref()
.map(|l| l.listening_ports())
.unwrap_or_default()
}
/// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
Statistics {
active_connections: self.metrics.active_connections(),
total_connections: self.metrics.total_connections(),
routes_count: self.route_table.load().route_count() as u64,
listening_ports: self.get_listening_ports(),
uptime_seconds: uptime,
}
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
info!("Socket handler relay path set to: {:?}", path);
self.socket_handler_relay_path = path;
}
/// Get the current socket handler relay path.
pub fn get_socket_handler_relay_path(&self) -> Option<&str> {
self.socket_handler_relay_path.as_deref()
}
/// Load a certificate for a domain and hot-swap the TLS configuration.
pub async fn load_certificate(
&mut self,
domain: &str,
cert_pem: String,
key_pem: String,
ca_pem: Option<String>,
) -> Result<()> {
info!("Loading certificate for domain: {}", domain);
// Store in cert manager if available
if let Some(ref cm_arc) = self.cert_manager {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
ca_pem: ca_pem.clone(),
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: now,
expires_at: now + 90 * 86400, // assume 90 days
renewed_at: None,
},
};
let mut cm = cm_arc.lock().await;
cm.load_static(domain.to_string(), bundle)
.map_err(|e| anyhow::anyhow!("Failed to store certificate: {}", e))?;
}
// Hot-swap TLS config on the listener
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
// Add the new cert
tls_configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
});
// Also include all existing certs from cert manager
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
}
}
}
listener.set_tls_configs(tls_configs);
}
info!("Certificate loaded and TLS config updated for {}", domain);
Ok(())
}
/// Get NFTables status.
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
match &self.nft_manager {
Some(nft) => Ok(nft.status()),
None => Ok(HashMap::new()),
}
}
/// Apply NFTables rules for routes using the nftables forwarding engine.
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
let nft_routes: Vec<&RouteConfig> = routes.iter()
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
.collect();
if nft_routes.is_empty() {
return;
}
info!("Applying NFTables rules for {} routes", nft_routes.len());
let table_name = nft_routes.iter()
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
.unwrap_or_else(|| "rustproxy".to_string());
let mut nft = NftManager::new(Some(table_name));
for route in &nft_routes {
let route_id = route.id.as_deref()
.or(route.name.as_deref())
.unwrap_or("unnamed");
let nft_options = match &route.action.nftables {
Some(opts) => opts.clone(),
None => rustproxy_config::NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
},
};
let targets = match &route.action.targets {
Some(targets) => targets,
None => {
warn!("NFTables route '{}' has no targets, skipping", route_id);
continue;
}
};
let source_ports = route.route_match.ports.to_ports();
for target in targets {
let target_host = target.host.first().to_string();
let target_port_spec = &target.port;
for &source_port in &source_ports {
let resolved_port = target_port_spec.resolve(source_port);
let rules = rule_builder::build_dnat_rule(
nft.table_name(),
"prerouting",
source_port,
&target_host,
resolved_port,
&nft_options,
);
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
}
}
}
}
self.nft_manager = Some(nft);
}
/// Update NFTables rules when routes change.
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
// Clean up old rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup during update failed: {}", e);
}
}
self.nft_manager = None;
// Apply new rules
self.apply_nftables_rules(new_routes).await;
}
/// Extract TLS configurations from route configs.
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
let mut configs = HashMap::new();
for route in routes {
let tls_mode = route.tls_mode();
let needs_cert = matches!(
tls_mode,
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
);
if !needs_cert {
continue;
}
let cert_spec = route.action.tls.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
});
}
}
}
}
configs
}
}

View File

@@ -0,0 +1,90 @@
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management;
use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy
#[derive(Parser, Debug)]
#[command(name = "rustproxy", version, about)]
struct Cli {
/// Path to JSON configuration file
#[arg(short, long, default_value = "config.json")]
config: String,
/// Log level (trace, debug, info, warn, error)
#[arg(short, long, default_value = "info")]
log_level: String,
/// Validate configuration without starting
#[arg(long)]
validate: bool,
/// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper)
#[arg(long)]
management: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize tracing - write to stderr so stdout is reserved for management IPC
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
)
.init();
// Management mode: JSON IPC over stdin/stdout
if cli.management {
tracing::info!("RustProxy starting in management mode...");
return management::management_loop().await;
}
tracing::info!("RustProxy starting...");
// Load configuration
let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!(
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
// Validate-only mode
if cli.validate {
match rustproxy_config::validate_routes(&options.routes) {
Ok(()) => {
tracing::info!("Configuration is valid");
return Ok(());
}
Err(errors) => {
for err in &errors {
tracing::error!("Validation error: {}", err);
}
anyhow::bail!("{} validation errors found", errors.len());
}
}
}
// Create and start proxy
let mut proxy = RustProxy::new(options)?;
proxy.start().await?;
// Wait for shutdown signal
tracing::info!("RustProxy is running. Press Ctrl+C to stop.");
tokio::signal::ctrl_c().await?;
tracing::info!("Shutdown signal received");
proxy.stop().await?;
tracing::info!("RustProxy shutdown complete");
Ok(())
}

View File

@@ -0,0 +1,470 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use crate::RustProxy;
use rustproxy_config::RustProxyOptions;
/// A management request from the TypeScript wrapper.
#[derive(Debug, Deserialize)]
pub struct ManagementRequest {
pub id: String,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// A management response back to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementResponse {
pub id: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// An unsolicited event from the proxy to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementEvent {
pub event: String,
pub data: serde_json::Value,
}
impl ManagementResponse {
fn ok(id: String, result: serde_json::Value) -> Self {
Self {
id,
success: true,
result: Some(result),
error: None,
}
}
fn err(id: String, message: String) -> Self {
Self {
id,
success: false,
result: None,
error: Some(message),
}
}
}
fn send_line(line: &str) {
// Use blocking stdout write - we're writing short JSON lines
use std::io::Write;
let stdout = std::io::stdout();
let mut handle = stdout.lock();
let _ = handle.write_all(line.as_bytes());
let _ = handle.write_all(b"\n");
let _ = handle.flush();
}
fn send_response(response: &ManagementResponse) {
match serde_json::to_string(response) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management response: {}", e),
}
}
fn send_event(event: &str, data: serde_json::Value) {
let evt = ManagementEvent {
event: event.to_string(),
data,
};
match serde_json::to_string(&evt) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management event: {}", e),
}
}
/// Run the management loop, reading JSON commands from stdin and writing responses to stdout.
pub async fn management_loop() -> Result<()> {
let stdin = BufReader::new(tokio::io::stdin());
let mut lines = stdin.lines();
let mut proxy: Option<RustProxy> = None;
send_event("ready", serde_json::json!({}));
loop {
let line = match lines.next_line().await {
Ok(Some(line)) => line,
Ok(None) => {
// stdin closed - parent process exited
info!("Management stdin closed, shutting down");
if let Some(ref mut p) = proxy {
let _ = p.stop().await;
}
break;
}
Err(e) => {
error!("Error reading management stdin: {}", e);
break;
}
};
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let request: ManagementRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
error!("Failed to parse management request: {}", e);
// Send error response without an ID
send_response(&ManagementResponse::err(
"unknown".to_string(),
format!("Failed to parse request: {}", e),
));
continue;
}
};
let response = handle_request(&request, &mut proxy).await;
send_response(&response);
}
Ok(())
}
async fn handle_request(
request: &ManagementRequest,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let id = request.id.clone();
match request.method.as_str() {
"start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
_ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)),
}
}
async fn handle_start(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
if proxy.is_some() {
return ManagementResponse::err(id.to_string(), "Proxy is already running".to_string());
}
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
};
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
Ok(o) => o,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
};
match RustProxy::new(options) {
Ok(mut p) => {
match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
}
}
async fn handle_stop(
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_mut() {
Some(p) => {
match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_update_routes(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let routes = match params.get("routes") {
Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
};
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
Ok(r) => r,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid routes: {}", e)),
};
match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
}
}
fn handle_get_metrics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let metrics = p.get_metrics();
match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
fn handle_get_statistics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let stats = p.get_statistics();
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
async fn handle_provision_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
}
}
async fn handle_renew_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
}
}
async fn handle_get_certificate_status(
id: &str,
params: &serde_json::Value,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_ref() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
})),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
}
}
fn handle_get_listening_ports(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let ports = p.get_listening_ports();
ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": ports }))
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": [] })),
}
}
async fn handle_get_nftables_status(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
match p.get_nftables_status().await {
Ok(status) => {
match serde_json::to_value(&status) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_set_socket_handler_relay(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("setSocketHandlerRelay: socket_path={:?}", socket_path);
p.set_socket_handler_relay_path(socket_path);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_add_listening_port(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
};
match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
}
}
async fn handle_remove_listening_port(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
};
match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
}
}
async fn handle_load_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
};
let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
};
let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
};
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
}
}

View File

@@ -0,0 +1,402 @@
use std::sync::atomic::{AtomicU16, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
/// Atomic port allocator starting at 19000 to avoid collisions.
static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000);
/// Get the next available port for testing.
pub fn next_port() -> u16 {
PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
}
/// Start a simple TCP echo server that echoes back whatever it receives.
/// Returns the join handle for the server task.
pub async fn start_echo_server(port: u16) -> JoinHandle<()> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind echo server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Start a TCP echo server that prefixes responses to identify which backend responded.
pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> {
let prefix = prefix.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind prefix echo server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let pfx = prefix.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let mut response = pfx.as_bytes().to_vec();
response.extend_from_slice(&buf[..n]);
if stream.write_all(&response).await.is_err() {
break;
}
}
});
}
})
}
/// Start a simple HTTP server that responds with a fixed status and body.
pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> {
let body = body.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind HTTP server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let b = body.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
// Read the request
let _n = stream.read(&mut buf).await.unwrap_or(0);
// Send response
let response = format!(
"HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
b.len(),
b,
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
})
}
/// Start an HTTP backend server that echoes back request details as JSON.
/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":"<name>"}
/// Supports keep-alive by reading HTTP requests properly.
pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> {
let name = backend_name.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let backend = name.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 16384];
// Read request data
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Parse first line: METHOD PATH HTTP/x.x
let first_line = req_str.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.split_whitespace().collect();
let method = parts.first().copied().unwrap_or("UNKNOWN");
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
let body = format!(
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
method, path, host, backend
);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body,
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
})
}
/// Wrap a future with a timeout, preventing tests from hanging.
pub async fn with_timeout<F, T>(future: F, secs: u64) -> Result<T, &'static str>
where
F: std::future::Future<Output = T>,
{
match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await {
Ok(result) => Ok(result),
Err(_) => Err("Test timed out"),
}
}
/// Wait briefly for a server to be ready by attempting TCP connections.
pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while start.elapsed() < timeout {
if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.is_ok()
{
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
false
}
/// Helper to create a minimal route config for testing.
pub fn make_test_route(
port: u16,
domain: Option<&str>,
target_host: &str,
target_port: u16,
) -> rustproxy_config::RouteConfig {
rustproxy_config::RouteConfig {
id: None,
route_match: rustproxy_config::RouteMatch {
ports: rustproxy_config::PortRange::Single(port),
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
tls_version: None,
headers: None,
},
action: rustproxy_config::RouteAction {
action_type: rustproxy_config::RouteActionType::Forward,
targets: Some(vec![rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(target_host.to_string()),
port: rustproxy_config::PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Start a simple WebSocket echo backend.
///
/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back,
/// then echoes all data received on the connection.
pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
tokio::spawn(async move {
// Read the HTTP upgrade request
let mut buf = vec![0u8; 4096];
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
// Compute Sec-WebSocket-Accept (simplified - just echo for test purposes)
// Real implementation would compute SHA-1 + base64
let accept_response = format!(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n",
ws_key
);
if stream.write_all(accept_response.as_bytes()).await.is_err() {
return;
}
// Echo all data back (raw TCP after upgrade)
let mut echo_buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut echo_buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if stream.write_all(&echo_buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Generate a self-signed certificate for testing using rcgen.
/// Returns (cert_pem, key_pem).
pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
(cert.pem(), key_pair.serialize_pem())
}
/// Start a TLS echo server using the given cert/key.
/// Returns the join handle.
pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
.expect("Failed to build TLS acceptor");
let acceptor = Arc::new(acceptor);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind TLS echo server");
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acc = acceptor.clone();
tokio::spawn(async move {
let mut tls_stream = match acc.accept(stream).await {
Ok(s) => s,
Err(_) => return,
};
let mut buf = vec![0u8; 65536];
loop {
let n = match tls_stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if tls_stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Helper to create a TLS terminate route with static cert for testing.
pub fn make_tls_terminate_route(
port: u16,
domain: &str,
target_host: &str,
target_port: u16,
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_test_route(port, Some(domain), target_host, target_port);
route.action.tls = Some(rustproxy_config::RouteTls {
mode: rustproxy_config::TlsMode::Terminate,
certificate: Some(rustproxy_config::CertificateSpec::Static(
rustproxy_config::CertificateConfig {
cert: cert_pem.to_string(),
key: key_pem.to_string(),
ca: None,
key_file: None,
cert_file: None,
},
)),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Helper to create a TLS passthrough route for testing.
pub fn make_tls_passthrough_route(
port: u16,
domain: Option<&str>,
target_host: &str,
target_port: u16,
) -> rustproxy_config::RouteConfig {
let mut route = make_test_route(port, domain, target_host, target_port);
route.action.tls = Some(rustproxy_config::RouteTls {
mode: rustproxy_config::TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}

View File

@@ -0,0 +1,453 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// Send a raw HTTP request and return the full response as a string.
async fn send_http_request(port: u16, host: &str, method: &str, path: &str) -> String {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let request = format!(
"{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
method, path, host,
);
stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}
/// Extract the body from a raw HTTP response string (after the \r\n\r\n).
fn extract_body(response: &str) -> &str {
response.split("\r\n\r\n").nth(1).unwrap_or("")
}
#[tokio::test]
async fn test_http_forward_basic() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
}, 10)
.await
.unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_host_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_http_echo_backend(backend1_port, "alpha").await;
let _b2 = start_http_echo_backend(backend2_port, "beta").await;
let options = RustProxyOptions {
routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let alpha_result = with_timeout(async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
// Test beta domain
let beta_result = with_timeout(async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_path_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_http_echo_backend(backend1_port, "api").await;
let _b2 = start_http_echo_backend(backend2_port, "web").await;
let mut api_route = make_test_route(proxy_port, None, "127.0.0.1", backend1_port);
api_route.route_match.path = Some("/api/**".to_string());
api_route.priority = Some(10);
let web_route = make_test_route(proxy_port, None, "127.0.0.1", backend2_port);
let options = RustProxyOptions {
routes: vec![api_route, web_route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test API path
let api_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
// Test web path (no /api prefix)
let web_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_cors_preflight() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send CORS preflight request
let request = format!(
"OPTIONS /api/data HTTP/1.1\r\nHost: example.com\r\nOrigin: http://localhost:3000\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
);
stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
// Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
assert!(result.to_lowercase().contains("access-control-allow-origin"),
"Expected CORS header, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_backend_error() {
let backend_port = next_port();
let proxy_port = next_port();
// Start an HTTP server that returns 500
let _backend = start_http_server(backend_port, 500, "Internal Error").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
}, 10)
.await
.unwrap();
// Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_no_route_matched() {
let proxy_port = next_port();
// Create a route only for a specific domain
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
}, 10)
.await
.unwrap();
// Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_backend_unavailable() {
let proxy_port = next_port();
let dead_port = next_port(); // No server running here
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
}, 10)
.await
.unwrap();
// Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_https_terminate_http_forward() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "httpproxy.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_http_echo_backend(backend_port, "tls-backend").await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send HTTP request through TLS
let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain
);
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_websocket_through_proxy() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_ws_echo_backend(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send WebSocket upgrade request
let request = format!(
"GET /ws HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; }
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break;
}
}
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap();
// Read echoed data
let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "ok");
proxy.stop().await.unwrap();
}
/// InsecureVerifier for test TLS client connections.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}

View File

@@ -0,0 +1,250 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn test_start_and_stop() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
// Not listening before start
assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
proxy.stop().await.unwrap();
// Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
}
#[tokio::test]
async fn test_double_start_fails() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Second start should fail
let result = proxy.start().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already started"));
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_update_routes_hot_reload() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Update routes atomically
let new_routes = vec![
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
];
let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok());
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_add_remove_listening_port() {
let port1 = next_port();
let port2 = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port1, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
// Add a new port
proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
// Remove the port
proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
// Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_get_statistics() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
let stats = proxy.get_statistics();
assert_eq!(stats.routes_count, 1);
assert!(stats.listening_ports.contains(&port));
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_invalid_routes_rejected() {
let options = RustProxyOptions {
routes: vec![{
let mut route = make_test_route(80, None, "127.0.0.1", 8080);
route.action.targets = None; // Invalid: forward without targets
route
}],
..Default::default()
};
let result = RustProxy::new(options);
assert!(result.is_err());
}
#[tokio::test]
async fn test_metrics_track_connections() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// No connections yet
let stats = proxy.get_statistics();
assert_eq!(stats.total_connections, 0);
// Make a connection and send data
{
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello").await.unwrap();
let mut buf = vec![0u8; 16];
let _ = stream.read(&mut buf).await;
}
// Small delay for metrics to update
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_metrics_track_bytes() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "metrics-test").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send HTTP request through proxy
{
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let request = b"GET /test HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n";
stream.write_all(request).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
assert!(!response.is_empty(), "Expected non-empty response");
}
// Small delay for metrics to update
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0,
"Expected some connections tracked, got {}", stats.total_connections);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_hot_reload_port_changes() {
let port1 = next_port();
let port2 = next_port();
let backend_port = next_port();
let _backend = start_echo_server(backend_port).await;
// Start with port1
let options = RustProxyOptions {
routes: vec![make_test_route(port1, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
// Update routes to use port2 instead
let new_routes = vec![
make_test_route(port2, None, "127.0.0.1", backend_port),
];
proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
// Verify port2 works
let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,197 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[tokio::test]
async fn test_tcp_forward_echo() {
let backend_port = next_port();
let proxy_port = next_port();
// Start echo backend
let _backend = start_echo_server(backend_port).await;
// Configure proxy
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
// Connect and send data
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello world").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert_eq!(result, "hello world");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_large_payload() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send 1MB of data
let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
// Read all back
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 10)
.await
.unwrap();
assert_eq!(result, 1_000_000);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_multiple_connections() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
handles.push(tokio::spawn(async move {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 10)
.await
.unwrap();
assert_eq!(result.len(), 10);
for (i, r) in result.iter().enumerate() {
assert_eq!(r, &format!("connection-{}", i));
}
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_backend_unreachable() {
let proxy_port = next_port();
let dead_port = next_port(); // No server on this port
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
}, 5)
.await
.unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_bidirectional() {
let backend_port = next_port();
let proxy_port = next_port();
// Start a prefix echo server to verify data flows in both directions
let _backend = start_prefix_echo_server(backend_port, "REPLY:").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"test data").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert_eq!(result, "REPLY:test data");
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,247 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
/// Build a minimal TLS ClientHello with the given SNI domain.
/// This is enough for the proxy's SNI parser to extract the domain.
fn build_client_hello(domain: &str) -> Vec<u8> {
let domain_bytes = domain.as_bytes();
let sni_length = domain_bytes.len() as u16;
// Server Name extension (type 0x0000)
let mut sni_ext = Vec::new();
sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name
let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name
sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length
sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length
sni_ext.push(0x00); // host_name type
sni_ext.extend_from_slice(&sni_length.to_be_bytes());
sni_ext.extend_from_slice(domain_bytes);
let extensions_length = sni_ext.len() as u16;
// ClientHello message
let mut client_hello = Vec::new();
client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version
client_hello.extend_from_slice(&[0x00; 32]); // random
client_hello.push(0x00); // session_id length
client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite)
client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null)
client_hello.extend_from_slice(&extensions_length.to_be_bytes());
client_hello.extend_from_slice(&sni_ext);
let hello_len = client_hello.len() as u32;
// Handshake wrapper (type 1 = ClientHello)
let mut handshake = Vec::new();
handshake.push(0x01); // ClientHello
handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length
handshake.extend_from_slice(&client_hello);
let hs_len = handshake.len() as u16;
// TLS record
let mut record = Vec::new();
record.push(0x16); // ContentType: Handshake
record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version)
record.extend_from_slice(&hs_len.to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[tokio::test]
async fn test_tls_passthrough_sni_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await;
let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
// Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
// Now test routing to backend2
let result2 = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_unknown_sni() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
}, 5)
.await
.unwrap();
assert!(result, "Unknown SNI should result in dropped connection");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_wildcard_domain() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_multiple_domains() {
let b1_port = next_port();
let b2_port = next_port();
let b3_port = next_port();
let proxy_port = next_port();
let _b1 = start_prefix_echo_server(b1_port, "B1:").await;
let _b2 = start_prefix_echo_server(b2_port, "B2:").await;
let _b3 = start_prefix_echo_server(b3_port, "B3:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port),
make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port),
make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
for (domain, expected_prefix) in [
("alpha.example.com", "B1:"),
("beta.example.com", "B2:"),
("gamma.example.com", "B3:"),
] {
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(
result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}",
domain, expected_prefix, result
);
}
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,324 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// Create a rustls client config that trusts self-signed certs.
fn make_insecure_tls_client_config() -> Arc<rustls::ClientConfig> {
let _ = rustls::crypto::ring::default_provider().install_default();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
Arc::new(config)
}
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}
#[tokio::test]
async fn test_tls_terminate_basic() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "test.example.com";
// Generate self-signed cert
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
// Start plain TCP echo backend (proxy terminates TLS, sends plain to backend)
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "hello TLS");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_and_reencrypt() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "reencrypt.example.com";
let backend_domain = "backend.internal";
// Generate certs
let (proxy_cert, proxy_key) = generate_self_signed_cert(domain);
let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain);
// Start TLS echo backend
let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await;
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "hello reencrypt");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_sni_cert_selection() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let (cert1, key1) = generate_self_signed_cert("alpha.example.com");
let (cert2, key2) = generate_self_signed_cert("beta.example.com");
let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await;
let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_large_payload() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "large.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send 1MB of data
let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap();
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 15)
.await
.unwrap();
assert_eq!(result, 1_000_000);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_concurrent() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "concurrent.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
let dom = domain.to_string();
handles.push(tokio::spawn(async move {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 15)
.await
.unwrap();
assert_eq!(result.len(), 10);
for (i, r) in result.iter().enumerate() {
assert_eq!(r, &format!("conn-{}", i));
}
proxy.stop().await.unwrap();
}