Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e806f7257f | |||
| af4908b63f | |||
| 8fa3a51b03 | |||
| 088ef6ab09 | |||
| fdb5ec59bc | |||
| 1ea290a085 | |||
| cb71f32b90 | |||
| 46155ab12c | |||
| 490a310b54 | |||
| 6c5180573a | |||
| 30e5ab308f | |||
| d2a54b3491 |
@@ -1,5 +1,44 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
|
||||
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
|
||||
|
||||
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
|
||||
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
|
||||
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
|
||||
|
||||
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
|
||||
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
|
||||
|
||||
- adds a bounded retention window for closed IP buckets so short-lived transfers are still included in per-IP throughput sampling
|
||||
- prunes expired inactive IP tracking by TTL and hard cap to prevent unbounded metric map growth
|
||||
- updates Rust and throughput tests to expect zero active connections during the temporary retention period
|
||||
|
||||
## 2026-04-26 - 27.8.1 - fix(rustproxy-metrics)
|
||||
preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated
|
||||
|
||||
- Select snapshot IPs using a blend of active-connection and throughput rankings instead of only active connections
|
||||
- Adds a regression test to ensure a high-bandwidth IP remains included when many other IPs have more active connections
|
||||
|
||||
## 2026-04-14 - 27.8.0 - feat(metrics)
|
||||
add per-domain HTTP request rate metrics
|
||||
|
||||
- Record canonicalized HTTP request rates per domain in the Rust metrics collector and expose per-second and last-minute values in snapshots.
|
||||
- Add TypeScript metrics interfaces and adapter support for requests.byDomain().
|
||||
- Cover HTTP domain rate tracking and ensure TLS passthrough SNI traffic does not affect HTTP request rate metrics.
|
||||
|
||||
## 2026-04-14 - 27.7.4 - fix(rustproxy metrics)
|
||||
use stable route metrics keys across HTTP and passthrough listeners
|
||||
|
||||
- adds a shared RouteConfig::metrics_key helper that prefers route name and falls back to route id
|
||||
- updates HTTP, TCP, UDP, and QUIC metrics labeling to use the shared route metrics key consistently
|
||||
- keeps route cancellation and rate limiter indexing bound to route config ids where required
|
||||
- adds tests covering metrics key selection behavior
|
||||
|
||||
## 2026-04-14 - 27.7.3 - fix(repo)
|
||||
no changes detected
|
||||
|
||||
|
||||
## 2026-04-14 - 27.7.2 - fix(docs)
|
||||
clarify metrics documentation for domain normalization and saturating gauges
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartproxy",
|
||||
"version": "27.7.2",
|
||||
"version": "27.9.0",
|
||||
"private": false,
|
||||
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
||||
"main": "dist_ts/index.js",
|
||||
|
||||
Generated
+1
@@ -1319,6 +1319,7 @@ dependencies = [
|
||||
"rustproxy-http",
|
||||
"rustproxy-metrics",
|
||||
"rustproxy-routing",
|
||||
"rustproxy-security",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket2 0.5.10",
|
||||
|
||||
@@ -3,15 +3,15 @@
|
||||
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
|
||||
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
|
||||
|
||||
pub mod route_types;
|
||||
pub mod proxy_options;
|
||||
pub mod tls_types;
|
||||
pub mod route_types;
|
||||
pub mod security_types;
|
||||
pub mod tls_types;
|
||||
pub mod validation;
|
||||
|
||||
// Re-export all primary types
|
||||
pub use route_types::*;
|
||||
pub use proxy_options::*;
|
||||
pub use tls_types::*;
|
||||
pub use route_types::*;
|
||||
pub use security_types::*;
|
||||
pub use tls_types::*;
|
||||
pub use validation::*;
|
||||
|
||||
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
|
||||
pub retention_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// Global ingress security policy.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SecurityPolicy {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blocked_ips: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blocked_cidrs: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// RustProxy configuration options.
|
||||
/// Matches TypeScript: `ISmartProxyOptions`
|
||||
///
|
||||
@@ -235,6 +245,10 @@ pub struct RustProxyOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metrics: Option<MetricsConfig>,
|
||||
|
||||
/// Global ingress security policy, enforced before route selection.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub security_policy: Option<SecurityPolicy>,
|
||||
|
||||
// ─── ACME ────────────────────────────────────────────────────────
|
||||
/// Global ACME configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -275,6 +289,7 @@ impl Default for RustProxyOptions {
|
||||
use_http_proxy: None,
|
||||
http_proxy_port: None,
|
||||
metrics: None,
|
||||
security_policy: None,
|
||||
acme: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -656,6 +656,11 @@ impl RouteConfig {
|
||||
self.route_match.ports.to_ports()
|
||||
}
|
||||
|
||||
/// Stable key used for frontend route-scoped metrics.
|
||||
pub fn metrics_key(&self) -> Option<&str> {
|
||||
self.name.as_deref().or(self.id.as_deref())
|
||||
}
|
||||
|
||||
/// Get the TLS mode for this route (from action-level or first target).
|
||||
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
|
||||
// Check action-level TLS first
|
||||
@@ -673,3 +678,63 @@ impl RouteConfig {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_route(name: Option<&str>, id: Option<&str>) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: id.map(str::to_string),
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(443),
|
||||
transport: None,
|
||||
domains: None,
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: name.map(str::to_string),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_key_prefers_name() {
|
||||
let route = test_route(Some("named-route"), Some("route-id"));
|
||||
|
||||
assert_eq!(route.metrics_key(), Some("named-route"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_key_falls_back_to_id() {
|
||||
let route = test_route(None, Some("route-id"));
|
||||
|
||||
assert_eq!(route.metrics_key(), Some("route-id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_key_is_absent_without_name_or_id() {
|
||||
let route = test_route(None, None);
|
||||
|
||||
assert_eq!(route.metrics_key(), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,10 +111,7 @@ pub enum IpAllowEntry {
|
||||
/// Plain IP/CIDR — allowed for all domains on this route
|
||||
Plain(String),
|
||||
/// Domain-scoped — allowed only when the requested domain matches
|
||||
DomainScoped {
|
||||
ip: String,
|
||||
domains: Vec<String>,
|
||||
},
|
||||
DomainScoped { ip: String, domains: Vec<String> },
|
||||
}
|
||||
|
||||
/// Security options for routes.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::route_types::{RouteConfig, RouteActionType};
|
||||
use crate::route_types::{RouteActionType, RouteConfig};
|
||||
|
||||
/// Validation errors for route configurations.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -30,9 +30,10 @@ pub enum ValidationError {
|
||||
/// Validate a single route configuration.
|
||||
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
|
||||
let mut errors = Vec::new();
|
||||
let name = route.name.clone().unwrap_or_else(|| {
|
||||
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
|
||||
});
|
||||
let name = route
|
||||
.name
|
||||
.clone()
|
||||
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
|
||||
|
||||
// Check ports
|
||||
let ports = route.listening_ports();
|
||||
@@ -160,7 +161,9 @@ mod tests {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = None;
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -168,7 +171,9 @@ mod tests {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = Some(vec![]);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -176,7 +181,9 @@ mod tests {
|
||||
let mut route = make_valid_route();
|
||||
route.route_match.ports = PortRange::Single(0);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -186,7 +193,9 @@ mod tests {
|
||||
let mut r2 = make_valid_route();
|
||||
r2.id = Some("route-1".to_string());
|
||||
let errors = validate_routes(&[r1, r2]).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -639,10 +639,12 @@ impl HttpProxyService {
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
let route_config_id = route_match.route.id.as_deref();
|
||||
let route_id = route_match.route.metrics_key();
|
||||
let ip_str = ip_string; // reuse from above (avoid redundant to_string())
|
||||
self.metrics.record_http_request();
|
||||
if let Some(ref h) = host {
|
||||
self.metrics.record_http_domain_request(h);
|
||||
self.metrics.record_ip_domain_request(&ip_str, h);
|
||||
}
|
||||
|
||||
@@ -654,7 +656,7 @@ impl HttpProxyService {
|
||||
.as_ref()
|
||||
.filter(|rl| rl.enabled)
|
||||
.map(|rl| {
|
||||
let route_key = route_id.unwrap_or("__default__").to_string();
|
||||
let route_key = route_config_id.unwrap_or("__default__").to_string();
|
||||
self.route_rate_limiters
|
||||
.entry(route_key)
|
||||
.or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window)))
|
||||
|
||||
@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::throughput::{ThroughputSample, ThroughputTracker};
|
||||
use crate::throughput::{RequestRateTracker, ThroughputSample, ThroughputTracker};
|
||||
|
||||
/// Aggregated metrics snapshot.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -26,6 +26,7 @@ pub struct Metrics {
|
||||
pub total_http_requests: u64,
|
||||
pub http_requests_per_sec: u64,
|
||||
pub http_requests_per_sec_recent: u64,
|
||||
pub http_domain_requests: std::collections::HashMap<String, HttpDomainRequestMetrics>,
|
||||
// UDP metrics
|
||||
pub active_udp_sessions: u64,
|
||||
pub total_udp_sessions: u64,
|
||||
@@ -66,6 +67,14 @@ pub struct IpMetrics {
|
||||
pub domain_requests: HashMap<String, u64>,
|
||||
}
|
||||
|
||||
/// Per-domain HTTP request rate metrics.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HttpDomainRequestMetrics {
|
||||
pub requests_per_second: u64,
|
||||
pub requests_last_minute: u64,
|
||||
}
|
||||
|
||||
/// Per-backend metrics (keyed by "host:port").
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -135,15 +144,24 @@ pub struct Statistics {
|
||||
/// Default retention for throughput samples (1 hour).
|
||||
const DEFAULT_RETENTION_SECONDS: usize = 3600;
|
||||
|
||||
/// Maximum number of IPs to include in a snapshot (top by active connections).
|
||||
/// Maximum number of IPs to include in a snapshot.
|
||||
const MAX_IPS_IN_SNAPSHOT: usize = 100;
|
||||
|
||||
/// How long to retain inactive IP metric buckets after the last connection closes.
|
||||
const INACTIVE_IP_RETENTION_MS: u64 = 15_000;
|
||||
|
||||
/// Hard cap for inactive IP metric buckets retained between sampler passes.
|
||||
const MAX_INACTIVE_IPS_RETAINED: usize = MAX_IPS_IN_SNAPSHOT * 10;
|
||||
|
||||
/// Maximum number of backends to include in a snapshot (top by total connections).
|
||||
const MAX_BACKENDS_IN_SNAPSHOT: usize = 100;
|
||||
|
||||
/// Maximum number of distinct domains tracked per IP (prevents subdomain-spray abuse).
|
||||
const MAX_DOMAINS_PER_IP: usize = 256;
|
||||
|
||||
/// Number of one-second HTTP request samples retained per domain.
|
||||
const HTTP_DOMAIN_REQUEST_WINDOW_SECONDS: usize = 60;
|
||||
|
||||
fn canonicalize_domain_key(domain: &str) -> Option<String> {
|
||||
let normalized = domain.trim().trim_end_matches('.').to_ascii_lowercase();
|
||||
if normalized.is_empty() {
|
||||
@@ -153,6 +171,13 @@ fn canonicalize_domain_key(domain: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
fn current_time_ms() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Metrics collector tracking connections and throughput.
|
||||
///
|
||||
/// Design: The hot path (`record_bytes`) is entirely lock-free — it only touches
|
||||
@@ -174,6 +199,7 @@ pub struct MetricsCollector {
|
||||
|
||||
// ── Per-IP tracking ──
|
||||
ip_connections: DashMap<String, AtomicU64>,
|
||||
ip_closed_at_ms: DashMap<String, AtomicU64>,
|
||||
ip_total_connections: DashMap<String, AtomicU64>,
|
||||
ip_bytes_in: DashMap<String, AtomicU64>,
|
||||
ip_bytes_out: DashMap<String, AtomicU64>,
|
||||
@@ -201,6 +227,7 @@ pub struct MetricsCollector {
|
||||
total_http_requests: AtomicU64,
|
||||
pending_http_requests: AtomicU64,
|
||||
http_request_throughput: Mutex<ThroughputTracker>,
|
||||
http_domain_request_rates: DashMap<String, Mutex<RequestRateTracker>>,
|
||||
|
||||
// ── UDP metrics ──
|
||||
active_udp_sessions: AtomicU64,
|
||||
@@ -260,6 +287,7 @@ impl MetricsCollector {
|
||||
route_bytes_in: DashMap::new(),
|
||||
route_bytes_out: DashMap::new(),
|
||||
ip_connections: DashMap::new(),
|
||||
ip_closed_at_ms: DashMap::new(),
|
||||
ip_total_connections: DashMap::new(),
|
||||
ip_bytes_in: DashMap::new(),
|
||||
ip_bytes_out: DashMap::new(),
|
||||
@@ -284,6 +312,7 @@ impl MetricsCollector {
|
||||
total_http_requests: AtomicU64::new(0),
|
||||
pending_http_requests: AtomicU64::new(0),
|
||||
http_request_throughput: Mutex::new(ThroughputTracker::new(retention_seconds)),
|
||||
http_domain_request_rates: DashMap::new(),
|
||||
frontend_h1_active: AtomicU64::new(0),
|
||||
frontend_h1_total: AtomicU64::new(0),
|
||||
frontend_h2_active: AtomicU64::new(0),
|
||||
@@ -334,6 +363,7 @@ impl MetricsCollector {
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.ip_closed_at_ms.remove(ip);
|
||||
self.ip_total_connections
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
@@ -378,22 +408,88 @@ impl MetricsCollector {
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
// Clean up zero-count entries to prevent memory growth
|
||||
// Keep inactive IP buckets briefly so pending bytes can still
|
||||
// be sampled into per-IP throughput after short-lived transfers.
|
||||
if matches!(prev, Some(v) if v <= 1) {
|
||||
drop(counter);
|
||||
self.ip_connections.remove(ip);
|
||||
// Evict all per-IP tracking data for this IP
|
||||
self.ip_total_connections.remove(ip);
|
||||
self.ip_bytes_in.remove(ip);
|
||||
self.ip_bytes_out.remove(ip);
|
||||
self.ip_pending_tp.remove(ip);
|
||||
self.ip_throughput.remove(ip);
|
||||
self.ip_domain_requests.remove(ip);
|
||||
self.ip_closed_at_ms
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.store(current_time_ms(), Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_inactive_ip_tracking(&self, ip: &str) {
|
||||
if self
|
||||
.ip_connections
|
||||
.get(ip)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
> 0
|
||||
{
|
||||
self.ip_closed_at_ms.remove(ip);
|
||||
return;
|
||||
}
|
||||
|
||||
self.ip_connections.remove(ip);
|
||||
self.ip_closed_at_ms.remove(ip);
|
||||
self.ip_total_connections.remove(ip);
|
||||
self.ip_bytes_in.remove(ip);
|
||||
self.ip_bytes_out.remove(ip);
|
||||
self.ip_pending_tp.remove(ip);
|
||||
self.ip_throughput.remove(ip);
|
||||
self.ip_domain_requests.remove(ip);
|
||||
}
|
||||
|
||||
fn prune_inactive_ip_tracking(&self, now_ms: u64) {
|
||||
let cutoff_ms = now_ms.saturating_sub(INACTIVE_IP_RETENTION_MS);
|
||||
let mut inactive_ips: Vec<(String, u64)> = Vec::new();
|
||||
let mut active_markers: Vec<String> = Vec::new();
|
||||
|
||||
for entry in self.ip_closed_at_ms.iter() {
|
||||
let ip = entry.key().clone();
|
||||
let active = self
|
||||
.ip_connections
|
||||
.get(&ip)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
if active > 0 {
|
||||
active_markers.push(ip);
|
||||
} else {
|
||||
inactive_ips.push((ip, entry.value().load(Ordering::Relaxed)));
|
||||
}
|
||||
}
|
||||
|
||||
for ip in active_markers {
|
||||
self.ip_closed_at_ms.remove(&ip);
|
||||
}
|
||||
|
||||
let mut remove_ips: HashSet<String> = inactive_ips
|
||||
.iter()
|
||||
.filter(|(_, closed_at_ms)| *closed_at_ms <= cutoff_ms)
|
||||
.map(|(ip, _)| ip.clone())
|
||||
.collect();
|
||||
|
||||
let retained_after_ttl = inactive_ips.len().saturating_sub(remove_ips.len());
|
||||
if retained_after_ttl > MAX_INACTIVE_IPS_RETAINED {
|
||||
inactive_ips.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
|
||||
let mut overflow = retained_after_ttl - MAX_INACTIVE_IPS_RETAINED;
|
||||
for (ip, closed_at_ms) in inactive_ips {
|
||||
if overflow == 0 {
|
||||
break;
|
||||
}
|
||||
if closed_at_ms > cutoff_ms && remove_ips.insert(ip) {
|
||||
overflow -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ip in remove_ips {
|
||||
self.remove_inactive_ip_tracking(&ip);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes transferred (lock-free hot path).
|
||||
///
|
||||
/// Called per-chunk in the TCP copy loop. Only touches AtomicU64 counters —
|
||||
@@ -467,9 +563,8 @@ impl MetricsCollector {
|
||||
|
||||
// Per-IP tracking: same get()-first pattern to avoid String allocation on hot path.
|
||||
if let Some(ip) = source_ip {
|
||||
// Only record per-IP stats if the IP still has active connections.
|
||||
// This prevents orphaned entries when record_bytes races with
|
||||
// connection_closed (which evicts all per-IP data on last close).
|
||||
// Only record per-IP stats if the IP is active or still within the
|
||||
// bounded inactive retention window after its last connection closed.
|
||||
if self.ip_connections.contains_key(ip) {
|
||||
if bytes_in > 0 {
|
||||
if let Some(counter) = self.ip_bytes_in.get(ip) {
|
||||
@@ -522,6 +617,24 @@ impl MetricsCollector {
|
||||
self.pending_http_requests.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record a real HTTP request for a canonicalized domain.
|
||||
pub fn record_http_domain_request(&self, domain: &str) {
|
||||
let Some(domain) = canonicalize_domain_key(domain) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.http_domain_request_rates
|
||||
.entry(domain.clone())
|
||||
.or_insert_with(|| {
|
||||
Mutex::new(RequestRateTracker::new(HTTP_DOMAIN_REQUEST_WINDOW_SECONDS))
|
||||
});
|
||||
if let Some(tracker_ref) = self.http_domain_request_rates.get(domain.as_str()) {
|
||||
if let Ok(mut tracker) = tracker_ref.value().lock() {
|
||||
tracker.record_event();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a domain request/connection for a frontend IP.
|
||||
///
|
||||
/// Called per HTTP request (with Host header) and per TCP passthrough
|
||||
@@ -791,8 +904,7 @@ impl MetricsCollector {
|
||||
/// Take a throughput sample on all trackers (cold path, call at 1Hz or configured interval).
|
||||
///
|
||||
/// Drains the lock-free pending counters and feeds the accumulated bytes
|
||||
/// into the throughput trackers (under Mutex). This is the only place
|
||||
/// the Mutex is locked.
|
||||
/// into the throughput trackers under their sampling mutexes.
|
||||
pub fn sample_all(&self) {
|
||||
// Drain global pending bytes and feed into the tracker
|
||||
let global_in = self.global_pending_tp_in.swap(0, Ordering::Relaxed);
|
||||
@@ -873,9 +985,27 @@ impl MetricsCollector {
|
||||
tracker.sample();
|
||||
}
|
||||
|
||||
// Advance HTTP domain request windows and prune fully idle domains.
|
||||
let mut stale_http_domains = Vec::new();
|
||||
for entry in self.http_domain_request_rates.iter() {
|
||||
if let Ok(mut tracker) = entry.value().lock() {
|
||||
tracker.advance_to_now();
|
||||
if tracker.is_idle() {
|
||||
stale_http_domains.push(entry.key().clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
for domain in stale_http_domains {
|
||||
self.http_domain_request_rates.remove(&domain);
|
||||
}
|
||||
|
||||
// Keep closed IP buckets only for a short, bounded window so the sampler
|
||||
// can attribute short-lived transfers without leaking per-IP maps.
|
||||
self.prune_inactive_ip_tracking(current_time_ms());
|
||||
|
||||
// Safety-net: prune orphaned per-IP entries that have no corresponding
|
||||
// ip_connections entry. This catches any entries created by a race between
|
||||
// record_bytes and connection_closed.
|
||||
// ip_connections entry. This catches any entries created by older races or
|
||||
// by code paths that manually inserted partial per-IP state.
|
||||
self.ip_bytes_in
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_bytes_out
|
||||
@@ -888,6 +1018,8 @@ impl MetricsCollector {
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_domain_requests
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_closed_at_ms
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
|
||||
// Safety-net: prune orphaned backend error/stats entries for backends
|
||||
// that have no active or total connections (error-only backends).
|
||||
@@ -1019,8 +1151,8 @@ impl MetricsCollector {
|
||||
);
|
||||
}
|
||||
|
||||
// Collect per-IP metrics — only IPs with active connections or total > 0,
|
||||
// capped at top MAX_IPS_IN_SNAPSHOT sorted by active count
|
||||
// Collect per-IP metrics — capped to the IPs most relevant for either
|
||||
// active connection visibility or bandwidth attribution.
|
||||
let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap<String, u64>)> =
|
||||
Vec::new();
|
||||
for entry in self.ip_total_connections.iter() {
|
||||
@@ -1068,9 +1200,54 @@ impl MetricsCollector {
|
||||
domain_requests,
|
||||
));
|
||||
}
|
||||
// Sort by active connections descending, then cap
|
||||
ip_entries.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
ip_entries.truncate(MAX_IPS_IN_SNAPSHOT);
|
||||
if ip_entries.len() > MAX_IPS_IN_SNAPSHOT {
|
||||
let mut selected = vec![false; ip_entries.len()];
|
||||
let mut selected_count = 0usize;
|
||||
|
||||
let mut active_rank: Vec<usize> = (0..ip_entries.len()).collect();
|
||||
active_rank.sort_by(|&a, &b| {
|
||||
ip_entries[b]
|
||||
.1
|
||||
.cmp(&ip_entries[a].1)
|
||||
.then_with(|| ip_entries[b].2.cmp(&ip_entries[a].2))
|
||||
.then_with(|| ip_entries[a].0.cmp(&ip_entries[b].0))
|
||||
});
|
||||
|
||||
let mut throughput_rank: Vec<usize> = (0..ip_entries.len()).collect();
|
||||
throughput_rank.sort_by(|&a, &b| {
|
||||
let a_tp = ip_entries[a].5.saturating_add(ip_entries[a].6);
|
||||
let b_tp = ip_entries[b].5.saturating_add(ip_entries[b].6);
|
||||
let a_bytes = ip_entries[a].3.saturating_add(ip_entries[a].4);
|
||||
let b_bytes = ip_entries[b].3.saturating_add(ip_entries[b].4);
|
||||
b_tp.cmp(&a_tp)
|
||||
.then_with(|| b_bytes.cmp(&a_bytes))
|
||||
.then_with(|| ip_entries[b].1.cmp(&ip_entries[a].1))
|
||||
.then_with(|| ip_entries[a].0.cmp(&ip_entries[b].0))
|
||||
});
|
||||
|
||||
for idx in active_rank.into_iter().take(MAX_IPS_IN_SNAPSHOT / 2) {
|
||||
if !selected[idx] {
|
||||
selected[idx] = true;
|
||||
selected_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for idx in throughput_rank {
|
||||
if selected_count >= MAX_IPS_IN_SNAPSHOT {
|
||||
break;
|
||||
}
|
||||
if !selected[idx] {
|
||||
selected[idx] = true;
|
||||
selected_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
ip_entries = ip_entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, entry)| selected[idx].then_some(entry))
|
||||
.collect();
|
||||
}
|
||||
|
||||
let mut ips = std::collections::HashMap::new();
|
||||
for (ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests) in ip_entries {
|
||||
@@ -1179,6 +1356,24 @@ impl MetricsCollector {
|
||||
})
|
||||
.unwrap_or((0, 0));
|
||||
|
||||
let mut http_domain_requests = std::collections::HashMap::new();
|
||||
for entry in self.http_domain_request_rates.iter() {
|
||||
if let Ok(mut tracker) = entry.value().lock() {
|
||||
tracker.advance_to_now();
|
||||
let requests_per_second = tracker.last_second();
|
||||
let requests_last_minute = tracker.last_minute();
|
||||
if requests_per_second > 0 || requests_last_minute > 0 {
|
||||
http_domain_requests.insert(
|
||||
entry.key().clone(),
|
||||
HttpDomainRequestMetrics {
|
||||
requests_per_second,
|
||||
requests_last_minute,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Metrics {
|
||||
active_connections: self.active_connections(),
|
||||
total_connections: self.total_connections(),
|
||||
@@ -1195,6 +1390,7 @@ impl MetricsCollector {
|
||||
total_http_requests: self.total_http_requests.load(Ordering::Relaxed),
|
||||
http_requests_per_sec: http_rps,
|
||||
http_requests_per_sec_recent: http_rps_recent,
|
||||
http_domain_requests,
|
||||
active_udp_sessions: self.active_udp_sessions.load(Ordering::Relaxed),
|
||||
total_udp_sessions: self.total_udp_sessions.load(Ordering::Relaxed),
|
||||
total_datagrams_in: self.total_datagrams_in.load(Ordering::Relaxed),
|
||||
@@ -1383,9 +1579,42 @@ mod tests {
|
||||
1
|
||||
);
|
||||
|
||||
// Close last connection for IP — should be cleaned up
|
||||
// Close last connection for IP — active count should drop to zero,
|
||||
// while the inactive bucket is retained briefly for final sampling.
|
||||
collector.connection_closed(Some("route-a"), Some("1.2.3.4"));
|
||||
assert!(collector.ip_connections.get("1.2.3.4").is_none());
|
||||
assert_eq!(
|
||||
collector
|
||||
.ip_connections
|
||||
.get("1.2.3.4")
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0),
|
||||
0
|
||||
);
|
||||
assert!(collector.ip_closed_at_ms.get("1.2.3.4").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snapshot_retains_high_throughput_ip_over_many_active_ips() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
|
||||
for i in 1..=(MAX_IPS_IN_SNAPSHOT + 20) {
|
||||
let ip = format!("10.0.0.{}", i);
|
||||
collector.connection_opened(Some("scanner-route"), Some(&ip));
|
||||
collector.connection_opened(Some("scanner-route"), Some(&ip));
|
||||
}
|
||||
|
||||
let busy_ip = "203.0.113.10";
|
||||
collector.connection_opened(Some("download-route"), Some(busy_ip));
|
||||
collector.record_bytes(0, 900_000, Some("download-route"), Some(busy_ip));
|
||||
collector.sample_all();
|
||||
|
||||
let snapshot = collector.snapshot();
|
||||
let busy_metrics = snapshot.ips.get(busy_ip).unwrap();
|
||||
|
||||
assert_eq!(snapshot.ips.len(), MAX_IPS_IN_SNAPSHOT);
|
||||
assert_eq!(busy_metrics.active_connections, 1);
|
||||
assert_eq!(busy_metrics.bytes_out, 900_000);
|
||||
assert_eq!(busy_metrics.throughput_out_bytes_per_sec, 900_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1411,13 +1640,36 @@ mod tests {
|
||||
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
|
||||
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
|
||||
|
||||
// All per-IP data for 10.0.0.1 should be evicted
|
||||
// Per-IP data for 10.0.0.1 should be retained until the inactive TTL expires.
|
||||
assert_eq!(
|
||||
collector
|
||||
.ip_connections
|
||||
.get("10.0.0.1")
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0),
|
||||
0
|
||||
);
|
||||
assert!(collector.ip_total_connections.get("10.0.0.1").is_some());
|
||||
assert!(collector.ip_bytes_in.get("10.0.0.1").is_some());
|
||||
assert!(collector.ip_bytes_out.get("10.0.0.1").is_some());
|
||||
assert!(collector.ip_pending_tp.get("10.0.0.1").is_some());
|
||||
assert!(collector.ip_throughput.get("10.0.0.1").is_some());
|
||||
|
||||
collector
|
||||
.ip_closed_at_ms
|
||||
.get("10.0.0.1")
|
||||
.unwrap()
|
||||
.store(0, Ordering::Relaxed);
|
||||
collector.sample_all();
|
||||
|
||||
// Expired inactive buckets are fully evicted to prevent leaks.
|
||||
assert!(collector.ip_connections.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_total_connections.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_bytes_in.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_bytes_out.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_pending_tp.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_throughput.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_closed_at_ms.get("10.0.0.1").is_none());
|
||||
|
||||
// 10.0.0.2 should still have data
|
||||
assert!(collector.ip_connections.get("10.0.0.2").is_some());
|
||||
@@ -1444,7 +1696,14 @@ mod tests {
|
||||
.unwrap_or(0),
|
||||
0
|
||||
);
|
||||
assert!(collector.ip_connections.get("10.0.0.1").is_none());
|
||||
assert_eq!(
|
||||
collector
|
||||
.ip_connections
|
||||
.get("10.0.0.1")
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1514,6 +1773,47 @@ mod tests {
|
||||
assert_eq!(snapshot.http_requests_per_sec, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_domain_request_rates_are_canonicalized() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
|
||||
collector.record_http_domain_request("Example.COM");
|
||||
collector.record_http_domain_request("example.com.");
|
||||
collector.record_http_domain_request(" example.com ");
|
||||
|
||||
let now_sec = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
if let Some(tracker) = collector.http_domain_request_rates.get("example.com") {
|
||||
tracker.value().lock().unwrap().advance_to(now_sec + 1);
|
||||
}
|
||||
|
||||
let snapshot = collector.snapshot();
|
||||
let metrics = snapshot.http_domain_requests.get("example.com").unwrap();
|
||||
assert_eq!(snapshot.http_domain_requests.len(), 1);
|
||||
assert_eq!(metrics.requests_per_second, 3);
|
||||
assert_eq!(metrics.requests_last_minute, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_domain_requests_do_not_affect_http_domain_request_rates() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
|
||||
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
|
||||
collector.record_ip_domain_request("10.0.0.1", "example.com");
|
||||
|
||||
let snapshot = collector.snapshot();
|
||||
assert!(snapshot.http_domain_requests.is_empty());
|
||||
assert_eq!(
|
||||
snapshot
|
||||
.ips
|
||||
.get("10.0.0.1")
|
||||
.and_then(|ip| ip.domain_requests.get("example.com")),
|
||||
Some(&1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retain_routes_prunes_stale() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
@@ -1544,26 +1844,52 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_bytes_after_close_no_orphan() {
|
||||
fn test_closed_ip_keeps_pending_throughput_until_sample() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
|
||||
// Open a connection, record bytes, then close
|
||||
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
|
||||
collector.record_bytes(100, 200, Some("route-a"), Some("10.0.0.1"));
|
||||
collector.record_bytes(100, 2000, Some("route-a"), Some("10.0.0.1"));
|
||||
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
|
||||
|
||||
// IP should be fully evicted
|
||||
assert!(collector.ip_connections.get("10.0.0.1").is_none());
|
||||
collector.sample_all();
|
||||
let snapshot = collector.snapshot();
|
||||
let ip_metrics = snapshot.ips.get("10.0.0.1").unwrap();
|
||||
|
||||
// Now record_bytes arrives late (simulates race) — should NOT re-create entries
|
||||
collector.record_bytes(50, 75, Some("route-a"), Some("10.0.0.1"));
|
||||
assert!(collector.ip_bytes_in.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_bytes_out.get("10.0.0.1").is_none());
|
||||
assert!(collector.ip_pending_tp.get("10.0.0.1").is_none());
|
||||
assert_eq!(ip_metrics.active_connections, 0);
|
||||
assert_eq!(ip_metrics.bytes_in, 100);
|
||||
assert_eq!(ip_metrics.bytes_out, 2000);
|
||||
assert_eq!(ip_metrics.throughput_in_bytes_per_sec, 100);
|
||||
assert_eq!(ip_metrics.throughput_out_bytes_per_sec, 2000);
|
||||
assert_eq!(snapshot.throughput_out_bytes_per_sec, 2000);
|
||||
assert_eq!(
|
||||
snapshot
|
||||
.routes
|
||||
.get("route-a")
|
||||
.unwrap()
|
||||
.throughput_out_bytes_per_sec,
|
||||
2000
|
||||
);
|
||||
}
|
||||
|
||||
// Global bytes should still be counted
|
||||
assert_eq!(collector.total_bytes_in.load(Ordering::Relaxed), 150);
|
||||
assert_eq!(collector.total_bytes_out.load(Ordering::Relaxed), 275);
|
||||
#[test]
|
||||
fn test_inactive_ip_retention_hard_cap_prunes_oldest() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
|
||||
for i in 0..(MAX_INACTIVE_IPS_RETAINED + 5) {
|
||||
let ip = format!("198.51.{}.{}", i / 255, i % 255);
|
||||
collector.connection_opened(Some("route-a"), Some(&ip));
|
||||
collector.connection_closed(Some("route-a"), Some(&ip));
|
||||
collector
|
||||
.ip_closed_at_ms
|
||||
.get(&ip)
|
||||
.unwrap()
|
||||
.store(current_time_ms().saturating_sub(1), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
collector.sample_all();
|
||||
|
||||
assert!(collector.ip_closed_at_ms.len() <= MAX_INACTIVE_IPS_RETAINED);
|
||||
assert!(collector.ip_connections.len() <= MAX_INACTIVE_IPS_RETAINED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
//!
|
||||
//! Metrics and throughput tracking for RustProxy.
|
||||
|
||||
pub mod throughput;
|
||||
pub mod collector;
|
||||
pub mod log_dedup;
|
||||
pub mod throughput;
|
||||
|
||||
pub use throughput::*;
|
||||
pub use collector::*;
|
||||
pub use log_dedup::*;
|
||||
pub use throughput::*;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
@@ -47,13 +47,16 @@ impl LogDeduplicator {
|
||||
let map_key = format!("{}:{}", category, key);
|
||||
let now = Instant::now();
|
||||
|
||||
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
|
||||
category: category.to_string(),
|
||||
first_message: message.to_string(),
|
||||
count: AtomicU64::new(0),
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
});
|
||||
let entry = self
|
||||
.events
|
||||
.entry(map_key)
|
||||
.or_insert_with(|| AggregatedEvent {
|
||||
category: category.to_string(),
|
||||
first_message: message.to_string(),
|
||||
count: AtomicU64::new(0),
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
});
|
||||
|
||||
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
|
||||
@@ -29,6 +29,113 @@ pub struct ThroughputTracker {
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
fn unix_timestamp_seconds() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
/// Circular buffer for per-second event counts.
|
||||
///
|
||||
/// Unlike `ThroughputTracker`, events are recorded directly into the current
|
||||
/// second so request counts remain stable even when the collector is sampled
|
||||
/// more frequently than once per second.
|
||||
pub(crate) struct RequestRateTracker {
|
||||
samples: Vec<u64>,
|
||||
write_index: usize,
|
||||
count: usize,
|
||||
capacity: usize,
|
||||
current_second: Option<u64>,
|
||||
current_count: u64,
|
||||
}
|
||||
|
||||
impl RequestRateTracker {
|
||||
pub(crate) fn new(retention_seconds: usize) -> Self {
|
||||
Self {
|
||||
samples: Vec::with_capacity(retention_seconds.max(1)),
|
||||
write_index: 0,
|
||||
count: 0,
|
||||
capacity: retention_seconds.max(1),
|
||||
current_second: None,
|
||||
current_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_sample(&mut self, count: u64) {
|
||||
if self.samples.len() < self.capacity {
|
||||
self.samples.push(count);
|
||||
} else {
|
||||
self.samples[self.write_index] = count;
|
||||
}
|
||||
self.write_index = (self.write_index + 1) % self.capacity;
|
||||
self.count = (self.count + 1).min(self.capacity);
|
||||
}
|
||||
|
||||
pub(crate) fn record_event(&mut self) {
|
||||
self.record_events_at(unix_timestamp_seconds(), 1);
|
||||
}
|
||||
|
||||
pub(crate) fn record_events_at(&mut self, now_sec: u64, count: u64) {
|
||||
self.advance_to(now_sec);
|
||||
self.current_count = self.current_count.saturating_add(count);
|
||||
}
|
||||
|
||||
pub(crate) fn advance_to_now(&mut self) {
|
||||
self.advance_to(unix_timestamp_seconds());
|
||||
}
|
||||
|
||||
pub(crate) fn advance_to(&mut self, now_sec: u64) {
|
||||
match self.current_second {
|
||||
Some(current_second) if now_sec > current_second => {
|
||||
self.push_sample(self.current_count);
|
||||
for _ in 1..(now_sec - current_second) {
|
||||
self.push_sample(0);
|
||||
}
|
||||
self.current_second = Some(now_sec);
|
||||
self.current_count = 0;
|
||||
}
|
||||
Some(_) => {}
|
||||
None => {
|
||||
self.current_second = Some(now_sec);
|
||||
self.current_count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_recent(&self, window_seconds: usize) -> u64 {
|
||||
let window = window_seconds.min(self.count);
|
||||
if window == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut total = 0u64;
|
||||
for i in 0..window {
|
||||
let idx = if self.write_index >= i + 1 {
|
||||
self.write_index - i - 1
|
||||
} else {
|
||||
self.capacity - (i + 1 - self.write_index)
|
||||
};
|
||||
if idx < self.samples.len() {
|
||||
total += self.samples[idx];
|
||||
}
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
pub(crate) fn last_second(&self) -> u64 {
|
||||
self.sum_recent(1)
|
||||
}
|
||||
|
||||
pub(crate) fn last_minute(&self) -> u64 {
|
||||
self.sum_recent(60)
|
||||
}
|
||||
|
||||
pub(crate) fn is_idle(&self) -> bool {
|
||||
self.current_count == 0 && self.sum_recent(self.capacity) == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl ThroughputTracker {
|
||||
/// Create a new tracker with the given capacity (seconds of retention).
|
||||
pub fn new(retention_seconds: usize) -> Self {
|
||||
@@ -46,7 +153,8 @@ impl ThroughputTracker {
|
||||
/// Record bytes (called from data flow callbacks).
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
|
||||
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
self.pending_bytes_out
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Take a sample (called at 1Hz).
|
||||
@@ -229,4 +337,41 @@ mod tests {
|
||||
let history = tracker.history(10);
|
||||
assert!(history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_rate_tracker_counts_last_second_and_last_minute() {
|
||||
let mut tracker = RequestRateTracker::new(60);
|
||||
|
||||
tracker.record_events_at(100, 2);
|
||||
tracker.record_events_at(100, 3);
|
||||
tracker.advance_to(101);
|
||||
|
||||
assert_eq!(tracker.last_second(), 5);
|
||||
assert_eq!(tracker.last_minute(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_rate_tracker_adds_zero_samples_for_gaps() {
|
||||
let mut tracker = RequestRateTracker::new(60);
|
||||
|
||||
tracker.record_events_at(100, 4);
|
||||
tracker.record_events_at(102, 1);
|
||||
tracker.advance_to(103);
|
||||
|
||||
assert_eq!(tracker.last_second(), 1);
|
||||
assert_eq!(tracker.last_minute(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_rate_tracker_decays_to_zero_over_window() {
|
||||
let mut tracker = RequestRateTracker::new(60);
|
||||
|
||||
tracker.record_events_at(100, 7);
|
||||
tracker.advance_to(101);
|
||||
tracker.advance_to(161);
|
||||
|
||||
assert_eq!(tracker.last_second(), 0);
|
||||
assert_eq!(tracker.last_minute(), 0);
|
||||
assert!(tracker.is_idle());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
|
||||
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
let matches = entry.domain.as_deref()
|
||||
let matches = entry
|
||||
.domain
|
||||
.as_deref()
|
||||
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
||||
.unwrap_or(false);
|
||||
if matches {
|
||||
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
if entry.route_id.as_deref() == Some(route_id) {
|
||||
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
|
||||
if !RequestFilter::check_ip_security(
|
||||
new_security,
|
||||
&entry.source_ip,
|
||||
entry.domain.as_deref(),
|
||||
) {
|
||||
info!(
|
||||
"Terminating connection from {} — IP now blocked on route '{}'",
|
||||
entry.source_ip, route_id
|
||||
|
||||
@@ -31,7 +31,8 @@ impl ConnectionTracker {
|
||||
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
||||
// Check per-IP connection limit
|
||||
if let Some(max) = self.max_per_ip {
|
||||
let count = self.active
|
||||
let count = self
|
||||
.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
@@ -48,7 +49,10 @@ impl ConnectionTracker {
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove timestamps older than 1 minute
|
||||
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
||||
while timestamps
|
||||
.front()
|
||||
.is_some_and(|t| now.duration_since(*t) >= one_minute)
|
||||
{
|
||||
timestamps.pop_front();
|
||||
}
|
||||
|
||||
@@ -111,7 +115,6 @@ impl ConnectionTracker {
|
||||
pub fn tracked_ips(&self) -> usize {
|
||||
self.active.len()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref ctx) = metrics {
|
||||
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
data.len() as u64,
|
||||
0,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some(ref ctx) = metrics_c2b {
|
||||
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
n as u64,
|
||||
0,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
backend_write.shutdown(),
|
||||
).await;
|
||||
let _ =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some(ref ctx) = metrics_b2c {
|
||||
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
0,
|
||||
n as u64,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
client_write.shutdown(),
|
||||
).await;
|
||||
let _ =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
|
||||
@@ -4,26 +4,26 @@
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
|
||||
//! and UDP datagram session tracking with forwarding.
|
||||
|
||||
pub mod tcp_listener;
|
||||
pub mod sni_parser;
|
||||
pub mod connection_registry;
|
||||
pub mod connection_tracker;
|
||||
pub mod forwarder;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls_handler;
|
||||
pub mod connection_tracker;
|
||||
pub mod connection_registry;
|
||||
pub mod socket_opts;
|
||||
pub mod udp_session;
|
||||
pub mod udp_listener;
|
||||
pub mod quic_handler;
|
||||
pub mod sni_parser;
|
||||
pub mod socket_opts;
|
||||
pub mod tcp_listener;
|
||||
pub mod tls_handler;
|
||||
pub mod udp_listener;
|
||||
pub mod udp_session;
|
||||
|
||||
pub use tcp_listener::*;
|
||||
pub use sni_parser::*;
|
||||
pub use connection_registry::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use forwarder::*;
|
||||
pub use proxy_protocol::*;
|
||||
pub use tls_handler::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use connection_registry::*;
|
||||
pub use socket_opts::*;
|
||||
pub use udp_session::*;
|
||||
pub use udp_listener::*;
|
||||
pub use quic_handler::*;
|
||||
pub use sni_parser::*;
|
||||
pub use socket_opts::*;
|
||||
pub use tcp_listener::*;
|
||||
pub use tls_handler::*;
|
||||
pub use udp_listener::*;
|
||||
pub use udp_session::*;
|
||||
|
||||
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
.position(|w| w == b"\r\n")
|
||||
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
let line = std::str::from_utf8(&data[..line_end])
|
||||
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
let line =
|
||||
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
if !line.starts_with("PROXY ") {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let command = data[12] & 0x0F;
|
||||
// 0x0 = LOCAL, 0x1 = PROXY
|
||||
if command > 1 {
|
||||
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
|
||||
return Err(ProxyProtocolError::Parse(format!(
|
||||
"Unknown command: {}",
|
||||
command
|
||||
)));
|
||||
}
|
||||
|
||||
// Address family (high nibble) + transport (low nibble) of byte 13
|
||||
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET (0x1) + STREAM (0x1) = TCP4
|
||||
(0x1, 0x1) => {
|
||||
if addr_len < 12 {
|
||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv4 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET (0x1) + DGRAM (0x2) = UDP4
|
||||
(0x1, 0x2) => {
|
||||
if addr_len < 12 {
|
||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv4 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
|
||||
(0x2, 0x1) => {
|
||||
if addr_len < 36 {
|
||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv6 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
|
||||
(0x2, 0x2) => {
|
||||
if addr_len < 36 {
|
||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv6 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
}
|
||||
|
||||
/// Generate a PROXY protocol v2 binary header.
|
||||
pub fn generate_v2(
|
||||
source: &SocketAddr,
|
||||
dest: &SocketAddr,
|
||||
transport: ProxyV2Transport,
|
||||
) -> Vec<u8> {
|
||||
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
|
||||
let transport_nibble: u8 = match transport {
|
||||
ProxyV2Transport::Stream => 0x1,
|
||||
ProxyV2Transport::Datagram => 0x2,
|
||||
@@ -462,7 +469,10 @@ mod tests {
|
||||
header.push(0x11);
|
||||
header.extend_from_slice(&12u16.to_be_bytes());
|
||||
header.extend_from_slice(&[0u8; 12]);
|
||||
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
|
||||
assert!(matches!(
|
||||
parse_v2(&header),
|
||||
Err(ProxyProtocolError::UnsupportedVersion)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
|
||||
use rustproxy_config::{RouteConfig, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_security::IpBlockList;
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
|
||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||
///
|
||||
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
info!("QUIC endpoint listening on port {}", port);
|
||||
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
port: u16,
|
||||
tls_config: Arc<RustlsServerConfig>,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<QuicProxyRelay> {
|
||||
// Bind external socket on the real port
|
||||
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
internal_socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
let real_client_map = Arc::new(DashMap::new());
|
||||
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
external_socket,
|
||||
quinn_internal_addr,
|
||||
proxy_ips,
|
||||
security_policy,
|
||||
Arc::clone(&real_client_map),
|
||||
cancel,
|
||||
));
|
||||
|
||||
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
|
||||
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
|
||||
info!(
|
||||
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
|
||||
port, quinn_internal_addr
|
||||
);
|
||||
Ok(QuicProxyRelay {
|
||||
endpoint,
|
||||
relay_task,
|
||||
real_client_map,
|
||||
})
|
||||
}
|
||||
|
||||
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
||||
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
|
||||
external_socket: Arc<UdpSocket>,
|
||||
quinn_internal_addr: SocketAddr,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
|
||||
debug!(
|
||||
"QUIC PROXY v2 from {}: real client {}",
|
||||
src_addr, header.source_addr
|
||||
);
|
||||
proxy_addr_map.insert(src_addr, header.source_addr);
|
||||
continue; // consume the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
|
||||
debug!(
|
||||
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
|
||||
src_addr, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine real client address
|
||||
let real_client = proxy_addr_map.get(&src_addr)
|
||||
let real_client = proxy_addr_map
|
||||
.get(&src_addr)
|
||||
.map(|r| *r)
|
||||
.unwrap_or(src_addr);
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
|
||||
debug!(
|
||||
"QUIC datagram from {} blocked by global security policy",
|
||||
real_client
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get or create relay session for this external source
|
||||
let session = match relay_sessions.get(&src_addr) {
|
||||
Some(s) => {
|
||||
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
s.last_activity
|
||||
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
Arc::clone(s.value())
|
||||
}
|
||||
None => {
|
||||
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
|
||||
}
|
||||
};
|
||||
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
||||
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
|
||||
warn!(
|
||||
"QUIC relay: failed to connect relay socket to {}: {}",
|
||||
quinn_internal_addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let relay_local_addr = match relay_socket.local_addr() {
|
||||
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
|
||||
});
|
||||
|
||||
relay_sessions.insert(src_addr, Arc::clone(&session));
|
||||
debug!("QUIC relay: new session for {} (relay {}), real client {}",
|
||||
src_addr, relay_local_addr, real_client);
|
||||
debug!(
|
||||
"QUIC relay: new session for {} (relay {}), real client {}",
|
||||
src_addr, relay_local_addr, real_client
|
||||
);
|
||||
|
||||
session
|
||||
}
|
||||
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
|
||||
if last_cleanup.elapsed() >= cleanup_interval {
|
||||
last_cleanup = Instant::now();
|
||||
let now_ms = epoch.elapsed().as_millis() as u64;
|
||||
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
|
||||
let stale_keys: Vec<SocketAddr> = relay_sessions
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||
let age =
|
||||
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||
age > session_timeout_ms
|
||||
})
|
||||
.map(|entry| *entry.key())
|
||||
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
|
||||
|
||||
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
||||
// but no relay session was ever created, e.g. client never sent data)
|
||||
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||
let orphaned: Vec<SocketAddr> = proxy_addr_map
|
||||
.iter()
|
||||
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
for key in orphaned {
|
||||
proxy_addr_map.remove(&key);
|
||||
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
|
||||
debug!(
|
||||
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
|
||||
key
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -328,8 +365,14 @@ async fn relay_return_path(
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
|
||||
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
|
||||
if let Err(e) = external_socket
|
||||
.send_to(&buf[..len], external_src_addr)
|
||||
.await
|
||||
{
|
||||
debug!(
|
||||
"QUIC relay return send error to {}: {}",
|
||||
external_src_addr, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
|
||||
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
) {
|
||||
loop {
|
||||
let incoming = tokio::select! {
|
||||
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
|
||||
let remote_addr = incoming.remote_address();
|
||||
|
||||
// Resolve real client IP from PROXY protocol map if available
|
||||
let real_addr = real_client_map.as_ref()
|
||||
let real_addr = real_client_map
|
||||
.as_ref()
|
||||
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
||||
.unwrap_or(remote_addr);
|
||||
let ip = real_addr.ip();
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&ip) {
|
||||
debug!(
|
||||
"QUIC connection from {} blocked by global security policy",
|
||||
real_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Per-IP rate limiting
|
||||
if !conn_tracker.try_accept(&ip) {
|
||||
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
||||
@@ -414,18 +468,22 @@ pub async fn quic_accept_loop(
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security, &ip, ctx.domain,
|
||||
) {
|
||||
debug!("QUIC connection from {} blocked by route security", real_addr);
|
||||
debug!(
|
||||
"QUIC connection from {} blocked by route security",
|
||||
real_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
let route_id = route.name.clone().or(route.id.clone());
|
||||
let route_id = route.metrics_key().map(str::to_string);
|
||||
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel)
|
||||
let route_cancel = match route_id.as_deref() {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
Some(id) => route_cancels
|
||||
.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel.child_token(),
|
||||
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
|
||||
let metrics = Arc::clone(&metrics);
|
||||
let conn_tracker = Arc::clone(&conn_tracker);
|
||||
let h3_svc = h3_service.clone();
|
||||
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
||||
let real_client_addr = if real_addr != remote_addr {
|
||||
Some(real_addr)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Register in connection registry (RAII guard removes on drop)
|
||||
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
|
||||
impl Drop for QuicConnGuard {
|
||||
fn drop(&mut self) {
|
||||
self.tracker.connection_closed(&self.ip);
|
||||
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||
self.metrics
|
||||
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||
}
|
||||
}
|
||||
let _guard = QuicConnGuard {
|
||||
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
|
||||
route_id,
|
||||
};
|
||||
|
||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await {
|
||||
match handle_quic_connection(
|
||||
incoming,
|
||||
route,
|
||||
port,
|
||||
Arc::clone(&metrics),
|
||||
&conn_cancel,
|
||||
h3_svc,
|
||||
real_client_addr,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
||||
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
||||
}
|
||||
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
|
||||
debug!("QUIC connection established from {}", effective_addr);
|
||||
|
||||
// Check if this route has HTTP/3 enabled
|
||||
let enable_http3 = route.action.udp.as_ref()
|
||||
let enable_http3 = route
|
||||
.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.and_then(|q| q.enable_http3)
|
||||
.unwrap_or(false);
|
||||
|
||||
if enable_http3 {
|
||||
if let Some(ref h3_svc) = h3_service {
|
||||
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
|
||||
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
|
||||
debug!(
|
||||
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
|
||||
route.name
|
||||
);
|
||||
h3_svc
|
||||
.handle_connection(connection, &route, port, real_client_addr, cancel)
|
||||
.await
|
||||
} else {
|
||||
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
|
||||
warn!(
|
||||
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
|
||||
route.name
|
||||
);
|
||||
// Keep connection alive until cancelled
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {}
|
||||
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
|
||||
}
|
||||
} else {
|
||||
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -541,11 +626,14 @@ async fn handle_quic_stream_forwarding(
|
||||
real_client_addr: Option<SocketAddr>,
|
||||
) -> anyhow::Result<()> {
|
||||
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||
let route_id = route.metrics_key();
|
||||
let metrics_arc = metrics;
|
||||
|
||||
// Resolve backend target
|
||||
let target = route.action.targets.as_ref()
|
||||
let target = route
|
||||
.action
|
||||
.targets
|
||||
.as_ref()
|
||||
.and_then(|t| t.first())
|
||||
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
||||
let backend_host = target.host.first();
|
||||
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
|
||||
|
||||
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||||
tokio::spawn(async move {
|
||||
match forward_quic_stream_to_tcp(
|
||||
send_stream,
|
||||
recv_stream,
|
||||
&backend_addr,
|
||||
stream_cancel,
|
||||
).await {
|
||||
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
|
||||
.await
|
||||
{
|
||||
Ok((bytes_in, bytes_out)) => {
|
||||
stream_metrics.record_bytes(
|
||||
bytes_in, bytes_out,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
stream_route_id.as_deref(),
|
||||
Some(&ip_str),
|
||||
);
|
||||
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
||||
debug!(
|
||||
"QUIC stream forwarded: {}B in, {}B out",
|
||||
bytes_in, bytes_out
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC stream forwarding error: {}", e);
|
||||
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
tcp_write.shutdown(),
|
||||
).await;
|
||||
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
@@ -721,8 +807,8 @@ mod tests {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Generate a single self-signed cert and use its key pair
|
||||
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
||||
.unwrap();
|
||||
let self_signed =
|
||||
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
|
||||
let cert_der = self_signed.cert.der().clone();
|
||||
let key_der = self_signed.key_pair.serialize_der();
|
||||
|
||||
@@ -737,6 +823,10 @@ mod tests {
|
||||
|
||||
// Port 0 = OS assigns a free port
|
||||
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
||||
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"QUIC endpoint creation failed: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes) - informational, we parse incrementally
|
||||
let _handshake_len = ((data[6] as usize) << 16)
|
||||
| ((data[7] as usize) << 8)
|
||||
| (data[8] as usize);
|
||||
let _handshake_len =
|
||||
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
|
||||
|
||||
let hello = &data[9..];
|
||||
|
||||
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
|
||||
pub fn extract_http_host(data: &[u8]) -> Option<String> {
|
||||
let text = std::str::from_utf8(data).ok()?;
|
||||
for line in text.split("\r\n") {
|
||||
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
|
||||
if let Some(value) = line
|
||||
.strip_prefix("Host: ")
|
||||
.or_else(|| line.strip_prefix("host: "))
|
||||
{
|
||||
// Strip port if present
|
||||
let host = value.split(':').next().unwrap_or(value).trim();
|
||||
if !host.is_empty() {
|
||||
@@ -196,7 +198,7 @@ pub fn is_http(data: &[u8]) -> bool {
|
||||
b"PATC",
|
||||
b"OPTI",
|
||||
b"CONN",
|
||||
b"PRI ", // HTTP/2 connection preface
|
||||
b"PRI ", // HTTP/2 connection preface
|
||||
];
|
||||
starts.iter().any(|s| data.starts_with(s))
|
||||
}
|
||||
@@ -213,7 +215,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_too_short() {
|
||||
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
|
||||
assert!(matches!(
|
||||
extract_sni(&[0x16, 0x03]),
|
||||
SniResult::NeedMoreData
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -263,7 +268,8 @@ mod tests {
|
||||
// Extension: type=0x0000 (SNI), length, data
|
||||
let sni_extension = {
|
||||
let mut e = Vec::new();
|
||||
e.push(0x00); e.push(0x00); // SNI type
|
||||
e.push(0x00);
|
||||
e.push(0x00); // SNI type
|
||||
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
||||
e.push((sni_ext_data.len() & 0xFF) as u8);
|
||||
e.extend_from_slice(&sni_ext_data);
|
||||
@@ -283,16 +289,20 @@ mod tests {
|
||||
let hello_body = {
|
||||
let mut h = Vec::new();
|
||||
// Client version TLS 1.2
|
||||
h.push(0x03); h.push(0x03);
|
||||
h.push(0x03);
|
||||
h.push(0x03);
|
||||
// Random (32 bytes)
|
||||
h.extend_from_slice(&[0u8; 32]);
|
||||
// Session ID length = 0
|
||||
h.push(0x00);
|
||||
// Cipher suites: length=2, one suite
|
||||
h.push(0x00); h.push(0x02);
|
||||
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01); h.push(0x00);
|
||||
h.push(0x00);
|
||||
h.push(0x02);
|
||||
h.push(0x00);
|
||||
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01);
|
||||
h.push(0x00);
|
||||
// Extensions
|
||||
h.extend_from_slice(&extensions);
|
||||
h
|
||||
@@ -302,7 +312,7 @@ mod tests {
|
||||
let handshake = {
|
||||
let mut hs = Vec::new();
|
||||
hs.push(0x01); // ClientHello
|
||||
// 3-byte length
|
||||
// 3-byte length
|
||||
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
|
||||
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
|
||||
hs.push((hello_body.len() & 0xFF) as u8);
|
||||
@@ -313,7 +323,8 @@ mod tests {
|
||||
// TLS record: type=0x16, version TLS 1.0, length
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // Handshake
|
||||
record.push(0x03); record.push(0x01); // TLS 1.0
|
||||
record.push(0x03);
|
||||
record.push(0x01); // TLS 1.0
|
||||
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
||||
record.push((handshake.len() & 0xFF) as u8);
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
|
||||
use rustls::sign::CertifiedKey;
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
||||
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::tcp_listener::TlsCertConfig;
|
||||
@@ -29,7 +29,9 @@ pub struct CertResolver {
|
||||
impl CertResolver {
|
||||
/// Build a resolver from PEM-encoded cert/key configs.
|
||||
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
|
||||
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn new(
|
||||
configs: &HashMap<String, TlsCertConfig>,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let provider = rustls::crypto::ring::default_provider();
|
||||
let mut certs = HashMap::new();
|
||||
@@ -38,8 +40,10 @@ impl CertResolver {
|
||||
for (domain, cfg) in configs {
|
||||
let cert_chain = load_certs(&cfg.cert_pem)?;
|
||||
let key = load_private_key(&cfg.key_pem)?;
|
||||
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
|
||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
|
||||
let ck = Arc::new(
|
||||
CertifiedKey::from_der(cert_chain, key, &provider)
|
||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
|
||||
);
|
||||
if domain == "*" {
|
||||
fallback = Some(Arc::clone(&ck));
|
||||
}
|
||||
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
|
||||
|
||||
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
|
||||
/// The returned acceptor can be reused across all connections (cheap Arc clone).
|
||||
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_shared_tls_acceptor(
|
||||
resolver: CertResolver,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let mut config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
|
||||
// Shared session cache — enables session ID resumption across connections
|
||||
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
|
||||
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
|
||||
config.ticketer = rustls::crypto::ring::Ticketer::new()
|
||||
.map_err(|e| format!("Ticketer: {}", e))?;
|
||||
config.ticketer =
|
||||
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
|
||||
|
||||
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
|
||||
info!(
|
||||
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
|
||||
);
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
||||
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
|
||||
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_tls_acceptor(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
|
||||
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
|
||||
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_tls_acceptor_h1_only(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let certs = load_certs(cert_pem)?;
|
||||
let key = load_private_key(key_pem)?;
|
||||
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
|
||||
// Apply TLS version restrictions
|
||||
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
||||
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
||||
builder
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
builder.with_no_client_auth().with_single_cert(certs, key)?
|
||||
} else {
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
|
||||
}
|
||||
|
||||
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
||||
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
fn resolve_tls_versions(
|
||||
versions: Option<&[String]>,
|
||||
) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
let versions = match versions {
|
||||
Some(v) if !v.is_empty() => v,
|
||||
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
||||
@@ -207,15 +221,17 @@ pub async fn accept_tls(
|
||||
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||
|
||||
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||
SHARED_CLIENT_CONFIG.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
info!("Built shared backend TLS client config with session resumption");
|
||||
Arc::new(config)
|
||||
}).clone()
|
||||
SHARED_CLIENT_CONFIG
|
||||
.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
info!("Built shared backend TLS client config with session resumption");
|
||||
Arc::new(config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
|
||||
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||
|
||||
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
|
||||
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
|
||||
Arc::new(config)
|
||||
}).clone()
|
||||
SHARED_CLIENT_CONFIG_ALPN
|
||||
.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
info!(
|
||||
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
|
||||
);
|
||||
Arc::new(config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
||||
@@ -249,7 +269,8 @@ pub async fn connect_tls(
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
|
||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
|
||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
|
||||
{
|
||||
debug!("Failed to set keepalive on backend TLS socket: {}", e);
|
||||
}
|
||||
|
||||
@@ -260,10 +281,12 @@ pub async fn connect_tls(
|
||||
}
|
||||
|
||||
/// Load certificates from PEM string.
|
||||
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn load_certs(
|
||||
pem: &str,
|
||||
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let certs: Vec<CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in PEM data".into());
|
||||
}
|
||||
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
|
||||
}
|
||||
|
||||
/// Load private key from PEM string.
|
||||
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn load_private_key(
|
||||
pem: &str,
|
||||
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
// Try PKCS8 first, then RSA, then EC
|
||||
let key = rustls_pemfile::private_key(&mut reader)?
|
||||
.ok_or("No private key found in PEM data")?;
|
||||
let key =
|
||||
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_config::{RouteActionType, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_security::IpBlockList;
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
/// Shared connection registry for selective recycling.
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
/// Global ingress block policy, hot-reloadable without restarting listeners.
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
}
|
||||
|
||||
impl Drop for UdpListenerManager {
|
||||
@@ -99,17 +102,26 @@ impl UdpListenerManager {
|
||||
proxy_ips: Arc::new(Vec::new()),
|
||||
route_cancels,
|
||||
connection_registry,
|
||||
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
|
||||
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
|
||||
if !ips.is_empty() {
|
||||
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
|
||||
info!(
|
||||
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
|
||||
ips.len()
|
||||
);
|
||||
}
|
||||
self.proxy_ips = Arc::new(ips);
|
||||
}
|
||||
|
||||
/// Set the shared global ingress security policy.
|
||||
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
|
||||
self.security_policy = policy;
|
||||
}
|
||||
|
||||
/// Set the H3 proxy service for HTTP/3 request handling.
|
||||
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
|
||||
self.h3_service = Some(svc);
|
||||
@@ -142,7 +154,9 @@ impl UdpListenerManager {
|
||||
// Check if any route on this port uses QUIC
|
||||
let rm = self.route_manager.load();
|
||||
let has_quic = rm.routes_for_port(port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
r.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
});
|
||||
@@ -164,8 +178,10 @@ impl UdpListenerManager {
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint started on port {}", port);
|
||||
} else {
|
||||
// Proxy relay path: we own external socket, quinn on localhost
|
||||
@@ -173,6 +189,7 @@ impl UdpListenerManager {
|
||||
port,
|
||||
tls,
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
self.cancel_token.child_token(),
|
||||
)?;
|
||||
let endpoint_for_updates = relay.endpoint.clone();
|
||||
@@ -187,13 +204,18 @@ impl UdpListenerManager {
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
||||
}
|
||||
return Ok(());
|
||||
} else {
|
||||
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
||||
warn!(
|
||||
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
|
||||
port
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,6 +236,7 @@ impl UdpListenerManager {
|
||||
Arc::clone(&self.relay_writer),
|
||||
self.cancel_token.child_token(),
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, (handle, None));
|
||||
@@ -254,8 +277,10 @@ impl UdpListenerManager {
|
||||
}
|
||||
debug!("UDP listener stopped on port {}", port);
|
||||
}
|
||||
info!("All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count());
|
||||
info!(
|
||||
"All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count()
|
||||
);
|
||||
}
|
||||
|
||||
/// Update TLS config on all active QUIC endpoints (cert refresh).
|
||||
@@ -288,11 +313,15 @@ impl UdpListenerManager {
|
||||
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
|
||||
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
|
||||
let rm = self.route_manager.load();
|
||||
let upgrade_ports: Vec<u16> = self.listeners.iter()
|
||||
let upgrade_ports: Vec<u16> = self
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|(_, (_, endpoint))| endpoint.is_none())
|
||||
.filter(|(port, _)| {
|
||||
rm.routes_for_port(**port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
r.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
})
|
||||
@@ -301,17 +330,23 @@ impl UdpListenerManager {
|
||||
.collect();
|
||||
|
||||
for port in upgrade_ports {
|
||||
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
|
||||
info!(
|
||||
"Upgrading raw UDP listener on port {} to QUIC endpoint",
|
||||
port
|
||||
);
|
||||
|
||||
// Stop the raw UDP listener task and drain sessions to release the socket
|
||||
if let Some((handle, _)) = self.listeners.remove(&port) {
|
||||
handle.abort();
|
||||
}
|
||||
let drained = self.session_table.drain_port(
|
||||
port, &self.metrics, &self.conn_tracker,
|
||||
);
|
||||
let drained = self
|
||||
.session_table
|
||||
.drain_port(port, &self.metrics, &self.conn_tracker);
|
||||
if drained > 0 {
|
||||
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
|
||||
debug!(
|
||||
"Drained {} UDP sessions on port {} for QUIC upgrade",
|
||||
drained, port
|
||||
);
|
||||
}
|
||||
|
||||
// Brief yield to let aborted tasks drop their socket references
|
||||
@@ -326,11 +361,17 @@ impl UdpListenerManager {
|
||||
|
||||
match create_result {
|
||||
Ok(()) => {
|
||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
|
||||
info!(
|
||||
"QUIC endpoint started on port {} (upgraded from raw UDP)",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Port may still be held — retry once after a brief delay
|
||||
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
|
||||
warn!(
|
||||
"QUIC endpoint creation failed on port {}, retrying: {}",
|
||||
port, e
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let retry_result = if self.proxy_ips.is_empty() {
|
||||
@@ -341,11 +382,17 @@ impl UdpListenerManager {
|
||||
|
||||
match retry_result {
|
||||
Ok(()) => {
|
||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
|
||||
info!(
|
||||
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e2) => {
|
||||
error!("Failed to upgrade port {} to QUIC after retry: {}. \
|
||||
Rebinding as raw UDP.", port, e2);
|
||||
error!(
|
||||
"Failed to upgrade port {} to QUIC after retry: {}. \
|
||||
Rebinding as raw UDP.",
|
||||
port, e2
|
||||
);
|
||||
// Fallback: rebind as raw UDP so the port isn't dead
|
||||
if let Ok(()) = self.rebind_raw_udp(port).await {
|
||||
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
|
||||
@@ -358,7 +405,11 @@ impl UdpListenerManager {
|
||||
}
|
||||
|
||||
/// Create a direct QUIC endpoint (quinn owns the socket).
|
||||
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||
fn create_quic_direct(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
|
||||
let endpoint_for_updates = endpoint.clone();
|
||||
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||
@@ -372,17 +423,24 @@ impl UdpListenerManager {
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a QUIC endpoint with PROXY protocol relay.
|
||||
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||
fn create_quic_with_relay(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
||||
port,
|
||||
tls_config,
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
self.cancel_token.child_token(),
|
||||
)?;
|
||||
let endpoint_for_updates = relay.endpoint.clone();
|
||||
@@ -397,8 +455,10 @@ impl UdpListenerManager {
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -419,6 +479,7 @@ impl UdpListenerManager {
|
||||
Arc::clone(&self.relay_writer),
|
||||
self.cancel_token.child_token(),
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, (handle, None));
|
||||
@@ -458,7 +519,10 @@ impl UdpListenerManager {
|
||||
info!("Datagram handler relay connected to {}", path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to connect datagram handler relay to {}: {}", path, e);
|
||||
error!(
|
||||
"Failed to connect datagram handler relay to {}: {}",
|
||||
path, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -514,6 +578,7 @@ impl UdpListenerManager {
|
||||
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
||||
cancel: CancellationToken,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
) {
|
||||
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
||||
let mut buf = vec![0u8; 65535];
|
||||
@@ -528,9 +593,11 @@ impl UdpListenerManager {
|
||||
|
||||
loop {
|
||||
// Periodic cleanup: remove proxy_addr_map entries with no active session
|
||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
|
||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
|
||||
{
|
||||
last_proxy_cleanup = tokio::time::Instant::now();
|
||||
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||
let stale: Vec<SocketAddr> = proxy_addr_map
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let key: SessionKey = (*entry.key(), port);
|
||||
session_table.get(&key).is_none()
|
||||
@@ -538,7 +605,11 @@ impl UdpListenerManager {
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
if !stale.is_empty() {
|
||||
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
|
||||
debug!(
|
||||
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
|
||||
stale.len(),
|
||||
port
|
||||
);
|
||||
for addr in stale {
|
||||
proxy_addr_map.remove(&addr);
|
||||
}
|
||||
@@ -564,34 +635,50 @@ impl UdpListenerManager {
|
||||
let datagram = &buf[..len];
|
||||
|
||||
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
|
||||
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
|
||||
// No session and no prior PROXY header — check for PROXY v2
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
|
||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||
continue; // discard the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||
client_addr.ip()
|
||||
let effective_client_ip =
|
||||
if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
if session_table.get(&session_key).is_none()
|
||||
&& !proxy_addr_map.contains_key(&client_addr)
|
||||
{
|
||||
// No session and no prior PROXY header — check for PROXY v2
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!(
|
||||
"UDP PROXY v2 from {}: real client {}",
|
||||
client_addr, header.source_addr
|
||||
);
|
||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||
continue; // discard the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||
client_addr.ip()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
// Use real client IP if we've previously seen a PROXY v2 header
|
||||
proxy_addr_map
|
||||
.get(&client_addr)
|
||||
.map(|r| r.ip())
|
||||
.unwrap_or_else(|| client_addr.ip())
|
||||
}
|
||||
} else {
|
||||
// Use real client IP if we've previously seen a PROXY v2 header
|
||||
proxy_addr_map.get(&client_addr)
|
||||
.map(|r| r.ip())
|
||||
.unwrap_or_else(|| client_addr.ip())
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
};
|
||||
client_addr.ip()
|
||||
};
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
|
||||
debug!(
|
||||
"UDP datagram from {} blocked by global security policy",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Route matching — use effective (real) client IP
|
||||
let rm = route_manager.load();
|
||||
@@ -611,13 +698,16 @@ impl UdpListenerManager {
|
||||
let route_match = match rm.find_route(&ctx) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
debug!("No UDP route matched for port {} from {}", port, client_addr);
|
||||
debug!(
|
||||
"No UDP route matched for port {} from {}",
|
||||
port, client_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let route = route_match.route;
|
||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||
let route_id = route.metrics_key();
|
||||
|
||||
// Socket handler routes → relay datagram to TS via persistent Unix socket
|
||||
if route.action.action_type == RouteActionType::SocketHandler {
|
||||
@@ -627,7 +717,9 @@ impl UdpListenerManager {
|
||||
&client_addr,
|
||||
port,
|
||||
datagram,
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("Failed to relay UDP datagram to TS: {}", e);
|
||||
}
|
||||
continue;
|
||||
@@ -638,8 +730,10 @@ impl UdpListenerManager {
|
||||
|
||||
// Check datagram size
|
||||
if len as u32 > udp_config.max_datagram_size {
|
||||
debug!("UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr);
|
||||
debug!(
|
||||
"UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -651,21 +745,27 @@ impl UdpListenerManager {
|
||||
None => {
|
||||
// New session — check per-IP limits using the real client IP
|
||||
if !conn_tracker.try_accept(&effective_client_ip) {
|
||||
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
|
||||
debug!(
|
||||
"UDP session rejected for {} (rate limit)",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if !session_table.can_create_session(
|
||||
&effective_client_ip,
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
|
||||
if !session_table
|
||||
.can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
|
||||
{
|
||||
debug!(
|
||||
"UDP session rejected for {} (per-IP session limit)",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resolve target
|
||||
let target = match route_match.target.or_else(|| {
|
||||
route.action.targets.as_ref().and_then(|t| t.first())
|
||||
}) {
|
||||
let target = match route_match
|
||||
.target
|
||||
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
|
||||
{
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("No target for UDP route {:?}", route_id);
|
||||
@@ -686,13 +786,18 @@ impl UdpListenerManager {
|
||||
}
|
||||
};
|
||||
if let Err(e) = backend_socket.connect(&backend_addr).await {
|
||||
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
|
||||
error!(
|
||||
"Failed to connect backend UDP socket to {}: {}",
|
||||
backend_addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let backend_socket = Arc::new(backend_socket);
|
||||
|
||||
debug!("New UDP session: {} -> {} (via port {}, real client {})",
|
||||
client_addr, backend_addr, port, effective_client_ip);
|
||||
debug!(
|
||||
"New UDP session: {} -> {} (via port {}, real client {})",
|
||||
client_addr, backend_addr, port, effective_client_ip
|
||||
);
|
||||
|
||||
// Spawn return-path relay task
|
||||
let session_cancel = CancellationToken::new();
|
||||
@@ -709,7 +814,9 @@ impl UdpListenerManager {
|
||||
|
||||
let session = Arc::new(UdpSession {
|
||||
backend_socket,
|
||||
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
||||
last_activity: std::sync::atomic::AtomicU64::new(
|
||||
session_table.elapsed_ms(),
|
||||
),
|
||||
created_at: std::time::Instant::now(),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
source_ip: effective_client_ip,
|
||||
@@ -718,7 +825,11 @@ impl UdpListenerManager {
|
||||
cancel: session_cancel,
|
||||
});
|
||||
|
||||
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
|
||||
if !session_table.insert(
|
||||
session_key,
|
||||
Arc::clone(&session),
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
warn!("Failed to insert UDP session (race condition)");
|
||||
continue;
|
||||
}
|
||||
@@ -735,7 +846,9 @@ impl UdpListenerManager {
|
||||
// Forward datagram to backend
|
||||
match session.backend_socket.send(datagram).await {
|
||||
Ok(_) => {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
session
|
||||
.last_activity
|
||||
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
|
||||
metrics.record_datagram_in();
|
||||
}
|
||||
@@ -779,7 +892,9 @@ impl UdpListenerManager {
|
||||
Ok(_) => {
|
||||
// Update session activity
|
||||
if let Some(session) = session_table.get(&session_key) {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
session
|
||||
.last_activity
|
||||
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
}
|
||||
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
|
||||
metrics.record_datagram_out();
|
||||
@@ -814,7 +929,8 @@ impl UdpListenerManager {
|
||||
let json = serde_json::to_vec(&msg)?;
|
||||
|
||||
let mut guard = writer.lock().await;
|
||||
let stream = guard.as_mut()
|
||||
let stream = guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
|
||||
|
||||
// Length-prefixed frame
|
||||
@@ -879,9 +995,15 @@ impl UdpListenerManager {
|
||||
}
|
||||
|
||||
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let source_port = reply
|
||||
.get("sourcePort")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u16;
|
||||
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let payload_b64 = reply
|
||||
.get("payloadBase64")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
|
||||
Ok(p) => p,
|
||||
|
||||
@@ -111,12 +111,15 @@ impl UdpSessionTable {
|
||||
|
||||
/// Look up an existing session.
|
||||
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
||||
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
|
||||
self.sessions
|
||||
.get(key)
|
||||
.map(|entry| Arc::clone(entry.value()))
|
||||
}
|
||||
|
||||
/// Check if we can create a new session for this IP (under the per-IP limit).
|
||||
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
||||
let count = self.ip_session_counts
|
||||
let count = self
|
||||
.ip_session_counts
|
||||
.get(ip)
|
||||
.map(|c| *c.value())
|
||||
.unwrap_or(0);
|
||||
@@ -124,12 +127,7 @@ impl UdpSessionTable {
|
||||
}
|
||||
|
||||
/// Insert a new session. Returns false if per-IP limit exceeded.
|
||||
pub fn insert(
|
||||
&self,
|
||||
key: SessionKey,
|
||||
session: Arc<UdpSession>,
|
||||
max_per_ip: u32,
|
||||
) -> bool {
|
||||
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
|
||||
let ip = session.source_ip;
|
||||
|
||||
// Atomically check and increment per-IP count
|
||||
@@ -173,7 +171,9 @@ impl UdpSessionTable {
|
||||
let mut removed = 0;
|
||||
|
||||
// Collect keys to remove (avoid holding DashMap refs during removal)
|
||||
let stale_keys: Vec<SessionKey> = self.sessions.iter()
|
||||
let stale_keys: Vec<SessionKey> = self
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
||||
now_ms.saturating_sub(last) >= timeout_ms
|
||||
@@ -185,7 +185,8 @@ impl UdpSessionTable {
|
||||
if let Some(session) = self.remove(&key) {
|
||||
debug!(
|
||||
"UDP session expired: {} -> port {} (idle {}ms)",
|
||||
session.client_addr, key.1,
|
||||
session.client_addr,
|
||||
key.1,
|
||||
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
||||
);
|
||||
conn_tracker.connection_closed(&session.source_ip);
|
||||
@@ -210,7 +211,9 @@ impl UdpSessionTable {
|
||||
metrics: &MetricsCollector,
|
||||
conn_tracker: &ConnectionTracker,
|
||||
) -> usize {
|
||||
let keys: Vec<SessionKey> = self.sessions.iter()
|
||||
let keys: Vec<SessionKey> = self
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| entry.key().1 == port)
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
@@ -257,9 +260,8 @@ mod tests {
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let backend_socket = rt.block_on(async {
|
||||
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
|
||||
});
|
||||
let backend_socket =
|
||||
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
|
||||
|
||||
let child_cancel = cancel.child_token();
|
||||
let return_task = rt.spawn(async move {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! Route matching engine for RustProxy.
|
||||
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
|
||||
|
||||
pub mod route_manager;
|
||||
pub mod matchers;
|
||||
pub mod route_manager;
|
||||
|
||||
pub use route_manager::*;
|
||||
|
||||
@@ -20,7 +20,7 @@ pub fn domain_matches(pattern: &str, domain: &str) -> bool {
|
||||
// Wildcard patterns
|
||||
if pattern.starts_with("*.") || pattern.starts_with("*.") {
|
||||
let suffix = &pattern[2..]; // e.g., "example.com"
|
||||
// Match exact parent or any single-level subdomain
|
||||
// Match exact parent or any single-level subdomain
|
||||
if domain.eq_ignore_ascii_case(suffix) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use ipnet::IpNet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use ipnet::IpNet;
|
||||
|
||||
/// Match an IP address against a pattern.
|
||||
///
|
||||
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
|
||||
Some(format!(
|
||||
"{}.{}.{}.{}/{}",
|
||||
octets[0], octets[1], octets[2], octets[3], prefix_len
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
pub mod domain;
|
||||
pub mod path;
|
||||
pub mod ip;
|
||||
pub mod header;
|
||||
pub mod ip;
|
||||
pub mod path;
|
||||
|
||||
pub use domain::*;
|
||||
pub use path::*;
|
||||
pub use ip::*;
|
||||
pub use header::*;
|
||||
pub use ip::*;
|
||||
pub use path::*;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
|
||||
use crate::matchers;
|
||||
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
|
||||
|
||||
/// Context for route matching (subset of connection info).
|
||||
pub struct MatchContext<'a> {
|
||||
@@ -42,19 +42,14 @@ impl RouteManager {
|
||||
};
|
||||
|
||||
// Filter enabled routes and sort by priority
|
||||
let mut enabled_routes: Vec<RouteConfig> = routes
|
||||
.into_iter()
|
||||
.filter(|r| r.is_enabled())
|
||||
.collect();
|
||||
let mut enabled_routes: Vec<RouteConfig> =
|
||||
routes.into_iter().filter(|r| r.is_enabled()).collect();
|
||||
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
|
||||
|
||||
// Build port index
|
||||
for (idx, route) in enabled_routes.iter().enumerate() {
|
||||
for port in route.listening_ports() {
|
||||
manager.port_index
|
||||
.entry(port)
|
||||
.or_default()
|
||||
.push(idx);
|
||||
manager.port_index.entry(port).or_default().push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +61,9 @@ impl RouteManager {
|
||||
/// Used to skip expensive header HashMap construction when no route needs it.
|
||||
pub fn any_route_has_headers(&self, port: u16) -> bool {
|
||||
if let Some(indices) = self.port_index.get(&port) {
|
||||
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some())
|
||||
indices
|
||||
.iter()
|
||||
.any(|&idx| self.routes[idx].route_match.headers.is_some())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -99,8 +96,8 @@ impl RouteManager {
|
||||
let ctx_transport = ctx.transport.as_ref();
|
||||
match (route_transport, ctx_transport) {
|
||||
// Route requires UDP only — reject non-UDP contexts
|
||||
(Some(TransportProtocol::Udp), None) |
|
||||
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
|
||||
(Some(TransportProtocol::Udp), None)
|
||||
| (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
|
||||
// Route requires TCP only — reject UDP contexts
|
||||
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
|
||||
// Route has no transport (default = TCP) — reject UDP contexts
|
||||
@@ -196,7 +193,11 @@ impl RouteManager {
|
||||
}
|
||||
|
||||
/// Find the best matching target within a route.
|
||||
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
|
||||
fn find_target<'a>(
|
||||
&self,
|
||||
route: &'a RouteConfig,
|
||||
ctx: &MatchContext<'_>,
|
||||
) -> Option<&'a RouteTarget> {
|
||||
let targets = route.action.targets.as_ref()?;
|
||||
|
||||
if targets.len() == 1 && targets[0].target_match.is_none() {
|
||||
@@ -223,17 +224,11 @@ impl RouteManager {
|
||||
}
|
||||
|
||||
// Fall back to first target without match criteria
|
||||
best.or_else(|| {
|
||||
targets.iter().find(|t| t.target_match.is_none())
|
||||
})
|
||||
best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
|
||||
}
|
||||
|
||||
/// Check if a target match criteria matches the context.
|
||||
fn matches_target(
|
||||
&self,
|
||||
tm: &rustproxy_config::TargetMatch,
|
||||
ctx: &MatchContext<'_>,
|
||||
) -> bool {
|
||||
fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
|
||||
// Port matching
|
||||
if let Some(ref ports) = tm.ports {
|
||||
if !ports.contains(&ctx.port) {
|
||||
@@ -298,9 +293,7 @@ impl RouteManager {
|
||||
// If multiple passthrough routes on same port, SNI is needed
|
||||
let passthrough_routes: Vec<_> = routes
|
||||
.iter()
|
||||
.filter(|r| {
|
||||
r.tls_mode() == Some(&TlsMode::Passthrough)
|
||||
})
|
||||
.filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
|
||||
.collect();
|
||||
|
||||
if passthrough_routes.len() > 1 {
|
||||
@@ -419,7 +412,11 @@ mod tests {
|
||||
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
// Should match the higher-priority specific route
|
||||
assert!(result.route.route_match.domains.as_ref()
|
||||
assert!(result
|
||||
.route
|
||||
.route_match
|
||||
.domains
|
||||
.as_ref()
|
||||
.map(|d| d.to_vec())
|
||||
.unwrap()
|
||||
.contains(&"api.example.com"));
|
||||
@@ -619,8 +616,14 @@ mod tests {
|
||||
|
||||
let result = manager.find_route(&ctx);
|
||||
assert!(result.is_some());
|
||||
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
|
||||
.map(|d| d.to_vec()).unwrap();
|
||||
let matched_domains = result
|
||||
.unwrap()
|
||||
.route
|
||||
.route_match
|
||||
.domains
|
||||
.as_ref()
|
||||
.map(|d| d.to_vec())
|
||||
.unwrap();
|
||||
assert!(matched_domains.contains(&"*"));
|
||||
}
|
||||
|
||||
@@ -735,7 +738,11 @@ mod tests {
|
||||
assert_eq!(result.target.unwrap().host.first(), "default-backend");
|
||||
}
|
||||
|
||||
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
|
||||
fn make_route_with_protocol(
|
||||
port: u16,
|
||||
domain: Option<&str>,
|
||||
protocol: Option<&str>,
|
||||
) -> RouteConfig {
|
||||
let mut route = make_route(port, domain, 0);
|
||||
route.route_match.protocol = protocol.map(|s| s.to_string());
|
||||
route
|
||||
@@ -1029,8 +1036,10 @@ mod tests {
|
||||
transport: Some(TransportProtocol::Udp),
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_some(),
|
||||
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes");
|
||||
assert!(
|
||||
manager.find_route(&ctx).is_some(),
|
||||
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1051,7 +1060,9 @@ mod tests {
|
||||
transport: None, // TCP (default)
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_none(),
|
||||
"TCP TLS without SNI should NOT match domain-restricted routes");
|
||||
assert!(
|
||||
manager.find_route(&ctx).is_none(),
|
||||
"TCP TLS without SNI should NOT match domain-restricted routes"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
/// Basic auth validator.
|
||||
pub struct BasicAuthValidator {
|
||||
|
||||
@@ -21,7 +21,7 @@ struct DomainScopedEntry {
|
||||
}
|
||||
|
||||
/// Represents an IP pattern for matching.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
enum IpPattern {
|
||||
/// Exact IP match
|
||||
Exact(IpAddr),
|
||||
@@ -31,6 +31,37 @@ enum IpPattern {
|
||||
Wildcard,
|
||||
}
|
||||
|
||||
/// Compiled block list for early ingress filtering.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IpBlockList {
|
||||
block_list: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
impl IpBlockList {
|
||||
pub fn new(block_list: &[String]) -> Self {
|
||||
Self {
|
||||
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
block_list: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.block_list.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
||||
let normalized = IpFilter::normalize_ip(ip);
|
||||
self.block_list
|
||||
.iter()
|
||||
.any(|pattern| pattern.matches(&normalized))
|
||||
}
|
||||
}
|
||||
|
||||
impl IpPattern {
|
||||
fn parse(s: &str) -> Self {
|
||||
let s = s.trim();
|
||||
@@ -68,8 +99,7 @@ fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
|
||||
}
|
||||
if p.starts_with("*.") {
|
||||
let suffix = &p[1..]; // e.g., ".abc.xyz"
|
||||
d.len() > suffix.len()
|
||||
&& d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
||||
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -127,7 +157,11 @@ impl IpFilter {
|
||||
if let Some(req_domain) = domain {
|
||||
for entry in &self.domain_scoped {
|
||||
if entry.pattern.matches(ip) {
|
||||
if entry.domains.iter().any(|d| domain_matches_pattern(d, req_domain)) {
|
||||
if entry
|
||||
.domains
|
||||
.iter()
|
||||
.any(|d| domain_matches_pattern(d, req_domain))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -212,10 +246,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_block_trumps_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&[plain("10.0.0.0/8")],
|
||||
&["10.0.0.5".to_string()],
|
||||
);
|
||||
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
|
||||
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
||||
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
@@ -255,30 +286,21 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_allows_matching_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_denies_non_matching_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_denies_without_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
// Without domain context, domain-scoped entries cannot match
|
||||
assert!(!filter.is_allowed_for_domain(&ip, None));
|
||||
@@ -286,10 +308,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_wildcard_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["*.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||
@@ -300,8 +319,8 @@ mod tests {
|
||||
fn test_plain_and_domain_scoped_coexist() {
|
||||
let filter = IpFilter::new(
|
||||
&[
|
||||
plain("1.2.3.4"), // full route access
|
||||
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
|
||||
plain("1.2.3.4"), // full route access
|
||||
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
|
||||
],
|
||||
&[],
|
||||
);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// JWT claims (minimal structure).
|
||||
@@ -160,10 +160,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_bearer() {
|
||||
assert_eq!(
|
||||
JwtValidator::extract_token("Bearer abc123"),
|
||||
Some("abc123")
|
||||
);
|
||||
assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
//!
|
||||
//! IP filtering, rate limiting, and authentication for RustProxy.
|
||||
|
||||
pub mod ip_filter;
|
||||
pub mod rate_limiter;
|
||||
pub mod basic_auth;
|
||||
pub mod ip_filter;
|
||||
pub mod jwt_auth;
|
||||
pub mod rate_limiter;
|
||||
|
||||
pub use ip_filter::*;
|
||||
pub use rate_limiter::*;
|
||||
pub use basic_auth::*;
|
||||
pub use ip_filter::*;
|
||||
pub use jwt_auth::*;
|
||||
pub use rate_limiter::*;
|
||||
|
||||
@@ -79,7 +79,7 @@ mod tests {
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(!limiter.check("client-a")); // blocked
|
||||
// Different key should still be allowed
|
||||
// Different key should still be allowed
|
||||
assert!(limiter.check("client-b"));
|
||||
assert!(limiter.check("client-b"));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
//! Account credentials are ephemeral — the consumer owns all persistence.
|
||||
|
||||
use instant_acme::{
|
||||
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
|
||||
AccountCredentials,
|
||||
Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
|
||||
};
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use thiserror::Error;
|
||||
@@ -89,7 +88,11 @@ impl AcmeClient {
|
||||
F: FnOnce(PendingChallenge) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(), AcmeError>>,
|
||||
{
|
||||
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
|
||||
info!(
|
||||
"Starting ACME provisioning for {} via {}",
|
||||
domain,
|
||||
self.directory_url()
|
||||
);
|
||||
|
||||
// 1. Get or create ACME account
|
||||
let account = self.get_or_create_account().await?;
|
||||
@@ -170,14 +173,14 @@ impl AcmeClient {
|
||||
debug!("Order ready, finalizing...");
|
||||
|
||||
// 6. Generate CSR and finalize
|
||||
let key_pair = KeyPair::generate().map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
|
||||
})?;
|
||||
let key_pair = KeyPair::generate()
|
||||
.map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
|
||||
})?;
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()])
|
||||
.map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let csr = params.serialize_request(&key_pair).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
|
||||
@@ -219,9 +222,7 @@ impl AcmeClient {
|
||||
.certificate()
|
||||
.await
|
||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
|
||||
.ok_or_else(|| {
|
||||
AcmeError::FinalizationFailed("No certificate returned".to_string())
|
||||
})?;
|
||||
.ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
|
||||
|
||||
let private_key_pem = key_pair.serialize_pem();
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use crate::acme::AcmeClient;
|
||||
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CertManagerError {
|
||||
@@ -45,17 +45,13 @@ impl CertManager {
|
||||
/// Create an ACME client using this manager's configuration.
|
||||
/// Returns None if no ACME email is configured.
|
||||
pub fn acme_client(&self) -> Option<AcmeClient> {
|
||||
self.acme_email.as_ref().map(|email| {
|
||||
AcmeClient::new(email.clone(), self.use_production)
|
||||
})
|
||||
self.acme_email
|
||||
.as_ref()
|
||||
.map(|email| AcmeClient::new(email.clone(), self.use_production))
|
||||
}
|
||||
|
||||
/// Load a static certificate into the store (infallible — pure cache insert).
|
||||
pub fn load_static(
|
||||
&mut self,
|
||||
domain: String,
|
||||
bundle: CertBundle,
|
||||
) {
|
||||
pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
|
||||
self.store.store(domain, bundle);
|
||||
}
|
||||
|
||||
@@ -108,23 +104,25 @@ impl CertManager {
|
||||
F: FnOnce(String, String) -> Fut,
|
||||
Fut: std::future::Future<Output = ()>,
|
||||
{
|
||||
let acme_client = self.acme_client()
|
||||
.ok_or(CertManagerError::NoEmail)?;
|
||||
let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
|
||||
|
||||
info!("Renewing certificate for {}", domain);
|
||||
|
||||
let domain_owned = domain.to_string();
|
||||
let result = acme_client.provision(&domain_owned, |pending| {
|
||||
let token = pending.token.clone();
|
||||
let key_auth = pending.key_authorization.clone();
|
||||
async move {
|
||||
challenge_setup(token, key_auth).await;
|
||||
Ok(())
|
||||
}
|
||||
}).await.map_err(|e| CertManagerError::AcmeFailure {
|
||||
domain: domain.to_string(),
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
let result = acme_client
|
||||
.provision(&domain_owned, |pending| {
|
||||
let token = pending.token.clone();
|
||||
let key_auth = pending.key_authorization.clone();
|
||||
async move {
|
||||
challenge_setup(token, key_auth).await;
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CertManagerError::AcmeFailure {
|
||||
domain: domain.to_string(),
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
let (cert_pem, key_pem) = result;
|
||||
let now = SystemTime::now()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Certificate metadata stored alongside certs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -90,8 +90,10 @@ mod tests {
|
||||
|
||||
fn make_test_bundle(domain: &str) -> CertBundle {
|
||||
CertBundle {
|
||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
|
||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
|
||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
|
||||
.to_string(),
|
||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
|
||||
.to_string(),
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
@@ -122,7 +124,8 @@ mod tests {
|
||||
let mut store = CertStore::new();
|
||||
|
||||
let mut bundle = make_test_bundle("secure.com");
|
||||
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||
bundle.ca_pem =
|
||||
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||
store.store("secure.com".to_string(), bundle);
|
||||
|
||||
let loaded = store.get("secure.com").unwrap();
|
||||
@@ -147,7 +150,10 @@ mod tests {
|
||||
fn test_remove_cert() {
|
||||
let mut store = CertStore::new();
|
||||
|
||||
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com"));
|
||||
store.store(
|
||||
"remove-me.com".to_string(),
|
||||
make_test_bundle("remove-me.com"),
|
||||
);
|
||||
assert!(store.has("remove-me.com"));
|
||||
|
||||
let removed = store.remove("remove-me.com");
|
||||
@@ -165,7 +171,10 @@ mod tests {
|
||||
fn test_wildcard_domain() {
|
||||
let mut store = CertStore::new();
|
||||
|
||||
store.store("*.example.com".to_string(), make_test_bundle("*.example.com"));
|
||||
store.store(
|
||||
"*.example.com".to_string(),
|
||||
make_test_bundle("*.example.com"),
|
||||
);
|
||||
assert!(store.has("*.example.com"));
|
||||
|
||||
let loaded = store.get("*.example.com").unwrap();
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
//! TLS certificate management for RustProxy.
|
||||
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
|
||||
|
||||
pub mod cert_store;
|
||||
pub mod cert_manager;
|
||||
pub mod acme;
|
||||
pub mod cert_manager;
|
||||
pub mod cert_store;
|
||||
pub mod sni_resolver;
|
||||
|
||||
pub use cert_store::*;
|
||||
pub use cert_manager::*;
|
||||
pub use cert_store::*;
|
||||
pub use sni_resolver::*;
|
||||
|
||||
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, error};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
/// ACME HTTP-01 challenge server.
|
||||
pub struct ChallengeServer {
|
||||
@@ -47,7 +47,10 @@ impl ChallengeServer {
|
||||
}
|
||||
|
||||
/// Start the challenge server on the given port.
|
||||
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn start(
|
||||
&mut self,
|
||||
port: u16,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
info!("ACME challenge server listening on port {}", port);
|
||||
@@ -101,10 +104,7 @@ impl ChallengeServer {
|
||||
pub async fn stop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
if let Some(handle) = self.handle.take() {
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
handle,
|
||||
).await;
|
||||
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
|
||||
}
|
||||
self.challenges.clear();
|
||||
self.cancel = CancellationToken::new();
|
||||
@@ -154,10 +154,14 @@ mod tests {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Fetch the challenge
|
||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
|
||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
|
||||
.await
|
||||
.unwrap();
|
||||
let io = TokioIo::new(client);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
|
||||
tokio::spawn(async move { let _ = conn.await; });
|
||||
tokio::spawn(async move {
|
||||
let _ = conn.await;
|
||||
});
|
||||
|
||||
let req = Request::get("/.well-known/acme-challenge/test-token")
|
||||
.body(Full::new(Bytes::new()))
|
||||
|
||||
+297
-140
@@ -57,24 +57,27 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use anyhow::Result;
|
||||
use tracing::{info, warn, debug, error};
|
||||
use arc_swap::ArcSwap;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
// Re-export key types
|
||||
pub use rustproxy_config;
|
||||
pub use rustproxy_routing;
|
||||
pub use rustproxy_passthrough;
|
||||
pub use rustproxy_tls;
|
||||
pub use rustproxy_http;
|
||||
pub use rustproxy_metrics;
|
||||
pub use rustproxy_passthrough;
|
||||
pub use rustproxy_routing;
|
||||
pub use rustproxy_security;
|
||||
pub use rustproxy_tls;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec};
|
||||
use rustproxy_config::{CertificateSpec, RouteConfig, RustProxyOptions, TlsMode};
|
||||
use rustproxy_metrics::{Metrics, MetricsCollector, Statistics};
|
||||
use rustproxy_passthrough::{
|
||||
ConnectionConfig, TcpListenerManager, TlsCertConfig, UdpListenerManager,
|
||||
};
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig};
|
||||
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
|
||||
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use rustproxy_security::IpBlockList;
|
||||
use rustproxy_tls::{CertBundle, CertManager, CertMetadata, CertSource, CertStore};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Certificate status.
|
||||
@@ -106,6 +109,8 @@ pub struct RustProxy {
|
||||
loaded_certs: HashMap<String, TlsCertConfig>,
|
||||
/// Cancellation token for cooperative shutdown of background tasks.
|
||||
cancel_token: CancellationToken,
|
||||
/// Shared global ingress blocklist, hot-reloadable across TCP/UDP listeners.
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
}
|
||||
|
||||
impl RustProxy {
|
||||
@@ -127,13 +132,19 @@ impl RustProxy {
|
||||
let route_manager = RouteManager::new(options.routes.clone());
|
||||
|
||||
// Set up certificate manager if ACME is configured
|
||||
let cert_manager = Self::build_cert_manager(&options)
|
||||
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
||||
let cert_manager =
|
||||
Self::build_cert_manager(&options).map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
||||
|
||||
let retention = options.metrics.as_ref()
|
||||
let retention = options
|
||||
.metrics
|
||||
.as_ref()
|
||||
.and_then(|m| m.retention_seconds)
|
||||
.unwrap_or(3600) as usize;
|
||||
|
||||
let security_policy = Arc::new(ArcSwap::from(Arc::new(Self::build_ip_block_list(
|
||||
options.security_policy.as_ref(),
|
||||
))));
|
||||
|
||||
Ok(Self {
|
||||
options,
|
||||
route_table: ArcSwap::from(Arc::new(route_manager)),
|
||||
@@ -149,6 +160,7 @@ impl RustProxy {
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
loaded_certs: HashMap::new(),
|
||||
cancel_token: CancellationToken::new(),
|
||||
security_policy,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -163,24 +175,25 @@ impl RustProxy {
|
||||
// Apply default target if route has no targets
|
||||
if route.action.targets.is_none() {
|
||||
if let Some(ref default_target) = defaults.target {
|
||||
debug!("Applying default target {}:{} to route {:?}",
|
||||
default_target.host, default_target.port,
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.action.targets = Some(vec![
|
||||
rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
||||
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
}
|
||||
]);
|
||||
debug!(
|
||||
"Applying default target {}:{} to route {:?}",
|
||||
default_target.host,
|
||||
default_target.port,
|
||||
route.name.as_deref().unwrap_or("unnamed")
|
||||
);
|
||||
route.action.targets = Some(vec![rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
||||
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,7 +212,10 @@ impl RustProxy {
|
||||
|
||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||
security.ip_allow_list = Some(
|
||||
allow_list.iter().map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone())).collect()
|
||||
allow_list
|
||||
.iter()
|
||||
.map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if let Some(ref block_list) = default_security.ip_block_list {
|
||||
@@ -208,8 +224,10 @@ impl RustProxy {
|
||||
|
||||
// Only apply if there's something meaningful
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
debug!("Applying default security to route {:?}",
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
debug!(
|
||||
"Applying default security to route {:?}",
|
||||
route.name.as_deref().unwrap_or("unnamed")
|
||||
);
|
||||
route.security = Some(security);
|
||||
}
|
||||
}
|
||||
@@ -224,13 +242,17 @@ impl RustProxy {
|
||||
return None;
|
||||
}
|
||||
|
||||
let email = acme.email.clone()
|
||||
.or_else(|| acme.account_email.clone());
|
||||
let email = acme.email.clone().or_else(|| acme.account_email.clone());
|
||||
let use_production = acme.use_production.unwrap_or(false);
|
||||
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
|
||||
|
||||
let store = CertStore::new();
|
||||
Some(CertManager::new(store, email, use_production, renew_before_days))
|
||||
Some(CertManager::new(
|
||||
store,
|
||||
email,
|
||||
use_production,
|
||||
renew_before_days,
|
||||
))
|
||||
}
|
||||
|
||||
/// Build ConnectionConfig from RustProxyOptions.
|
||||
@@ -248,7 +270,10 @@ impl RustProxy {
|
||||
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
|
||||
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
|
||||
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
|
||||
proxy_ips: options.proxy_ips.as_deref().unwrap_or(&[])
|
||||
proxy_ips: options
|
||||
.proxy_ips
|
||||
.as_deref()
|
||||
.unwrap_or(&[])
|
||||
.iter()
|
||||
.filter_map(|s| s.parse::<std::net::IpAddr>().ok())
|
||||
.collect(),
|
||||
@@ -258,6 +283,22 @@ impl RustProxy {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_ip_block_list(policy: Option<&rustproxy_config::SecurityPolicy>) -> IpBlockList {
|
||||
let Some(policy) = policy else {
|
||||
return IpBlockList::empty();
|
||||
};
|
||||
|
||||
let mut entries = Vec::new();
|
||||
if let Some(blocked_ips) = &policy.blocked_ips {
|
||||
entries.extend(blocked_ips.iter().cloned());
|
||||
}
|
||||
if let Some(blocked_cidrs) = &policy.blocked_cidrs {
|
||||
entries.extend(blocked_cidrs.iter().cloned());
|
||||
}
|
||||
|
||||
IpBlockList::new(&entries)
|
||||
}
|
||||
|
||||
/// Start the proxy, binding to all configured ports.
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
if self.started {
|
||||
@@ -272,7 +313,11 @@ impl RustProxy {
|
||||
let route_manager = self.route_table.load();
|
||||
let ports = route_manager.listening_ports();
|
||||
|
||||
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
|
||||
info!(
|
||||
"Configured {} routes on {} ports",
|
||||
route_manager.route_count(),
|
||||
ports.len()
|
||||
);
|
||||
|
||||
// Create TCP listener manager with metrics
|
||||
let mut listener = TcpListenerManager::with_metrics(
|
||||
@@ -282,7 +327,8 @@ impl RustProxy {
|
||||
|
||||
// Apply connection config from options
|
||||
let conn_config = Self::build_connection_config(&self.options);
|
||||
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
||||
debug!(
|
||||
"Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
||||
conn_config.connection_timeout_ms,
|
||||
conn_config.initial_data_timeout_ms,
|
||||
conn_config.socket_timeout_ms,
|
||||
@@ -291,6 +337,7 @@ impl RustProxy {
|
||||
// Clone proxy_ips before conn_config is moved into the TCP listener
|
||||
let udp_proxy_ips = conn_config.proxy_ips.clone();
|
||||
listener.set_connection_config(conn_config);
|
||||
listener.set_security_policy(Arc::clone(&self.security_policy));
|
||||
|
||||
// Share the socket-handler relay path with the listener
|
||||
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
|
||||
@@ -303,10 +350,13 @@ impl RustProxy {
|
||||
let cm = cm.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
tls_configs.insert(
|
||||
domain.clone(),
|
||||
TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -330,7 +380,9 @@ impl RustProxy {
|
||||
let mut tcp_ports = std::collections::HashSet::new();
|
||||
let mut udp_ports = std::collections::HashSet::new();
|
||||
for route in &self.options.routes {
|
||||
if !route.is_enabled() { continue; }
|
||||
if !route.is_enabled() {
|
||||
continue;
|
||||
}
|
||||
let transport = route.route_match.transport.as_ref();
|
||||
let route_ports = route.route_match.ports.to_ports();
|
||||
for port in route_ports {
|
||||
@@ -371,6 +423,7 @@ impl RustProxy {
|
||||
connection_registry,
|
||||
);
|
||||
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
|
||||
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
|
||||
|
||||
// Share HttpProxyService with H3 — same route matching, connection
|
||||
// pool, and ALPN protocol detection as the TCP/HTTP path.
|
||||
@@ -379,10 +432,15 @@ impl RustProxy {
|
||||
udp_mgr.set_h3_service(Arc::new(h3_svc));
|
||||
|
||||
for port in &udp_ports {
|
||||
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
|
||||
udp_mgr
|
||||
.add_port_with_tls(*port, quic_tls_config.clone())
|
||||
.await?;
|
||||
}
|
||||
info!("UDP listeners started on {} ports: {:?}",
|
||||
udp_ports.len(), udp_mgr.listening_ports());
|
||||
info!(
|
||||
"UDP listeners started on {} ports: {:?}",
|
||||
udp_ports.len(),
|
||||
udp_mgr.listening_ports()
|
||||
);
|
||||
self.udp_listener_manager = Some(udp_mgr);
|
||||
}
|
||||
|
||||
@@ -391,16 +449,22 @@ impl RustProxy {
|
||||
|
||||
// Start the throughput sampling task with cooperative cancellation
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
|
||||
let conn_tracker = self
|
||||
.listener_manager
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.conn_tracker()
|
||||
.clone();
|
||||
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
|
||||
let interval_ms = self.options.metrics.as_ref()
|
||||
let interval_ms = self
|
||||
.options
|
||||
.metrics
|
||||
.as_ref()
|
||||
.and_then(|m| m.sample_interval_ms)
|
||||
.unwrap_or(1000);
|
||||
let sampling_cancel = self.cancel_token.clone();
|
||||
self.sampling_handle = Some(tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(
|
||||
std::time::Duration::from_millis(interval_ms)
|
||||
);
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = sampling_cancel.cancelled() => break,
|
||||
@@ -442,7 +506,10 @@ impl RustProxy {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
let cert_spec = route
|
||||
.action
|
||||
.tls
|
||||
.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Auto(_)) = cert_spec {
|
||||
@@ -466,16 +533,25 @@ impl RustProxy {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
|
||||
info!(
|
||||
"Auto-provisioning certificates for {} domains",
|
||||
domains_to_provision.len()
|
||||
);
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
let acme_port = self
|
||||
.options
|
||||
.acme
|
||||
.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut challenge_server = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = challenge_server.start(acme_port).await {
|
||||
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
|
||||
error!(
|
||||
"Failed to start ACME challenge server on port {}: {}",
|
||||
acme_port, e
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -488,13 +564,15 @@ impl RustProxy {
|
||||
|
||||
if let Some(acme_client) = acme_client {
|
||||
let challenge_server_ref = &challenge_server;
|
||||
let result = acme_client.provision(domain, |pending| {
|
||||
challenge_server_ref.set_challenge(
|
||||
pending.token.clone(),
|
||||
pending.key_authorization.clone(),
|
||||
);
|
||||
async move { Ok(()) }
|
||||
}).await;
|
||||
let result = acme_client
|
||||
.provision(domain, |pending| {
|
||||
challenge_server_ref.set_challenge(
|
||||
pending.token.clone(),
|
||||
pending.key_authorization.clone(),
|
||||
);
|
||||
async move { Ok(()) }
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok((cert_pem, key_pem)) => {
|
||||
@@ -539,7 +617,10 @@ impl RustProxy {
|
||||
None => return,
|
||||
};
|
||||
|
||||
let auto_renew = self.options.acme.as_ref()
|
||||
let auto_renew = self
|
||||
.options
|
||||
.acme
|
||||
.as_ref()
|
||||
.and_then(|a| a.auto_renew)
|
||||
.unwrap_or(true);
|
||||
|
||||
@@ -547,11 +628,17 @@ impl RustProxy {
|
||||
return;
|
||||
}
|
||||
|
||||
let check_interval_hours = self.options.acme.as_ref()
|
||||
let check_interval_hours = self
|
||||
.options
|
||||
.acme
|
||||
.as_ref()
|
||||
.and_then(|a| a.renew_check_interval_hours)
|
||||
.unwrap_or(24);
|
||||
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
let acme_port = self
|
||||
.options
|
||||
.acme
|
||||
.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
@@ -664,17 +751,19 @@ impl RustProxy {
|
||||
/// Update routes atomically (hot-reload).
|
||||
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
|
||||
// Validate new routes
|
||||
rustproxy_config::validate_routes(&routes)
|
||||
.map_err(|errors| {
|
||||
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
||||
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
||||
})?;
|
||||
rustproxy_config::validate_routes(&routes).map_err(|errors| {
|
||||
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
||||
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
||||
})?;
|
||||
|
||||
let new_manager = RouteManager::new(routes.clone());
|
||||
let new_ports = new_manager.listening_ports();
|
||||
|
||||
info!("Updating routes: {} routes on {} ports",
|
||||
new_manager.route_count(), new_ports.len());
|
||||
info!(
|
||||
"Updating routes: {} routes on {} ports",
|
||||
new_manager.route_count(),
|
||||
new_ports.len()
|
||||
);
|
||||
|
||||
// Get old ports
|
||||
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
|
||||
@@ -684,28 +773,35 @@ impl RustProxy {
|
||||
};
|
||||
|
||||
// Prune per-route metrics for route IDs that no longer exist
|
||||
let active_route_ids: HashSet<String> = routes.iter()
|
||||
.filter_map(|r| r.id.clone())
|
||||
.collect();
|
||||
let active_route_ids: HashSet<String> =
|
||||
routes.iter().filter_map(|r| r.id.clone()).collect();
|
||||
self.metrics.retain_routes(&active_route_ids);
|
||||
|
||||
// Prune per-backend metrics for backends no longer in any route target.
|
||||
// For PortSpec::Preserve routes, expand across all listening ports since
|
||||
// the actual runtime port depends on the incoming connection.
|
||||
let listening_ports = self.get_listening_ports();
|
||||
let active_backends: HashSet<String> = routes.iter()
|
||||
let active_backends: HashSet<String> = routes
|
||||
.iter()
|
||||
.filter_map(|r| r.action.targets.as_ref())
|
||||
.flat_map(|targets| targets.iter())
|
||||
.flat_map(|target| {
|
||||
let hosts: Vec<String> = target.host.to_vec().into_iter().map(|s| s.to_string()).collect();
|
||||
let hosts: Vec<String> = target
|
||||
.host
|
||||
.to_vec()
|
||||
.into_iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
match &target.port {
|
||||
rustproxy_config::PortSpec::Fixed(p) => {
|
||||
hosts.into_iter().map(|h| format!("{}:{}", h, p)).collect::<Vec<_>>()
|
||||
}
|
||||
rustproxy_config::PortSpec::Fixed(p) => hosts
|
||||
.into_iter()
|
||||
.map(|h| format!("{}:{}", h, p))
|
||||
.collect::<Vec<_>>(),
|
||||
_ => {
|
||||
// Preserve/special: expand across all listening ports
|
||||
let lp = &listening_ports;
|
||||
hosts.into_iter()
|
||||
hosts
|
||||
.into_iter()
|
||||
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
@@ -733,10 +829,13 @@ impl RustProxy {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
tls_configs.insert(
|
||||
domain.clone(),
|
||||
TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -753,7 +852,9 @@ impl RustProxy {
|
||||
// Cancel connections on routes that were removed or disabled
|
||||
listener.invalidate_removed_routes(&active_route_ids);
|
||||
// Clean up registry entries for removed routes
|
||||
listener.connection_registry().cleanup_removed_routes(&active_route_ids);
|
||||
listener
|
||||
.connection_registry()
|
||||
.cleanup_removed_routes(&active_route_ids);
|
||||
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
|
||||
listener.prune_http_proxy_caches(&active_route_ids);
|
||||
|
||||
@@ -766,9 +867,10 @@ impl RustProxy {
|
||||
None => continue,
|
||||
};
|
||||
// Find corresponding old route
|
||||
let old_route = old_manager.routes().iter().find(|r| {
|
||||
r.id.as_deref() == Some(new_id)
|
||||
});
|
||||
let old_route = old_manager
|
||||
.routes()
|
||||
.iter()
|
||||
.find(|r| r.id.as_deref() == Some(new_id));
|
||||
let old_route = match old_route {
|
||||
Some(r) => r,
|
||||
None => continue, // new route, no existing connections to recycle
|
||||
@@ -812,11 +914,13 @@ impl RustProxy {
|
||||
{
|
||||
let mut new_udp_ports = HashSet::new();
|
||||
for route in &routes {
|
||||
if !route.is_enabled() { continue; }
|
||||
if !route.is_enabled() {
|
||||
continue;
|
||||
}
|
||||
let transport = route.route_match.transport.as_ref();
|
||||
match transport {
|
||||
Some(rustproxy_config::TransportProtocol::Udp) |
|
||||
Some(rustproxy_config::TransportProtocol::All) => {
|
||||
Some(rustproxy_config::TransportProtocol::Udp)
|
||||
| Some(rustproxy_config::TransportProtocol::All) => {
|
||||
for port in route.route_match.ports.to_ports() {
|
||||
new_udp_ports.insert(port);
|
||||
}
|
||||
@@ -825,7 +929,8 @@ impl RustProxy {
|
||||
}
|
||||
}
|
||||
|
||||
let old_udp_ports: HashSet<u16> = self.udp_listener_manager
|
||||
let old_udp_ports: HashSet<u16> = self
|
||||
.udp_listener_manager
|
||||
.as_ref()
|
||||
.map(|u| u.listening_ports().into_iter().collect())
|
||||
.unwrap_or_default();
|
||||
@@ -847,6 +952,7 @@ impl RustProxy {
|
||||
connection_registry,
|
||||
);
|
||||
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
|
||||
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
|
||||
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
|
||||
let http_proxy = listener.http_proxy().clone();
|
||||
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
|
||||
@@ -898,56 +1004,77 @@ impl RustProxy {
|
||||
|
||||
/// Provision a certificate for a named route.
|
||||
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
let cm_arc = self.cert_manager.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
|
||||
let cm_arc = self.cert_manager.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("No certificate manager configured (ACME not enabled)")
|
||||
})?;
|
||||
|
||||
// Find the route by name
|
||||
let route = self.options.routes.iter()
|
||||
let route = self
|
||||
.options
|
||||
.routes
|
||||
.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
let domain = route
|
||||
.route_match
|
||||
.domains
|
||||
.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
|
||||
|
||||
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
|
||||
info!(
|
||||
"Provisioning certificate for route '{}' (domain: {})",
|
||||
route_name, domain
|
||||
);
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
let acme_port = self
|
||||
.options
|
||||
.acme
|
||||
.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
cs.start(acme_port).await
|
||||
cs.start(acme_port)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
|
||||
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(&domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
let result = cm
|
||||
.renew_domain(&domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
})
|
||||
.await;
|
||||
drop(cm);
|
||||
|
||||
cs.stop().await;
|
||||
|
||||
let bundle = result
|
||||
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||
let bundle = result.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||
|
||||
// Hot-swap into TLS configs
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
tls_configs.insert(
|
||||
domain.clone(),
|
||||
TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
},
|
||||
);
|
||||
{
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
tls_configs.insert(
|
||||
d.clone(),
|
||||
TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -966,7 +1093,10 @@ impl RustProxy {
|
||||
}
|
||||
}
|
||||
|
||||
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
||||
info!(
|
||||
"Certificate provisioned and loaded for route '{}'",
|
||||
route_name
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -978,10 +1108,16 @@ impl RustProxy {
|
||||
|
||||
/// Get the status of a certificate for a named route.
|
||||
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
|
||||
let route = self.options.routes.iter()
|
||||
let route = self
|
||||
.options
|
||||
.routes
|
||||
.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
let domain = route
|
||||
.route_match
|
||||
.domains
|
||||
.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
|
||||
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
@@ -1010,8 +1146,9 @@ impl RustProxy {
|
||||
let mut metrics = self.metrics.snapshot();
|
||||
if let Some(ref lm) = self.listener_manager {
|
||||
let entries = lm.http_proxy().protocol_cache_snapshot();
|
||||
metrics.detected_protocols = entries.into_iter().map(|e| {
|
||||
rustproxy_metrics::ProtocolCacheEntryMetric {
|
||||
metrics.detected_protocols = entries
|
||||
.into_iter()
|
||||
.map(|e| rustproxy_metrics::ProtocolCacheEntryMetric {
|
||||
host: e.host,
|
||||
port: e.port,
|
||||
domain: e.domain,
|
||||
@@ -1026,8 +1163,8 @@ impl RustProxy {
|
||||
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
|
||||
h2_consecutive_failures: e.h2_consecutive_failures,
|
||||
h3_consecutive_failures: e.h3_consecutive_failures,
|
||||
}
|
||||
}).collect();
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
metrics
|
||||
}
|
||||
@@ -1058,9 +1195,7 @@ impl RustProxy {
|
||||
|
||||
/// Get statistics snapshot.
|
||||
pub fn get_statistics(&self) -> Statistics {
|
||||
let uptime = self.started_at
|
||||
.map(|t| t.elapsed().as_secs())
|
||||
.unwrap_or(0);
|
||||
let uptime = self.started_at.map(|t| t.elapsed().as_secs()).unwrap_or(0);
|
||||
|
||||
Statistics {
|
||||
active_connections: self.metrics.active_connections(),
|
||||
@@ -1071,6 +1206,13 @@ impl RustProxy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the global ingress security policy.
|
||||
pub fn set_security_policy(&mut self, policy: rustproxy_config::SecurityPolicy) {
|
||||
self.security_policy
|
||||
.store(Arc::new(Self::build_ip_block_list(Some(&policy))));
|
||||
self.options.security_policy = Some(policy);
|
||||
}
|
||||
|
||||
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
|
||||
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
|
||||
/// take effect immediately for all new connections.
|
||||
@@ -1130,10 +1272,13 @@ impl RustProxy {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !configs.contains_key(d) {
|
||||
configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
configs.insert(
|
||||
d.clone(),
|
||||
TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1166,7 +1311,8 @@ impl RustProxy {
|
||||
info!("Loading certificate for domain: {}", domain);
|
||||
|
||||
// Check if the cert actually changed (for selective connection recycling)
|
||||
let cert_changed = self.loaded_certs
|
||||
let cert_changed = self
|
||||
.loaded_certs
|
||||
.get(domain)
|
||||
.map(|existing| existing.cert_pem != cert_pem)
|
||||
.unwrap_or(false); // new domain = no existing connections to recycle
|
||||
@@ -1196,10 +1342,13 @@ impl RustProxy {
|
||||
}
|
||||
|
||||
// Persist in loaded_certs so future rebuild calls include this cert
|
||||
self.loaded_certs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
});
|
||||
self.loaded_certs.insert(
|
||||
domain.to_string(),
|
||||
TlsCertConfig {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
// Hot-swap TLS config on TCP and QUIC listeners
|
||||
let tls_configs = self.current_tls_configs().await;
|
||||
@@ -1222,7 +1371,9 @@ impl RustProxy {
|
||||
// Recycle existing connections if cert actually changed
|
||||
if cert_changed {
|
||||
if let Some(ref listener) = self.listener_manager {
|
||||
listener.connection_registry().recycle_for_cert_change(domain);
|
||||
listener
|
||||
.connection_registry()
|
||||
.recycle_for_cert_change(domain);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1244,16 +1395,22 @@ impl RustProxy {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
let cert_spec = route
|
||||
.action
|
||||
.tls
|
||||
.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_config.cert.clone(),
|
||||
key_pem: cert_config.key.clone(),
|
||||
});
|
||||
configs.insert(
|
||||
domain.to_string(),
|
||||
TlsCertConfig {
|
||||
cert_pem: cert_config.cert.clone(),
|
||||
key_pem: cert_config.key.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#[global_allocator]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use anyhow::Result;
|
||||
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy::management;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
|
||||
/// RustProxy - High-performance multi-protocol proxy
|
||||
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
|
||||
)
|
||||
.init();
|
||||
|
||||
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
|
||||
let options = RustProxyOptions::from_file(&cli.config)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
|
||||
|
||||
tracing::info!(
|
||||
"Loaded {} routes from {}",
|
||||
options.routes.len(),
|
||||
cli.config
|
||||
);
|
||||
tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
|
||||
|
||||
// Validate-only mode
|
||||
if cli.validate {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
@@ -141,14 +141,19 @@ async fn handle_request(
|
||||
"start" => handle_start(&id, &request.params, proxy).await,
|
||||
"stop" => handle_stop(&id, proxy).await,
|
||||
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
|
||||
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
|
||||
"getMetrics" => handle_get_metrics(&id, proxy),
|
||||
"getStatistics" => handle_get_statistics(&id, proxy),
|
||||
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
|
||||
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
|
||||
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
|
||||
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
|
||||
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
|
||||
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
|
||||
"setSocketHandlerRelay" => {
|
||||
handle_set_socket_handler_relay(&id, &request.params, proxy).await
|
||||
}
|
||||
"setDatagramHandlerRelay" => {
|
||||
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
|
||||
}
|
||||
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
|
||||
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
|
||||
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
|
||||
@@ -167,7 +172,12 @@ async fn handle_start(
|
||||
|
||||
let config = match params.get("config") {
|
||||
Some(config) => config,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'config' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
|
||||
@@ -176,38 +186,31 @@ async fn handle_start(
|
||||
};
|
||||
|
||||
match RustProxy::new(options) {
|
||||
Ok(mut p) => {
|
||||
match p.start().await {
|
||||
Ok(()) => {
|
||||
send_event("started", serde_json::json!({}));
|
||||
*proxy = Some(p);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => {
|
||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||
}
|
||||
Ok(mut p) => match p.start().await {
|
||||
Ok(()) => {
|
||||
send_event("started", serde_json::json!({}));
|
||||
*proxy = Some(p);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||
}
|
||||
},
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stop(
|
||||
id: &str,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_mut() {
|
||||
Some(p) => {
|
||||
match p.stop().await {
|
||||
Ok(()) => {
|
||||
*proxy = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||
Some(p) => match p.stop().await {
|
||||
Ok(()) => {
|
||||
*proxy = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||
},
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
@@ -224,7 +227,12 @@ async fn handle_update_routes(
|
||||
|
||||
let routes = match params.get("routes") {
|
||||
Some(routes) => routes,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routes' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
|
||||
@@ -234,36 +242,72 @@ async fn handle_update_routes(
|
||||
|
||||
match p.update_routes(routes).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
|
||||
Err(e) => {
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_metrics(
|
||||
fn handle_set_security_policy(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let policy = match params.get("policy") {
|
||||
Some(policy) => policy,
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'policy' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
|
||||
Ok(policy) => policy,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Invalid security policy: {}", e),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
p.set_security_policy(policy);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let metrics = p.get_metrics();
|
||||
match serde_json::to_value(&metrics) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to serialize metrics: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_statistics(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let stats = p.get_statistics();
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to serialize statistics: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routeName' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match p.provision_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to provision certificate: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routeName' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match p.renew_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to renew certificate: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routeName' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match p.get_certificate_status(route_name).await {
|
||||
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
|
||||
"domain": status.domain,
|
||||
"source": status.source,
|
||||
"expiresAt": status.expires_at,
|
||||
"isValid": status.is_valid,
|
||||
})),
|
||||
Some(status) => ManagementResponse::ok(
|
||||
id.to_string(),
|
||||
serde_json::json!({
|
||||
"domain": status.domain,
|
||||
"source": status.source,
|
||||
"expiresAt": status.expires_at,
|
||||
"isValid": status.is_valid,
|
||||
}),
|
||||
),
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_listening_ports(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let ports = p.get_listening_ports();
|
||||
@@ -361,7 +426,8 @@ async fn handle_set_socket_handler_relay(
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let socket_path = params.get("socketPath")
|
||||
let socket_path = params
|
||||
.get("socketPath")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
@@ -381,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let socket_path = params.get("socketPath")
|
||||
let socket_path = params
|
||||
.get("socketPath")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
match p.add_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to add port {}: {}", port, e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
match p.remove_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to remove port {}: {}", port, e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
|
||||
|
||||
let domain = match params.get("domain").and_then(|v| v.as_str()) {
|
||||
Some(d) => d.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'domain' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let cert = match params.get("cert").and_then(|v| v.as_str()) {
|
||||
Some(c) => c.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
let key = match params.get("key").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||
let ca = params
|
||||
.get("ca")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
info!("loadCertificate: domain={}", domain);
|
||||
|
||||
// Load cert into cert manager and hot-swap TLS config
|
||||
match p.load_certificate(&domain, cert, key, ca).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to load certificate for {}: {}", domain, e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
|
||||
let path = parts.get(1).copied().unwrap_or("/");
|
||||
|
||||
// Extract Host header
|
||||
let host = req_str.lines()
|
||||
let host = req_str
|
||||
.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("host:"))
|
||||
.map(|l| l[5..].trim())
|
||||
.unwrap_or("unknown");
|
||||
@@ -336,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
|
||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Extract Sec-WebSocket-Key for proper handshake
|
||||
let ws_key = req_str.lines()
|
||||
let ws_key = req_str
|
||||
.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
||||
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||
.unwrap_or_default();
|
||||
@@ -378,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
@@ -458,11 +462,7 @@ pub fn make_tls_terminate_route(
|
||||
|
||||
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
|
||||
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
|
||||
pub async fn start_tls_ws_echo_backend(
|
||||
port: u16,
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> JoinHandle<()> {
|
||||
pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
|
||||
use std::sync::Arc;
|
||||
|
||||
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
mod common;
|
||||
|
||||
use bytes::Buf;
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
|
||||
use bytes::Buf;
|
||||
use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
|
||||
@@ -14,7 +14,14 @@ fn make_h3_route(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
|
||||
let mut route = make_tls_terminate_route(
|
||||
port,
|
||||
"localhost",
|
||||
target_host,
|
||||
target_port,
|
||||
cert_pem,
|
||||
key_pem,
|
||||
);
|
||||
route.route_match.transport = Some(TransportProtocol::All);
|
||||
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
|
||||
route.action.udp = Some(RouteUdp {
|
||||
@@ -89,11 +96,9 @@ async fn test_h3_response_stream_finishes() {
|
||||
.await
|
||||
.expect("QUIC handshake failed");
|
||||
|
||||
let (mut driver, mut send_request) = h3::client::new(
|
||||
h3_quinn::Connection::new(connection),
|
||||
)
|
||||
.await
|
||||
.expect("H3 connection setup failed");
|
||||
let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
|
||||
.await
|
||||
.expect("H3 connection setup failed");
|
||||
|
||||
// Drive the H3 connection in background
|
||||
tokio::spawn(async move {
|
||||
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
|
||||
.body(())
|
||||
.unwrap();
|
||||
|
||||
let mut stream = send_request.send_request(req).await
|
||||
let mut stream = send_request
|
||||
.send_request(req)
|
||||
.await
|
||||
.expect("Failed to send H3 request");
|
||||
stream.finish().await
|
||||
stream
|
||||
.finish()
|
||||
.await
|
||||
.expect("Failed to finish sending H3 request body");
|
||||
|
||||
// 6. Read response headers
|
||||
let resp = stream.recv_response().await
|
||||
let resp = stream
|
||||
.recv_response()
|
||||
.await
|
||||
.expect("Failed to receive H3 response");
|
||||
assert_eq!(resp.status(), http::StatusCode::OK,
|
||||
"Expected 200 OK, got {}", resp.status());
|
||||
assert_eq!(
|
||||
resp.status(),
|
||||
http::StatusCode::OK,
|
||||
"Expected 200 OK, got {}",
|
||||
resp.status()
|
||||
);
|
||||
|
||||
// 7. Read body and verify stream ends (FIN received)
|
||||
// This is the critical assertion: recv_data() must return None (stream ended)
|
||||
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
|
||||
let result = with_timeout(async {
|
||||
let mut total = 0usize;
|
||||
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
||||
total += chunk.remaining();
|
||||
}
|
||||
// recv_data() returned None => stream ended (FIN received)
|
||||
total
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut total = 0usize;
|
||||
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
||||
total += chunk.remaining();
|
||||
}
|
||||
// recv_data() returned None => stream ended (FIN received)
|
||||
total
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await;
|
||||
|
||||
let bytes_received = result.expect(
|
||||
"TIMEOUT: H3 stream never ended (FIN not received by client). \
|
||||
The proxy sent all response data but failed to send the QUIC stream FIN."
|
||||
The proxy sent all response data but failed to send the QUIC stream FIN.",
|
||||
);
|
||||
assert_eq!(
|
||||
bytes_received,
|
||||
|
||||
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||
let body = extract_body(&response);
|
||||
body.to_string()
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||
let body = extract_body(&response);
|
||||
body.to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
|
||||
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
|
||||
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
|
||||
assert!(
|
||||
result.contains(r#""method":"GET"#),
|
||||
"Expected GET method, got: {}",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
result.contains(r#""path":"/hello"#),
|
||||
"Expected /hello path, got: {}",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
result.contains(r#""backend":"main"#),
|
||||
"Expected main backend, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
|
||||
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
|
||||
make_test_route(
|
||||
proxy_port,
|
||||
Some("alpha.example.com"),
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
),
|
||||
make_test_route(
|
||||
proxy_port,
|
||||
Some("beta.example.com"),
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let alpha_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let alpha_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
|
||||
assert!(
|
||||
alpha_result.contains(r#""backend":"alpha"#),
|
||||
"Expected alpha backend, got: {}",
|
||||
alpha_result
|
||||
);
|
||||
|
||||
// Test beta domain
|
||||
let beta_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let beta_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
|
||||
assert!(
|
||||
beta_result.contains(r#""backend":"beta"#),
|
||||
"Expected beta backend, got: {}",
|
||||
beta_result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test API path
|
||||
let api_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let api_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
|
||||
assert!(
|
||||
api_result.contains(r#""backend":"api"#),
|
||||
"Expected api backend, got: {}",
|
||||
api_result
|
||||
);
|
||||
|
||||
// Test web path (no /api prefix)
|
||||
let web_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let web_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
|
||||
assert!(
|
||||
web_result.contains(r#""backend":"web"#),
|
||||
"Expected web backend, got: {}",
|
||||
web_result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
|
||||
.unwrap();
|
||||
|
||||
// Should get 204 No Content with CORS headers
|
||||
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
|
||||
assert!(result.to_lowercase().contains("access-control-allow-origin"),
|
||||
"Expected CORS header, got: {}", result);
|
||||
assert!(
|
||||
result.contains("204"),
|
||||
"Expected 204 status, got: {}",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
result
|
||||
.to_lowercase()
|
||||
.contains("access-control-allow-origin"),
|
||||
"Expected CORS header, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||
response
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||
response
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Proxy should relay the 500 from backend
|
||||
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
|
||||
assert!(
|
||||
result.contains("500"),
|
||||
"Expected 500 status, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
|
||||
|
||||
// Create a route only for a specific domain
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
|
||||
routes: vec![make_test_route(
|
||||
proxy_port,
|
||||
Some("known.example.com"),
|
||||
"127.0.0.1",
|
||||
9999,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||
response
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (no route matched)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
assert!(
|
||||
result.contains("502"),
|
||||
"Expected 502 status, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||
response
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (backend unavailable)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
assert!(
|
||||
result.contains("502"),
|
||||
"Expected 502 status, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -295,38 +388,53 @@ async fn test_https_terminate_http_forward() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send HTTP request through TLS
|
||||
let request = format!(
|
||||
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
domain
|
||||
);
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
// Send HTTP request through TLS
|
||||
let request = format!(
|
||||
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
domain
|
||||
);
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = extract_body(&result);
|
||||
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
|
||||
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
|
||||
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
|
||||
assert!(
|
||||
body.contains(r#""method":"GET"#),
|
||||
"Expected GET, got: {}",
|
||||
body
|
||||
);
|
||||
assert!(
|
||||
body.contains(r#""path":"/api/data"#),
|
||||
"Expected /api/data, got: {}",
|
||||
body
|
||||
);
|
||||
assert!(
|
||||
body.contains(r#""backend":"tls-backend"#),
|
||||
"Expected tls-backend, got: {}",
|
||||
body
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -347,59 +455,68 @@ async fn test_websocket_through_proxy() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send WebSocket upgrade request
|
||||
let request = format!(
|
||||
"GET /ws HTTP/1.1\r\n\
|
||||
// Send WebSocket upgrade request
|
||||
let request = format!(
|
||||
"GET /ws HTTP/1.1\r\n\
|
||||
Host: example.com\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
\r\n"
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
// Read the 101 response
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
let n = stream.read(&mut temp).await.unwrap();
|
||||
if n == 0 { break; }
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
// Read the 101 response
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
let n = stream.read(&mut temp).await.unwrap();
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len - 4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
|
||||
assert!(
|
||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||
"Expected Upgrade header, got: {}",
|
||||
response_str
|
||||
);
|
||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||
assert!(
|
||||
response_str.contains("101"),
|
||||
"Expected 101 Switching Protocols, got: {}",
|
||||
response_str
|
||||
);
|
||||
assert!(
|
||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||
"Expected Upgrade header, got: {}",
|
||||
response_str
|
||||
);
|
||||
|
||||
// After upgrade, send data and verify echo
|
||||
let test_data = b"Hello WebSocket!";
|
||||
stream.write_all(test_data).await.unwrap();
|
||||
// After upgrade, send data and verify echo
|
||||
let test_data = b"Hello WebSocket!";
|
||||
stream.write_all(test_data).await.unwrap();
|
||||
|
||||
// Read echoed data
|
||||
let mut echo_buf = vec![0u8; 256];
|
||||
let n = stream.read(&mut echo_buf).await.unwrap();
|
||||
let echoed = &echo_buf[..n];
|
||||
// Read echoed data
|
||||
let mut echo_buf = vec![0u8; 256];
|
||||
let n = stream.read(&mut echo_buf).await.unwrap();
|
||||
let echoed = &echo_buf[..n];
|
||||
|
||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||
|
||||
"ok".to_string()
|
||||
}, 10)
|
||||
"ok".to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
||||
|
||||
// Create terminate-and-reencrypt routes
|
||||
let mut route1 = make_tls_terminate_route(
|
||||
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
|
||||
proxy_port,
|
||||
"alpha.example.com",
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
&cert1,
|
||||
&key1,
|
||||
);
|
||||
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
let mut route2 = make_tls_terminate_route(
|
||||
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
|
||||
proxy_port,
|
||||
"beta.example.com",
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
&cert2,
|
||||
&key2,
|
||||
);
|
||||
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
@@ -450,27 +577,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
|
||||
let alpha_result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let alpha_result = with_timeout(
|
||||
async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
let request =
|
||||
"GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -498,27 +630,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
||||
);
|
||||
|
||||
// Test beta domain - different host goes to different backend
|
||||
let beta_result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let beta_result = with_timeout(
|
||||
async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
let request =
|
||||
"GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector =
|
||||
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send WebSocket upgrade request through TLS
|
||||
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// HTTP request should match the route and get proxied
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
|
||||
assert!(!wait_for_port(port, 200).await);
|
||||
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
|
||||
assert!(
|
||||
wait_for_port(port, 2000).await,
|
||||
"Port should be listening after start"
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
|
||||
// Give the OS a moment to release the port
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
|
||||
assert!(
|
||||
!wait_for_port(port, 200).await,
|
||||
"Port should not be listening after stop"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
|
||||
routes: vec![make_test_route(
|
||||
port,
|
||||
Some("old.example.com"),
|
||||
"127.0.0.1",
|
||||
8080,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Update routes atomically
|
||||
let new_routes = vec![
|
||||
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
|
||||
];
|
||||
let new_routes = vec![make_test_route(
|
||||
port,
|
||||
Some("new.example.com"),
|
||||
"127.0.0.1",
|
||||
9090,
|
||||
)];
|
||||
let result = proxy.update_routes(new_routes).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
|
||||
|
||||
// Add a new port
|
||||
proxy.add_listening_port(port2).await.unwrap();
|
||||
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
|
||||
assert!(
|
||||
wait_for_port(port2, 2000).await,
|
||||
"New port should be listening"
|
||||
);
|
||||
|
||||
// Remove the port
|
||||
proxy.remove_listening_port(port2).await.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
|
||||
assert!(
|
||||
!wait_for_port(port2, 200).await,
|
||||
"Removed port should not be listening"
|
||||
);
|
||||
|
||||
// Original port should still be listening
|
||||
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
|
||||
assert!(
|
||||
wait_for_port(port1, 200).await,
|
||||
"Original port should still be listening"
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
|
||||
assert!(
|
||||
stats.total_connections > 0,
|
||||
"Expected total_connections > 0, got {}",
|
||||
stats.total_connections
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0,
|
||||
"Expected some connections tracked, got {}", stats.total_connections);
|
||||
assert!(
|
||||
stats.total_connections > 0,
|
||||
"Expected some connections tracked, got {}",
|
||||
stats.total_connections
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port1, 2000).await);
|
||||
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
|
||||
assert!(
|
||||
!wait_for_port(port2, 200).await,
|
||||
"port2 should not be listening yet"
|
||||
);
|
||||
|
||||
// Update routes to use port2 instead
|
||||
let new_routes = vec![
|
||||
make_test_route(port2, None, "127.0.0.1", backend_port),
|
||||
];
|
||||
let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
|
||||
proxy.update_routes(new_routes).await.unwrap();
|
||||
|
||||
// Port2 should now be listening, port1 should be closed
|
||||
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
|
||||
assert!(
|
||||
wait_for_port(port2, 2000).await,
|
||||
"port2 should be listening after reload"
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
|
||||
assert!(
|
||||
!wait_for_port(port1, 200).await,
|
||||
"port1 should be closed after reload"
|
||||
);
|
||||
|
||||
// Verify port2 works
|
||||
let ports = proxy.get_listening_ports();
|
||||
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
|
||||
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
|
||||
assert!(
|
||||
ports.contains(&port2),
|
||||
"Expected port2 in listening ports: {:?}",
|
||||
ports
|
||||
);
|
||||
assert!(
|
||||
!ports.contains(&port1),
|
||||
"port1 should not be in listening ports: {:?}",
|
||||
ports
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
@@ -24,19 +24,25 @@ async fn test_tcp_forward_echo() {
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Wait for proxy to be ready
|
||||
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
|
||||
assert!(
|
||||
wait_for_port(proxy_port, 2000).await,
|
||||
"Proxy port not ready"
|
||||
);
|
||||
|
||||
// Connect and send data
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello world").await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello world").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -61,21 +67,24 @@ async fn test_tcp_forward_large_payload() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'A'; 1_000_000];
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
// Send 1MB of data
|
||||
let data = vec![b'A'; 1_000_000];
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
|
||||
// Read all back
|
||||
let mut received = Vec::new();
|
||||
stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 10)
|
||||
// Read all back
|
||||
let mut received = Vec::new();
|
||||
stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -100,29 +109,32 @@ async fn test_tcp_forward_multiple_connections() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let msg = format!("connection-{}", i);
|
||||
stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let msg = format!("connection-{}", i);
|
||||
stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 10)
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connection should complete (proxy accepts it) but data should not flow
|
||||
let result = with_timeout(async {
|
||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||
stream.is_ok()
|
||||
}, 5)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||
stream.is_ok()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result, "Should be able to connect to proxy even if backend is down");
|
||||
assert!(
|
||||
result,
|
||||
"Should be able to connect to proxy even if backend is down"
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -178,16 +196,19 @@ async fn test_tcp_forward_bidirectional() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"test data").await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"test data").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
|
||||
make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("one.example.com"),
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
),
|
||||
make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("two.example.com"),
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -76,39 +86,53 @@ async fn test_tls_passthrough_sni_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send a fake ClientHello with SNI "one.example.com"
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("one.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("one.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Backend1 should have received the ClientHello and prefixed its response
|
||||
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
|
||||
assert!(
|
||||
result.starts_with("BACKEND1:"),
|
||||
"Expected BACKEND1 prefix, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
// Now test routing to backend2
|
||||
let result2 = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("two.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result2 = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("two.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
|
||||
assert!(
|
||||
result2.starts_with("BACKEND2:"),
|
||||
"Expected BACKEND2 prefix, got: {}",
|
||||
result2
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
routes: vec![make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("known.example.com"),
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -132,21 +159,24 @@ async fn test_tls_passthrough_unknown_sni() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("unknown.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("unknown.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
// Should either get 0 bytes (closed) or an error
|
||||
match stream.read(&mut buf).await {
|
||||
Ok(0) => true, // Connection closed = no route matched
|
||||
Ok(_) => false, // Got data = route shouldn't have matched
|
||||
Err(_) => true, // Error = connection dropped
|
||||
}
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
// Should either get 0 bytes (closed) or an error
|
||||
match stream.read(&mut buf).await {
|
||||
Ok(0) => true, // Connection closed = no route matched
|
||||
Ok(_) => false, // Got data = route shouldn't have matched
|
||||
Err(_) => true, // Error = connection dropped
|
||||
}
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
|
||||
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
routes: vec![make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("*.example.com"),
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -174,21 +207,28 @@ async fn test_tls_passthrough_wildcard_domain() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Should match any subdomain of example.com
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("anything.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("anything.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
|
||||
assert!(
|
||||
result.starts_with("WILDCARD:"),
|
||||
"Expected WILDCARD prefix, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -222,24 +262,29 @@ async fn test_tls_passthrough_multiple_domains() {
|
||||
("beta.example.com", "B2:"),
|
||||
("gamma.example.com", "B3:"),
|
||||
] {
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello(domain);
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello(domain);
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.starts_with(expected_prefix),
|
||||
"Domain {} should route to {}, got: {}",
|
||||
domain, expected_prefix, result
|
||||
domain,
|
||||
expected_prefix,
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -84,23 +89,26 @@ async fn test_tls_terminate_basic() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connect with TLS client
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello TLS").await.unwrap();
|
||||
tls_stream.write_all(b"hello TLS").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
|
||||
|
||||
// Create terminate-and-reencrypt route
|
||||
let mut route = make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&proxy_cert,
|
||||
&proxy_key,
|
||||
);
|
||||
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
@@ -138,23 +151,26 @@ async fn test_tls_terminate_and_reencrypt() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello reencrypt").await.unwrap();
|
||||
tls_stream.write_all(b"hello reencrypt").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
|
||||
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
|
||||
make_tls_terminate_route(
|
||||
proxy_port,
|
||||
"alpha.example.com",
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
&cert1,
|
||||
&key1,
|
||||
),
|
||||
make_tls_terminate_route(
|
||||
proxy_port,
|
||||
"beta.example.com",
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
&cert2,
|
||||
&key2,
|
||||
),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -188,27 +218,35 @@ async fn test_tls_terminate_sni_cert_selection() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"test").await.unwrap();
|
||||
tls_stream.write_all(b"test").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
|
||||
assert!(
|
||||
result.starts_with("ALPHA:"),
|
||||
"Expected ALPHA prefix, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -233,26 +276,29 @@ async fn test_tls_terminate_large_payload() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'X'; 1_000_000];
|
||||
tls_stream.write_all(&data).await.unwrap();
|
||||
tls_stream.shutdown().await.unwrap();
|
||||
// Send 1MB of data
|
||||
let data = vec![b'X'; 1_000_000];
|
||||
tls_stream.write_all(&data).await.unwrap();
|
||||
tls_stream.shutdown().await.unwrap();
|
||||
|
||||
let mut received = Vec::new();
|
||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 15)
|
||||
let mut received = Vec::new();
|
||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
},
|
||||
15,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -281,37 +332,40 @@ async fn test_tls_terminate_concurrent() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
let dom = domain.to_string();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
let dom = domain.to_string();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let msg = format!("conn-{}", i);
|
||||
tls_stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
let msg = format!("conn-{}", i);
|
||||
tls_stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 15)
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
},
|
||||
15,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import * as tls from 'tls';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
import { assertPortsFree, findFreePorts } from './helpers/port-allocator.js';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
|
||||
const CERT_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8');
|
||||
const KEY_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8');
|
||||
|
||||
let httpBackendPort: number;
|
||||
let tlsBackendPort: number;
|
||||
let httpProxyPort: number;
|
||||
let tlsProxyPort: number;
|
||||
|
||||
let httpBackend: http.Server;
|
||||
let tlsBackend: tls.Server;
|
||||
let proxy: SmartProxy;
|
||||
|
||||
async function pollMetrics(proxyToPoll: SmartProxy): Promise<void> {
|
||||
await (proxyToPoll as any).metricsAdapter.poll();
|
||||
}
|
||||
|
||||
async function waitForCondition(
|
||||
callback: () => Promise<boolean>,
|
||||
timeoutMs: number = 5000,
|
||||
stepMs: number = 100,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await callback()) {
|
||||
return;
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, stepMs));
|
||||
}
|
||||
throw new Error(`Condition not met within ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
function hasIpDomainRequest(domain: string): boolean {
|
||||
const byIp = proxy.getMetrics().connections.domainRequestsByIP();
|
||||
for (const domainMap of byIp.values()) {
|
||||
if (domainMap.has(domain)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
tap.test('setup - backend servers for HTTP domain rate metrics', async () => {
|
||||
[httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort] = await findFreePorts(4);
|
||||
|
||||
httpBackend = http.createServer((req, res) => {
|
||||
let body = '';
|
||||
req.on('data', (chunk) => {
|
||||
body += chunk;
|
||||
});
|
||||
req.on('end', () => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/plain' });
|
||||
res.end(`ok:${body}`);
|
||||
});
|
||||
});
|
||||
await new Promise<void>((resolve) => {
|
||||
httpBackend.listen(httpBackendPort, () => resolve());
|
||||
});
|
||||
|
||||
tlsBackend = tls.createServer({ cert: CERT_PEM, key: KEY_PEM }, (socket) => {
|
||||
socket.on('data', (data) => {
|
||||
socket.write(data);
|
||||
});
|
||||
socket.on('error', () => {});
|
||||
});
|
||||
await new Promise<void>((resolve) => {
|
||||
tlsBackend.listen(tlsBackendPort, () => resolve());
|
||||
});
|
||||
});
|
||||
|
||||
tap.test('setup - start proxy with HTTP and TLS passthrough routes', async () => {
|
||||
proxy = new SmartProxy({
|
||||
routes: [
|
||||
{
|
||||
id: 'http-domain-rates',
|
||||
name: 'http-domain-rates',
|
||||
match: { ports: httpProxyPort, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: httpBackendPort }],
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'tls-passthrough-domain-rates',
|
||||
name: 'tls-passthrough-domain-rates',
|
||||
match: { ports: tlsProxyPort, domains: 'passthrough.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
tls: { mode: 'passthrough' },
|
||||
targets: [{ host: 'localhost', port: tlsBackendPort }],
|
||||
},
|
||||
},
|
||||
],
|
||||
metrics: { enabled: true, sampleIntervalMs: 100, retentionSeconds: 60 },
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
});
|
||||
|
||||
tap.test('HTTP requests populate per-domain HTTP request rates', async () => {
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const body = `payload-${i}`;
|
||||
const req = http.request(
|
||||
{
|
||||
hostname: 'localhost',
|
||||
port: httpProxyPort,
|
||||
path: '/echo',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Host: 'Example.COM',
|
||||
'Content-Type': 'text/plain',
|
||||
'Content-Length': String(body.length),
|
||||
},
|
||||
},
|
||||
(res) => {
|
||||
res.resume();
|
||||
res.on('end', () => resolve());
|
||||
},
|
||||
);
|
||||
req.on('error', reject);
|
||||
req.end(body);
|
||||
});
|
||||
}
|
||||
|
||||
await waitForCondition(async () => {
|
||||
await pollMetrics(proxy);
|
||||
const domainMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
|
||||
return (domainMetrics?.lastMinute ?? 0) >= 3 && (domainMetrics?.perSecond ?? 0) > 0;
|
||||
});
|
||||
|
||||
const exampleMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
|
||||
expect(exampleMetrics).toBeTruthy();
|
||||
expect(exampleMetrics?.lastMinute).toEqual(3);
|
||||
expect(exampleMetrics?.perSecond).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
tap.test('TLS passthrough SNI does not inflate HTTP domain request rates', async () => {
|
||||
const tlsClient = tls.connect({
|
||||
host: 'localhost',
|
||||
port: tlsProxyPort,
|
||||
servername: 'passthrough.example.com',
|
||||
rejectUnauthorized: false,
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
tlsClient.once('secureConnect', () => resolve());
|
||||
tlsClient.once('error', reject);
|
||||
});
|
||||
|
||||
const echoPromise = new Promise<void>((resolve, reject) => {
|
||||
tlsClient.once('data', () => resolve());
|
||||
tlsClient.once('error', reject);
|
||||
});
|
||||
tlsClient.write(Buffer.from('hello over tls passthrough'));
|
||||
await echoPromise;
|
||||
|
||||
await waitForCondition(async () => {
|
||||
await pollMetrics(proxy);
|
||||
return hasIpDomainRequest('passthrough.example.com');
|
||||
});
|
||||
|
||||
const requestRates = proxy.getMetrics().requests.byDomain();
|
||||
expect(requestRates.has('passthrough.example.com')).toBeFalse();
|
||||
expect(requestRates.get('example.com')?.lastMinute).toEqual(3);
|
||||
expect(hasIpDomainRequest('passthrough.example.com')).toBeTrue();
|
||||
|
||||
tlsClient.destroy();
|
||||
});
|
||||
|
||||
tap.test('cleanup - stop proxy and close backend servers', async () => {
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => httpBackend.close(() => resolve()));
|
||||
await new Promise<void>((resolve) => tlsBackend.close(() => resolve()));
|
||||
await assertPortsFree([httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort]);
|
||||
});
|
||||
|
||||
export default tap.start()
|
||||
@@ -83,6 +83,9 @@ tap.test('should verify new metrics API structure', async () => {
|
||||
expect(metrics.throughput).toHaveProperty('history');
|
||||
expect(metrics.throughput).toHaveProperty('byRoute');
|
||||
expect(metrics.throughput).toHaveProperty('byIP');
|
||||
|
||||
// Check request methods
|
||||
expect(metrics.requests).toHaveProperty('byDomain');
|
||||
});
|
||||
|
||||
tap.test('should track active connections', async (tools) => {
|
||||
@@ -273,4 +276,4 @@ tap.test('should clean up resources', async () => {
|
||||
await assertPortsFree([echoServerPort, proxyPort]);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
@@ -188,10 +188,12 @@ tap.test('TCP forward - real-time byte tracking', async (tools) => {
|
||||
const byRoute = m.throughput.byRoute();
|
||||
console.log('TCP forward — throughput byRoute:', Array.from(byRoute.entries()));
|
||||
|
||||
// After close, per-IP data should be evicted (memory leak fix)
|
||||
// After close, per-IP buckets are retained briefly for final throughput sampling,
|
||||
// but active connection counts must already be zero.
|
||||
const byIPAfter = m.connections.byIP();
|
||||
console.log('TCP forward — connections byIP after close:', Array.from(byIPAfter.entries()));
|
||||
expect(byIPAfter.size).toEqual(0);
|
||||
expect(byIPAfter.size).toBeGreaterThan(0);
|
||||
expect(Array.from(byIPAfter.values()).every((count) => count === 0)).toEqual(true);
|
||||
|
||||
await proxy.stop();
|
||||
await tools.delayFor(200);
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartproxy',
|
||||
version: '27.7.2',
|
||||
version: '27.9.0',
|
||||
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
||||
}
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ export { SmartProxy } from './proxies/smart-proxy/index.js';
|
||||
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
|
||||
|
||||
// Export smart-proxy models
|
||||
export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
|
||||
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
|
||||
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
|
||||
export * from './proxies/smart-proxy/utils/index.js';
|
||||
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
* SmartProxy models
|
||||
*/
|
||||
// Export everything except IAcmeOptions from interfaces
|
||||
export type { ISmartProxyOptions, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
|
||||
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
|
||||
export * from './route-types.js';
|
||||
export * from './metrics-types.js';
|
||||
|
||||
@@ -29,6 +29,11 @@ export interface ISmartProxyCertStore {
|
||||
}
|
||||
import type { IRouteConfig } from './route-types.js';
|
||||
|
||||
export interface ISmartProxySecurityPolicy {
|
||||
blockedIps?: string[];
|
||||
blockedCidrs?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Provision object for static or HTTP-01 certificate
|
||||
*/
|
||||
@@ -137,6 +142,7 @@ export interface ISmartProxyOptions {
|
||||
// Rate limiting and security
|
||||
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
|
||||
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
|
||||
securityPolicy?: ISmartProxySecurityPolicy; // Global ingress block policy, enforced before routing
|
||||
|
||||
// Enhanced keep-alive settings
|
||||
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
|
||||
@@ -276,4 +282,4 @@ export interface IConnectionRecord {
|
||||
path?: string;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ export interface IThroughputHistoryPoint {
|
||||
out: number;
|
||||
}
|
||||
|
||||
export interface IRequestRateMetrics {
|
||||
perSecond: number;
|
||||
lastMinute: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main metrics interface with clean, grouped API
|
||||
*/
|
||||
@@ -81,6 +86,7 @@ export interface IMetrics {
|
||||
perSecond(): number;
|
||||
perMinute(): number;
|
||||
total(): number;
|
||||
byDomain(): Map<string, IRequestRateMetrics>;
|
||||
};
|
||||
|
||||
// Cumulative totals
|
||||
@@ -185,4 +191,4 @@ export interface IByteTracker {
|
||||
bytesOut: number;
|
||||
startTime: number;
|
||||
lastUpdate: number;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
|
||||
import type { IAcmeOptions, ISmartProxyOptions } from './interfaces.js';
|
||||
import type { IAcmeOptions, ISmartProxyOptions, ISmartProxySecurityPolicy } from './interfaces.js';
|
||||
import type {
|
||||
IRouteAction,
|
||||
IRouteConfig,
|
||||
@@ -75,6 +75,7 @@ export interface IRustProxyOptions {
|
||||
keepAliveInactivityMultiplier?: number;
|
||||
extendedKeepAliveLifetime?: number;
|
||||
metrics?: ISmartProxyOptions['metrics'];
|
||||
securityPolicy?: ISmartProxySecurityPolicy;
|
||||
acme?: IRustAcmeOptions;
|
||||
}
|
||||
|
||||
@@ -134,6 +135,11 @@ export interface IRustBackendMetrics {
|
||||
h2Failures: number;
|
||||
}
|
||||
|
||||
export interface IRustHttpDomainRequestMetrics {
|
||||
requestsPerSecond: number;
|
||||
requestsLastMinute: number;
|
||||
}
|
||||
|
||||
export interface IRustMetricsSnapshot {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
@@ -150,6 +156,7 @@ export interface IRustMetricsSnapshot {
|
||||
totalHttpRequests: number;
|
||||
httpRequestsPerSec: number;
|
||||
httpRequestsPerSecRecent: number;
|
||||
httpDomainRequests: Record<string, IRustHttpDomainRequestMetrics>;
|
||||
activeUdpSessions: number;
|
||||
totalUdpSessions: number;
|
||||
totalDatagramsIn: number;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
|
||||
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IRequestRateMetrics, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
|
||||
import type { RustProxyBridge } from './rust-proxy-bridge.js';
|
||||
import type { IRustBackendMetrics, IRustIpMetrics, IRustMetricsSnapshot, IRustRouteMetrics } from './models/rust-types.js';
|
||||
import type { IRustBackendMetrics, IRustHttpDomainRequestMetrics, IRustIpMetrics, IRustMetricsSnapshot, IRustRouteMetrics } from './models/rust-types.js';
|
||||
|
||||
/**
|
||||
* Adapts Rust JSON metrics to the IMetrics interface.
|
||||
@@ -219,6 +219,18 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
total: (): number => {
|
||||
return this.cache?.totalHttpRequests ?? this.cache?.totalConnections ?? 0;
|
||||
},
|
||||
byDomain: (): Map<string, IRequestRateMetrics> => {
|
||||
const result = new Map<string, IRequestRateMetrics>();
|
||||
if (this.cache?.httpDomainRequests) {
|
||||
for (const [domain, metrics] of Object.entries(this.cache.httpDomainRequests) as Array<[string, IRustHttpDomainRequestMetrics]>) {
|
||||
result.set(domain, {
|
||||
perSecond: metrics.requestsPerSecond ?? 0,
|
||||
lastMinute: metrics.requestsLastMinute ?? 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
public totals = {
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
IRustRouteConfig,
|
||||
IRustStatistics,
|
||||
} from './models/rust-types.js';
|
||||
import type { ISmartProxySecurityPolicy } from './models/interfaces.js';
|
||||
|
||||
/**
|
||||
* Type-safe command definitions for the Rust proxy IPC protocol.
|
||||
@@ -15,6 +16,7 @@ type TSmartProxyCommands = {
|
||||
start: { params: { config: IRustProxyOptions }; result: void };
|
||||
stop: { params: Record<string, never>; result: void };
|
||||
updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
|
||||
setSecurityPolicy: { params: { policy: ISmartProxySecurityPolicy }; result: void };
|
||||
getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
|
||||
getStatistics: { params: Record<string, never>; result: IRustStatistics };
|
||||
provisionCertificate: { params: { routeName: string }; result: void };
|
||||
@@ -139,6 +141,10 @@ export class RustProxyBridge extends plugins.EventEmitter {
|
||||
await this.bridge.sendCommand('updateRoutes', { routes });
|
||||
}
|
||||
|
||||
public async setSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
|
||||
await this.bridge.sendCommand('setSecurityPolicy', { policy });
|
||||
}
|
||||
|
||||
public async getMetrics(): Promise<IRustMetricsSnapshot> {
|
||||
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import { Mutex } from './utils/mutex.js';
|
||||
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
|
||||
|
||||
// Types
|
||||
import type { ISmartProxyOptions, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
|
||||
import type { ISmartProxyOptions, ISmartProxySecurityPolicy, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
|
||||
import type { IRouteConfig } from './models/route-types.js';
|
||||
import type { IMetrics } from './models/metrics-types.js';
|
||||
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
|
||||
@@ -350,6 +350,15 @@ export class SmartProxy extends plugins.EventEmitter {
|
||||
.catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' }));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the global ingress security policy without changing routes.
|
||||
* The Rust engine applies this before route selection and backend connection.
|
||||
*/
|
||||
public async updateSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
|
||||
this.settings.securityPolicy = policy;
|
||||
await this.bridge.setSecurityPolicy(policy);
|
||||
}
|
||||
|
||||
/**
|
||||
* Provision a certificate for a named route.
|
||||
*/
|
||||
|
||||
@@ -182,6 +182,7 @@ export function buildRustProxyOptions(
|
||||
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
|
||||
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
|
||||
metrics: settings.metrics,
|
||||
securityPolicy: settings.securityPolicy,
|
||||
acme: serializeAcmeForRust(acme),
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user