feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers

This commit is contained in:
2026-04-26 15:11:10 +00:00
parent 8fa3a51b03
commit af4908b63f
53 changed files with 2350 additions and 1196 deletions
+7
View File
@@ -1,5 +1,12 @@
# Changelog # Changelog
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics) ## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
retain inactive per-IP metric buckets briefly to capture final throughput before pruning retain inactive per-IP metric buckets briefly to capture final throughput before pruning
+1
View File
@@ -1319,6 +1319,7 @@ dependencies = [
"rustproxy-http", "rustproxy-http",
"rustproxy-metrics", "rustproxy-metrics",
"rustproxy-routing", "rustproxy-routing",
"rustproxy-security",
"serde", "serde",
"serde_json", "serde_json",
"socket2 0.5.10", "socket2 0.5.10",
+4 -4
View File
@@ -3,15 +3,15 @@
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema. //! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming. //! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
pub mod route_types;
pub mod proxy_options; pub mod proxy_options;
pub mod tls_types; pub mod route_types;
pub mod security_types; pub mod security_types;
pub mod tls_types;
pub mod validation; pub mod validation;
// Re-export all primary types // Re-export all primary types
pub use route_types::*;
pub use proxy_options::*; pub use proxy_options::*;
pub use tls_types::*; pub use route_types::*;
pub use security_types::*; pub use security_types::*;
pub use tls_types::*;
pub use validation::*; pub use validation::*;
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
pub retention_seconds: Option<u64>, pub retention_seconds: Option<u64>,
} }
/// Global ingress security policy.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SecurityPolicy {
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_ips: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_cidrs: Option<Vec<String>>,
}
/// RustProxy configuration options. /// RustProxy configuration options.
/// Matches TypeScript: `ISmartProxyOptions` /// Matches TypeScript: `ISmartProxyOptions`
/// ///
@@ -235,6 +245,10 @@ pub struct RustProxyOptions {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<MetricsConfig>, pub metrics: Option<MetricsConfig>,
/// Global ingress security policy, enforced before route selection.
#[serde(skip_serializing_if = "Option::is_none")]
pub security_policy: Option<SecurityPolicy>,
// ─── ACME ──────────────────────────────────────────────────────── // ─── ACME ────────────────────────────────────────────────────────
/// Global ACME configuration /// Global ACME configuration
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -275,6 +289,7 @@ impl Default for RustProxyOptions {
use_http_proxy: None, use_http_proxy: None,
http_proxy_port: None, http_proxy_port: None,
metrics: None, metrics: None,
security_policy: None,
acme: None, acme: None,
} }
} }
@@ -111,10 +111,7 @@ pub enum IpAllowEntry {
/// Plain IP/CIDR — allowed for all domains on this route /// Plain IP/CIDR — allowed for all domains on this route
Plain(String), Plain(String),
/// Domain-scoped — allowed only when the requested domain matches /// Domain-scoped — allowed only when the requested domain matches
DomainScoped { DomainScoped { ip: String, domains: Vec<String> },
ip: String,
domains: Vec<String>,
},
} }
/// Security options for routes. /// Security options for routes.
+17 -8
View File
@@ -1,6 +1,6 @@
use thiserror::Error; use thiserror::Error;
use crate::route_types::{RouteConfig, RouteActionType}; use crate::route_types::{RouteActionType, RouteConfig};
/// Validation errors for route configurations. /// Validation errors for route configurations.
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -30,9 +30,10 @@ pub enum ValidationError {
/// Validate a single route configuration. /// Validate a single route configuration.
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> { pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new(); let mut errors = Vec::new();
let name = route.name.clone().unwrap_or_else(|| { let name = route
route.id.clone().unwrap_or_else(|| "unnamed".to_string()) .name
}); .clone()
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
// Check ports // Check ports
let ports = route.listening_ports(); let ports = route.listening_ports();
@@ -160,7 +161,9 @@ mod tests {
let mut route = make_valid_route(); let mut route = make_valid_route();
route.action.targets = None; route.action.targets = None;
let errors = validate_route(&route).unwrap_err(); let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
} }
#[test] #[test]
@@ -168,7 +171,9 @@ mod tests {
let mut route = make_valid_route(); let mut route = make_valid_route();
route.action.targets = Some(vec![]); route.action.targets = Some(vec![]);
let errors = validate_route(&route).unwrap_err(); let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
} }
#[test] #[test]
@@ -176,7 +181,9 @@ mod tests {
let mut route = make_valid_route(); let mut route = make_valid_route();
route.route_match.ports = PortRange::Single(0); route.route_match.ports = PortRange::Single(0);
let errors = validate_route(&route).unwrap_err(); let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
} }
#[test] #[test]
@@ -186,7 +193,9 @@ mod tests {
let mut r2 = make_valid_route(); let mut r2 = make_valid_route();
r2.id = Some("route-1".to_string()); r2.id = Some("route-1".to_string());
let errors = validate_routes(&[r1, r2]).unwrap_err(); let errors = validate_routes(&[r1, r2]).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
} }
#[test] #[test]
+2 -2
View File
@@ -2,10 +2,10 @@
//! //!
//! Metrics and throughput tracking for RustProxy. //! Metrics and throughput tracking for RustProxy.
pub mod throughput;
pub mod collector; pub mod collector;
pub mod log_dedup; pub mod log_dedup;
pub mod throughput;
pub use throughput::*;
pub use collector::*; pub use collector::*;
pub use log_dedup::*; pub use log_dedup::*;
pub use throughput::*;
+11 -8
View File
@@ -1,6 +1,6 @@
use dashmap::DashMap; use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::info; use tracing::info;
@@ -47,13 +47,16 @@ impl LogDeduplicator {
let map_key = format!("{}:{}", category, key); let map_key = format!("{}:{}", category, key);
let now = Instant::now(); let now = Instant::now();
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent { let entry = self
category: category.to_string(), .events
first_message: message.to_string(), .entry(map_key)
count: AtomicU64::new(0), .or_insert_with(|| AggregatedEvent {
first_seen: now, category: category.to_string(),
last_seen: now, first_message: message.to_string(),
}); count: AtomicU64::new(0),
first_seen: now,
last_seen: now,
});
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1; let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
rustproxy-config = { workspace = true } rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true } rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true } rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
@@ -7,8 +7,8 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
pub fn recycle_for_cert_change(&self, cert_domain: &str) { pub fn recycle_for_cert_change(&self, cert_domain: &str) {
let mut recycled = 0u64; let mut recycled = 0u64;
self.connections.retain(|_, entry| { self.connections.retain(|_, entry| {
let matches = entry.domain.as_deref() let matches = entry
.domain
.as_deref()
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain)) .map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
.unwrap_or(false); .unwrap_or(false);
if matches { if matches {
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
let mut recycled = 0u64; let mut recycled = 0u64;
self.connections.retain(|_, entry| { self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) { if entry.route_id.as_deref() == Some(route_id) {
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) { if !RequestFilter::check_ip_security(
new_security,
&entry.source_ip,
entry.domain.as_deref(),
) {
info!( info!(
"Terminating connection from {} — IP now blocked on route '{}'", "Terminating connection from {} — IP now blocked on route '{}'",
entry.source_ip, route_id entry.source_ip, route_id
@@ -31,7 +31,8 @@ impl ConnectionTracker {
pub fn try_accept(&self, ip: &IpAddr) -> bool { pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit // Check per-IP connection limit
if let Some(max) = self.max_per_ip { if let Some(max) = self.max_per_ip {
let count = self.active let count = self
.active
.get(ip) .get(ip)
.map(|c| c.value().load(Ordering::Relaxed)) .map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
@@ -48,7 +49,10 @@ impl ConnectionTracker {
let timestamps = entry.value_mut(); let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute // Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) { while timestamps
.front()
.is_some_and(|t| now.duration_since(*t) >= one_minute)
{
timestamps.pop_front(); timestamps.pop_front();
} }
@@ -111,7 +115,6 @@ impl ConnectionTracker {
pub fn tracked_ips(&self) -> usize { pub fn tracked_ips(&self) -> usize {
self.active.len() self.active.len()
} }
} }
#[cfg(test)] #[cfg(test)]
@@ -1,8 +1,8 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug; use tracing::debug;
use rustproxy_metrics::MetricsCollector; use rustproxy_metrics::MetricsCollector;
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
if let Some(data) = initial_data { if let Some(data) = initial_data {
backend.write_all(data).await?; backend.write_all(data).await?;
if let Some(ref ctx) = metrics { if let Some(ref ctx) = metrics {
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); ctx.collector.record_bytes(
data.len() as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
} }
} }
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64; total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_c2b { if let Some(ref ctx) = metrics_c2b {
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); ctx.collector.record_bytes(
n as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
} }
} }
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify) // Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout( let _ =
std::time::Duration::from_secs(2), tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
backend_write.shutdown(),
).await;
total total
}); });
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64; total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_b2c { if let Some(ref ctx) = metrics_b2c {
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); ctx.collector.record_bytes(
0,
n as u64,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
} }
} }
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify) // Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout( let _ =
std::time::Duration::from_secs(2), tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
client_write.shutdown(),
).await;
total total
}); });
+16 -16
View File
@@ -4,26 +4,26 @@
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding, //! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
//! and UDP datagram session tracking with forwarding. //! and UDP datagram session tracking with forwarding.
pub mod tcp_listener; pub mod connection_registry;
pub mod sni_parser; pub mod connection_tracker;
pub mod forwarder; pub mod forwarder;
pub mod proxy_protocol; pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_tracker;
pub mod connection_registry;
pub mod socket_opts;
pub mod udp_session;
pub mod udp_listener;
pub mod quic_handler; pub mod quic_handler;
pub mod sni_parser;
pub mod socket_opts;
pub mod tcp_listener;
pub mod tls_handler;
pub mod udp_listener;
pub mod udp_session;
pub use tcp_listener::*; pub use connection_registry::*;
pub use sni_parser::*; pub use connection_tracker::*;
pub use forwarder::*; pub use forwarder::*;
pub use proxy_protocol::*; pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_tracker::*;
pub use connection_registry::*;
pub use socket_opts::*;
pub use udp_session::*;
pub use udp_listener::*;
pub use quic_handler::*; pub use quic_handler::*;
pub use sni_parser::*;
pub use socket_opts::*;
pub use tcp_listener::*;
pub use tls_handler::*;
pub use udp_listener::*;
pub use udp_session::*;
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
.position(|w| w == b"\r\n") .position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?; .ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end]) let line =
.map_err(|_| ProxyProtocolError::InvalidHeader)?; std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") { if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader); return Err(ProxyProtocolError::InvalidHeader);
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
let command = data[12] & 0x0F; let command = data[12] & 0x0F;
// 0x0 = LOCAL, 0x1 = PROXY // 0x0 = LOCAL, 0x1 = PROXY
if command > 1 { if command > 1 {
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command))); return Err(ProxyProtocolError::Parse(format!(
"Unknown command: {}",
command
)));
} }
// Address family (high nibble) + transport (low nibble) of byte 13 // Address family (high nibble) + transport (low nibble) of byte 13
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + STREAM (0x1) = TCP4 // AF_INET (0x1) + STREAM (0x1) = TCP4
(0x1, 0x1) => { (0x1, 0x1) => {
if addr_len < 12 { if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
} }
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]); let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]); let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + DGRAM (0x2) = UDP4 // AF_INET (0x1) + DGRAM (0x2) = UDP4
(0x1, 0x2) => { (0x1, 0x2) => {
if addr_len < 12 { if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
} }
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]); let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]); let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + STREAM (0x1) = TCP6 // AF_INET6 (0x2) + STREAM (0x1) = TCP6
(0x2, 0x1) => { (0x2, 0x1) => {
if addr_len < 36 { if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
} }
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap()); let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6 // AF_INET6 (0x2) + DGRAM (0x2) = UDP6
(0x2, 0x2) => { (0x2, 0x2) => {
if addr_len < 36 { if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
} }
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap()); let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
} }
/// Generate a PROXY protocol v2 binary header. /// Generate a PROXY protocol v2 binary header.
pub fn generate_v2( pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
source: &SocketAddr,
dest: &SocketAddr,
transport: ProxyV2Transport,
) -> Vec<u8> {
let transport_nibble: u8 = match transport { let transport_nibble: u8 = match transport {
ProxyV2Transport::Stream => 0x1, ProxyV2Transport::Stream => 0x1,
ProxyV2Transport::Datagram => 0x2, ProxyV2Transport::Datagram => 0x2,
@@ -462,7 +469,10 @@ mod tests {
header.push(0x11); header.push(0x11);
header.extend_from_slice(&12u16.to_be_bytes()); header.extend_from_slice(&12u16.to_be_bytes());
header.extend_from_slice(&[0u8; 12]); header.extend_from_slice(&[0u8; 12]);
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion))); assert!(matches!(
parse_v2(&header),
Err(ProxyProtocolError::UnsupportedVersion)
));
} }
#[test] #[test]
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
use rustproxy_config::{RouteConfig, TransportProtocol}; use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector; use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager}; use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService; use rustproxy_http::h3_service::H3ProxyService;
use crate::connection_tracker::ConnectionTracker;
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry}; use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
use crate::connection_tracker::ConnectionTracker;
/// Create a QUIC server endpoint on the given port with the provided TLS config. /// Create a QUIC server endpoint on the given port with the provided TLS config.
/// ///
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
quinn::EndpointConfig::default(), quinn::EndpointConfig::default(),
Some(server_config), Some(server_config),
socket, socket,
quinn::default_runtime() quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?; )?;
info!("QUIC endpoint listening on port {}", port); info!("QUIC endpoint listening on port {}", port);
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
port: u16, port: u16,
tls_config: Arc<RustlsServerConfig>, tls_config: Arc<RustlsServerConfig>,
proxy_ips: Arc<Vec<IpAddr>>, proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
cancel: CancellationToken, cancel: CancellationToken,
) -> anyhow::Result<QuicProxyRelay> { ) -> anyhow::Result<QuicProxyRelay> {
// Bind external socket on the real port // Bind external socket on the real port
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
quinn::EndpointConfig::default(), quinn::EndpointConfig::default(),
Some(server_config), Some(server_config),
internal_socket, internal_socket,
quinn::default_runtime() quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?; )?;
let real_client_map = Arc::new(DashMap::new()); let real_client_map = Arc::new(DashMap::new());
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
external_socket, external_socket,
quinn_internal_addr, quinn_internal_addr,
proxy_ips, proxy_ips,
security_policy,
Arc::clone(&real_client_map), Arc::clone(&real_client_map),
cancel, cancel,
)); ));
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr); info!(
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map }) "QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
port, quinn_internal_addr
);
Ok(QuicProxyRelay {
endpoint,
relay_task,
real_client_map,
})
} }
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2 /// Main relay loop: reads datagrams from the external socket, filters PROXY v2
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
external_socket: Arc<UdpSocket>, external_socket: Arc<UdpSocket>,
quinn_internal_addr: SocketAddr, quinn_internal_addr: SocketAddr,
proxy_ips: Arc<Vec<IpAddr>>, proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>, real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
cancel: CancellationToken, cancel: CancellationToken,
) { ) {
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) { if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) { match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => { Ok((header, _consumed)) => {
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr); debug!(
"QUIC PROXY v2 from {}: real client {}",
src_addr, header.source_addr
);
proxy_addr_map.insert(src_addr, header.source_addr); proxy_addr_map.insert(src_addr, header.source_addr);
continue; // consume the PROXY v2 datagram continue; // consume the PROXY v2 datagram
} }
Err(e) => { Err(e) => {
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e); debug!(
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
src_addr, e
);
} }
} }
} }
} }
// Determine real client address // Determine real client address
let real_client = proxy_addr_map.get(&src_addr) let real_client = proxy_addr_map
.get(&src_addr)
.map(|r| *r) .map(|r| *r)
.unwrap_or(src_addr); .unwrap_or(src_addr);
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
debug!(
"QUIC datagram from {} blocked by global security policy",
real_client
);
continue;
}
// Get or create relay session for this external source // Get or create relay session for this external source
let session = match relay_sessions.get(&src_addr) { let session = match relay_sessions.get(&src_addr) {
Some(s) => { Some(s) => {
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed); s.last_activity
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
Arc::clone(s.value()) Arc::clone(s.value())
} }
None => { None => {
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
} }
}; };
if let Err(e) = relay_socket.connect(quinn_internal_addr).await { if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e); warn!(
"QUIC relay: failed to connect relay socket to {}: {}",
quinn_internal_addr, e
);
continue; continue;
} }
let relay_local_addr = match relay_socket.local_addr() { let relay_local_addr = match relay_socket.local_addr() {
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
}); });
relay_sessions.insert(src_addr, Arc::clone(&session)); relay_sessions.insert(src_addr, Arc::clone(&session));
debug!("QUIC relay: new session for {} (relay {}), real client {}", debug!(
src_addr, relay_local_addr, real_client); "QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client
);
session session
} }
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
if last_cleanup.elapsed() >= cleanup_interval { if last_cleanup.elapsed() >= cleanup_interval {
last_cleanup = Instant::now(); last_cleanup = Instant::now();
let now_ms = epoch.elapsed().as_millis() as u64; let now_ms = epoch.elapsed().as_millis() as u64;
let stale_keys: Vec<SocketAddr> = relay_sessions.iter() let stale_keys: Vec<SocketAddr> = relay_sessions
.iter()
.filter(|entry| { .filter(|entry| {
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed)); let age =
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
age > session_timeout_ms age > session_timeout_ms
}) })
.map(|entry| *entry.key()) .map(|entry| *entry.key())
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
// Also clean orphaned proxy_addr_map entries (PROXY header received // Also clean orphaned proxy_addr_map entries (PROXY header received
// but no relay session was ever created, e.g. client never sent data) // but no relay session was ever created, e.g. client never sent data)
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter() let orphaned: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| relay_sessions.get(entry.key()).is_none()) .filter(|entry| relay_sessions.get(entry.key()).is_none())
.map(|entry| *entry.key()) .map(|entry| *entry.key())
.collect(); .collect();
for key in orphaned { for key in orphaned {
proxy_addr_map.remove(&key); proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key); debug!(
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
key
);
} }
} }
} }
@@ -328,8 +365,14 @@ async fn relay_return_path(
} }
}; };
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await { if let Err(e) = external_socket
debug!("QUIC relay return send error to {}: {}", external_src_addr, e); .send_to(&buf[..len], external_src_addr)
.await
{
debug!(
"QUIC relay return send error to {}: {}",
external_src_addr, e
);
break; break;
} }
} }
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>, real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
route_cancels: Arc<DashMap<String, CancellationToken>>, route_cancels: Arc<DashMap<String, CancellationToken>>,
connection_registry: Arc<ConnectionRegistry>, connection_registry: Arc<ConnectionRegistry>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) { ) {
loop { loop {
let incoming = tokio::select! { let incoming = tokio::select! {
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
let remote_addr = incoming.remote_address(); let remote_addr = incoming.remote_address();
// Resolve real client IP from PROXY protocol map if available // Resolve real client IP from PROXY protocol map if available
let real_addr = real_client_map.as_ref() let real_addr = real_client_map
.as_ref()
.and_then(|map| map.get(&remote_addr).map(|r| *r)) .and_then(|map| map.get(&remote_addr).map(|r| *r))
.unwrap_or(remote_addr); .unwrap_or(remote_addr);
let ip = real_addr.ip(); let ip = real_addr.ip();
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&ip) {
debug!(
"QUIC connection from {} blocked by global security policy",
real_addr
);
continue;
}
// Per-IP rate limiting // Per-IP rate limiting
if !conn_tracker.try_accept(&ip) { if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", real_addr); debug!("QUIC connection rejected from {} (rate limit)", real_addr);
@@ -414,7 +468,10 @@ pub async fn quic_accept_loop(
if !rustproxy_http::request_filter::RequestFilter::check_ip_security( if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security, &ip, ctx.domain, security, &ip, ctx.domain,
) { ) {
debug!("QUIC connection from {} blocked by route security", real_addr); debug!(
"QUIC connection from {} blocked by route security",
real_addr
);
continue; continue;
} }
} }
@@ -425,7 +482,8 @@ pub async fn quic_accept_loop(
// Resolve per-route cancel token (child of global cancel) // Resolve per-route cancel token (child of global cancel)
let route_cancel = match route_id.as_deref() { let route_cancel = match route_id.as_deref() {
Some(id) => route_cancels.entry(id.to_string()) Some(id) => route_cancels
.entry(id.to_string())
.or_insert_with(|| cancel.child_token()) .or_insert_with(|| cancel.child_token())
.clone(), .clone(),
None => cancel.child_token(), None => cancel.child_token(),
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
let metrics = Arc::clone(&metrics); let metrics = Arc::clone(&metrics);
let conn_tracker = Arc::clone(&conn_tracker); let conn_tracker = Arc::clone(&conn_tracker);
let h3_svc = h3_service.clone(); let h3_svc = h3_service.clone();
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None }; let real_client_addr = if real_addr != remote_addr {
Some(real_addr)
} else {
None
};
tokio::spawn(async move { tokio::spawn(async move {
// Register in connection registry (RAII guard removes on drop) // Register in connection registry (RAII guard removes on drop)
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
impl Drop for QuicConnGuard { impl Drop for QuicConnGuard {
fn drop(&mut self) { fn drop(&mut self) {
self.tracker.connection_closed(&self.ip); self.tracker.connection_closed(&self.ip);
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str)); self.metrics
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
} }
} }
let _guard = QuicConnGuard { let _guard = QuicConnGuard {
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
route_id, route_id,
}; };
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await { match handle_quic_connection(
incoming,
route,
port,
Arc::clone(&metrics),
&conn_cancel,
h3_svc,
real_client_addr,
)
.await
{
Ok(()) => debug!("QUIC connection from {} completed", real_addr), Ok(()) => debug!("QUIC connection from {} completed", real_addr),
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e), Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
} }
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
debug!("QUIC connection established from {}", effective_addr); debug!("QUIC connection established from {}", effective_addr);
// Check if this route has HTTP/3 enabled // Check if this route has HTTP/3 enabled
let enable_http3 = route.action.udp.as_ref() let enable_http3 = route
.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref()) .and_then(|u| u.quic.as_ref())
.and_then(|q| q.enable_http3) .and_then(|q| q.enable_http3)
.unwrap_or(false); .unwrap_or(false);
if enable_http3 { if enable_http3 {
if let Some(ref h3_svc) = h3_service { if let Some(ref h3_svc) = h3_service {
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name); debug!(
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await "HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
route.name
);
h3_svc
.handle_connection(connection, &route, port, real_client_addr, cancel)
.await
} else { } else {
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name); warn!(
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
route.name
);
// Keep connection alive until cancelled // Keep connection alive until cancelled
tokio::select! { tokio::select! {
_ = cancel.cancelled() => {} _ = cancel.cancelled() => {}
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
} }
} else { } else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend // Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
.await
} }
} }
@@ -545,7 +630,10 @@ async fn handle_quic_stream_forwarding(
let metrics_arc = metrics; let metrics_arc = metrics;
// Resolve backend target // Resolve backend target
let target = route.action.targets.as_ref() let target = route
.action
.targets
.as_ref()
.and_then(|t| t.first()) .and_then(|t| t.first())
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?; .ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
let backend_host = target.host.first(); let backend_host = target.host.first();
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
// Spawn a task for each QUIC stream → TCP bidirectional forwarding // Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move { tokio::spawn(async move {
match forward_quic_stream_to_tcp( match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
send_stream, .await
recv_stream, {
&backend_addr,
stream_cancel,
).await {
Ok((bytes_in, bytes_out)) => { Ok((bytes_in, bytes_out)) => {
stream_metrics.record_bytes( stream_metrics.record_bytes(
bytes_in, bytes_out, bytes_in,
bytes_out,
stream_route_id.as_deref(), stream_route_id.as_deref(),
Some(&ip_str), Some(&ip_str),
); );
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out); debug!(
"QUIC stream forwarded: {}B in, {}B out",
bytes_in, bytes_out
);
} }
Err(e) => { Err(e) => {
debug!("QUIC stream forwarding error: {}", e); debug!("QUIC stream forwarding error: {}", e);
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
total += n as u64; total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
} }
let _ = tokio::time::timeout( let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
std::time::Duration::from_secs(2),
tcp_write.shutdown(),
).await;
total total
}); });
@@ -721,8 +807,8 @@ mod tests {
let _ = rustls::crypto::ring::default_provider().install_default(); let _ = rustls::crypto::ring::default_provider().install_default();
// Generate a single self-signed cert and use its key pair // Generate a single self-signed cert and use its key pair
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]) let self_signed =
.unwrap(); rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_der = self_signed.cert.der().clone(); let cert_der = self_signed.cert.der().clone();
let key_der = self_signed.key_pair.serialize_der(); let key_der = self_signed.key_pair.serialize_der();
@@ -737,6 +823,10 @@ mod tests {
// Port 0 = OS assigns a free port // Port 0 = OS assigns a free port
let result = create_quic_endpoint(0, Arc::new(tls_config)); let result = create_quic_endpoint(0, Arc::new(tls_config));
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err()); assert!(
result.is_ok(),
"QUIC endpoint creation failed: {:?}",
result.err()
);
} }
} }
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
} }
// Handshake length (3 bytes) - informational, we parse incrementally // Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16) let _handshake_len =
| ((data[7] as usize) << 8) ((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
| (data[8] as usize);
let hello = &data[9..]; let hello = &data[9..];
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
pub fn extract_http_host(data: &[u8]) -> Option<String> { pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?; let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") { for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) { if let Some(value) = line
.strip_prefix("Host: ")
.or_else(|| line.strip_prefix("host: "))
{
// Strip port if present // Strip port if present
let host = value.split(':').next().unwrap_or(value).trim(); let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() { if !host.is_empty() {
@@ -196,7 +198,7 @@ pub fn is_http(data: &[u8]) -> bool {
b"PATC", b"PATC",
b"OPTI", b"OPTI",
b"CONN", b"CONN",
b"PRI ", // HTTP/2 connection preface b"PRI ", // HTTP/2 connection preface
]; ];
starts.iter().any(|s| data.starts_with(s)) starts.iter().any(|s| data.starts_with(s))
} }
@@ -213,7 +215,10 @@ mod tests {
#[test] #[test]
fn test_too_short() { fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData)); assert!(matches!(
extract_sni(&[0x16, 0x03]),
SniResult::NeedMoreData
));
} }
#[test] #[test]
@@ -263,7 +268,8 @@ mod tests {
// Extension: type=0x0000 (SNI), length, data // Extension: type=0x0000 (SNI), length, data
let sni_extension = { let sni_extension = {
let mut e = Vec::new(); let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type e.push(0x00);
e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8); e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8); e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data); e.extend_from_slice(&sni_ext_data);
@@ -283,16 +289,20 @@ mod tests {
let hello_body = { let hello_body = {
let mut h = Vec::new(); let mut h = Vec::new();
// Client version TLS 1.2 // Client version TLS 1.2
h.push(0x03); h.push(0x03); h.push(0x03);
h.push(0x03);
// Random (32 bytes) // Random (32 bytes)
h.extend_from_slice(&[0u8; 32]); h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0 // Session ID length = 0
h.push(0x00); h.push(0x00);
// Cipher suites: length=2, one suite // Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02); h.push(0x00);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA h.push(0x02);
// Compression methods: length=1, null h.push(0x00);
h.push(0x01); h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01);
h.push(0x00);
// Extensions // Extensions
h.extend_from_slice(&extensions); h.extend_from_slice(&extensions);
h h
@@ -302,7 +312,7 @@ mod tests {
let handshake = { let handshake = {
let mut hs = Vec::new(); let mut hs = Vec::new();
hs.push(0x01); // ClientHello hs.push(0x01); // ClientHello
// 3-byte length // 3-byte length
hs.push(((hello_body.len() >> 16) & 0xFF) as u8); hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
hs.push(((hello_body.len() >> 8) & 0xFF) as u8); hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
hs.push((hello_body.len() & 0xFF) as u8); hs.push((hello_body.len() & 0xFF) as u8);
@@ -313,7 +323,8 @@ mod tests {
// TLS record: type=0x16, version TLS 1.0, length // TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new(); let mut record = Vec::new();
record.push(0x16); // Handshake record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0 record.push(0x03);
record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8); record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8); record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake); record.extend_from_slice(&handshake);
File diff suppressed because it is too large Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey; use rustls::sign::CertifiedKey;
use rustls::ServerConfig; use rustls::ServerConfig;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream}; use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
use tracing::{debug, info}; use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig; use crate::tcp_listener::TlsCertConfig;
@@ -29,7 +29,9 @@ pub struct CertResolver {
impl CertResolver { impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs. /// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup. /// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { pub fn new(
configs: &HashMap<String, TlsCertConfig>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider(); ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider(); let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new(); let mut certs = HashMap::new();
@@ -38,8 +40,10 @@ impl CertResolver {
for (domain, cfg) in configs { for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?; let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?; let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider) let ck = Arc::new(
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?); CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
);
if domain == "*" { if domain == "*" {
fallback = Some(Arc::clone(&ck)); fallback = Some(Arc::clone(&ck));
} }
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets. /// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone). /// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> { pub fn build_shared_tls_acceptor(
resolver: CertResolver,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider(); ensure_crypto_provider();
let mut config = ServerConfig::builder() let mut config = ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
// Shared session cache — enables session ID resumption across connections // Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096); config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted) // Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new() config.ticketer =
.map_err(|e| format!("Ticketer: {}", e))?; rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"); info!(
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
);
Ok(TlsAcceptor::from(Arc::new(config))) Ok(TlsAcceptor::from(Arc::new(config)))
} }
/// Build a TLS acceptor from PEM-encoded cert and key data. /// Build a TLS acceptor from PEM-encoded cert and key data.
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections). /// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> { pub fn build_tls_acceptor(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None) build_tls_acceptor_with_config(cert_pem, key_pem, None)
} }
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1. /// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection. /// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> { pub fn build_tls_acceptor_h1_only(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider(); ensure_crypto_provider();
let certs = load_certs(cert_pem)?; let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?; let key = load_private_key(key_pem)?;
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
// Apply TLS version restrictions // Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref()); let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions); let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder builder.with_no_client_auth().with_single_cert(certs, key)?
.with_no_client_auth()
.with_single_cert(certs, key)?
} else { } else {
ServerConfig::builder() ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
} }
/// Resolve TLS version strings to rustls SupportedProtocolVersion. /// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> { fn resolve_tls_versions(
versions: Option<&[String]>,
) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions { let versions = match versions {
Some(v) if !v.is_empty() => v, Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13], _ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
@@ -207,15 +221,17 @@ pub async fn accept_tls(
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new(); static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> { pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG.get_or_init(|| { SHARED_CLIENT_CONFIG
ensure_crypto_provider(); .get_or_init(|| {
let config = rustls::ClientConfig::builder() ensure_crypto_provider();
.dangerous() let config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(Arc::new(InsecureVerifier))
info!("Built shared backend TLS client config with session resumption"); .with_no_client_auth();
Arc::new(config) info!("Built shared backend TLS client config with session resumption");
}).clone() Arc::new(config)
})
.clone()
} }
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`. /// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new(); static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> { pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| { SHARED_CLIENT_CONFIG_ALPN
ensure_crypto_provider(); .get_or_init(|| {
let mut config = rustls::ClientConfig::builder() ensure_crypto_provider();
.dangerous() let mut config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(Arc::new(InsecureVerifier))
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; .with_no_client_auth();
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"); config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(config) info!(
}).clone() "Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
);
Arc::new(config)
})
.clone()
} }
/// Connect to a backend with TLS (for terminate-and-reencrypt mode). /// Connect to a backend with TLS (for terminate-and-reencrypt mode).
@@ -249,7 +269,8 @@ pub async fn connect_tls(
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?; stream.set_nodelay(true)?;
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access) // Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) { if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
{
debug!("Failed to set keepalive on backend TLS socket: {}", e); debug!("Failed to set keepalive on backend TLS socket: {}", e);
} }
@@ -260,10 +281,12 @@ pub async fn connect_tls(
} }
/// Load certificates from PEM string. /// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> { fn load_certs(
pem: &str,
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes()); let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader) let certs: Vec<CertificateDer<'static>> =
.collect::<Result<Vec<_>, _>>()?; rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() { if certs.is_empty() {
return Err("No certificates found in PEM data".into()); return Err("No certificates found in PEM data".into());
} }
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
} }
/// Load private key from PEM string. /// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> { fn load_private_key(
pem: &str,
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes()); let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC // Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)? let key =
.ok_or("No private key found in PEM data")?; rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
Ok(key) Ok(key)
} }
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use rustproxy_config::{RouteActionType, TransportProtocol}; use rustproxy_config::{RouteActionType, TransportProtocol};
use rustproxy_metrics::MetricsCollector; use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager}; use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService; use rustproxy_http::h3_service::H3ProxyService;
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
route_cancels: Arc<DashMap<String, CancellationToken>>, route_cancels: Arc<DashMap<String, CancellationToken>>,
/// Shared connection registry for selective recycling. /// Shared connection registry for selective recycling.
connection_registry: Arc<ConnectionRegistry>, connection_registry: Arc<ConnectionRegistry>,
/// Global ingress block policy, hot-reloadable without restarting listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
} }
impl Drop for UdpListenerManager { impl Drop for UdpListenerManager {
@@ -99,17 +102,26 @@ impl UdpListenerManager {
proxy_ips: Arc::new(Vec::new()), proxy_ips: Arc::new(Vec::new()),
route_cancels, route_cancels,
connection_registry, connection_registry,
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
} }
} }
/// Set the trusted proxy IPs for PROXY protocol v2 detection. /// Set the trusted proxy IPs for PROXY protocol v2 detection.
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) { pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
if !ips.is_empty() { if !ips.is_empty() {
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len()); info!(
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
ips.len()
);
} }
self.proxy_ips = Arc::new(ips); self.proxy_ips = Arc::new(ips);
} }
/// Set the shared global ingress security policy.
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
self.security_policy = policy;
}
/// Set the H3 proxy service for HTTP/3 request handling. /// Set the H3 proxy service for HTTP/3 request handling.
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) { pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
self.h3_service = Some(svc); self.h3_service = Some(svc);
@@ -142,7 +154,9 @@ impl UdpListenerManager {
// Check if any route on this port uses QUIC // Check if any route on this port uses QUIC
let rm = self.route_manager.load(); let rm = self.route_manager.load();
let has_quic = rm.routes_for_port(port).iter().any(|r| { let has_quic = rm.routes_for_port(port).iter().any(|r| {
r.action.udp.as_ref() r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref()) .and_then(|u| u.quic.as_ref())
.is_some() .is_some()
}); });
@@ -164,8 +178,10 @@ impl UdpListenerManager {
None, None,
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint started on port {}", port); info!("QUIC endpoint started on port {}", port);
} else { } else {
// Proxy relay path: we own external socket, quinn on localhost // Proxy relay path: we own external socket, quinn on localhost
@@ -173,6 +189,7 @@ impl UdpListenerManager {
port, port,
tls, tls,
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(), self.cancel_token.child_token(),
)?; )?;
let endpoint_for_updates = relay.endpoint.clone(); let endpoint_for_updates = relay.endpoint.clone();
@@ -187,13 +204,18 @@ impl UdpListenerManager {
Some(relay.real_client_map), Some(relay.real_client_map),
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint with PROXY relay started on port {}", port); info!("QUIC endpoint with PROXY relay started on port {}", port);
} }
return Ok(()); return Ok(());
} else { } else {
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port); warn!(
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
port
);
} }
} }
@@ -214,6 +236,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer), Arc::clone(&self.relay_writer),
self.cancel_token.child_token(), self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, None)); self.listeners.insert(port, (handle, None));
@@ -254,8 +277,10 @@ impl UdpListenerManager {
} }
debug!("UDP listener stopped on port {}", port); debug!("UDP listener stopped on port {}", port);
} }
info!("All UDP listeners stopped, {} sessions remaining", info!(
self.session_table.session_count()); "All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count()
);
} }
/// Update TLS config on all active QUIC endpoints (cert refresh). /// Update TLS config on all active QUIC endpoints (cert refresh).
@@ -288,11 +313,15 @@ impl UdpListenerManager {
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) { pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes // Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
let rm = self.route_manager.load(); let rm = self.route_manager.load();
let upgrade_ports: Vec<u16> = self.listeners.iter() let upgrade_ports: Vec<u16> = self
.listeners
.iter()
.filter(|(_, (_, endpoint))| endpoint.is_none()) .filter(|(_, (_, endpoint))| endpoint.is_none())
.filter(|(port, _)| { .filter(|(port, _)| {
rm.routes_for_port(**port).iter().any(|r| { rm.routes_for_port(**port).iter().any(|r| {
r.action.udp.as_ref() r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref()) .and_then(|u| u.quic.as_ref())
.is_some() .is_some()
}) })
@@ -301,17 +330,23 @@ impl UdpListenerManager {
.collect(); .collect();
for port in upgrade_ports { for port in upgrade_ports {
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port); info!(
"Upgrading raw UDP listener on port {} to QUIC endpoint",
port
);
// Stop the raw UDP listener task and drain sessions to release the socket // Stop the raw UDP listener task and drain sessions to release the socket
if let Some((handle, _)) = self.listeners.remove(&port) { if let Some((handle, _)) = self.listeners.remove(&port) {
handle.abort(); handle.abort();
} }
let drained = self.session_table.drain_port( let drained = self
port, &self.metrics, &self.conn_tracker, .session_table
); .drain_port(port, &self.metrics, &self.conn_tracker);
if drained > 0 { if drained > 0 {
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port); debug!(
"Drained {} UDP sessions on port {} for QUIC upgrade",
drained, port
);
} }
// Brief yield to let aborted tasks drop their socket references // Brief yield to let aborted tasks drop their socket references
@@ -326,11 +361,17 @@ impl UdpListenerManager {
match create_result { match create_result {
Ok(()) => { Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port); info!(
"QUIC endpoint started on port {} (upgraded from raw UDP)",
port
);
} }
Err(e) => { Err(e) => {
// Port may still be held — retry once after a brief delay // Port may still be held — retry once after a brief delay
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e); warn!(
"QUIC endpoint creation failed on port {}, retrying: {}",
port, e
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await; tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let retry_result = if self.proxy_ips.is_empty() { let retry_result = if self.proxy_ips.is_empty() {
@@ -341,11 +382,17 @@ impl UdpListenerManager {
match retry_result { match retry_result {
Ok(()) => { Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port); info!(
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
port
);
} }
Err(e2) => { Err(e2) => {
error!("Failed to upgrade port {} to QUIC after retry: {}. \ error!(
Rebinding as raw UDP.", port, e2); "Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.",
port, e2
);
// Fallback: rebind as raw UDP so the port isn't dead // Fallback: rebind as raw UDP so the port isn't dead
if let Ok(()) = self.rebind_raw_udp(port).await { if let Ok(()) = self.rebind_raw_udp(port).await {
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port); warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
@@ -358,7 +405,11 @@ impl UdpListenerManager {
} }
/// Create a direct QUIC endpoint (quinn owns the socket). /// Create a direct QUIC endpoint (quinn owns the socket).
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> { fn create_quic_direct(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?; let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
let endpoint_for_updates = endpoint.clone(); let endpoint_for_updates = endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop( let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
@@ -372,17 +423,24 @@ impl UdpListenerManager {
None, None,
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(()) Ok(())
} }
/// Create a QUIC endpoint with PROXY protocol relay. /// Create a QUIC endpoint with PROXY protocol relay.
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> { fn create_quic_with_relay(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay( let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
port, port,
tls_config, tls_config,
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(), self.cancel_token.child_token(),
)?; )?;
let endpoint_for_updates = relay.endpoint.clone(); let endpoint_for_updates = relay.endpoint.clone();
@@ -397,8 +455,10 @@ impl UdpListenerManager {
Some(relay.real_client_map), Some(relay.real_client_map),
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(()) Ok(())
} }
@@ -419,6 +479,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer), Arc::clone(&self.relay_writer),
self.cancel_token.child_token(), self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, None)); self.listeners.insert(port, (handle, None));
@@ -458,7 +519,10 @@ impl UdpListenerManager {
info!("Datagram handler relay connected to {}", path); info!("Datagram handler relay connected to {}", path);
} }
Err(e) => { Err(e) => {
error!("Failed to connect datagram handler relay to {}: {}", path, e); error!(
"Failed to connect datagram handler relay to {}: {}",
path, e
);
} }
} }
} }
@@ -514,6 +578,7 @@ impl UdpListenerManager {
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>, relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
cancel: CancellationToken, cancel: CancellationToken,
proxy_ips: Arc<Vec<IpAddr>>, proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) { ) {
// Use a reasonably large buffer; actual max is per-route but we need a single buffer // Use a reasonably large buffer; actual max is per-route but we need a single buffer
let mut buf = vec![0u8; 65535]; let mut buf = vec![0u8; 65535];
@@ -528,9 +593,11 @@ impl UdpListenerManager {
loop { loop {
// Periodic cleanup: remove proxy_addr_map entries with no active session // Periodic cleanup: remove proxy_addr_map entries with no active session
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval { if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
{
last_proxy_cleanup = tokio::time::Instant::now(); last_proxy_cleanup = tokio::time::Instant::now();
let stale: Vec<SocketAddr> = proxy_addr_map.iter() let stale: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| { .filter(|entry| {
let key: SessionKey = (*entry.key(), port); let key: SessionKey = (*entry.key(), port);
session_table.get(&key).is_none() session_table.get(&key).is_none()
@@ -538,7 +605,11 @@ impl UdpListenerManager {
.map(|entry| *entry.key()) .map(|entry| *entry.key())
.collect(); .collect();
if !stale.is_empty() { if !stale.is_empty() {
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port); debug!(
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
stale.len(),
port
);
for addr in stale { for addr in stale {
proxy_addr_map.remove(&addr); proxy_addr_map.remove(&addr);
} }
@@ -564,34 +635,50 @@ impl UdpListenerManager {
let datagram = &buf[..len]; let datagram = &buf[..len];
// PROXY protocol v2 detection for datagrams from trusted proxy IPs // PROXY protocol v2 detection for datagrams from trusted proxy IPs
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) { let effective_client_ip =
let session_key: SessionKey = (client_addr, port); if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) { let session_key: SessionKey = (client_addr, port);
// No session and no prior PROXY header — check for PROXY v2 if session_table.get(&session_key).is_none()
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) { && !proxy_addr_map.contains_key(&client_addr)
match crate::proxy_protocol::parse_v2(datagram) { {
Ok((header, _consumed)) => { // No session and no prior PROXY header — check for PROXY v2
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr); if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
proxy_addr_map.insert(client_addr, header.source_addr); match crate::proxy_protocol::parse_v2(datagram) {
continue; // discard the PROXY v2 datagram Ok((header, _consumed)) => {
} debug!(
Err(e) => { "UDP PROXY v2 from {}: real client {}",
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e); client_addr, header.source_addr
client_addr.ip() );
proxy_addr_map.insert(client_addr, header.source_addr);
continue; // discard the PROXY v2 datagram
}
Err(e) => {
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
client_addr.ip()
}
} }
} else {
client_addr.ip()
} }
} else { } else {
client_addr.ip() // Use real client IP if we've previously seen a PROXY v2 header
proxy_addr_map
.get(&client_addr)
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip())
} }
} else { } else {
// Use real client IP if we've previously seen a PROXY v2 header client_addr.ip()
proxy_addr_map.get(&client_addr) };
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip()) let block_list = security_policy.load();
} if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
} else { debug!(
client_addr.ip() "UDP datagram from {} blocked by global security policy",
}; effective_client_ip
);
continue;
}
// Route matching — use effective (real) client IP // Route matching — use effective (real) client IP
let rm = route_manager.load(); let rm = route_manager.load();
@@ -611,7 +698,10 @@ impl UdpListenerManager {
let route_match = match rm.find_route(&ctx) { let route_match = match rm.find_route(&ctx) {
Some(m) => m, Some(m) => m,
None => { None => {
debug!("No UDP route matched for port {} from {}", port, client_addr); debug!(
"No UDP route matched for port {} from {}",
port, client_addr
);
continue; continue;
} }
}; };
@@ -627,7 +717,9 @@ impl UdpListenerManager {
&client_addr, &client_addr,
port, port,
datagram, datagram,
).await { )
.await
{
debug!("Failed to relay UDP datagram to TS: {}", e); debug!("Failed to relay UDP datagram to TS: {}", e);
} }
continue; continue;
@@ -638,8 +730,10 @@ impl UdpListenerManager {
// Check datagram size // Check datagram size
if len as u32 > udp_config.max_datagram_size { if len as u32 > udp_config.max_datagram_size {
debug!("UDP datagram too large ({} > {}) from {}, dropping", debug!(
len, udp_config.max_datagram_size, client_addr); "UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr
);
continue; continue;
} }
@@ -651,21 +745,27 @@ impl UdpListenerManager {
None => { None => {
// New session — check per-IP limits using the real client IP // New session — check per-IP limits using the real client IP
if !conn_tracker.try_accept(&effective_client_ip) { if !conn_tracker.try_accept(&effective_client_ip) {
debug!("UDP session rejected for {} (rate limit)", effective_client_ip); debug!(
"UDP session rejected for {} (rate limit)",
effective_client_ip
);
continue; continue;
} }
if !session_table.can_create_session( if !session_table
&effective_client_ip, .can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
udp_config.max_sessions_per_ip, {
) { debug!(
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip); "UDP session rejected for {} (per-IP session limit)",
effective_client_ip
);
continue; continue;
} }
// Resolve target // Resolve target
let target = match route_match.target.or_else(|| { let target = match route_match
route.action.targets.as_ref().and_then(|t| t.first()) .target
}) { .or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
{
Some(t) => t, Some(t) => t,
None => { None => {
warn!("No target for UDP route {:?}", route_id); warn!("No target for UDP route {:?}", route_id);
@@ -686,13 +786,18 @@ impl UdpListenerManager {
} }
}; };
if let Err(e) = backend_socket.connect(&backend_addr).await { if let Err(e) = backend_socket.connect(&backend_addr).await {
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e); error!(
"Failed to connect backend UDP socket to {}: {}",
backend_addr, e
);
continue; continue;
} }
let backend_socket = Arc::new(backend_socket); let backend_socket = Arc::new(backend_socket);
debug!("New UDP session: {} -> {} (via port {}, real client {})", debug!(
client_addr, backend_addr, port, effective_client_ip); "New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip
);
// Spawn return-path relay task // Spawn return-path relay task
let session_cancel = CancellationToken::new(); let session_cancel = CancellationToken::new();
@@ -709,7 +814,9 @@ impl UdpListenerManager {
let session = Arc::new(UdpSession { let session = Arc::new(UdpSession {
backend_socket, backend_socket,
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()), last_activity: std::sync::atomic::AtomicU64::new(
session_table.elapsed_ms(),
),
created_at: std::time::Instant::now(), created_at: std::time::Instant::now(),
route_id: route_id.map(|s| s.to_string()), route_id: route_id.map(|s| s.to_string()),
source_ip: effective_client_ip, source_ip: effective_client_ip,
@@ -718,7 +825,11 @@ impl UdpListenerManager {
cancel: session_cancel, cancel: session_cancel,
}); });
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) { if !session_table.insert(
session_key,
Arc::clone(&session),
udp_config.max_sessions_per_ip,
) {
warn!("Failed to insert UDP session (race condition)"); warn!("Failed to insert UDP session (race condition)");
continue; continue;
} }
@@ -735,7 +846,9 @@ impl UdpListenerManager {
// Forward datagram to backend // Forward datagram to backend
match session.backend_socket.send(datagram).await { match session.backend_socket.send(datagram).await {
Ok(_) => { Ok(_) => {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed); session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str)); metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
metrics.record_datagram_in(); metrics.record_datagram_in();
} }
@@ -779,7 +892,9 @@ impl UdpListenerManager {
Ok(_) => { Ok(_) => {
// Update session activity // Update session activity
if let Some(session) = session_table.get(&session_key) { if let Some(session) = session_table.get(&session_key) {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed); session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
} }
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str)); metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
metrics.record_datagram_out(); metrics.record_datagram_out();
@@ -814,7 +929,8 @@ impl UdpListenerManager {
let json = serde_json::to_vec(&msg)?; let json = serde_json::to_vec(&msg)?;
let mut guard = writer.lock().await; let mut guard = writer.lock().await;
let stream = guard.as_mut() let stream = guard
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?; .ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
// Length-prefixed frame // Length-prefixed frame
@@ -879,9 +995,15 @@ impl UdpListenerManager {
} }
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or(""); let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16; let source_port = reply
.get("sourcePort")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u16;
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16; let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or(""); let payload_b64 = reply
.get("payloadBase64")
.and_then(|v| v.as_str())
.unwrap_or("");
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) { let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
Ok(p) => p, Ok(p) => p,
@@ -111,12 +111,15 @@ impl UdpSessionTable {
/// Look up an existing session. /// Look up an existing session.
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> { pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
self.sessions.get(key).map(|entry| Arc::clone(entry.value())) self.sessions
.get(key)
.map(|entry| Arc::clone(entry.value()))
} }
/// Check if we can create a new session for this IP (under the per-IP limit). /// Check if we can create a new session for this IP (under the per-IP limit).
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool { pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
let count = self.ip_session_counts let count = self
.ip_session_counts
.get(ip) .get(ip)
.map(|c| *c.value()) .map(|c| *c.value())
.unwrap_or(0); .unwrap_or(0);
@@ -124,12 +127,7 @@ impl UdpSessionTable {
} }
/// Insert a new session. Returns false if per-IP limit exceeded. /// Insert a new session. Returns false if per-IP limit exceeded.
pub fn insert( pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
&self,
key: SessionKey,
session: Arc<UdpSession>,
max_per_ip: u32,
) -> bool {
let ip = session.source_ip; let ip = session.source_ip;
// Atomically check and increment per-IP count // Atomically check and increment per-IP count
@@ -173,7 +171,9 @@ impl UdpSessionTable {
let mut removed = 0; let mut removed = 0;
// Collect keys to remove (avoid holding DashMap refs during removal) // Collect keys to remove (avoid holding DashMap refs during removal)
let stale_keys: Vec<SessionKey> = self.sessions.iter() let stale_keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| { .filter(|entry| {
let last = entry.value().last_activity.load(Ordering::Relaxed); let last = entry.value().last_activity.load(Ordering::Relaxed);
now_ms.saturating_sub(last) >= timeout_ms now_ms.saturating_sub(last) >= timeout_ms
@@ -185,7 +185,8 @@ impl UdpSessionTable {
if let Some(session) = self.remove(&key) { if let Some(session) = self.remove(&key) {
debug!( debug!(
"UDP session expired: {} -> port {} (idle {}ms)", "UDP session expired: {} -> port {} (idle {}ms)",
session.client_addr, key.1, session.client_addr,
key.1,
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed)) now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
); );
conn_tracker.connection_closed(&session.source_ip); conn_tracker.connection_closed(&session.source_ip);
@@ -210,7 +211,9 @@ impl UdpSessionTable {
metrics: &MetricsCollector, metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker, conn_tracker: &ConnectionTracker,
) -> usize { ) -> usize {
let keys: Vec<SessionKey> = self.sessions.iter() let keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| entry.key().1 == port) .filter(|entry| entry.key().1 == port)
.map(|entry| *entry.key()) .map(|entry| *entry.key())
.collect(); .collect();
@@ -257,9 +260,8 @@ mod tests {
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let backend_socket = rt.block_on(async { let backend_socket =
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
});
let child_cancel = cancel.child_token(); let child_cancel = cancel.child_token();
let return_task = rt.spawn(async move { let return_task = rt.spawn(async move {
+1 -1
View File
@@ -3,7 +3,7 @@
//! Route matching engine for RustProxy. //! Route matching engine for RustProxy.
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager. //! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
pub mod route_manager;
pub mod matchers; pub mod matchers;
pub mod route_manager;
pub use route_manager::*; pub use route_manager::*;
@@ -20,7 +20,7 @@ pub fn domain_matches(pattern: &str, domain: &str) -> bool {
// Wildcard patterns // Wildcard patterns
if pattern.starts_with("*.") || pattern.starts_with("*.") { if pattern.starts_with("*.") || pattern.starts_with("*.") {
let suffix = &pattern[2..]; // e.g., "example.com" let suffix = &pattern[2..]; // e.g., "example.com"
// Match exact parent or any single-level subdomain // Match exact parent or any single-level subdomain
if domain.eq_ignore_ascii_case(suffix) { if domain.eq_ignore_ascii_case(suffix) {
return true; return true;
} }
@@ -1,6 +1,6 @@
use ipnet::IpNet;
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr; use std::str::FromStr;
use ipnet::IpNet;
/// Match an IP address against a pattern. /// Match an IP address against a pattern.
/// ///
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
} }
} }
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len)) Some(format!(
"{}.{}.{}.{}/{}",
octets[0], octets[1], octets[2], octets[3], prefix_len
))
} }
#[cfg(test)] #[cfg(test)]
@@ -1,9 +1,9 @@
pub mod domain; pub mod domain;
pub mod path;
pub mod ip;
pub mod header; pub mod header;
pub mod ip;
pub mod path;
pub use domain::*; pub use domain::*;
pub use path::*;
pub use ip::*;
pub use header::*; pub use header::*;
pub use ip::*;
pub use path::*;
@@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
use crate::matchers; use crate::matchers;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
/// Context for route matching (subset of connection info). /// Context for route matching (subset of connection info).
pub struct MatchContext<'a> { pub struct MatchContext<'a> {
@@ -42,19 +42,14 @@ impl RouteManager {
}; };
// Filter enabled routes and sort by priority // Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes let mut enabled_routes: Vec<RouteConfig> =
.into_iter() routes.into_iter().filter(|r| r.is_enabled()).collect();
.filter(|r| r.is_enabled())
.collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority())); enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index // Build port index
for (idx, route) in enabled_routes.iter().enumerate() { for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() { for port in route.listening_ports() {
manager.port_index manager.port_index.entry(port).or_default().push(idx);
.entry(port)
.or_default()
.push(idx);
} }
} }
@@ -66,7 +61,9 @@ impl RouteManager {
/// Used to skip expensive header HashMap construction when no route needs it. /// Used to skip expensive header HashMap construction when no route needs it.
pub fn any_route_has_headers(&self, port: u16) -> bool { pub fn any_route_has_headers(&self, port: u16) -> bool {
if let Some(indices) = self.port_index.get(&port) { if let Some(indices) = self.port_index.get(&port) {
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some()) indices
.iter()
.any(|&idx| self.routes[idx].route_match.headers.is_some())
} else { } else {
false false
} }
@@ -99,8 +96,8 @@ impl RouteManager {
let ctx_transport = ctx.transport.as_ref(); let ctx_transport = ctx.transport.as_ref();
match (route_transport, ctx_transport) { match (route_transport, ctx_transport) {
// Route requires UDP only — reject non-UDP contexts // Route requires UDP only — reject non-UDP contexts
(Some(TransportProtocol::Udp), None) | (Some(TransportProtocol::Udp), None)
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false, | (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
// Route requires TCP only — reject UDP contexts // Route requires TCP only — reject UDP contexts
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false, (Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
// Route has no transport (default = TCP) — reject UDP contexts // Route has no transport (default = TCP) — reject UDP contexts
@@ -196,7 +193,11 @@ impl RouteManager {
} }
/// Find the best matching target within a route. /// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> { fn find_target<'a>(
&self,
route: &'a RouteConfig,
ctx: &MatchContext<'_>,
) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?; let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() { if targets.len() == 1 && targets[0].target_match.is_none() {
@@ -223,17 +224,11 @@ impl RouteManager {
} }
// Fall back to first target without match criteria // Fall back to first target without match criteria
best.or_else(|| { best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
targets.iter().find(|t| t.target_match.is_none())
})
} }
/// Check if a target match criteria matches the context. /// Check if a target match criteria matches the context.
fn matches_target( fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
// Port matching // Port matching
if let Some(ref ports) = tm.ports { if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) { if !ports.contains(&ctx.port) {
@@ -298,9 +293,7 @@ impl RouteManager {
// If multiple passthrough routes on same port, SNI is needed // If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes let passthrough_routes: Vec<_> = routes
.iter() .iter()
.filter(|r| { .filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.collect(); .collect();
if passthrough_routes.len() > 1 { if passthrough_routes.len() > 1 {
@@ -419,7 +412,11 @@ mod tests {
let result = manager.find_route(&ctx).unwrap(); let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route // Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref() assert!(result
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec()) .map(|d| d.to_vec())
.unwrap() .unwrap()
.contains(&"api.example.com")); .contains(&"api.example.com"));
@@ -619,8 +616,14 @@ mod tests {
let result = manager.find_route(&ctx); let result = manager.find_route(&ctx);
assert!(result.is_some()); assert!(result.is_some());
let matched_domains = result.unwrap().route.route_match.domains.as_ref() let matched_domains = result
.map(|d| d.to_vec()).unwrap(); .unwrap()
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec())
.unwrap();
assert!(matched_domains.contains(&"*")); assert!(matched_domains.contains(&"*"));
} }
@@ -735,7 +738,11 @@ mod tests {
assert_eq!(result.target.unwrap().host.first(), "default-backend"); assert_eq!(result.target.unwrap().host.first(), "default-backend");
} }
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig { fn make_route_with_protocol(
port: u16,
domain: Option<&str>,
protocol: Option<&str>,
) -> RouteConfig {
let mut route = make_route(port, domain, 0); let mut route = make_route(port, domain, 0);
route.route_match.protocol = protocol.map(|s| s.to_string()); route.route_match.protocol = protocol.map(|s| s.to_string());
route route
@@ -1029,8 +1036,10 @@ mod tests {
transport: Some(TransportProtocol::Udp), transport: Some(TransportProtocol::Udp),
}; };
assert!(manager.find_route(&ctx).is_some(), assert!(
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"); manager.find_route(&ctx).is_some(),
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
);
} }
#[test] #[test]
@@ -1051,7 +1060,9 @@ mod tests {
transport: None, // TCP (default) transport: None, // TCP (default)
}; };
assert!(manager.find_route(&ctx).is_none(), assert!(
"TCP TLS without SNI should NOT match domain-restricted routes"); manager.find_route(&ctx).is_none(),
"TCP TLS without SNI should NOT match domain-restricted routes"
);
} }
} }
@@ -1,5 +1,5 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64; use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
/// Basic auth validator. /// Basic auth validator.
pub struct BasicAuthValidator { pub struct BasicAuthValidator {
+45 -26
View File
@@ -21,7 +21,7 @@ struct DomainScopedEntry {
} }
/// Represents an IP pattern for matching. /// Represents an IP pattern for matching.
#[derive(Debug)] #[derive(Debug, Clone)]
enum IpPattern { enum IpPattern {
/// Exact IP match /// Exact IP match
Exact(IpAddr), Exact(IpAddr),
@@ -31,6 +31,37 @@ enum IpPattern {
Wildcard, Wildcard,
} }
/// Compiled block list for early ingress filtering.
#[derive(Debug, Clone)]
pub struct IpBlockList {
block_list: Vec<IpPattern>,
}
impl IpBlockList {
pub fn new(block_list: &[String]) -> Self {
Self {
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
pub fn empty() -> Self {
Self {
block_list: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.block_list.is_empty()
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
let normalized = IpFilter::normalize_ip(ip);
self.block_list
.iter()
.any(|pattern| pattern.matches(&normalized))
}
}
impl IpPattern { impl IpPattern {
fn parse(s: &str) -> Self { fn parse(s: &str) -> Self {
let s = s.trim(); let s = s.trim();
@@ -68,8 +99,7 @@ fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
} }
if p.starts_with("*.") { if p.starts_with("*.") {
let suffix = &p[1..]; // e.g., ".abc.xyz" let suffix = &p[1..]; // e.g., ".abc.xyz"
d.len() > suffix.len() d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
&& d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
} else { } else {
false false
} }
@@ -127,7 +157,11 @@ impl IpFilter {
if let Some(req_domain) = domain { if let Some(req_domain) = domain {
for entry in &self.domain_scoped { for entry in &self.domain_scoped {
if entry.pattern.matches(ip) { if entry.pattern.matches(ip) {
if entry.domains.iter().any(|d| domain_matches_pattern(d, req_domain)) { if entry
.domains
.iter()
.any(|d| domain_matches_pattern(d, req_domain))
{
return true; return true;
} }
} }
@@ -212,10 +246,7 @@ mod tests {
#[test] #[test]
fn test_block_trumps_allow() { fn test_block_trumps_allow() {
let filter = IpFilter::new( let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
&[plain("10.0.0.0/8")],
&["10.0.0.5".to_string()],
);
let blocked: IpAddr = "10.0.0.5".parse().unwrap(); let blocked: IpAddr = "10.0.0.5".parse().unwrap();
let allowed: IpAddr = "10.0.0.6".parse().unwrap(); let allowed: IpAddr = "10.0.0.6".parse().unwrap();
assert!(!filter.is_allowed(&blocked)); assert!(!filter.is_allowed(&blocked));
@@ -255,30 +286,21 @@ mod tests {
#[test] #[test]
fn test_domain_scoped_allows_matching_domain() { fn test_domain_scoped_allows_matching_domain() {
let filter = IpFilter::new( let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&[],
);
let ip: IpAddr = "10.8.0.2".parse().unwrap(); let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz"))); assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
} }
#[test] #[test]
fn test_domain_scoped_denies_non_matching_domain() { fn test_domain_scoped_denies_non_matching_domain() {
let filter = IpFilter::new( let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&[],
);
let ip: IpAddr = "10.8.0.2".parse().unwrap(); let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz"))); assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
} }
#[test] #[test]
fn test_domain_scoped_denies_without_domain() { fn test_domain_scoped_denies_without_domain() {
let filter = IpFilter::new( let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&[],
);
let ip: IpAddr = "10.8.0.2".parse().unwrap(); let ip: IpAddr = "10.8.0.2".parse().unwrap();
// Without domain context, domain-scoped entries cannot match // Without domain context, domain-scoped entries cannot match
assert!(!filter.is_allowed_for_domain(&ip, None)); assert!(!filter.is_allowed_for_domain(&ip, None));
@@ -286,10 +308,7 @@ mod tests {
#[test] #[test]
fn test_domain_scoped_wildcard_domain() { fn test_domain_scoped_wildcard_domain() {
let filter = IpFilter::new( let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
&[scoped("10.8.0.2", &["*.abc.xyz"])],
&[],
);
let ip: IpAddr = "10.8.0.2".parse().unwrap(); let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz"))); assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz"))); assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
@@ -300,8 +319,8 @@ mod tests {
fn test_plain_and_domain_scoped_coexist() { fn test_plain_and_domain_scoped_coexist() {
let filter = IpFilter::new( let filter = IpFilter::new(
&[ &[
plain("1.2.3.4"), // full route access plain("1.2.3.4"), // full route access
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
], ],
&[], &[],
); );
@@ -1,4 +1,4 @@
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm}; use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure). /// JWT claims (minimal structure).
@@ -160,10 +160,7 @@ mod tests {
#[test] #[test]
fn test_extract_token_bearer() { fn test_extract_token_bearer() {
assert_eq!( assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
} }
#[test] #[test]
+4 -4
View File
@@ -2,12 +2,12 @@
//! //!
//! IP filtering, rate limiting, and authentication for RustProxy. //! IP filtering, rate limiting, and authentication for RustProxy.
pub mod ip_filter;
pub mod rate_limiter;
pub mod basic_auth; pub mod basic_auth;
pub mod ip_filter;
pub mod jwt_auth; pub mod jwt_auth;
pub mod rate_limiter;
pub use ip_filter::*;
pub use rate_limiter::*;
pub use basic_auth::*; pub use basic_auth::*;
pub use ip_filter::*;
pub use jwt_auth::*; pub use jwt_auth::*;
pub use rate_limiter::*;
@@ -79,7 +79,7 @@ mod tests {
assert!(limiter.check("client-a")); assert!(limiter.check("client-a"));
assert!(limiter.check("client-a")); assert!(limiter.check("client-a"));
assert!(!limiter.check("client-a")); // blocked assert!(!limiter.check("client-a")); // blocked
// Different key should still be allowed // Different key should still be allowed
assert!(limiter.check("client-b")); assert!(limiter.check("client-b"));
assert!(limiter.check("client-b")); assert!(limiter.check("client-b"));
} }
+14 -13
View File
@@ -4,8 +4,7 @@
//! Account credentials are ephemeral — the consumer owns all persistence. //! Account credentials are ephemeral — the consumer owns all persistence.
use instant_acme::{ use instant_acme::{
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus, Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
AccountCredentials,
}; };
use rcgen::{CertificateParams, KeyPair}; use rcgen::{CertificateParams, KeyPair};
use thiserror::Error; use thiserror::Error;
@@ -89,7 +88,11 @@ impl AcmeClient {
F: FnOnce(PendingChallenge) -> Fut, F: FnOnce(PendingChallenge) -> Fut,
Fut: std::future::Future<Output = Result<(), AcmeError>>, Fut: std::future::Future<Output = Result<(), AcmeError>>,
{ {
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url()); info!(
"Starting ACME provisioning for {} via {}",
domain,
self.directory_url()
);
// 1. Get or create ACME account // 1. Get or create ACME account
let account = self.get_or_create_account().await?; let account = self.get_or_create_account().await?;
@@ -170,14 +173,14 @@ impl AcmeClient {
debug!("Order ready, finalizing..."); debug!("Order ready, finalizing...");
// 6. Generate CSR and finalize // 6. Generate CSR and finalize
let key_pair = KeyPair::generate().map_err(|e| { let key_pair = KeyPair::generate()
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)) .map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
})?;
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| { let mut params = CertificateParams::new(vec![domain.to_string()])
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)) .map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
})?; params
params.distinguished_name.push(rcgen::DnType::CommonName, domain); .distinguished_name
.push(rcgen::DnType::CommonName, domain);
let csr = params.serialize_request(&key_pair).map_err(|e| { let csr = params.serialize_request(&key_pair).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e)) AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
@@ -219,9 +222,7 @@ impl AcmeClient {
.certificate() .certificate()
.await .await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))? .map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
.ok_or_else(|| { .ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
AcmeError::FinalizationFailed("No certificate returned".to_string())
})?;
let private_key_pem = key_pair.serialize_pem(); let private_key_pem = key_pair.serialize_pem();
+20 -22
View File
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error; use thiserror::Error;
use tracing::info; use tracing::info;
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
use crate::acme::AcmeClient; use crate::acme::AcmeClient;
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum CertManagerError { pub enum CertManagerError {
@@ -45,17 +45,13 @@ impl CertManager {
/// Create an ACME client using this manager's configuration. /// Create an ACME client using this manager's configuration.
/// Returns None if no ACME email is configured. /// Returns None if no ACME email is configured.
pub fn acme_client(&self) -> Option<AcmeClient> { pub fn acme_client(&self) -> Option<AcmeClient> {
self.acme_email.as_ref().map(|email| { self.acme_email
AcmeClient::new(email.clone(), self.use_production) .as_ref()
}) .map(|email| AcmeClient::new(email.clone(), self.use_production))
} }
/// Load a static certificate into the store (infallible — pure cache insert). /// Load a static certificate into the store (infallible — pure cache insert).
pub fn load_static( pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
&mut self,
domain: String,
bundle: CertBundle,
) {
self.store.store(domain, bundle); self.store.store(domain, bundle);
} }
@@ -108,23 +104,25 @@ impl CertManager {
F: FnOnce(String, String) -> Fut, F: FnOnce(String, String) -> Fut,
Fut: std::future::Future<Output = ()>, Fut: std::future::Future<Output = ()>,
{ {
let acme_client = self.acme_client() let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
.ok_or(CertManagerError::NoEmail)?;
info!("Renewing certificate for {}", domain); info!("Renewing certificate for {}", domain);
let domain_owned = domain.to_string(); let domain_owned = domain.to_string();
let result = acme_client.provision(&domain_owned, |pending| { let result = acme_client
let token = pending.token.clone(); .provision(&domain_owned, |pending| {
let key_auth = pending.key_authorization.clone(); let token = pending.token.clone();
async move { let key_auth = pending.key_authorization.clone();
challenge_setup(token, key_auth).await; async move {
Ok(()) challenge_setup(token, key_auth).await;
} Ok(())
}).await.map_err(|e| CertManagerError::AcmeFailure { }
domain: domain.to_string(), })
message: e.to_string(), .await
})?; .map_err(|e| CertManagerError::AcmeFailure {
domain: domain.to_string(),
message: e.to_string(),
})?;
let (cert_pem, key_pem) = result; let (cert_pem, key_pem) = result;
let now = SystemTime::now() let now = SystemTime::now()
+15 -6
View File
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Certificate metadata stored alongside certs. /// Certificate metadata stored alongside certs.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -90,8 +90,10 @@ mod tests {
fn make_test_bundle(domain: &str) -> CertBundle { fn make_test_bundle(domain: &str) -> CertBundle {
CertBundle { CertBundle {
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(), key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(), .to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
.to_string(),
ca_pem: None, ca_pem: None,
metadata: CertMetadata { metadata: CertMetadata {
domain: domain.to_string(), domain: domain.to_string(),
@@ -122,7 +124,8 @@ mod tests {
let mut store = CertStore::new(); let mut store = CertStore::new();
let mut bundle = make_test_bundle("secure.com"); let mut bundle = make_test_bundle("secure.com");
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string()); bundle.ca_pem =
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
store.store("secure.com".to_string(), bundle); store.store("secure.com".to_string(), bundle);
let loaded = store.get("secure.com").unwrap(); let loaded = store.get("secure.com").unwrap();
@@ -147,7 +150,10 @@ mod tests {
fn test_remove_cert() { fn test_remove_cert() {
let mut store = CertStore::new(); let mut store = CertStore::new();
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com")); store.store(
"remove-me.com".to_string(),
make_test_bundle("remove-me.com"),
);
assert!(store.has("remove-me.com")); assert!(store.has("remove-me.com"));
let removed = store.remove("remove-me.com"); let removed = store.remove("remove-me.com");
@@ -165,7 +171,10 @@ mod tests {
fn test_wildcard_domain() { fn test_wildcard_domain() {
let mut store = CertStore::new(); let mut store = CertStore::new();
store.store("*.example.com".to_string(), make_test_bundle("*.example.com")); store.store(
"*.example.com".to_string(),
make_test_bundle("*.example.com"),
);
assert!(store.has("*.example.com")); assert!(store.has("*.example.com"));
let loaded = store.get("*.example.com").unwrap(); let loaded = store.get("*.example.com").unwrap();
+3 -3
View File
@@ -3,11 +3,11 @@
//! TLS certificate management for RustProxy. //! TLS certificate management for RustProxy.
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution. //! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
pub mod cert_store;
pub mod cert_manager;
pub mod acme; pub mod acme;
pub mod cert_manager;
pub mod cert_store;
pub mod sni_resolver; pub mod sni_resolver;
pub use cert_store::*;
pub use cert_manager::*; pub use cert_manager::*;
pub use cert_store::*;
pub use sni_resolver::*; pub use sni_resolver::*;
+12 -8
View File
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error}; use tracing::{debug, error, info};
/// ACME HTTP-01 challenge server. /// ACME HTTP-01 challenge server.
pub struct ChallengeServer { pub struct ChallengeServer {
@@ -47,7 +47,10 @@ impl ChallengeServer {
} }
/// Start the challenge server on the given port. /// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub async fn start(
&mut self,
port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port); let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port); info!("ACME challenge server listening on port {}", port);
@@ -101,10 +104,7 @@ impl ChallengeServer {
pub async fn stop(&mut self) { pub async fn stop(&mut self) {
self.cancel.cancel(); self.cancel.cancel();
if let Some(handle) = self.handle.take() { if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout( let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
std::time::Duration::from_secs(5),
handle,
).await;
} }
self.challenges.clear(); self.challenges.clear();
self.cancel = CancellationToken::new(); self.cancel = CancellationToken::new();
@@ -154,10 +154,14 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_millis(50)).await; tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge // Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap(); let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
.await
.unwrap();
let io = TokioIo::new(client); let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; }); tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::get("/.well-known/acme-challenge/test-token") let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new())) .body(Full::new(Bytes::new()))
+297 -140
View File
@@ -57,24 +57,27 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result; use anyhow::Result;
use tracing::{info, warn, debug, error}; use arc_swap::ArcSwap;
use tracing::{debug, error, info, warn};
// Re-export key types // Re-export key types
pub use rustproxy_config; pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http; pub use rustproxy_http;
pub use rustproxy_metrics; pub use rustproxy_metrics;
pub use rustproxy_passthrough;
pub use rustproxy_routing;
pub use rustproxy_security; pub use rustproxy_security;
pub use rustproxy_tls;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec}; use rustproxy_config::{CertificateSpec, RouteConfig, RustProxyOptions, TlsMode};
use rustproxy_metrics::{Metrics, MetricsCollector, Statistics};
use rustproxy_passthrough::{
ConnectionConfig, TcpListenerManager, TlsCertConfig, UdpListenerManager,
};
use rustproxy_routing::RouteManager; use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig}; use rustproxy_security::IpBlockList;
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics}; use rustproxy_tls::{CertBundle, CertManager, CertMetadata, CertSource, CertStore};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
/// Certificate status. /// Certificate status.
@@ -106,6 +109,8 @@ pub struct RustProxy {
loaded_certs: HashMap<String, TlsCertConfig>, loaded_certs: HashMap<String, TlsCertConfig>,
/// Cancellation token for cooperative shutdown of background tasks. /// Cancellation token for cooperative shutdown of background tasks.
cancel_token: CancellationToken, cancel_token: CancellationToken,
/// Shared global ingress blocklist, hot-reloadable across TCP/UDP listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
} }
impl RustProxy { impl RustProxy {
@@ -127,13 +132,19 @@ impl RustProxy {
let route_manager = RouteManager::new(options.routes.clone()); let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured // Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options) let cert_manager =
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm))); Self::build_cert_manager(&options).map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let retention = options.metrics.as_ref() let retention = options
.metrics
.as_ref()
.and_then(|m| m.retention_seconds) .and_then(|m| m.retention_seconds)
.unwrap_or(3600) as usize; .unwrap_or(3600) as usize;
let security_policy = Arc::new(ArcSwap::from(Arc::new(Self::build_ip_block_list(
options.security_policy.as_ref(),
))));
Ok(Self { Ok(Self {
options, options,
route_table: ArcSwap::from(Arc::new(route_manager)), route_table: ArcSwap::from(Arc::new(route_manager)),
@@ -149,6 +160,7 @@ impl RustProxy {
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)), socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
loaded_certs: HashMap::new(), loaded_certs: HashMap::new(),
cancel_token: CancellationToken::new(), cancel_token: CancellationToken::new(),
security_policy,
}) })
} }
@@ -163,24 +175,25 @@ impl RustProxy {
// Apply default target if route has no targets // Apply default target if route has no targets
if route.action.targets.is_none() { if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target { if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}", debug!(
default_target.host, default_target.port, "Applying default target {}:{} to route {:?}",
route.name.as_deref().unwrap_or("unnamed")); default_target.host,
route.action.targets = Some(vec![ default_target.port,
rustproxy_config::RouteTarget { route.name.as_deref().unwrap_or("unnamed")
target_match: None, );
host: rustproxy_config::HostSpec::Single(default_target.host.clone()), route.action.targets = Some(vec![rustproxy_config::RouteTarget {
port: rustproxy_config::PortSpec::Fixed(default_target.port), target_match: None,
tls: None, host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
websocket: None, port: rustproxy_config::PortSpec::Fixed(default_target.port),
load_balancing: None, tls: None,
send_proxy_protocol: None, websocket: None,
headers: None, load_balancing: None,
advanced: None, send_proxy_protocol: None,
backend_transport: None, headers: None,
priority: None, advanced: None,
} backend_transport: None,
]); priority: None,
}]);
} }
} }
@@ -199,7 +212,10 @@ impl RustProxy {
if let Some(ref allow_list) = default_security.ip_allow_list { if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some( security.ip_allow_list = Some(
allow_list.iter().map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone())).collect() allow_list
.iter()
.map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone()))
.collect(),
); );
} }
if let Some(ref block_list) = default_security.ip_block_list { if let Some(ref block_list) = default_security.ip_block_list {
@@ -208,8 +224,10 @@ impl RustProxy {
// Only apply if there's something meaningful // Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}", debug!(
route.name.as_deref().unwrap_or("unnamed")); "Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed")
);
route.security = Some(security); route.security = Some(security);
} }
} }
@@ -224,13 +242,17 @@ impl RustProxy {
return None; return None;
} }
let email = acme.email.clone() let email = acme.email.clone().or_else(|| acme.account_email.clone());
.or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false); let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30); let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new(); let store = CertStore::new();
Some(CertManager::new(store, email, use_production, renew_before_days)) Some(CertManager::new(
store,
email,
use_production,
renew_before_days,
))
} }
/// Build ConnectionConfig from RustProxyOptions. /// Build ConnectionConfig from RustProxyOptions.
@@ -248,7 +270,10 @@ impl RustProxy {
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime, extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false), accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false), send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
proxy_ips: options.proxy_ips.as_deref().unwrap_or(&[]) proxy_ips: options
.proxy_ips
.as_deref()
.unwrap_or(&[])
.iter() .iter()
.filter_map(|s| s.parse::<std::net::IpAddr>().ok()) .filter_map(|s| s.parse::<std::net::IpAddr>().ok())
.collect(), .collect(),
@@ -258,6 +283,22 @@ impl RustProxy {
} }
} }
fn build_ip_block_list(policy: Option<&rustproxy_config::SecurityPolicy>) -> IpBlockList {
let Some(policy) = policy else {
return IpBlockList::empty();
};
let mut entries = Vec::new();
if let Some(blocked_ips) = &policy.blocked_ips {
entries.extend(blocked_ips.iter().cloned());
}
if let Some(blocked_cidrs) = &policy.blocked_cidrs {
entries.extend(blocked_cidrs.iter().cloned());
}
IpBlockList::new(&entries)
}
/// Start the proxy, binding to all configured ports. /// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> { pub async fn start(&mut self) -> Result<()> {
if self.started { if self.started {
@@ -272,7 +313,11 @@ impl RustProxy {
let route_manager = self.route_table.load(); let route_manager = self.route_table.load();
let ports = route_manager.listening_ports(); let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len()); info!(
"Configured {} routes on {} ports",
route_manager.route_count(),
ports.len()
);
// Create TCP listener manager with metrics // Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics( let mut listener = TcpListenerManager::with_metrics(
@@ -282,7 +327,8 @@ impl RustProxy {
// Apply connection config from options // Apply connection config from options
let conn_config = Self::build_connection_config(&self.options); let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms", debug!(
"Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms, conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms, conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms, conn_config.socket_timeout_ms,
@@ -291,6 +337,7 @@ impl RustProxy {
// Clone proxy_ips before conn_config is moved into the TCP listener // Clone proxy_ips before conn_config is moved into the TCP listener
let udp_proxy_ips = conn_config.proxy_ips.clone(); let udp_proxy_ips = conn_config.proxy_ips.clone();
listener.set_connection_config(conn_config); listener.set_connection_config(conn_config);
listener.set_security_policy(Arc::clone(&self.security_policy));
// Share the socket-handler relay path with the listener // Share the socket-handler relay path with the listener
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay)); listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
@@ -303,10 +350,13 @@ impl RustProxy {
let cm = cm.lock().await; let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() { for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) { if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: bundle.cert_pem.clone(), domain.clone(),
key_pem: bundle.key_pem.clone(), TlsCertConfig {
}); cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
} }
} }
} }
@@ -330,7 +380,9 @@ impl RustProxy {
let mut tcp_ports = std::collections::HashSet::new(); let mut tcp_ports = std::collections::HashSet::new();
let mut udp_ports = std::collections::HashSet::new(); let mut udp_ports = std::collections::HashSet::new();
for route in &self.options.routes { for route in &self.options.routes {
if !route.is_enabled() { continue; } if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref(); let transport = route.route_match.transport.as_ref();
let route_ports = route.route_match.ports.to_ports(); let route_ports = route.route_match.ports.to_ports();
for port in route_ports { for port in route_ports {
@@ -371,6 +423,7 @@ impl RustProxy {
connection_registry, connection_registry,
); );
udp_mgr.set_proxy_ips(udp_proxy_ips.clone()); udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Share HttpProxyService with H3 — same route matching, connection // Share HttpProxyService with H3 — same route matching, connection
// pool, and ALPN protocol detection as the TCP/HTTP path. // pool, and ALPN protocol detection as the TCP/HTTP path.
@@ -379,10 +432,15 @@ impl RustProxy {
udp_mgr.set_h3_service(Arc::new(h3_svc)); udp_mgr.set_h3_service(Arc::new(h3_svc));
for port in &udp_ports { for port in &udp_ports {
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?; udp_mgr
.add_port_with_tls(*port, quic_tls_config.clone())
.await?;
} }
info!("UDP listeners started on {} ports: {:?}", info!(
udp_ports.len(), udp_mgr.listening_ports()); "UDP listeners started on {} ports: {:?}",
udp_ports.len(),
udp_mgr.listening_ports()
);
self.udp_listener_manager = Some(udp_mgr); self.udp_listener_manager = Some(udp_mgr);
} }
@@ -391,16 +449,22 @@ impl RustProxy {
// Start the throughput sampling task with cooperative cancellation // Start the throughput sampling task with cooperative cancellation
let metrics = Arc::clone(&self.metrics); let metrics = Arc::clone(&self.metrics);
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone(); let conn_tracker = self
.listener_manager
.as_ref()
.unwrap()
.conn_tracker()
.clone();
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone(); let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
let interval_ms = self.options.metrics.as_ref() let interval_ms = self
.options
.metrics
.as_ref()
.and_then(|m| m.sample_interval_ms) .and_then(|m| m.sample_interval_ms)
.unwrap_or(1000); .unwrap_or(1000);
let sampling_cancel = self.cancel_token.clone(); let sampling_cancel = self.cancel_token.clone();
self.sampling_handle = Some(tokio::spawn(async move { self.sampling_handle = Some(tokio::spawn(async move {
let mut interval = tokio::time::interval( let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
std::time::Duration::from_millis(interval_ms)
);
loop { loop {
tokio::select! { tokio::select! {
_ = sampling_cancel.cancelled() => break, _ = sampling_cancel.cancelled() => break,
@@ -442,7 +506,10 @@ impl RustProxy {
continue; continue;
} }
let cert_spec = route.action.tls.as_ref() let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref()); .and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec { if let Some(CertificateSpec::Auto(_)) = cert_spec {
@@ -466,16 +533,25 @@ impl RustProxy {
return; return;
} }
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len()); info!(
"Auto-provisioning certificates for {} domains",
domains_to_provision.len()
);
// Start challenge server // Start challenge server
let acme_port = self.options.acme.as_ref() let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port) .and_then(|a| a.port)
.unwrap_or(80); .unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new(); let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await { if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e); error!(
"Failed to start ACME challenge server on port {}: {}",
acme_port, e
);
return; return;
} }
@@ -488,13 +564,15 @@ impl RustProxy {
if let Some(acme_client) = acme_client { if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server; let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| { let result = acme_client
challenge_server_ref.set_challenge( .provision(domain, |pending| {
pending.token.clone(), challenge_server_ref.set_challenge(
pending.key_authorization.clone(), pending.token.clone(),
); pending.key_authorization.clone(),
async move { Ok(()) } );
}).await; async move { Ok(()) }
})
.await;
match result { match result {
Ok((cert_pem, key_pem)) => { Ok((cert_pem, key_pem)) => {
@@ -539,7 +617,10 @@ impl RustProxy {
None => return, None => return,
}; };
let auto_renew = self.options.acme.as_ref() let auto_renew = self
.options
.acme
.as_ref()
.and_then(|a| a.auto_renew) .and_then(|a| a.auto_renew)
.unwrap_or(true); .unwrap_or(true);
@@ -547,11 +628,17 @@ impl RustProxy {
return; return;
} }
let check_interval_hours = self.options.acme.as_ref() let check_interval_hours = self
.options
.acme
.as_ref()
.and_then(|a| a.renew_check_interval_hours) .and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24); .unwrap_or(24);
let acme_port = self.options.acme.as_ref() let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port) .and_then(|a| a.port)
.unwrap_or(80); .unwrap_or(80);
@@ -664,17 +751,19 @@ impl RustProxy {
/// Update routes atomically (hot-reload). /// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> { pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes // Validate new routes
rustproxy_config::validate_routes(&routes) rustproxy_config::validate_routes(&routes).map_err(|errors| {
.map_err(|errors| { let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect(); anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
anyhow::anyhow!("Route validation failed: {}", msgs.join(", ")) })?;
})?;
let new_manager = RouteManager::new(routes.clone()); let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports(); let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports", info!(
new_manager.route_count(), new_ports.len()); "Updating routes: {} routes on {} ports",
new_manager.route_count(),
new_ports.len()
);
// Get old ports // Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager { let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
@@ -684,28 +773,35 @@ impl RustProxy {
}; };
// Prune per-route metrics for route IDs that no longer exist // Prune per-route metrics for route IDs that no longer exist
let active_route_ids: HashSet<String> = routes.iter() let active_route_ids: HashSet<String> =
.filter_map(|r| r.id.clone()) routes.iter().filter_map(|r| r.id.clone()).collect();
.collect();
self.metrics.retain_routes(&active_route_ids); self.metrics.retain_routes(&active_route_ids);
// Prune per-backend metrics for backends no longer in any route target. // Prune per-backend metrics for backends no longer in any route target.
// For PortSpec::Preserve routes, expand across all listening ports since // For PortSpec::Preserve routes, expand across all listening ports since
// the actual runtime port depends on the incoming connection. // the actual runtime port depends on the incoming connection.
let listening_ports = self.get_listening_ports(); let listening_ports = self.get_listening_ports();
let active_backends: HashSet<String> = routes.iter() let active_backends: HashSet<String> = routes
.iter()
.filter_map(|r| r.action.targets.as_ref()) .filter_map(|r| r.action.targets.as_ref())
.flat_map(|targets| targets.iter()) .flat_map(|targets| targets.iter())
.flat_map(|target| { .flat_map(|target| {
let hosts: Vec<String> = target.host.to_vec().into_iter().map(|s| s.to_string()).collect(); let hosts: Vec<String> = target
.host
.to_vec()
.into_iter()
.map(|s| s.to_string())
.collect();
match &target.port { match &target.port {
rustproxy_config::PortSpec::Fixed(p) => { rustproxy_config::PortSpec::Fixed(p) => hosts
hosts.into_iter().map(|h| format!("{}:{}", h, p)).collect::<Vec<_>>() .into_iter()
} .map(|h| format!("{}:{}", h, p))
.collect::<Vec<_>>(),
_ => { _ => {
// Preserve/special: expand across all listening ports // Preserve/special: expand across all listening ports
let lp = &listening_ports; let lp = &listening_ports;
hosts.into_iter() hosts
.into_iter()
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p))) .flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
@@ -733,10 +829,13 @@ impl RustProxy {
let cm = cm_arc.lock().await; let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() { for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) { if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: bundle.cert_pem.clone(), domain.clone(),
key_pem: bundle.key_pem.clone(), TlsCertConfig {
}); cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
} }
} }
} }
@@ -753,7 +852,9 @@ impl RustProxy {
// Cancel connections on routes that were removed or disabled // Cancel connections on routes that were removed or disabled
listener.invalidate_removed_routes(&active_route_ids); listener.invalidate_removed_routes(&active_route_ids);
// Clean up registry entries for removed routes // Clean up registry entries for removed routes
listener.connection_registry().cleanup_removed_routes(&active_route_ids); listener
.connection_registry()
.cleanup_removed_routes(&active_route_ids);
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters) // Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
listener.prune_http_proxy_caches(&active_route_ids); listener.prune_http_proxy_caches(&active_route_ids);
@@ -766,9 +867,10 @@ impl RustProxy {
None => continue, None => continue,
}; };
// Find corresponding old route // Find corresponding old route
let old_route = old_manager.routes().iter().find(|r| { let old_route = old_manager
r.id.as_deref() == Some(new_id) .routes()
}); .iter()
.find(|r| r.id.as_deref() == Some(new_id));
let old_route = match old_route { let old_route = match old_route {
Some(r) => r, Some(r) => r,
None => continue, // new route, no existing connections to recycle None => continue, // new route, no existing connections to recycle
@@ -812,11 +914,13 @@ impl RustProxy {
{ {
let mut new_udp_ports = HashSet::new(); let mut new_udp_ports = HashSet::new();
for route in &routes { for route in &routes {
if !route.is_enabled() { continue; } if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref(); let transport = route.route_match.transport.as_ref();
match transport { match transport {
Some(rustproxy_config::TransportProtocol::Udp) | Some(rustproxy_config::TransportProtocol::Udp)
Some(rustproxy_config::TransportProtocol::All) => { | Some(rustproxy_config::TransportProtocol::All) => {
for port in route.route_match.ports.to_ports() { for port in route.route_match.ports.to_ports() {
new_udp_ports.insert(port); new_udp_ports.insert(port);
} }
@@ -825,7 +929,8 @@ impl RustProxy {
} }
} }
let old_udp_ports: HashSet<u16> = self.udp_listener_manager let old_udp_ports: HashSet<u16> = self
.udp_listener_manager
.as_ref() .as_ref()
.map(|u| u.listening_ports().into_iter().collect()) .map(|u| u.listening_ports().into_iter().collect())
.unwrap_or_default(); .unwrap_or_default();
@@ -847,6 +952,7 @@ impl RustProxy {
connection_registry, connection_registry,
); );
udp_mgr.set_proxy_ips(conn_config.proxy_ips); udp_mgr.set_proxy_ips(conn_config.proxy_ips);
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Wire up H3ProxyService so QUIC connections can serve HTTP/3 // Wire up H3ProxyService so QUIC connections can serve HTTP/3
let http_proxy = listener.http_proxy().clone(); let http_proxy = listener.http_proxy().clone();
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy); let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
@@ -898,56 +1004,77 @@ impl RustProxy {
/// Provision a certificate for a named route. /// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> { pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref() let cm_arc = self.cert_manager.as_ref().ok_or_else(|| {
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?; anyhow::anyhow!("No certificate manager configured (ACME not enabled)")
})?;
// Find the route by name // Find the route by name
let route = self.options.routes.iter() let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name)) .find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?; .ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref() let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string())) .and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?; .ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain); info!(
"Provisioning certificate for route '{}' (domain: {})",
route_name, domain
);
// Start challenge server // Start challenge server
let acme_port = self.options.acme.as_ref() let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port) .and_then(|a| a.port)
.unwrap_or(80); .unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new(); let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await cs.start(acme_port)
.await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs; let cs_ref = &cs;
let mut cm = cm_arc.lock().await; let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| { let result = cm
cs_ref.set_challenge(token, key_auth); .renew_domain(&domain, |token, key_auth| {
async {} cs_ref.set_challenge(token, key_auth);
}).await; async {}
})
.await;
drop(cm); drop(cm);
cs.stop().await; cs.stop().await;
let bundle = result let bundle = result.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs // Hot-swap into TLS configs
let mut tls_configs = Self::extract_tls_configs(&self.options.routes); let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: bundle.cert_pem.clone(), domain.clone(),
key_pem: bundle.key_pem.clone(), TlsCertConfig {
}); cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
{ {
let cm = cm_arc.lock().await; let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() { for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) { if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: b.cert_pem.clone(), d.clone(),
key_pem: b.key_pem.clone(), TlsCertConfig {
}); cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
},
);
} }
} }
} }
@@ -966,7 +1093,10 @@ impl RustProxy {
} }
} }
info!("Certificate provisioned and loaded for route '{}'", route_name); info!(
"Certificate provisioned and loaded for route '{}'",
route_name
);
Ok(()) Ok(())
} }
@@ -978,10 +1108,16 @@ impl RustProxy {
/// Get the status of a certificate for a named route. /// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> { pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter() let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name))?; .find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref() let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?; .and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager { if let Some(ref cm_arc) = self.cert_manager {
@@ -1010,8 +1146,9 @@ impl RustProxy {
let mut metrics = self.metrics.snapshot(); let mut metrics = self.metrics.snapshot();
if let Some(ref lm) = self.listener_manager { if let Some(ref lm) = self.listener_manager {
let entries = lm.http_proxy().protocol_cache_snapshot(); let entries = lm.http_proxy().protocol_cache_snapshot();
metrics.detected_protocols = entries.into_iter().map(|e| { metrics.detected_protocols = entries
rustproxy_metrics::ProtocolCacheEntryMetric { .into_iter()
.map(|e| rustproxy_metrics::ProtocolCacheEntryMetric {
host: e.host, host: e.host,
port: e.port, port: e.port,
domain: e.domain, domain: e.domain,
@@ -1026,8 +1163,8 @@ impl RustProxy {
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs, h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
h2_consecutive_failures: e.h2_consecutive_failures, h2_consecutive_failures: e.h2_consecutive_failures,
h3_consecutive_failures: e.h3_consecutive_failures, h3_consecutive_failures: e.h3_consecutive_failures,
} })
}).collect(); .collect();
} }
metrics metrics
} }
@@ -1058,9 +1195,7 @@ impl RustProxy {
/// Get statistics snapshot. /// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics { pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at let uptime = self.started_at.map(|t| t.elapsed().as_secs()).unwrap_or(0);
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
Statistics { Statistics {
active_connections: self.metrics.active_connections(), active_connections: self.metrics.active_connections(),
@@ -1071,6 +1206,13 @@ impl RustProxy {
} }
} }
/// Update the global ingress security policy.
pub fn set_security_policy(&mut self, policy: rustproxy_config::SecurityPolicy) {
self.security_policy
.store(Arc::new(Self::build_ip_block_list(Some(&policy))));
self.options.security_policy = Some(policy);
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript. /// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates /// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
/// take effect immediately for all new connections. /// take effect immediately for all new connections.
@@ -1130,10 +1272,13 @@ impl RustProxy {
let cm = cm_arc.lock().await; let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() { for (d, b) in cm.store().iter() {
if !configs.contains_key(d) { if !configs.contains_key(d) {
configs.insert(d.clone(), TlsCertConfig { configs.insert(
cert_pem: b.cert_pem.clone(), d.clone(),
key_pem: b.key_pem.clone(), TlsCertConfig {
}); cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
},
);
} }
} }
} }
@@ -1166,7 +1311,8 @@ impl RustProxy {
info!("Loading certificate for domain: {}", domain); info!("Loading certificate for domain: {}", domain);
// Check if the cert actually changed (for selective connection recycling) // Check if the cert actually changed (for selective connection recycling)
let cert_changed = self.loaded_certs let cert_changed = self
.loaded_certs
.get(domain) .get(domain)
.map(|existing| existing.cert_pem != cert_pem) .map(|existing| existing.cert_pem != cert_pem)
.unwrap_or(false); // new domain = no existing connections to recycle .unwrap_or(false); // new domain = no existing connections to recycle
@@ -1196,10 +1342,13 @@ impl RustProxy {
} }
// Persist in loaded_certs so future rebuild calls include this cert // Persist in loaded_certs so future rebuild calls include this cert
self.loaded_certs.insert(domain.to_string(), TlsCertConfig { self.loaded_certs.insert(
cert_pem: cert_pem.clone(), domain.to_string(),
key_pem: key_pem.clone(), TlsCertConfig {
}); cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
},
);
// Hot-swap TLS config on TCP and QUIC listeners // Hot-swap TLS config on TCP and QUIC listeners
let tls_configs = self.current_tls_configs().await; let tls_configs = self.current_tls_configs().await;
@@ -1222,7 +1371,9 @@ impl RustProxy {
// Recycle existing connections if cert actually changed // Recycle existing connections if cert actually changed
if cert_changed { if cert_changed {
if let Some(ref listener) = self.listener_manager { if let Some(ref listener) = self.listener_manager {
listener.connection_registry().recycle_for_cert_change(domain); listener
.connection_registry()
.recycle_for_cert_change(domain);
} }
} }
@@ -1244,16 +1395,22 @@ impl RustProxy {
continue; continue;
} }
let cert_spec = route.action.tls.as_ref() let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref()); .and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec { if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains { if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() { for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig { configs.insert(
cert_pem: cert_config.cert.clone(), domain.to_string(),
key_pem: cert_config.key.clone(), TlsCertConfig {
}); cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
},
);
} }
} }
} }
+4 -9
View File
@@ -1,12 +1,12 @@
#[global_allocator] #[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use anyhow::Result;
use clap::Parser; use clap::Parser;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management; use rustproxy::management;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions; use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy /// RustProxy - High-performance multi-protocol proxy
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_env_filter( .with_env_filter(
EnvFilter::try_from_default_env() EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
) )
.init(); .init();
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
let options = RustProxyOptions::from_file(&cli.config) let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?; .map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!( tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
// Validate-only mode // Validate-only mode
if cli.validate { if cli.validate {
+157 -65
View File
@@ -1,7 +1,7 @@
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error}; use tracing::{error, info};
use crate::RustProxy; use crate::RustProxy;
use rustproxy_config::RustProxyOptions; use rustproxy_config::RustProxyOptions;
@@ -141,14 +141,19 @@ async fn handle_request(
"start" => handle_start(&id, &request.params, proxy).await, "start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await, "stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await, "updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
"getMetrics" => handle_get_metrics(&id, proxy), "getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy), "getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await, "provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await, "renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await, "getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy), "getListeningPorts" => handle_get_listening_ports(&id, proxy),
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await, "setSocketHandlerRelay" => {
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await, handle_set_socket_handler_relay(&id, &request.params, proxy).await
}
"setDatagramHandlerRelay" => {
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
}
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await, "addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await, "removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await, "loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
@@ -167,7 +172,12 @@ async fn handle_start(
let config = match params.get("config") { let config = match params.get("config") {
Some(config) => config, Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'config' parameter".to_string(),
)
}
}; };
let options: RustProxyOptions = match serde_json::from_value(config.clone()) { let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
@@ -176,38 +186,31 @@ async fn handle_start(
}; };
match RustProxy::new(options) { match RustProxy::new(options) {
Ok(mut p) => { Ok(mut p) => match p.start().await {
match p.start().await { Ok(()) => {
Ok(()) => { send_event("started", serde_json::json!({}));
send_event("started", serde_json::json!({})); *proxy = Some(p);
*proxy = Some(p); ManagementResponse::ok(id.to_string(), serde_json::json!({}))
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
} }
} Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
},
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)), Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
} }
} }
async fn handle_stop( async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_mut() { match proxy.as_mut() {
Some(p) => { Some(p) => match p.stop().await {
match p.stop().await { Ok(()) => {
Ok(()) => { *proxy = None;
*proxy = None; send_event("stopped", serde_json::json!({}));
send_event("stopped", serde_json::json!({})); ManagementResponse::ok(id.to_string(), serde_json::json!({}))
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
} }
} Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
},
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})), None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
} }
} }
@@ -224,7 +227,12 @@ async fn handle_update_routes(
let routes = match params.get("routes") { let routes = match params.get("routes") {
Some(routes) => routes, Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routes' parameter".to_string(),
)
}
}; };
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) { let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
@@ -234,36 +242,72 @@ async fn handle_update_routes(
match p.update_routes(routes).await { match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)), Err(e) => {
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
}
} }
} }
fn handle_get_metrics( fn handle_set_security_policy(
id: &str, id: &str,
proxy: &Option<RustProxy>, params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse { ) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let policy = match params.get("policy") {
Some(policy) => policy,
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'policy' parameter".to_string(),
)
}
};
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
Ok(policy) => policy,
Err(e) => {
return ManagementResponse::err(
id.to_string(),
format!("Invalid security policy: {}", e),
)
}
};
p.set_security_policy(policy);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() { match proxy.as_ref() {
Some(p) => { Some(p) => {
let metrics = p.get_metrics(); let metrics = p.get_metrics();
match serde_json::to_value(&metrics) { match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v), Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize metrics: {}", e),
),
} }
} }
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
} }
} }
fn handle_get_statistics( fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() { match proxy.as_ref() {
Some(p) => { Some(p) => {
let stats = p.get_statistics(); let stats = p.get_statistics();
match serde_json::to_value(&stats) { match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v), Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize statistics: {}", e),
),
} }
} }
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) { let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(), Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
}; };
match p.provision_certificate(&route_name).await { match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to provision certificate: {}", e),
),
} }
} }
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) { let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(), Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
}; };
match p.renew_certificate(&route_name).await { match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to renew certificate: {}", e),
),
} }
} }
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) { let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name, Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
}; };
match p.get_certificate_status(route_name).await { match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({ Some(status) => ManagementResponse::ok(
"domain": status.domain, id.to_string(),
"source": status.source, serde_json::json!({
"expiresAt": status.expires_at, "domain": status.domain,
"isValid": status.is_valid, "source": status.source,
})), "expiresAt": status.expires_at,
"isValid": status.is_valid,
}),
),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null), None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
} }
} }
fn handle_get_listening_ports( fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() { match proxy.as_ref() {
Some(p) => { Some(p) => {
let ports = p.get_listening_ports(); let ports = p.get_listening_ports();
@@ -361,7 +426,8 @@ async fn handle_set_socket_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}; };
let socket_path = params.get("socketPath") let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()); .map(|s| s.to_string());
@@ -381,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}; };
let socket_path = params.get("socketPath") let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()); .map(|s| s.to_string());
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) { let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16, Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
}; };
match p.add_listening_port(port).await { match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to add port {}: {}", port, e),
),
} }
} }
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) { let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16, Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
}; };
match p.remove_listening_port(port).await { match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to remove port {}: {}", port, e),
),
} }
} }
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
let domain = match params.get("domain").and_then(|v| v.as_str()) { let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(), Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'domain' parameter".to_string(),
)
}
}; };
let cert = match params.get("cert").and_then(|v| v.as_str()) { let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(), Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
}
}; };
let key = match params.get("key").and_then(|v| v.as_str()) { let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(), Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
}
}; };
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string()); let ca = params
.get("ca")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("loadCertificate: domain={}", domain); info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config // Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await { match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to load certificate for {}: {}", domain, e),
),
} }
} }
+8 -8
View File
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
let path = parts.get(1).copied().unwrap_or("/"); let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header // Extract Host header
let host = req_str.lines() let host = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("host:")) .find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim()) .map(|l| l[5..].trim())
.unwrap_or("unknown"); .unwrap_or("unknown");
@@ -336,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let req_str = String::from_utf8_lossy(&buf[..n]); let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake // Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines() let ws_key = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:")) .find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default(); .unwrap_or_default();
@@ -378,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair}; use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap(); let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain); params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap(); let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap(); let cert = params.self_signed(&key_pair).unwrap();
@@ -458,11 +462,7 @@ pub fn make_tls_terminate_route(
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data. /// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`). /// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
pub async fn start_tls_ws_echo_backend( pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
port: u16,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
use std::sync::Arc; use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem) let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
@@ -1,9 +1,9 @@
mod common; mod common;
use bytes::Buf;
use common::*; use common::*;
use rustproxy::RustProxy; use rustproxy::RustProxy;
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic}; use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
use bytes::Buf;
use std::sync::Arc; use std::sync::Arc;
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate. /// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
@@ -14,7 +14,14 @@ fn make_h3_route(
cert_pem: &str, cert_pem: &str,
key_pem: &str, key_pem: &str,
) -> rustproxy_config::RouteConfig { ) -> rustproxy_config::RouteConfig {
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem); let mut route = make_tls_terminate_route(
port,
"localhost",
target_host,
target_port,
cert_pem,
key_pem,
);
route.route_match.transport = Some(TransportProtocol::All); route.route_match.transport = Some(TransportProtocol::All);
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction // Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
route.action.udp = Some(RouteUdp { route.action.udp = Some(RouteUdp {
@@ -89,11 +96,9 @@ async fn test_h3_response_stream_finishes() {
.await .await
.expect("QUIC handshake failed"); .expect("QUIC handshake failed");
let (mut driver, mut send_request) = h3::client::new( let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
h3_quinn::Connection::new(connection), .await
) .expect("H3 connection setup failed");
.await
.expect("H3 connection setup failed");
// Drive the H3 connection in background // Drive the H3 connection in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
.body(()) .body(())
.unwrap(); .unwrap();
let mut stream = send_request.send_request(req).await let mut stream = send_request
.send_request(req)
.await
.expect("Failed to send H3 request"); .expect("Failed to send H3 request");
stream.finish().await stream
.finish()
.await
.expect("Failed to finish sending H3 request body"); .expect("Failed to finish sending H3 request body");
// 6. Read response headers // 6. Read response headers
let resp = stream.recv_response().await let resp = stream
.recv_response()
.await
.expect("Failed to receive H3 response"); .expect("Failed to receive H3 response");
assert_eq!(resp.status(), http::StatusCode::OK, assert_eq!(
"Expected 200 OK, got {}", resp.status()); resp.status(),
http::StatusCode::OK,
"Expected 200 OK, got {}",
resp.status()
);
// 7. Read body and verify stream ends (FIN received) // 7. Read body and verify stream ends (FIN received)
// This is the critical assertion: recv_data() must return None (stream ended) // This is the critical assertion: recv_data() must return None (stream ended)
// within the timeout, NOT hang forever waiting for a FIN that never arrives. // within the timeout, NOT hang forever waiting for a FIN that never arrives.
let result = with_timeout(async { let result = with_timeout(
let mut total = 0usize; async {
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") { let mut total = 0usize;
total += chunk.remaining(); while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
} total += chunk.remaining();
// recv_data() returned None => stream ended (FIN received) }
total // recv_data() returned None => stream ended (FIN received)
}, 10) total
},
10,
)
.await; .await;
let bytes_received = result.expect( let bytes_received = result.expect(
"TIMEOUT: H3 stream never ended (FIN not received by client). \ "TIMEOUT: H3 stream never ended (FIN not received by client). \
The proxy sent all response data but failed to send the QUIC stream FIN." The proxy sent all response data but failed to send the QUIC stream FIN.",
); );
assert_eq!( assert_eq!(
bytes_received, bytes_received,
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await; async {
let body = extract_body(&response); let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
body.to_string() let body = extract_body(&response);
}, 10) body.to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result); assert!(
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result); result.contains(r#""method":"GET"#),
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result); "Expected GET method, got: {}",
result
);
assert!(
result.contains(r#""path":"/hello"#),
"Expected /hello path, got: {}",
result
);
assert!(
result.contains(r#""backend":"main"#),
"Expected main backend, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port), make_test_route(
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port), proxy_port,
Some("alpha.example.com"),
"127.0.0.1",
backend1_port,
),
make_test_route(
proxy_port,
Some("beta.example.com"),
"127.0.0.1",
backend2_port,
),
], ],
..Default::default() ..Default::default()
}; };
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain // Test alpha domain
let alpha_result = with_timeout(async { let alpha_result = with_timeout(
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result); assert!(
alpha_result.contains(r#""backend":"alpha"#),
"Expected alpha backend, got: {}",
alpha_result
);
// Test beta domain // Test beta domain
let beta_result = with_timeout(async { let beta_result = with_timeout(
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result); assert!(
beta_result.contains(r#""backend":"beta"#),
"Expected beta backend, got: {}",
beta_result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test API path // Test API path
let api_result = with_timeout(async { let api_result = with_timeout(
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result); assert!(
api_result.contains(r#""backend":"api"#),
"Expected api backend, got: {}",
api_result
);
// Test web path (no /api prefix) // Test web path (no /api prefix)
let web_result = with_timeout(async { let web_result = with_timeout(
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result); assert!(
web_result.contains(r#""backend":"web"#),
"Expected web backend, got: {}",
web_result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
.unwrap(); .unwrap();
// Should get 204 No Content with CORS headers // Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result); assert!(
assert!(result.to_lowercase().contains("access-control-allow-origin"), result.contains("204"),
"Expected CORS header, got: {}", result); "Expected 204 status, got: {}",
result
);
assert!(
result
.to_lowercase()
.contains("access-control-allow-origin"),
"Expected CORS header, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await; async {
response let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
}, 10) response
},
10,
)
.await .await
.unwrap(); .unwrap();
// Proxy should relay the 500 from backend // Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result); assert!(
result.contains("500"),
"Expected 500 status, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
// Create a route only for a specific domain // Create a route only for a specific domain
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)], routes: vec![make_test_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
9999,
)],
..Default::default() ..Default::default()
}; };
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await; async {
response let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
}, 10) response
},
10,
)
.await .await
.unwrap(); .unwrap();
// Should get 502 Bad Gateway (no route matched) // Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result); assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "example.com", "GET", "/").await; async {
response let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
}, 10) response
},
10,
)
.await .await
.unwrap(); .unwrap();
// Should get 502 Bad Gateway (backend unavailable) // Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result); assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -295,38 +388,53 @@ async fn test_https_terminate_http_forward() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let _ = rustls::crypto::ring::default_provider().install_default(); async {
let tls_config = rustls::ClientConfig::builder() let _ = rustls::crypto::ring::default_provider().install_default();
.dangerous() let tls_config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); .with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send HTTP request through TLS // Send HTTP request through TLS
let request = format!( let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", "GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain domain
); );
tls_stream.write_all(request.as_bytes()).await.unwrap(); tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new(); let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap(); tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string() String::from_utf8_lossy(&response).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
let body = extract_body(&result); let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body); assert!(
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body); body.contains(r#""method":"GET"#),
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body); "Expected GET, got: {}",
body
);
assert!(
body.contains(r#""path":"/api/data"#),
"Expected /api/data, got: {}",
body
);
assert!(
body.contains(r#""backend":"tls-backend"#),
"Expected tls-backend, got: {}",
body
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -347,59 +455,68 @@ async fn test_websocket_through_proxy() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
.unwrap();
// Send WebSocket upgrade request // Send WebSocket upgrade request
let request = format!( let request = format!(
"GET /ws HTTP/1.1\r\n\ "GET /ws HTTP/1.1\r\n\
Host: example.com\r\n\ Host: example.com\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
\r\n" \r\n"
); );
stream.write_all(request.as_bytes()).await.unwrap(); stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response // Read the 101 response
let mut response_buf = Vec::with_capacity(4096); let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1]; let mut temp = [0u8; 1];
loop { loop {
let n = stream.read(&mut temp).await.unwrap(); let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; } if n == 0 {
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break; break;
} }
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len - 4..] == *b"\r\n\r\n" {
break;
}
}
} }
}
let response_str = String::from_utf8_lossy(&response_buf).to_string(); let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str); assert!(
assert!( response_str.contains("101"),
response_str.to_lowercase().contains("upgrade: websocket"), "Expected 101 Switching Protocols, got: {}",
"Expected Upgrade header, got: {}", response_str
response_str );
); assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo // After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!"; let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap(); stream.write_all(test_data).await.unwrap();
// Read echoed data // Read echoed data
let mut echo_buf = vec![0u8; 256]; let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap(); let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n]; let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data"); assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string() "ok".to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
// Create terminate-and-reencrypt routes // Create terminate-and-reencrypt routes
let mut route1 = make_tls_terminate_route( let mut route1 = make_tls_terminate_route(
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1, proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
); );
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let mut route2 = make_tls_terminate_route( let mut route2 = make_tls_terminate_route(
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2, proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
); );
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -450,27 +577,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt // Test alpha domain - HTTP request through TLS terminate-and-reencrypt
let alpha_result = with_timeout(async { let alpha_result = with_timeout(
let _ = rustls::crypto::ring::default_provider().install_default(); async {
let tls_config = rustls::ClientConfig::builder() let _ = rustls::crypto::ring::default_provider().install_default();
.dangerous() let tls_config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); .with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap(); let server_name =
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n"; let request =
tls_stream.write_all(request.as_bytes()).await.unwrap(); "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new(); let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap(); tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string() String::from_utf8_lossy(&response).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -498,27 +630,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
); );
// Test beta domain - different host goes to different backend // Test beta domain - different host goes to different backend
let beta_result = with_timeout(async { let beta_result = with_timeout(
let _ = rustls::crypto::ring::default_provider().install_default(); async {
let tls_config = rustls::ClientConfig::builder() let _ = rustls::crypto::ring::default_provider().install_default();
.dangerous() let tls_config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); .with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap(); let server_name =
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n"; let request =
tls_stream.write_all(request.as_bytes()).await.unwrap(); "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new(); let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap(); tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string() String::from_utf8_lossy(&response).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
.dangerous() .dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth(); .with_no_client_auth();
let connector = let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send WebSocket upgrade request through TLS // Send WebSocket upgrade request through TLS
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// HTTP request should match the route and get proxied // HTTP request should match the route and get proxied
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
assert!(!wait_for_port(port, 200).await); assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start"); assert!(
wait_for_port(port, 2000).await,
"Port should be listening after start"
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
// Give the OS a moment to release the port // Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop"); assert!(
!wait_for_port(port, 200).await,
"Port should not be listening after stop"
);
} }
#[tokio::test] #[tokio::test]
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
let port = next_port(); let port = next_port();
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)], routes: vec![make_test_route(
port,
Some("old.example.com"),
"127.0.0.1",
8080,
)],
..Default::default() ..Default::default()
}; };
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
// Update routes atomically // Update routes atomically
let new_routes = vec![ let new_routes = vec![make_test_route(
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090), port,
]; Some("new.example.com"),
"127.0.0.1",
9090,
)];
let result = proxy.update_routes(new_routes).await; let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok()); assert!(result.is_ok());
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
// Add a new port // Add a new port
proxy.add_listening_port(port2).await.unwrap(); proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening"); assert!(
wait_for_port(port2, 2000).await,
"New port should be listening"
);
// Remove the port // Remove the port
proxy.remove_listening_port(port2).await.unwrap(); proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening"); assert!(
!wait_for_port(port2, 200).await,
"Removed port should not be listening"
);
// Original port should still be listening // Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening"); assert!(
wait_for_port(port1, 200).await,
"Original port should still be listening"
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics(); let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections); assert!(
stats.total_connections > 0,
"Expected total_connections > 0, got {}",
stats.total_connections
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics(); let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, assert!(
"Expected some connections tracked, got {}", stats.total_connections); stats.total_connections > 0,
"Expected some connections tracked, got {}",
stats.total_connections
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
let mut proxy = RustProxy::new(options).unwrap(); let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await); assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet"); assert!(
!wait_for_port(port2, 200).await,
"port2 should not be listening yet"
);
// Update routes to use port2 instead // Update routes to use port2 instead
let new_routes = vec![ let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
make_test_route(port2, None, "127.0.0.1", backend_port),
];
proxy.update_routes(new_routes).await.unwrap(); proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed // Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload"); assert!(
wait_for_port(port2, 2000).await,
"port2 should be listening after reload"
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload"); assert!(
!wait_for_port(port1, 200).await,
"port1 should be closed after reload"
);
// Verify port2 works // Verify port2 works
let ports = proxy.get_listening_ports(); let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports); assert!(
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports); ports.contains(&port2),
"Expected port2 in listening ports: {:?}",
ports
);
assert!(
!ports.contains(&port1),
"port1 should not be in listening ports: {:?}",
ports
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -24,19 +24,25 @@ async fn test_tcp_forward_echo() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
// Wait for proxy to be ready // Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready"); assert!(
wait_for_port(proxy_port, 2000).await,
"Proxy port not ready"
);
// Connect and send data // Connect and send data
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
stream.write_all(b"hello world").await.unwrap(); .unwrap();
stream.write_all(b"hello world").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
@@ -61,21 +67,24 @@ async fn test_tcp_forward_large_payload() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
.unwrap();
// Send 1MB of data // Send 1MB of data
let data = vec![b'A'; 1_000_000]; let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap(); stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap(); stream.shutdown().await.unwrap();
// Read all back // Read all back
let mut received = Vec::new(); let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap(); stream.read_to_end(&mut received).await.unwrap();
received.len() received.len()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -100,29 +109,32 @@ async fn test_tcp_forward_multiple_connections() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut handles = Vec::new(); async {
for i in 0..10 { let mut handles = Vec::new();
let port = proxy_port; for i in 0..10 {
handles.push(tokio::spawn(async move { let port = proxy_port;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) handles.push(tokio::spawn(async move {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.unwrap(); .await
let msg = format!("connection-{}", i); .unwrap();
stream.write_all(msg.as_bytes()).await.unwrap(); let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
})); }));
} }
let mut results = Vec::new(); let mut results = Vec::new();
for handle in handles { for handle in handles {
results.push(handle.await.unwrap()); results.push(handle.await.unwrap());
} }
results results
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow // Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async { let result = with_timeout(
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await; async {
stream.is_ok() let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
}, 5) stream.is_ok()
},
5,
)
.await .await
.unwrap(); .unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down"); assert!(
result,
"Should be able to connect to proxy even if backend is down"
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -178,16 +196,19 @@ async fn test_tcp_forward_bidirectional() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
stream.write_all(b"test data").await.unwrap(); .unwrap();
stream.write_all(b"test data").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port), make_tls_passthrough_route(
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port), proxy_port,
Some("one.example.com"),
"127.0.0.1",
backend1_port,
),
make_tls_passthrough_route(
proxy_port,
Some("two.example.com"),
"127.0.0.1",
backend2_port,
),
], ],
..Default::default() ..Default::default()
}; };
@@ -76,39 +86,53 @@ async fn test_tls_passthrough_sni_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com" // Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("one.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
// Backend1 should have received the ClientHello and prefixed its response // Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result); assert!(
result.starts_with("BACKEND1:"),
"Expected BACKEND1 prefix, got: {}",
result
);
// Now test routing to backend2 // Now test routing to backend2
let result2 = with_timeout(async { let result2 = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("two.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2); assert!(
result2.starts_with("BACKEND2:"),
"Expected BACKEND2 prefix, got: {}",
result2
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
let _backend = start_echo_server(backend_port).await; let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![make_tls_passthrough_route(
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port), proxy_port,
], Some("known.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default() ..Default::default()
}; };
@@ -132,21 +159,24 @@ async fn test_tls_passthrough_unknown_sni() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped) // Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("unknown.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error // Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await { match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped Err(_) => true, // Error = connection dropped
} }
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await; let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![make_tls_passthrough_route(
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port), proxy_port,
], Some("*.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default() ..Default::default()
}; };
@@ -174,21 +207,28 @@ async fn test_tls_passthrough_wildcard_domain() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com // Should match any subdomain of example.com
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("anything.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result); assert!(
result.starts_with("WILDCARD:"),
"Expected WILDCARD prefix, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -222,24 +262,29 @@ async fn test_tls_passthrough_multiple_domains() {
("beta.example.com", "B2:"), ("beta.example.com", "B2:"),
("gamma.example.com", "B3:"), ("gamma.example.com", "B3:"),
] { ] {
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello(domain); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
assert!( assert!(
result.starts_with(expected_prefix), result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}", "Domain {} should route to {}, got: {}",
domain, expected_prefix, result domain,
expected_prefix,
result
); );
} }
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -84,23 +89,26 @@ async fn test_tls_terminate_basic() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client // Connect with TLS client
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap(); tls_stream.write_all(b"hello TLS").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
// Create terminate-and-reencrypt route // Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route( let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key, proxy_port,
domain,
"127.0.0.1",
backend_port,
&proxy_cert,
&proxy_key,
); );
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -138,23 +151,26 @@ async fn test_tls_terminate_and_reencrypt() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap(); tls_stream.write_all(b"hello reencrypt").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1), make_tls_terminate_route(
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2), proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
),
make_tls_terminate_route(
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
),
], ],
..Default::default() ..Default::default()
}; };
@@ -188,27 +218,35 @@ async fn test_tls_terminate_sni_cert_selection() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain // Test alpha domain
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap(); let server_name =
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap(); tls_stream.write_all(b"test").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result); assert!(
result.starts_with("ALPHA:"),
"Expected ALPHA prefix, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -233,26 +276,29 @@ async fn test_tls_terminate_large_payload() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send 1MB of data // Send 1MB of data
let data = vec![b'X'; 1_000_000]; let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap(); tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap(); tls_stream.shutdown().await.unwrap();
let mut received = Vec::new(); let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap(); tls_stream.read_to_end(&mut received).await.unwrap();
received.len() received.len()
}, 15) },
15,
)
.await .await
.unwrap(); .unwrap();
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -281,37 +332,40 @@ async fn test_tls_terminate_concurrent() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut handles = Vec::new(); async {
for i in 0..10 { let mut handles = Vec::new();
let port = proxy_port; for i in 0..10 {
let dom = domain.to_string(); let port = proxy_port;
handles.push(tokio::spawn(async move { let dom = domain.to_string();
let tls_config = make_insecure_tls_client_config(); handles.push(tokio::spawn(async move {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let msg = format!("conn-{}", i); let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap(); tls_stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
})); }));
} }
let mut results = Vec::new(); let mut results = Vec::new();
for handle in handles { for handle in handles {
results.push(handle.await.unwrap()); results.push(handle.await.unwrap());
} }
results results
}, 15) },
15,
)
.await .await
.unwrap(); .unwrap();
+1 -1
View File
@@ -3,6 +3,6 @@
*/ */
export const commitinfo = { export const commitinfo = {
name: '@push.rocks/smartproxy', name: '@push.rocks/smartproxy',
version: '27.8.2', version: '27.9.0',
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.' description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
} }
+1 -1
View File
@@ -7,7 +7,7 @@ export { SmartProxy } from './proxies/smart-proxy/index.js';
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js'; export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
// Export smart-proxy models // Export smart-proxy models
export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js'; export type { ISmartProxyOptions, ISmartProxySecurityPolicy, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js'; export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
export * from './proxies/smart-proxy/utils/index.js'; export * from './proxies/smart-proxy/utils/index.js';
+1 -1
View File
@@ -2,6 +2,6 @@
* SmartProxy models * SmartProxy models
*/ */
// Export everything except IAcmeOptions from interfaces // Export everything except IAcmeOptions from interfaces
export type { ISmartProxyOptions, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js'; export type { ISmartProxyOptions, ISmartProxySecurityPolicy, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
export * from './route-types.js'; export * from './route-types.js';
export * from './metrics-types.js'; export * from './metrics-types.js';
@@ -29,6 +29,11 @@ export interface ISmartProxyCertStore {
} }
import type { IRouteConfig } from './route-types.js'; import type { IRouteConfig } from './route-types.js';
export interface ISmartProxySecurityPolicy {
blockedIps?: string[];
blockedCidrs?: string[];
}
/** /**
* Provision object for static or HTTP-01 certificate * Provision object for static or HTTP-01 certificate
*/ */
@@ -137,6 +142,7 @@ export interface ISmartProxyOptions {
// Rate limiting and security // Rate limiting and security
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
securityPolicy?: ISmartProxySecurityPolicy; // Global ingress block policy, enforced before routing
// Enhanced keep-alive settings // Enhanced keep-alive settings
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
+2 -1
View File
@@ -1,5 +1,5 @@
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js'; import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
import type { IAcmeOptions, ISmartProxyOptions } from './interfaces.js'; import type { IAcmeOptions, ISmartProxyOptions, ISmartProxySecurityPolicy } from './interfaces.js';
import type { import type {
IRouteAction, IRouteAction,
IRouteConfig, IRouteConfig,
@@ -75,6 +75,7 @@ export interface IRustProxyOptions {
keepAliveInactivityMultiplier?: number; keepAliveInactivityMultiplier?: number;
extendedKeepAliveLifetime?: number; extendedKeepAliveLifetime?: number;
metrics?: ISmartProxyOptions['metrics']; metrics?: ISmartProxyOptions['metrics'];
securityPolicy?: ISmartProxySecurityPolicy;
acme?: IRustAcmeOptions; acme?: IRustAcmeOptions;
} }
@@ -7,6 +7,7 @@ import type {
IRustRouteConfig, IRustRouteConfig,
IRustStatistics, IRustStatistics,
} from './models/rust-types.js'; } from './models/rust-types.js';
import type { ISmartProxySecurityPolicy } from './models/interfaces.js';
/** /**
* Type-safe command definitions for the Rust proxy IPC protocol. * Type-safe command definitions for the Rust proxy IPC protocol.
@@ -15,6 +16,7 @@ type TSmartProxyCommands = {
start: { params: { config: IRustProxyOptions }; result: void }; start: { params: { config: IRustProxyOptions }; result: void };
stop: { params: Record<string, never>; result: void }; stop: { params: Record<string, never>; result: void };
updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void }; updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
setSecurityPolicy: { params: { policy: ISmartProxySecurityPolicy }; result: void };
getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot }; getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
getStatistics: { params: Record<string, never>; result: IRustStatistics }; getStatistics: { params: Record<string, never>; result: IRustStatistics };
provisionCertificate: { params: { routeName: string }; result: void }; provisionCertificate: { params: { routeName: string }; result: void };
@@ -139,6 +141,10 @@ export class RustProxyBridge extends plugins.EventEmitter {
await this.bridge.sendCommand('updateRoutes', { routes }); await this.bridge.sendCommand('updateRoutes', { routes });
} }
public async setSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
await this.bridge.sendCommand('setSecurityPolicy', { policy });
}
public async getMetrics(): Promise<IRustMetricsSnapshot> { public async getMetrics(): Promise<IRustMetricsSnapshot> {
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>); return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
} }
+10 -1
View File
@@ -17,7 +17,7 @@ import { Mutex } from './utils/mutex.js';
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js'; import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
// Types // Types
import type { ISmartProxyOptions, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js'; import type { ISmartProxyOptions, ISmartProxySecurityPolicy, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
import type { IRouteConfig } from './models/route-types.js'; import type { IRouteConfig } from './models/route-types.js';
import type { IMetrics } from './models/metrics-types.js'; import type { IMetrics } from './models/metrics-types.js';
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js'; import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
@@ -350,6 +350,15 @@ export class SmartProxy extends plugins.EventEmitter {
.catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' })); .catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' }));
} }
/**
* Update the global ingress security policy without changing routes.
* The Rust engine applies this before route selection and backend connection.
*/
public async updateSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
this.settings.securityPolicy = policy;
await this.bridge.setSecurityPolicy(policy);
}
/** /**
* Provision a certificate for a named route. * Provision a certificate for a named route.
*/ */
@@ -182,6 +182,7 @@ export function buildRustProxyOptions(
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier, keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime, extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
metrics: settings.metrics, metrics: settings.metrics,
securityPolicy: settings.securityPolicy,
acme: serializeAcmeForRust(acme), acme: serializeAcmeForRust(acme),
}; };
} }