feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
This commit is contained in:
1760
rust/Cargo.lock
generated
Normal file
1760
rust/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
98
rust/Cargo.toml
Normal file
98
rust/Cargo.toml
Normal file
@@ -0,0 +1,98 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/rustproxy",
|
||||
"crates/rustproxy-config",
|
||||
"crates/rustproxy-routing",
|
||||
"crates/rustproxy-tls",
|
||||
"crates/rustproxy-passthrough",
|
||||
"crates/rustproxy-http",
|
||||
"crates/rustproxy-nftables",
|
||||
"crates/rustproxy-metrics",
|
||||
"crates/rustproxy-security",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
authors = ["Lossless GmbH <hello@lossless.com>"]
|
||||
|
||||
[workspace.dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# HTTP proxy engine (hyper-based)
|
||||
hyper = { version = "1", features = ["http1", "http2", "server", "client"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "http1", "http2", "client-legacy", "server-auto"] }
|
||||
http-body-util = "0.1"
|
||||
bytes = "1"
|
||||
|
||||
# ACME / Let's Encrypt
|
||||
instant-acme = { version = "0.7", features = ["hyper-rustls"] }
|
||||
|
||||
# TLS for passthrough SNI
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
tokio-rustls = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
|
||||
# Self-signed cert generation for tests
|
||||
rcgen = "0.13"
|
||||
|
||||
# Temp directories for tests
|
||||
tempfile = "3"
|
||||
|
||||
# Lock-free atomics
|
||||
arc-swap = "1"
|
||||
|
||||
# Concurrent maps
|
||||
dashmap = "6"
|
||||
|
||||
# Domain wildcard matching
|
||||
glob-match = "0.2"
|
||||
|
||||
# IP/CIDR parsing
|
||||
ipnet = "2"
|
||||
|
||||
# JWT authentication
|
||||
jsonwebtoken = "9"
|
||||
|
||||
# Structured logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# Error handling
|
||||
thiserror = "2"
|
||||
anyhow = "1"
|
||||
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
# Regex for URL rewriting
|
||||
regex = "1"
|
||||
|
||||
# Base64 for basic auth
|
||||
base64 = "0.22"
|
||||
|
||||
# Cancellation / utility
|
||||
tokio-util = "0.7"
|
||||
|
||||
# Async traits
|
||||
async-trait = "0.1"
|
||||
|
||||
# libc for uid checks
|
||||
libc = "0.2"
|
||||
|
||||
# Internal crates
|
||||
rustproxy-config = { path = "crates/rustproxy-config" }
|
||||
rustproxy-routing = { path = "crates/rustproxy-routing" }
|
||||
rustproxy-tls = { path = "crates/rustproxy-tls" }
|
||||
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
|
||||
rustproxy-http = { path = "crates/rustproxy-http" }
|
||||
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
|
||||
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
|
||||
rustproxy-security = { path = "crates/rustproxy-security" }
|
||||
145
rust/config/example.json
Normal file
145
rust/config/example.json
Normal file
@@ -0,0 +1,145 @@
|
||||
{
|
||||
"routes": [
|
||||
{
|
||||
"id": "https-passthrough",
|
||||
"name": "HTTPS Passthrough to Backend",
|
||||
"match": {
|
||||
"ports": 443,
|
||||
"domains": "backend.example.com"
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "10.0.0.1",
|
||||
"port": 443
|
||||
}
|
||||
],
|
||||
"tls": {
|
||||
"mode": "passthrough"
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "https-terminate",
|
||||
"name": "HTTPS Terminate for API",
|
||||
"match": {
|
||||
"ports": 443,
|
||||
"domains": "api.example.com"
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 8080
|
||||
}
|
||||
],
|
||||
"tls": {
|
||||
"mode": "terminate",
|
||||
"certificate": "auto"
|
||||
}
|
||||
},
|
||||
"priority": 20,
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "http-redirect",
|
||||
"name": "HTTP to HTTPS Redirect",
|
||||
"match": {
|
||||
"ports": 80,
|
||||
"domains": ["api.example.com", "www.example.com"]
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 8080
|
||||
}
|
||||
]
|
||||
},
|
||||
"priority": 0
|
||||
},
|
||||
{
|
||||
"id": "load-balanced",
|
||||
"name": "Load Balanced Backend",
|
||||
"match": {
|
||||
"ports": 443,
|
||||
"domains": "*.example.com"
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "backend1.internal",
|
||||
"port": 8080
|
||||
},
|
||||
{
|
||||
"host": "backend2.internal",
|
||||
"port": 8080
|
||||
},
|
||||
{
|
||||
"host": "backend3.internal",
|
||||
"port": 8080
|
||||
}
|
||||
],
|
||||
"tls": {
|
||||
"mode": "terminate",
|
||||
"certificate": "auto"
|
||||
},
|
||||
"loadBalancing": {
|
||||
"algorithm": "round-robin",
|
||||
"healthCheck": {
|
||||
"path": "/health",
|
||||
"interval": 30,
|
||||
"timeout": 5,
|
||||
"unhealthyThreshold": 3,
|
||||
"healthyThreshold": 2
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": {
|
||||
"ipAllowList": ["10.0.0.0/8", "192.168.0.0/16"],
|
||||
"maxConnections": 1000,
|
||||
"rateLimit": {
|
||||
"enabled": true,
|
||||
"maxRequests": 100,
|
||||
"window": 60,
|
||||
"keyBy": "ip"
|
||||
}
|
||||
},
|
||||
"headers": {
|
||||
"request": {
|
||||
"X-Forwarded-For": "{clientIp}",
|
||||
"X-Real-IP": "{clientIp}"
|
||||
},
|
||||
"response": {
|
||||
"X-Powered-By": "RustProxy"
|
||||
},
|
||||
"cors": {
|
||||
"enabled": true,
|
||||
"allowOrigin": "*",
|
||||
"allowMethods": "GET,POST,PUT,DELETE,OPTIONS",
|
||||
"allowHeaders": "Content-Type,Authorization",
|
||||
"allowCredentials": false,
|
||||
"maxAge": 86400
|
||||
}
|
||||
},
|
||||
"priority": 5
|
||||
}
|
||||
],
|
||||
"acme": {
|
||||
"email": "admin@example.com",
|
||||
"useProduction": false,
|
||||
"port": 80
|
||||
},
|
||||
"connectionTimeout": 30000,
|
||||
"socketTimeout": 3600000,
|
||||
"maxConnectionsPerIp": 100,
|
||||
"connectionRateLimitPerMinute": 300,
|
||||
"keepAliveTreatment": "extended",
|
||||
"enableDetailedLogging": false
|
||||
}
|
||||
13
rust/crates/rustproxy-config/Cargo.toml
Normal file
13
rust/crates/rustproxy-config/Cargo.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "rustproxy-config"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Configuration types for RustProxy, compatible with SmartProxy JSON schema"
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
334
rust/crates/rustproxy-config/src/helpers.rs
Normal file
334
rust/crates/rustproxy-config/src/helpers.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
use crate::route_types::*;
|
||||
use crate::tls_types::*;
|
||||
|
||||
/// Create a simple HTTP forwarding route.
|
||||
/// Equivalent to SmartProxy's `createHttpRoute()`.
|
||||
pub fn create_http_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(domains.into()),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(target_host.into()),
|
||||
port: PortSpec::Fixed(target_port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an HTTPS termination route.
|
||||
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
|
||||
pub fn create_https_terminate_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
let mut route = create_http_route(domains, target_host, target_port);
|
||||
route.route_match.ports = PortRange::Single(443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Terminate,
|
||||
certificate: Some(CertificateSpec::Auto("auto".to_string())),
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Create a TLS passthrough route.
|
||||
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
|
||||
pub fn create_https_passthrough_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
let mut route = create_http_route(domains, target_host, target_port);
|
||||
route.route_match.ports = PortRange::Single(443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Create an HTTP-to-HTTPS redirect route.
|
||||
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
|
||||
pub fn create_http_to_https_redirect(
|
||||
domains: impl Into<DomainSpec>,
|
||||
) -> RouteConfig {
|
||||
let domains = domains.into();
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(domains),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: Some(RouteAdvanced {
|
||||
timeout: None,
|
||||
headers: None,
|
||||
keep_alive: None,
|
||||
static_files: None,
|
||||
test_response: Some(RouteTestResponse {
|
||||
status: 301,
|
||||
headers: {
|
||||
let mut h = std::collections::HashMap::new();
|
||||
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
|
||||
h
|
||||
},
|
||||
body: String::new(),
|
||||
}),
|
||||
url_rewrite: None,
|
||||
}),
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: Some("HTTP to HTTPS Redirect".to_string()),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a complete HTTPS server with HTTP redirect.
|
||||
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
|
||||
pub fn create_complete_https_server(
|
||||
domain: impl Into<String>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> Vec<RouteConfig> {
|
||||
let domain = domain.into();
|
||||
let target_host = target_host.into();
|
||||
|
||||
vec![
|
||||
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
|
||||
create_https_terminate_route(
|
||||
DomainSpec::Single(domain),
|
||||
target_host,
|
||||
target_port,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
/// Create a load balancer route.
|
||||
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
|
||||
pub fn create_load_balancer_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
targets: Vec<(String, u16)>,
|
||||
tls: Option<RouteTls>,
|
||||
) -> RouteConfig {
|
||||
let route_targets: Vec<RouteTarget> = targets
|
||||
.into_iter()
|
||||
.map(|(host, port)| RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(host),
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let port = if tls.is_some() { 443 } else { 80 };
|
||||
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(port),
|
||||
domains: Some(domains.into()),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(route_targets),
|
||||
tls,
|
||||
websocket: None,
|
||||
load_balancing: Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::RoundRobin,
|
||||
health_check: None,
|
||||
}),
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: Some("Load Balancer".to_string()),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience conversions for DomainSpec
|
||||
impl From<&str> for DomainSpec {
|
||||
fn from(s: &str) -> Self {
|
||||
DomainSpec::Single(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for DomainSpec {
|
||||
fn from(s: String) -> Self {
|
||||
DomainSpec::Single(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for DomainSpec {
|
||||
fn from(v: Vec<String>) -> Self {
|
||||
DomainSpec::List(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<&str>> for DomainSpec {
|
||||
fn from(v: Vec<&str>) -> Self {
|
||||
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tls_types::TlsMode;
|
||||
|
||||
#[test]
|
||||
fn test_create_http_route() {
|
||||
let route = create_http_route("example.com", "localhost", 8080);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
|
||||
assert_eq!(domains, vec!["example.com"]);
|
||||
let target = &route.action.targets.as_ref().unwrap()[0];
|
||||
assert_eq!(target.host.first(), "localhost");
|
||||
assert_eq!(target.port.resolve(80), 8080);
|
||||
assert!(route.action.tls.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_https_terminate_route() {
|
||||
let route = create_https_terminate_route("api.example.com", "backend", 3000);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
|
||||
let tls = route.action.tls.as_ref().unwrap();
|
||||
assert_eq!(tls.mode, TlsMode::Terminate);
|
||||
assert!(tls.certificate.as_ref().unwrap().is_auto());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_https_passthrough_route() {
|
||||
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
|
||||
let tls = route.action.tls.as_ref().unwrap();
|
||||
assert_eq!(tls.mode, TlsMode::Passthrough);
|
||||
assert!(tls.certificate.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_http_to_https_redirect() {
|
||||
let route = create_http_to_https_redirect("example.com");
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
assert!(route.action.targets.is_none());
|
||||
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
|
||||
assert_eq!(test_response.status, 301);
|
||||
assert!(test_response.headers.contains_key("Location"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_complete_https_server() {
|
||||
let routes = create_complete_https_server("example.com", "backend", 8080);
|
||||
assert_eq!(routes.len(), 2);
|
||||
// First route is HTTP redirect
|
||||
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
|
||||
// Second route is HTTPS terminate
|
||||
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_load_balancer_route() {
|
||||
let targets = vec![
|
||||
("backend1".to_string(), 8080),
|
||||
("backend2".to_string(), 8080),
|
||||
("backend3".to_string(), 8080),
|
||||
];
|
||||
let route = create_load_balancer_route("*.example.com", targets, None);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
|
||||
let lb = route.action.load_balancing.as_ref().unwrap();
|
||||
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_spec_from_str() {
|
||||
let spec: DomainSpec = "example.com".into();
|
||||
assert_eq!(spec.to_vec(), vec!["example.com"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_spec_from_vec() {
|
||||
let spec: DomainSpec = vec!["a.com", "b.com"].into();
|
||||
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
|
||||
}
|
||||
}
|
||||
19
rust/crates/rustproxy-config/src/lib.rs
Normal file
19
rust/crates/rustproxy-config/src/lib.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! # rustproxy-config
|
||||
//!
|
||||
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
|
||||
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
|
||||
|
||||
pub mod route_types;
|
||||
pub mod proxy_options;
|
||||
pub mod tls_types;
|
||||
pub mod security_types;
|
||||
pub mod validation;
|
||||
pub mod helpers;
|
||||
|
||||
// Re-export all primary types
|
||||
pub use route_types::*;
|
||||
pub use proxy_options::*;
|
||||
pub use tls_types::*;
|
||||
pub use security_types::*;
|
||||
pub use validation::*;
|
||||
pub use helpers::*;
|
||||
439
rust/crates/rustproxy-config/src/proxy_options.rs
Normal file
439
rust/crates/rustproxy-config/src/proxy_options.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::route_types::RouteConfig;
|
||||
|
||||
/// Global ACME configuration options.
|
||||
/// Matches TypeScript: `IAcmeOptions`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AcmeOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enabled: Option<bool>,
|
||||
/// Required when any route uses certificate: 'auto'
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub email: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub environment: Option<AcmeEnvironment>,
|
||||
/// Alias for email
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub account_email: Option<String>,
|
||||
/// Port for HTTP-01 challenges (default: 80)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub port: Option<u16>,
|
||||
/// Use Let's Encrypt production (default: false)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_production: Option<bool>,
|
||||
/// Days before expiry to renew (default: 30)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub renew_threshold_days: Option<u32>,
|
||||
/// Enable automatic renewal (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auto_renew: Option<bool>,
|
||||
/// Directory to store certificates (default: './certs')
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub certificate_store: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub skip_configured_certs: Option<bool>,
|
||||
/// How often to check for renewals (default: 24)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub renew_check_interval_hours: Option<u32>,
|
||||
}
|
||||
|
||||
/// ACME environment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AcmeEnvironment {
|
||||
Production,
|
||||
Staging,
|
||||
}
|
||||
|
||||
/// Default target configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultTarget {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
/// Default security configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultSecurity {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections: Option<u64>,
|
||||
}
|
||||
|
||||
/// Default configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub target: Option<DefaultTarget>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub security: Option<DefaultSecurity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
}
|
||||
|
||||
/// Keep-alive treatment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum KeepAliveTreatment {
|
||||
Standard,
|
||||
Extended,
|
||||
Immortal,
|
||||
}
|
||||
|
||||
/// Metrics configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MetricsConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enabled: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_interval_ms: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub retention_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// RustProxy configuration options.
|
||||
/// Matches TypeScript: `ISmartProxyOptions`
|
||||
///
|
||||
/// This is the top-level configuration that can be loaded from a JSON file
|
||||
/// or constructed programmatically.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RustProxyOptions {
|
||||
/// The unified configuration array (required)
|
||||
pub routes: Vec<RouteConfig>,
|
||||
|
||||
/// Preserve client IP when forwarding
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
|
||||
/// List of trusted proxy IPs that can send PROXY protocol
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub proxy_ips: Option<Vec<String>>,
|
||||
|
||||
/// Global option to accept PROXY protocol
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub accept_proxy_protocol: Option<bool>,
|
||||
|
||||
/// Global option to send PROXY protocol to all targets
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
|
||||
/// Global/default settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub defaults: Option<DefaultConfig>,
|
||||
|
||||
// ─── Timeout Settings ────────────────────────────────────────────
|
||||
|
||||
/// Timeout for establishing connection to backend (ms), default: 30000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connection_timeout: Option<u64>,
|
||||
|
||||
/// Timeout for initial data/SNI (ms), default: 60000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub initial_data_timeout: Option<u64>,
|
||||
|
||||
/// Socket inactivity timeout (ms), default: 3600000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub socket_timeout: Option<u64>,
|
||||
|
||||
/// How often to check for inactive connections (ms), default: 60000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub inactivity_check_interval: Option<u64>,
|
||||
|
||||
/// Default max connection lifetime (ms), default: 86400000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connection_lifetime: Option<u64>,
|
||||
|
||||
/// Inactivity timeout (ms), default: 14400000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub inactivity_timeout: Option<u64>,
|
||||
|
||||
/// Maximum time to wait for connections to close during shutdown (ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub graceful_shutdown_timeout: Option<u64>,
|
||||
|
||||
// ─── Socket Optimization ─────────────────────────────────────────
|
||||
|
||||
/// Disable Nagle's algorithm (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub no_delay: Option<bool>,
|
||||
|
||||
/// Enable TCP keepalive (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive: Option<bool>,
|
||||
|
||||
/// Initial delay before sending keepalive probes (ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_initial_delay: Option<u64>,
|
||||
|
||||
/// Maximum bytes to buffer during connection setup
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_pending_data_size: Option<u64>,
|
||||
|
||||
// ─── Enhanced Features ───────────────────────────────────────────
|
||||
|
||||
/// Disable inactivity checking entirely
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_inactivity_check: Option<bool>,
|
||||
|
||||
/// Enable TCP keep-alive probes
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_keep_alive_probes: Option<bool>,
|
||||
|
||||
/// Enable detailed connection logging
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_detailed_logging: Option<bool>,
|
||||
|
||||
/// Enable TLS handshake debug logging
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_tls_debug_logging: Option<bool>,
|
||||
|
||||
/// Randomize timeouts to prevent thundering herd
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_randomized_timeouts: Option<bool>,
|
||||
|
||||
// ─── Rate Limiting ───────────────────────────────────────────────
|
||||
|
||||
/// Maximum simultaneous connections from a single IP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections_per_ip: Option<u64>,
|
||||
|
||||
/// Max new connections per minute from a single IP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connection_rate_limit_per_minute: Option<u64>,
|
||||
|
||||
// ─── Keep-Alive Settings ─────────────────────────────────────────
|
||||
|
||||
/// How to treat keep-alive connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_treatment: Option<KeepAliveTreatment>,
|
||||
|
||||
/// Multiplier for inactivity timeout for keep-alive connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_inactivity_multiplier: Option<f64>,
|
||||
|
||||
/// Extended lifetime for keep-alive connections (ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub extended_keep_alive_lifetime: Option<u64>,
|
||||
|
||||
// ─── HttpProxy Integration ───────────────────────────────────────
|
||||
|
||||
/// Array of ports to forward to HttpProxy
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_http_proxy: Option<Vec<u16>>,
|
||||
|
||||
/// Port where HttpProxy is listening (default: 8443)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub http_proxy_port: Option<u16>,
|
||||
|
||||
// ─── Metrics ─────────────────────────────────────────────────────
|
||||
|
||||
/// Metrics configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metrics: Option<MetricsConfig>,
|
||||
|
||||
// ─── ACME ────────────────────────────────────────────────────────
|
||||
|
||||
/// Global ACME configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acme: Option<AcmeOptions>,
|
||||
}
|
||||
|
||||
impl Default for RustProxyOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
routes: Vec::new(),
|
||||
preserve_source_ip: None,
|
||||
proxy_ips: None,
|
||||
accept_proxy_protocol: None,
|
||||
send_proxy_protocol: None,
|
||||
defaults: None,
|
||||
connection_timeout: None,
|
||||
initial_data_timeout: None,
|
||||
socket_timeout: None,
|
||||
inactivity_check_interval: None,
|
||||
max_connection_lifetime: None,
|
||||
inactivity_timeout: None,
|
||||
graceful_shutdown_timeout: None,
|
||||
no_delay: None,
|
||||
keep_alive: None,
|
||||
keep_alive_initial_delay: None,
|
||||
max_pending_data_size: None,
|
||||
disable_inactivity_check: None,
|
||||
enable_keep_alive_probes: None,
|
||||
enable_detailed_logging: None,
|
||||
enable_tls_debug_logging: None,
|
||||
enable_randomized_timeouts: None,
|
||||
max_connections_per_ip: None,
|
||||
connection_rate_limit_per_minute: None,
|
||||
keep_alive_treatment: None,
|
||||
keep_alive_inactivity_multiplier: None,
|
||||
extended_keep_alive_lifetime: None,
|
||||
use_http_proxy: None,
|
||||
http_proxy_port: None,
|
||||
metrics: None,
|
||||
acme: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RustProxyOptions {
|
||||
/// Load configuration from a JSON file.
|
||||
pub fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let options: Self = serde_json::from_str(&content)?;
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
/// Get the effective connection timeout in milliseconds.
|
||||
pub fn effective_connection_timeout(&self) -> u64 {
|
||||
self.connection_timeout.unwrap_or(30_000)
|
||||
}
|
||||
|
||||
/// Get the effective initial data timeout in milliseconds.
|
||||
pub fn effective_initial_data_timeout(&self) -> u64 {
|
||||
self.initial_data_timeout.unwrap_or(60_000)
|
||||
}
|
||||
|
||||
/// Get the effective socket timeout in milliseconds.
|
||||
pub fn effective_socket_timeout(&self) -> u64 {
|
||||
self.socket_timeout.unwrap_or(3_600_000)
|
||||
}
|
||||
|
||||
/// Get the effective max connection lifetime in milliseconds.
|
||||
pub fn effective_max_connection_lifetime(&self) -> u64 {
|
||||
self.max_connection_lifetime.unwrap_or(86_400_000)
|
||||
}
|
||||
|
||||
/// Get all unique ports that routes listen on.
|
||||
pub fn all_listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.routes
|
||||
.iter()
|
||||
.flat_map(|r| r.listening_ports())
|
||||
.collect();
|
||||
ports.sort();
|
||||
ports.dedup();
|
||||
ports
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::helpers::*;
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip_minimal() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![create_http_route("example.com", "localhost", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&options).unwrap();
|
||||
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.routes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip_full() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
create_http_route("a.com", "backend1", 8080),
|
||||
create_https_passthrough_route("b.com", "backend2", 443),
|
||||
],
|
||||
connection_timeout: Some(5000),
|
||||
socket_timeout: Some(60000),
|
||||
max_connections_per_ip: Some(100),
|
||||
acme: Some(AcmeOptions {
|
||||
enabled: Some(true),
|
||||
email: Some("admin@example.com".to_string()),
|
||||
environment: Some(AcmeEnvironment::Staging),
|
||||
account_email: None,
|
||||
port: None,
|
||||
use_production: None,
|
||||
renew_threshold_days: None,
|
||||
auto_renew: None,
|
||||
certificate_store: None,
|
||||
skip_configured_certs: None,
|
||||
renew_check_interval_hours: None,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&options).unwrap();
|
||||
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.routes.len(), 2);
|
||||
assert_eq!(parsed.connection_timeout, Some(5000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_timeouts() {
|
||||
let options = RustProxyOptions::default();
|
||||
assert_eq!(options.effective_connection_timeout(), 30_000);
|
||||
assert_eq!(options.effective_initial_data_timeout(), 60_000);
|
||||
assert_eq!(options.effective_socket_timeout(), 3_600_000);
|
||||
assert_eq!(options.effective_max_connection_lifetime(), 86_400_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_timeouts() {
|
||||
let options = RustProxyOptions {
|
||||
connection_timeout: Some(5000),
|
||||
initial_data_timeout: Some(10000),
|
||||
socket_timeout: Some(30000),
|
||||
max_connection_lifetime: Some(60000),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(options.effective_connection_timeout(), 5000);
|
||||
assert_eq!(options.effective_initial_data_timeout(), 10000);
|
||||
assert_eq!(options.effective_socket_timeout(), 30000);
|
||||
assert_eq!(options.effective_max_connection_lifetime(), 60000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_listening_ports() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
create_http_route("a.com", "backend", 8080), // port 80
|
||||
create_https_passthrough_route("b.com", "backend", 443), // port 443
|
||||
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
let ports = options.all_listening_ports();
|
||||
assert_eq!(ports, vec![80, 443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_camel_case_field_names() {
|
||||
let options = RustProxyOptions {
|
||||
connection_timeout: Some(5000),
|
||||
max_connections_per_ip: Some(100),
|
||||
keep_alive_treatment: Some(KeepAliveTreatment::Extended),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&options).unwrap();
|
||||
assert!(json.contains("connectionTimeout"));
|
||||
assert!(json.contains("maxConnectionsPerIp"));
|
||||
assert!(json.contains("keepAliveTreatment"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_json() {
|
||||
let content = std::fs::read_to_string(
|
||||
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
|
||||
).unwrap();
|
||||
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
|
||||
assert_eq!(options.routes.len(), 4);
|
||||
let ports = options.all_listening_ports();
|
||||
assert!(ports.contains(&80));
|
||||
assert!(ports.contains(&443));
|
||||
}
|
||||
}
|
||||
603
rust/crates/rustproxy-config/src/route_types.rs
Normal file
603
rust/crates/rustproxy-config/src/route_types.rs
Normal file
@@ -0,0 +1,603 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tls_types::RouteTls;
|
||||
use crate::security_types::RouteSecurity;
|
||||
|
||||
// ─── Port Range ──────────────────────────────────────────────────────
|
||||
|
||||
/// Port range specification format.
|
||||
/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PortRange {
|
||||
/// Single port number
|
||||
Single(u16),
|
||||
/// Array of port numbers
|
||||
List(Vec<u16>),
|
||||
/// Array of port ranges
|
||||
Ranges(Vec<PortRangeSpec>),
|
||||
}
|
||||
|
||||
impl PortRange {
|
||||
/// Expand the port range into a flat list of ports.
|
||||
pub fn to_ports(&self) -> Vec<u16> {
|
||||
match self {
|
||||
PortRange::Single(p) => vec![*p],
|
||||
PortRange::List(ports) => ports.clone(),
|
||||
PortRange::Ranges(ranges) => {
|
||||
ranges.iter().flat_map(|r| r.from..=r.to).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A from-to port range.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PortRangeSpec {
|
||||
pub from: u16,
|
||||
pub to: u16,
|
||||
}
|
||||
|
||||
// ─── Route Action Type ───────────────────────────────────────────────
|
||||
|
||||
/// Supported action types for route configurations.
|
||||
/// Matches TypeScript: `type TRouteActionType = 'forward' | 'socket-handler'`
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum RouteActionType {
|
||||
Forward,
|
||||
SocketHandler,
|
||||
}
|
||||
|
||||
// ─── Forwarding Engine ───────────────────────────────────────────────
|
||||
|
||||
/// Forwarding engine specification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ForwardingEngine {
|
||||
Node,
|
||||
Nftables,
|
||||
}
|
||||
|
||||
// ─── Route Match ─────────────────────────────────────────────────────
|
||||
|
||||
/// Domain specification: single string or array.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum DomainSpec {
|
||||
Single(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
impl DomainSpec {
|
||||
pub fn to_vec(&self) -> Vec<&str> {
|
||||
match self {
|
||||
DomainSpec::Single(s) => vec![s.as_str()],
|
||||
DomainSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Header match value: either exact string or regex pattern.
|
||||
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum HeaderMatchValue {
|
||||
Exact(String),
|
||||
}
|
||||
|
||||
/// Route match criteria for incoming requests.
|
||||
/// Matches TypeScript: `IRouteMatch`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteMatch {
|
||||
/// Listen on these ports (required)
|
||||
pub ports: PortRange,
|
||||
|
||||
/// Optional domain patterns to match (default: all domains)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub domains: Option<DomainSpec>,
|
||||
|
||||
/// Match specific paths
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub path: Option<String>,
|
||||
|
||||
/// Match specific client IPs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub client_ip: Option<Vec<String>>,
|
||||
|
||||
/// Match specific TLS versions
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tls_version: Option<Vec<String>>,
|
||||
|
||||
/// Match specific HTTP headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
// ─── Target Match ────────────────────────────────────────────────────
|
||||
|
||||
/// Target-specific match criteria for sub-routing within a route.
|
||||
/// Matches TypeScript: `ITargetMatch`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TargetMatch {
|
||||
/// Match specific ports from the route
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ports: Option<Vec<u16>>,
|
||||
/// Match specific paths (supports wildcards like /api/*)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub path: Option<String>,
|
||||
/// Match specific HTTP headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
/// Match specific HTTP methods
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub method: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// ─── WebSocket Config ────────────────────────────────────────────────
|
||||
|
||||
/// WebSocket configuration.
|
||||
/// Matches TypeScript: `IRouteWebSocket`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteWebSocket {
|
||||
pub enabled: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ping_interval: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ping_timeout: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_payload_size: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub custom_headers: Option<HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub subprotocols: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rewrite_path: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_origins: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authenticate_request: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Load Balancing ──────────────────────────────────────────────────
|
||||
|
||||
/// Load balancing algorithm.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum LoadBalancingAlgorithm {
|
||||
RoundRobin,
|
||||
LeastConnections,
|
||||
IpHash,
|
||||
}
|
||||
|
||||
/// Health check configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HealthCheck {
|
||||
pub path: String,
|
||||
pub interval: u64,
|
||||
pub timeout: u64,
|
||||
pub unhealthy_threshold: u32,
|
||||
pub healthy_threshold: u32,
|
||||
}
|
||||
|
||||
/// Load balancing configuration.
|
||||
/// Matches TypeScript: `IRouteLoadBalancing`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteLoadBalancing {
|
||||
pub algorithm: LoadBalancingAlgorithm,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub health_check: Option<HealthCheck>,
|
||||
}
|
||||
|
||||
// ─── CORS ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Allowed origin specification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AllowOrigin {
|
||||
Single(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
/// CORS configuration for a route.
|
||||
/// Matches TypeScript: `IRouteCors`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteCors {
|
||||
pub enabled: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_origin: Option<AllowOrigin>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_methods: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_headers: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_credentials: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expose_headers: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_age: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preflight: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Headers ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Headers configuration.
|
||||
/// Matches TypeScript: `IRouteHeaders`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteHeaders {
|
||||
/// Headers to add/modify for requests to backend
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request: Option<HashMap<String, String>>,
|
||||
/// Headers to add/modify for responses to client
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response: Option<HashMap<String, String>>,
|
||||
/// CORS configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cors: Option<RouteCors>,
|
||||
}
|
||||
|
||||
// ─── Static Files ────────────────────────────────────────────────────
|
||||
|
||||
/// Static file server configuration.
|
||||
/// Matches TypeScript: `IRouteStaticFiles`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteStaticFiles {
|
||||
pub root: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub index: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub directory: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub index_files: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub follow_symlinks: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_directory_listing: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Test Response ───────────────────────────────────────────────────
|
||||
|
||||
/// Test route response configuration.
|
||||
/// Matches TypeScript: `IRouteTestResponse`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteTestResponse {
|
||||
pub status: u16,
|
||||
pub headers: HashMap<String, String>,
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
// ─── URL Rewriting ───────────────────────────────────────────────────
|
||||
|
||||
/// URL rewriting configuration.
|
||||
/// Matches TypeScript: `IRouteUrlRewrite`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteUrlRewrite {
|
||||
/// RegExp pattern to match in URL
|
||||
pub pattern: String,
|
||||
/// Replacement pattern
|
||||
pub target: String,
|
||||
/// RegExp flags
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub flags: Option<String>,
|
||||
/// Only apply to path, not query string
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub only_rewrite_path: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Advanced Options ────────────────────────────────────────────────
|
||||
|
||||
/// Advanced options for route actions.
|
||||
/// Matches TypeScript: `IRouteAdvanced`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAdvanced {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub static_files: Option<RouteStaticFiles>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub test_response: Option<RouteTestResponse>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub url_rewrite: Option<RouteUrlRewrite>,
|
||||
}
|
||||
|
||||
// ─── NFTables Options ────────────────────────────────────────────────
|
||||
|
||||
/// NFTables protocol type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NfTablesProtocol {
|
||||
Tcp,
|
||||
Udp,
|
||||
All,
|
||||
}
|
||||
|
||||
/// NFTables-specific configuration options.
|
||||
/// Matches TypeScript: `INfTablesOptions`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NfTablesOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub protocol: Option<NfTablesProtocol>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_rate: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_ip_sets: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_advanced_nat: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Backend Protocol ────────────────────────────────────────────────
|
||||
|
||||
/// Backend protocol.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum BackendProtocol {
|
||||
Http1,
|
||||
Http2,
|
||||
}
|
||||
|
||||
/// Action options.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ActionOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub backend_protocol: Option<BackendProtocol>,
|
||||
/// Catch-all for additional options
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
// ─── Route Target ────────────────────────────────────────────────────
|
||||
|
||||
/// Host specification: single string or array of strings.
|
||||
/// Note: Dynamic host functions are only available via programmatic API, not JSON.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum HostSpec {
|
||||
Single(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
impl HostSpec {
|
||||
pub fn to_vec(&self) -> Vec<&str> {
|
||||
match self {
|
||||
HostSpec::Single(s) => vec![s.as_str()],
|
||||
HostSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn first(&self) -> &str {
|
||||
match self {
|
||||
HostSpec::Single(s) => s.as_str(),
|
||||
HostSpec::List(v) => v.first().map(|s| s.as_str()).unwrap_or(""),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Port specification: number or "preserve".
|
||||
/// Note: Dynamic port functions are only available via programmatic API, not JSON.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PortSpec {
|
||||
/// Fixed port number
|
||||
Fixed(u16),
|
||||
/// Special string value like "preserve"
|
||||
Special(String),
|
||||
}
|
||||
|
||||
impl PortSpec {
|
||||
/// Resolve the port, using incoming_port when "preserve" is specified.
|
||||
pub fn resolve(&self, incoming_port: u16) -> u16 {
|
||||
match self {
|
||||
PortSpec::Fixed(p) => *p,
|
||||
PortSpec::Special(s) if s == "preserve" => incoming_port,
|
||||
PortSpec::Special(_) => incoming_port, // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Target configuration for forwarding with sub-matching and overrides.
|
||||
/// Matches TypeScript: `IRouteTarget`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteTarget {
|
||||
/// Optional sub-matching criteria within the route
|
||||
#[serde(rename = "match")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub target_match: Option<TargetMatch>,
|
||||
|
||||
/// Target host(s)
|
||||
pub host: HostSpec,
|
||||
|
||||
/// Target port
|
||||
pub port: PortSpec,
|
||||
|
||||
/// Override route-level TLS settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tls: Option<RouteTls>,
|
||||
|
||||
/// Override route-level WebSocket settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub websocket: Option<RouteWebSocket>,
|
||||
|
||||
/// Override route-level load balancing
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_balancing: Option<RouteLoadBalancing>,
|
||||
|
||||
/// Override route-level proxy protocol setting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
|
||||
/// Override route-level headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<RouteHeaders>,
|
||||
|
||||
/// Override route-level advanced settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub advanced: Option<RouteAdvanced>,
|
||||
|
||||
/// Priority for matching (higher values checked first, default: 0)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
}
|
||||
|
||||
// ─── Route Action ────────────────────────────────────────────────────
|
||||
|
||||
/// Action configuration for route handling.
|
||||
/// Matches TypeScript: `IRouteAction`
|
||||
///
|
||||
/// Note: `socketHandler` is not serializable in JSON. Use the programmatic API
|
||||
/// for socket handler routes.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAction {
|
||||
/// Basic routing type
|
||||
#[serde(rename = "type")]
|
||||
pub action_type: RouteActionType,
|
||||
|
||||
/// Targets for forwarding (array supports multiple targets with sub-matching)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub targets: Option<Vec<RouteTarget>>,
|
||||
|
||||
/// TLS handling (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tls: Option<RouteTls>,
|
||||
|
||||
/// WebSocket support (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub websocket: Option<RouteWebSocket>,
|
||||
|
||||
/// Load balancing options (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_balancing: Option<RouteLoadBalancing>,
|
||||
|
||||
/// Advanced options (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub advanced: Option<RouteAdvanced>,
|
||||
|
||||
/// Additional options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<ActionOptions>,
|
||||
|
||||
/// Forwarding engine specification
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub forwarding_engine: Option<ForwardingEngine>,
|
||||
|
||||
/// NFTables-specific options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub nftables: Option<NfTablesOptions>,
|
||||
|
||||
/// PROXY protocol support (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Route Config ────────────────────────────────────────────────────
|
||||
|
||||
/// The core unified configuration interface.
|
||||
/// Matches TypeScript: `IRouteConfig`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteConfig {
|
||||
/// Unique identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// What to match
|
||||
#[serde(rename = "match")]
|
||||
pub route_match: RouteMatch,
|
||||
|
||||
/// What to do with matched traffic
|
||||
pub action: RouteAction,
|
||||
|
||||
/// Custom headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<RouteHeaders>,
|
||||
|
||||
/// Security features
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub security: Option<RouteSecurity>,
|
||||
|
||||
/// Human-readable name for this route
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Description of the route's purpose
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Controls matching order (higher = matched first)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
|
||||
/// Arbitrary tags for categorization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the route is active (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl RouteConfig {
|
||||
/// Check if this route is enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Get the effective priority (defaults to 0).
|
||||
pub fn effective_priority(&self) -> i32 {
|
||||
self.priority.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get all ports this route listens on.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
self.route_match.ports.to_ports()
|
||||
}
|
||||
|
||||
/// Get the TLS mode for this route (from action-level or first target).
|
||||
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
|
||||
// Check action-level TLS first
|
||||
if let Some(tls) = &self.action.tls {
|
||||
return Some(&tls.mode);
|
||||
}
|
||||
// Check first target's TLS
|
||||
if let Some(targets) = &self.action.targets {
|
||||
if let Some(first) = targets.first() {
|
||||
if let Some(tls) = &first.tls {
|
||||
return Some(&tls.mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
132
rust/crates/rustproxy-config/src/security_types.rs
Normal file
132
rust/crates/rustproxy-config/src/security_types.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Rate limiting configuration.
|
||||
/// Matches TypeScript: `IRouteRateLimit`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteRateLimit {
|
||||
pub enabled: bool,
|
||||
pub max_requests: u64,
|
||||
/// Time window in seconds
|
||||
pub window: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub key_by: Option<RateLimitKeyBy>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub header_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
/// Rate limit key selection.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum RateLimitKeyBy {
|
||||
Ip,
|
||||
Path,
|
||||
Header,
|
||||
}
|
||||
|
||||
/// Authentication type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AuthenticationType {
|
||||
Basic,
|
||||
Digest,
|
||||
Oauth,
|
||||
Jwt,
|
||||
}
|
||||
|
||||
/// Authentication credentials.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AuthCredentials {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
/// Authentication options.
|
||||
/// Matches TypeScript: `IRouteAuthentication`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAuthentication {
|
||||
#[serde(rename = "type")]
|
||||
pub auth_type: AuthenticationType,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub credentials: Option<Vec<AuthCredentials>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub realm: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_secret: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_issuer: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_provider: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_client_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_client_secret: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_redirect_uri: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Basic auth configuration.
|
||||
/// Matches TypeScript: `IRouteSecurity.basicAuth`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct BasicAuthConfig {
|
||||
pub enabled: bool,
|
||||
pub users: Vec<AuthCredentials>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub realm: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude_paths: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// JWT auth configuration.
|
||||
/// Matches TypeScript: `IRouteSecurity.jwtAuth`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct JwtAuthConfig {
|
||||
pub enabled: bool,
|
||||
pub secret: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub algorithm: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub issuer: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audience: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires_in: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude_paths: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Security options for routes.
|
||||
/// Matches TypeScript: `IRouteSecurity`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteSecurity {
|
||||
/// IP addresses that are allowed to connect
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// IP addresses that are blocked from connecting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Maximum concurrent connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections: Option<u64>,
|
||||
/// Authentication configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authentication: Option<RouteAuthentication>,
|
||||
/// Rate limiting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rate_limit: Option<RouteRateLimit>,
|
||||
/// Basic auth
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub basic_auth: Option<BasicAuthConfig>,
|
||||
/// JWT auth
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_auth: Option<JwtAuthConfig>,
|
||||
}
|
||||
93
rust/crates/rustproxy-config/src/tls_types.rs
Normal file
93
rust/crates/rustproxy-config/src/tls_types.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// TLS handling modes for route configurations.
|
||||
/// Matches TypeScript: `type TTlsMode = 'passthrough' | 'terminate' | 'terminate-and-reencrypt'`
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum TlsMode {
|
||||
Passthrough,
|
||||
Terminate,
|
||||
TerminateAndReencrypt,
|
||||
}
|
||||
|
||||
/// Static certificate configuration (PEM-encoded).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CertificateConfig {
|
||||
/// PEM-encoded private key
|
||||
pub key: String,
|
||||
/// PEM-encoded certificate
|
||||
pub cert: String,
|
||||
/// PEM-encoded CA chain
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ca: Option<String>,
|
||||
/// Path to key file (overrides key)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub key_file: Option<String>,
|
||||
/// Path to cert file (overrides cert)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cert_file: Option<String>,
|
||||
}
|
||||
|
||||
/// Certificate specification: either automatic (ACME) or static.
|
||||
/// Matches TypeScript: `certificate?: 'auto' | { key, cert, ca?, keyFile?, certFile? }`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CertificateSpec {
|
||||
/// Use ACME (Let's Encrypt) for automatic provisioning
|
||||
Auto(String), // "auto"
|
||||
/// Static certificate configuration
|
||||
Static(CertificateConfig),
|
||||
}
|
||||
|
||||
impl CertificateSpec {
|
||||
/// Check if this is an auto (ACME) certificate
|
||||
pub fn is_auto(&self) -> bool {
|
||||
matches!(self, CertificateSpec::Auto(s) if s == "auto")
|
||||
}
|
||||
}
|
||||
|
||||
/// ACME configuration for automatic certificate provisioning.
|
||||
/// Matches TypeScript: `IRouteAcme`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAcme {
|
||||
/// Contact email for ACME account
|
||||
pub email: String,
|
||||
/// Use production ACME servers (default: false)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_production: Option<bool>,
|
||||
/// Port for HTTP-01 challenges (default: 80)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub challenge_port: Option<u16>,
|
||||
/// Days before expiry to renew (default: 30)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub renew_before_days: Option<u32>,
|
||||
}
|
||||
|
||||
/// TLS configuration for route actions.
|
||||
/// Matches TypeScript: `IRouteTls`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteTls {
|
||||
/// TLS mode (passthrough, terminate, terminate-and-reencrypt)
|
||||
pub mode: TlsMode,
|
||||
/// Certificate configuration (auto or static)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub certificate: Option<CertificateSpec>,
|
||||
/// ACME options when certificate is 'auto'
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acme: Option<RouteAcme>,
|
||||
/// Allowed TLS versions
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub versions: Option<Vec<String>>,
|
||||
/// OpenSSL cipher string
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ciphers: Option<String>,
|
||||
/// Use server's cipher preferences
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub honor_cipher_order: Option<bool>,
|
||||
/// TLS session timeout in seconds
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_timeout: Option<u64>,
|
||||
}
|
||||
158
rust/crates/rustproxy-config/src/validation.rs
Normal file
158
rust/crates/rustproxy-config/src/validation.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::route_types::{RouteConfig, RouteActionType};
|
||||
|
||||
/// Validation errors for route configurations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ValidationError {
|
||||
#[error("Route '{name}' has no targets but action type is 'forward'")]
|
||||
MissingTargets { name: String },
|
||||
|
||||
#[error("Route '{name}' has empty targets list")]
|
||||
EmptyTargets { name: String },
|
||||
|
||||
#[error("Route '{name}' has no ports specified")]
|
||||
NoPorts { name: String },
|
||||
|
||||
#[error("Route '{name}' port {port} is invalid (must be 1-65535)")]
|
||||
InvalidPort { name: String, port: u16 },
|
||||
|
||||
#[error("Route '{name}': socket-handler action type is not supported in JSON config")]
|
||||
SocketHandlerInJson { name: String },
|
||||
|
||||
#[error("Route '{name}': duplicate route ID '{id}'")]
|
||||
DuplicateId { name: String, id: String },
|
||||
|
||||
#[error("Route '{name}': {message}")]
|
||||
Custom { name: String, message: String },
|
||||
}
|
||||
|
||||
/// Validate a single route configuration.
|
||||
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
|
||||
let mut errors = Vec::new();
|
||||
let name = route.name.clone().unwrap_or_else(|| {
|
||||
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
|
||||
});
|
||||
|
||||
// Check ports
|
||||
let ports = route.listening_ports();
|
||||
if ports.is_empty() {
|
||||
errors.push(ValidationError::NoPorts { name: name.clone() });
|
||||
}
|
||||
for &port in &ports {
|
||||
if port == 0 {
|
||||
errors.push(ValidationError::InvalidPort {
|
||||
name: name.clone(),
|
||||
port,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check forward action has targets
|
||||
if route.action.action_type == RouteActionType::Forward {
|
||||
match &route.action.targets {
|
||||
None => {
|
||||
errors.push(ValidationError::MissingTargets { name: name.clone() });
|
||||
}
|
||||
Some(targets) if targets.is_empty() => {
|
||||
errors.push(ValidationError::EmptyTargets { name: name.clone() });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate an entire list of routes.
|
||||
pub fn validate_routes(routes: &[RouteConfig]) -> Result<(), Vec<ValidationError>> {
|
||||
let mut all_errors = Vec::new();
|
||||
let mut seen_ids = std::collections::HashSet::new();
|
||||
|
||||
for route in routes {
|
||||
// Check for duplicate IDs
|
||||
if let Some(id) = &route.id {
|
||||
if !seen_ids.insert(id.clone()) {
|
||||
let name = route.name.clone().unwrap_or_else(|| id.clone());
|
||||
all_errors.push(ValidationError::DuplicateId {
|
||||
name,
|
||||
id: id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate individual route
|
||||
if let Err(errors) = validate_route(route) {
|
||||
all_errors.extend(errors);
|
||||
}
|
||||
}
|
||||
|
||||
if all_errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(all_errors)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::route_types::*;
|
||||
|
||||
fn make_valid_route() -> RouteConfig {
|
||||
crate::helpers::create_http_route("example.com", "localhost", 8080)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_route_passes() {
|
||||
let route = make_valid_route();
|
||||
assert!(validate_route(&route).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_targets() {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = None;
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_targets() {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = Some(vec![]);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_port_zero() {
|
||||
let mut route = make_valid_route();
|
||||
route.route_match.ports = PortRange::Single(0);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_ids() {
|
||||
let mut r1 = make_valid_route();
|
||||
r1.id = Some("route-1".to_string());
|
||||
let mut r2 = make_valid_route();
|
||||
r2.id = Some("route-1".to_string());
|
||||
let errors = validate_routes(&[r1, r2]).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_errors_collected() {
|
||||
let mut r1 = make_valid_route();
|
||||
r1.action.targets = None; // MissingTargets
|
||||
r1.route_match.ports = PortRange::Single(0); // InvalidPort
|
||||
let errors = validate_route(&r1).unwrap_err();
|
||||
assert!(errors.len() >= 2);
|
||||
}
|
||||
}
|
||||
24
rust/crates/rustproxy-http/Cargo.toml
Normal file
24
rust/crates/rustproxy-http/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "rustproxy-http"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Hyper-based HTTP proxy service for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
hyper = { workspace = true }
|
||||
hyper-util = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
http-body-util = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
14
rust/crates/rustproxy-http/src/lib.rs
Normal file
14
rust/crates/rustproxy-http/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! # rustproxy-http
|
||||
//!
|
||||
//! Hyper-based HTTP proxy service for RustProxy.
|
||||
//! Handles HTTP request parsing, route-based forwarding, and response filtering.
|
||||
|
||||
pub mod proxy_service;
|
||||
pub mod request_filter;
|
||||
pub mod response_filter;
|
||||
pub mod template;
|
||||
pub mod upstream_selector;
|
||||
|
||||
pub use proxy_service::*;
|
||||
pub use template::*;
|
||||
pub use upstream_selector::*;
|
||||
827
rust/crates/rustproxy-http/src/proxy_service.rs
Normal file
827
rust/crates/rustproxy-http/src/proxy_service.rs
Normal file
@@ -0,0 +1,827 @@
|
||||
//! Hyper-based HTTP proxy service.
|
||||
//!
|
||||
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
|
||||
//! parses HTTP requests, matches routes, and forwards to upstream backends.
|
||||
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use regex::Regex;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
|
||||
use crate::request_filter::RequestFilter;
|
||||
use crate::response_filter::ResponseFilter;
|
||||
use crate::upstream_selector::UpstreamSelector;
|
||||
|
||||
/// HTTP proxy service that processes HTTP traffic.
|
||||
pub struct HttpProxyService {
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
upstream_selector: UpstreamSelector,
|
||||
}
|
||||
|
||||
impl HttpProxyService {
|
||||
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
Self {
|
||||
route_manager,
|
||||
metrics,
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on a plain TCP stream.
|
||||
pub async fn handle_connection(
|
||||
self: Arc<Self>,
|
||||
stream: TcpStream,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) {
|
||||
self.handle_io(stream, peer_addr, port).await;
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
|
||||
///
|
||||
/// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2,
|
||||
/// use `handle_io_auto` instead.
|
||||
pub async fn handle_io<I>(
|
||||
self: Arc<Self>,
|
||||
stream: I,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
)
|
||||
where
|
||||
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
let svc = Arc::clone(&self);
|
||||
let peer = peer_addr;
|
||||
async move {
|
||||
svc.handle_request(req, peer, port).await
|
||||
}
|
||||
});
|
||||
|
||||
// Use http1::Builder with upgrades for WebSocket support
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, service)
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP connection error from {}: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single HTTP request.
|
||||
async fn handle_request(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let host = req.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| {
|
||||
// Strip port from host header
|
||||
h.split(':').next().unwrap_or(h).to_string()
|
||||
});
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let method = req.method().clone();
|
||||
|
||||
// Extract headers for matching
|
||||
let headers: HashMap<String, String> = req.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
|
||||
|
||||
// Check for CORS preflight
|
||||
if method == hyper::Method::OPTIONS {
|
||||
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: host.as_deref(),
|
||||
path: Some(&path),
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: Some(&headers),
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let route_match = match self.route_manager.find_route(&ctx) {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for HTTP request to {:?}{}", host, path);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
self.metrics.connection_opened(route_id);
|
||||
|
||||
// Apply request filters (IP check, rate limiting, auth)
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for test response (returns immediately, no upstream needed)
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref test_response) = advanced.test_response {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::build_test_response(test_response));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for static file serving
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref static_files) = advanced.static_files {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::serve_static_file(&path, static_files));
|
||||
}
|
||||
}
|
||||
|
||||
// Select upstream
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
|
||||
}
|
||||
};
|
||||
|
||||
let upstream = self.upstream_selector.select(target, &peer_addr, port);
|
||||
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
|
||||
self.upstream_selector.connection_started(&upstream_key);
|
||||
|
||||
// Check for WebSocket upgrade
|
||||
let is_websocket = req.headers()
|
||||
.get("upgrade")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_websocket {
|
||||
let result = self.handle_websocket_upgrade(
|
||||
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key,
|
||||
).await;
|
||||
// Note: for WebSocket, connection_ended is called inside
|
||||
// the spawned tunnel task when the connection closes.
|
||||
return result;
|
||||
}
|
||||
|
||||
// Determine backend protocol
|
||||
let use_h2 = route_match.route.action.options.as_ref()
|
||||
.and_then(|o| o.backend_protocol.as_ref())
|
||||
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
|
||||
.unwrap_or(false);
|
||||
|
||||
// Build the upstream path (path + query), applying URL rewriting if configured
|
||||
let upstream_path = {
|
||||
let raw_path = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path.clone(),
|
||||
};
|
||||
Self::apply_url_rewrite(&raw_path, &route_match.route)
|
||||
};
|
||||
|
||||
// Build upstream request - stream body instead of buffering
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
// Apply request headers from route config
|
||||
let mut upstream_headers = parts.headers.clone();
|
||||
if let Some(ref route_headers) = route_match.route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
upstream_headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to upstream
|
||||
let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let io = TokioIo::new(upstream_stream);
|
||||
|
||||
let result = if use_h2 {
|
||||
// HTTP/2 backend
|
||||
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
} else {
|
||||
// HTTP/1.1 backend (default)
|
||||
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
};
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
result
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/1.1 with body streaming.
|
||||
async fn forward_h1(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("Upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("Upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path)
|
||||
.version(parts.version);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("Upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/2 with body streaming.
|
||||
async fn forward_h2(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let exec = hyper_util::rt::TokioExecutor::new();
|
||||
let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP/2 upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Build the client-facing response from an upstream response, streaming the body.
|
||||
async fn build_streaming_response(
|
||||
&self,
|
||||
upstream_response: Response<Incoming>,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (resp_parts, resp_body) = upstream_response.into_parts();
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(resp_parts.status);
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
*headers = resp_parts.headers;
|
||||
ResponseFilter::apply_headers(route, headers, None);
|
||||
}
|
||||
|
||||
self.metrics.connection_closed(route_id);
|
||||
|
||||
// Stream the response body directly from upstream to client
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(resp_body);
|
||||
|
||||
Ok(response.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Handle a WebSocket upgrade request.
|
||||
async fn handle_websocket_upgrade(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
upstream_key: &str,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Get WebSocket config from route
|
||||
let ws_config = route.action.websocket.as_ref();
|
||||
|
||||
// Check allowed origins if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref allowed_origins) = ws.allowed_origins {
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
|
||||
|
||||
let mut upstream_stream = match TcpStream::connect(
|
||||
format!("{}:{}", upstream.host, upstream.port)
|
||||
).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let upstream_path = {
|
||||
let raw = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path,
|
||||
};
|
||||
// Apply rewrite_path if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref rewrite_path) = ws.rewrite_path {
|
||||
rewrite_path.clone()
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
};
|
||||
|
||||
let (parts, _body) = req.into_parts();
|
||||
|
||||
let mut raw_request = format!(
|
||||
"{} {} HTTP/1.1\r\n",
|
||||
parts.method, upstream_path
|
||||
);
|
||||
|
||||
let upstream_host = format!("{}:{}", upstream.host, upstream.port);
|
||||
for (name, value) in parts.headers.iter() {
|
||||
if name == hyper::header::HOST {
|
||||
raw_request.push_str(&format!("host: {}\r\n", upstream_host));
|
||||
} else {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply WebSocket custom headers
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref custom_headers) = ws.custom_headers {
|
||||
for (key, value) in custom_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
raw_request.push_str("\r\n");
|
||||
|
||||
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
|
||||
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
|
||||
}
|
||||
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
match upstream_stream.read(&mut temp).await {
|
||||
Ok(0) => {
|
||||
error!("WebSocket: upstream closed before completing handshake");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
|
||||
}
|
||||
Ok(_) => {
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if response_buf.len() > 8192 {
|
||||
error!("WebSocket: upstream response headers too large");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to read upstream response: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf);
|
||||
|
||||
let status_line = response_str.lines().next().unwrap_or("");
|
||||
let status_code = status_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|s| s.parse::<u16>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if status_code != 101 {
|
||||
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(
|
||||
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
|
||||
"WebSocket upgrade rejected by backend",
|
||||
));
|
||||
}
|
||||
|
||||
let mut client_resp = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS);
|
||||
|
||||
if let Some(resp_headers) = client_resp.headers_mut() {
|
||||
for line in response_str.lines().skip(1) {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
break;
|
||||
}
|
||||
if let Some((name, value)) = line.split_once(':') {
|
||||
let name = name.trim();
|
||||
let value = value.trim();
|
||||
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
|
||||
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
|
||||
resp_headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let on_client_upgrade = hyper::upgrade::on(
|
||||
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
|
||||
);
|
||||
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let route_id_owned = route_id.map(|s| s.to_string());
|
||||
let upstream_selector = self.upstream_selector.clone();
|
||||
let upstream_key_owned = upstream_key.to_string();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client_upgraded = match on_client_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(e) => {
|
||||
debug!("WebSocket: client upgrade failed: {}", e);
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let client_io = TokioIo::new(client_upgraded);
|
||||
|
||||
let (mut cr, mut cw) = tokio::io::split(client_io);
|
||||
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
|
||||
|
||||
let c2u = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match cr.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if uw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = uw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let u2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match ur.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if cw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = cw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let bytes_in = c2u.await.unwrap_or(0);
|
||||
let bytes_out = u2c.await.unwrap_or(0);
|
||||
|
||||
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
||||
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()));
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
});
|
||||
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
||||
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
|
||||
);
|
||||
Ok(client_resp.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Build a test response from config (no upstream connection needed).
|
||||
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (key, value) in &config.headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(config.body.clone()))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
|
||||
/// Apply URL rewriting rules from route config.
|
||||
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
|
||||
let rewrite = match route.action.advanced.as_ref()
|
||||
.and_then(|a| a.url_rewrite.as_ref())
|
||||
{
|
||||
Some(r) => r,
|
||||
None => return path.to_string(),
|
||||
};
|
||||
|
||||
// Determine what to rewrite
|
||||
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
|
||||
// Only rewrite the path portion (before ?)
|
||||
match path.split_once('?') {
|
||||
Some((p, q)) => (p.to_string(), format!("?{}", q)),
|
||||
None => (path.to_string(), String::new()),
|
||||
}
|
||||
} else {
|
||||
(path.to_string(), String::new())
|
||||
};
|
||||
|
||||
match Regex::new(&rewrite.pattern) {
|
||||
Ok(re) => {
|
||||
let result = re.replace_all(&subject, rewrite.target.as_str());
|
||||
format!("{}{}", result, suffix)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
|
||||
path.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serve a static file from the configured directory.
|
||||
fn serve_static_file(
|
||||
path: &str,
|
||||
config: &rustproxy_config::RouteStaticFiles,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
use std::path::Path;
|
||||
|
||||
let root = Path::new(&config.root);
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
let clean_path = path.trim_start_matches('/');
|
||||
let clean_path = clean_path.replace("..", "");
|
||||
|
||||
let mut file_path = root.join(&clean_path);
|
||||
|
||||
// If path points to a directory, try index files
|
||||
if file_path.is_dir() || clean_path.is_empty() {
|
||||
let index_files = config.index_files.as_deref()
|
||||
.or(config.index.as_deref())
|
||||
.unwrap_or(&[]);
|
||||
let default_index = vec!["index.html".to_string()];
|
||||
let index_files = if index_files.is_empty() { &default_index } else { index_files };
|
||||
|
||||
let mut found = false;
|
||||
for index in index_files {
|
||||
let candidate = if clean_path.is_empty() {
|
||||
root.join(index)
|
||||
} else {
|
||||
file_path.join(index)
|
||||
};
|
||||
if candidate.is_file() {
|
||||
file_path = candidate;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return error_response(StatusCode::NOT_FOUND, "Not found");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the resolved path is within the root (prevent traversal)
|
||||
let canonical_root = match root.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
let canonical_file = match file_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
if !canonical_file.starts_with(&canonical_root) {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Check if symlinks are allowed
|
||||
if config.follow_symlinks == Some(false) && canonical_file != file_path {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Read the file
|
||||
match std::fs::read(&file_path) {
|
||||
Ok(content) => {
|
||||
let content_type = guess_content_type(&file_path);
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", content_type);
|
||||
|
||||
// Apply cache-control if configured
|
||||
if let Some(ref cache_control) = config.cache_control {
|
||||
response = response.header("Cache-Control", cache_control.as_str());
|
||||
}
|
||||
|
||||
// Apply custom headers
|
||||
if let Some(ref headers) = config.headers {
|
||||
for (key, value) in headers {
|
||||
response = response.header(key.as_str(), value.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(content))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Guess MIME content type from file extension.
|
||||
fn guess_content_type(path: &std::path::Path) -> &'static str {
|
||||
match path.extension().and_then(|e| e.to_str()) {
|
||||
Some("html") | Some("htm") => "text/html; charset=utf-8",
|
||||
Some("css") => "text/css; charset=utf-8",
|
||||
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
|
||||
Some("json") => "application/json; charset=utf-8",
|
||||
Some("xml") => "application/xml; charset=utf-8",
|
||||
Some("txt") => "text/plain; charset=utf-8",
|
||||
Some("png") => "image/png",
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg",
|
||||
Some("gif") => "image/gif",
|
||||
Some("svg") => "image/svg+xml",
|
||||
Some("ico") => "image/x-icon",
|
||||
Some("woff") => "font/woff",
|
||||
Some("woff2") => "font/woff2",
|
||||
Some("ttf") => "font/ttf",
|
||||
Some("pdf") => "application/pdf",
|
||||
Some("wasm") => "application/wasm",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HttpProxyService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
route_manager: Arc::new(RouteManager::new(vec![])),
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let body = Full::new(Bytes::from(message.to_string()))
|
||||
.map_err(|never| match never {});
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(BoxBody::new(body))
|
||||
.unwrap()
|
||||
}
|
||||
263
rust/crates/rustproxy-http/src/request_filter.rs
Normal file
263
rust/crates/rustproxy-http/src/request_filter.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
//! Request filtering: security checks, auth, CORS preflight.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
|
||||
|
||||
pub struct RequestFilter;
|
||||
|
||||
impl RequestFilter {
|
||||
/// Apply security filters. Returns Some(response) if the request should be blocked.
|
||||
pub fn apply(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
Self::apply_with_rate_limiter(security, req, peer_addr, None)
|
||||
}
|
||||
|
||||
/// Apply security filters with an optional shared rate limiter.
|
||||
/// Returns Some(response) if the request should be blocked.
|
||||
pub fn apply_with_rate_limiter(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
rate_limiter: Option<&Arc<RateLimiter>>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
|
||||
// IP filter
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
if !filter.is_allowed(&normalized) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(ref rate_limit_config) = security.rate_limit {
|
||||
if rate_limit_config.enabled {
|
||||
// Use shared rate limiter if provided, otherwise create ephemeral one
|
||||
let should_block = if let Some(limiter) = rate_limiter {
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
} else {
|
||||
// Create a per-check limiter (less ideal but works for non-shared case)
|
||||
let limiter = RateLimiter::new(
|
||||
rate_limit_config.max_requests,
|
||||
rate_limit_config.window,
|
||||
);
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
};
|
||||
|
||||
if should_block {
|
||||
let message = rate_limit_config.error_message
|
||||
.as_deref()
|
||||
.unwrap_or("Rate limit exceeded");
|
||||
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check exclude paths before auth
|
||||
let should_skip_auth = Self::path_matches_exclude_list(request_path, security);
|
||||
|
||||
if !should_skip_auth {
|
||||
// Basic auth
|
||||
if let Some(ref basic_auth) = security.basic_auth {
|
||||
if basic_auth.enabled {
|
||||
// Check basic auth exclude paths
|
||||
let skip_basic = basic_auth.exclude_paths.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_basic {
|
||||
let users: Vec<(String, String)> = basic_auth.users.iter()
|
||||
.map(|c| (c.username.clone(), c.password.clone()))
|
||||
.collect();
|
||||
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
|
||||
|
||||
let auth_header = req.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(header) => {
|
||||
if validator.validate(header).is_none() {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWT auth
|
||||
if let Some(ref jwt_auth) = security.jwt_auth {
|
||||
if jwt_auth.enabled {
|
||||
// Check JWT auth exclude paths
|
||||
let skip_jwt = jwt_auth.exclude_paths.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_jwt {
|
||||
let validator = JwtValidator::new(
|
||||
&jwt_auth.secret,
|
||||
jwt_auth.algorithm.as_deref(),
|
||||
jwt_auth.issuer.as_deref(),
|
||||
jwt_auth.audience.as_deref(),
|
||||
);
|
||||
|
||||
let auth_header = req.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header.and_then(JwtValidator::extract_token) {
|
||||
Some(token) => {
|
||||
if validator.validate(token).is_err() {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a request path matches any pattern in the exclude list.
|
||||
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
|
||||
// No global exclude paths on RouteSecurity currently,
|
||||
// but we check per-auth exclude paths above.
|
||||
// This can be extended if a global exclude_paths is added.
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if a path matches any pattern in the list.
|
||||
/// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check`
|
||||
fn path_matches_any(path: &str, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
if path.starts_with(prefix) {
|
||||
return true;
|
||||
}
|
||||
} else if path == pattern {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Determine the rate limit key based on configuration.
|
||||
fn rate_limit_key(
|
||||
config: &rustproxy_config::RouteRateLimit,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
) -> String {
|
||||
use rustproxy_config::RateLimitKeyBy;
|
||||
match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) {
|
||||
RateLimitKeyBy::Ip => peer_addr.ip().to_string(),
|
||||
RateLimitKeyBy::Path => req.uri().path().to_string(),
|
||||
RateLimitKeyBy::Header => {
|
||||
if let Some(ref header_name) = config.header_name {
|
||||
req.headers()
|
||||
.get(header_name.as_str())
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string()
|
||||
} else {
|
||||
peer_addr.ip().to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check IP-based security (for use in passthrough / TCP-level connections).
|
||||
/// Returns true if allowed, false if blocked.
|
||||
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(client_ip);
|
||||
filter.is_allowed(&normalized)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle CORS preflight (OPTIONS) requests.
|
||||
/// Returns Some(response) if this is a CORS preflight that should be handled.
|
||||
pub fn handle_cors_preflight(
|
||||
req: &Request<Incoming>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
if req.method() != hyper::Method::OPTIONS {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check for CORS preflight indicators
|
||||
let has_origin = req.headers().contains_key("origin");
|
||||
let has_request_method = req.headers().contains_key("access-control-request-method");
|
||||
|
||||
if !has_origin || !has_request_method {
|
||||
return None;
|
||||
}
|
||||
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("*");
|
||||
|
||||
Some(Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(boxed_body(message))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
|
||||
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
|
||||
}
|
||||
92
rust/crates/rustproxy-http/src/response_filter.rs
Normal file
92
rust/crates/rustproxy-http/src/response_filter.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Response filtering: CORS headers, custom headers, security headers.
|
||||
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use rustproxy_config::RouteConfig;
|
||||
|
||||
use crate::template::{RequestContext, expand_template};
|
||||
|
||||
pub struct ResponseFilter;
|
||||
|
||||
impl ResponseFilter {
|
||||
/// Apply response headers from route config and CORS settings.
|
||||
/// If a `RequestContext` is provided, template variables in header values will be expanded.
|
||||
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
|
||||
// Apply custom response headers from route config
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref response_headers) = route_headers.response {
|
||||
for (key, value) in response_headers {
|
||||
if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) {
|
||||
let expanded = match req_ctx {
|
||||
Some(ctx) => expand_template(value, ctx),
|
||||
None => value.clone(),
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&expanded) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply CORS headers if configured
|
||||
if let Some(ref cors) = route_headers.cors {
|
||||
if cors.enabled {
|
||||
Self::apply_cors_headers(cors, headers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) {
|
||||
// Allow-Origin
|
||||
if let Some(ref origin) = cors.allow_origin {
|
||||
let origin_str = match origin {
|
||||
rustproxy_config::AllowOrigin::Single(s) => s.clone(),
|
||||
rustproxy_config::AllowOrigin::List(list) => list.join(", "),
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&origin_str) {
|
||||
headers.insert("access-control-allow-origin", val);
|
||||
}
|
||||
} else {
|
||||
headers.insert(
|
||||
"access-control-allow-origin",
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
if let Some(ref methods) = cors.allow_methods {
|
||||
if let Ok(val) = HeaderValue::from_str(methods) {
|
||||
headers.insert("access-control-allow-methods", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Allow-Headers
|
||||
if let Some(ref allow_headers) = cors.allow_headers {
|
||||
if let Ok(val) = HeaderValue::from_str(allow_headers) {
|
||||
headers.insert("access-control-allow-headers", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Allow-Credentials
|
||||
if cors.allow_credentials == Some(true) {
|
||||
headers.insert(
|
||||
"access-control-allow-credentials",
|
||||
HeaderValue::from_static("true"),
|
||||
);
|
||||
}
|
||||
|
||||
// Expose-Headers
|
||||
if let Some(ref expose) = cors.expose_headers {
|
||||
if let Ok(val) = HeaderValue::from_str(expose) {
|
||||
headers.insert("access-control-expose-headers", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Max-Age
|
||||
if let Some(max_age) = cors.max_age {
|
||||
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
|
||||
headers.insert("access-control-max-age", val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
162
rust/crates/rustproxy-http/src/template.rs
Normal file
162
rust/crates/rustproxy-http/src/template.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Header template variable expansion.
|
||||
//!
|
||||
//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc.
|
||||
//! in header values before they are applied to requests or responses.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// Context for template variable expansion.
|
||||
pub struct RequestContext {
|
||||
pub client_ip: String,
|
||||
pub domain: String,
|
||||
pub port: u16,
|
||||
pub path: String,
|
||||
pub route_name: String,
|
||||
pub connection_id: u64,
|
||||
}
|
||||
|
||||
/// Expand template variables in a header value.
|
||||
/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp}
|
||||
pub fn expand_template(template: &str, ctx: &RequestContext) -> String {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
template
|
||||
.replace("{clientIp}", &ctx.client_ip)
|
||||
.replace("{domain}", &ctx.domain)
|
||||
.replace("{port}", &ctx.port.to_string())
|
||||
.replace("{path}", &ctx.path)
|
||||
.replace("{routeName}", &ctx.route_name)
|
||||
.replace("{connectionId}", &ctx.connection_id.to_string())
|
||||
.replace("{timestamp}", ×tamp.to_string())
|
||||
}
|
||||
|
||||
/// Expand templates in a map of header key-value pairs.
|
||||
pub fn expand_headers(
|
||||
headers: &HashMap<String, String>,
|
||||
ctx: &RequestContext,
|
||||
) -> HashMap<String, String> {
|
||||
headers.iter()
|
||||
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> RequestContext {
|
||||
RequestContext {
|
||||
client_ip: "192.168.1.100".to_string(),
|
||||
domain: "example.com".to_string(),
|
||||
port: 443,
|
||||
path: "/api/v1/users".to_string(),
|
||||
route_name: "api-route".to_string(),
|
||||
connection_id: 42,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_client_ip() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_domain() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{domain}", &ctx), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_port() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{port}", &ctx), "443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_path() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{path}", &ctx), "/api/v1/users");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_route_name() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{routeName}", &ctx), "api-route");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_connection_id() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{connectionId}", &ctx), "42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_timestamp() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{timestamp}", &ctx);
|
||||
// Timestamp should be a valid number
|
||||
let ts: u64 = result.parse().expect("timestamp should be a number");
|
||||
// Should be a reasonable Unix timestamp (after 2020)
|
||||
assert!(ts > 1_577_836_800);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_mixed_template() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx);
|
||||
assert_eq!(result, "client=192.168.1.100, host=example.com:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_no_variables() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("plain-value", &ctx), "plain-value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_empty_string() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("", &ctx), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_multiple_same_variable() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{clientIp}-{clientIp}", &ctx);
|
||||
assert_eq!(result, "192.168.1.100-192.168.1.100");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_headers_map() {
|
||||
let ctx = test_context();
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string());
|
||||
headers.insert("X-Route".to_string(), "{routeName}".to_string());
|
||||
headers.insert("X-Static".to_string(), "no-template".to_string());
|
||||
|
||||
let result = expand_headers(&headers, &ctx);
|
||||
assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100");
|
||||
assert_eq!(result.get("X-Route").unwrap(), "api-route");
|
||||
assert_eq!(result.get("X-Static").unwrap(), "no-template");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_all_variables_in_one() {
|
||||
let ctx = test_context();
|
||||
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
|
||||
let result = expand_template(template, &ctx);
|
||||
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_unknown_variable_left_as_is() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{unknownVar}", &ctx);
|
||||
assert_eq!(result, "{unknownVar}");
|
||||
}
|
||||
}
|
||||
222
rust/crates/rustproxy-http/src/upstream_selector.rs
Normal file
222
rust/crates/rustproxy-http/src/upstream_selector.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
//! Route-aware upstream selection with load balancing.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
|
||||
|
||||
/// Upstream selection result.
|
||||
pub struct UpstreamSelection {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub use_tls: bool,
|
||||
}
|
||||
|
||||
/// Selects upstream backends with load balancing support.
|
||||
pub struct UpstreamSelector {
|
||||
/// Round-robin counters per route (keyed by first target host:port)
|
||||
round_robin: Mutex<HashMap<String, AtomicUsize>>,
|
||||
/// Active connection counts per host (keyed by "host:port")
|
||||
active_connections: Arc<DashMap<String, AtomicU64>>,
|
||||
}
|
||||
|
||||
impl UpstreamSelector {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
round_robin: Mutex::new(HashMap::new()),
|
||||
active_connections: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select an upstream target based on the route target config and load balancing.
|
||||
pub fn select(
|
||||
&self,
|
||||
target: &RouteTarget,
|
||||
client_addr: &SocketAddr,
|
||||
incoming_port: u16,
|
||||
) -> UpstreamSelection {
|
||||
let hosts = target.host.to_vec();
|
||||
let port = target.port.resolve(incoming_port);
|
||||
|
||||
if hosts.len() <= 1 {
|
||||
return UpstreamSelection {
|
||||
host: hosts.first().map(|s| s.to_string()).unwrap_or_default(),
|
||||
port,
|
||||
use_tls: target.tls.is_some(),
|
||||
};
|
||||
}
|
||||
|
||||
// Determine load balancing algorithm
|
||||
let algorithm = target.load_balancing.as_ref()
|
||||
.map(|lb| &lb.algorithm)
|
||||
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
|
||||
|
||||
let idx = match algorithm {
|
||||
LoadBalancingAlgorithm::RoundRobin => {
|
||||
self.round_robin_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::IpHash => {
|
||||
let hash = Self::ip_hash(client_addr);
|
||||
hash % hosts.len()
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => {
|
||||
self.least_connections_select(&hosts, port)
|
||||
}
|
||||
};
|
||||
|
||||
UpstreamSelection {
|
||||
host: hosts[idx].to_string(),
|
||||
port,
|
||||
use_tls: target.tls.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let key = format!("{}:{}", hosts[0], port);
|
||||
let mut counters = self.round_robin.lock().unwrap();
|
||||
let counter = counters
|
||||
.entry(key)
|
||||
.or_insert_with(|| AtomicUsize::new(0));
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed);
|
||||
idx % hosts.len()
|
||||
}
|
||||
|
||||
fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let mut min_conns = u64::MAX;
|
||||
let mut min_idx = 0;
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let key = format!("{}:{}", host, port);
|
||||
let conns = self.active_connections
|
||||
.get(&key)
|
||||
.map(|entry| entry.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
if conns < min_conns {
|
||||
min_conns = conns;
|
||||
min_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
min_idx
|
||||
}
|
||||
|
||||
/// Record that a connection to the given host has started.
|
||||
pub fn connection_started(&self, host: &str) {
|
||||
self.active_connections
|
||||
.entry(host.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record that a connection to the given host has ended.
|
||||
pub fn connection_ended(&self, host: &str) {
|
||||
if let Some(counter) = self.active_connections.get(host) {
|
||||
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
||||
// Guard against underflow (shouldn't happen, but be safe)
|
||||
if prev == 0 {
|
||||
counter.value().store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ip_hash(addr: &SocketAddr) -> usize {
|
||||
let ip_str = addr.ip().to_string();
|
||||
let mut hash: usize = 5381;
|
||||
for byte in ip_str.bytes() {
|
||||
hash = hash.wrapping_mul(33).wrapping_add(byte as usize);
|
||||
}
|
||||
hash
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for UpstreamSelector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for UpstreamSelector {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
round_robin: Mutex::new(HashMap::new()),
|
||||
active_connections: Arc::clone(&self.active_connections),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustproxy_config::*;
|
||||
|
||||
fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget {
|
||||
RouteTarget {
|
||||
target_match: None,
|
||||
host: if hosts.len() == 1 {
|
||||
HostSpec::Single(hosts[0].to_string())
|
||||
} else {
|
||||
HostSpec::List(hosts.iter().map(|s| s.to_string()).collect())
|
||||
},
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_host() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let target = make_target(vec!["backend"], 8080);
|
||||
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
|
||||
let result = selector.select(&target, &addr, 80);
|
||||
assert_eq!(result.host, "backend");
|
||||
assert_eq!(result.port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let mut target = make_target(vec!["a", "b", "c"], 8080);
|
||||
target.load_balancing = Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::RoundRobin,
|
||||
health_check: None,
|
||||
});
|
||||
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
|
||||
|
||||
let r1 = selector.select(&target, &addr, 80);
|
||||
let r2 = selector.select(&target, &addr, 80);
|
||||
let r3 = selector.select(&target, &addr, 80);
|
||||
let r4 = selector.select(&target, &addr, 80);
|
||||
|
||||
// Should cycle through a, b, c, a
|
||||
assert_eq!(r1.host, "a");
|
||||
assert_eq!(r2.host, "b");
|
||||
assert_eq!(r3.host, "c");
|
||||
assert_eq!(r4.host, "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_hash_consistent() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let mut target = make_target(vec!["a", "b", "c"], 8080);
|
||||
target.load_balancing = Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::IpHash,
|
||||
health_check: None,
|
||||
});
|
||||
let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap();
|
||||
|
||||
let r1 = selector.select(&target, &addr, 80);
|
||||
let r2 = selector.select(&target, &addr, 80);
|
||||
// Same IP should always get same backend
|
||||
assert_eq!(r1.host, r2.host);
|
||||
}
|
||||
}
|
||||
15
rust/crates/rustproxy-metrics/Cargo.toml
Normal file
15
rust/crates/rustproxy-metrics/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "rustproxy-metrics"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Metrics and throughput tracking for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
dashmap = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
251
rust/crates/rustproxy-metrics/src/collector.rs
Normal file
251
rust/crates/rustproxy-metrics/src/collector.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// Aggregated metrics snapshot.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Metrics {
|
||||
pub active_connections: u64,
|
||||
pub total_connections: u64,
|
||||
pub bytes_in: u64,
|
||||
pub bytes_out: u64,
|
||||
pub throughput_in_bytes_per_sec: u64,
|
||||
pub throughput_out_bytes_per_sec: u64,
|
||||
pub routes: std::collections::HashMap<String, RouteMetrics>,
|
||||
}
|
||||
|
||||
/// Per-route metrics.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteMetrics {
|
||||
pub active_connections: u64,
|
||||
pub total_connections: u64,
|
||||
pub bytes_in: u64,
|
||||
pub bytes_out: u64,
|
||||
pub throughput_in_bytes_per_sec: u64,
|
||||
pub throughput_out_bytes_per_sec: u64,
|
||||
}
|
||||
|
||||
/// Statistics snapshot.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Statistics {
|
||||
pub active_connections: u64,
|
||||
pub total_connections: u64,
|
||||
pub routes_count: u64,
|
||||
pub listening_ports: Vec<u16>,
|
||||
pub uptime_seconds: u64,
|
||||
}
|
||||
|
||||
/// Metrics collector tracking connections and throughput.
|
||||
pub struct MetricsCollector {
|
||||
active_connections: AtomicU64,
|
||||
total_connections: AtomicU64,
|
||||
total_bytes_in: AtomicU64,
|
||||
total_bytes_out: AtomicU64,
|
||||
/// Per-route active connection counts
|
||||
route_connections: DashMap<String, AtomicU64>,
|
||||
/// Per-route total connection counts
|
||||
route_total_connections: DashMap<String, AtomicU64>,
|
||||
/// Per-route byte counters
|
||||
route_bytes_in: DashMap<String, AtomicU64>,
|
||||
route_bytes_out: DashMap<String, AtomicU64>,
|
||||
}
|
||||
|
||||
impl MetricsCollector {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_connections: AtomicU64::new(0),
|
||||
total_connections: AtomicU64::new(0),
|
||||
total_bytes_in: AtomicU64::new(0),
|
||||
total_bytes_out: AtomicU64::new(0),
|
||||
route_connections: DashMap::new(),
|
||||
route_total_connections: DashMap::new(),
|
||||
route_bytes_in: DashMap::new(),
|
||||
route_bytes_out: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new connection.
|
||||
pub fn connection_opened(&self, route_id: Option<&str>) {
|
||||
self.active_connections.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_connections.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
self.route_connections
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.route_total_connections
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a connection closing.
|
||||
pub fn connection_closed(&self, route_id: Option<&str>) {
|
||||
self.active_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
if let Some(counter) = self.route_connections.get(route_id) {
|
||||
let val = counter.load(Ordering::Relaxed);
|
||||
if val > 0 {
|
||||
counter.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes transferred.
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>) {
|
||||
self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
self.route_bytes_in
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.route_bytes_out
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current active connection count.
|
||||
pub fn active_connections(&self) -> u64 {
|
||||
self.active_connections.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total connection count.
|
||||
pub fn total_connections(&self) -> u64 {
|
||||
self.total_connections.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total bytes received.
|
||||
pub fn total_bytes_in(&self) -> u64 {
|
||||
self.total_bytes_in.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total bytes sent.
|
||||
pub fn total_bytes_out(&self) -> u64 {
|
||||
self.total_bytes_out.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get a full metrics snapshot including per-route data.
|
||||
pub fn snapshot(&self) -> Metrics {
|
||||
let mut routes = std::collections::HashMap::new();
|
||||
|
||||
// Collect per-route metrics
|
||||
for entry in self.route_total_connections.iter() {
|
||||
let route_id = entry.key().clone();
|
||||
let total = entry.value().load(Ordering::Relaxed);
|
||||
let active = self.route_connections
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_in = self.route_bytes_in
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_out = self.route_bytes_out
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
|
||||
routes.insert(route_id, RouteMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
throughput_in_bytes_per_sec: 0,
|
||||
throughput_out_bytes_per_sec: 0,
|
||||
});
|
||||
}
|
||||
|
||||
Metrics {
|
||||
active_connections: self.active_connections(),
|
||||
total_connections: self.total_connections(),
|
||||
bytes_in: self.total_bytes_in(),
|
||||
bytes_out: self.total_bytes_out(),
|
||||
throughput_in_bytes_per_sec: 0,
|
||||
throughput_out_bytes_per_sec: 0,
|
||||
routes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MetricsCollector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_initial_state_zeros() {
|
||||
let collector = MetricsCollector::new();
|
||||
assert_eq!(collector.active_connections(), 0);
|
||||
assert_eq!(collector.total_connections(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_opened_increments() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.connection_opened(None);
|
||||
assert_eq!(collector.active_connections(), 1);
|
||||
assert_eq!(collector.total_connections(), 1);
|
||||
collector.connection_opened(None);
|
||||
assert_eq!(collector.active_connections(), 2);
|
||||
assert_eq!(collector.total_connections(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_closed_decrements() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.connection_opened(None);
|
||||
collector.connection_opened(None);
|
||||
assert_eq!(collector.active_connections(), 2);
|
||||
collector.connection_closed(None);
|
||||
assert_eq!(collector.active_connections(), 1);
|
||||
// total_connections should stay at 2
|
||||
assert_eq!(collector.total_connections(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_specific_tracking() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.connection_opened(Some("route-a"));
|
||||
collector.connection_opened(Some("route-a"));
|
||||
collector.connection_opened(Some("route-b"));
|
||||
|
||||
assert_eq!(collector.active_connections(), 3);
|
||||
assert_eq!(collector.total_connections(), 3);
|
||||
|
||||
collector.connection_closed(Some("route-a"));
|
||||
assert_eq!(collector.active_connections(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_bytes() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.record_bytes(100, 200, Some("route-a"));
|
||||
collector.record_bytes(50, 75, Some("route-a"));
|
||||
collector.record_bytes(25, 30, None);
|
||||
|
||||
let total_in = collector.total_bytes_in.load(Ordering::Relaxed);
|
||||
let total_out = collector.total_bytes_out.load(Ordering::Relaxed);
|
||||
assert_eq!(total_in, 175);
|
||||
assert_eq!(total_out, 305);
|
||||
|
||||
// Route-specific bytes
|
||||
let route_in = collector.route_bytes_in.get("route-a").unwrap();
|
||||
assert_eq!(route_in.load(Ordering::Relaxed), 150);
|
||||
}
|
||||
}
|
||||
11
rust/crates/rustproxy-metrics/src/lib.rs
Normal file
11
rust/crates/rustproxy-metrics/src/lib.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! # rustproxy-metrics
|
||||
//!
|
||||
//! Metrics and throughput tracking for RustProxy.
|
||||
|
||||
pub mod throughput;
|
||||
pub mod collector;
|
||||
pub mod log_dedup;
|
||||
|
||||
pub use throughput::*;
|
||||
pub use collector::*;
|
||||
pub use log_dedup::*;
|
||||
219
rust/crates/rustproxy-metrics/src/log_dedup.rs
Normal file
219
rust/crates/rustproxy-metrics/src/log_dedup.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
/// An aggregated event during the deduplication window.
|
||||
struct AggregatedEvent {
|
||||
category: String,
|
||||
first_message: String,
|
||||
count: AtomicU64,
|
||||
first_seen: Instant,
|
||||
#[allow(dead_code)]
|
||||
last_seen: Instant,
|
||||
}
|
||||
|
||||
/// Log deduplicator that batches similar events over a time window.
|
||||
///
|
||||
/// Events are grouped by a composite key of `category:key`. Within each
|
||||
/// deduplication window (`flush_interval`) identical events are counted
|
||||
/// instead of being emitted individually. When the window expires (or the
|
||||
/// batch reaches `max_batch_size`) a single summary line is written via
|
||||
/// `tracing::info!`.
|
||||
pub struct LogDeduplicator {
|
||||
events: DashMap<String, AggregatedEvent>,
|
||||
flush_interval: Duration,
|
||||
max_batch_size: u64,
|
||||
#[allow(dead_code)]
|
||||
rapid_threshold: u64, // events/sec that triggers immediate flush
|
||||
}
|
||||
|
||||
impl LogDeduplicator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
events: DashMap::new(),
|
||||
flush_interval: Duration::from_secs(5),
|
||||
max_batch_size: 100,
|
||||
rapid_threshold: 50,
|
||||
}
|
||||
}
|
||||
|
||||
/// Log an event, deduplicating by `category` + `key`.
|
||||
///
|
||||
/// If the batch for this composite key reaches `max_batch_size` the
|
||||
/// accumulated events are flushed immediately.
|
||||
pub fn log(&self, category: &str, key: &str, message: &str) {
|
||||
let map_key = format!("{}:{}", category, key);
|
||||
let now = Instant::now();
|
||||
|
||||
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
|
||||
category: category.to_string(),
|
||||
first_message: message.to_string(),
|
||||
count: AtomicU64::new(0),
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
});
|
||||
|
||||
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if we should flush (batch size exceeded)
|
||||
if count >= self.max_batch_size {
|
||||
drop(entry);
|
||||
self.flush();
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush all accumulated events, emitting summary log lines.
|
||||
pub fn flush(&self) {
|
||||
// Collect and remove all events
|
||||
self.events.retain(|_key, event| {
|
||||
let count = event.count.load(Ordering::Relaxed);
|
||||
if count > 0 {
|
||||
let elapsed = event.first_seen.elapsed();
|
||||
if count == 1 {
|
||||
info!("[{}] {}", event.category, event.first_message);
|
||||
} else {
|
||||
info!(
|
||||
"[SUMMARY] {} {} events in {:.1}s: {}",
|
||||
count,
|
||||
event.category,
|
||||
elapsed.as_secs_f64(),
|
||||
event.first_message
|
||||
);
|
||||
}
|
||||
}
|
||||
false // remove all entries after flushing
|
||||
});
|
||||
}
|
||||
|
||||
/// Start a background flush task that periodically drains accumulated
|
||||
/// events. The task runs until the supplied `CancellationToken` is
|
||||
/// cancelled, at which point it performs one final flush before exiting.
|
||||
pub fn start_flush_task(self: &Arc<Self>, cancel: tokio_util::sync::CancellationToken) {
|
||||
let dedup = Arc::clone(self);
|
||||
let interval = self.flush_interval;
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
dedup.flush();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(interval) => {
|
||||
dedup.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LogDeduplicator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_single_event_emitted_as_is() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
dedup.log("conn", "open", "connection opened from 1.2.3.4");
|
||||
// One event should exist
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
let entry = dedup.events.get("conn:open").unwrap();
|
||||
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(entry.first_message, "connection opened from 1.2.3.4");
|
||||
drop(entry);
|
||||
dedup.flush();
|
||||
// After flush, map should be empty
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_events_aggregated() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
for _ in 0..10 {
|
||||
dedup.log("conn", "timeout", "connection timed out");
|
||||
}
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
let entry = dedup.events.get("conn:timeout").unwrap();
|
||||
assert_eq!(entry.count.load(Ordering::Relaxed), 10);
|
||||
drop(entry);
|
||||
dedup.flush();
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keys_separate() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
dedup.log("conn", "open", "opened");
|
||||
dedup.log("conn", "close", "closed");
|
||||
dedup.log("tls", "handshake", "TLS handshake");
|
||||
assert_eq!(dedup.events.len(), 3);
|
||||
dedup.flush();
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_clears_events() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
dedup.log("a", "b", "msg1");
|
||||
dedup.log("a", "b", "msg2");
|
||||
dedup.flush();
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
// Logging after flush creates a new entry
|
||||
dedup.log("a", "b", "msg3");
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
let entry = dedup.events.get("a:b").unwrap();
|
||||
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(entry.first_message, "msg3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_batch_triggers_flush() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
// max_batch_size defaults to 100
|
||||
for i in 0..100 {
|
||||
dedup.log("flood", "key", &format!("event {}", i));
|
||||
}
|
||||
// After hitting max_batch_size the events map should have been flushed
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_trait() {
|
||||
let dedup = LogDeduplicator::default();
|
||||
assert_eq!(dedup.flush_interval, Duration::from_secs(5));
|
||||
assert_eq!(dedup.max_batch_size, 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_background_flush_task() {
|
||||
let dedup = Arc::new(LogDeduplicator {
|
||||
events: DashMap::new(),
|
||||
flush_interval: Duration::from_millis(50),
|
||||
max_batch_size: 100,
|
||||
rapid_threshold: 50,
|
||||
});
|
||||
|
||||
let cancel = tokio_util::sync::CancellationToken::new();
|
||||
dedup.start_flush_task(cancel.clone());
|
||||
|
||||
// Log some events
|
||||
dedup.log("bg", "test", "background flush test");
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
|
||||
// Wait for the background task to flush
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
|
||||
// Cancel the task
|
||||
cancel.cancel();
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
}
|
||||
173
rust/crates/rustproxy-metrics/src/throughput.rs
Normal file
173
rust/crates/rustproxy-metrics/src/throughput.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single throughput sample.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ThroughputSample {
|
||||
pub timestamp_ms: u64,
|
||||
pub bytes_in: u64,
|
||||
pub bytes_out: u64,
|
||||
}
|
||||
|
||||
/// Circular buffer for 1Hz throughput sampling.
|
||||
/// Matches smartproxy's ThroughputTracker.
|
||||
pub struct ThroughputTracker {
|
||||
/// Circular buffer of samples
|
||||
samples: Vec<ThroughputSample>,
|
||||
/// Current write index
|
||||
write_index: usize,
|
||||
/// Number of valid samples
|
||||
count: usize,
|
||||
/// Maximum number of samples to retain
|
||||
capacity: usize,
|
||||
/// Accumulated bytes since last sample
|
||||
pending_bytes_in: AtomicU64,
|
||||
pending_bytes_out: AtomicU64,
|
||||
/// When the tracker was created
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
impl ThroughputTracker {
|
||||
/// Create a new tracker with the given capacity (seconds of retention).
|
||||
pub fn new(retention_seconds: usize) -> Self {
|
||||
Self {
|
||||
samples: Vec::with_capacity(retention_seconds),
|
||||
write_index: 0,
|
||||
count: 0,
|
||||
capacity: retention_seconds,
|
||||
pending_bytes_in: AtomicU64::new(0),
|
||||
pending_bytes_out: AtomicU64::new(0),
|
||||
created_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes (called from data flow callbacks).
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
|
||||
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Take a sample (called at 1Hz).
|
||||
pub fn sample(&mut self) {
|
||||
let bytes_in = self.pending_bytes_in.swap(0, Ordering::Relaxed);
|
||||
let bytes_out = self.pending_bytes_out.swap(0, Ordering::Relaxed);
|
||||
let timestamp_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
|
||||
let sample = ThroughputSample {
|
||||
timestamp_ms,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
};
|
||||
|
||||
if self.samples.len() < self.capacity {
|
||||
self.samples.push(sample);
|
||||
} else {
|
||||
self.samples[self.write_index] = sample;
|
||||
}
|
||||
self.write_index = (self.write_index + 1) % self.capacity;
|
||||
self.count = (self.count + 1).min(self.capacity);
|
||||
}
|
||||
|
||||
/// Get throughput over the last N seconds.
|
||||
pub fn throughput(&self, window_seconds: usize) -> (u64, u64) {
|
||||
let window = window_seconds.min(self.count);
|
||||
if window == 0 {
|
||||
return (0, 0);
|
||||
}
|
||||
|
||||
let mut total_in = 0u64;
|
||||
let mut total_out = 0u64;
|
||||
|
||||
for i in 0..window {
|
||||
let idx = if self.write_index >= i + 1 {
|
||||
self.write_index - i - 1
|
||||
} else {
|
||||
self.capacity - (i + 1 - self.write_index)
|
||||
};
|
||||
if idx < self.samples.len() {
|
||||
total_in += self.samples[idx].bytes_in;
|
||||
total_out += self.samples[idx].bytes_out;
|
||||
}
|
||||
}
|
||||
|
||||
(total_in / window as u64, total_out / window as u64)
|
||||
}
|
||||
|
||||
/// Get instant throughput (last 1 second).
|
||||
pub fn instant(&self) -> (u64, u64) {
|
||||
self.throughput(1)
|
||||
}
|
||||
|
||||
/// Get recent throughput (last 10 seconds).
|
||||
pub fn recent(&self) -> (u64, u64) {
|
||||
self.throughput(10)
|
||||
}
|
||||
|
||||
/// How long this tracker has been alive.
|
||||
pub fn uptime(&self) -> std::time::Duration {
|
||||
self.created_at.elapsed()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_throughput() {
|
||||
let tracker = ThroughputTracker::new(60);
|
||||
let (bytes_in, bytes_out) = tracker.throughput(10);
|
||||
assert_eq!(bytes_in, 0);
|
||||
assert_eq!(bytes_out, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_sample() {
|
||||
let mut tracker = ThroughputTracker::new(60);
|
||||
tracker.record_bytes(1000, 2000);
|
||||
tracker.sample();
|
||||
let (bytes_in, bytes_out) = tracker.instant();
|
||||
assert_eq!(bytes_in, 1000);
|
||||
assert_eq!(bytes_out, 2000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_buffer_wrap() {
|
||||
let mut tracker = ThroughputTracker::new(3); // Small capacity
|
||||
for i in 0..5 {
|
||||
tracker.record_bytes(i * 100, i * 200);
|
||||
tracker.sample();
|
||||
}
|
||||
// Should still work after wrapping
|
||||
let (bytes_in, bytes_out) = tracker.throughput(3);
|
||||
assert!(bytes_in > 0);
|
||||
assert!(bytes_out > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_window_averaging() {
|
||||
let mut tracker = ThroughputTracker::new(60);
|
||||
// Record 3 samples of different sizes
|
||||
tracker.record_bytes(100, 200);
|
||||
tracker.sample();
|
||||
tracker.record_bytes(200, 400);
|
||||
tracker.sample();
|
||||
tracker.record_bytes(300, 600);
|
||||
tracker.sample();
|
||||
|
||||
// Average over 3 samples: (100+200+300)/3 = 200, (200+400+600)/3 = 400
|
||||
let (avg_in, avg_out) = tracker.throughput(3);
|
||||
assert_eq!(avg_in, 200);
|
||||
assert_eq!(avg_out, 400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uptime_positive() {
|
||||
let tracker = ThroughputTracker::new(60);
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
assert!(tracker.uptime().as_millis() >= 10);
|
||||
}
|
||||
}
|
||||
17
rust/crates/rustproxy-nftables/Cargo.toml
Normal file
17
rust/crates/rustproxy-nftables/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "rustproxy-nftables"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "NFTables kernel-level forwarding for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
10
rust/crates/rustproxy-nftables/src/lib.rs
Normal file
10
rust/crates/rustproxy-nftables/src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
//! # rustproxy-nftables
|
||||
//!
|
||||
//! NFTables kernel-level forwarding for RustProxy.
|
||||
//! Generates and manages nft CLI rules for DNAT/SNAT.
|
||||
|
||||
pub mod nft_manager;
|
||||
pub mod rule_builder;
|
||||
|
||||
pub use nft_manager::*;
|
||||
pub use rule_builder::*;
|
||||
238
rust/crates/rustproxy-nftables/src/nft_manager.rs
Normal file
238
rust/crates/rustproxy-nftables/src/nft_manager.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use thiserror::Error;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NftError {
|
||||
#[error("nft command failed: {0}")]
|
||||
CommandFailed(String),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Not running as root")]
|
||||
NotRoot,
|
||||
}
|
||||
|
||||
/// Manager for nftables rules.
|
||||
///
|
||||
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
|
||||
/// Requires root privileges; operations are skipped gracefully if not root.
|
||||
pub struct NftManager {
|
||||
table_name: String,
|
||||
/// Active rules indexed by route ID
|
||||
active_rules: HashMap<String, Vec<String>>,
|
||||
/// Whether the table has been initialized
|
||||
table_initialized: bool,
|
||||
}
|
||||
|
||||
impl NftManager {
|
||||
pub fn new(table_name: Option<String>) -> Self {
|
||||
Self {
|
||||
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
|
||||
active_rules: HashMap::new(),
|
||||
table_initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if we are running as root.
|
||||
fn is_root() -> bool {
|
||||
unsafe { libc::geteuid() == 0 }
|
||||
}
|
||||
|
||||
/// Execute a single nft command via the CLI.
|
||||
async fn exec_nft(command: &str) -> Result<String, NftError> {
|
||||
// The command starts with "nft ", strip it to get the args
|
||||
let args = if command.starts_with("nft ") {
|
||||
&command[4..]
|
||||
} else {
|
||||
command
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("nft")
|
||||
.args(args.split_whitespace())
|
||||
.output()
|
||||
.await
|
||||
.map_err(NftError::Io)?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
Err(NftError::CommandFailed(format!(
|
||||
"Command '{}' failed: {}",
|
||||
command, stderr
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the nftables table and chains are set up.
|
||||
async fn ensure_table(&mut self) -> Result<(), NftError> {
|
||||
if self.table_initialized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
|
||||
for cmd in &setup_commands {
|
||||
Self::exec_nft(cmd).await?;
|
||||
}
|
||||
|
||||
self.table_initialized = true;
|
||||
info!("NFTables table '{}' initialized", self.table_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply rules for a route.
|
||||
///
|
||||
/// Executes the nft commands via the CLI. If not running as root,
|
||||
/// the rules are stored locally but not applied to the kernel.
|
||||
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
|
||||
if !Self::is_root() {
|
||||
warn!("Not running as root, nftables rules will not be applied to kernel");
|
||||
self.active_rules.insert(route_id.to_string(), rules);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.ensure_table().await?;
|
||||
|
||||
for cmd in &rules {
|
||||
Self::exec_nft(cmd).await?;
|
||||
debug!("Applied nft rule: {}", cmd);
|
||||
}
|
||||
|
||||
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
|
||||
self.active_rules.insert(route_id.to_string(), rules);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove rules for a route.
|
||||
///
|
||||
/// Currently removes the route from tracking. To fully remove specific
|
||||
/// rules would require handle-based tracking; for now, cleanup() removes
|
||||
/// the entire table.
|
||||
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
|
||||
if let Some(rules) = self.active_rules.remove(route_id) {
|
||||
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up all managed rules by deleting the entire nftables table.
|
||||
pub async fn cleanup(&mut self) -> Result<(), NftError> {
|
||||
if !Self::is_root() {
|
||||
warn!("Not running as root, skipping nftables cleanup");
|
||||
self.active_rules.clear();
|
||||
self.table_initialized = false;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.table_initialized {
|
||||
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
|
||||
for cmd in &cleanup_commands {
|
||||
match Self::exec_nft(cmd).await {
|
||||
Ok(_) => debug!("Cleanup: {}", cmd),
|
||||
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
|
||||
}
|
||||
}
|
||||
info!("NFTables table '{}' cleaned up", self.table_name);
|
||||
}
|
||||
|
||||
self.active_rules.clear();
|
||||
self.table_initialized = false;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the table name.
|
||||
pub fn table_name(&self) -> &str {
|
||||
&self.table_name
|
||||
}
|
||||
|
||||
/// Whether the table has been initialized in the kernel.
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.table_initialized
|
||||
}
|
||||
|
||||
/// Get the number of active route rule sets.
|
||||
pub fn active_route_count(&self) -> usize {
|
||||
self.active_rules.len()
|
||||
}
|
||||
|
||||
/// Get the status of all active rules.
|
||||
pub fn status(&self) -> HashMap<String, serde_json::Value> {
|
||||
let mut status = HashMap::new();
|
||||
for (route_id, rules) in &self.active_rules {
|
||||
status.insert(
|
||||
route_id.clone(),
|
||||
serde_json::json!({
|
||||
"ruleCount": rules.len(),
|
||||
"rules": rules,
|
||||
}),
|
||||
);
|
||||
}
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new_default_table_name() {
|
||||
let mgr = NftManager::new(None);
|
||||
assert_eq!(mgr.table_name(), "rustproxy");
|
||||
assert!(!mgr.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_custom_table_name() {
|
||||
let mgr = NftManager::new(Some("custom".to_string()));
|
||||
assert_eq!(mgr.table_name(), "custom");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_rules_non_root() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
// When not root, rules are stored but not applied to kernel
|
||||
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 1);
|
||||
|
||||
let status = mgr.status();
|
||||
assert!(status.contains_key("route-1"));
|
||||
assert_eq!(status["route-1"]["ruleCount"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_rules() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
let rules = vec!["nft add rule test".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 1);
|
||||
|
||||
mgr.remove_rules("route-1").await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_non_root() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
let rules = vec!["nft add rule test".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
|
||||
|
||||
mgr.cleanup().await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 0);
|
||||
assert!(!mgr.is_initialized());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_status_multiple_routes() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
|
||||
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
|
||||
|
||||
let status = mgr.status();
|
||||
assert_eq!(status.len(), 2);
|
||||
assert_eq!(status["web"]["ruleCount"], 2);
|
||||
assert_eq!(status["api"]["ruleCount"], 1);
|
||||
}
|
||||
}
|
||||
123
rust/crates/rustproxy-nftables/src/rule_builder.rs
Normal file
123
rust/crates/rustproxy-nftables/src/rule_builder.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
|
||||
|
||||
/// Build nftables DNAT rule for port forwarding.
|
||||
pub fn build_dnat_rule(
|
||||
table_name: &str,
|
||||
chain_name: &str,
|
||||
source_port: u16,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
options: &NfTablesOptions,
|
||||
) -> Vec<String> {
|
||||
let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
|
||||
NfTablesProtocol::Tcp => "tcp",
|
||||
NfTablesProtocol::Udp => "udp",
|
||||
NfTablesProtocol::All => "tcp", // TODO: handle "all"
|
||||
};
|
||||
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// DNAT rule
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
|
||||
table_name, chain_name, protocol, source_port, target_host, target_port,
|
||||
));
|
||||
|
||||
// SNAT rule if preserving source IP is not enabled
|
||||
if !options.preserve_source_ip.unwrap_or(false) {
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} postrouting {} dport {} masquerade",
|
||||
table_name, protocol, target_port,
|
||||
));
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(max_rate) = &options.max_rate {
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} {} {} dport {} limit rate {} accept",
|
||||
table_name, chain_name, protocol, source_port, max_rate,
|
||||
));
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
/// Build the initial table and chain setup commands.
|
||||
pub fn build_table_setup(table_name: &str) -> Vec<String> {
|
||||
vec![
|
||||
format!("nft add table ip {}", table_name),
|
||||
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
|
||||
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
|
||||
]
|
||||
}
|
||||
|
||||
/// Build cleanup commands to remove the table.
|
||||
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
|
||||
vec![format!("nft delete table ip {}", table_name)]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_options() -> NfTablesOptions {
|
||||
NfTablesOptions {
|
||||
preserve_source_ip: None,
|
||||
protocol: None,
|
||||
max_rate: None,
|
||||
priority: None,
|
||||
table_name: None,
|
||||
use_ip_sets: None,
|
||||
use_advanced_nat: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_dnat_rule() {
|
||||
let options = make_options();
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
assert!(rules.len() >= 1);
|
||||
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
|
||||
assert!(rules[0].contains("dport 443"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preserve_source_ip() {
|
||||
let mut options = make_options();
|
||||
options.preserve_source_ip = Some(true);
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
// When preserving source IP, no masquerade rule
|
||||
assert!(rules.iter().all(|r| !r.contains("masquerade")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_without_preserve_source_ip() {
|
||||
let options = make_options();
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
assert!(rules.iter().any(|r| r.contains("masquerade")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limited_rule() {
|
||||
let mut options = make_options();
|
||||
options.max_rate = Some("100/second".to_string());
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
|
||||
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_setup_commands() {
|
||||
let commands = build_table_setup("rustproxy");
|
||||
assert_eq!(commands.len(), 3);
|
||||
assert!(commands[0].contains("add table ip rustproxy"));
|
||||
assert!(commands[1].contains("prerouting"));
|
||||
assert!(commands[2].contains("postrouting"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_cleanup() {
|
||||
let commands = build_table_cleanup("rustproxy");
|
||||
assert_eq!(commands.len(), 1);
|
||||
assert!(commands[0].contains("delete table ip rustproxy"));
|
||||
}
|
||||
}
|
||||
25
rust/crates/rustproxy-passthrough/Cargo.toml
Normal file
25
rust/crates/rustproxy-passthrough/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "rustproxy-passthrough"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Raw TCP/SNI passthrough engine for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
rustproxy-http = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { workspace = true }
|
||||
rustls-pemfile = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
155
rust/crates/rustproxy-passthrough/src/connection_record.rs
Normal file
155
rust/crates/rustproxy-passthrough/src/connection_record.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Per-connection tracking record with atomics for lock-free updates.
|
||||
///
|
||||
/// Each field uses atomics so that the forwarding tasks can update
|
||||
/// bytes_received / bytes_sent / last_activity without holding any lock,
|
||||
/// while the zombie scanner reads them concurrently.
|
||||
pub struct ConnectionRecord {
|
||||
/// Unique connection ID assigned by the ConnectionTracker.
|
||||
pub id: u64,
|
||||
/// Wall-clock instant when this connection was created.
|
||||
pub created_at: Instant,
|
||||
/// Milliseconds since `created_at` when the last activity occurred.
|
||||
/// Updated atomically by the forwarding loops.
|
||||
pub last_activity: AtomicU64,
|
||||
/// Total bytes received from the client (inbound).
|
||||
pub bytes_received: AtomicU64,
|
||||
/// Total bytes sent to the client (outbound / from backend).
|
||||
pub bytes_sent: AtomicU64,
|
||||
/// True once the client side of the connection has closed.
|
||||
pub client_closed: AtomicBool,
|
||||
/// True once the backend side of the connection has closed.
|
||||
pub backend_closed: AtomicBool,
|
||||
/// Whether this connection uses TLS (affects zombie thresholds).
|
||||
pub is_tls: AtomicBool,
|
||||
/// Whether this connection has keep-alive semantics.
|
||||
pub has_keep_alive: AtomicBool,
|
||||
}
|
||||
|
||||
impl ConnectionRecord {
|
||||
/// Create a new connection record with the given ID.
|
||||
/// All counters start at zero, all flags start as false.
|
||||
pub fn new(id: u64) -> Self {
|
||||
Self {
|
||||
id,
|
||||
created_at: Instant::now(),
|
||||
last_activity: AtomicU64::new(0),
|
||||
bytes_received: AtomicU64::new(0),
|
||||
bytes_sent: AtomicU64::new(0),
|
||||
client_closed: AtomicBool::new(false),
|
||||
backend_closed: AtomicBool::new(false),
|
||||
is_tls: AtomicBool::new(false),
|
||||
has_keep_alive: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update `last_activity` to reflect the current elapsed time.
|
||||
pub fn touch(&self) {
|
||||
let elapsed_ms = self.created_at.elapsed().as_millis() as u64;
|
||||
self.last_activity.store(elapsed_ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record `n` bytes received from the client (inbound).
|
||||
pub fn record_bytes_in(&self, n: u64) {
|
||||
self.bytes_received.fetch_add(n, Ordering::Relaxed);
|
||||
self.touch();
|
||||
}
|
||||
|
||||
/// Record `n` bytes sent to the client (outbound / from backend).
|
||||
pub fn record_bytes_out(&self, n: u64) {
|
||||
self.bytes_sent.fetch_add(n, Ordering::Relaxed);
|
||||
self.touch();
|
||||
}
|
||||
|
||||
/// How long since the last activity on this connection.
|
||||
pub fn idle_duration(&self) -> Duration {
|
||||
let last_ms = self.last_activity.load(Ordering::Relaxed);
|
||||
let age_ms = self.created_at.elapsed().as_millis() as u64;
|
||||
Duration::from_millis(age_ms.saturating_sub(last_ms))
|
||||
}
|
||||
|
||||
/// Total age of this connection (time since creation).
|
||||
pub fn age(&self) -> Duration {
|
||||
self.created_at.elapsed()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn test_new_record() {
|
||||
let record = ConnectionRecord::new(42);
|
||||
assert_eq!(record.id, 42);
|
||||
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 0);
|
||||
assert!(!record.client_closed.load(Ordering::Relaxed));
|
||||
assert!(!record.backend_closed.load(Ordering::Relaxed));
|
||||
assert!(!record.is_tls.load(Ordering::Relaxed));
|
||||
assert!(!record.has_keep_alive.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_bytes() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
record.record_bytes_in(100);
|
||||
record.record_bytes_in(200);
|
||||
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 300);
|
||||
|
||||
record.record_bytes_out(50);
|
||||
record.record_bytes_out(75);
|
||||
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 125);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_touch_updates_activity() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
assert_eq!(record.last_activity.load(Ordering::Relaxed), 0);
|
||||
|
||||
// Sleep briefly so elapsed time is nonzero
|
||||
thread::sleep(Duration::from_millis(10));
|
||||
record.touch();
|
||||
|
||||
let activity = record.last_activity.load(Ordering::Relaxed);
|
||||
assert!(activity >= 10, "last_activity should be at least 10ms, got {}", activity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idle_duration() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
// Initially idle_duration ~ age since last_activity is 0
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
let idle = record.idle_duration();
|
||||
assert!(idle >= Duration::from_millis(20));
|
||||
|
||||
// After touch, idle should be near zero
|
||||
record.touch();
|
||||
let idle = record.idle_duration();
|
||||
assert!(idle < Duration::from_millis(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_age() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
let age = record.age();
|
||||
assert!(age >= Duration::from_millis(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flags() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
record.is_tls.store(true, Ordering::Relaxed);
|
||||
record.has_keep_alive.store(true, Ordering::Relaxed);
|
||||
|
||||
assert!(record.client_closed.load(Ordering::Relaxed));
|
||||
assert!(!record.backend_closed.load(Ordering::Relaxed));
|
||||
assert!(record.is_tls.load(Ordering::Relaxed));
|
||||
assert!(record.has_keep_alive.load(Ordering::Relaxed));
|
||||
}
|
||||
}
|
||||
402
rust/crates/rustproxy-passthrough/src/connection_tracker.rs
Normal file
402
rust/crates/rustproxy-passthrough/src/connection_tracker.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use dashmap::DashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::connection_record::ConnectionRecord;
|
||||
|
||||
/// Thresholds for zombie detection (non-TLS connections).
|
||||
const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30);
|
||||
/// Thresholds for zombie detection (TLS connections).
|
||||
const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300);
|
||||
/// Stuck connection timeout (non-TLS): received data but never sent any.
|
||||
const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60);
|
||||
/// Stuck connection timeout (TLS): received data but never sent any.
|
||||
const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300);
|
||||
|
||||
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
|
||||
/// Also maintains per-connection records for zombie detection.
|
||||
pub struct ConnectionTracker {
|
||||
/// Active connection counts per IP
|
||||
active: DashMap<IpAddr, AtomicU64>,
|
||||
/// Connection timestamps per IP for rate limiting
|
||||
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
|
||||
/// Maximum concurrent connections per IP (None = unlimited)
|
||||
max_per_ip: Option<u64>,
|
||||
/// Maximum new connections per minute per IP (None = unlimited)
|
||||
rate_limit_per_minute: Option<u64>,
|
||||
/// Per-connection tracking records for zombie detection
|
||||
connections: DashMap<u64, Arc<ConnectionRecord>>,
|
||||
/// Monotonically increasing connection ID counter
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl ConnectionTracker {
|
||||
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
|
||||
Self {
|
||||
active: DashMap::new(),
|
||||
timestamps: DashMap::new(),
|
||||
max_per_ip,
|
||||
rate_limit_per_minute,
|
||||
connections: DashMap::new(),
|
||||
next_id: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to accept a new connection from the given IP.
|
||||
/// Returns true if allowed, false if over limit.
|
||||
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
||||
// Check per-IP connection limit
|
||||
if let Some(max) = self.max_per_ip {
|
||||
let count = self.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
if count >= max {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
if let Some(rate_limit) = self.rate_limit_per_minute {
|
||||
let now = Instant::now();
|
||||
let one_minute = std::time::Duration::from_secs(60);
|
||||
let mut entry = self.timestamps.entry(*ip).or_default();
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove timestamps older than 1 minute
|
||||
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
||||
timestamps.pop_front();
|
||||
}
|
||||
|
||||
if timestamps.len() as u64 >= rate_limit {
|
||||
return false;
|
||||
}
|
||||
timestamps.push_back(now);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Record that a connection was opened from the given IP.
|
||||
pub fn connection_opened(&self, ip: &IpAddr) {
|
||||
self.active
|
||||
.entry(*ip)
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.value()
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record that a connection was closed from the given IP.
|
||||
pub fn connection_closed(&self, ip: &IpAddr) {
|
||||
if let Some(counter) = self.active.get(ip) {
|
||||
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
||||
// Clean up zero entries
|
||||
if prev <= 1 {
|
||||
drop(counter);
|
||||
self.active.remove(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current number of active connections for an IP.
|
||||
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
|
||||
self.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get the total number of tracked IPs.
|
||||
pub fn tracked_ips(&self) -> usize {
|
||||
self.active.len()
|
||||
}
|
||||
|
||||
/// Register a new connection and return its tracking record.
|
||||
///
|
||||
/// The returned `Arc<ConnectionRecord>` should be passed to the forwarding
|
||||
/// loop so it can update bytes / activity atomics in real time.
|
||||
pub fn register_connection(&self, is_tls: bool) -> Arc<ConnectionRecord> {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let record = Arc::new(ConnectionRecord::new(id));
|
||||
record.is_tls.store(is_tls, Ordering::Relaxed);
|
||||
self.connections.insert(id, Arc::clone(&record));
|
||||
record
|
||||
}
|
||||
|
||||
/// Remove a connection record when the connection is fully closed.
|
||||
pub fn unregister_connection(&self, id: u64) {
|
||||
self.connections.remove(&id);
|
||||
}
|
||||
|
||||
/// Scan all tracked connections and return IDs of zombie connections.
|
||||
///
|
||||
/// A connection is considered a zombie in any of these cases:
|
||||
/// - **Full zombie**: both `client_closed` and `backend_closed` are true.
|
||||
/// - **Half zombie**: one side closed for longer than the threshold
|
||||
/// (5 min for TLS, 30s for non-TLS).
|
||||
/// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer
|
||||
/// than the stuck threshold (5 min for TLS, 60s for non-TLS).
|
||||
pub fn scan_zombies(&self) -> Vec<u64> {
|
||||
let mut zombies = Vec::new();
|
||||
|
||||
for entry in self.connections.iter() {
|
||||
let record = entry.value();
|
||||
let id = *entry.key();
|
||||
let is_tls = record.is_tls.load(Ordering::Relaxed);
|
||||
let client_closed = record.client_closed.load(Ordering::Relaxed);
|
||||
let backend_closed = record.backend_closed.load(Ordering::Relaxed);
|
||||
let idle = record.idle_duration();
|
||||
let bytes_in = record.bytes_received.load(Ordering::Relaxed);
|
||||
let bytes_out = record.bytes_sent.load(Ordering::Relaxed);
|
||||
|
||||
// Full zombie: both sides closed
|
||||
if client_closed && backend_closed {
|
||||
zombies.push(id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Half zombie: one side closed for too long
|
||||
let half_timeout = if is_tls {
|
||||
HALF_ZOMBIE_TIMEOUT_TLS
|
||||
} else {
|
||||
HALF_ZOMBIE_TIMEOUT_PLAIN
|
||||
};
|
||||
|
||||
if (client_closed || backend_closed) && idle >= half_timeout {
|
||||
zombies.push(id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Stuck: received data but never sent anything for too long
|
||||
let stuck_timeout = if is_tls {
|
||||
STUCK_TIMEOUT_TLS
|
||||
} else {
|
||||
STUCK_TIMEOUT_PLAIN
|
||||
};
|
||||
|
||||
if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout {
|
||||
zombies.push(id);
|
||||
}
|
||||
}
|
||||
|
||||
zombies
|
||||
}
|
||||
|
||||
/// Start a background task that periodically scans for zombie connections.
|
||||
///
|
||||
/// The scanner runs every 10 seconds and logs any zombies it finds.
|
||||
/// It stops when the provided `CancellationToken` is cancelled.
|
||||
pub fn start_zombie_scanner(self: &Arc<Self>, cancel: CancellationToken) {
|
||||
let tracker = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
let interval = Duration::from_secs(10);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Zombie scanner shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(interval) => {
|
||||
let zombies = tracker.scan_zombies();
|
||||
if !zombies.is_empty() {
|
||||
warn!(
|
||||
"Detected {} zombie connection(s): {:?}",
|
||||
zombies.len(),
|
||||
zombies
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Get the total number of tracked connections (with records).
|
||||
pub fn total_connections(&self) -> usize {
|
||||
self.connections.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic_tracking() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 1);
|
||||
|
||||
tracker.connection_opened(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 2);
|
||||
|
||||
tracker.connection_closed(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 1);
|
||||
|
||||
tracker.connection_closed(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_per_ip_limit() {
|
||||
let tracker = ConnectionTracker::new(Some(2), None);
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
|
||||
// Third connection should be rejected
|
||||
assert!(!tracker.try_accept(&ip));
|
||||
|
||||
// Different IP should still be allowed
|
||||
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
assert!(tracker.try_accept(&ip2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limit() {
|
||||
let tracker = ConnectionTracker::new(None, Some(3));
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
assert!(tracker.try_accept(&ip));
|
||||
assert!(tracker.try_accept(&ip));
|
||||
// 4th attempt within the minute should be rejected
|
||||
assert!(!tracker.try_accept(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_limits() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
|
||||
for _ in 0..1000 {
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
}
|
||||
assert_eq!(tracker.active_connections(&ip), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tracked_ips() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
assert_eq!(tracker.tracked_ips(), 0);
|
||||
|
||||
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
|
||||
tracker.connection_opened(&ip1);
|
||||
tracker.connection_opened(&ip2);
|
||||
assert_eq!(tracker.tracked_ips(), 2);
|
||||
|
||||
tracker.connection_closed(&ip1);
|
||||
assert_eq!(tracker.tracked_ips(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_unregister_connection() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
|
||||
let record1 = tracker.register_connection(false);
|
||||
assert_eq!(tracker.total_connections(), 1);
|
||||
assert!(!record1.is_tls.load(Ordering::Relaxed));
|
||||
|
||||
let record2 = tracker.register_connection(true);
|
||||
assert_eq!(tracker.total_connections(), 2);
|
||||
assert!(record2.is_tls.load(Ordering::Relaxed));
|
||||
|
||||
// IDs should be unique
|
||||
assert_ne!(record1.id, record2.id);
|
||||
|
||||
tracker.unregister_connection(record1.id);
|
||||
assert_eq!(tracker.total_connections(), 1);
|
||||
|
||||
tracker.unregister_connection(record2.id);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_zombie_detection() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
|
||||
// Not a zombie initially
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
|
||||
// Set both sides closed -> full zombie
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
record.backend_closed.store(true, Ordering::Relaxed);
|
||||
|
||||
let zombies = tracker.scan_zombies();
|
||||
assert_eq!(zombies.len(), 1);
|
||||
assert_eq!(zombies[0], record.id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_half_zombie_not_triggered_immediately() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
record.touch(); // mark activity now
|
||||
|
||||
// Only one side closed, but just now -> not a zombie yet
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stuck_connection_not_triggered_immediately() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
record.touch(); // mark activity now
|
||||
|
||||
// Has received data but sent nothing -> but just started, not stuck yet
|
||||
record.bytes_received.store(1000, Ordering::Relaxed);
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unregister_removes_from_zombie_scan() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
let id = record.id;
|
||||
|
||||
// Make it a full zombie
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
record.backend_closed.store(true, Ordering::Relaxed);
|
||||
assert_eq!(tracker.scan_zombies().len(), 1);
|
||||
|
||||
// Unregister should remove it
|
||||
tracker.unregister_connection(id);
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_total_connections() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
|
||||
let r1 = tracker.register_connection(false);
|
||||
let r2 = tracker.register_connection(true);
|
||||
let r3 = tracker.register_connection(false);
|
||||
assert_eq!(tracker.total_connections(), 3);
|
||||
|
||||
tracker.unregister_connection(r2.id);
|
||||
assert_eq!(tracker.total_connections(), 2);
|
||||
|
||||
tracker.unregister_connection(r1.id);
|
||||
tracker.unregister_connection(r3.id);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
}
|
||||
}
|
||||
325
rust/crates/rustproxy-passthrough/src/forwarder.rs
Normal file
325
rust/crates/rustproxy-passthrough/src/forwarder.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use super::connection_record::ConnectionRecord;
|
||||
|
||||
/// Statistics for a forwarded connection.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ForwardStats {
|
||||
pub bytes_in: AtomicU64,
|
||||
pub bytes_out: AtomicU64,
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding between client and backend.
|
||||
///
|
||||
/// This is the core data path for passthrough connections.
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
|
||||
pub async fn forward_bidirectional(
|
||||
mut client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.split();
|
||||
let (mut backend_read, mut backend_write) = backend.split();
|
||||
|
||||
let client_to_backend = async {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
loop {
|
||||
let n = client_read.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
backend_write.write_all(&buf[..n]).await?;
|
||||
total += n as u64;
|
||||
}
|
||||
backend_write.shutdown().await?;
|
||||
Ok::<u64, std::io::Error>(total)
|
||||
};
|
||||
|
||||
let backend_to_client = async {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = backend_read.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
client_write.write_all(&buf[..n]).await?;
|
||||
total += n as u64;
|
||||
}
|
||||
client_write.shutdown().await?;
|
||||
Ok::<u64, std::io::Error>(total)
|
||||
};
|
||||
|
||||
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
||||
|
||||
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
|
||||
///
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
|
||||
pub async fn forward_bidirectional_with_timeouts(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
|
||||
/// Forward bidirectional with a callback for byte counting.
|
||||
pub async fn forward_bidirectional_with_stats(
|
||||
client: TcpStream,
|
||||
backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
stats: Arc<ForwardStats>,
|
||||
) -> std::io::Result<()> {
|
||||
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
|
||||
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
|
||||
/// updating a `ConnectionRecord` with byte counts and activity timestamps
|
||||
/// in real time for zombie detection.
|
||||
///
|
||||
/// When `record` is `None`, this behaves identically to
|
||||
/// `forward_bidirectional_with_timeouts`.
|
||||
///
|
||||
/// The record's `client_closed` / `backend_closed` flags are set when the
|
||||
/// respective copy loop terminates, giving the zombie scanner visibility
|
||||
/// into half-open connections.
|
||||
pub async fn forward_bidirectional_with_record(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
record: Option<Arc<ConnectionRecord>>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref r) = record {
|
||||
r.record_bytes_in(data.len() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let rec1 = record.clone();
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la1.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec1 {
|
||||
r.record_bytes_in(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
// Mark client side as closed
|
||||
if let Some(ref r) = rec1 {
|
||||
r.client_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let rec2 = record.clone();
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la2.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec2 {
|
||||
r.record_bytes_out(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
// Mark backend side as closed
|
||||
if let Some(ref r) = rec2 {
|
||||
r.backend_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
22
rust/crates/rustproxy-passthrough/src/lib.rs
Normal file
22
rust/crates/rustproxy-passthrough/src/lib.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! # rustproxy-passthrough
|
||||
//!
|
||||
//! Raw TCP/SNI passthrough engine for RustProxy.
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
|
||||
|
||||
pub mod tcp_listener;
|
||||
pub mod sni_parser;
|
||||
pub mod forwarder;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls_handler;
|
||||
pub mod connection_record;
|
||||
pub mod connection_tracker;
|
||||
pub mod socket_relay;
|
||||
|
||||
pub use tcp_listener::*;
|
||||
pub use sni_parser::*;
|
||||
pub use forwarder::*;
|
||||
pub use proxy_protocol::*;
|
||||
pub use tls_handler::*;
|
||||
pub use connection_record::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use socket_relay::*;
|
||||
129
rust/crates/rustproxy-passthrough/src/proxy_protocol.rs
Normal file
129
rust/crates/rustproxy-passthrough/src/proxy_protocol.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use std::net::SocketAddr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProxyProtocolError {
|
||||
#[error("Invalid PROXY protocol header")]
|
||||
InvalidHeader,
|
||||
#[error("Unsupported PROXY protocol version")]
|
||||
UnsupportedVersion,
|
||||
#[error("Parse error: {0}")]
|
||||
Parse(String),
|
||||
}
|
||||
|
||||
/// Parsed PROXY protocol v1 header.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyProtocolHeader {
|
||||
pub source_addr: SocketAddr,
|
||||
pub dest_addr: SocketAddr,
|
||||
pub protocol: ProxyProtocol,
|
||||
}
|
||||
|
||||
/// Protocol in PROXY header.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ProxyProtocol {
|
||||
Tcp4,
|
||||
Tcp6,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Parse a PROXY protocol v1 header from data.
|
||||
///
|
||||
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
|
||||
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
|
||||
// Find the end of the header line
|
||||
let line_end = data
|
||||
.windows(2)
|
||||
.position(|w| w == b"\r\n")
|
||||
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
let line = std::str::from_utf8(&data[..line_end])
|
||||
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
if !line.starts_with("PROXY ") {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.split(' ').collect();
|
||||
if parts.len() != 6 {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
}
|
||||
|
||||
let protocol = match parts[1] {
|
||||
"TCP4" => ProxyProtocol::Tcp4,
|
||||
"TCP6" => ProxyProtocol::Tcp6,
|
||||
"UNKNOWN" => ProxyProtocol::Unknown,
|
||||
_ => return Err(ProxyProtocolError::UnsupportedVersion),
|
||||
};
|
||||
|
||||
let src_ip: std::net::IpAddr = parts[2]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
|
||||
let dst_ip: std::net::IpAddr = parts[3]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
|
||||
let src_port: u16 = parts[4]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
|
||||
let dst_port: u16 = parts[5]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
|
||||
|
||||
let header = ProxyProtocolHeader {
|
||||
source_addr: SocketAddr::new(src_ip, src_port),
|
||||
dest_addr: SocketAddr::new(dst_ip, dst_port),
|
||||
protocol,
|
||||
};
|
||||
|
||||
// Consumed bytes = line + \r\n
|
||||
Ok((header, line_end + 2))
|
||||
}
|
||||
|
||||
/// Generate a PROXY protocol v1 header string.
|
||||
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
|
||||
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
|
||||
format!(
|
||||
"PROXY {} {} {} {} {}\r\n",
|
||||
proto,
|
||||
source.ip(),
|
||||
dest.ip(),
|
||||
source.port(),
|
||||
dest.port()
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if data starts with a PROXY protocol v1 header.
|
||||
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
|
||||
data.starts_with(b"PROXY ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_v1_tcp4() {
|
||||
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
|
||||
let (parsed, consumed) = parse_v1(header).unwrap();
|
||||
assert_eq!(consumed, header.len());
|
||||
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
|
||||
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
|
||||
assert_eq!(parsed.source_addr.port(), 12345);
|
||||
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
|
||||
assert_eq!(parsed.dest_addr.port(), 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_v1() {
|
||||
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
|
||||
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
|
||||
let header = generate_v1(&source, &dest);
|
||||
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_proxy_protocol() {
|
||||
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
|
||||
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
|
||||
}
|
||||
}
|
||||
287
rust/crates/rustproxy-passthrough/src/sni_parser.rs
Normal file
287
rust/crates/rustproxy-passthrough/src/sni_parser.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
//! ClientHello SNI extraction via manual byte parsing.
|
||||
//! No TLS stack needed - we just parse enough of the ClientHello to extract the SNI.
|
||||
|
||||
/// Result of SNI extraction.
|
||||
#[derive(Debug)]
|
||||
pub enum SniResult {
|
||||
/// Successfully extracted SNI hostname.
|
||||
Found(String),
|
||||
/// TLS ClientHello detected but no SNI extension present.
|
||||
NoSni,
|
||||
/// Not a TLS ClientHello (plain HTTP or other protocol).
|
||||
NotTls,
|
||||
/// Need more data to determine.
|
||||
NeedMoreData,
|
||||
}
|
||||
|
||||
/// Extract the SNI hostname from a TLS ClientHello message.
|
||||
///
|
||||
/// This parses just enough of the TLS record to find the SNI extension,
|
||||
/// without performing any actual TLS operations.
|
||||
pub fn extract_sni(data: &[u8]) -> SniResult {
|
||||
// Minimum TLS record header is 5 bytes
|
||||
if data.len() < 5 {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Check for TLS record: content_type=22 (Handshake)
|
||||
if data[0] != 0x16 {
|
||||
return SniResult::NotTls;
|
||||
}
|
||||
|
||||
// TLS version (major.minor) - accept any
|
||||
// data[1..2] = version
|
||||
|
||||
// Record length
|
||||
let record_len = ((data[3] as usize) << 8) | (data[4] as usize);
|
||||
let _total_len = 5 + record_len;
|
||||
|
||||
// We need at least the handshake header (5 TLS + 4 handshake = 9)
|
||||
if data.len() < 9 {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Handshake type = 1 (ClientHello)
|
||||
if data[5] != 0x01 {
|
||||
return SniResult::NotTls;
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes) - informational, we parse incrementally
|
||||
let _handshake_len = ((data[6] as usize) << 16)
|
||||
| ((data[7] as usize) << 8)
|
||||
| (data[8] as usize);
|
||||
|
||||
let hello = &data[9..];
|
||||
|
||||
// ClientHello structure:
|
||||
// 2 bytes: client version
|
||||
// 32 bytes: random
|
||||
// 1 byte: session_id length + session_id
|
||||
let mut pos = 2 + 32; // skip version + random
|
||||
|
||||
if pos >= hello.len() {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Session ID
|
||||
let session_id_len = hello[pos] as usize;
|
||||
pos += 1 + session_id_len;
|
||||
|
||||
if pos + 2 > hello.len() {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Cipher suites
|
||||
let cipher_suites_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
|
||||
pos += 2 + cipher_suites_len;
|
||||
|
||||
if pos + 1 > hello.len() {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Compression methods
|
||||
let compression_len = hello[pos] as usize;
|
||||
pos += 1 + compression_len;
|
||||
|
||||
if pos + 2 > hello.len() {
|
||||
// No extensions
|
||||
return SniResult::NoSni;
|
||||
}
|
||||
|
||||
// Extensions length
|
||||
let extensions_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
|
||||
pos += 2;
|
||||
|
||||
let extensions_end = pos + extensions_len;
|
||||
if extensions_end > hello.len() {
|
||||
// Partial extensions, try to parse what we have
|
||||
}
|
||||
|
||||
// Parse extensions looking for SNI (type 0x0000)
|
||||
while pos + 4 <= hello.len() && pos < extensions_end {
|
||||
let ext_type = ((hello[pos] as u16) << 8) | (hello[pos + 1] as u16);
|
||||
let ext_len = ((hello[pos + 2] as usize) << 8) | (hello[pos + 3] as usize);
|
||||
pos += 4;
|
||||
|
||||
if ext_type == 0x0000 {
|
||||
// SNI extension
|
||||
return parse_sni_extension(&hello[pos..(pos + ext_len).min(hello.len())], ext_len);
|
||||
}
|
||||
|
||||
pos += ext_len;
|
||||
}
|
||||
|
||||
SniResult::NoSni
|
||||
}
|
||||
|
||||
/// Parse the SNI extension data.
|
||||
fn parse_sni_extension(data: &[u8], _ext_len: usize) -> SniResult {
|
||||
if data.len() < 5 {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Server name list length
|
||||
let _list_len = ((data[0] as usize) << 8) | (data[1] as usize);
|
||||
|
||||
// Server name type (0 = hostname)
|
||||
if data[2] != 0x00 {
|
||||
return SniResult::NoSni;
|
||||
}
|
||||
|
||||
// Hostname length
|
||||
let name_len = ((data[3] as usize) << 8) | (data[4] as usize);
|
||||
|
||||
if data.len() < 5 + name_len {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
match std::str::from_utf8(&data[5..5 + name_len]) {
|
||||
Ok(hostname) => SniResult::Found(hostname.to_lowercase()),
|
||||
Err(_) => SniResult::NoSni,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the initial bytes look like a TLS ClientHello.
|
||||
pub fn is_tls(data: &[u8]) -> bool {
|
||||
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
|
||||
}
|
||||
|
||||
/// Check if the initial bytes look like HTTP.
|
||||
pub fn is_http(data: &[u8]) -> bool {
|
||||
if data.len() < 4 {
|
||||
return false;
|
||||
}
|
||||
// Check for common HTTP methods
|
||||
let starts = [
|
||||
b"GET " as &[u8],
|
||||
b"POST",
|
||||
b"PUT ",
|
||||
b"HEAD",
|
||||
b"DELE",
|
||||
b"PATC",
|
||||
b"OPTI",
|
||||
b"CONN",
|
||||
];
|
||||
starts.iter().any(|s| data.starts_with(s))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_not_tls() {
|
||||
let http_data = b"GET / HTTP/1.1\r\n";
|
||||
assert!(matches!(extract_sni(http_data), SniResult::NotTls));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_too_short() {
|
||||
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_tls() {
|
||||
assert!(is_tls(&[0x16, 0x03, 0x01]));
|
||||
assert!(!is_tls(&[0x47, 0x45, 0x54])); // "GET"
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_http() {
|
||||
assert!(is_http(b"GET /"));
|
||||
assert!(is_http(b"POST /api"));
|
||||
assert!(!is_http(&[0x16, 0x03, 0x01]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_real_client_hello() {
|
||||
// A minimal TLS 1.2 ClientHello with SNI "example.com"
|
||||
let client_hello: Vec<u8> = build_test_client_hello("example.com");
|
||||
match extract_sni(&client_hello) {
|
||||
SniResult::Found(sni) => assert_eq!(sni, "example.com"),
|
||||
other => panic!("Expected Found, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a minimal TLS ClientHello for testing.
|
||||
fn build_test_client_hello(hostname: &str) -> Vec<u8> {
|
||||
let hostname_bytes = hostname.as_bytes();
|
||||
|
||||
// SNI extension
|
||||
let sni_ext_data = {
|
||||
let mut d = Vec::new();
|
||||
// Server name list length
|
||||
let name_entry_len = 3 + hostname_bytes.len(); // type(1) + len(2) + name
|
||||
d.push(((name_entry_len >> 8) & 0xFF) as u8);
|
||||
d.push((name_entry_len & 0xFF) as u8);
|
||||
// Host name type = 0
|
||||
d.push(0x00);
|
||||
// Host name length
|
||||
d.push(((hostname_bytes.len() >> 8) & 0xFF) as u8);
|
||||
d.push((hostname_bytes.len() & 0xFF) as u8);
|
||||
// Host name
|
||||
d.extend_from_slice(hostname_bytes);
|
||||
d
|
||||
};
|
||||
|
||||
// Extension: type=0x0000 (SNI), length, data
|
||||
let sni_extension = {
|
||||
let mut e = Vec::new();
|
||||
e.push(0x00); e.push(0x00); // SNI type
|
||||
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
||||
e.push((sni_ext_data.len() & 0xFF) as u8);
|
||||
e.extend_from_slice(&sni_ext_data);
|
||||
e
|
||||
};
|
||||
|
||||
// Extensions block
|
||||
let extensions = {
|
||||
let mut ext = Vec::new();
|
||||
ext.push(((sni_extension.len() >> 8) & 0xFF) as u8);
|
||||
ext.push((sni_extension.len() & 0xFF) as u8);
|
||||
ext.extend_from_slice(&sni_extension);
|
||||
ext
|
||||
};
|
||||
|
||||
// ClientHello body
|
||||
let hello_body = {
|
||||
let mut h = Vec::new();
|
||||
// Client version TLS 1.2
|
||||
h.push(0x03); h.push(0x03);
|
||||
// Random (32 bytes)
|
||||
h.extend_from_slice(&[0u8; 32]);
|
||||
// Session ID length = 0
|
||||
h.push(0x00);
|
||||
// Cipher suites: length=2, one suite
|
||||
h.push(0x00); h.push(0x02);
|
||||
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01); h.push(0x00);
|
||||
// Extensions
|
||||
h.extend_from_slice(&extensions);
|
||||
h
|
||||
};
|
||||
|
||||
// Handshake: type=1 (ClientHello), length
|
||||
let handshake = {
|
||||
let mut hs = Vec::new();
|
||||
hs.push(0x01); // ClientHello
|
||||
// 3-byte length
|
||||
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
|
||||
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
|
||||
hs.push((hello_body.len() & 0xFF) as u8);
|
||||
hs.extend_from_slice(&hello_body);
|
||||
hs
|
||||
};
|
||||
|
||||
// TLS record: type=0x16, version TLS 1.0, length
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // Handshake
|
||||
record.push(0x03); record.push(0x01); // TLS 1.0
|
||||
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
||||
record.push((handshake.len() & 0xFF) as u8);
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
record
|
||||
}
|
||||
}
|
||||
126
rust/crates/rustproxy-passthrough/src/socket_relay.rs
Normal file
126
rust/crates/rustproxy-passthrough/src/socket_relay.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
//! Socket handler relay for connecting client connections to a TypeScript handler
|
||||
//! via a Unix domain socket.
|
||||
//!
|
||||
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
|
||||
|
||||
use tokio::net::UnixStream;
|
||||
use tokio::io::{AsyncWriteExt, AsyncReadExt};
|
||||
use tokio::net::TcpStream;
|
||||
use serde::Serialize;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct RelayMetadata {
|
||||
connection_id: u64,
|
||||
remote_ip: String,
|
||||
remote_port: u16,
|
||||
local_port: u16,
|
||||
sni: Option<String>,
|
||||
route_name: String,
|
||||
initial_data_base64: Option<String>,
|
||||
}
|
||||
|
||||
/// Relay a client connection to a TypeScript handler via Unix domain socket.
|
||||
///
|
||||
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
|
||||
pub async fn relay_to_handler(
|
||||
client: TcpStream,
|
||||
relay_socket_path: &str,
|
||||
connection_id: u64,
|
||||
remote_ip: String,
|
||||
remote_port: u16,
|
||||
local_port: u16,
|
||||
sni: Option<String>,
|
||||
route_name: String,
|
||||
initial_data: Option<&[u8]>,
|
||||
) -> std::io::Result<()> {
|
||||
debug!(
|
||||
"Relaying connection {} to handler socket {}",
|
||||
connection_id, relay_socket_path
|
||||
);
|
||||
|
||||
// Connect to TypeScript handler Unix socket
|
||||
let mut handler = UnixStream::connect(relay_socket_path).await?;
|
||||
|
||||
// Build and send metadata header
|
||||
let initial_data_base64 = initial_data.map(base64_encode);
|
||||
|
||||
let metadata = RelayMetadata {
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
local_port,
|
||||
sni,
|
||||
route_name,
|
||||
initial_data_base64,
|
||||
};
|
||||
|
||||
let metadata_json = serde_json::to_string(&metadata)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
|
||||
handler.write_all(metadata_json.as_bytes()).await?;
|
||||
handler.write_all(b"\n").await?;
|
||||
|
||||
// Bidirectional relay between client and handler
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut handler_read, mut handler_write) = handler.into_split();
|
||||
|
||||
let c2h = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if handler_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = handler_write.shutdown().await;
|
||||
});
|
||||
|
||||
let h2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match handler_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(c2h, h2c);
|
||||
|
||||
debug!("Relay connection {} completed", connection_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Simple base64 encoding without external dependency.
|
||||
fn base64_encode(data: &[u8]) -> String {
|
||||
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
let mut result = String::new();
|
||||
for chunk in data.chunks(3) {
|
||||
let b0 = chunk[0] as u32;
|
||||
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
|
||||
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
|
||||
let n = (b0 << 16) | (b1 << 8) | b2;
|
||||
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
|
||||
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
|
||||
if chunk.len() > 1 {
|
||||
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
if chunk.len() > 2 {
|
||||
result.push(CHARS[(n & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
874
rust/crates/rustproxy-passthrough/src/tcp_listener.rs
Normal file
874
rust/crates/rustproxy-passthrough/src/tcp_listener.rs
Normal file
@@ -0,0 +1,874 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, error, debug, warn};
|
||||
use thiserror::Error;
|
||||
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_http::HttpProxyService;
|
||||
use crate::sni_parser;
|
||||
use crate::forwarder;
|
||||
use crate::tls_handler;
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ListenerError {
|
||||
#[error("Failed to bind port {port}: {source}")]
|
||||
BindFailed { port: u16, source: std::io::Error },
|
||||
#[error("Port {0} already bound")]
|
||||
AlreadyBound(u16),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// TLS configuration for a specific domain.
|
||||
#[derive(Clone)]
|
||||
pub struct TlsCertConfig {
|
||||
pub cert_pem: String,
|
||||
pub key_pem: String,
|
||||
}
|
||||
|
||||
/// Timeout and connection management configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectionConfig {
|
||||
/// Timeout for establishing connection to backend (ms)
|
||||
pub connection_timeout_ms: u64,
|
||||
/// Timeout for initial data/SNI peek (ms)
|
||||
pub initial_data_timeout_ms: u64,
|
||||
/// Socket inactivity timeout (ms)
|
||||
pub socket_timeout_ms: u64,
|
||||
/// Maximum connection lifetime (ms)
|
||||
pub max_connection_lifetime_ms: u64,
|
||||
/// Graceful shutdown timeout (ms)
|
||||
pub graceful_shutdown_timeout_ms: u64,
|
||||
/// Maximum connections per IP (None = unlimited)
|
||||
pub max_connections_per_ip: Option<u64>,
|
||||
/// Connection rate limit per minute per IP (None = unlimited)
|
||||
pub connection_rate_limit_per_minute: Option<u64>,
|
||||
/// Keep-alive treatment
|
||||
pub keep_alive_treatment: Option<rustproxy_config::KeepAliveTreatment>,
|
||||
/// Inactivity multiplier for keep-alive connections
|
||||
pub keep_alive_inactivity_multiplier: Option<f64>,
|
||||
/// Extended keep-alive lifetime (ms) for Extended treatment mode
|
||||
pub extended_keep_alive_lifetime_ms: Option<u64>,
|
||||
/// Whether to accept PROXY protocol
|
||||
pub accept_proxy_protocol: bool,
|
||||
/// Whether to send PROXY protocol
|
||||
pub send_proxy_protocol: bool,
|
||||
}
|
||||
|
||||
impl Default for ConnectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
connection_timeout_ms: 30_000,
|
||||
initial_data_timeout_ms: 60_000,
|
||||
socket_timeout_ms: 3_600_000,
|
||||
max_connection_lifetime_ms: 86_400_000,
|
||||
graceful_shutdown_timeout_ms: 30_000,
|
||||
max_connections_per_ip: None,
|
||||
connection_rate_limit_per_minute: None,
|
||||
keep_alive_treatment: None,
|
||||
keep_alive_inactivity_multiplier: None,
|
||||
extended_keep_alive_lifetime_ms: None,
|
||||
accept_proxy_protocol: false,
|
||||
send_proxy_protocol: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages TCP listeners for all configured ports.
|
||||
pub struct TcpListenerManager {
|
||||
/// Active listeners indexed by port
|
||||
listeners: HashMap<u16, tokio::task::JoinHandle<()>>,
|
||||
/// Shared route manager
|
||||
route_manager: Arc<RouteManager>,
|
||||
/// Shared metrics collector
|
||||
metrics: Arc<MetricsCollector>,
|
||||
/// TLS acceptors indexed by domain
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
/// HTTP proxy service for HTTP-level forwarding
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
/// Connection configuration
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
/// Connection tracker for per-IP limits
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
/// Cancellation token for graceful shutdown
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl TcpListenerManager {
|
||||
pub fn new(route_manager: Arc<RouteManager>) -> Self {
|
||||
let metrics = Arc::new(MetricsCollector::new());
|
||||
let http_proxy = Arc::new(HttpProxyService::new(
|
||||
Arc::clone(&route_manager),
|
||||
Arc::clone(&metrics),
|
||||
));
|
||||
let conn_config = ConnectionConfig::default();
|
||||
let conn_tracker = Arc::new(ConnectionTracker::new(
|
||||
conn_config.max_connections_per_ip,
|
||||
conn_config.connection_rate_limit_per_minute,
|
||||
));
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager,
|
||||
metrics,
|
||||
tls_configs: Arc::new(HashMap::new()),
|
||||
http_proxy,
|
||||
conn_config: Arc::new(conn_config),
|
||||
conn_tracker,
|
||||
cancel_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a metrics collector.
|
||||
pub fn with_metrics(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
let http_proxy = Arc::new(HttpProxyService::new(
|
||||
Arc::clone(&route_manager),
|
||||
Arc::clone(&metrics),
|
||||
));
|
||||
let conn_config = ConnectionConfig::default();
|
||||
let conn_tracker = Arc::new(ConnectionTracker::new(
|
||||
conn_config.max_connections_per_ip,
|
||||
conn_config.connection_rate_limit_per_minute,
|
||||
));
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager,
|
||||
metrics,
|
||||
tls_configs: Arc::new(HashMap::new()),
|
||||
http_proxy,
|
||||
conn_config: Arc::new(conn_config),
|
||||
conn_tracker,
|
||||
cancel_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set connection configuration.
|
||||
pub fn set_connection_config(&mut self, config: ConnectionConfig) {
|
||||
self.conn_tracker = Arc::new(ConnectionTracker::new(
|
||||
config.max_connections_per_ip,
|
||||
config.connection_rate_limit_per_minute,
|
||||
));
|
||||
self.conn_config = Arc::new(config);
|
||||
}
|
||||
|
||||
/// Set TLS certificate configurations.
|
||||
pub fn set_tls_configs(&mut self, configs: HashMap<String, TlsCertConfig>) {
|
||||
self.tls_configs = Arc::new(configs);
|
||||
}
|
||||
|
||||
/// Start listening on a port.
|
||||
pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> {
|
||||
if self.listeners.contains_key(&port) {
|
||||
return Err(ListenerError::AlreadyBound(port));
|
||||
}
|
||||
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = TcpListener::bind(&addr).await.map_err(|e| {
|
||||
ListenerError::BindFailed { port, source: e }
|
||||
})?;
|
||||
|
||||
info!("Listening on port {}", port);
|
||||
|
||||
let route_manager = Arc::clone(&self.route_manager);
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let tls_configs = Arc::clone(&self.tls_configs);
|
||||
let http_proxy = Arc::clone(&self.http_proxy);
|
||||
let conn_config = Arc::clone(&self.conn_config);
|
||||
let conn_tracker = Arc::clone(&self.conn_tracker);
|
||||
let cancel = self.cancel_token.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::accept_loop(
|
||||
listener, port, route_manager, metrics, tls_configs,
|
||||
http_proxy, conn_config, conn_tracker, cancel,
|
||||
).await;
|
||||
});
|
||||
|
||||
self.listeners.insert(port, handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop listening on a port.
|
||||
pub fn remove_port(&mut self, port: u16) -> bool {
|
||||
if let Some(handle) = self.listeners.remove(&port) {
|
||||
handle.abort();
|
||||
info!("Stopped listening on port {}", port);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all currently listening ports.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.listeners.keys().copied().collect();
|
||||
ports.sort();
|
||||
ports
|
||||
}
|
||||
|
||||
/// Stop all listeners gracefully.
|
||||
///
|
||||
/// Signals cancellation and waits up to `graceful_shutdown_timeout_ms` for
|
||||
/// connections to drain, then aborts remaining tasks.
|
||||
pub async fn graceful_stop(&mut self) {
|
||||
let timeout_ms = self.conn_config.graceful_shutdown_timeout_ms;
|
||||
info!("Initiating graceful shutdown (timeout: {}ms)", timeout_ms);
|
||||
|
||||
// Signal all accept loops to stop accepting new connections
|
||||
self.cancel_token.cancel();
|
||||
|
||||
// Wait for existing connections to drain
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
let deadline = tokio::time::Instant::now() + timeout;
|
||||
|
||||
for (port, handle) in self.listeners.drain() {
|
||||
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
|
||||
if remaining.is_zero() {
|
||||
handle.abort();
|
||||
warn!("Force-stopped listener on port {} (timeout exceeded)", port);
|
||||
} else {
|
||||
match tokio::time::timeout(remaining, handle).await {
|
||||
Ok(_) => info!("Listener on port {} stopped gracefully", port),
|
||||
Err(_) => {
|
||||
warn!("Listener on port {} did not stop in time, aborting", port);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset cancellation token for potential restart
|
||||
self.cancel_token = CancellationToken::new();
|
||||
info!("Graceful shutdown complete");
|
||||
}
|
||||
|
||||
/// Stop all listeners immediately (backward compatibility).
|
||||
pub fn stop_all(&mut self) {
|
||||
self.cancel_token.cancel();
|
||||
for (port, handle) in self.listeners.drain() {
|
||||
handle.abort();
|
||||
info!("Stopped listening on port {}", port);
|
||||
}
|
||||
self.cancel_token = CancellationToken::new();
|
||||
}
|
||||
|
||||
/// Update the route manager (for hot-reload).
|
||||
pub fn update_route_manager(&mut self, route_manager: Arc<RouteManager>) {
|
||||
self.route_manager = route_manager;
|
||||
}
|
||||
|
||||
/// Get a reference to the metrics collector.
|
||||
pub fn metrics(&self) -> &Arc<MetricsCollector> {
|
||||
&self.metrics
|
||||
}
|
||||
|
||||
/// Accept loop for a single port.
|
||||
async fn accept_loop(
|
||||
listener: TcpListener,
|
||||
port: u16,
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
info!("Accept loop on port {} shutting down", port);
|
||||
break;
|
||||
}
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, peer_addr)) => {
|
||||
let ip = peer_addr.ip();
|
||||
|
||||
// Check per-IP limits and rate limiting
|
||||
if !conn_tracker.try_accept(&ip) {
|
||||
debug!("Rejected connection from {} (per-IP limit or rate limit)", peer_addr);
|
||||
drop(stream);
|
||||
continue;
|
||||
}
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
|
||||
let rm = Arc::clone(&route_manager);
|
||||
let m = Arc::clone(&metrics);
|
||||
let tc = Arc::clone(&tls_configs);
|
||||
let hp = Arc::clone(&http_proxy);
|
||||
let cc = Arc::clone(&conn_config);
|
||||
let ct = Arc::clone(&conn_tracker);
|
||||
let cn = cancel.clone();
|
||||
debug!("Accepted connection from {} on port {}", peer_addr, port);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = Self::handle_connection(
|
||||
stream, port, peer_addr, rm, m, tc, hp, cc, cn,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
debug!("Connection error from {}: {}", peer_addr, e);
|
||||
}
|
||||
ct.connection_closed(&ip);
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error on port {}: {}", port, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single incoming connection.
|
||||
async fn handle_connection(
|
||||
mut stream: tokio::net::TcpStream,
|
||||
port: u16,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
stream.set_nodelay(true)?;
|
||||
|
||||
// Handle PROXY protocol if configured
|
||||
let mut effective_peer_addr = peer_addr;
|
||||
if conn_config.accept_proxy_protocol {
|
||||
let mut proxy_peek = vec![0u8; 256];
|
||||
let pn = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut proxy_peek),
|
||||
).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Initial data timeout (proxy protocol peek)".into()),
|
||||
};
|
||||
|
||||
if pn > 0 && crate::proxy_protocol::is_proxy_protocol_v1(&proxy_peek[..pn]) {
|
||||
match crate::proxy_protocol::parse_v1(&proxy_peek[..pn]) {
|
||||
Ok((header, consumed)) => {
|
||||
debug!("PROXY protocol: real client {} -> {}", header.source_addr, header.dest_addr);
|
||||
effective_peer_addr = header.source_addr;
|
||||
// Consume the proxy protocol header bytes
|
||||
let mut discard = vec![0u8; consumed];
|
||||
stream.read_exact(&mut discard).await?;
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to parse PROXY protocol header: {}", e);
|
||||
// Not a PROXY protocol header, continue normally
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let peer_addr = effective_peer_addr;
|
||||
|
||||
// Peek at initial bytes with timeout
|
||||
let mut peek_buf = vec![0u8; 4096];
|
||||
let n = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut peek_buf),
|
||||
).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Initial data timeout".into()),
|
||||
};
|
||||
let initial_data = &peek_buf[..n];
|
||||
|
||||
// Determine connection type and extract SNI if TLS
|
||||
let is_tls = sni_parser::is_tls(initial_data);
|
||||
let is_http = sni_parser::is_http(initial_data);
|
||||
let domain = if is_tls {
|
||||
match sni_parser::extract_sni(initial_data) {
|
||||
sni_parser::SniResult::Found(sni) => Some(sni),
|
||||
sni_parser::SniResult::NoSni => None,
|
||||
sni_parser::SniResult::NeedMoreData => {
|
||||
let mut bigger_buf = vec![0u8; 16384];
|
||||
let n = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut bigger_buf),
|
||||
).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("SNI data timeout".into()),
|
||||
};
|
||||
match sni_parser::extract_sni(&bigger_buf[..n]) {
|
||||
sni_parser::SniResult::Found(sni) => Some(sni),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
sni_parser::SniResult::NotTls => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: domain.as_deref(),
|
||||
path: None,
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls,
|
||||
};
|
||||
|
||||
let route_match = route_manager.find_route(&ctx);
|
||||
|
||||
let route_match = match route_match {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for port {} domain {:?}", port, domain);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
|
||||
// Check route-level IP security for passthrough connections
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security,
|
||||
&peer_addr.ip(),
|
||||
) {
|
||||
debug!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Track connection in metrics
|
||||
metrics.connection_opened(route_id);
|
||||
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
debug!("Route matched but no target available");
|
||||
metrics.connection_closed(route_id);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let target_host = target.host.first().to_string();
|
||||
let target_port = target.port.resolve(port);
|
||||
let tls_mode = route_match.route.tls_mode();
|
||||
|
||||
// Connection timeout for backend connections
|
||||
let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms);
|
||||
let base_inactivity_ms = conn_config.socket_timeout_ms;
|
||||
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
|
||||
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
|
||||
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
|
||||
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
|
||||
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
|
||||
(
|
||||
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
|
||||
std::time::Duration::from_millis(extended_lifetime),
|
||||
)
|
||||
}
|
||||
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_secs(u64::MAX / 2),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// Standard
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we should send PROXY protocol to backend
|
||||
let should_send_proxy = conn_config.send_proxy_protocol
|
||||
|| route_match.route.action.send_proxy_protocol.unwrap_or(false)
|
||||
|| target.send_proxy_protocol.unwrap_or(false);
|
||||
|
||||
// Generate PROXY protocol header if needed
|
||||
let proxy_header = if should_send_proxy {
|
||||
let dest = std::net::SocketAddr::new(
|
||||
target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)),
|
||||
target_port,
|
||||
);
|
||||
Some(crate::proxy_protocol::generate_v1(&peer_addr, &dest))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let result = match tls_mode {
|
||||
Some(rustproxy_config::TlsMode::Passthrough) => {
|
||||
// Raw TCP passthrough - connect to backend and forward
|
||||
let mut backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Backend connection timeout".into()),
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
// Send PROXY protocol header if configured
|
||||
if let Some(ref header) = proxy_header {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
backend.write_all(header.as_bytes()).await?;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Passthrough: {} -> {}:{} (SNI: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, Some(&actual_buf),
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
Some(rustproxy_config::TlsMode::Terminate) => {
|
||||
let tls_config = Self::find_tls_config(&domain, &tls_configs)?;
|
||||
|
||||
// TLS accept with timeout, applying route-level TLS settings
|
||||
let route_tls = route_match.route.action.tls.as_ref();
|
||||
let acceptor = tls_handler::build_tls_acceptor_with_config(
|
||||
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
|
||||
)?;
|
||||
let tls_stream = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
tls_handler::accept_tls(stream, &acceptor),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => return Err("TLS handshake timeout".into()),
|
||||
};
|
||||
|
||||
// Peek at decrypted data to determine if HTTP
|
||||
let mut buf_stream = tokio::io::BufReader::new(tls_stream);
|
||||
let peeked = {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
match buf_stream.fill_buf().await {
|
||||
Ok(data) => sni_parser::is_http(data),
|
||||
Err(_) => false,
|
||||
}
|
||||
};
|
||||
|
||||
if peeked {
|
||||
debug!(
|
||||
"TLS Terminate + HTTP: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
http_proxy.handle_io(buf_stream, peer_addr, port).await;
|
||||
} else {
|
||||
debug!(
|
||||
"TLS Terminate + TCP: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
// Raw TCP forwarding of decrypted stream
|
||||
let backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Backend connection timeout".into()),
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
let (tls_read, tls_write) = tokio::io::split(buf_stream);
|
||||
let (backend_read, backend_write) = tokio::io::split(backend);
|
||||
|
||||
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
tls_read, tls_write, backend_read, backend_write,
|
||||
inactivity_timeout, max_lifetime,
|
||||
).await;
|
||||
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Some(rustproxy_config::TlsMode::TerminateAndReencrypt) => {
|
||||
let route_tls = route_match.route.action.tls.as_ref();
|
||||
Self::handle_tls_terminate_reencrypt(
|
||||
stream, n, &domain, &target_host, target_port,
|
||||
peer_addr, &tls_configs, &metrics, route_id, &conn_config, route_tls,
|
||||
).await
|
||||
}
|
||||
None => {
|
||||
if is_http {
|
||||
// Plain HTTP - use HTTP proxy for request-level routing
|
||||
debug!("HTTP proxy: {} on port {}", peer_addr, port);
|
||||
http_proxy.handle_connection(stream, peer_addr, port).await;
|
||||
Ok(())
|
||||
} else {
|
||||
// Plain TCP forwarding (non-HTTP)
|
||||
let mut backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Backend connection timeout".into()),
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
// Send PROXY protocol header if configured
|
||||
if let Some(ref header) = proxy_header {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
backend.write_all(header.as_bytes()).await?;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Forward: {} -> {}:{}",
|
||||
peer_addr, target_host, target_port
|
||||
);
|
||||
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, Some(&actual_buf),
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
metrics.connection_closed(route_id);
|
||||
result
|
||||
}
|
||||
|
||||
/// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend.
|
||||
async fn handle_tls_terminate_reencrypt(
|
||||
stream: tokio::net::TcpStream,
|
||||
_peek_len: usize,
|
||||
domain: &Option<String>,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
tls_configs: &HashMap<String, TlsCertConfig>,
|
||||
metrics: &MetricsCollector,
|
||||
route_id: Option<&str>,
|
||||
conn_config: &ConnectionConfig,
|
||||
route_tls: Option<&rustproxy_config::RouteTls>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tls_config = Self::find_tls_config(domain, tls_configs)?;
|
||||
let acceptor = tls_handler::build_tls_acceptor_with_config(
|
||||
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
|
||||
)?;
|
||||
|
||||
// Accept TLS from client with timeout
|
||||
let client_tls = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
tls_handler::accept_tls(stream, &acceptor),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => return Err("TLS handshake timeout".into()),
|
||||
};
|
||||
|
||||
debug!(
|
||||
"TLS Terminate+Reencrypt: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
|
||||
// Connect to backend over TLS with timeout
|
||||
let backend_tls = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
|
||||
tls_handler::connect_tls(target_host, target_port),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => return Err("Backend TLS connection timeout".into()),
|
||||
};
|
||||
|
||||
// Forward between two TLS streams
|
||||
let (client_read, client_write) = tokio::io::split(client_tls);
|
||||
let (backend_read, backend_write) = tokio::io::split(backend_tls);
|
||||
|
||||
let base_inactivity_ms = conn_config.socket_timeout_ms;
|
||||
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
|
||||
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
|
||||
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
|
||||
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
|
||||
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
|
||||
(
|
||||
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
|
||||
std::time::Duration::from_millis(extended_lifetime),
|
||||
)
|
||||
}
|
||||
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_secs(u64::MAX / 2),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// Standard
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
client_read, client_write, backend_read, backend_write,
|
||||
inactivity_timeout, max_lifetime,
|
||||
).await;
|
||||
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find the TLS config for a given domain.
|
||||
fn find_tls_config<'a>(
|
||||
domain: &Option<String>,
|
||||
tls_configs: &'a HashMap<String, TlsCertConfig>,
|
||||
) -> Result<&'a TlsCertConfig, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if let Some(domain) = domain {
|
||||
// Try exact match
|
||||
if let Some(config) = tls_configs.get(domain) {
|
||||
return Ok(config);
|
||||
}
|
||||
// Try wildcard
|
||||
if let Some(dot_pos) = domain.find('.') {
|
||||
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
|
||||
if let Some(config) = tls_configs.get(&wildcard) {
|
||||
return Ok(config);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try default/fallback cert
|
||||
if let Some(config) = tls_configs.get("*") {
|
||||
return Ok(config);
|
||||
}
|
||||
// Try first available cert
|
||||
if let Some((_key, config)) = tls_configs.iter().next() {
|
||||
return Ok(config);
|
||||
}
|
||||
Err("No TLS certificate available for this domain".into())
|
||||
}
|
||||
|
||||
/// Forward bidirectional between two split streams with inactivity and lifetime timeouts.
|
||||
async fn forward_bidirectional_split_with_timeouts<R1, W1, R2, W2>(
|
||||
mut client_read: R1,
|
||||
mut client_write: W1,
|
||||
mut backend_read: R2,
|
||||
mut backend_write: W2,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
) -> (u64, u64)
|
||||
where
|
||||
R1: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
W1: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
R2: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
W2: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la1.store(
|
||||
start.elapsed().as_millis() as u64,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la2.store(
|
||||
start.elapsed().as_millis() as u64,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog task: check for inactivity and max lifetime
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::time::sleep(check_interval).await;
|
||||
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
// No activity since last check
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
(bytes_in, bytes_out)
|
||||
}
|
||||
}
|
||||
190
rust/crates/rustproxy-passthrough/src/tls_handler.rs
Normal file
190
rust/crates/rustproxy-passthrough/src/tls_handler.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
||||
use tracing::debug;
|
||||
|
||||
/// Ensure the default crypto provider is installed.
|
||||
fn ensure_crypto_provider() {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
||||
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
|
||||
pub fn build_tls_acceptor_with_config(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
tls_config: Option<&rustproxy_config::RouteTls>,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let certs = load_certs(cert_pem)?;
|
||||
let key = load_private_key(key_pem)?;
|
||||
|
||||
let mut config = if let Some(route_tls) = tls_config {
|
||||
// Apply TLS version restrictions
|
||||
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
||||
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
||||
builder
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
} else {
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
};
|
||||
|
||||
// Apply session timeout if configured
|
||||
if let Some(route_tls) = tls_config {
|
||||
if let Some(timeout_secs) = route_tls.session_timeout {
|
||||
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
|
||||
256, // max sessions
|
||||
);
|
||||
debug!("TLS session timeout configured: {}s", timeout_secs);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
||||
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
let versions = match versions {
|
||||
Some(v) if !v.is_empty() => v,
|
||||
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
||||
};
|
||||
|
||||
let mut result = Vec::new();
|
||||
for v in versions {
|
||||
match v.as_str() {
|
||||
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
|
||||
if !result.contains(&&rustls::version::TLS12) {
|
||||
result.push(&rustls::version::TLS12);
|
||||
}
|
||||
}
|
||||
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
|
||||
if !result.contains(&&rustls::version::TLS13) {
|
||||
result.push(&rustls::version::TLS13);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
debug!("Unknown TLS version '{}', ignoring", other);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_empty() {
|
||||
// Fallback to both if no valid versions specified
|
||||
vec![&rustls::version::TLS12, &rustls::version::TLS13]
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a TLS connection from a client stream.
|
||||
pub async fn accept_tls(
|
||||
stream: TcpStream,
|
||||
acceptor: &TlsAcceptor,
|
||||
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tls_stream = acceptor.accept(stream).await?;
|
||||
debug!("TLS handshake completed");
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
||||
pub async fn connect_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
debug!("Backend TLS connection established to {}:{}", host, port);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// Load certificates from PEM string.
|
||||
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in PEM data".into());
|
||||
}
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
/// Load private key from PEM string.
|
||||
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
// Try PKCS8 first, then RSA, then EC
|
||||
let key = rustls_pemfile::private_key(&mut reader)?
|
||||
.ok_or("No private key found in PEM data")?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
|
||||
/// In internal networks, backends may use self-signed certs.
|
||||
#[derive(Debug)]
|
||||
struct InsecureVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
16
rust/crates/rustproxy-routing/Cargo.toml
Normal file
16
rust/crates/rustproxy-routing/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "rustproxy-routing"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Route matching engine for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
glob-match = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
9
rust/crates/rustproxy-routing/src/lib.rs
Normal file
9
rust/crates/rustproxy-routing/src/lib.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! # rustproxy-routing
|
||||
//!
|
||||
//! Route matching engine for RustProxy.
|
||||
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
|
||||
|
||||
pub mod route_manager;
|
||||
pub mod matchers;
|
||||
|
||||
pub use route_manager::*;
|
||||
86
rust/crates/rustproxy-routing/src/matchers/domain.rs
Normal file
86
rust/crates/rustproxy-routing/src/matchers/domain.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
/// Match a domain against a pattern supporting wildcards.
|
||||
///
|
||||
/// Supported patterns:
|
||||
/// - `*` matches any domain
|
||||
/// - `*.example.com` matches any subdomain of example.com
|
||||
/// - `example.com` exact match
|
||||
/// - `**.example.com` matches any depth of subdomain
|
||||
pub fn domain_matches(pattern: &str, domain: &str) -> bool {
|
||||
let pattern = pattern.trim().to_lowercase();
|
||||
let domain = domain.trim().to_lowercase();
|
||||
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
if pattern == domain {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Wildcard patterns
|
||||
if pattern.starts_with("*.") {
|
||||
let suffix = &pattern[2..]; // e.g., "example.com"
|
||||
// Match exact parent or any single-level subdomain
|
||||
if domain == suffix {
|
||||
return true;
|
||||
}
|
||||
if domain.ends_with(&format!(".{}", suffix)) {
|
||||
// Check it's a single level subdomain for `*.`
|
||||
let prefix = &domain[..domain.len() - suffix.len() - 1];
|
||||
return !prefix.contains('.');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if pattern.starts_with("**.") {
|
||||
let suffix = &pattern[3..];
|
||||
// Match exact parent or any depth of subdomain
|
||||
return domain == suffix || domain.ends_with(&format!(".{}", suffix));
|
||||
}
|
||||
|
||||
// Use glob-match for more complex patterns
|
||||
glob_match::glob_match(&pattern, &domain)
|
||||
}
|
||||
|
||||
/// Check if a domain matches any of the given patterns.
|
||||
pub fn domain_matches_any(patterns: &[&str], domain: &str) -> bool {
|
||||
patterns.iter().any(|p| domain_matches(p, domain))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_match() {
|
||||
assert!(domain_matches("example.com", "example.com"));
|
||||
assert!(!domain_matches("example.com", "other.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_all() {
|
||||
assert!(domain_matches("*", "anything.com"));
|
||||
assert!(domain_matches("*", "sub.domain.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_subdomain() {
|
||||
assert!(domain_matches("*.example.com", "www.example.com"));
|
||||
assert!(domain_matches("*.example.com", "api.example.com"));
|
||||
assert!(domain_matches("*.example.com", "example.com"));
|
||||
assert!(!domain_matches("*.example.com", "deep.sub.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_wildcard() {
|
||||
assert!(domain_matches("**.example.com", "www.example.com"));
|
||||
assert!(domain_matches("**.example.com", "deep.sub.example.com"));
|
||||
assert!(domain_matches("**.example.com", "example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive() {
|
||||
assert!(domain_matches("Example.COM", "example.com"));
|
||||
assert!(domain_matches("*.EXAMPLE.com", "WWW.example.COM"));
|
||||
}
|
||||
}
|
||||
98
rust/crates/rustproxy-routing/src/matchers/header.rs
Normal file
98
rust/crates/rustproxy-routing/src/matchers/header.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
|
||||
/// Match HTTP headers against a set of patterns.
|
||||
///
|
||||
/// Pattern values can be:
|
||||
/// - Exact string: `"application/json"`
|
||||
/// - Regex (surrounded by /): `"/^text\/.*/"`
|
||||
pub fn headers_match(
|
||||
patterns: &HashMap<String, String>,
|
||||
headers: &HashMap<String, String>,
|
||||
) -> bool {
|
||||
for (key, pattern) in patterns {
|
||||
let key_lower = key.to_lowercase();
|
||||
|
||||
// Find the header (case-insensitive)
|
||||
let header_value = headers
|
||||
.iter()
|
||||
.find(|(k, _)| k.to_lowercase() == key_lower)
|
||||
.map(|(_, v)| v.as_str());
|
||||
|
||||
let header_value = match header_value {
|
||||
Some(v) => v,
|
||||
None => return false, // Required header not present
|
||||
};
|
||||
|
||||
// Check if pattern is a regex (surrounded by /)
|
||||
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
|
||||
let regex_str = &pattern[1..pattern.len() - 1];
|
||||
match Regex::new(regex_str) {
|
||||
Ok(re) => {
|
||||
if !re.is_match(header_value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Invalid regex, fall back to exact match
|
||||
if header_value != pattern {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Exact match
|
||||
if header_value != pattern {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_header_match() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("Content-Type".to_string(), "application/json".to_string());
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("content-type".to_string(), "application/json".to_string());
|
||||
m
|
||||
};
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_header_match() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("Content-Type".to_string(), "/^text\\/.*/".to_string());
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("content-type".to_string(), "text/html".to_string());
|
||||
m
|
||||
};
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_header() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("X-Custom".to_string(), "value".to_string());
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = HashMap::new();
|
||||
assert!(!headers_match(&patterns, &headers));
|
||||
}
|
||||
}
|
||||
126
rust/crates/rustproxy-routing/src/matchers/ip.rs
Normal file
126
rust/crates/rustproxy-routing/src/matchers/ip.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use ipnet::IpNet;
|
||||
|
||||
/// Match an IP address against a pattern.
|
||||
///
|
||||
/// Supported patterns:
|
||||
/// - `*` matches any IP
|
||||
/// - `192.168.1.0/24` CIDR range
|
||||
/// - `192.168.1.100` exact match
|
||||
/// - `192.168.1.*` wildcard (converted to CIDR)
|
||||
/// - `::ffff:192.168.1.100` IPv6-mapped IPv4
|
||||
pub fn ip_matches(pattern: &str, ip: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Normalize IPv4-mapped IPv6
|
||||
let normalized_ip = normalize_ip_str(ip);
|
||||
|
||||
// Try CIDR match
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = IpNet::from_str(pattern) {
|
||||
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
|
||||
return net.contains(&addr);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle wildcard patterns like 192.168.1.*
|
||||
if pattern.contains('*') {
|
||||
let pattern_cidr = wildcard_to_cidr(pattern);
|
||||
if let Some(cidr) = pattern_cidr {
|
||||
if let Ok(net) = IpNet::from_str(&cidr) {
|
||||
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
|
||||
return net.contains(&addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Exact match
|
||||
let normalized_pattern = normalize_ip_str(pattern);
|
||||
normalized_ip == normalized_pattern
|
||||
}
|
||||
|
||||
/// Check if an IP matches any of the given patterns.
|
||||
pub fn ip_matches_any(patterns: &[String], ip: &str) -> bool {
|
||||
patterns.iter().any(|p| ip_matches(p, ip))
|
||||
}
|
||||
|
||||
/// Normalize IPv4-mapped IPv6 addresses.
|
||||
fn normalize_ip_str(ip: &str) -> String {
|
||||
let ip = ip.trim();
|
||||
if ip.starts_with("::ffff:") {
|
||||
return ip[7..].to_string();
|
||||
}
|
||||
ip.to_string()
|
||||
}
|
||||
|
||||
/// Convert a wildcard IP pattern to CIDR notation.
|
||||
/// e.g., "192.168.1.*" -> "192.168.1.0/24"
|
||||
fn wildcard_to_cidr(pattern: &str) -> Option<String> {
|
||||
let parts: Vec<&str> = pattern.split('.').collect();
|
||||
if parts.len() != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut octets = [0u8; 4];
|
||||
let mut prefix_len = 0;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
break;
|
||||
}
|
||||
if let Ok(n) = part.parse::<u8>() {
|
||||
octets[i] = n;
|
||||
prefix_len += 8;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_all() {
|
||||
assert!(ip_matches("*", "192.168.1.100"));
|
||||
assert!(ip_matches("*", "::1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_match() {
|
||||
assert!(ip_matches("192.168.1.100", "192.168.1.100"));
|
||||
assert!(!ip_matches("192.168.1.100", "192.168.1.101"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cidr() {
|
||||
assert!(ip_matches("192.168.1.0/24", "192.168.1.100"));
|
||||
assert!(ip_matches("192.168.1.0/24", "192.168.1.1"));
|
||||
assert!(!ip_matches("192.168.1.0/24", "192.168.2.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_pattern() {
|
||||
assert!(ip_matches("192.168.1.*", "192.168.1.100"));
|
||||
assert!(ip_matches("192.168.1.*", "192.168.1.1"));
|
||||
assert!(!ip_matches("192.168.1.*", "192.168.2.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_mapped() {
|
||||
assert!(ip_matches("192.168.1.100", "::ffff:192.168.1.100"));
|
||||
assert!(ip_matches("192.168.1.0/24", "::ffff:192.168.1.50"));
|
||||
}
|
||||
}
|
||||
9
rust/crates/rustproxy-routing/src/matchers/mod.rs
Normal file
9
rust/crates/rustproxy-routing/src/matchers/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod domain;
|
||||
pub mod path;
|
||||
pub mod ip;
|
||||
pub mod header;
|
||||
|
||||
pub use domain::*;
|
||||
pub use path::*;
|
||||
pub use ip::*;
|
||||
pub use header::*;
|
||||
65
rust/crates/rustproxy-routing/src/matchers/path.rs
Normal file
65
rust/crates/rustproxy-routing/src/matchers/path.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
/// Match a URL path against a pattern supporting wildcards.
|
||||
///
|
||||
/// Supported patterns:
|
||||
/// - `/api/*` matches `/api/anything` (single level)
|
||||
/// - `/api/**` matches `/api/any/depth/here`
|
||||
/// - `/exact/path` exact match
|
||||
/// - `/prefix*` prefix match
|
||||
pub fn path_matches(pattern: &str, path: &str) -> bool {
|
||||
// Exact match
|
||||
if pattern == path {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Double-star: match any depth
|
||||
if pattern.ends_with("/**") {
|
||||
let prefix = &pattern[..pattern.len() - 3];
|
||||
return path == prefix || path.starts_with(&format!("{}/", prefix));
|
||||
}
|
||||
|
||||
// Single-star at end: match single path segment
|
||||
if pattern.ends_with("/*") {
|
||||
let prefix = &pattern[..pattern.len() - 2];
|
||||
if path == prefix {
|
||||
return true;
|
||||
}
|
||||
if path.starts_with(&format!("{}/", prefix)) {
|
||||
let rest = &path[prefix.len() + 1..];
|
||||
// Single level means no more slashes
|
||||
return !rest.contains('/');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Star anywhere: use glob matching
|
||||
if pattern.contains('*') {
|
||||
return glob_match::glob_match(pattern, path);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_path() {
|
||||
assert!(path_matches("/api/users", "/api/users"));
|
||||
assert!(!path_matches("/api/users", "/api/posts"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_wildcard() {
|
||||
assert!(path_matches("/api/*", "/api/users"));
|
||||
assert!(path_matches("/api/*", "/api/posts"));
|
||||
assert!(!path_matches("/api/*", "/api/users/123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_wildcard() {
|
||||
assert!(path_matches("/api/**", "/api/users"));
|
||||
assert!(path_matches("/api/**", "/api/users/123"));
|
||||
assert!(path_matches("/api/**", "/api/users/123/posts"));
|
||||
}
|
||||
}
|
||||
545
rust/crates/rustproxy-routing/src/route_manager.rs
Normal file
545
rust/crates/rustproxy-routing/src/route_manager.rs
Normal file
@@ -0,0 +1,545 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
|
||||
use crate::matchers;
|
||||
|
||||
/// Context for route matching (subset of connection info).
|
||||
pub struct MatchContext<'a> {
|
||||
pub port: u16,
|
||||
pub domain: Option<&'a str>,
|
||||
pub path: Option<&'a str>,
|
||||
pub client_ip: Option<&'a str>,
|
||||
pub tls_version: Option<&'a str>,
|
||||
pub headers: Option<&'a HashMap<String, String>>,
|
||||
pub is_tls: bool,
|
||||
}
|
||||
|
||||
/// Result of a route match.
|
||||
pub struct RouteMatchResult<'a> {
|
||||
pub route: &'a RouteConfig,
|
||||
pub target: Option<&'a RouteTarget>,
|
||||
}
|
||||
|
||||
/// Port-indexed route lookup with priority-based matching.
|
||||
/// This is the core routing engine.
|
||||
pub struct RouteManager {
|
||||
/// Routes indexed by port for O(1) port lookup.
|
||||
port_index: HashMap<u16, Vec<usize>>,
|
||||
/// All routes, sorted by priority (highest first).
|
||||
routes: Vec<RouteConfig>,
|
||||
}
|
||||
|
||||
impl RouteManager {
|
||||
/// Create a new RouteManager from a list of routes.
|
||||
pub fn new(routes: Vec<RouteConfig>) -> Self {
|
||||
let mut manager = Self {
|
||||
port_index: HashMap::new(),
|
||||
routes: Vec::new(),
|
||||
};
|
||||
|
||||
// Filter enabled routes and sort by priority
|
||||
let mut enabled_routes: Vec<RouteConfig> = routes
|
||||
.into_iter()
|
||||
.filter(|r| r.is_enabled())
|
||||
.collect();
|
||||
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
|
||||
|
||||
// Build port index
|
||||
for (idx, route) in enabled_routes.iter().enumerate() {
|
||||
for port in route.listening_ports() {
|
||||
manager.port_index
|
||||
.entry(port)
|
||||
.or_default()
|
||||
.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
manager.routes = enabled_routes;
|
||||
manager
|
||||
}
|
||||
|
||||
/// Find the best matching route for the given context.
|
||||
pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option<RouteMatchResult<'a>> {
|
||||
// Get routes for this port
|
||||
let route_indices = self.port_index.get(&ctx.port)?;
|
||||
|
||||
for &idx in route_indices {
|
||||
let route = &self.routes[idx];
|
||||
|
||||
if self.matches_route(route, ctx) {
|
||||
// Find the best matching target within the route
|
||||
let target = self.find_target(route, ctx);
|
||||
return Some(RouteMatchResult { route, target });
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a route matches the given context.
|
||||
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
|
||||
let rm = &route.route_match;
|
||||
|
||||
// Domain matching
|
||||
if let Some(ref domains) = rm.domains {
|
||||
if let Some(domain) = ctx.domain {
|
||||
let patterns = domains.to_vec();
|
||||
if !matchers::domain_matches_any(&patterns, domain) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If no domain provided but route requires domain, it depends on context
|
||||
// For TLS passthrough, we need SNI; for other cases we may still match
|
||||
}
|
||||
|
||||
// Path matching
|
||||
if let Some(ref pattern) = rm.path {
|
||||
if let Some(path) = ctx.path {
|
||||
if !matchers::path_matches(pattern, path) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Route requires path but none provided
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Client IP matching
|
||||
if let Some(ref client_ips) = rm.client_ip {
|
||||
if let Some(ip) = ctx.client_ip {
|
||||
if !matchers::ip_matches_any(client_ips, ip) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// TLS version matching
|
||||
if let Some(ref tls_versions) = rm.tls_version {
|
||||
if let Some(version) = ctx.tls_version {
|
||||
if !tls_versions.iter().any(|v| v == version) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Header matching
|
||||
if let Some(ref patterns) = rm.headers {
|
||||
if let Some(headers) = ctx.headers {
|
||||
if !matchers::headers_match(patterns, headers) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Find the best matching target within a route.
|
||||
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
|
||||
let targets = route.action.targets.as_ref()?;
|
||||
|
||||
if targets.len() == 1 && targets[0].target_match.is_none() {
|
||||
return Some(&targets[0]);
|
||||
}
|
||||
|
||||
// Sort candidates by priority (already in order from config)
|
||||
let mut best: Option<&RouteTarget> = None;
|
||||
let mut best_priority = i32::MIN;
|
||||
|
||||
for target in targets {
|
||||
let priority = target.priority.unwrap_or(0);
|
||||
|
||||
if let Some(ref tm) = target.target_match {
|
||||
if !self.matches_target(tm, ctx) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if priority > best_priority || best.is_none() {
|
||||
best = Some(target);
|
||||
best_priority = priority;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to first target without match criteria
|
||||
best.or_else(|| {
|
||||
targets.iter().find(|t| t.target_match.is_none())
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a target match criteria matches the context.
|
||||
fn matches_target(
|
||||
&self,
|
||||
tm: &rustproxy_config::TargetMatch,
|
||||
ctx: &MatchContext<'_>,
|
||||
) -> bool {
|
||||
// Port matching
|
||||
if let Some(ref ports) = tm.ports {
|
||||
if !ports.contains(&ctx.port) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Path matching
|
||||
if let Some(ref pattern) = tm.path {
|
||||
if let Some(path) = ctx.path {
|
||||
if !matchers::path_matches(pattern, path) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Header matching
|
||||
if let Some(ref patterns) = tm.headers {
|
||||
if let Some(headers) = ctx.headers {
|
||||
if !matchers::headers_match(patterns, headers) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Get all unique listening ports.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.port_index.keys().copied().collect();
|
||||
ports.sort();
|
||||
ports
|
||||
}
|
||||
|
||||
/// Get all routes for a specific port.
|
||||
pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> {
|
||||
self.port_index
|
||||
.get(&port)
|
||||
.map(|indices| indices.iter().map(|&i| &self.routes[i]).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get the total number of enabled routes.
|
||||
pub fn route_count(&self) -> usize {
|
||||
self.routes.len()
|
||||
}
|
||||
|
||||
/// Check if any route on the given port requires SNI.
|
||||
pub fn port_requires_sni(&self, port: u16) -> bool {
|
||||
let routes = self.routes_for_port(port);
|
||||
|
||||
// If multiple passthrough routes on same port, SNI is needed
|
||||
let passthrough_routes: Vec<_> = routes
|
||||
.iter()
|
||||
.filter(|r| {
|
||||
r.tls_mode() == Some(&TlsMode::Passthrough)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if passthrough_routes.len() > 1 {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Single passthrough route with specific domain restriction needs SNI
|
||||
if let Some(route) = passthrough_routes.first() {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
let domain_list = domains.to_vec();
|
||||
// If it's not just a wildcard, SNI is needed
|
||||
if !domain_list.iter().all(|d| *d == "*") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustproxy_config::*;
|
||||
|
||||
fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(port),
|
||||
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single("localhost".to_string()),
|
||||
port: PortSpec::Fixed(8080),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: Some(priority),
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_routing() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("example.com"), 0),
|
||||
make_route(80, Some("other.com"), 0),
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let result = manager.find_route(&ctx);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_ordering() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("*.example.com"), 0),
|
||||
make_route(80, Some("api.example.com"), 10), // Higher priority
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("api.example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
// Should match the higher-priority specific route
|
||||
assert!(result.route.route_match.domains.as_ref()
|
||||
.map(|d| d.to_vec())
|
||||
.unwrap()
|
||||
.contains(&"api.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match() {
|
||||
let routes = vec![make_route(80, Some("example.com"), 0)];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 443, // Different port
|
||||
domain: Some("example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disabled_routes_excluded() {
|
||||
let mut route = make_route(80, Some("example.com"), 0);
|
||||
route.enabled = Some(false);
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
assert_eq!(manager.route_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_listening_ports() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("a.com"), 0),
|
||||
make_route(443, Some("b.com"), 0),
|
||||
make_route(80, Some("c.com"), 0), // duplicate port
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
let ports = manager.listening_ports();
|
||||
assert_eq!(ports, vec![80, 443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_port_requires_sni_single_passthrough() {
|
||||
let mut route = make_route(443, Some("example.com"), 0);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
// Single passthrough route with specific domain needs SNI
|
||||
assert!(manager.port_requires_sni(443));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_port_requires_sni_wildcard_only() {
|
||||
let mut route = make_route(443, Some("*"), 0);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
// Single passthrough route with wildcard doesn't need SNI
|
||||
assert!(!manager.port_requires_sni(443));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routes_for_port() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("a.com"), 0),
|
||||
make_route(80, Some("b.com"), 0),
|
||||
make_route(443, Some("c.com"), 0),
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
assert_eq!(manager.routes_for_port(80).len(), 2);
|
||||
assert_eq!(manager.routes_for_port(443).len(), 1);
|
||||
assert_eq!(manager.routes_for_port(8080).len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_domain_matches_any() {
|
||||
let routes = vec![make_route(80, Some("*"), 0)];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("anything.example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_domain_route_matches_any_domain() {
|
||||
let routes = vec![make_route(80, None, 0)];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_target_sub_matching() {
|
||||
let mut route = make_route(80, Some("example.com"), 0);
|
||||
route.action.targets = Some(vec![
|
||||
RouteTarget {
|
||||
target_match: Some(rustproxy_config::TargetMatch {
|
||||
ports: None,
|
||||
path: Some("/api/*".to_string()),
|
||||
headers: None,
|
||||
method: None,
|
||||
}),
|
||||
host: HostSpec::Single("api-backend".to_string()),
|
||||
port: PortSpec::Fixed(3000),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: Some(10),
|
||||
},
|
||||
RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single("default-backend".to_string()),
|
||||
port: PortSpec::Fixed(8080),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
},
|
||||
]);
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
|
||||
// Should match the API target
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: Some("/api/users"),
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
assert_eq!(result.target.unwrap().host.first(), "api-backend");
|
||||
|
||||
// Should fall back to default target
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: Some("/home"),
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
assert_eq!(result.target.unwrap().host.first(), "default-backend");
|
||||
}
|
||||
}
|
||||
17
rust/crates/rustproxy-security/Cargo.toml
Normal file
17
rust/crates/rustproxy-security/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "rustproxy-security"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "IP filtering, rate limiting, and authentication for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
111
rust/crates/rustproxy-security/src/basic_auth.rs
Normal file
111
rust/crates/rustproxy-security/src/basic_auth.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
|
||||
/// Basic auth validator.
|
||||
pub struct BasicAuthValidator {
|
||||
users: Vec<(String, String)>,
|
||||
realm: String,
|
||||
}
|
||||
|
||||
impl BasicAuthValidator {
|
||||
pub fn new(users: Vec<(String, String)>, realm: Option<String>) -> Self {
|
||||
Self {
|
||||
users,
|
||||
realm: realm.unwrap_or_else(|| "Restricted".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate an Authorization header value.
|
||||
/// Returns the username if valid.
|
||||
pub fn validate(&self, auth_header: &str) -> Option<String> {
|
||||
let auth_header = auth_header.trim();
|
||||
if !auth_header.starts_with("Basic ") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let encoded = &auth_header[6..];
|
||||
let decoded = BASE64.decode(encoded).ok()?;
|
||||
let credentials = String::from_utf8(decoded).ok()?;
|
||||
|
||||
let mut parts = credentials.splitn(2, ':');
|
||||
let username = parts.next()?;
|
||||
let password = parts.next()?;
|
||||
|
||||
for (u, p) in &self.users {
|
||||
if u == username && p == password {
|
||||
return Some(username.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the realm for WWW-Authenticate header.
|
||||
pub fn realm(&self) -> &str {
|
||||
&self.realm
|
||||
}
|
||||
|
||||
/// Generate the WWW-Authenticate header value.
|
||||
pub fn www_authenticate(&self) -> String {
|
||||
format!("Basic realm=\"{}\"", self.realm)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use base64::Engine;
|
||||
|
||||
fn make_validator() -> BasicAuthValidator {
|
||||
BasicAuthValidator::new(
|
||||
vec![
|
||||
("admin".to_string(), "secret".to_string()),
|
||||
("user".to_string(), "pass".to_string()),
|
||||
],
|
||||
Some("TestRealm".to_string()),
|
||||
)
|
||||
}
|
||||
|
||||
fn encode_basic(user: &str, pass: &str) -> String {
|
||||
let encoded = BASE64.encode(format!("{}:{}", user, pass));
|
||||
format!("Basic {}", encoded)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_credentials() {
|
||||
let validator = make_validator();
|
||||
let header = encode_basic("admin", "secret");
|
||||
assert_eq!(validator.validate(&header), Some("admin".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_password() {
|
||||
let validator = make_validator();
|
||||
let header = encode_basic("admin", "wrong");
|
||||
assert_eq!(validator.validate(&header), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_basic_scheme() {
|
||||
let validator = make_validator();
|
||||
assert_eq!(validator.validate("Bearer sometoken"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_malformed_base64() {
|
||||
let validator = make_validator();
|
||||
assert_eq!(validator.validate("Basic !!!not-base64!!!"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_www_authenticate_format() {
|
||||
let validator = make_validator();
|
||||
assert_eq!(validator.www_authenticate(), "Basic realm=\"TestRealm\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_realm() {
|
||||
let validator = BasicAuthValidator::new(vec![], None);
|
||||
assert_eq!(validator.www_authenticate(), "Basic realm=\"Restricted\"");
|
||||
}
|
||||
}
|
||||
189
rust/crates/rustproxy-security/src/ip_filter.rs
Normal file
189
rust/crates/rustproxy-security/src/ip_filter.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use ipnet::IpNet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
|
||||
pub struct IpFilter {
|
||||
allow_list: Vec<IpPattern>,
|
||||
block_list: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
/// Represents an IP pattern for matching.
|
||||
#[derive(Debug)]
|
||||
enum IpPattern {
|
||||
/// Exact IP match
|
||||
Exact(IpAddr),
|
||||
/// CIDR range match
|
||||
Cidr(IpNet),
|
||||
/// Wildcard (matches everything)
|
||||
Wildcard,
|
||||
}
|
||||
|
||||
impl IpPattern {
|
||||
fn parse(s: &str) -> Self {
|
||||
let s = s.trim();
|
||||
if s == "*" {
|
||||
return IpPattern::Wildcard;
|
||||
}
|
||||
if let Ok(net) = IpNet::from_str(s) {
|
||||
return IpPattern::Cidr(net);
|
||||
}
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Try as CIDR by appending default prefix
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Fallback: treat as exact, will never match an invalid string
|
||||
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
|
||||
}
|
||||
|
||||
fn matches(&self, ip: &IpAddr) -> bool {
|
||||
match self {
|
||||
IpPattern::Wildcard => true,
|
||||
IpPattern::Exact(addr) => addr == ip,
|
||||
IpPattern::Cidr(net) => net.contains(ip),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IpFilter {
|
||||
/// Create a new IP filter from allow and block lists.
|
||||
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
|
||||
Self {
|
||||
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an IP is allowed.
|
||||
/// If allow_list is non-empty, IP must match at least one entry.
|
||||
/// If block_list is non-empty, IP must NOT match any entry.
|
||||
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
|
||||
// Check block list first
|
||||
if !self.block_list.is_empty() {
|
||||
for pattern in &self.block_list {
|
||||
if pattern.matches(ip) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If allow list is non-empty, must match at least one
|
||||
if !self.allow_list.is_empty() {
|
||||
return self.allow_list.iter().any(|p| p.matches(ip));
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
|
||||
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
IpAddr::V6(v6) => {
|
||||
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||
IpAddr::V4(v4)
|
||||
} else {
|
||||
*ip
|
||||
}
|
||||
}
|
||||
_ => *ip,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_lists_allow_all() {
|
||||
let filter = IpFilter::new(&[], &[]);
|
||||
let ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_exact() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.1".to_string()],
|
||||
&[],
|
||||
);
|
||||
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let denied: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
assert!(!filter.is_allowed(&denied));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_cidr() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&[],
|
||||
);
|
||||
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
|
||||
let denied: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
assert!(!filter.is_allowed(&denied));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_list() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["192.168.1.100".to_string()],
|
||||
);
|
||||
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_trumps_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&["10.0.0.5".to_string()],
|
||||
);
|
||||
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
||||
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["*".to_string()],
|
||||
&[],
|
||||
);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_block() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["*".to_string()],
|
||||
);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_ipv4_mapped_ipv6() {
|
||||
let mapped: IpAddr = "::ffff:192.168.1.1".parse().unwrap();
|
||||
let normalized = IpFilter::normalize_ip(&mapped);
|
||||
let expected: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert_eq!(normalized, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_pure_ipv4() {
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let normalized = IpFilter::normalize_ip(&ip);
|
||||
assert_eq!(normalized, ip);
|
||||
}
|
||||
}
|
||||
174
rust/crates/rustproxy-security/src/jwt_auth.rs
Normal file
174
rust/crates/rustproxy-security/src/jwt_auth.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// JWT claims (minimal structure).
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: Option<String>,
|
||||
pub exp: Option<u64>,
|
||||
pub iss: Option<String>,
|
||||
pub aud: Option<String>,
|
||||
}
|
||||
|
||||
/// JWT auth validator.
|
||||
pub struct JwtValidator {
|
||||
decoding_key: DecodingKey,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
impl JwtValidator {
|
||||
pub fn new(
|
||||
secret: &str,
|
||||
algorithm: Option<&str>,
|
||||
issuer: Option<&str>,
|
||||
audience: Option<&str>,
|
||||
) -> Self {
|
||||
let algo = match algorithm {
|
||||
Some("HS384") => Algorithm::HS384,
|
||||
Some("HS512") => Algorithm::HS512,
|
||||
Some("RS256") => Algorithm::RS256,
|
||||
_ => Algorithm::HS256,
|
||||
};
|
||||
|
||||
let mut validation = Validation::new(algo);
|
||||
if let Some(iss) = issuer {
|
||||
validation.set_issuer(&[iss]);
|
||||
}
|
||||
if let Some(aud) = audience {
|
||||
validation.set_audience(&[aud]);
|
||||
}
|
||||
|
||||
Self {
|
||||
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
|
||||
validation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a JWT token string (without "Bearer " prefix).
|
||||
/// Returns the claims if valid.
|
||||
pub fn validate(&self, token: &str) -> Result<Claims, String> {
|
||||
decode::<Claims>(token, &self.decoding_key, &self.validation)
|
||||
.map(|data| data.claims)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// Extract token from Authorization header.
|
||||
pub fn extract_token(auth_header: &str) -> Option<&str> {
|
||||
let header = auth_header.trim();
|
||||
if header.starts_with("Bearer ") {
|
||||
Some(&header[7..])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
|
||||
fn make_token(secret: &str, claims: &Claims) -> String {
|
||||
encode(
|
||||
&Header::default(),
|
||||
claims,
|
||||
&EncodingKey::from_secret(secret.as_bytes()),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn future_exp() -> u64 {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
+ 3600
|
||||
}
|
||||
|
||||
fn past_exp() -> u64 {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
- 3600
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_token() {
|
||||
let secret = "test-secret";
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(future_exp()),
|
||||
iss: None,
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token(secret, &claims);
|
||||
let validator = JwtValidator::new(secret, None, None, None);
|
||||
let result = validator.validate(&token);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().sub, Some("user123".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expired_token() {
|
||||
let secret = "test-secret";
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(past_exp()),
|
||||
iss: None,
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token(secret, &claims);
|
||||
let validator = JwtValidator::new(secret, None, None, None);
|
||||
assert!(validator.validate(&token).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_secret() {
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(future_exp()),
|
||||
iss: None,
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token("correct-secret", &claims);
|
||||
let validator = JwtValidator::new("wrong-secret", None, None, None);
|
||||
assert!(validator.validate(&token).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_issuer_validation() {
|
||||
let secret = "test-secret";
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(future_exp()),
|
||||
iss: Some("my-issuer".to_string()),
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token(secret, &claims);
|
||||
|
||||
// Correct issuer
|
||||
let validator = JwtValidator::new(secret, None, Some("my-issuer"), None);
|
||||
assert!(validator.validate(&token).is_ok());
|
||||
|
||||
// Wrong issuer
|
||||
let validator = JwtValidator::new(secret, None, Some("other-issuer"), None);
|
||||
assert!(validator.validate(&token).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_bearer() {
|
||||
assert_eq!(
|
||||
JwtValidator::extract_token("Bearer abc123"),
|
||||
Some("abc123")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_non_bearer() {
|
||||
assert_eq!(JwtValidator::extract_token("Basic abc123"), None);
|
||||
assert_eq!(JwtValidator::extract_token("abc123"), None);
|
||||
}
|
||||
}
|
||||
13
rust/crates/rustproxy-security/src/lib.rs
Normal file
13
rust/crates/rustproxy-security/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! # rustproxy-security
|
||||
//!
|
||||
//! IP filtering, rate limiting, and authentication for RustProxy.
|
||||
|
||||
pub mod ip_filter;
|
||||
pub mod rate_limiter;
|
||||
pub mod basic_auth;
|
||||
pub mod jwt_auth;
|
||||
|
||||
pub use ip_filter::*;
|
||||
pub use rate_limiter::*;
|
||||
pub use basic_auth::*;
|
||||
pub use jwt_auth::*;
|
||||
97
rust/crates/rustproxy-security/src/rate_limiter.rs
Normal file
97
rust/crates/rustproxy-security/src/rate_limiter.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use dashmap::DashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Sliding window rate limiter.
|
||||
pub struct RateLimiter {
|
||||
/// Map of key -> list of request timestamps
|
||||
windows: DashMap<String, Vec<Instant>>,
|
||||
/// Maximum requests per window
|
||||
max_requests: u64,
|
||||
/// Window duration in seconds
|
||||
window_seconds: u64,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(max_requests: u64, window_seconds: u64) -> Self {
|
||||
Self {
|
||||
windows: DashMap::new(),
|
||||
max_requests,
|
||||
window_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a request is allowed for the given key.
|
||||
/// Returns true if allowed, false if rate limited.
|
||||
pub fn check(&self, key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(self.window_seconds);
|
||||
|
||||
let mut entry = self.windows.entry(key.to_string()).or_default();
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove expired entries
|
||||
timestamps.retain(|t| now.duration_since(*t) < window);
|
||||
|
||||
if timestamps.len() as u64 >= self.max_requests {
|
||||
false
|
||||
} else {
|
||||
timestamps.push(now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean up expired entries (call periodically).
|
||||
pub fn cleanup(&self) {
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(self.window_seconds);
|
||||
|
||||
self.windows.retain(|_, timestamps| {
|
||||
timestamps.retain(|t| now.duration_since(*t) < window);
|
||||
!timestamps.is_empty()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_allow_under_limit() {
|
||||
let limiter = RateLimiter::new(5, 60);
|
||||
for _ in 0..5 {
|
||||
assert!(limiter.check("client-1"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_over_limit() {
|
||||
let limiter = RateLimiter::new(3, 60);
|
||||
assert!(limiter.check("client-1"));
|
||||
assert!(limiter.check("client-1"));
|
||||
assert!(limiter.check("client-1"));
|
||||
assert!(!limiter.check("client-1")); // 4th request blocked
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keys_independent() {
|
||||
let limiter = RateLimiter::new(2, 60);
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(!limiter.check("client-a")); // blocked
|
||||
// Different key should still be allowed
|
||||
assert!(limiter.check("client-b"));
|
||||
assert!(limiter.check("client-b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cleanup_removes_expired() {
|
||||
let limiter = RateLimiter::new(100, 0); // 0 second window = immediately expired
|
||||
limiter.check("client-1");
|
||||
// Sleep briefly to let entries expire
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
limiter.cleanup();
|
||||
// After cleanup, the key should be allowed again (entries expired)
|
||||
assert!(limiter.check("client-1"));
|
||||
}
|
||||
}
|
||||
22
rust/crates/rustproxy-tls/Cargo.toml
Normal file
22
rust/crates/rustproxy-tls/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "rustproxy-tls"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "TLS certificate management for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
instant-acme = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
360
rust/crates/rustproxy-tls/src/acme.rs
Normal file
360
rust/crates/rustproxy-tls/src/acme.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
//! ACME (Let's Encrypt) integration using instant-acme.
|
||||
//!
|
||||
//! This module handles HTTP-01 challenge creation and certificate provisioning.
|
||||
//! Supports persisting ACME account credentials to disk for reuse across restarts.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use instant_acme::{
|
||||
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
|
||||
AccountCredentials,
|
||||
};
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AcmeError {
|
||||
#[error("ACME account creation failed: {0}")]
|
||||
AccountCreation(String),
|
||||
#[error("ACME order failed: {0}")]
|
||||
OrderFailed(String),
|
||||
#[error("Challenge failed: {0}")]
|
||||
ChallengeFailed(String),
|
||||
#[error("Certificate finalization failed: {0}")]
|
||||
FinalizationFailed(String),
|
||||
#[error("No HTTP-01 challenge found")]
|
||||
NoHttp01Challenge,
|
||||
#[error("Timeout waiting for order: {0}")]
|
||||
Timeout(String),
|
||||
#[error("Account persistence error: {0}")]
|
||||
Persistence(String),
|
||||
}
|
||||
|
||||
/// Pending HTTP-01 challenge that needs to be served.
|
||||
pub struct PendingChallenge {
|
||||
pub token: String,
|
||||
pub key_authorization: String,
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// ACME client wrapper around instant-acme.
|
||||
pub struct AcmeClient {
|
||||
use_production: bool,
|
||||
email: String,
|
||||
/// Optional directory where account.json is persisted.
|
||||
account_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl AcmeClient {
|
||||
pub fn new(email: String, use_production: bool) -> Self {
|
||||
Self {
|
||||
use_production,
|
||||
email,
|
||||
account_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with account persistence at the given directory.
|
||||
pub fn with_persistence(email: String, use_production: bool, account_dir: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
use_production,
|
||||
email,
|
||||
account_dir: Some(account_dir.as_ref().to_path_buf()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create an ACME account, persisting credentials if account_dir is set.
|
||||
async fn get_or_create_account(&self) -> Result<Account, AcmeError> {
|
||||
let directory_url = self.directory_url();
|
||||
|
||||
// Try to restore from persisted credentials
|
||||
if let Some(ref dir) = self.account_dir {
|
||||
let account_file = dir.join("account.json");
|
||||
if account_file.exists() {
|
||||
match std::fs::read_to_string(&account_file) {
|
||||
Ok(json) => {
|
||||
match serde_json::from_str::<AccountCredentials>(&json) {
|
||||
Ok(credentials) => {
|
||||
match Account::from_credentials(credentials).await {
|
||||
Ok(account) => {
|
||||
debug!("Restored ACME account from {}", account_file.display());
|
||||
return Ok(account);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to restore ACME account, creating new: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid account.json, creating new account: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Could not read account.json: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new account
|
||||
let contact = format!("mailto:{}", self.email);
|
||||
let (account, credentials) = Account::create(
|
||||
&NewAccount {
|
||||
contact: &[&contact],
|
||||
terms_of_service_agreed: true,
|
||||
only_return_existing: false,
|
||||
},
|
||||
directory_url,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AcmeError::AccountCreation(e.to_string()))?;
|
||||
|
||||
debug!("ACME account created");
|
||||
|
||||
// Persist credentials if we have a directory
|
||||
if let Some(ref dir) = self.account_dir {
|
||||
if let Err(e) = std::fs::create_dir_all(dir) {
|
||||
warn!("Failed to create account directory {}: {}", dir.display(), e);
|
||||
} else {
|
||||
let account_file = dir.join("account.json");
|
||||
match serde_json::to_string_pretty(&credentials) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = std::fs::write(&account_file, &json) {
|
||||
warn!("Failed to persist ACME account to {}: {}", account_file.display(), e);
|
||||
} else {
|
||||
info!("ACME account credentials persisted to {}", account_file.display());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize account credentials: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(account)
|
||||
}
|
||||
|
||||
/// Request a certificate for a domain using the HTTP-01 challenge.
|
||||
///
|
||||
/// Returns (cert_chain_pem, private_key_pem) on success.
|
||||
///
|
||||
/// The caller must serve the HTTP-01 challenge at:
|
||||
/// `http://<domain>/.well-known/acme-challenge/<token>`
|
||||
///
|
||||
/// The `challenge_handler` closure is called with a `PendingChallenge`
|
||||
/// and must arrange for the challenge response to be served. It should
|
||||
/// return once the challenge is ready to be validated.
|
||||
pub async fn provision<F, Fut>(
|
||||
&self,
|
||||
domain: &str,
|
||||
challenge_handler: F,
|
||||
) -> Result<(String, String), AcmeError>
|
||||
where
|
||||
F: FnOnce(PendingChallenge) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(), AcmeError>>,
|
||||
{
|
||||
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
|
||||
|
||||
// 1. Get or create ACME account (with persistence)
|
||||
let account = self.get_or_create_account().await?;
|
||||
|
||||
// 2. Create order
|
||||
let identifier = Identifier::Dns(domain.to_string());
|
||||
let mut order = account
|
||||
.new_order(&NewOrder {
|
||||
identifiers: &[identifier],
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
|
||||
debug!("ACME order created");
|
||||
|
||||
// 3. Get authorizations and find HTTP-01 challenge
|
||||
let authorizations = order
|
||||
.authorizations()
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
|
||||
// Find the HTTP-01 challenge
|
||||
let (challenge_token, challenge_url) = authorizations
|
||||
.iter()
|
||||
.flat_map(|auth| auth.challenges.iter())
|
||||
.find(|c| c.r#type == ChallengeType::Http01)
|
||||
.map(|c| {
|
||||
let key_auth = order.key_authorization(c);
|
||||
(
|
||||
PendingChallenge {
|
||||
token: c.token.clone(),
|
||||
key_authorization: key_auth.as_str().to_string(),
|
||||
domain: domain.to_string(),
|
||||
},
|
||||
c.url.clone(),
|
||||
)
|
||||
})
|
||||
.ok_or(AcmeError::NoHttp01Challenge)?;
|
||||
|
||||
// Call the handler to set up challenge serving
|
||||
challenge_handler(challenge_token).await?;
|
||||
|
||||
// 4. Notify ACME server that challenge is ready
|
||||
order
|
||||
.set_challenge_ready(&challenge_url)
|
||||
.await
|
||||
.map_err(|e| AcmeError::ChallengeFailed(e.to_string()))?;
|
||||
|
||||
debug!("Challenge marked as ready, waiting for validation...");
|
||||
|
||||
// 5. Poll for order to become ready
|
||||
let mut attempts = 0;
|
||||
let state = loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
let state = order
|
||||
.refresh()
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
|
||||
match state.status {
|
||||
OrderStatus::Ready | OrderStatus::Valid => break state.status,
|
||||
OrderStatus::Invalid => {
|
||||
return Err(AcmeError::ChallengeFailed(
|
||||
"Order became invalid (challenge failed)".to_string(),
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
attempts += 1;
|
||||
if attempts > 30 {
|
||||
return Err(AcmeError::Timeout(
|
||||
"Order did not become ready within 60 seconds".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Order ready, finalizing...");
|
||||
|
||||
// 6. Generate CSR and finalize
|
||||
let key_pair = KeyPair::generate().map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
|
||||
})?;
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
|
||||
})?;
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let csr = params.serialize_request(&key_pair).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
|
||||
})?;
|
||||
|
||||
if state == OrderStatus::Ready {
|
||||
order
|
||||
.finalize(csr.der())
|
||||
.await
|
||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?;
|
||||
}
|
||||
|
||||
// 7. Wait for certificate to be issued
|
||||
let mut attempts = 0;
|
||||
loop {
|
||||
let state = order
|
||||
.refresh()
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
if state.status == OrderStatus::Valid {
|
||||
break;
|
||||
}
|
||||
if state.status == OrderStatus::Invalid {
|
||||
return Err(AcmeError::FinalizationFailed(
|
||||
"Order became invalid during finalization".to_string(),
|
||||
));
|
||||
}
|
||||
attempts += 1;
|
||||
if attempts > 15 {
|
||||
return Err(AcmeError::Timeout(
|
||||
"Certificate not issued within 30 seconds".to_string(),
|
||||
));
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
}
|
||||
|
||||
// 8. Download certificate
|
||||
let cert_chain_pem = order
|
||||
.certificate()
|
||||
.await
|
||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
|
||||
.ok_or_else(|| {
|
||||
AcmeError::FinalizationFailed("No certificate returned".to_string())
|
||||
})?;
|
||||
|
||||
let private_key_pem = key_pair.serialize_pem();
|
||||
|
||||
info!("Certificate provisioned successfully for {}", domain);
|
||||
|
||||
Ok((cert_chain_pem, private_key_pem))
|
||||
}
|
||||
|
||||
/// Restore an ACME account from stored credentials.
|
||||
pub async fn restore_account(
|
||||
&self,
|
||||
credentials: AccountCredentials,
|
||||
) -> Result<Account, AcmeError> {
|
||||
Account::from_credentials(credentials)
|
||||
.await
|
||||
.map_err(|e| AcmeError::AccountCreation(e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the ACME directory URL based on production/staging.
|
||||
pub fn directory_url(&self) -> &str {
|
||||
if self.use_production {
|
||||
"https://acme-v02.api.letsencrypt.org/directory"
|
||||
} else {
|
||||
"https://acme-staging-v02.api.letsencrypt.org/directory"
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this client is configured for production.
|
||||
pub fn is_production(&self) -> bool {
|
||||
self.use_production
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_directory_url_staging() {
|
||||
let client = AcmeClient::new("test@example.com".to_string(), false);
|
||||
assert!(client.directory_url().contains("staging"));
|
||||
assert!(!client.is_production());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_directory_url_production() {
|
||||
let client = AcmeClient::new("test@example.com".to_string(), true);
|
||||
assert!(!client.directory_url().contains("staging"));
|
||||
assert!(client.is_production());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_persistence_sets_account_dir() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let client = AcmeClient::with_persistence(
|
||||
"test@example.com".to_string(),
|
||||
false,
|
||||
tmp.path(),
|
||||
);
|
||||
assert!(client.account_dir.is_some());
|
||||
assert_eq!(client.account_dir.unwrap(), tmp.path());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_without_persistence_no_account_dir() {
|
||||
let client = AcmeClient::new("test@example.com".to_string(), false);
|
||||
assert!(client.account_dir.is_none());
|
||||
}
|
||||
}
|
||||
183
rust/crates/rustproxy-tls/src/cert_manager.rs
Normal file
183
rust/crates/rustproxy-tls/src/cert_manager.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use crate::acme::AcmeClient;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CertManagerError {
|
||||
#[error("ACME provisioning failed for {domain}: {message}")]
|
||||
AcmeFailure { domain: String, message: String },
|
||||
#[error("Certificate store error: {0}")]
|
||||
Store(#[from] crate::cert_store::CertStoreError),
|
||||
#[error("No ACME email configured")]
|
||||
NoEmail,
|
||||
}
|
||||
|
||||
/// Certificate lifecycle manager.
|
||||
/// Handles ACME provisioning, static cert loading, and renewal.
|
||||
pub struct CertManager {
|
||||
store: CertStore,
|
||||
acme_email: Option<String>,
|
||||
use_production: bool,
|
||||
renew_before_days: u32,
|
||||
}
|
||||
|
||||
impl CertManager {
|
||||
pub fn new(
|
||||
store: CertStore,
|
||||
acme_email: Option<String>,
|
||||
use_production: bool,
|
||||
renew_before_days: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
acme_email,
|
||||
use_production,
|
||||
renew_before_days,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a certificate for a domain (from cache).
|
||||
pub fn get_cert(&self, domain: &str) -> Option<&CertBundle> {
|
||||
self.store.get(domain)
|
||||
}
|
||||
|
||||
/// Create an ACME client using this manager's configuration.
|
||||
/// Returns None if no ACME email is configured.
|
||||
/// Account credentials are persisted in the cert store base directory.
|
||||
pub fn acme_client(&self) -> Option<AcmeClient> {
|
||||
self.acme_email.as_ref().map(|email| {
|
||||
AcmeClient::with_persistence(
|
||||
email.clone(),
|
||||
self.use_production,
|
||||
self.store.base_dir(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a static certificate into the store.
|
||||
pub fn load_static(
|
||||
&mut self,
|
||||
domain: String,
|
||||
bundle: CertBundle,
|
||||
) -> Result<(), CertManagerError> {
|
||||
self.store.store(domain, bundle)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check and return domains that need certificate renewal.
|
||||
///
|
||||
/// A certificate needs renewal if it expires within `renew_before_days`.
|
||||
/// Returns a list of domain names needing renewal.
|
||||
pub fn check_renewals(&self) -> Vec<String> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let renewal_threshold = self.renew_before_days as u64 * 86400;
|
||||
let mut needs_renewal = Vec::new();
|
||||
|
||||
for (domain, bundle) in self.store.iter() {
|
||||
// Only auto-renew ACME certs
|
||||
if bundle.metadata.source != CertSource::Acme {
|
||||
continue;
|
||||
}
|
||||
|
||||
let time_until_expiry = bundle.metadata.expires_at.saturating_sub(now);
|
||||
if time_until_expiry < renewal_threshold {
|
||||
info!(
|
||||
"Certificate for {} needs renewal (expires in {} days)",
|
||||
domain,
|
||||
time_until_expiry / 86400
|
||||
);
|
||||
needs_renewal.push(domain.clone());
|
||||
}
|
||||
}
|
||||
|
||||
needs_renewal
|
||||
}
|
||||
|
||||
/// Renew a certificate for a domain.
|
||||
///
|
||||
/// Performs the full ACME provision+store flow. The `challenge_setup` closure
|
||||
/// is called to arrange for the HTTP-01 challenge to be served. It receives
|
||||
/// (token, key_authorization) and must make the challenge response available.
|
||||
///
|
||||
/// Returns the new CertBundle on success.
|
||||
pub async fn renew_domain<F, Fut>(
|
||||
&mut self,
|
||||
domain: &str,
|
||||
challenge_setup: F,
|
||||
) -> Result<CertBundle, CertManagerError>
|
||||
where
|
||||
F: FnOnce(String, String) -> Fut,
|
||||
Fut: std::future::Future<Output = ()>,
|
||||
{
|
||||
let acme_client = self.acme_client()
|
||||
.ok_or(CertManagerError::NoEmail)?;
|
||||
|
||||
info!("Renewing certificate for {}", domain);
|
||||
|
||||
let domain_owned = domain.to_string();
|
||||
let result = acme_client.provision(&domain_owned, |pending| {
|
||||
let token = pending.token.clone();
|
||||
let key_auth = pending.key_authorization.clone();
|
||||
async move {
|
||||
challenge_setup(token, key_auth).await;
|
||||
Ok(())
|
||||
}
|
||||
}).await.map_err(|e| CertManagerError::AcmeFailure {
|
||||
domain: domain.to_string(),
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
let (cert_pem, key_pem) = result;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem,
|
||||
key_pem,
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Acme,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400,
|
||||
renewed_at: Some(now),
|
||||
},
|
||||
};
|
||||
|
||||
self.store.store(domain.to_string(), bundle.clone())?;
|
||||
info!("Certificate renewed and stored for {}", domain);
|
||||
|
||||
Ok(bundle)
|
||||
}
|
||||
|
||||
/// Load all certificates from disk.
|
||||
pub fn load_all(&mut self) -> Result<usize, CertManagerError> {
|
||||
let loaded = self.store.load_all()?;
|
||||
info!("Loaded {} certificates from store", loaded);
|
||||
Ok(loaded)
|
||||
}
|
||||
|
||||
/// Whether this manager has an ACME email configured.
|
||||
pub fn has_acme(&self) -> bool {
|
||||
self.acme_email.is_some()
|
||||
}
|
||||
|
||||
/// Get reference to the underlying store.
|
||||
pub fn store(&self) -> &CertStore {
|
||||
&self.store
|
||||
}
|
||||
|
||||
/// Get mutable reference to the underlying store.
|
||||
pub fn store_mut(&mut self) -> &mut CertStore {
|
||||
&mut self.store
|
||||
}
|
||||
}
|
||||
314
rust/crates/rustproxy-tls/src/cert_store.rs
Normal file
314
rust/crates/rustproxy-tls/src/cert_store.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CertStoreError {
|
||||
#[error("Certificate not found for domain: {0}")]
|
||||
NotFound(String),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Invalid certificate: {0}")]
|
||||
Invalid(String),
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
/// Certificate metadata stored alongside certs on disk.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CertMetadata {
|
||||
pub domain: String,
|
||||
pub source: CertSource,
|
||||
pub issued_at: u64,
|
||||
pub expires_at: u64,
|
||||
pub renewed_at: Option<u64>,
|
||||
}
|
||||
|
||||
/// How a certificate was obtained.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CertSource {
|
||||
Acme,
|
||||
Static,
|
||||
Custom,
|
||||
SelfSigned,
|
||||
}
|
||||
|
||||
/// An in-memory certificate bundle.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertBundle {
|
||||
pub key_pem: String,
|
||||
pub cert_pem: String,
|
||||
pub ca_pem: Option<String>,
|
||||
pub metadata: CertMetadata,
|
||||
}
|
||||
|
||||
/// Filesystem-backed certificate store.
|
||||
///
|
||||
/// File layout per domain:
|
||||
/// ```text
|
||||
/// {base_dir}/{domain}/
|
||||
/// key.pem
|
||||
/// cert.pem
|
||||
/// ca.pem (optional)
|
||||
/// metadata.json
|
||||
/// ```
|
||||
pub struct CertStore {
|
||||
base_dir: PathBuf,
|
||||
/// In-memory cache of loaded certs
|
||||
cache: HashMap<String, CertBundle>,
|
||||
}
|
||||
|
||||
impl CertStore {
|
||||
/// Create a new cert store at the given directory.
|
||||
pub fn new(base_dir: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
base_dir: base_dir.as_ref().to_path_buf(),
|
||||
cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a certificate by domain.
|
||||
pub fn get(&self, domain: &str) -> Option<&CertBundle> {
|
||||
self.cache.get(domain)
|
||||
}
|
||||
|
||||
/// Store a certificate to both cache and filesystem.
|
||||
pub fn store(&mut self, domain: String, bundle: CertBundle) -> Result<(), CertStoreError> {
|
||||
// Sanitize domain for directory name (replace wildcards)
|
||||
let dir_name = domain.replace('*', "_wildcard_");
|
||||
let cert_dir = self.base_dir.join(&dir_name);
|
||||
|
||||
// Create directory
|
||||
std::fs::create_dir_all(&cert_dir)?;
|
||||
|
||||
// Write key
|
||||
std::fs::write(cert_dir.join("key.pem"), &bundle.key_pem)?;
|
||||
|
||||
// Write cert
|
||||
std::fs::write(cert_dir.join("cert.pem"), &bundle.cert_pem)?;
|
||||
|
||||
// Write CA cert if present
|
||||
if let Some(ref ca) = bundle.ca_pem {
|
||||
std::fs::write(cert_dir.join("ca.pem"), ca)?;
|
||||
}
|
||||
|
||||
// Write metadata
|
||||
let metadata_json = serde_json::to_string_pretty(&bundle.metadata)?;
|
||||
std::fs::write(cert_dir.join("metadata.json"), metadata_json)?;
|
||||
|
||||
// Update cache
|
||||
self.cache.insert(domain, bundle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a certificate exists for a domain.
|
||||
pub fn has(&self, domain: &str) -> bool {
|
||||
self.cache.contains_key(domain)
|
||||
}
|
||||
|
||||
/// Load all certificates from the base directory.
|
||||
pub fn load_all(&mut self) -> Result<usize, CertStoreError> {
|
||||
if !self.base_dir.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let entries = std::fs::read_dir(&self.base_dir)?;
|
||||
let mut loaded = 0;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata_path = path.join("metadata.json");
|
||||
let key_path = path.join("key.pem");
|
||||
let cert_path = path.join("cert.pem");
|
||||
|
||||
// All three files must exist
|
||||
if !metadata_path.exists() || !key_path.exists() || !cert_path.exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Load metadata
|
||||
let metadata_str = std::fs::read_to_string(&metadata_path)?;
|
||||
let metadata: CertMetadata = serde_json::from_str(&metadata_str)?;
|
||||
|
||||
// Load key and cert
|
||||
let key_pem = std::fs::read_to_string(&key_path)?;
|
||||
let cert_pem = std::fs::read_to_string(&cert_path)?;
|
||||
|
||||
// Load CA cert if present
|
||||
let ca_path = path.join("ca.pem");
|
||||
let ca_pem = if ca_path.exists() {
|
||||
Some(std::fs::read_to_string(&ca_path)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let domain = metadata.domain.clone();
|
||||
let bundle = CertBundle {
|
||||
key_pem,
|
||||
cert_pem,
|
||||
ca_pem,
|
||||
metadata,
|
||||
};
|
||||
|
||||
self.cache.insert(domain, bundle);
|
||||
loaded += 1;
|
||||
}
|
||||
|
||||
Ok(loaded)
|
||||
}
|
||||
|
||||
/// Get the base directory.
|
||||
pub fn base_dir(&self) -> &Path {
|
||||
&self.base_dir
|
||||
}
|
||||
|
||||
/// Get the number of cached certificates.
|
||||
pub fn count(&self) -> usize {
|
||||
self.cache.len()
|
||||
}
|
||||
|
||||
/// Iterate over all cached certificates.
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&String, &CertBundle)> {
|
||||
self.cache.iter()
|
||||
}
|
||||
|
||||
/// Remove a certificate from cache and filesystem.
|
||||
pub fn remove(&mut self, domain: &str) -> Result<bool, CertStoreError> {
|
||||
let removed = self.cache.remove(domain).is_some();
|
||||
if removed {
|
||||
let dir_name = domain.replace('*', "_wildcard_");
|
||||
let cert_dir = self.base_dir.join(&dir_name);
|
||||
if cert_dir.exists() {
|
||||
std::fs::remove_dir_all(&cert_dir)?;
|
||||
}
|
||||
}
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_bundle(domain: &str) -> CertBundle {
|
||||
CertBundle {
|
||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
|
||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: 1700000000,
|
||||
expires_at: 1700000000 + 90 * 86400,
|
||||
renewed_at: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_and_load_roundtrip() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
let bundle = make_test_bundle("example.com");
|
||||
store.store("example.com".to_string(), bundle.clone()).unwrap();
|
||||
|
||||
// Verify files exist
|
||||
let cert_dir = tmp.path().join("example.com");
|
||||
assert!(cert_dir.join("key.pem").exists());
|
||||
assert!(cert_dir.join("cert.pem").exists());
|
||||
assert!(cert_dir.join("metadata.json").exists());
|
||||
assert!(!cert_dir.join("ca.pem").exists()); // No CA cert
|
||||
|
||||
// Load into a fresh store
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
let loaded = store2.load_all().unwrap();
|
||||
assert_eq!(loaded, 1);
|
||||
|
||||
let loaded_bundle = store2.get("example.com").unwrap();
|
||||
assert_eq!(loaded_bundle.key_pem, bundle.key_pem);
|
||||
assert_eq!(loaded_bundle.cert_pem, bundle.cert_pem);
|
||||
assert_eq!(loaded_bundle.metadata.domain, "example.com");
|
||||
assert_eq!(loaded_bundle.metadata.source, CertSource::Static);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_with_ca_cert() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
let mut bundle = make_test_bundle("secure.com");
|
||||
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||
store.store("secure.com".to_string(), bundle).unwrap();
|
||||
|
||||
let cert_dir = tmp.path().join("secure.com");
|
||||
assert!(cert_dir.join("ca.pem").exists());
|
||||
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
store2.load_all().unwrap();
|
||||
let loaded = store2.get("secure.com").unwrap();
|
||||
assert!(loaded.ca_pem.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_multiple_certs() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
store.store("a.com".to_string(), make_test_bundle("a.com")).unwrap();
|
||||
store.store("b.com".to_string(), make_test_bundle("b.com")).unwrap();
|
||||
store.store("c.com".to_string(), make_test_bundle("c.com")).unwrap();
|
||||
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
let loaded = store2.load_all().unwrap();
|
||||
assert_eq!(loaded, 3);
|
||||
assert!(store2.has("a.com"));
|
||||
assert!(store2.has("b.com"));
|
||||
assert!(store2.has("c.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_missing_directory() {
|
||||
let mut store = CertStore::new("/nonexistent/path/to/certs");
|
||||
let loaded = store.load_all().unwrap();
|
||||
assert_eq!(loaded, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_cert() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com")).unwrap();
|
||||
assert!(store.has("remove-me.com"));
|
||||
|
||||
let removed = store.remove("remove-me.com").unwrap();
|
||||
assert!(removed);
|
||||
assert!(!store.has("remove-me.com"));
|
||||
assert!(!tmp.path().join("remove-me.com").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_domain_storage() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
store.store("*.example.com".to_string(), make_test_bundle("*.example.com")).unwrap();
|
||||
|
||||
// Directory should use sanitized name
|
||||
assert!(tmp.path().join("_wildcard_.example.com").exists());
|
||||
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
store2.load_all().unwrap();
|
||||
assert!(store2.has("*.example.com"));
|
||||
}
|
||||
}
|
||||
13
rust/crates/rustproxy-tls/src/lib.rs
Normal file
13
rust/crates/rustproxy-tls/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! # rustproxy-tls
|
||||
//!
|
||||
//! TLS certificate management for RustProxy.
|
||||
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
|
||||
|
||||
pub mod cert_store;
|
||||
pub mod cert_manager;
|
||||
pub mod acme;
|
||||
pub mod sni_resolver;
|
||||
|
||||
pub use cert_store::*;
|
||||
pub use cert_manager::*;
|
||||
pub use sni_resolver::*;
|
||||
139
rust/crates/rustproxy-tls/src/sni_resolver.rs
Normal file
139
rust/crates/rustproxy-tls/src/sni_resolver.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::cert_store::CertBundle;
|
||||
|
||||
/// Dynamic SNI-based certificate resolver.
|
||||
/// Used by the TLS stack to select the right certificate based on client SNI.
|
||||
pub struct SniResolver {
|
||||
/// Domain -> certificate bundle mapping
|
||||
certs: RwLock<HashMap<String, Arc<CertBundle>>>,
|
||||
/// Fallback certificate (used when no SNI or no match)
|
||||
fallback: RwLock<Option<Arc<CertBundle>>>,
|
||||
}
|
||||
|
||||
impl SniResolver {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
certs: RwLock::new(HashMap::new()),
|
||||
fallback: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a certificate for a domain.
|
||||
pub fn add_cert(&self, domain: String, bundle: CertBundle) {
|
||||
let mut certs = self.certs.write().unwrap();
|
||||
certs.insert(domain, Arc::new(bundle));
|
||||
}
|
||||
|
||||
/// Set the fallback certificate.
|
||||
pub fn set_fallback(&self, bundle: CertBundle) {
|
||||
let mut fallback = self.fallback.write().unwrap();
|
||||
*fallback = Some(Arc::new(bundle));
|
||||
}
|
||||
|
||||
/// Resolve a certificate for the given SNI domain.
|
||||
pub fn resolve(&self, domain: &str) -> Option<Arc<CertBundle>> {
|
||||
let certs = self.certs.read().unwrap();
|
||||
|
||||
// Try exact match
|
||||
if let Some(bundle) = certs.get(domain) {
|
||||
return Some(Arc::clone(bundle));
|
||||
}
|
||||
|
||||
// Try wildcard match (e.g., *.example.com)
|
||||
if let Some(dot_pos) = domain.find('.') {
|
||||
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
|
||||
if let Some(bundle) = certs.get(&wildcard) {
|
||||
return Some(Arc::clone(bundle));
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback
|
||||
let fallback = self.fallback.read().unwrap();
|
||||
fallback.clone()
|
||||
}
|
||||
|
||||
/// Remove a certificate for a domain.
|
||||
pub fn remove_cert(&self, domain: &str) {
|
||||
let mut certs = self.certs.write().unwrap();
|
||||
certs.remove(domain);
|
||||
}
|
||||
|
||||
/// Get the number of registered certificates.
|
||||
pub fn cert_count(&self) -> usize {
|
||||
self.certs.read().unwrap().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SniResolver {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cert_store::{CertBundle, CertMetadata, CertSource};
|
||||
|
||||
fn make_bundle(domain: &str) -> CertBundle {
|
||||
CertBundle {
|
||||
key_pem: format!("KEY-{}", domain),
|
||||
cert_pem: format!("CERT-{}", domain),
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: 0,
|
||||
expires_at: 0,
|
||||
renewed_at: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_domain_resolve() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||||
let result = resolver.resolve("example.com");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().cert_pem, "CERT-example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_resolve() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com"));
|
||||
let result = resolver.resolve("sub.example.com");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.set_fallback(make_bundle("fallback"));
|
||||
let result = resolver.resolve("unknown.com");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().cert_pem, "CERT-fallback");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match_no_fallback() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||||
let result = resolver.resolve("other.com");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_cert() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||||
assert_eq!(resolver.cert_count(), 1);
|
||||
resolver.remove_cert("example.com");
|
||||
assert_eq!(resolver.cert_count(), 0);
|
||||
assert!(resolver.resolve("example.com").is_none());
|
||||
}
|
||||
}
|
||||
44
rust/crates/rustproxy/Cargo.toml
Normal file
44
rust/crates/rustproxy/Cargo.toml
Normal file
@@ -0,0 +1,44 @@
|
||||
[package]
|
||||
name = "rustproxy"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "High-performance multi-protocol proxy built on Pingora, compatible with SmartProxy configuration"
|
||||
|
||||
[[bin]]
|
||||
name = "rustproxy"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "rustproxy"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-tls = { workspace = true }
|
||||
rustproxy-passthrough = { workspace = true }
|
||||
rustproxy-http = { workspace = true }
|
||||
rustproxy-nftables = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
hyper = { workspace = true }
|
||||
hyper-util = { workspace = true }
|
||||
http-body-util = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen = { workspace = true }
|
||||
177
rust/crates/rustproxy/src/challenge_server.rs
Normal file
177
rust/crates/rustproxy/src/challenge_server.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
//! HTTP-01 ACME challenge server.
|
||||
//!
|
||||
//! A lightweight HTTP server that serves ACME challenge responses at
|
||||
//! `/.well-known/acme-challenge/<token>`.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, error};
|
||||
|
||||
/// ACME HTTP-01 challenge server.
|
||||
pub struct ChallengeServer {
|
||||
/// Token -> key authorization mapping
|
||||
challenges: Arc<DashMap<String, String>>,
|
||||
/// Cancellation token to stop the server
|
||||
cancel: CancellationToken,
|
||||
/// Server task handle
|
||||
handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl ChallengeServer {
|
||||
/// Create a new challenge server (not yet started).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
challenges: Arc::new(DashMap::new()),
|
||||
cancel: CancellationToken::new(),
|
||||
handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a challenge token -> key_authorization mapping.
|
||||
pub fn set_challenge(&self, token: String, key_authorization: String) {
|
||||
debug!("Registered ACME challenge: token={}", token);
|
||||
self.challenges.insert(token, key_authorization);
|
||||
}
|
||||
|
||||
/// Remove a challenge token.
|
||||
pub fn remove_challenge(&self, token: &str) {
|
||||
self.challenges.remove(token);
|
||||
}
|
||||
|
||||
/// Start the challenge server on the given port.
|
||||
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
info!("ACME challenge server listening on port {}", port);
|
||||
|
||||
let challenges = Arc::clone(&self.challenges);
|
||||
let cancel = self.cancel.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
info!("ACME challenge server stopping");
|
||||
break;
|
||||
}
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, _)) => {
|
||||
let challenges = Arc::clone(&challenges);
|
||||
tokio::spawn(async move {
|
||||
let io = TokioIo::new(stream);
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
let challenges = Arc::clone(&challenges);
|
||||
async move {
|
||||
Self::handle_request(req, &challenges)
|
||||
}
|
||||
});
|
||||
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.serve_connection(io, service);
|
||||
|
||||
if let Err(e) = conn.await {
|
||||
debug!("Challenge server connection error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Challenge server accept error: {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
self.handle = Some(handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the challenge server.
|
||||
pub async fn stop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
if let Some(handle) = self.handle.take() {
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
handle,
|
||||
).await;
|
||||
}
|
||||
self.challenges.clear();
|
||||
self.cancel = CancellationToken::new();
|
||||
info!("ACME challenge server stopped");
|
||||
}
|
||||
|
||||
/// Handle an HTTP request for ACME challenges.
|
||||
fn handle_request(
|
||||
req: Request<Incoming>,
|
||||
challenges: &DashMap<String, String>,
|
||||
) -> Result<Response<Full<Bytes>>, hyper::Error> {
|
||||
let path = req.uri().path();
|
||||
|
||||
if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
|
||||
if let Some(key_auth) = challenges.get(token) {
|
||||
debug!("Serving ACME challenge for token: {}", token);
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Full::new(Bytes::from(key_auth.value().clone())))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Full::new(Bytes::from("Not Found")))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_challenge_server_lifecycle() {
|
||||
let mut server = ChallengeServer::new();
|
||||
|
||||
// Set a challenge before starting
|
||||
server.set_challenge("test-token".to_string(), "test-key-auth".to_string());
|
||||
|
||||
// Start on a random port
|
||||
server.start(19900).await.unwrap();
|
||||
|
||||
// Give server a moment to start
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Fetch the challenge
|
||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
|
||||
let io = TokioIo::new(client);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
|
||||
tokio::spawn(async move { let _ = conn.await; });
|
||||
|
||||
let req = Request::get("/.well-known/acme-challenge/test-token")
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap();
|
||||
let resp = sender.send_request(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
// Test 404 for unknown token
|
||||
let req = Request::get("/.well-known/acme-challenge/unknown")
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap();
|
||||
let resp = sender.send_request(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
server.stop().await;
|
||||
}
|
||||
}
|
||||
931
rust/crates/rustproxy/src/lib.rs
Normal file
931
rust/crates/rustproxy/src/lib.rs
Normal file
@@ -0,0 +1,931 @@
|
||||
//! # RustProxy
|
||||
//!
|
||||
//! High-performance multi-protocol proxy built on Rust,
|
||||
//! compatible with SmartProxy configuration.
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use rustproxy::RustProxy;
|
||||
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let options = RustProxyOptions {
|
||||
//! routes: vec![
|
||||
//! create_https_passthrough_route("example.com", "backend", 443),
|
||||
//! ],
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! let mut proxy = RustProxy::new(options)?;
|
||||
//! proxy.start().await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod challenge_server;
|
||||
pub mod management;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use anyhow::Result;
|
||||
use tracing::{info, warn, debug, error};
|
||||
|
||||
// Re-export key types
|
||||
pub use rustproxy_config;
|
||||
pub use rustproxy_routing;
|
||||
pub use rustproxy_passthrough;
|
||||
pub use rustproxy_tls;
|
||||
pub use rustproxy_http;
|
||||
pub use rustproxy_nftables;
|
||||
pub use rustproxy_metrics;
|
||||
pub use rustproxy_security;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
|
||||
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
|
||||
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use rustproxy_nftables::{NftManager, rule_builder};
|
||||
|
||||
/// Certificate status.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertStatus {
|
||||
pub domain: String,
|
||||
pub source: String,
|
||||
pub expires_at: u64,
|
||||
pub is_valid: bool,
|
||||
}
|
||||
|
||||
/// The main RustProxy struct.
|
||||
/// This is the primary public API matching SmartProxy's interface.
|
||||
pub struct RustProxy {
|
||||
options: RustProxyOptions,
|
||||
route_table: ArcSwap<RouteManager>,
|
||||
listener_manager: Option<TcpListenerManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
|
||||
challenge_server: Option<challenge_server::ChallengeServer>,
|
||||
renewal_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
nft_manager: Option<NftManager>,
|
||||
started: bool,
|
||||
started_at: Option<Instant>,
|
||||
/// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
|
||||
socket_handler_relay_path: Option<String>,
|
||||
}
|
||||
|
||||
impl RustProxy {
|
||||
/// Create a new RustProxy instance with the given configuration.
|
||||
pub fn new(mut options: RustProxyOptions) -> Result<Self> {
|
||||
// Apply defaults to routes before validation
|
||||
Self::apply_defaults(&mut options);
|
||||
|
||||
// Validate routes
|
||||
if let Err(errors) = rustproxy_config::validate_routes(&options.routes) {
|
||||
for err in &errors {
|
||||
warn!("Route validation error: {}", err);
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
anyhow::bail!("Route validation failed with {} errors", errors.len());
|
||||
}
|
||||
}
|
||||
|
||||
let route_manager = RouteManager::new(options.routes.clone());
|
||||
|
||||
// Set up certificate manager if ACME is configured
|
||||
let cert_manager = Self::build_cert_manager(&options)
|
||||
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
||||
|
||||
Ok(Self {
|
||||
options,
|
||||
route_table: ArcSwap::from(Arc::new(route_manager)),
|
||||
listener_manager: None,
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
cert_manager,
|
||||
challenge_server: None,
|
||||
renewal_handle: None,
|
||||
nft_manager: None,
|
||||
started: false,
|
||||
started_at: None,
|
||||
socket_handler_relay_path: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply default configuration to routes that lack targets or security.
|
||||
fn apply_defaults(options: &mut RustProxyOptions) {
|
||||
let defaults = match &options.defaults {
|
||||
Some(d) => d.clone(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
for route in &mut options.routes {
|
||||
// Apply default target if route has no targets
|
||||
if route.action.targets.is_none() {
|
||||
if let Some(ref default_target) = defaults.target {
|
||||
debug!("Applying default target {}:{} to route {:?}",
|
||||
default_target.host, default_target.port,
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.action.targets = Some(vec![
|
||||
rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
||||
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply default security if route has no security
|
||||
if route.security.is_none() {
|
||||
if let Some(ref default_security) = defaults.security {
|
||||
let mut security = rustproxy_config::RouteSecurity {
|
||||
ip_allow_list: None,
|
||||
ip_block_list: None,
|
||||
max_connections: default_security.max_connections,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
};
|
||||
|
||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||
security.ip_allow_list = Some(allow_list.clone());
|
||||
}
|
||||
if let Some(ref block_list) = default_security.ip_block_list {
|
||||
security.ip_block_list = Some(block_list.clone());
|
||||
}
|
||||
|
||||
// Only apply if there's something meaningful
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
debug!("Applying default security to route {:?}",
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.security = Some(security);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a CertManager from options.
|
||||
fn build_cert_manager(options: &RustProxyOptions) -> Option<CertManager> {
|
||||
let acme = options.acme.as_ref()?;
|
||||
if !acme.enabled.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let store_path = acme.certificate_store
|
||||
.as_deref()
|
||||
.unwrap_or("./certs");
|
||||
let email = acme.email.clone()
|
||||
.or_else(|| acme.account_email.clone());
|
||||
let use_production = acme.use_production.unwrap_or(false);
|
||||
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
|
||||
|
||||
let store = CertStore::new(store_path);
|
||||
Some(CertManager::new(store, email, use_production, renew_before_days))
|
||||
}
|
||||
|
||||
/// Build ConnectionConfig from RustProxyOptions.
|
||||
fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig {
|
||||
ConnectionConfig {
|
||||
connection_timeout_ms: options.effective_connection_timeout(),
|
||||
initial_data_timeout_ms: options.effective_initial_data_timeout(),
|
||||
socket_timeout_ms: options.effective_socket_timeout(),
|
||||
max_connection_lifetime_ms: options.effective_max_connection_lifetime(),
|
||||
graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000),
|
||||
max_connections_per_ip: options.max_connections_per_ip,
|
||||
connection_rate_limit_per_minute: options.connection_rate_limit_per_minute,
|
||||
keep_alive_treatment: options.keep_alive_treatment.clone(),
|
||||
keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier,
|
||||
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
|
||||
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
|
||||
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the proxy, binding to all configured ports.
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
if self.started {
|
||||
anyhow::bail!("Proxy is already started");
|
||||
}
|
||||
|
||||
info!("Starting RustProxy...");
|
||||
|
||||
// Load persisted certificates
|
||||
if let Some(ref cm) = self.cert_manager {
|
||||
let mut cm = cm.lock().await;
|
||||
match cm.load_all() {
|
||||
Ok(count) => {
|
||||
if count > 0 {
|
||||
info!("Loaded {} persisted certificates", count);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to load persisted certificates: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-provision certificates for routes with certificate: 'auto'
|
||||
self.auto_provision_certificates().await;
|
||||
|
||||
let route_manager = self.route_table.load();
|
||||
let ports = route_manager.listening_ports();
|
||||
|
||||
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
|
||||
|
||||
// Create TCP listener manager with metrics
|
||||
let mut listener = TcpListenerManager::with_metrics(
|
||||
Arc::clone(&*route_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
);
|
||||
|
||||
// Apply connection config from options
|
||||
let conn_config = Self::build_connection_config(&self.options);
|
||||
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
||||
conn_config.connection_timeout_ms,
|
||||
conn_config.initial_data_timeout_ms,
|
||||
conn_config.socket_timeout_ms,
|
||||
conn_config.max_connection_lifetime_ms,
|
||||
);
|
||||
listener.set_connection_config(conn_config);
|
||||
|
||||
// Extract TLS configurations from routes and cert manager
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
|
||||
// Also load certs from cert manager into TLS config
|
||||
if let Some(ref cm) = self.cert_manager {
|
||||
let cm = cm.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tls_configs.is_empty() {
|
||||
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
// Bind all ports
|
||||
for port in &ports {
|
||||
listener.add_port(*port).await?;
|
||||
}
|
||||
|
||||
self.listener_manager = Some(listener);
|
||||
self.started = true;
|
||||
self.started_at = Some(Instant::now());
|
||||
|
||||
// Apply NFTables rules for routes using nftables forwarding engine
|
||||
self.apply_nftables_rules(&self.options.routes.clone()).await;
|
||||
|
||||
// Start renewal timer if ACME is enabled
|
||||
self.start_renewal_timer();
|
||||
|
||||
info!("RustProxy started successfully on ports: {:?}", ports);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Auto-provision certificates for routes that use certificate: 'auto'.
|
||||
async fn auto_provision_certificates(&mut self) {
|
||||
let cm_arc = match self.cert_manager {
|
||||
Some(ref cm) => Arc::clone(cm),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let mut domains_to_provision = Vec::new();
|
||||
|
||||
for route in &self.options.routes {
|
||||
let tls_mode = route.tls_mode();
|
||||
let needs_cert = matches!(
|
||||
tls_mode,
|
||||
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
|
||||
);
|
||||
if !needs_cert {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Auto(_)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
let domain = domain.to_string();
|
||||
// Skip if we already have a valid cert
|
||||
let cm = cm_arc.lock().await;
|
||||
if cm.store().has(&domain) {
|
||||
debug!("Already have cert for {}, skipping auto-provision", domain);
|
||||
continue;
|
||||
}
|
||||
drop(cm);
|
||||
domains_to_provision.push(domain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if domains_to_provision.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut challenge_server = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = challenge_server.start(acme_port).await {
|
||||
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
|
||||
return;
|
||||
}
|
||||
|
||||
for domain in &domains_to_provision {
|
||||
info!("Provisioning certificate for {}", domain);
|
||||
|
||||
let cm = cm_arc.lock().await;
|
||||
let acme_client = cm.acme_client();
|
||||
drop(cm);
|
||||
|
||||
if let Some(acme_client) = acme_client {
|
||||
let challenge_server_ref = &challenge_server;
|
||||
let result = acme_client.provision(domain, |pending| {
|
||||
challenge_server_ref.set_challenge(
|
||||
pending.token.clone(),
|
||||
pending.key_authorization.clone(),
|
||||
);
|
||||
async move { Ok(()) }
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok((cert_pem, key_pem)) => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem,
|
||||
key_pem,
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.clone(),
|
||||
source: CertSource::Acme,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400, // 90 days
|
||||
renewed_at: None,
|
||||
},
|
||||
};
|
||||
|
||||
let mut cm = cm_arc.lock().await;
|
||||
if let Err(e) = cm.load_static(domain.clone(), bundle) {
|
||||
error!("Failed to store certificate for {}: {}", domain, e);
|
||||
}
|
||||
|
||||
info!("Certificate provisioned for {}", domain);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to provision certificate for {}: {}", domain, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
challenge_server.stop().await;
|
||||
}
|
||||
|
||||
/// Start the renewal timer background task.
|
||||
/// The background task checks for expiring certificates and renews them.
|
||||
fn start_renewal_timer(&mut self) {
|
||||
let cm_arc = match self.cert_manager {
|
||||
Some(ref cm) => Arc::clone(cm),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let auto_renew = self.options.acme.as_ref()
|
||||
.and_then(|a| a.auto_renew)
|
||||
.unwrap_or(true);
|
||||
|
||||
if !auto_renew {
|
||||
return;
|
||||
}
|
||||
|
||||
let check_interval_hours = self.options.acme.as_ref()
|
||||
.and_then(|a| a.renew_check_interval_hours)
|
||||
.unwrap_or(24);
|
||||
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours);
|
||||
|
||||
// Check which domains need renewal
|
||||
let domains = {
|
||||
let cm = cm_arc.lock().await;
|
||||
cm.check_renewals()
|
||||
};
|
||||
|
||||
if domains.is_empty() {
|
||||
debug!("No certificates need renewal");
|
||||
continue;
|
||||
}
|
||||
|
||||
info!("Renewing {} certificate(s)", domains.len());
|
||||
|
||||
// Start challenge server for renewals
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = cs.start(acme_port).await {
|
||||
error!("Failed to start challenge server for renewal: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
for domain in &domains {
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok(_bundle) => {
|
||||
info!("Successfully renewed certificate for {}", domain);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to renew certificate for {}: {}", domain, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cs.stop().await;
|
||||
}
|
||||
});
|
||||
|
||||
self.renewal_handle = Some(handle);
|
||||
}
|
||||
|
||||
/// Stop the proxy gracefully.
|
||||
pub async fn stop(&mut self) -> Result<()> {
|
||||
if !self.started {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Stopping RustProxy...");
|
||||
|
||||
// Stop renewal timer
|
||||
if let Some(handle) = self.renewal_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
// Stop challenge server if running
|
||||
if let Some(ref mut cs) = self.challenge_server {
|
||||
cs.stop().await;
|
||||
}
|
||||
self.challenge_server = None;
|
||||
|
||||
// Clean up NFTables rules
|
||||
if let Some(ref mut nft) = self.nft_manager {
|
||||
if let Err(e) = nft.cleanup().await {
|
||||
warn!("NFTables cleanup failed: {}", e);
|
||||
}
|
||||
}
|
||||
self.nft_manager = None;
|
||||
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.graceful_stop().await;
|
||||
}
|
||||
self.listener_manager = None;
|
||||
self.started = false;
|
||||
|
||||
info!("RustProxy stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update routes atomically (hot-reload).
|
||||
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
|
||||
// Validate new routes
|
||||
rustproxy_config::validate_routes(&routes)
|
||||
.map_err(|errors| {
|
||||
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
||||
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
||||
})?;
|
||||
|
||||
let new_manager = RouteManager::new(routes.clone());
|
||||
let new_ports = new_manager.listening_ports();
|
||||
|
||||
info!("Updating routes: {} routes on {} ports",
|
||||
new_manager.route_count(), new_ports.len());
|
||||
|
||||
// Get old ports
|
||||
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
|
||||
listener.listening_ports()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Atomically swap the route table
|
||||
let new_manager = Arc::new(new_manager);
|
||||
self.route_table.store(Arc::clone(&new_manager));
|
||||
|
||||
// Update listener manager
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.update_route_manager(Arc::clone(&new_manager));
|
||||
|
||||
// Update TLS configs
|
||||
let mut tls_configs = Self::extract_tls_configs(&routes);
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
listener.set_tls_configs(tls_configs);
|
||||
|
||||
// Add new ports
|
||||
for port in &new_ports {
|
||||
if !old_ports.contains(port) {
|
||||
listener.add_port(*port).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old ports no longer needed
|
||||
for port in &old_ports {
|
||||
if !new_ports.contains(port) {
|
||||
listener.remove_port(*port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update NFTables rules: remove old, apply new
|
||||
self.update_nftables_rules(&routes).await;
|
||||
|
||||
self.options.routes = routes;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Provision a certificate for a named route.
|
||||
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
let cm_arc = self.cert_manager.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
|
||||
|
||||
// Find the route by name
|
||||
let route = self.options.routes.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
|
||||
|
||||
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
cs.start(acme_port).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
|
||||
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(&domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
drop(cm);
|
||||
|
||||
cs.stop().await;
|
||||
|
||||
let bundle = result
|
||||
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||
|
||||
// Hot-swap into TLS configs
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Renew a certificate for a named route.
|
||||
pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
// Renewal is just re-provisioning
|
||||
self.provision_certificate(route_name).await
|
||||
}
|
||||
|
||||
/// Get the status of a certificate for a named route.
|
||||
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
|
||||
let route = self.options.routes.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
|
||||
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
if let Some(bundle) = cm.get_cert(&domain) {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
return Some(CertStatus {
|
||||
domain,
|
||||
source: format!("{:?}", bundle.metadata.source),
|
||||
expires_at: bundle.metadata.expires_at,
|
||||
is_valid: bundle.metadata.expires_at > now,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get current metrics snapshot.
|
||||
pub fn get_metrics(&self) -> Metrics {
|
||||
self.metrics.snapshot()
|
||||
}
|
||||
|
||||
/// Add a listening port at runtime.
|
||||
pub async fn add_listening_port(&mut self, port: u16) -> Result<()> {
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.add_port(port).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a listening port at runtime.
|
||||
pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> {
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.remove_port(port);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all currently listening ports.
|
||||
pub fn get_listening_ports(&self) -> Vec<u16> {
|
||||
self.listener_manager
|
||||
.as_ref()
|
||||
.map(|l| l.listening_ports())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get statistics snapshot.
|
||||
pub fn get_statistics(&self) -> Statistics {
|
||||
let uptime = self.started_at
|
||||
.map(|t| t.elapsed().as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
Statistics {
|
||||
active_connections: self.metrics.active_connections(),
|
||||
total_connections: self.metrics.total_connections(),
|
||||
routes_count: self.route_table.load().route_count() as u64,
|
||||
listening_ports: self.get_listening_ports(),
|
||||
uptime_seconds: uptime,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
|
||||
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
|
||||
info!("Socket handler relay path set to: {:?}", path);
|
||||
self.socket_handler_relay_path = path;
|
||||
}
|
||||
|
||||
/// Get the current socket handler relay path.
|
||||
pub fn get_socket_handler_relay_path(&self) -> Option<&str> {
|
||||
self.socket_handler_relay_path.as_deref()
|
||||
}
|
||||
|
||||
/// Load a certificate for a domain and hot-swap the TLS configuration.
|
||||
pub async fn load_certificate(
|
||||
&mut self,
|
||||
domain: &str,
|
||||
cert_pem: String,
|
||||
key_pem: String,
|
||||
ca_pem: Option<String>,
|
||||
) -> Result<()> {
|
||||
info!("Loading certificate for domain: {}", domain);
|
||||
|
||||
// Store in cert manager if available
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
ca_pem: ca_pem.clone(),
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400, // assume 90 days
|
||||
renewed_at: None,
|
||||
},
|
||||
};
|
||||
|
||||
let mut cm = cm_arc.lock().await;
|
||||
cm.load_static(domain.to_string(), bundle)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to store certificate: {}", e))?;
|
||||
}
|
||||
|
||||
// Hot-swap TLS config on the listener
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
|
||||
// Add the new cert
|
||||
tls_configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
});
|
||||
|
||||
// Also include all existing certs from cert manager
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
info!("Certificate loaded and TLS config updated for {}", domain);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get NFTables status.
|
||||
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
|
||||
match &self.nft_manager {
|
||||
Some(nft) => Ok(nft.status()),
|
||||
None => Ok(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply NFTables rules for routes using the nftables forwarding engine.
|
||||
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
|
||||
let nft_routes: Vec<&RouteConfig> = routes.iter()
|
||||
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
|
||||
.collect();
|
||||
|
||||
if nft_routes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Applying NFTables rules for {} routes", nft_routes.len());
|
||||
|
||||
let table_name = nft_routes.iter()
|
||||
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
|
||||
.unwrap_or_else(|| "rustproxy".to_string());
|
||||
|
||||
let mut nft = NftManager::new(Some(table_name));
|
||||
|
||||
for route in &nft_routes {
|
||||
let route_id = route.id.as_deref()
|
||||
.or(route.name.as_deref())
|
||||
.unwrap_or("unnamed");
|
||||
|
||||
let nft_options = match &route.action.nftables {
|
||||
Some(opts) => opts.clone(),
|
||||
None => rustproxy_config::NfTablesOptions {
|
||||
preserve_source_ip: None,
|
||||
protocol: None,
|
||||
max_rate: None,
|
||||
priority: None,
|
||||
table_name: None,
|
||||
use_ip_sets: None,
|
||||
use_advanced_nat: None,
|
||||
},
|
||||
};
|
||||
|
||||
let targets = match &route.action.targets {
|
||||
Some(targets) => targets,
|
||||
None => {
|
||||
warn!("NFTables route '{}' has no targets, skipping", route_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let source_ports = route.route_match.ports.to_ports();
|
||||
for target in targets {
|
||||
let target_host = target.host.first().to_string();
|
||||
let target_port_spec = &target.port;
|
||||
|
||||
for &source_port in &source_ports {
|
||||
let resolved_port = target_port_spec.resolve(source_port);
|
||||
let rules = rule_builder::build_dnat_rule(
|
||||
nft.table_name(),
|
||||
"prerouting",
|
||||
source_port,
|
||||
&target_host,
|
||||
resolved_port,
|
||||
&nft_options,
|
||||
);
|
||||
|
||||
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
|
||||
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
|
||||
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.nft_manager = Some(nft);
|
||||
}
|
||||
|
||||
/// Update NFTables rules when routes change.
|
||||
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
|
||||
// Clean up old rules
|
||||
if let Some(ref mut nft) = self.nft_manager {
|
||||
if let Err(e) = nft.cleanup().await {
|
||||
warn!("NFTables cleanup during update failed: {}", e);
|
||||
}
|
||||
}
|
||||
self.nft_manager = None;
|
||||
|
||||
// Apply new rules
|
||||
self.apply_nftables_rules(new_routes).await;
|
||||
}
|
||||
|
||||
/// Extract TLS configurations from route configs.
|
||||
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
|
||||
let mut configs = HashMap::new();
|
||||
|
||||
for route in routes {
|
||||
let tls_mode = route.tls_mode();
|
||||
let needs_cert = matches!(
|
||||
tls_mode,
|
||||
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
|
||||
);
|
||||
if !needs_cert {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_config.cert.clone(),
|
||||
key_pem: cert_config.key.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
configs
|
||||
}
|
||||
}
|
||||
90
rust/crates/rustproxy/src/main.rs
Normal file
90
rust/crates/rustproxy/src/main.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use anyhow::Result;
|
||||
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy::management;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
|
||||
/// RustProxy - High-performance multi-protocol proxy
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "rustproxy", version, about)]
|
||||
struct Cli {
|
||||
/// Path to JSON configuration file
|
||||
#[arg(short, long, default_value = "config.json")]
|
||||
config: String,
|
||||
|
||||
/// Log level (trace, debug, info, warn, error)
|
||||
#[arg(short, long, default_value = "info")]
|
||||
log_level: String,
|
||||
|
||||
/// Validate configuration without starting
|
||||
#[arg(long)]
|
||||
validate: bool,
|
||||
|
||||
/// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper)
|
||||
#[arg(long)]
|
||||
management: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize tracing - write to stderr so stdout is reserved for management IPC
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
|
||||
)
|
||||
.init();
|
||||
|
||||
// Management mode: JSON IPC over stdin/stdout
|
||||
if cli.management {
|
||||
tracing::info!("RustProxy starting in management mode...");
|
||||
return management::management_loop().await;
|
||||
}
|
||||
|
||||
tracing::info!("RustProxy starting...");
|
||||
|
||||
// Load configuration
|
||||
let options = RustProxyOptions::from_file(&cli.config)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
|
||||
|
||||
tracing::info!(
|
||||
"Loaded {} routes from {}",
|
||||
options.routes.len(),
|
||||
cli.config
|
||||
);
|
||||
|
||||
// Validate-only mode
|
||||
if cli.validate {
|
||||
match rustproxy_config::validate_routes(&options.routes) {
|
||||
Ok(()) => {
|
||||
tracing::info!("Configuration is valid");
|
||||
return Ok(());
|
||||
}
|
||||
Err(errors) => {
|
||||
for err in &errors {
|
||||
tracing::error!("Validation error: {}", err);
|
||||
}
|
||||
anyhow::bail!("{} validation errors found", errors.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create and start proxy
|
||||
let mut proxy = RustProxy::new(options)?;
|
||||
proxy.start().await?;
|
||||
|
||||
// Wait for shutdown signal
|
||||
tracing::info!("RustProxy is running. Press Ctrl+C to stop.");
|
||||
tokio::signal::ctrl_c().await?;
|
||||
|
||||
tracing::info!("Shutdown signal received");
|
||||
proxy.stop().await?;
|
||||
|
||||
tracing::info!("RustProxy shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
470
rust/crates/rustproxy/src/management.rs
Normal file
470
rust/crates/rustproxy/src/management.rs
Normal file
@@ -0,0 +1,470 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tracing::{info, error};
|
||||
|
||||
use crate::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
|
||||
/// A management request from the TypeScript wrapper.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ManagementRequest {
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A management response back to the TypeScript wrapper.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ManagementResponse {
|
||||
pub id: String,
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// An unsolicited event from the proxy to the TypeScript wrapper.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ManagementEvent {
|
||||
pub event: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
impl ManagementResponse {
|
||||
fn ok(id: String, result: serde_json::Value) -> Self {
|
||||
Self {
|
||||
id,
|
||||
success: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn err(id: String, message: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
success: false,
|
||||
result: None,
|
||||
error: Some(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_line(line: &str) {
|
||||
// Use blocking stdout write - we're writing short JSON lines
|
||||
use std::io::Write;
|
||||
let stdout = std::io::stdout();
|
||||
let mut handle = stdout.lock();
|
||||
let _ = handle.write_all(line.as_bytes());
|
||||
let _ = handle.write_all(b"\n");
|
||||
let _ = handle.flush();
|
||||
}
|
||||
|
||||
fn send_response(response: &ManagementResponse) {
|
||||
match serde_json::to_string(response) {
|
||||
Ok(json) => send_line(&json),
|
||||
Err(e) => error!("Failed to serialize management response: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_event(event: &str, data: serde_json::Value) {
|
||||
let evt = ManagementEvent {
|
||||
event: event.to_string(),
|
||||
data,
|
||||
};
|
||||
match serde_json::to_string(&evt) {
|
||||
Ok(json) => send_line(&json),
|
||||
Err(e) => error!("Failed to serialize management event: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the management loop, reading JSON commands from stdin and writing responses to stdout.
|
||||
pub async fn management_loop() -> Result<()> {
|
||||
let stdin = BufReader::new(tokio::io::stdin());
|
||||
let mut lines = stdin.lines();
|
||||
let mut proxy: Option<RustProxy> = None;
|
||||
|
||||
send_event("ready", serde_json::json!({}));
|
||||
|
||||
loop {
|
||||
let line = match lines.next_line().await {
|
||||
Ok(Some(line)) => line,
|
||||
Ok(None) => {
|
||||
// stdin closed - parent process exited
|
||||
info!("Management stdin closed, shutting down");
|
||||
if let Some(ref mut p) = proxy {
|
||||
let _ = p.stop().await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error reading management stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let request: ManagementRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to parse management request: {}", e);
|
||||
// Send error response without an ID
|
||||
send_response(&ManagementResponse::err(
|
||||
"unknown".to_string(),
|
||||
format!("Failed to parse request: {}", e),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_request(&request, &mut proxy).await;
|
||||
send_response(&response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
request: &ManagementRequest,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => handle_start(&id, &request.params, proxy).await,
|
||||
"stop" => handle_stop(&id, proxy).await,
|
||||
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
|
||||
"getMetrics" => handle_get_metrics(&id, proxy),
|
||||
"getStatistics" => handle_get_statistics(&id, proxy),
|
||||
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
|
||||
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
|
||||
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
|
||||
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
|
||||
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
|
||||
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
|
||||
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
|
||||
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
|
||||
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
|
||||
_ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_start(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
if proxy.is_some() {
|
||||
return ManagementResponse::err(id.to_string(), "Proxy is already running".to_string());
|
||||
}
|
||||
|
||||
let config = match params.get("config") {
|
||||
Some(config) => config,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
|
||||
};
|
||||
|
||||
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
|
||||
Ok(o) => o,
|
||||
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
|
||||
};
|
||||
|
||||
match RustProxy::new(options) {
|
||||
Ok(mut p) => {
|
||||
match p.start().await {
|
||||
Ok(()) => {
|
||||
send_event("started", serde_json::json!({}));
|
||||
*proxy = Some(p);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => {
|
||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stop(
|
||||
id: &str,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_mut() {
|
||||
Some(p) => {
|
||||
match p.stop().await {
|
||||
Ok(()) => {
|
||||
*proxy = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_update_routes(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let routes = match params.get("routes") {
|
||||
Some(routes) => routes,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
|
||||
};
|
||||
|
||||
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid routes: {}", e)),
|
||||
};
|
||||
|
||||
match p.update_routes(routes).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_metrics(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let metrics = p.get_metrics();
|
||||
match serde_json::to_value(&metrics) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_statistics(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let stats = p.get_statistics();
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_provision_certificate(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.provision_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_renew_certificate(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.renew_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_certificate_status(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_ref() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.get_certificate_status(route_name).await {
|
||||
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
|
||||
"domain": status.domain,
|
||||
"source": status.source,
|
||||
"expiresAt": status.expires_at,
|
||||
"isValid": status.is_valid,
|
||||
})),
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_listening_ports(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let ports = p.get_listening_ports();
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": ports }))
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": [] })),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_nftables_status(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
match p.get_nftables_status().await {
|
||||
Ok(status) => {
|
||||
match serde_json::to_value(&status) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_set_socket_handler_relay(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let socket_path = params.get("socketPath")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
info!("setSocketHandlerRelay: socket_path={:?}", socket_path);
|
||||
p.set_socket_handler_relay_path(socket_path);
|
||||
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
|
||||
async fn handle_add_listening_port(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.add_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_remove_listening_port(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.remove_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_load_certificate(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let domain = match params.get("domain").and_then(|v| v.as_str()) {
|
||||
Some(d) => d.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
|
||||
};
|
||||
|
||||
let cert = match params.get("cert").and_then(|v| v.as_str()) {
|
||||
Some(c) => c.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
|
||||
};
|
||||
|
||||
let key = match params.get("key").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
|
||||
};
|
||||
|
||||
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||
|
||||
info!("loadCertificate: domain={}", domain);
|
||||
|
||||
// Load cert into cert manager and hot-swap TLS config
|
||||
match p.load_certificate(&domain, cert, key, ca).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
|
||||
}
|
||||
}
|
||||
402
rust/crates/rustproxy/tests/common/mod.rs
Normal file
402
rust/crates/rustproxy/tests/common/mod.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Atomic port allocator starting at 19000 to avoid collisions.
|
||||
static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000);
|
||||
|
||||
/// Get the next available port for testing.
|
||||
pub fn next_port() -> u16 {
|
||||
PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Start a simple TCP echo server that echoes back whatever it receives.
|
||||
/// Returns the join handle for the server task.
|
||||
pub async fn start_echo_server(port: u16) -> JoinHandle<()> {
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind echo server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if stream.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Start a TCP echo server that prefixes responses to identify which backend responded.
|
||||
pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> {
|
||||
let prefix = prefix.to_string();
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind prefix echo server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let pfx = prefix.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
let mut response = pfx.as_bytes().to_vec();
|
||||
response.extend_from_slice(&buf[..n]);
|
||||
if stream.write_all(&response).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Start a simple HTTP server that responds with a fixed status and body.
|
||||
pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> {
|
||||
let body = body.to_string();
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind HTTP server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let b = body.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 8192];
|
||||
// Read the request
|
||||
let _n = stream.read(&mut buf).await.unwrap_or(0);
|
||||
// Send response
|
||||
let response = format!(
|
||||
"HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
status,
|
||||
b.len(),
|
||||
b,
|
||||
);
|
||||
let _ = stream.write_all(response.as_bytes()).await;
|
||||
let _ = stream.shutdown().await;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Start an HTTP backend server that echoes back request details as JSON.
|
||||
/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":"<name>"}
|
||||
/// Supports keep-alive by reading HTTP requests properly.
|
||||
pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> {
|
||||
let name = backend_name.to_string();
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port));
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let backend = name.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 16384];
|
||||
// Read request data
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => n,
|
||||
};
|
||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Parse first line: METHOD PATH HTTP/x.x
|
||||
let first_line = req_str.lines().next().unwrap_or("");
|
||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||
let method = parts.first().copied().unwrap_or("UNKNOWN");
|
||||
let path = parts.get(1).copied().unwrap_or("/");
|
||||
|
||||
// Extract Host header
|
||||
let host = req_str.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("host:"))
|
||||
.map(|l| l[5..].trim())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let body = format!(
|
||||
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
|
||||
method, path, host, backend
|
||||
);
|
||||
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
body.len(),
|
||||
body,
|
||||
);
|
||||
let _ = stream.write_all(response.as_bytes()).await;
|
||||
let _ = stream.shutdown().await;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Wrap a future with a timeout, preventing tests from hanging.
|
||||
pub async fn with_timeout<F, T>(future: F, secs: u64) -> Result<T, &'static str>
|
||||
where
|
||||
F: std::future::Future<Output = T>,
|
||||
{
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await {
|
||||
Ok(result) => Ok(result),
|
||||
Err(_) => Err("Test timed out"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait briefly for a server to be ready by attempting TCP connections.
|
||||
pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool {
|
||||
let start = std::time::Instant::now();
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
while start.elapsed() < timeout {
|
||||
if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Helper to create a minimal route config for testing.
|
||||
pub fn make_test_route(
|
||||
port: u16,
|
||||
domain: Option<&str>,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
rustproxy_config::RouteConfig {
|
||||
id: None,
|
||||
route_match: rustproxy_config::RouteMatch {
|
||||
ports: rustproxy_config::PortRange::Single(port),
|
||||
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: rustproxy_config::RouteAction {
|
||||
action_type: rustproxy_config::RouteActionType::Forward,
|
||||
targets: Some(vec![rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(target_host.to_string()),
|
||||
port: rustproxy_config::PortSpec::Fixed(target_port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a simple WebSocket echo backend.
|
||||
///
|
||||
/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back,
|
||||
/// then echoes all data received on the connection.
|
||||
pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port));
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
// Read the HTTP upgrade request
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => n,
|
||||
};
|
||||
|
||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Extract Sec-WebSocket-Key for proper handshake
|
||||
let ws_key = req_str.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
||||
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Compute Sec-WebSocket-Accept (simplified - just echo for test purposes)
|
||||
// Real implementation would compute SHA-1 + base64
|
||||
let accept_response = format!(
|
||||
"HTTP/1.1 101 Switching Protocols\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-Accept: {}\r\n\
|
||||
\r\n",
|
||||
ws_key
|
||||
);
|
||||
|
||||
if stream.write_all(accept_response.as_bytes()).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Echo all data back (raw TCP after upgrade)
|
||||
let mut echo_buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match stream.read(&mut echo_buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if stream.write_all(&echo_buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a self-signed certificate for testing using rcgen.
|
||||
/// Returns (cert_pem, key_pem).
|
||||
pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
|
||||
(cert.pem(), key_pair.serialize_pem())
|
||||
}
|
||||
|
||||
/// Start a TLS echo server using the given cert/key.
|
||||
/// Returns the join handle.
|
||||
pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
|
||||
use std::sync::Arc;
|
||||
|
||||
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
||||
.expect("Failed to build TLS acceptor");
|
||||
let acceptor = Arc::new(acceptor);
|
||||
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind TLS echo server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let acc = acceptor.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut tls_stream = match acc.accept(stream).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match tls_stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if tls_stream.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper to create a TLS terminate route with static cert for testing.
|
||||
pub fn make_tls_terminate_route(
|
||||
port: u16,
|
||||
domain: &str,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
let mut route = make_test_route(port, Some(domain), target_host, target_port);
|
||||
route.action.tls = Some(rustproxy_config::RouteTls {
|
||||
mode: rustproxy_config::TlsMode::Terminate,
|
||||
certificate: Some(rustproxy_config::CertificateSpec::Static(
|
||||
rustproxy_config::CertificateConfig {
|
||||
cert: cert_pem.to_string(),
|
||||
key: key_pem.to_string(),
|
||||
ca: None,
|
||||
key_file: None,
|
||||
cert_file: None,
|
||||
},
|
||||
)),
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Helper to create a TLS passthrough route for testing.
|
||||
pub fn make_tls_passthrough_route(
|
||||
port: u16,
|
||||
domain: Option<&str>,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
let mut route = make_test_route(port, domain, target_host, target_port);
|
||||
route.action.tls = Some(rustproxy_config::RouteTls {
|
||||
mode: rustproxy_config::TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
453
rust/crates/rustproxy/tests/integration_http_proxy.rs
Normal file
453
rust/crates/rustproxy/tests/integration_http_proxy.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
/// Send a raw HTTP request and return the full response as a string.
|
||||
async fn send_http_request(port: u16, host: &str, method: &str, path: &str) -> String {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request = format!(
|
||||
"{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
method, path, host,
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}
|
||||
|
||||
/// Extract the body from a raw HTTP response string (after the \r\n\r\n).
|
||||
fn extract_body(response: &str) -> &str {
|
||||
response.split("\r\n\r\n").nth(1).unwrap_or("")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_basic() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_http_echo_backend(backend_port, "main").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||
let body = extract_body(&response);
|
||||
body.to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
|
||||
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
|
||||
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_host_routing() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_http_echo_backend(backend1_port, "alpha").await;
|
||||
let _b2 = start_http_echo_backend(backend2_port, "beta").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
|
||||
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let alpha_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
|
||||
|
||||
// Test beta domain
|
||||
let beta_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_path_routing() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_http_echo_backend(backend1_port, "api").await;
|
||||
let _b2 = start_http_echo_backend(backend2_port, "web").await;
|
||||
|
||||
let mut api_route = make_test_route(proxy_port, None, "127.0.0.1", backend1_port);
|
||||
api_route.route_match.path = Some("/api/**".to_string());
|
||||
api_route.priority = Some(10);
|
||||
|
||||
let web_route = make_test_route(proxy_port, None, "127.0.0.1", backend2_port);
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![api_route, web_route],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test API path
|
||||
let api_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
|
||||
|
||||
// Test web path (no /api prefix)
|
||||
let web_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_cors_preflight() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_http_echo_backend(backend_port, "main").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send CORS preflight request
|
||||
let request = format!(
|
||||
"OPTIONS /api/data HTTP/1.1\r\nHost: example.com\r\nOrigin: http://localhost:3000\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 204 No Content with CORS headers
|
||||
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
|
||||
assert!(result.to_lowercase().contains("access-control-allow-origin"),
|
||||
"Expected CORS header, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_backend_error() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Start an HTTP server that returns 500
|
||||
let _backend = start_http_server(backend_port, 500, "Internal Error").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||
response
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Proxy should relay the 500 from backend
|
||||
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_no_route_matched() {
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Create a route only for a specific domain
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (no route matched)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_backend_unavailable() {
|
||||
let proxy_port = next_port();
|
||||
let dead_port = next_port(); // No server running here
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (backend unavailable)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_https_terminate_http_forward() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "httpproxy.example.com";
|
||||
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
let _backend = start_http_echo_backend(backend_port, "tls-backend").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send HTTP request through TLS
|
||||
let request = format!(
|
||||
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
domain
|
||||
);
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = extract_body(&result);
|
||||
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
|
||||
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
|
||||
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_websocket_through_proxy() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_ws_echo_backend(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send WebSocket upgrade request
|
||||
let request = format!(
|
||||
"GET /ws HTTP/1.1\r\n\
|
||||
Host: example.com\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
\r\n"
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
// Read the 101 response
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
let n = stream.read(&mut temp).await.unwrap();
|
||||
if n == 0 { break; }
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
|
||||
assert!(
|
||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||
"Expected Upgrade header, got: {}",
|
||||
response_str
|
||||
);
|
||||
|
||||
// After upgrade, send data and verify echo
|
||||
let test_data = b"Hello WebSocket!";
|
||||
stream.write_all(test_data).await.unwrap();
|
||||
|
||||
// Read echoed data
|
||||
let mut echo_buf = vec![0u8; 256];
|
||||
let n = stream.read(&mut echo_buf).await.unwrap();
|
||||
let echoed = &echo_buf[..n];
|
||||
|
||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||
|
||||
"ok".to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "ok");
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
/// InsecureVerifier for test TLS client connections.
|
||||
#[derive(Debug)]
|
||||
struct InsecureVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
]
|
||||
}
|
||||
}
|
||||
250
rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs
Normal file
250
rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_start_and_stop() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
|
||||
// Not listening before start
|
||||
assert!(!wait_for_port(port, 200).await);
|
||||
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
|
||||
// Give the OS a moment to release the port
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_double_start_fails() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Second start should fail
|
||||
let result = proxy.start().await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("already started"));
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_routes_hot_reload() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Update routes atomically
|
||||
let new_routes = vec![
|
||||
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
|
||||
];
|
||||
let result = proxy.update_routes(new_routes).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_remove_listening_port() {
|
||||
let port1 = next_port();
|
||||
let port2 = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port1, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port1, 2000).await);
|
||||
|
||||
// Add a new port
|
||||
proxy.add_listening_port(port2).await.unwrap();
|
||||
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
|
||||
|
||||
// Remove the port
|
||||
proxy.remove_listening_port(port2).await.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
|
||||
|
||||
// Original port should still be listening
|
||||
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_statistics() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert_eq!(stats.routes_count, 1);
|
||||
assert!(stats.listening_ports.contains(&port));
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_routes_rejected() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![{
|
||||
let mut route = make_test_route(80, None, "127.0.0.1", 8080);
|
||||
route.action.targets = None; // Invalid: forward without targets
|
||||
route
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = RustProxy::new(options);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metrics_track_connections() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// No connections yet
|
||||
let stats = proxy.get_statistics();
|
||||
assert_eq!(stats.total_connections, 0);
|
||||
|
||||
// Make a connection and send data
|
||||
{
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello").await.unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
let _ = stream.read(&mut buf).await;
|
||||
}
|
||||
|
||||
// Small delay for metrics to update
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metrics_track_bytes() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_http_echo_backend(backend_port, "metrics-test").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send HTTP request through proxy
|
||||
{
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let request = b"GET /test HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n";
|
||||
stream.write_all(request).await.unwrap();
|
||||
let mut response = Vec::new();
|
||||
stream.read_to_end(&mut response).await.unwrap();
|
||||
assert!(!response.is_empty(), "Expected non-empty response");
|
||||
}
|
||||
|
||||
// Small delay for metrics to update
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0,
|
||||
"Expected some connections tracked, got {}", stats.total_connections);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hot_reload_port_changes() {
|
||||
let port1 = next_port();
|
||||
let port2 = next_port();
|
||||
let backend_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
// Start with port1
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port1, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port1, 2000).await);
|
||||
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
|
||||
|
||||
// Update routes to use port2 instead
|
||||
let new_routes = vec![
|
||||
make_test_route(port2, None, "127.0.0.1", backend_port),
|
||||
];
|
||||
proxy.update_routes(new_routes).await.unwrap();
|
||||
|
||||
// Port2 should now be listening, port1 should be closed
|
||||
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
|
||||
|
||||
// Verify port2 works
|
||||
let ports = proxy.get_listening_ports();
|
||||
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
|
||||
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
197
rust/crates/rustproxy/tests/integration_tcp_passthrough.rs
Normal file
197
rust/crates/rustproxy/tests/integration_tcp_passthrough.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_echo() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Start echo backend
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
// Configure proxy
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Wait for proxy to be ready
|
||||
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
|
||||
|
||||
// Connect and send data
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello world").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "hello world");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_large_payload() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'A'; 1_000_000];
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
|
||||
// Read all back
|
||||
let mut received = Vec::new();
|
||||
stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, 1_000_000);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_multiple_connections() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let msg = format!("connection-{}", i);
|
||||
stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 10);
|
||||
for (i, r) in result.iter().enumerate() {
|
||||
assert_eq!(r, &format!("connection-{}", i));
|
||||
}
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_backend_unreachable() {
|
||||
let proxy_port = next_port();
|
||||
let dead_port = next_port(); // No server on this port
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connection should complete (proxy accepts it) but data should not flow
|
||||
let result = with_timeout(async {
|
||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||
stream.is_ok()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result, "Should be able to connect to proxy even if backend is down");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_bidirectional() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Start a prefix echo server to verify data flows in both directions
|
||||
let _backend = start_prefix_echo_server(backend_port, "REPLY:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"test data").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "REPLY:test data");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
247
rust/crates/rustproxy/tests/integration_tls_passthrough.rs
Normal file
247
rust/crates/rustproxy/tests/integration_tls_passthrough.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Build a minimal TLS ClientHello with the given SNI domain.
|
||||
/// This is enough for the proxy's SNI parser to extract the domain.
|
||||
fn build_client_hello(domain: &str) -> Vec<u8> {
|
||||
let domain_bytes = domain.as_bytes();
|
||||
let sni_length = domain_bytes.len() as u16;
|
||||
|
||||
// Server Name extension (type 0x0000)
|
||||
let mut sni_ext = Vec::new();
|
||||
sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name
|
||||
let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name
|
||||
sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length
|
||||
sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length
|
||||
sni_ext.push(0x00); // host_name type
|
||||
sni_ext.extend_from_slice(&sni_length.to_be_bytes());
|
||||
sni_ext.extend_from_slice(domain_bytes);
|
||||
|
||||
let extensions_length = sni_ext.len() as u16;
|
||||
|
||||
// ClientHello message
|
||||
let mut client_hello = Vec::new();
|
||||
client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version
|
||||
client_hello.extend_from_slice(&[0x00; 32]); // random
|
||||
client_hello.push(0x00); // session_id length
|
||||
client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite)
|
||||
client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null)
|
||||
client_hello.extend_from_slice(&extensions_length.to_be_bytes());
|
||||
client_hello.extend_from_slice(&sni_ext);
|
||||
|
||||
let hello_len = client_hello.len() as u32;
|
||||
|
||||
// Handshake wrapper (type 1 = ClientHello)
|
||||
let mut handshake = Vec::new();
|
||||
handshake.push(0x01); // ClientHello
|
||||
handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length
|
||||
handshake.extend_from_slice(&client_hello);
|
||||
|
||||
let hs_len = handshake.len() as u16;
|
||||
|
||||
// TLS record
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // ContentType: Handshake
|
||||
record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version)
|
||||
record.extend_from_slice(&hs_len.to_be_bytes());
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
record
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_sni_routing() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await;
|
||||
let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send a fake ClientHello with SNI "one.example.com"
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("one.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Backend1 should have received the ClientHello and prefixed its response
|
||||
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
|
||||
|
||||
// Now test routing to backend2
|
||||
let result2 = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("two.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_unknown_sni() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("unknown.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
// Should either get 0 bytes (closed) or an error
|
||||
match stream.read(&mut buf).await {
|
||||
Ok(0) => true, // Connection closed = no route matched
|
||||
Ok(_) => false, // Got data = route shouldn't have matched
|
||||
Err(_) => true, // Error = connection dropped
|
||||
}
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result, "Unknown SNI should result in dropped connection");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_wildcard_domain() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Should match any subdomain of example.com
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("anything.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_multiple_domains() {
|
||||
let b1_port = next_port();
|
||||
let b2_port = next_port();
|
||||
let b3_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_prefix_echo_server(b1_port, "B1:").await;
|
||||
let _b2 = start_prefix_echo_server(b2_port, "B2:").await;
|
||||
let _b3 = start_prefix_echo_server(b3_port, "B3:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
for (domain, expected_prefix) in [
|
||||
("alpha.example.com", "B1:"),
|
||||
("beta.example.com", "B2:"),
|
||||
("gamma.example.com", "B3:"),
|
||||
] {
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello(domain);
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.starts_with(expected_prefix),
|
||||
"Domain {} should route to {}, got: {}",
|
||||
domain, expected_prefix, result
|
||||
);
|
||||
}
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
324
rust/crates/rustproxy/tests/integration_tls_terminate.rs
Normal file
324
rust/crates/rustproxy/tests/integration_tls_terminate.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
/// Create a rustls client config that trusts self-signed certs.
|
||||
fn make_insecure_tls_client_config() -> Arc<rustls::ClientConfig> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
Arc::new(config)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InsecureVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_basic() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "test.example.com";
|
||||
|
||||
// Generate self-signed cert
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
|
||||
// Start plain TCP echo backend (proxy terminates TLS, sends plain to backend)
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connect with TLS client
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello TLS").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "hello TLS");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_and_reencrypt() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "reencrypt.example.com";
|
||||
let backend_domain = "backend.internal";
|
||||
|
||||
// Generate certs
|
||||
let (proxy_cert, proxy_key) = generate_self_signed_cert(domain);
|
||||
let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain);
|
||||
|
||||
// Start TLS echo backend
|
||||
let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await;
|
||||
|
||||
// Create terminate-and-reencrypt route
|
||||
let mut route = make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
|
||||
);
|
||||
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![route],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello reencrypt").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "hello reencrypt");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_sni_cert_selection() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let (cert1, key1) = generate_self_signed_cert("alpha.example.com");
|
||||
let (cert2, key2) = generate_self_signed_cert("beta.example.com");
|
||||
|
||||
let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await;
|
||||
let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
|
||||
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"test").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_large_payload() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "large.example.com";
|
||||
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'X'; 1_000_000];
|
||||
tls_stream.write_all(&data).await.unwrap();
|
||||
tls_stream.shutdown().await.unwrap();
|
||||
|
||||
let mut received = Vec::new();
|
||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 15)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, 1_000_000);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_concurrent() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "concurrent.example.com";
|
||||
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
let dom = domain.to_string();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let msg = format!("conn-{}", i);
|
||||
tls_stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 15)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 10);
|
||||
for (i, r) in result.iter().enumerate() {
|
||||
assert_eq!(r, &format!("conn-{}", i));
|
||||
}
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
Reference in New Issue
Block a user