feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers

This commit is contained in:
2026-04-26 15:11:10 +00:00
parent 8fa3a51b03
commit af4908b63f
53 changed files with 2350 additions and 1196 deletions
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
@@ -7,8 +7,8 @@
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
let matches = entry.domain.as_deref()
let matches = entry
.domain
.as_deref()
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
.unwrap_or(false);
if matches {
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) {
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
if !RequestFilter::check_ip_security(
new_security,
&entry.source_ip,
entry.domain.as_deref(),
) {
info!(
"Terminating connection from {} — IP now blocked on route '{}'",
entry.source_ip, route_id
@@ -31,7 +31,8 @@ impl ConnectionTracker {
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
let count = self
.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
@@ -48,7 +49,10 @@ impl ConnectionTracker {
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
while timestamps
.front()
.is_some_and(|t| now.duration_since(*t) >= one_minute)
{
timestamps.pop_front();
}
@@ -111,7 +115,6 @@ impl ConnectionTracker {
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
}
#[cfg(test)]
@@ -1,8 +1,8 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
use rustproxy_metrics::MetricsCollector;
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
if let Some(data) = initial_data {
backend.write_all(data).await?;
if let Some(ref ctx) = metrics {
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
data.len() as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_c2b {
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
n as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
backend_write.shutdown(),
).await;
let _ =
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
total
});
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_b2c {
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
0,
n as u64,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
client_write.shutdown(),
).await;
let _ =
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
total
});
+16 -16
View File
@@ -4,26 +4,26 @@
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
//! and UDP datagram session tracking with forwarding.
pub mod tcp_listener;
pub mod sni_parser;
pub mod connection_registry;
pub mod connection_tracker;
pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_tracker;
pub mod connection_registry;
pub mod socket_opts;
pub mod udp_session;
pub mod udp_listener;
pub mod quic_handler;
pub mod sni_parser;
pub mod socket_opts;
pub mod tcp_listener;
pub mod tls_handler;
pub mod udp_listener;
pub mod udp_session;
pub use tcp_listener::*;
pub use sni_parser::*;
pub use connection_registry::*;
pub use connection_tracker::*;
pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_tracker::*;
pub use connection_registry::*;
pub use socket_opts::*;
pub use udp_session::*;
pub use udp_listener::*;
pub use quic_handler::*;
pub use sni_parser::*;
pub use socket_opts::*;
pub use tcp_listener::*;
pub use tls_handler::*;
pub use udp_listener::*;
pub use udp_session::*;
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
.position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end])
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
let line =
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader);
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
let command = data[12] & 0x0F;
// 0x0 = LOCAL, 0x1 = PROXY
if command > 1 {
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
return Err(ProxyProtocolError::Parse(format!(
"Unknown command: {}",
command
)));
}
// Address family (high nibble) + transport (low nibble) of byte 13
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + STREAM (0x1) = TCP4
(0x1, 0x1) => {
if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
}
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + DGRAM (0x2) = UDP4
(0x1, 0x2) => {
if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
}
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
(0x2, 0x1) => {
if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
(0x2, 0x2) => {
if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
}
/// Generate a PROXY protocol v2 binary header.
pub fn generate_v2(
source: &SocketAddr,
dest: &SocketAddr,
transport: ProxyV2Transport,
) -> Vec<u8> {
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
let transport_nibble: u8 = match transport {
ProxyV2Transport::Stream => 0x1,
ProxyV2Transport::Datagram => 0x2,
@@ -462,7 +469,10 @@ mod tests {
header.push(0x11);
header.extend_from_slice(&12u16.to_be_bytes());
header.extend_from_slice(&[0u8; 12]);
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
assert!(matches!(
parse_v2(&header),
Err(ProxyProtocolError::UnsupportedVersion)
));
}
#[test]
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService;
use crate::connection_tracker::ConnectionTracker;
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
use crate::connection_tracker::ConnectionTracker;
/// Create a QUIC server endpoint on the given port with the provided TLS config.
///
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
quinn::EndpointConfig::default(),
Some(server_config),
socket,
quinn::default_runtime()
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
info!("QUIC endpoint listening on port {}", port);
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
port: u16,
tls_config: Arc<RustlsServerConfig>,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
cancel: CancellationToken,
) -> anyhow::Result<QuicProxyRelay> {
// Bind external socket on the real port
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
quinn::EndpointConfig::default(),
Some(server_config),
internal_socket,
quinn::default_runtime()
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
let real_client_map = Arc::new(DashMap::new());
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
external_socket,
quinn_internal_addr,
proxy_ips,
security_policy,
Arc::clone(&real_client_map),
cancel,
));
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
info!(
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
port, quinn_internal_addr
);
Ok(QuicProxyRelay {
endpoint,
relay_task,
real_client_map,
})
}
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
external_socket: Arc<UdpSocket>,
quinn_internal_addr: SocketAddr,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
cancel: CancellationToken,
) {
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
debug!(
"QUIC PROXY v2 from {}: real client {}",
src_addr, header.source_addr
);
proxy_addr_map.insert(src_addr, header.source_addr);
continue; // consume the PROXY v2 datagram
}
Err(e) => {
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
debug!(
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
src_addr, e
);
}
}
}
}
// Determine real client address
let real_client = proxy_addr_map.get(&src_addr)
let real_client = proxy_addr_map
.get(&src_addr)
.map(|r| *r)
.unwrap_or(src_addr);
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
debug!(
"QUIC datagram from {} blocked by global security policy",
real_client
);
continue;
}
// Get or create relay session for this external source
let session = match relay_sessions.get(&src_addr) {
Some(s) => {
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
s.last_activity
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
Arc::clone(s.value())
}
None => {
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
}
};
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
warn!(
"QUIC relay: failed to connect relay socket to {}: {}",
quinn_internal_addr, e
);
continue;
}
let relay_local_addr = match relay_socket.local_addr() {
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
});
relay_sessions.insert(src_addr, Arc::clone(&session));
debug!("QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client);
debug!(
"QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client
);
session
}
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
if last_cleanup.elapsed() >= cleanup_interval {
last_cleanup = Instant::now();
let now_ms = epoch.elapsed().as_millis() as u64;
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
let stale_keys: Vec<SocketAddr> = relay_sessions
.iter()
.filter(|entry| {
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
let age =
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
age > session_timeout_ms
})
.map(|entry| *entry.key())
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
// Also clean orphaned proxy_addr_map entries (PROXY header received
// but no relay session was ever created, e.g. client never sent data)
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
let orphaned: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| relay_sessions.get(entry.key()).is_none())
.map(|entry| *entry.key())
.collect();
for key in orphaned {
proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
debug!(
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
key
);
}
}
}
@@ -328,8 +365,14 @@ async fn relay_return_path(
}
};
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
if let Err(e) = external_socket
.send_to(&buf[..len], external_src_addr)
.await
{
debug!(
"QUIC relay return send error to {}: {}",
external_src_addr, e
);
break;
}
}
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
route_cancels: Arc<DashMap<String, CancellationToken>>,
connection_registry: Arc<ConnectionRegistry>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) {
loop {
let incoming = tokio::select! {
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
let remote_addr = incoming.remote_address();
// Resolve real client IP from PROXY protocol map if available
let real_addr = real_client_map.as_ref()
let real_addr = real_client_map
.as_ref()
.and_then(|map| map.get(&remote_addr).map(|r| *r))
.unwrap_or(remote_addr);
let ip = real_addr.ip();
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&ip) {
debug!(
"QUIC connection from {} blocked by global security policy",
real_addr
);
continue;
}
// Per-IP rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
@@ -414,7 +468,10 @@ pub async fn quic_accept_loop(
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security, &ip, ctx.domain,
) {
debug!("QUIC connection from {} blocked by route security", real_addr);
debug!(
"QUIC connection from {} blocked by route security",
real_addr
);
continue;
}
}
@@ -425,7 +482,8 @@ pub async fn quic_accept_loop(
// Resolve per-route cancel token (child of global cancel)
let route_cancel = match route_id.as_deref() {
Some(id) => route_cancels.entry(id.to_string())
Some(id) => route_cancels
.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
None => cancel.child_token(),
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
let metrics = Arc::clone(&metrics);
let conn_tracker = Arc::clone(&conn_tracker);
let h3_svc = h3_service.clone();
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
let real_client_addr = if real_addr != remote_addr {
Some(real_addr)
} else {
None
};
tokio::spawn(async move {
// Register in connection registry (RAII guard removes on drop)
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
impl Drop for QuicConnGuard {
fn drop(&mut self) {
self.tracker.connection_closed(&self.ip);
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
self.metrics
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
}
}
let _guard = QuicConnGuard {
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
route_id,
};
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await {
match handle_quic_connection(
incoming,
route,
port,
Arc::clone(&metrics),
&conn_cancel,
h3_svc,
real_client_addr,
)
.await
{
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
}
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
debug!("QUIC connection established from {}", effective_addr);
// Check if this route has HTTP/3 enabled
let enable_http3 = route.action.udp.as_ref()
let enable_http3 = route
.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.and_then(|q| q.enable_http3)
.unwrap_or(false);
if enable_http3 {
if let Some(ref h3_svc) = h3_service {
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
debug!(
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
route.name
);
h3_svc
.handle_connection(connection, &route, port, real_client_addr, cancel)
.await
} else {
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
warn!(
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
route.name
);
// Keep connection alive until cancelled
tokio::select! {
_ = cancel.cancelled() => {}
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
}
} else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
.await
}
}
@@ -545,7 +630,10 @@ async fn handle_quic_stream_forwarding(
let metrics_arc = metrics;
// Resolve backend target
let target = route.action.targets.as_ref()
let target = route
.action
.targets
.as_ref()
.and_then(|t| t.first())
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
let backend_host = target.host.first();
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move {
match forward_quic_stream_to_tcp(
send_stream,
recv_stream,
&backend_addr,
stream_cancel,
).await {
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
.await
{
Ok((bytes_in, bytes_out)) => {
stream_metrics.record_bytes(
bytes_in, bytes_out,
bytes_in,
bytes_out,
stream_route_id.as_deref(),
Some(&ip_str),
);
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
debug!(
"QUIC stream forwarded: {}B in, {}B out",
bytes_in, bytes_out
);
}
Err(e) => {
debug!("QUIC stream forwarding error: {}", e);
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
tcp_write.shutdown(),
).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
total
});
@@ -721,8 +807,8 @@ mod tests {
let _ = rustls::crypto::ring::default_provider().install_default();
// Generate a single self-signed cert and use its key pair
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.unwrap();
let self_signed =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_der = self_signed.cert.der().clone();
let key_der = self_signed.key_pair.serialize_der();
@@ -737,6 +823,10 @@ mod tests {
// Port 0 = OS assigns a free port
let result = create_quic_endpoint(0, Arc::new(tls_config));
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
assert!(
result.is_ok(),
"QUIC endpoint creation failed: {:?}",
result.err()
);
}
}
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
}
// Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16)
| ((data[7] as usize) << 8)
| (data[8] as usize);
let _handshake_len =
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
let hello = &data[9..];
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
if let Some(value) = line
.strip_prefix("Host: ")
.or_else(|| line.strip_prefix("host: "))
{
// Strip port if present
let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() {
@@ -196,7 +198,7 @@ pub fn is_http(data: &[u8]) -> bool {
b"PATC",
b"OPTI",
b"CONN",
b"PRI ", // HTTP/2 connection preface
b"PRI ", // HTTP/2 connection preface
];
starts.iter().any(|s| data.starts_with(s))
}
@@ -213,7 +215,10 @@ mod tests {
#[test]
fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
assert!(matches!(
extract_sni(&[0x16, 0x03]),
SniResult::NeedMoreData
));
}
#[test]
@@ -263,7 +268,8 @@ mod tests {
// Extension: type=0x0000 (SNI), length, data
let sni_extension = {
let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type
e.push(0x00);
e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data);
@@ -283,16 +289,20 @@ mod tests {
let hello_body = {
let mut h = Vec::new();
// Client version TLS 1.2
h.push(0x03); h.push(0x03);
h.push(0x03);
h.push(0x03);
// Random (32 bytes)
h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0
h.push(0x00);
// Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01); h.push(0x00);
h.push(0x00);
h.push(0x02);
h.push(0x00);
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01);
h.push(0x00);
// Extensions
h.extend_from_slice(&extensions);
h
@@ -302,7 +312,7 @@ mod tests {
let handshake = {
let mut hs = Vec::new();
hs.push(0x01); // ClientHello
// 3-byte length
// 3-byte length
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
hs.push((hello_body.len() & 0xFF) as u8);
@@ -313,7 +323,8 @@ mod tests {
// TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new();
record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0
record.push(0x03);
record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake);
File diff suppressed because it is too large Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig;
@@ -29,7 +29,9 @@ pub struct CertResolver {
impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
pub fn new(
configs: &HashMap<String, TlsCertConfig>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new();
@@ -38,8 +40,10 @@ impl CertResolver {
for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
let ck = Arc::new(
CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
);
if domain == "*" {
fallback = Some(Arc::clone(&ck));
}
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_shared_tls_acceptor(
resolver: CertResolver,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
// Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new()
.map_err(|e| format!("Ticketer: {}", e))?;
config.ticketer =
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
info!(
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
);
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_tls_acceptor(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_tls_acceptor_h1_only(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
builder.with_no_client_auth().with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
fn resolve_tls_versions(
versions: Option<&[String]>,
) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
@@ -207,15 +221,17 @@ pub async fn accept_tls(
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG.get_or_init(|| {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
info!("Built shared backend TLS client config with session resumption");
Arc::new(config)
}).clone()
SHARED_CLIENT_CONFIG
.get_or_init(|| {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
info!("Built shared backend TLS client config with session resumption");
Arc::new(config)
})
.clone()
}
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
ensure_crypto_provider();
let mut config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
Arc::new(config)
}).clone()
SHARED_CLIENT_CONFIG_ALPN
.get_or_init(|| {
ensure_crypto_provider();
let mut config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
info!(
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
);
Arc::new(config)
})
.clone()
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
@@ -249,7 +269,8 @@ pub async fn connect_tls(
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
{
debug!("Failed to set keepalive on backend TLS socket: {}", e);
}
@@ -260,10 +281,12 @@ pub async fn connect_tls(
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
fn load_certs(
pem: &str,
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
let certs: Vec<CertificateDer<'static>> =
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
fn load_private_key(
pem: &str,
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
let key =
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
Ok(key)
}
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use arc_swap::ArcSwap;
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use rustproxy_config::{RouteActionType, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService;
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
route_cancels: Arc<DashMap<String, CancellationToken>>,
/// Shared connection registry for selective recycling.
connection_registry: Arc<ConnectionRegistry>,
/// Global ingress block policy, hot-reloadable without restarting listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
}
impl Drop for UdpListenerManager {
@@ -99,17 +102,26 @@ impl UdpListenerManager {
proxy_ips: Arc::new(Vec::new()),
route_cancels,
connection_registry,
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
}
}
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
if !ips.is_empty() {
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
info!(
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
ips.len()
);
}
self.proxy_ips = Arc::new(ips);
}
/// Set the shared global ingress security policy.
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
self.security_policy = policy;
}
/// Set the H3 proxy service for HTTP/3 request handling.
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
self.h3_service = Some(svc);
@@ -142,7 +154,9 @@ impl UdpListenerManager {
// Check if any route on this port uses QUIC
let rm = self.route_manager.load();
let has_quic = rm.routes_for_port(port).iter().any(|r| {
r.action.udp.as_ref()
r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.is_some()
});
@@ -164,8 +178,10 @@ impl UdpListenerManager {
None,
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint started on port {}", port);
} else {
// Proxy relay path: we own external socket, quinn on localhost
@@ -173,6 +189,7 @@ impl UdpListenerManager {
port,
tls,
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(),
)?;
let endpoint_for_updates = relay.endpoint.clone();
@@ -187,13 +204,18 @@ impl UdpListenerManager {
Some(relay.real_client_map),
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint with PROXY relay started on port {}", port);
}
return Ok(());
} else {
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
warn!(
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
port
);
}
}
@@ -214,6 +236,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer),
self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, None));
@@ -254,8 +277,10 @@ impl UdpListenerManager {
}
debug!("UDP listener stopped on port {}", port);
}
info!("All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count());
info!(
"All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count()
);
}
/// Update TLS config on all active QUIC endpoints (cert refresh).
@@ -288,11 +313,15 @@ impl UdpListenerManager {
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
let rm = self.route_manager.load();
let upgrade_ports: Vec<u16> = self.listeners.iter()
let upgrade_ports: Vec<u16> = self
.listeners
.iter()
.filter(|(_, (_, endpoint))| endpoint.is_none())
.filter(|(port, _)| {
rm.routes_for_port(**port).iter().any(|r| {
r.action.udp.as_ref()
r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.is_some()
})
@@ -301,17 +330,23 @@ impl UdpListenerManager {
.collect();
for port in upgrade_ports {
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
info!(
"Upgrading raw UDP listener on port {} to QUIC endpoint",
port
);
// Stop the raw UDP listener task and drain sessions to release the socket
if let Some((handle, _)) = self.listeners.remove(&port) {
handle.abort();
}
let drained = self.session_table.drain_port(
port, &self.metrics, &self.conn_tracker,
);
let drained = self
.session_table
.drain_port(port, &self.metrics, &self.conn_tracker);
if drained > 0 {
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
debug!(
"Drained {} UDP sessions on port {} for QUIC upgrade",
drained, port
);
}
// Brief yield to let aborted tasks drop their socket references
@@ -326,11 +361,17 @@ impl UdpListenerManager {
match create_result {
Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
info!(
"QUIC endpoint started on port {} (upgraded from raw UDP)",
port
);
}
Err(e) => {
// Port may still be held — retry once after a brief delay
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
warn!(
"QUIC endpoint creation failed on port {}, retrying: {}",
port, e
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let retry_result = if self.proxy_ips.is_empty() {
@@ -341,11 +382,17 @@ impl UdpListenerManager {
match retry_result {
Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
info!(
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
port
);
}
Err(e2) => {
error!("Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.", port, e2);
error!(
"Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.",
port, e2
);
// Fallback: rebind as raw UDP so the port isn't dead
if let Ok(()) = self.rebind_raw_udp(port).await {
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
@@ -358,7 +405,11 @@ impl UdpListenerManager {
}
/// Create a direct QUIC endpoint (quinn owns the socket).
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
fn create_quic_direct(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
let endpoint_for_updates = endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
@@ -372,17 +423,24 @@ impl UdpListenerManager {
None,
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(())
}
/// Create a QUIC endpoint with PROXY protocol relay.
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
fn create_quic_with_relay(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
port,
tls_config,
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(),
)?;
let endpoint_for_updates = relay.endpoint.clone();
@@ -397,8 +455,10 @@ impl UdpListenerManager {
Some(relay.real_client_map),
Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(())
}
@@ -419,6 +479,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer),
self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
));
self.listeners.insert(port, (handle, None));
@@ -458,7 +519,10 @@ impl UdpListenerManager {
info!("Datagram handler relay connected to {}", path);
}
Err(e) => {
error!("Failed to connect datagram handler relay to {}: {}", path, e);
error!(
"Failed to connect datagram handler relay to {}: {}",
path, e
);
}
}
}
@@ -514,6 +578,7 @@ impl UdpListenerManager {
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
cancel: CancellationToken,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) {
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
let mut buf = vec![0u8; 65535];
@@ -528,9 +593,11 @@ impl UdpListenerManager {
loop {
// Periodic cleanup: remove proxy_addr_map entries with no active session
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
{
last_proxy_cleanup = tokio::time::Instant::now();
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
let stale: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| {
let key: SessionKey = (*entry.key(), port);
session_table.get(&key).is_none()
@@ -538,7 +605,11 @@ impl UdpListenerManager {
.map(|entry| *entry.key())
.collect();
if !stale.is_empty() {
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
debug!(
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
stale.len(),
port
);
for addr in stale {
proxy_addr_map.remove(&addr);
}
@@ -564,34 +635,50 @@ impl UdpListenerManager {
let datagram = &buf[..len];
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
let session_key: SessionKey = (client_addr, port);
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
// No session and no prior PROXY header — check for PROXY v2
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
proxy_addr_map.insert(client_addr, header.source_addr);
continue; // discard the PROXY v2 datagram
}
Err(e) => {
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
client_addr.ip()
let effective_client_ip =
if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
let session_key: SessionKey = (client_addr, port);
if session_table.get(&session_key).is_none()
&& !proxy_addr_map.contains_key(&client_addr)
{
// No session and no prior PROXY header — check for PROXY v2
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!(
"UDP PROXY v2 from {}: real client {}",
client_addr, header.source_addr
);
proxy_addr_map.insert(client_addr, header.source_addr);
continue; // discard the PROXY v2 datagram
}
Err(e) => {
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
client_addr.ip()
}
}
} else {
client_addr.ip()
}
} else {
client_addr.ip()
// Use real client IP if we've previously seen a PROXY v2 header
proxy_addr_map
.get(&client_addr)
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip())
}
} else {
// Use real client IP if we've previously seen a PROXY v2 header
proxy_addr_map.get(&client_addr)
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip())
}
} else {
client_addr.ip()
};
client_addr.ip()
};
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
debug!(
"UDP datagram from {} blocked by global security policy",
effective_client_ip
);
continue;
}
// Route matching — use effective (real) client IP
let rm = route_manager.load();
@@ -611,7 +698,10 @@ impl UdpListenerManager {
let route_match = match rm.find_route(&ctx) {
Some(m) => m,
None => {
debug!("No UDP route matched for port {} from {}", port, client_addr);
debug!(
"No UDP route matched for port {} from {}",
port, client_addr
);
continue;
}
};
@@ -627,7 +717,9 @@ impl UdpListenerManager {
&client_addr,
port,
datagram,
).await {
)
.await
{
debug!("Failed to relay UDP datagram to TS: {}", e);
}
continue;
@@ -638,8 +730,10 @@ impl UdpListenerManager {
// Check datagram size
if len as u32 > udp_config.max_datagram_size {
debug!("UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr);
debug!(
"UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr
);
continue;
}
@@ -651,21 +745,27 @@ impl UdpListenerManager {
None => {
// New session — check per-IP limits using the real client IP
if !conn_tracker.try_accept(&effective_client_ip) {
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
debug!(
"UDP session rejected for {} (rate limit)",
effective_client_ip
);
continue;
}
if !session_table.can_create_session(
&effective_client_ip,
udp_config.max_sessions_per_ip,
) {
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
if !session_table
.can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
{
debug!(
"UDP session rejected for {} (per-IP session limit)",
effective_client_ip
);
continue;
}
// Resolve target
let target = match route_match.target.or_else(|| {
route.action.targets.as_ref().and_then(|t| t.first())
}) {
let target = match route_match
.target
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
{
Some(t) => t,
None => {
warn!("No target for UDP route {:?}", route_id);
@@ -686,13 +786,18 @@ impl UdpListenerManager {
}
};
if let Err(e) = backend_socket.connect(&backend_addr).await {
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
error!(
"Failed to connect backend UDP socket to {}: {}",
backend_addr, e
);
continue;
}
let backend_socket = Arc::new(backend_socket);
debug!("New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip);
debug!(
"New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip
);
// Spawn return-path relay task
let session_cancel = CancellationToken::new();
@@ -709,7 +814,9 @@ impl UdpListenerManager {
let session = Arc::new(UdpSession {
backend_socket,
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
last_activity: std::sync::atomic::AtomicU64::new(
session_table.elapsed_ms(),
),
created_at: std::time::Instant::now(),
route_id: route_id.map(|s| s.to_string()),
source_ip: effective_client_ip,
@@ -718,7 +825,11 @@ impl UdpListenerManager {
cancel: session_cancel,
});
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
if !session_table.insert(
session_key,
Arc::clone(&session),
udp_config.max_sessions_per_ip,
) {
warn!("Failed to insert UDP session (race condition)");
continue;
}
@@ -735,7 +846,9 @@ impl UdpListenerManager {
// Forward datagram to backend
match session.backend_socket.send(datagram).await {
Ok(_) => {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
metrics.record_datagram_in();
}
@@ -779,7 +892,9 @@ impl UdpListenerManager {
Ok(_) => {
// Update session activity
if let Some(session) = session_table.get(&session_key) {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
}
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
metrics.record_datagram_out();
@@ -814,7 +929,8 @@ impl UdpListenerManager {
let json = serde_json::to_vec(&msg)?;
let mut guard = writer.lock().await;
let stream = guard.as_mut()
let stream = guard
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
// Length-prefixed frame
@@ -879,9 +995,15 @@ impl UdpListenerManager {
}
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let source_port = reply
.get("sourcePort")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u16;
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or("");
let payload_b64 = reply
.get("payloadBase64")
.and_then(|v| v.as_str())
.unwrap_or("");
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
Ok(p) => p,
@@ -111,12 +111,15 @@ impl UdpSessionTable {
/// Look up an existing session.
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
self.sessions
.get(key)
.map(|entry| Arc::clone(entry.value()))
}
/// Check if we can create a new session for this IP (under the per-IP limit).
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
let count = self.ip_session_counts
let count = self
.ip_session_counts
.get(ip)
.map(|c| *c.value())
.unwrap_or(0);
@@ -124,12 +127,7 @@ impl UdpSessionTable {
}
/// Insert a new session. Returns false if per-IP limit exceeded.
pub fn insert(
&self,
key: SessionKey,
session: Arc<UdpSession>,
max_per_ip: u32,
) -> bool {
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
let ip = session.source_ip;
// Atomically check and increment per-IP count
@@ -173,7 +171,9 @@ impl UdpSessionTable {
let mut removed = 0;
// Collect keys to remove (avoid holding DashMap refs during removal)
let stale_keys: Vec<SessionKey> = self.sessions.iter()
let stale_keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| {
let last = entry.value().last_activity.load(Ordering::Relaxed);
now_ms.saturating_sub(last) >= timeout_ms
@@ -185,7 +185,8 @@ impl UdpSessionTable {
if let Some(session) = self.remove(&key) {
debug!(
"UDP session expired: {} -> port {} (idle {}ms)",
session.client_addr, key.1,
session.client_addr,
key.1,
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
);
conn_tracker.connection_closed(&session.source_ip);
@@ -210,7 +211,9 @@ impl UdpSessionTable {
metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker,
) -> usize {
let keys: Vec<SessionKey> = self.sessions.iter()
let keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| entry.key().1 == port)
.map(|entry| *entry.key())
.collect();
@@ -257,9 +260,8 @@ mod tests {
.enable_all()
.build()
.unwrap();
let backend_socket = rt.block_on(async {
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
});
let backend_socket =
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
let child_cancel = cancel.child_token();
let return_task = rt.spawn(async move {