Compare commits

...

2 Commits

Author SHA1 Message Date
jkunz e806f7257f v27.9.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 15:11:10 +00:00
jkunz af4908b63f feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers 2026-04-26 15:11:10 +00:00
54 changed files with 2351 additions and 1197 deletions
+7
View File
@@ -1,5 +1,12 @@
# Changelog
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartproxy",
"version": "27.8.2",
"version": "27.9.0",
"private": false,
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
"main": "dist_ts/index.js",
+1
View File
@@ -1319,6 +1319,7 @@ dependencies = [
"rustproxy-http",
"rustproxy-metrics",
"rustproxy-routing",
"rustproxy-security",
"serde",
"serde_json",
"socket2 0.5.10",
+4 -4
View File
@@ -3,15 +3,15 @@
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
pub mod route_types;
pub mod proxy_options;
pub mod tls_types;
pub mod route_types;
pub mod security_types;
pub mod tls_types;
pub mod validation;
// Re-export all primary types
pub use route_types::*;
pub use proxy_options::*;
pub use tls_types::*;
pub use route_types::*;
pub use security_types::*;
pub use tls_types::*;
pub use validation::*;
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
pub retention_seconds: Option<u64>,
}
/// Global ingress security policy.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SecurityPolicy {
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_ips: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_cidrs: Option<Vec<String>>,
}
/// RustProxy configuration options.
/// Matches TypeScript: `ISmartProxyOptions`
///
@@ -235,6 +245,10 @@ pub struct RustProxyOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<MetricsConfig>,
/// Global ingress security policy, enforced before route selection.
#[serde(skip_serializing_if = "Option::is_none")]
pub security_policy: Option<SecurityPolicy>,
// ─── ACME ────────────────────────────────────────────────────────
/// Global ACME configuration
#[serde(skip_serializing_if = "Option::is_none")]
@@ -275,6 +289,7 @@ impl Default for RustProxyOptions {
use_http_proxy: None,
http_proxy_port: None,
metrics: None,
security_policy: None,
acme: None,
}
}
@@ -111,10 +111,7 @@ pub enum IpAllowEntry {
/// Plain IP/CIDR — allowed for all domains on this route
Plain(String),
/// Domain-scoped — allowed only when the requested domain matches
DomainScoped {
ip: String,
domains: Vec<String>,
},
DomainScoped { ip: String, domains: Vec<String> },
}
/// Security options for routes.
+17 -8
View File
@@ -1,6 +1,6 @@
use thiserror::Error;
use crate::route_types::{RouteConfig, RouteActionType};
use crate::route_types::{RouteActionType, RouteConfig};
/// Validation errors for route configurations.
#[derive(Debug, Error)]
@@ -30,9 +30,10 @@ pub enum ValidationError {
/// Validate a single route configuration.
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
let name = route.name.clone().unwrap_or_else(|| {
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
});
let name = route
.name
.clone()
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
// Check ports
let ports = route.listening_ports();
@@ -160,7 +161,9 @@ mod tests {
let mut route = make_valid_route();
route.action.targets = None;
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
}
#[test]
@@ -168,7 +171,9 @@ mod tests {
let mut route = make_valid_route();
route.action.targets = Some(vec![]);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
}
#[test]
@@ -176,7 +181,9 @@ mod tests {
let mut route = make_valid_route();
route.route_match.ports = PortRange::Single(0);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
}
#[test]
@@ -186,7 +193,9 @@ mod tests {
let mut r2 = make_valid_route();
r2.id = Some("route-1".to_string());
let errors = validate_routes(&[r1, r2]).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
}
#[test]
+2 -2
View File
@@ -2,10 +2,10 @@
//!
//! Metrics and throughput tracking for RustProxy.
pub mod throughput;
pub mod collector;
pub mod log_dedup;
pub mod throughput;
pub use throughput::*;
pub use collector::*;
pub use log_dedup::*;
pub use throughput::*;
@@ -1,6 +1,6 @@
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::info;
@@ -47,7 +47,10 @@ impl LogDeduplicator {
let map_key = format!("{}:{}", category, key);
let now = Instant::now();
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
let entry = self
.events
.entry(map_key)
.or_insert_with(|| AggregatedEvent {
category: category.to_string(),
first_message: message.to_string(),
count: AtomicU64::new(0),
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
@@ -7,8 +7,8 @@
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
let matches = entry.domain.as_deref()
let matches = entry
.domain
.as_deref()
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
.unwrap_or(false);
if matches {
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) {
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
if !RequestFilter::check_ip_security(
new_security,
&entry.source_ip,
entry.domain.as_deref(),
) {
info!(
"Terminating connection from {} — IP now blocked on route '{}'",
entry.source_ip, route_id
@@ -31,7 +31,8 @@ impl ConnectionTracker {
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
let count = self
.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
@@ -48,7 +49,10 @@ impl ConnectionTracker {
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
while timestamps
.front()
.is_some_and(|t| now.duration_since(*t) >= one_minute)
{
timestamps.pop_front();
}
@@ -111,7 +115,6 @@ impl ConnectionTracker {
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
}
#[cfg(test)]
@@ -1,8 +1,8 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
use rustproxy_metrics::MetricsCollector;
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
if let Some(data) = initial_data {
backend.write_all(data).await?;
if let Some(ref ctx) = metrics {
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
data.len() as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_c2b {
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
n as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
backend_write.shutdown(),
).await;
let _ =
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
total
});
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_b2c {
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
0,
n as u64,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
client_write.shutdown(),
).await;
let _ =
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
total
});
+16 -16
View File
@@ -4,26 +4,26 @@
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
//! and UDP datagram session tracking with forwarding.
pub mod tcp_listener;
pub mod sni_parser;
pub mod connection_registry;
pub mod connection_tracker;
pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_tracker;
pub mod connection_registry;
pub mod socket_opts;
pub mod udp_session;
pub mod udp_listener;
pub mod quic_handler;
pub mod sni_parser;
pub mod socket_opts;
pub mod tcp_listener;
pub mod tls_handler;
pub mod udp_listener;
pub mod udp_session;
pub use tcp_listener::*;
pub use sni_parser::*;
pub use connection_registry::*;
pub use connection_tracker::*;
pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_tracker::*;
pub use connection_registry::*;
pub use socket_opts::*;
pub use udp_session::*;
pub use udp_listener::*;
pub use quic_handler::*;
pub use sni_parser::*;
pub use socket_opts::*;
pub use tcp_listener::*;
pub use tls_handler::*;
pub use udp_listener::*;
pub use udp_session::*;
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
.position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end])
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
let line =
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader);
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
let command = data[12] & 0x0F;
// 0x0 = LOCAL, 0x1 = PROXY
if command > 1 {
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
return Err(ProxyProtocolError::Parse(format!(
"Unknown command: {}",
command
)));
}
// Address family (high nibble) + transport (low nibble) of byte 13
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + STREAM (0x1) = TCP4
(0x1, 0x1) => {
if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
}
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + DGRAM (0x2) = UDP4
(0x1, 0x2) => {
if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
}
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
(0x2, 0x1) => {
if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
(0x2, 0x2) => {
if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
}
/// Generate a PROXY protocol v2 binary header.
pub fn generate_v2(
source: &SocketAddr,
dest: &SocketAddr,
transport: ProxyV2Transport,
) -> Vec<u8> {
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
let transport_nibble: u8 = match transport {
ProxyV2Transport::Stream => 0x1,
ProxyV2Transport::Datagram => 0x2,
@@ -462,7 +469,10 @@ mod tests {
header.push(0x11);
header.extend_from_slice(&12u16.to_be_bytes());
header.extend_from_slice(&[0u8; 12]);
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
assert!(matches!(
parse_v2(&header),
Err(ProxyProtocolError::UnsupportedVersion)
));
}
#[test]
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService;
use crate::connection_tracker::ConnectionTracker;
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
use crate::connection_tracker::ConnectionTracker;
/// Create a QUIC server endpoint on the given port with the provided TLS config.
///
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
quinn::EndpointConfig::default(),
Some(server_config),
socket,
quinn::default_runtime()
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
info!("QUIC endpoint listening on port {}", port);
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
port: u16,
tls_config: Arc<RustlsServerConfig>,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
cancel: CancellationToken,
) -> anyhow::Result<QuicProxyRelay> {
// Bind external socket on the real port
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
quinn::EndpointConfig::default(),
Some(server_config),
internal_socket,
quinn::default_runtime()
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
let real_client_map = Arc::new(DashMap::new());
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
external_socket,
quinn_internal_addr,
proxy_ips,
security_policy,
Arc::clone(&real_client_map),
cancel,
));
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
info!(
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
port, quinn_internal_addr
);
Ok(QuicProxyRelay {
endpoint,
relay_task,
real_client_map,
})
}
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
external_socket: Arc<UdpSocket>,
quinn_internal_addr: SocketAddr,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
cancel: CancellationToken,
) {
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
debug!(
"QUIC PROXY v2 from {}: real client {}",
src_addr, header.source_addr
);
proxy_addr_map.insert(src_addr, header.source_addr);
continue; // consume the PROXY v2 datagram
}
Err(e) => {
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
debug!(
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
src_addr, e
);
}
}
}
}
// Determine real client address
let real_client = proxy_addr_map.get(&src_addr)
let real_client = proxy_addr_map
.get(&src_addr)
.map(|r| *r)
.unwrap_or(src_addr);
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
debug!(
"QUIC datagram from {} blocked by global security policy",
real_client
);
continue;
}
// Get or create relay session for this external source
let session = match relay_sessions.get(&src_addr) {
Some(s) => {
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
s.last_activity
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
Arc::clone(s.value())
}
None => {
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
}
};
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
warn!(
"QUIC relay: failed to connect relay socket to {}: {}",
quinn_internal_addr, e
);
continue;
}
let relay_local_addr = match relay_socket.local_addr() {
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
});
relay_sessions.insert(src_addr, Arc::clone(&session));
debug!("QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client);
debug!(
"QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client
);
session
}
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
if last_cleanup.elapsed() >= cleanup_interval {
last_cleanup = Instant::now();
let now_ms = epoch.elapsed().as_millis() as u64;
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
let stale_keys: Vec<SocketAddr> = relay_sessions
.iter()
.filter(|entry| {
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
let age =
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
age > session_timeout_ms
})
.map(|entry| *entry.key())
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
// Also clean orphaned proxy_addr_map entries (PROXY header received
// but no relay session was ever created, e.g. client never sent data)
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
let orphaned: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| relay_sessions.get(entry.key()).is_none())
.map(|entry| *entry.key())
.collect();
for key in orphaned {
proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
debug!(
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
key
);
}
}
}
@@ -328,8 +365,14 @@ async fn relay_return_path(
}
};
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
if let Err(e) = external_socket
.send_to(&buf[..len], external_src_addr)
.await
{
debug!(
"QUIC relay return send error to {}: {}",
external_src_addr, e
);
break;
}
}
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
route_cancels: Arc<DashMap<String, CancellationToken>>,
connection_registry: Arc<ConnectionRegistry>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) {
loop {
let incoming = tokio::select! {
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
let remote_addr = incoming.remote_address();
// Resolve real client IP from PROXY protocol map if available
let real_addr = real_client_map.as_ref()
let real_addr = real_client_map
.as_ref()
.and_then(|map| map.get(&remote_addr).map(|r| *r))
.unwrap_or(remote_addr);
let ip = real_addr.ip();
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&ip) {
debug!(
"QUIC connection from {} blocked by global security policy",
real_addr
);
continue;
}
// Per-IP rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
@@ -414,7 +468,10 @@ pub async fn quic_accept_loop(
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security, &ip, ctx.domain,
) {
debug!("QUIC connection from {} blocked by route security", real_addr);
debug!(
"QUIC connection from {} blocked by route security",
real_addr
);
continue;
}
}
@@ -425,7 +482,8 @@ pub async fn quic_accept_loop(
// Resolve per-route cancel token (child of global cancel)
let route_cancel = match route_id.as_deref() {
Some(id) => route_cancels.entry(id.to_string())
Some(id) => route_cancels
.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
None => cancel.child_token(),
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
let metrics = Arc::clone(&metrics);
let conn_tracker = Arc::clone(&conn_tracker);
let h3_svc = h3_service.clone();
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
let real_client_addr = if real_addr != remote_addr {
Some(real_addr)
} else {
None
};
tokio::spawn(async move {
// Register in connection registry (RAII guard removes on drop)
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
impl Drop for QuicConnGuard {
fn drop(&mut self) {
self.tracker.connection_closed(&self.ip);
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
self.metrics
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
}
}
let _guard = QuicConnGuard {
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
route_id,
};
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await {
match handle_quic_connection(
incoming,
route,
port,
Arc::clone(&metrics),
&conn_cancel,
h3_svc,
real_client_addr,
)
.await
{
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
}
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
debug!("QUIC connection established from {}", effective_addr);
// Check if this route has HTTP/3 enabled
let enable_http3 = route.action.udp.as_ref()
let enable_http3 = route
.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.and_then(|q| q.enable_http3)
.unwrap_or(false);
if enable_http3 {
if let Some(ref h3_svc) = h3_service {
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
debug!(
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
route.name
);
h3_svc
.handle_connection(connection, &route, port, real_client_addr, cancel)
.await
} else {
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
warn!(
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
route.name
);
// Keep connection alive until cancelled
tokio::select! {
_ = cancel.cancelled() => {}
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
}
} else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
.await
}
}
@@ -545,7 +630,10 @@ async fn handle_quic_stream_forwarding(
let metrics_arc = metrics;
// Resolve backend target
let target = route.action.targets.as_ref()
let target = route
.action
.targets
.as_ref()
.and_then(|t| t.first())
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
let backend_host = target.host.first();
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move {
match forward_quic_stream_to_tcp(
send_stream,
recv_stream,
&backend_addr,
stream_cancel,
).await {
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
.await
{
Ok((bytes_in, bytes_out)) => {
stream_metrics.record_bytes(
bytes_in, bytes_out,
bytes_in,
bytes_out,
stream_route_id.as_deref(),
Some(&ip_str),
);
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
debug!(
"QUIC stream forwarded: {}B in, {}B out",
bytes_in, bytes_out
);
}
Err(e) => {
debug!("QUIC stream forwarding error: {}", e);
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
tcp_write.shutdown(),
).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
total
});
@@ -721,8 +807,8 @@ mod tests {
let _ = rustls::crypto::ring::default_provider().install_default();
// Generate a single self-signed cert and use its key pair
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.unwrap();
let self_signed =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_der = self_signed.cert.der().clone();
let key_der = self_signed.key_pair.serialize_der();
@@ -737,6 +823,10 @@ mod tests {
// Port 0 = OS assigns a free port
let result = create_quic_endpoint(0, Arc::new(tls_config));
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
assert!(
result.is_ok(),
"QUIC endpoint creation failed: {:?}",
result.err()
);
}
}
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
}
// Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16)
| ((data[7] as usize) << 8)
| (data[8] as usize);
let _handshake_len =
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
let hello = &data[9..];
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
if let Some(value) = line
.strip_prefix("Host: ")
.or_else(|| line.strip_prefix("host: "))
{
// Strip port if present
let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() {
@@ -213,7 +215,10 @@ mod tests {
#[test]
fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
assert!(matches!(
extract_sni(&[0x16, 0x03]),
SniResult::NeedMoreData
));
}
#[test]
@@ -263,7 +268,8 @@ mod tests {
// Extension: type=0x0000 (SNI), length, data
let sni_extension = {
let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type
e.push(0x00);
e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data);
@@ -283,16 +289,20 @@ mod tests {
let hello_body = {
let mut h = Vec::new();
// Client version TLS 1.2
h.push(0x03); h.push(0x03);
h.push(0x03);
h.push(0x03);
// Random (32 bytes)
h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0
h.push(0x00);
// Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
h.push(0x00);
h.push(0x02);
h.push(0x00);
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01); h.push(0x00);
h.push(0x01);
h.push(0x00);
// Extensions
h.extend_from_slice(&extensions);
h
@@ -313,7 +323,8 @@ mod tests {
// TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new();
record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0
record.push(0x03);
record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake);
File diff suppressed because it is too large Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig;
@@ -29,7 +29,9 @@ pub struct CertResolver {
impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
pub fn new(
configs: &HashMap<String, TlsCertConfig>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new();
@@ -38,8 +40,10 @@ impl CertResolver {
for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
let ck = Arc::new(
CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
);
if domain == "*" {
fallback = Some(Arc::clone(&ck));
}
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_shared_tls_acceptor(
resolver: CertResolver,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
// Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new()
.map_err(|e| format!("Ticketer: {}", e))?;
config.ticketer =
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
info!(
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
);
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_tls_acceptor(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_tls_acceptor_h1_only(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
builder.with_no_client_auth().with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
fn resolve_tls_versions(
versions: Option<&[String]>,
) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
@@ -207,7 +221,8 @@ pub async fn accept_tls(
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG.get_or_init(|| {
SHARED_CLIENT_CONFIG
.get_or_init(|| {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
@@ -215,7 +230,8 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
.with_no_client_auth();
info!("Built shared backend TLS client config with session resumption");
Arc::new(config)
}).clone()
})
.clone()
}
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
SHARED_CLIENT_CONFIG_ALPN
.get_or_init(|| {
ensure_crypto_provider();
let mut config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
info!(
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
);
Arc::new(config)
}).clone()
})
.clone()
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
@@ -249,7 +269,8 @@ pub async fn connect_tls(
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
{
debug!("Failed to set keepalive on backend TLS socket: {}", e);
}
@@ -260,10 +281,12 @@ pub async fn connect_tls(
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
fn load_certs(
pem: &str,
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
let certs: Vec<CertificateDer<'static>> =
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
fn load_private_key(
pem: &str,
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
let key =
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
Ok(key)
}
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use arc_swap::ArcSwap;
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use rustproxy_config::{RouteActionType, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService;
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
route_cancels: Arc<DashMap<String, CancellationToken>>,
/// Shared connection registry for selective recycling.
connection_registry: Arc<ConnectionRegistry>,
/// Global ingress block policy, hot-reloadable without restarting listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
}
impl Drop for UdpListenerManager {
@@ -99,17 +102,26 @@ impl UdpListenerManager {
proxy_ips: Arc::new(Vec::new()),
route_cancels,
connection_registry,
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
}
}
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
if !ips.is_empty() {
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
info!(
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
ips.len()
);
}
self.proxy_ips = Arc::new(ips);
}
/// Set the shared global ingress security policy.
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
self.security_policy = policy;
}
/// Set the H3 proxy service for HTTP/3 request handling.
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
self.h3_service = Some(svc);
@@ -142,7 +154,9 @@ impl UdpListenerManager {
// Check if any route on this port uses QUIC
let rm = self.route_manager.load();
let has_quic = rm.routes_for_port(port).iter().any(|r| {
r.action.udp.as_ref()
r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.is_some()
});
@@ -164,8 +178,10 @@ impl UdpListenerManager {
None,
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint started on port {}", port);
} else {
// Proxy relay path: we own external socket, quinn on localhost
@@ -173,6 +189,7 @@ impl UdpListenerManager {
port,
tls,
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(),
)?;
let endpoint_for_updates = relay.endpoint.clone();
@@ -187,13 +204,18 @@ impl UdpListenerManager {
Some(relay.real_client_map),
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint with PROXY relay started on port {}", port);
}
return Ok(());
} else {
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
warn!(
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
port
);
}
}
@@ -214,6 +236,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer),
self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, None));
@@ -254,8 +277,10 @@ impl UdpListenerManager {
}
debug!("UDP listener stopped on port {}", port);
}
info!("All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count());
info!(
"All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count()
);
}
/// Update TLS config on all active QUIC endpoints (cert refresh).
@@ -288,11 +313,15 @@ impl UdpListenerManager {
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
let rm = self.route_manager.load();
let upgrade_ports: Vec<u16> = self.listeners.iter()
let upgrade_ports: Vec<u16> = self
.listeners
.iter()
.filter(|(_, (_, endpoint))| endpoint.is_none())
.filter(|(port, _)| {
rm.routes_for_port(**port).iter().any(|r| {
r.action.udp.as_ref()
r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.is_some()
})
@@ -301,17 +330,23 @@ impl UdpListenerManager {
.collect();
for port in upgrade_ports {
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
info!(
"Upgrading raw UDP listener on port {} to QUIC endpoint",
port
);
// Stop the raw UDP listener task and drain sessions to release the socket
if let Some((handle, _)) = self.listeners.remove(&port) {
handle.abort();
}
let drained = self.session_table.drain_port(
port, &self.metrics, &self.conn_tracker,
);
let drained = self
.session_table
.drain_port(port, &self.metrics, &self.conn_tracker);
if drained > 0 {
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
debug!(
"Drained {} UDP sessions on port {} for QUIC upgrade",
drained, port
);
}
// Brief yield to let aborted tasks drop their socket references
@@ -326,11 +361,17 @@ impl UdpListenerManager {
match create_result {
Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
info!(
"QUIC endpoint started on port {} (upgraded from raw UDP)",
port
);
}
Err(e) => {
// Port may still be held — retry once after a brief delay
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
warn!(
"QUIC endpoint creation failed on port {}, retrying: {}",
port, e
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let retry_result = if self.proxy_ips.is_empty() {
@@ -341,11 +382,17 @@ impl UdpListenerManager {
match retry_result {
Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
info!(
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
port
);
}
Err(e2) => {
error!("Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.", port, e2);
error!(
"Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.",
port, e2
);
// Fallback: rebind as raw UDP so the port isn't dead
if let Ok(()) = self.rebind_raw_udp(port).await {
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
@@ -358,7 +405,11 @@ impl UdpListenerManager {
}
/// Create a direct QUIC endpoint (quinn owns the socket).
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
fn create_quic_direct(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
let endpoint_for_updates = endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
@@ -372,17 +423,24 @@ impl UdpListenerManager {
None,
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(())
}
/// Create a QUIC endpoint with PROXY protocol relay.
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
fn create_quic_with_relay(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
port,
tls_config,
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(),
)?;
let endpoint_for_updates = relay.endpoint.clone();
@@ -397,8 +455,10 @@ impl UdpListenerManager {
Some(relay.real_client_map),
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(())
}
@@ -419,6 +479,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer),
self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, None));
@@ -458,7 +519,10 @@ impl UdpListenerManager {
info!("Datagram handler relay connected to {}", path);
}
Err(e) => {
error!("Failed to connect datagram handler relay to {}: {}", path, e);
error!(
"Failed to connect datagram handler relay to {}: {}",
path, e
);
}
}
}
@@ -514,6 +578,7 @@ impl UdpListenerManager {
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
cancel: CancellationToken,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) {
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
let mut buf = vec![0u8; 65535];
@@ -528,9 +593,11 @@ impl UdpListenerManager {
loop {
// Periodic cleanup: remove proxy_addr_map entries with no active session
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
{
last_proxy_cleanup = tokio::time::Instant::now();
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
let stale: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| {
let key: SessionKey = (*entry.key(), port);
session_table.get(&key).is_none()
@@ -538,7 +605,11 @@ impl UdpListenerManager {
.map(|entry| *entry.key())
.collect();
if !stale.is_empty() {
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
debug!(
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
stale.len(),
port
);
for addr in stale {
proxy_addr_map.remove(&addr);
}
@@ -564,14 +635,20 @@ impl UdpListenerManager {
let datagram = &buf[..len];
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
let effective_client_ip =
if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
let session_key: SessionKey = (client_addr, port);
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
if session_table.get(&session_key).is_none()
&& !proxy_addr_map.contains_key(&client_addr)
{
// No session and no prior PROXY header — check for PROXY v2
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
debug!(
"UDP PROXY v2 from {}: real client {}",
client_addr, header.source_addr
);
proxy_addr_map.insert(client_addr, header.source_addr);
continue; // discard the PROXY v2 datagram
}
@@ -585,7 +662,8 @@ impl UdpListenerManager {
}
} else {
// Use real client IP if we've previously seen a PROXY v2 header
proxy_addr_map.get(&client_addr)
proxy_addr_map
.get(&client_addr)
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip())
}
@@ -593,6 +671,15 @@ impl UdpListenerManager {
client_addr.ip()
};
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
debug!(
"UDP datagram from {} blocked by global security policy",
effective_client_ip
);
continue;
}
// Route matching — use effective (real) client IP
let rm = route_manager.load();
let ip_str = effective_client_ip.to_string();
@@ -611,7 +698,10 @@ impl UdpListenerManager {
let route_match = match rm.find_route(&ctx) {
Some(m) => m,
None => {
debug!("No UDP route matched for port {} from {}", port, client_addr);
debug!(
"No UDP route matched for port {} from {}",
port, client_addr
);
continue;
}
};
@@ -627,7 +717,9 @@ impl UdpListenerManager {
&client_addr,
port,
datagram,
).await {
)
.await
{
debug!("Failed to relay UDP datagram to TS: {}", e);
}
continue;
@@ -638,8 +730,10 @@ impl UdpListenerManager {
// Check datagram size
if len as u32 > udp_config.max_datagram_size {
debug!("UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr);
debug!(
"UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr
);
continue;
}
@@ -651,21 +745,27 @@ impl UdpListenerManager {
None => {
// New session — check per-IP limits using the real client IP
if !conn_tracker.try_accept(&effective_client_ip) {
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
debug!(
"UDP session rejected for {} (rate limit)",
effective_client_ip
);
continue;
}
if !session_table.can_create_session(
&effective_client_ip,
udp_config.max_sessions_per_ip,
) {
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
if !session_table
.can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
{
debug!(
"UDP session rejected for {} (per-IP session limit)",
effective_client_ip
);
continue;
}
// Resolve target
let target = match route_match.target.or_else(|| {
route.action.targets.as_ref().and_then(|t| t.first())
}) {
let target = match route_match
.target
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
{
Some(t) => t,
None => {
warn!("No target for UDP route {:?}", route_id);
@@ -686,13 +786,18 @@ impl UdpListenerManager {
}
};
if let Err(e) = backend_socket.connect(&backend_addr).await {
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
error!(
"Failed to connect backend UDP socket to {}: {}",
backend_addr, e
);
continue;
}
let backend_socket = Arc::new(backend_socket);
debug!("New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip);
debug!(
"New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip
);
// Spawn return-path relay task
let session_cancel = CancellationToken::new();
@@ -709,7 +814,9 @@ impl UdpListenerManager {
let session = Arc::new(UdpSession {
backend_socket,
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
last_activity: std::sync::atomic::AtomicU64::new(
session_table.elapsed_ms(),
),
created_at: std::time::Instant::now(),
route_id: route_id.map(|s| s.to_string()),
source_ip: effective_client_ip,
@@ -718,7 +825,11 @@ impl UdpListenerManager {
cancel: session_cancel,
});
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
if !session_table.insert(
session_key,
Arc::clone(&session),
udp_config.max_sessions_per_ip,
) {
warn!("Failed to insert UDP session (race condition)");
continue;
}
@@ -735,7 +846,9 @@ impl UdpListenerManager {
// Forward datagram to backend
match session.backend_socket.send(datagram).await {
Ok(_) => {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
metrics.record_datagram_in();
}
@@ -779,7 +892,9 @@ impl UdpListenerManager {
Ok(_) => {
// Update session activity
if let Some(session) = session_table.get(&session_key) {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
}
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
metrics.record_datagram_out();
@@ -814,7 +929,8 @@ impl UdpListenerManager {
let json = serde_json::to_vec(&msg)?;
let mut guard = writer.lock().await;
let stream = guard.as_mut()
let stream = guard
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
// Length-prefixed frame
@@ -879,9 +995,15 @@ impl UdpListenerManager {
}
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let source_port = reply
.get("sourcePort")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u16;
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or("");
let payload_b64 = reply
.get("payloadBase64")
.and_then(|v| v.as_str())
.unwrap_or("");
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
Ok(p) => p,
@@ -111,12 +111,15 @@ impl UdpSessionTable {
/// Look up an existing session.
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
self.sessions
.get(key)
.map(|entry| Arc::clone(entry.value()))
}
/// Check if we can create a new session for this IP (under the per-IP limit).
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
let count = self.ip_session_counts
let count = self
.ip_session_counts
.get(ip)
.map(|c| *c.value())
.unwrap_or(0);
@@ -124,12 +127,7 @@ impl UdpSessionTable {
}
/// Insert a new session. Returns false if per-IP limit exceeded.
pub fn insert(
&self,
key: SessionKey,
session: Arc<UdpSession>,
max_per_ip: u32,
) -> bool {
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
let ip = session.source_ip;
// Atomically check and increment per-IP count
@@ -173,7 +171,9 @@ impl UdpSessionTable {
let mut removed = 0;
// Collect keys to remove (avoid holding DashMap refs during removal)
let stale_keys: Vec<SessionKey> = self.sessions.iter()
let stale_keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| {
let last = entry.value().last_activity.load(Ordering::Relaxed);
now_ms.saturating_sub(last) >= timeout_ms
@@ -185,7 +185,8 @@ impl UdpSessionTable {
if let Some(session) = self.remove(&key) {
debug!(
"UDP session expired: {} -> port {} (idle {}ms)",
session.client_addr, key.1,
session.client_addr,
key.1,
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
);
conn_tracker.connection_closed(&session.source_ip);
@@ -210,7 +211,9 @@ impl UdpSessionTable {
metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker,
) -> usize {
let keys: Vec<SessionKey> = self.sessions.iter()
let keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| entry.key().1 == port)
.map(|entry| *entry.key())
.collect();
@@ -257,9 +260,8 @@ mod tests {
.enable_all()
.build()
.unwrap();
let backend_socket = rt.block_on(async {
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
});
let backend_socket =
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
let child_cancel = cancel.child_token();
let return_task = rt.spawn(async move {
+1 -1
View File
@@ -3,7 +3,7 @@
//! Route matching engine for RustProxy.
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
pub mod route_manager;
pub mod matchers;
pub mod route_manager;
pub use route_manager::*;
@@ -1,6 +1,6 @@
use ipnet::IpNet;
use std::net::IpAddr;
use std::str::FromStr;
use ipnet::IpNet;
/// Match an IP address against a pattern.
///
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
}
}
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
Some(format!(
"{}.{}.{}.{}/{}",
octets[0], octets[1], octets[2], octets[3], prefix_len
))
}
#[cfg(test)]
@@ -1,9 +1,9 @@
pub mod domain;
pub mod path;
pub mod ip;
pub mod header;
pub mod ip;
pub mod path;
pub use domain::*;
pub use path::*;
pub use ip::*;
pub use header::*;
pub use ip::*;
pub use path::*;
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
use crate::matchers;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
/// Context for route matching (subset of connection info).
pub struct MatchContext<'a> {
@@ -42,19 +42,14 @@ impl RouteManager {
};
// Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes
.into_iter()
.filter(|r| r.is_enabled())
.collect();
let mut enabled_routes: Vec<RouteConfig> =
routes.into_iter().filter(|r| r.is_enabled()).collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index
for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() {
manager.port_index
.entry(port)
.or_default()
.push(idx);
manager.port_index.entry(port).or_default().push(idx);
}
}
@@ -66,7 +61,9 @@ impl RouteManager {
/// Used to skip expensive header HashMap construction when no route needs it.
pub fn any_route_has_headers(&self, port: u16) -> bool {
if let Some(indices) = self.port_index.get(&port) {
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some())
indices
.iter()
.any(|&idx| self.routes[idx].route_match.headers.is_some())
} else {
false
}
@@ -99,8 +96,8 @@ impl RouteManager {
let ctx_transport = ctx.transport.as_ref();
match (route_transport, ctx_transport) {
// Route requires UDP only — reject non-UDP contexts
(Some(TransportProtocol::Udp), None) |
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
(Some(TransportProtocol::Udp), None)
| (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
// Route requires TCP only — reject UDP contexts
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
// Route has no transport (default = TCP) — reject UDP contexts
@@ -196,7 +193,11 @@ impl RouteManager {
}
/// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
fn find_target<'a>(
&self,
route: &'a RouteConfig,
ctx: &MatchContext<'_>,
) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() {
@@ -223,17 +224,11 @@ impl RouteManager {
}
// Fall back to first target without match criteria
best.or_else(|| {
targets.iter().find(|t| t.target_match.is_none())
})
best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
}
/// Check if a target match criteria matches the context.
fn matches_target(
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
// Port matching
if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) {
@@ -298,9 +293,7 @@ impl RouteManager {
// If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes
.iter()
.filter(|r| {
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
.collect();
if passthrough_routes.len() > 1 {
@@ -419,7 +412,11 @@ mod tests {
let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref()
assert!(result
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec())
.unwrap()
.contains(&"api.example.com"));
@@ -619,8 +616,14 @@ mod tests {
let result = manager.find_route(&ctx);
assert!(result.is_some());
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
.map(|d| d.to_vec()).unwrap();
let matched_domains = result
.unwrap()
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec())
.unwrap();
assert!(matched_domains.contains(&"*"));
}
@@ -735,7 +738,11 @@ mod tests {
assert_eq!(result.target.unwrap().host.first(), "default-backend");
}
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
fn make_route_with_protocol(
port: u16,
domain: Option<&str>,
protocol: Option<&str>,
) -> RouteConfig {
let mut route = make_route(port, domain, 0);
route.route_match.protocol = protocol.map(|s| s.to_string());
route
@@ -1029,8 +1036,10 @@ mod tests {
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&ctx).is_some(),
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes");
assert!(
manager.find_route(&ctx).is_some(),
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
);
}
#[test]
@@ -1051,7 +1060,9 @@ mod tests {
transport: None, // TCP (default)
};
assert!(manager.find_route(&ctx).is_none(),
"TCP TLS without SNI should NOT match domain-restricted routes");
assert!(
manager.find_route(&ctx).is_none(),
"TCP TLS without SNI should NOT match domain-restricted routes"
);
}
}
@@ -1,5 +1,5 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
/// Basic auth validator.
pub struct BasicAuthValidator {
+43 -24
View File
@@ -21,7 +21,7 @@ struct DomainScopedEntry {
}
/// Represents an IP pattern for matching.
#[derive(Debug)]
#[derive(Debug, Clone)]
enum IpPattern {
/// Exact IP match
Exact(IpAddr),
@@ -31,6 +31,37 @@ enum IpPattern {
Wildcard,
}
/// Compiled block list for early ingress filtering.
#[derive(Debug, Clone)]
pub struct IpBlockList {
block_list: Vec<IpPattern>,
}
impl IpBlockList {
pub fn new(block_list: &[String]) -> Self {
Self {
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
pub fn empty() -> Self {
Self {
block_list: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.block_list.is_empty()
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
let normalized = IpFilter::normalize_ip(ip);
self.block_list
.iter()
.any(|pattern| pattern.matches(&normalized))
}
}
impl IpPattern {
fn parse(s: &str) -> Self {
let s = s.trim();
@@ -68,8 +99,7 @@ fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
}
if p.starts_with("*.") {
let suffix = &p[1..]; // e.g., ".abc.xyz"
d.len() > suffix.len()
&& d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
} else {
false
}
@@ -127,7 +157,11 @@ impl IpFilter {
if let Some(req_domain) = domain {
for entry in &self.domain_scoped {
if entry.pattern.matches(ip) {
if entry.domains.iter().any(|d| domain_matches_pattern(d, req_domain)) {
if entry
.domains
.iter()
.any(|d| domain_matches_pattern(d, req_domain))
{
return true;
}
}
@@ -212,10 +246,7 @@ mod tests {
#[test]
fn test_block_trumps_allow() {
let filter = IpFilter::new(
&[plain("10.0.0.0/8")],
&["10.0.0.5".to_string()],
);
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
@@ -255,30 +286,21 @@ mod tests {
#[test]
fn test_domain_scoped_allows_matching_domain() {
let filter = IpFilter::new(
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&[],
);
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
}
#[test]
fn test_domain_scoped_denies_non_matching_domain() {
let filter = IpFilter::new(
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&[],
);
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
}
#[test]
fn test_domain_scoped_denies_without_domain() {
let filter = IpFilter::new(
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&[],
);
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
// Without domain context, domain-scoped entries cannot match
assert!(!filter.is_allowed_for_domain(&ip, None));
@@ -286,10 +308,7 @@ mod tests {
#[test]
fn test_domain_scoped_wildcard_domain() {
let filter = IpFilter::new(
&[scoped("10.8.0.2", &["*.abc.xyz"])],
&[],
);
let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
@@ -1,4 +1,4 @@
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure).
@@ -160,10 +160,7 @@ mod tests {
#[test]
fn test_extract_token_bearer() {
assert_eq!(
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
}
#[test]
+4 -4
View File
@@ -2,12 +2,12 @@
//!
//! IP filtering, rate limiting, and authentication for RustProxy.
pub mod ip_filter;
pub mod rate_limiter;
pub mod basic_auth;
pub mod ip_filter;
pub mod jwt_auth;
pub mod rate_limiter;
pub use ip_filter::*;
pub use rate_limiter::*;
pub use basic_auth::*;
pub use ip_filter::*;
pub use jwt_auth::*;
pub use rate_limiter::*;
+14 -13
View File
@@ -4,8 +4,7 @@
//! Account credentials are ephemeral — the consumer owns all persistence.
use instant_acme::{
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
AccountCredentials,
Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
};
use rcgen::{CertificateParams, KeyPair};
use thiserror::Error;
@@ -89,7 +88,11 @@ impl AcmeClient {
F: FnOnce(PendingChallenge) -> Fut,
Fut: std::future::Future<Output = Result<(), AcmeError>>,
{
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
info!(
"Starting ACME provisioning for {} via {}",
domain,
self.directory_url()
);
// 1. Get or create ACME account
let account = self.get_or_create_account().await?;
@@ -170,14 +173,14 @@ impl AcmeClient {
debug!("Order ready, finalizing...");
// 6. Generate CSR and finalize
let key_pair = KeyPair::generate().map_err(|e| {
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
})?;
let key_pair = KeyPair::generate()
.map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
})?;
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let mut params = CertificateParams::new(vec![domain.to_string()])
.map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let csr = params.serialize_request(&key_pair).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
@@ -219,9 +222,7 @@ impl AcmeClient {
.certificate()
.await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
.ok_or_else(|| {
AcmeError::FinalizationFailed("No certificate returned".to_string())
})?;
.ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
let private_key_pem = key_pair.serialize_pem();
+11 -13
View File
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use tracing::info;
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
use crate::acme::AcmeClient;
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
#[derive(Debug, Error)]
pub enum CertManagerError {
@@ -45,17 +45,13 @@ impl CertManager {
/// Create an ACME client using this manager's configuration.
/// Returns None if no ACME email is configured.
pub fn acme_client(&self) -> Option<AcmeClient> {
self.acme_email.as_ref().map(|email| {
AcmeClient::new(email.clone(), self.use_production)
})
self.acme_email
.as_ref()
.map(|email| AcmeClient::new(email.clone(), self.use_production))
}
/// Load a static certificate into the store (infallible — pure cache insert).
pub fn load_static(
&mut self,
domain: String,
bundle: CertBundle,
) {
pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
self.store.store(domain, bundle);
}
@@ -108,20 +104,22 @@ impl CertManager {
F: FnOnce(String, String) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let acme_client = self.acme_client()
.ok_or(CertManagerError::NoEmail)?;
let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
info!("Renewing certificate for {}", domain);
let domain_owned = domain.to_string();
let result = acme_client.provision(&domain_owned, |pending| {
let result = acme_client
.provision(&domain_owned, |pending| {
let token = pending.token.clone();
let key_auth = pending.key_authorization.clone();
async move {
challenge_setup(token, key_auth).await;
Ok(())
}
}).await.map_err(|e| CertManagerError::AcmeFailure {
})
.await
.map_err(|e| CertManagerError::AcmeFailure {
domain: domain.to_string(),
message: e.to_string(),
})?;
+15 -6
View File
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Certificate metadata stored alongside certs.
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -90,8 +90,10 @@ mod tests {
fn make_test_bundle(domain: &str) -> CertBundle {
CertBundle {
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
.to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
.to_string(),
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
@@ -122,7 +124,8 @@ mod tests {
let mut store = CertStore::new();
let mut bundle = make_test_bundle("secure.com");
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
bundle.ca_pem =
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
store.store("secure.com".to_string(), bundle);
let loaded = store.get("secure.com").unwrap();
@@ -147,7 +150,10 @@ mod tests {
fn test_remove_cert() {
let mut store = CertStore::new();
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com"));
store.store(
"remove-me.com".to_string(),
make_test_bundle("remove-me.com"),
);
assert!(store.has("remove-me.com"));
let removed = store.remove("remove-me.com");
@@ -165,7 +171,10 @@ mod tests {
fn test_wildcard_domain() {
let mut store = CertStore::new();
store.store("*.example.com".to_string(), make_test_bundle("*.example.com"));
store.store(
"*.example.com".to_string(),
make_test_bundle("*.example.com"),
);
assert!(store.has("*.example.com"));
let loaded = store.get("*.example.com").unwrap();
+3 -3
View File
@@ -3,11 +3,11 @@
//! TLS certificate management for RustProxy.
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
pub mod cert_store;
pub mod cert_manager;
pub mod acme;
pub mod cert_manager;
pub mod cert_store;
pub mod sni_resolver;
pub use cert_store::*;
pub use cert_manager::*;
pub use cert_store::*;
pub use sni_resolver::*;
+12 -8
View File
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error};
use tracing::{debug, error, info};
/// ACME HTTP-01 challenge server.
pub struct ChallengeServer {
@@ -47,7 +47,10 @@ impl ChallengeServer {
}
/// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn start(
&mut self,
port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port);
@@ -101,10 +104,7 @@ impl ChallengeServer {
pub async fn stop(&mut self) {
self.cancel.cancel();
if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
}
self.challenges.clear();
self.cancel = CancellationToken::new();
@@ -154,10 +154,14 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
.await
.unwrap();
let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; });
tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new()))
+262 -105
View File
@@ -57,24 +57,27 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result;
use tracing::{info, warn, debug, error};
use arc_swap::ArcSwap;
use tracing::{debug, error, info, warn};
// Re-export key types
pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http;
pub use rustproxy_metrics;
pub use rustproxy_passthrough;
pub use rustproxy_routing;
pub use rustproxy_security;
pub use rustproxy_tls;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec};
use rustproxy_config::{CertificateSpec, RouteConfig, RustProxyOptions, TlsMode};
use rustproxy_metrics::{Metrics, MetricsCollector, Statistics};
use rustproxy_passthrough::{
ConnectionConfig, TcpListenerManager, TlsCertConfig, UdpListenerManager,
};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_security::IpBlockList;
use rustproxy_tls::{CertBundle, CertManager, CertMetadata, CertSource, CertStore};
use tokio_util::sync::CancellationToken;
/// Certificate status.
@@ -106,6 +109,8 @@ pub struct RustProxy {
loaded_certs: HashMap<String, TlsCertConfig>,
/// Cancellation token for cooperative shutdown of background tasks.
cancel_token: CancellationToken,
/// Shared global ingress blocklist, hot-reloadable across TCP/UDP listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
}
impl RustProxy {
@@ -127,13 +132,19 @@ impl RustProxy {
let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options)
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let cert_manager =
Self::build_cert_manager(&options).map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let retention = options.metrics.as_ref()
let retention = options
.metrics
.as_ref()
.and_then(|m| m.retention_seconds)
.unwrap_or(3600) as usize;
let security_policy = Arc::new(ArcSwap::from(Arc::new(Self::build_ip_block_list(
options.security_policy.as_ref(),
))));
Ok(Self {
options,
route_table: ArcSwap::from(Arc::new(route_manager)),
@@ -149,6 +160,7 @@ impl RustProxy {
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
loaded_certs: HashMap::new(),
cancel_token: CancellationToken::new(),
security_policy,
})
}
@@ -163,11 +175,13 @@ impl RustProxy {
// Apply default target if route has no targets
if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}",
default_target.host, default_target.port,
route.name.as_deref().unwrap_or("unnamed"));
route.action.targets = Some(vec![
rustproxy_config::RouteTarget {
debug!(
"Applying default target {}:{} to route {:?}",
default_target.host,
default_target.port,
route.name.as_deref().unwrap_or("unnamed")
);
route.action.targets = Some(vec![rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
port: rustproxy_config::PortSpec::Fixed(default_target.port),
@@ -179,8 +193,7 @@ impl RustProxy {
advanced: None,
backend_transport: None,
priority: None,
}
]);
}]);
}
}
@@ -199,7 +212,10 @@ impl RustProxy {
if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some(
allow_list.iter().map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone())).collect()
allow_list
.iter()
.map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone()))
.collect(),
);
}
if let Some(ref block_list) = default_security.ip_block_list {
@@ -208,8 +224,10 @@ impl RustProxy {
// Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed"));
debug!(
"Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed")
);
route.security = Some(security);
}
}
@@ -224,13 +242,17 @@ impl RustProxy {
return None;
}
let email = acme.email.clone()
.or_else(|| acme.account_email.clone());
let email = acme.email.clone().or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new();
Some(CertManager::new(store, email, use_production, renew_before_days))
Some(CertManager::new(
store,
email,
use_production,
renew_before_days,
))
}
/// Build ConnectionConfig from RustProxyOptions.
@@ -248,7 +270,10 @@ impl RustProxy {
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
proxy_ips: options.proxy_ips.as_deref().unwrap_or(&[])
proxy_ips: options
.proxy_ips
.as_deref()
.unwrap_or(&[])
.iter()
.filter_map(|s| s.parse::<std::net::IpAddr>().ok())
.collect(),
@@ -258,6 +283,22 @@ impl RustProxy {
}
}
fn build_ip_block_list(policy: Option<&rustproxy_config::SecurityPolicy>) -> IpBlockList {
let Some(policy) = policy else {
return IpBlockList::empty();
};
let mut entries = Vec::new();
if let Some(blocked_ips) = &policy.blocked_ips {
entries.extend(blocked_ips.iter().cloned());
}
if let Some(blocked_cidrs) = &policy.blocked_cidrs {
entries.extend(blocked_cidrs.iter().cloned());
}
IpBlockList::new(&entries)
}
/// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> {
if self.started {
@@ -272,7 +313,11 @@ impl RustProxy {
let route_manager = self.route_table.load();
let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
info!(
"Configured {} routes on {} ports",
route_manager.route_count(),
ports.len()
);
// Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics(
@@ -282,7 +327,8 @@ impl RustProxy {
// Apply connection config from options
let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
debug!(
"Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms,
@@ -291,6 +337,7 @@ impl RustProxy {
// Clone proxy_ips before conn_config is moved into the TCP listener
let udp_proxy_ips = conn_config.proxy_ips.clone();
listener.set_connection_config(conn_config);
listener.set_security_policy(Arc::clone(&self.security_policy));
// Share the socket-handler relay path with the listener
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
@@ -303,10 +350,13 @@ impl RustProxy {
let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
tls_configs.insert(
domain.clone(),
TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
},
);
}
}
}
@@ -330,7 +380,9 @@ impl RustProxy {
let mut tcp_ports = std::collections::HashSet::new();
let mut udp_ports = std::collections::HashSet::new();
for route in &self.options.routes {
if !route.is_enabled() { continue; }
if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref();
let route_ports = route.route_match.ports.to_ports();
for port in route_ports {
@@ -371,6 +423,7 @@ impl RustProxy {
connection_registry,
);
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Share HttpProxyService with H3 — same route matching, connection
// pool, and ALPN protocol detection as the TCP/HTTP path.
@@ -379,10 +432,15 @@ impl RustProxy {
udp_mgr.set_h3_service(Arc::new(h3_svc));
for port in &udp_ports {
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
udp_mgr
.add_port_with_tls(*port, quic_tls_config.clone())
.await?;
}
info!("UDP listeners started on {} ports: {:?}",
udp_ports.len(), udp_mgr.listening_ports());
info!(
"UDP listeners started on {} ports: {:?}",
udp_ports.len(),
udp_mgr.listening_ports()
);
self.udp_listener_manager = Some(udp_mgr);
}
@@ -391,16 +449,22 @@ impl RustProxy {
// Start the throughput sampling task with cooperative cancellation
let metrics = Arc::clone(&self.metrics);
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
let conn_tracker = self
.listener_manager
.as_ref()
.unwrap()
.conn_tracker()
.clone();
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
let interval_ms = self.options.metrics.as_ref()
let interval_ms = self
.options
.metrics
.as_ref()
.and_then(|m| m.sample_interval_ms)
.unwrap_or(1000);
let sampling_cancel = self.cancel_token.clone();
self.sampling_handle = Some(tokio::spawn(async move {
let mut interval = tokio::time::interval(
std::time::Duration::from_millis(interval_ms)
);
let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
loop {
tokio::select! {
_ = sampling_cancel.cancelled() => break,
@@ -442,7 +506,10 @@ impl RustProxy {
continue;
}
let cert_spec = route.action.tls.as_ref()
let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec {
@@ -466,16 +533,25 @@ impl RustProxy {
return;
}
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
info!(
"Auto-provisioning certificates for {} domains",
domains_to_provision.len()
);
// Start challenge server
let acme_port = self.options.acme.as_ref()
let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
error!(
"Failed to start ACME challenge server on port {}: {}",
acme_port, e
);
return;
}
@@ -488,13 +564,15 @@ impl RustProxy {
if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| {
let result = acme_client
.provision(domain, |pending| {
challenge_server_ref.set_challenge(
pending.token.clone(),
pending.key_authorization.clone(),
);
async move { Ok(()) }
}).await;
})
.await;
match result {
Ok((cert_pem, key_pem)) => {
@@ -539,7 +617,10 @@ impl RustProxy {
None => return,
};
let auto_renew = self.options.acme.as_ref()
let auto_renew = self
.options
.acme
.as_ref()
.and_then(|a| a.auto_renew)
.unwrap_or(true);
@@ -547,11 +628,17 @@ impl RustProxy {
return;
}
let check_interval_hours = self.options.acme.as_ref()
let check_interval_hours = self
.options
.acme
.as_ref()
.and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24);
let acme_port = self.options.acme.as_ref()
let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
@@ -664,8 +751,7 @@ impl RustProxy {
/// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes
rustproxy_config::validate_routes(&routes)
.map_err(|errors| {
rustproxy_config::validate_routes(&routes).map_err(|errors| {
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
})?;
@@ -673,8 +759,11 @@ impl RustProxy {
let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports",
new_manager.route_count(), new_ports.len());
info!(
"Updating routes: {} routes on {} ports",
new_manager.route_count(),
new_ports.len()
);
// Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
@@ -684,28 +773,35 @@ impl RustProxy {
};
// Prune per-route metrics for route IDs that no longer exist
let active_route_ids: HashSet<String> = routes.iter()
.filter_map(|r| r.id.clone())
.collect();
let active_route_ids: HashSet<String> =
routes.iter().filter_map(|r| r.id.clone()).collect();
self.metrics.retain_routes(&active_route_ids);
// Prune per-backend metrics for backends no longer in any route target.
// For PortSpec::Preserve routes, expand across all listening ports since
// the actual runtime port depends on the incoming connection.
let listening_ports = self.get_listening_ports();
let active_backends: HashSet<String> = routes.iter()
let active_backends: HashSet<String> = routes
.iter()
.filter_map(|r| r.action.targets.as_ref())
.flat_map(|targets| targets.iter())
.flat_map(|target| {
let hosts: Vec<String> = target.host.to_vec().into_iter().map(|s| s.to_string()).collect();
let hosts: Vec<String> = target
.host
.to_vec()
.into_iter()
.map(|s| s.to_string())
.collect();
match &target.port {
rustproxy_config::PortSpec::Fixed(p) => {
hosts.into_iter().map(|h| format!("{}:{}", h, p)).collect::<Vec<_>>()
}
rustproxy_config::PortSpec::Fixed(p) => hosts
.into_iter()
.map(|h| format!("{}:{}", h, p))
.collect::<Vec<_>>(),
_ => {
// Preserve/special: expand across all listening ports
let lp = &listening_ports;
hosts.into_iter()
hosts
.into_iter()
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
.collect::<Vec<_>>()
}
@@ -733,10 +829,13 @@ impl RustProxy {
let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
tls_configs.insert(
domain.clone(),
TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
},
);
}
}
}
@@ -753,7 +852,9 @@ impl RustProxy {
// Cancel connections on routes that were removed or disabled
listener.invalidate_removed_routes(&active_route_ids);
// Clean up registry entries for removed routes
listener.connection_registry().cleanup_removed_routes(&active_route_ids);
listener
.connection_registry()
.cleanup_removed_routes(&active_route_ids);
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
listener.prune_http_proxy_caches(&active_route_ids);
@@ -766,9 +867,10 @@ impl RustProxy {
None => continue,
};
// Find corresponding old route
let old_route = old_manager.routes().iter().find(|r| {
r.id.as_deref() == Some(new_id)
});
let old_route = old_manager
.routes()
.iter()
.find(|r| r.id.as_deref() == Some(new_id));
let old_route = match old_route {
Some(r) => r,
None => continue, // new route, no existing connections to recycle
@@ -812,11 +914,13 @@ impl RustProxy {
{
let mut new_udp_ports = HashSet::new();
for route in &routes {
if !route.is_enabled() { continue; }
if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref();
match transport {
Some(rustproxy_config::TransportProtocol::Udp) |
Some(rustproxy_config::TransportProtocol::All) => {
Some(rustproxy_config::TransportProtocol::Udp)
| Some(rustproxy_config::TransportProtocol::All) => {
for port in route.route_match.ports.to_ports() {
new_udp_ports.insert(port);
}
@@ -825,7 +929,8 @@ impl RustProxy {
}
}
let old_udp_ports: HashSet<u16> = self.udp_listener_manager
let old_udp_ports: HashSet<u16> = self
.udp_listener_manager
.as_ref()
.map(|u| u.listening_ports().into_iter().collect())
.unwrap_or_default();
@@ -847,6 +952,7 @@ impl RustProxy {
connection_registry,
);
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
let http_proxy = listener.http_proxy().clone();
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
@@ -898,56 +1004,77 @@ impl RustProxy {
/// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref()
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
let cm_arc = self.cert_manager.as_ref().ok_or_else(|| {
anyhow::anyhow!("No certificate manager configured (ACME not enabled)")
})?;
// Find the route by name
let route = self.options.routes.iter()
let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref()
let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
info!(
"Provisioning certificate for route '{}' (domain: {})",
route_name, domain
);
// Start challenge server
let acme_port = self.options.acme.as_ref()
let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await
cs.start(acme_port)
.await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| {
let result = cm
.renew_domain(&domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
})
.await;
drop(cm);
cs.stop().await;
let bundle = result
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
let bundle = result.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
tls_configs.insert(
domain.clone(),
TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
},
);
{
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
tls_configs.insert(
d.clone(),
TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
},
);
}
}
}
@@ -966,7 +1093,10 @@ impl RustProxy {
}
}
info!("Certificate provisioned and loaded for route '{}'", route_name);
info!(
"Certificate provisioned and loaded for route '{}'",
route_name
);
Ok(())
}
@@ -978,10 +1108,16 @@ impl RustProxy {
/// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter()
let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref()
let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager {
@@ -1010,8 +1146,9 @@ impl RustProxy {
let mut metrics = self.metrics.snapshot();
if let Some(ref lm) = self.listener_manager {
let entries = lm.http_proxy().protocol_cache_snapshot();
metrics.detected_protocols = entries.into_iter().map(|e| {
rustproxy_metrics::ProtocolCacheEntryMetric {
metrics.detected_protocols = entries
.into_iter()
.map(|e| rustproxy_metrics::ProtocolCacheEntryMetric {
host: e.host,
port: e.port,
domain: e.domain,
@@ -1026,8 +1163,8 @@ impl RustProxy {
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
h2_consecutive_failures: e.h2_consecutive_failures,
h3_consecutive_failures: e.h3_consecutive_failures,
}
}).collect();
})
.collect();
}
metrics
}
@@ -1058,9 +1195,7 @@ impl RustProxy {
/// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
let uptime = self.started_at.map(|t| t.elapsed().as_secs()).unwrap_or(0);
Statistics {
active_connections: self.metrics.active_connections(),
@@ -1071,6 +1206,13 @@ impl RustProxy {
}
}
/// Update the global ingress security policy.
pub fn set_security_policy(&mut self, policy: rustproxy_config::SecurityPolicy) {
self.security_policy
.store(Arc::new(Self::build_ip_block_list(Some(&policy))));
self.options.security_policy = Some(policy);
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
/// take effect immediately for all new connections.
@@ -1130,10 +1272,13 @@ impl RustProxy {
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !configs.contains_key(d) {
configs.insert(d.clone(), TlsCertConfig {
configs.insert(
d.clone(),
TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
},
);
}
}
}
@@ -1166,7 +1311,8 @@ impl RustProxy {
info!("Loading certificate for domain: {}", domain);
// Check if the cert actually changed (for selective connection recycling)
let cert_changed = self.loaded_certs
let cert_changed = self
.loaded_certs
.get(domain)
.map(|existing| existing.cert_pem != cert_pem)
.unwrap_or(false); // new domain = no existing connections to recycle
@@ -1196,10 +1342,13 @@ impl RustProxy {
}
// Persist in loaded_certs so future rebuild calls include this cert
self.loaded_certs.insert(domain.to_string(), TlsCertConfig {
self.loaded_certs.insert(
domain.to_string(),
TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
});
},
);
// Hot-swap TLS config on TCP and QUIC listeners
let tls_configs = self.current_tls_configs().await;
@@ -1222,7 +1371,9 @@ impl RustProxy {
// Recycle existing connections if cert actually changed
if cert_changed {
if let Some(ref listener) = self.listener_manager {
listener.connection_registry().recycle_for_cert_change(domain);
listener
.connection_registry()
.recycle_for_cert_change(domain);
}
}
@@ -1244,16 +1395,22 @@ impl RustProxy {
continue;
}
let cert_spec = route.action.tls.as_ref()
let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig {
configs.insert(
domain.to_string(),
TlsCertConfig {
cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
});
},
);
}
}
}
+4 -9
View File
@@ -1,12 +1,12 @@
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use anyhow::Result;
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
)
.init();
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!(
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
// Validate-only mode
if cli.validate {
+140 -48
View File
@@ -1,7 +1,7 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use tracing::{error, info};
use crate::RustProxy;
use rustproxy_config::RustProxyOptions;
@@ -141,14 +141,19 @@ async fn handle_request(
"start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
"getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
"setSocketHandlerRelay" => {
handle_set_socket_handler_relay(&id, &request.params, proxy).await
}
"setDatagramHandlerRelay" => {
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
}
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
@@ -167,7 +172,12 @@ async fn handle_start(
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'config' parameter".to_string(),
)
}
};
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
@@ -176,8 +186,7 @@ async fn handle_start(
};
match RustProxy::new(options) {
Ok(mut p) => {
match p.start().await {
Ok(mut p) => match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
@@ -187,27 +196,21 @@ async fn handle_start(
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
}
}
},
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
}
}
async fn handle_stop(
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
match proxy.as_mut() {
Some(p) => {
match p.stop().await {
Some(p) => match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
}
}
},
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
@@ -224,7 +227,12 @@ async fn handle_update_routes(
let routes = match params.get("routes") {
Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routes' parameter".to_string(),
)
}
};
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
@@ -234,36 +242,72 @@ async fn handle_update_routes(
match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
Err(e) => {
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
}
}
}
fn handle_get_metrics(
fn handle_set_security_policy(
id: &str,
proxy: &Option<RustProxy>,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let policy = match params.get("policy") {
Some(policy) => policy,
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'policy' parameter".to_string(),
)
}
};
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
Ok(policy) => policy,
Err(e) => {
return ManagementResponse::err(
id.to_string(),
format!("Invalid security policy: {}", e),
)
}
};
p.set_security_policy(policy);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let metrics = p.get_metrics();
match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize metrics: {}", e),
),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
fn handle_get_statistics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let stats = p.get_statistics();
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize statistics: {}", e),
),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to provision certificate: {}", e),
),
}
}
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to renew certificate: {}", e),
),
}
}
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
Some(status) => ManagementResponse::ok(
id.to_string(),
serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
})),
}),
),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
}
}
fn handle_get_listening_ports(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let ports = p.get_listening_ports();
@@ -361,7 +426,8 @@ async fn handle_set_socket_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
@@ -381,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
};
match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to add port {}: {}", port, e),
),
}
}
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
};
match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to remove port {}: {}", port, e),
),
}
}
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'domain' parameter".to_string(),
)
}
};
let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
}
};
let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
}
};
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
let ca = params
.get("ca")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to load certificate for {}: {}", domain, e),
),
}
}
+8 -8
View File
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str.lines()
let host = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
@@ -336,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines()
let ws_key = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
@@ -378,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
@@ -458,11 +462,7 @@ pub fn make_tls_terminate_route(
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
pub async fn start_tls_ws_echo_backend(
port: u16,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
@@ -1,9 +1,9 @@
mod common;
use bytes::Buf;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
use bytes::Buf;
use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
use std::sync::Arc;
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
@@ -14,7 +14,14 @@ fn make_h3_route(
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
let mut route = make_tls_terminate_route(
port,
"localhost",
target_host,
target_port,
cert_pem,
key_pem,
);
route.route_match.transport = Some(TransportProtocol::All);
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
route.action.udp = Some(RouteUdp {
@@ -89,9 +96,7 @@ async fn test_h3_response_stream_finishes() {
.await
.expect("QUIC handshake failed");
let (mut driver, mut send_request) = h3::client::new(
h3_quinn::Connection::new(connection),
)
let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
.await
.expect("H3 connection setup failed");
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
.body(())
.unwrap();
let mut stream = send_request.send_request(req).await
let mut stream = send_request
.send_request(req)
.await
.expect("Failed to send H3 request");
stream.finish().await
stream
.finish()
.await
.expect("Failed to finish sending H3 request body");
// 6. Read response headers
let resp = stream.recv_response().await
let resp = stream
.recv_response()
.await
.expect("Failed to receive H3 response");
assert_eq!(resp.status(), http::StatusCode::OK,
"Expected 200 OK, got {}", resp.status());
assert_eq!(
resp.status(),
http::StatusCode::OK,
"Expected 200 OK, got {}",
resp.status()
);
// 7. Read body and verify stream ends (FIN received)
// This is the critical assertion: recv_data() must return None (stream ended)
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
let result = with_timeout(async {
let result = with_timeout(
async {
let mut total = 0usize;
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
total += chunk.remaining();
}
// recv_data() returned None => stream ended (FIN received)
total
}, 10)
},
10,
)
.await;
let bytes_received = result.expect(
"TIMEOUT: H3 stream never ended (FIN not received by client). \
The proxy sent all response data but failed to send the QUIC stream FIN."
The proxy sent all response data but failed to send the QUIC stream FIN.",
);
assert_eq!(
bytes_received,
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
assert!(
result.contains(r#""method":"GET"#),
"Expected GET method, got: {}",
result
);
assert!(
result.contains(r#""path":"/hello"#),
"Expected /hello path, got: {}",
result
);
assert!(
result.contains(r#""backend":"main"#),
"Expected main backend, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
let options = RustProxyOptions {
routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
make_test_route(
proxy_port,
Some("alpha.example.com"),
"127.0.0.1",
backend1_port,
),
make_test_route(
proxy_port,
Some("beta.example.com"),
"127.0.0.1",
backend2_port,
),
],
..Default::default()
};
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let alpha_result = with_timeout(async {
let alpha_result = with_timeout(
async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
assert!(
alpha_result.contains(r#""backend":"alpha"#),
"Expected alpha backend, got: {}",
alpha_result
);
// Test beta domain
let beta_result = with_timeout(async {
let beta_result = with_timeout(
async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
assert!(
beta_result.contains(r#""backend":"beta"#),
"Expected beta backend, got: {}",
beta_result
);
proxy.stop().await.unwrap();
}
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test API path
let api_result = with_timeout(async {
let api_result = with_timeout(
async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
assert!(
api_result.contains(r#""backend":"api"#),
"Expected api backend, got: {}",
api_result
);
// Test web path (no /api prefix)
let web_result = with_timeout(async {
let web_result = with_timeout(
async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
assert!(
web_result.contains(r#""backend":"web"#),
"Expected web backend, got: {}",
web_result
);
proxy.stop().await.unwrap();
}
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
.unwrap();
// Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
assert!(result.to_lowercase().contains("access-control-allow-origin"),
"Expected CORS header, got: {}", result);
assert!(
result.contains("204"),
"Expected 204 status, got: {}",
result
);
assert!(
result
.to_lowercase()
.contains("access-control-allow-origin"),
"Expected CORS header, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
}, 10)
},
10,
)
.await
.unwrap();
// Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
assert!(
result.contains("500"),
"Expected 500 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
// Create a route only for a specific domain
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
routes: vec![make_test_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
9999,
)],
..Default::default()
};
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
}, 10)
},
10,
)
.await
.unwrap();
// Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
}, 10)
},
10,
)
.await
.unwrap();
// Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -295,7 +388,8 @@ async fn test_https_terminate_http_forward() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
@@ -319,14 +413,28 @@ async fn test_https_terminate_http_forward() {
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
assert!(
body.contains(r#""method":"GET"#),
"Expected GET, got: {}",
body
);
assert!(
body.contains(r#""path":"/api/data"#),
"Expected /api/data, got: {}",
body
);
assert!(
body.contains(r#""backend":"tls-backend"#),
"Expected tls-backend, got: {}",
body
);
proxy.stop().await.unwrap();
}
@@ -347,7 +455,8 @@ async fn test_websocket_through_proxy() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -369,18 +478,24 @@ async fn test_websocket_through_proxy() {
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; }
if n == 0 {
break;
}
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
if response_buf[len - 4..] == *b"\r\n\r\n" {
break;
}
}
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
assert!(
response_str.contains("101"),
"Expected 101 Switching Protocols, got: {}",
response_str
);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
@@ -399,7 +514,9 @@ async fn test_websocket_through_proxy() {
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
// Create terminate-and-reencrypt routes
let mut route1 = make_tls_terminate_route(
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
);
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let mut route2 = make_tls_terminate_route(
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
);
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -450,7 +577,8 @@ async fn test_terminate_and_reencrypt_http_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
let alpha_result = with_timeout(async {
let alpha_result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
@@ -461,16 +589,20 @@ async fn test_terminate_and_reencrypt_http_routing() {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
let request =
"GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -498,7 +630,8 @@ async fn test_terminate_and_reencrypt_http_routing() {
);
// Test beta domain - different host goes to different backend
let beta_result = with_timeout(async {
let beta_result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
@@ -509,16 +642,20 @@ async fn test_terminate_and_reencrypt_http_routing() {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
let request =
"GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector =
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send WebSocket upgrade request through TLS
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
assert!(wait_for_port(proxy_port, 2000).await);
// HTTP request should match the route and get proxied
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
assert!(
wait_for_port(port, 2000).await,
"Port should be listening after start"
);
proxy.stop().await.unwrap();
// Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
assert!(
!wait_for_port(port, 200).await,
"Port should not be listening after stop"
);
}
#[tokio::test]
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
routes: vec![make_test_route(
port,
Some("old.example.com"),
"127.0.0.1",
8080,
)],
..Default::default()
};
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
proxy.start().await.unwrap();
// Update routes atomically
let new_routes = vec![
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
];
let new_routes = vec![make_test_route(
port,
Some("new.example.com"),
"127.0.0.1",
9090,
)];
let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok());
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
// Add a new port
proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
assert!(
wait_for_port(port2, 2000).await,
"New port should be listening"
);
// Remove the port
proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
assert!(
!wait_for_port(port2, 200).await,
"Removed port should not be listening"
);
// Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
assert!(
wait_for_port(port1, 200).await,
"Original port should still be listening"
);
proxy.stop().await.unwrap();
}
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
assert!(
stats.total_connections > 0,
"Expected total_connections > 0, got {}",
stats.total_connections
);
proxy.stop().await.unwrap();
}
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0,
"Expected some connections tracked, got {}", stats.total_connections);
assert!(
stats.total_connections > 0,
"Expected some connections tracked, got {}",
stats.total_connections
);
proxy.stop().await.unwrap();
}
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
assert!(
!wait_for_port(port2, 200).await,
"port2 should not be listening yet"
);
// Update routes to use port2 instead
let new_routes = vec![
make_test_route(port2, None, "127.0.0.1", backend_port),
];
let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
assert!(
wait_for_port(port2, 2000).await,
"port2 should be listening after reload"
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
assert!(
!wait_for_port(port1, 200).await,
"port1 should be closed after reload"
);
// Verify port2 works
let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
assert!(
ports.contains(&port2),
"Expected port2 in listening ports: {:?}",
ports
);
assert!(
!ports.contains(&port1),
"port1 should not be in listening ports: {:?}",
ports
);
proxy.stop().await.unwrap();
}
@@ -24,10 +24,14 @@ async fn test_tcp_forward_echo() {
proxy.start().await.unwrap();
// Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
assert!(
wait_for_port(proxy_port, 2000).await,
"Proxy port not ready"
);
// Connect and send data
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -36,7 +40,9 @@ async fn test_tcp_forward_echo() {
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
@@ -61,7 +67,8 @@ async fn test_tcp_forward_large_payload() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -75,7 +82,9 @@ async fn test_tcp_forward_large_payload() {
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 10)
},
10,
)
.await
.unwrap();
@@ -100,7 +109,8 @@ async fn test_tcp_forward_multiple_connections() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
@@ -122,7 +132,9 @@ async fn test_tcp_forward_multiple_connections() {
results.push(handle.await.unwrap());
}
results
}, 10)
},
10,
)
.await
.unwrap();
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async {
let result = with_timeout(
async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
}, 5)
},
5,
)
.await
.unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down");
assert!(
result,
"Should be able to connect to proxy even if backend is down"
);
proxy.stop().await.unwrap();
}
@@ -178,7 +196,8 @@ async fn test_tcp_forward_bidirectional() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -187,7 +206,9 @@ async fn test_tcp_forward_bidirectional() {
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
make_tls_passthrough_route(
proxy_port,
Some("one.example.com"),
"127.0.0.1",
backend1_port,
),
make_tls_passthrough_route(
proxy_port,
Some("two.example.com"),
"127.0.0.1",
backend2_port,
),
],
..Default::default()
};
@@ -76,7 +86,8 @@ async fn test_tls_passthrough_sni_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -86,15 +97,22 @@ async fn test_tls_passthrough_sni_routing() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
// Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
assert!(
result.starts_with("BACKEND1:"),
"Expected BACKEND1 prefix, got: {}",
result
);
// Now test routing to backend2
let result2 = with_timeout(async {
let result2 = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -104,11 +122,17 @@ async fn test_tls_passthrough_sni_routing() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
assert!(
result2.starts_with("BACKEND2:"),
"Expected BACKEND2 prefix, got: {}",
result2
);
proxy.stop().await.unwrap();
}
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
],
routes: vec![make_tls_passthrough_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default()
};
@@ -132,7 +159,8 @@ async fn test_tls_passthrough_unknown_sni() {
assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -146,7 +174,9 @@ async fn test_tls_passthrough_unknown_sni() {
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
}, 5)
},
5,
)
.await
.unwrap();
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
],
routes: vec![make_tls_passthrough_route(
proxy_port,
Some("*.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default()
};
@@ -174,7 +207,8 @@ async fn test_tls_passthrough_wildcard_domain() {
assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -184,11 +218,17 @@ async fn test_tls_passthrough_wildcard_domain() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
assert!(
result.starts_with("WILDCARD:"),
"Expected WILDCARD prefix, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -222,7 +262,8 @@ async fn test_tls_passthrough_multiple_domains() {
("beta.example.com", "B2:"),
("gamma.example.com", "B3:"),
] {
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -232,14 +273,18 @@ async fn test_tls_passthrough_multiple_domains() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
assert!(
result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}",
domain, expected_prefix, result
domain,
expected_prefix,
result
);
}
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -84,7 +89,8 @@ async fn test_tls_terminate_basic() {
assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -100,7 +106,9 @@ async fn test_tls_terminate_basic() {
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&proxy_cert,
&proxy_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -138,7 +151,8 @@ async fn test_tls_terminate_and_reencrypt() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -154,7 +168,9 @@ async fn test_tls_terminate_and_reencrypt() {
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
let options = RustProxyOptions {
routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
make_tls_terminate_route(
proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
),
make_tls_terminate_route(
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
),
],
..Default::default()
};
@@ -188,7 +218,8 @@ async fn test_tls_terminate_sni_cert_selection() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -196,7 +227,8 @@ async fn test_tls_terminate_sni_cert_selection() {
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
@@ -204,11 +236,17 @@ async fn test_tls_terminate_sni_cert_selection() {
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
assert!(
result.starts_with("ALPHA:"),
"Expected ALPHA prefix, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -233,7 +276,8 @@ async fn test_tls_terminate_large_payload() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -252,7 +296,9 @@ async fn test_tls_terminate_large_payload() {
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 15)
},
15,
)
.await
.unwrap();
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -281,7 +332,8 @@ async fn test_tls_terminate_concurrent() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
@@ -311,7 +363,9 @@ async fn test_tls_terminate_concurrent() {
results.push(handle.await.unwrap());
}
results
}, 15)
},
15,
)
.await
.unwrap();
+1 -1
View File
@@ -3,6 +3,6 @@
*/
export const commitinfo = {
name: '@push.rocks/smartproxy',
version: '27.8.2',
version: '27.9.0',
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
}
+1 -1
View File
@@ -7,7 +7,7 @@ export { SmartProxy } from './proxies/smart-proxy/index.js';
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
// Export smart-proxy models
export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
export * from './proxies/smart-proxy/utils/index.js';
+1 -1
View File
@@ -2,6 +2,6 @@
* SmartProxy models
*/
// Export everything except IAcmeOptions from interfaces
export type { ISmartProxyOptions, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
export * from './route-types.js';
export * from './metrics-types.js';
@@ -29,6 +29,11 @@ export interface ISmartProxyCertStore {
}
import type { IRouteConfig } from './route-types.js';
export interface ISmartProxySecurityPolicy {
blockedIps?: string[];
blockedCidrs?: string[];
}
/**
* Provision object for static or HTTP-01 certificate
*/
@@ -137,6 +142,7 @@ export interface ISmartProxyOptions {
// Rate limiting and security
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
securityPolicy?: ISmartProxySecurityPolicy; // Global ingress block policy, enforced before routing
// Enhanced keep-alive settings
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
+2 -1
View File
@@ -1,5 +1,5 @@
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
import type { IAcmeOptions, ISmartProxyOptions } from './interfaces.js';
import type { IAcmeOptions, ISmartProxyOptions, ISmartProxySecurityPolicy } from './interfaces.js';
import type {
IRouteAction,
IRouteConfig,
@@ -75,6 +75,7 @@ export interface IRustProxyOptions {
keepAliveInactivityMultiplier?: number;
extendedKeepAliveLifetime?: number;
metrics?: ISmartProxyOptions['metrics'];
securityPolicy?: ISmartProxySecurityPolicy;
acme?: IRustAcmeOptions;
}
@@ -7,6 +7,7 @@ import type {
IRustRouteConfig,
IRustStatistics,
} from './models/rust-types.js';
import type { ISmartProxySecurityPolicy } from './models/interfaces.js';
/**
* Type-safe command definitions for the Rust proxy IPC protocol.
@@ -15,6 +16,7 @@ type TSmartProxyCommands = {
start: { params: { config: IRustProxyOptions }; result: void };
stop: { params: Record<string, never>; result: void };
updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
setSecurityPolicy: { params: { policy: ISmartProxySecurityPolicy }; result: void };
getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
getStatistics: { params: Record<string, never>; result: IRustStatistics };
provisionCertificate: { params: { routeName: string }; result: void };
@@ -139,6 +141,10 @@ export class RustProxyBridge extends plugins.EventEmitter {
await this.bridge.sendCommand('updateRoutes', { routes });
}
public async setSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
await this.bridge.sendCommand('setSecurityPolicy', { policy });
}
public async getMetrics(): Promise<IRustMetricsSnapshot> {
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
}
+10 -1
View File
@@ -17,7 +17,7 @@ import { Mutex } from './utils/mutex.js';
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
// Types
import type { ISmartProxyOptions, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
import type { ISmartProxyOptions, ISmartProxySecurityPolicy, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
import type { IRouteConfig } from './models/route-types.js';
import type { IMetrics } from './models/metrics-types.js';
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
@@ -350,6 +350,15 @@ export class SmartProxy extends plugins.EventEmitter {
.catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' }));
}
/**
* Update the global ingress security policy without changing routes.
* The Rust engine applies this before route selection and backend connection.
*/
public async updateSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
this.settings.securityPolicy = policy;
await this.bridge.setSecurityPolicy(policy);
}
/**
* Provision a certificate for a named route.
*/
@@ -182,6 +182,7 @@ export function buildRustProxyOptions(
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
metrics: settings.metrics,
securityPolicy: settings.securityPolicy,
acme: serializeAcmeForRust(acme),
};
}