feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
This commit is contained in:
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
|
||||
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
let matches = entry.domain.as_deref()
|
||||
let matches = entry
|
||||
.domain
|
||||
.as_deref()
|
||||
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
||||
.unwrap_or(false);
|
||||
if matches {
|
||||
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
if entry.route_id.as_deref() == Some(route_id) {
|
||||
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
|
||||
if !RequestFilter::check_ip_security(
|
||||
new_security,
|
||||
&entry.source_ip,
|
||||
entry.domain.as_deref(),
|
||||
) {
|
||||
info!(
|
||||
"Terminating connection from {} — IP now blocked on route '{}'",
|
||||
entry.source_ip, route_id
|
||||
|
||||
@@ -31,7 +31,8 @@ impl ConnectionTracker {
|
||||
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
||||
// Check per-IP connection limit
|
||||
if let Some(max) = self.max_per_ip {
|
||||
let count = self.active
|
||||
let count = self
|
||||
.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
@@ -48,7 +49,10 @@ impl ConnectionTracker {
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove timestamps older than 1 minute
|
||||
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
||||
while timestamps
|
||||
.front()
|
||||
.is_some_and(|t| now.duration_since(*t) >= one_minute)
|
||||
{
|
||||
timestamps.pop_front();
|
||||
}
|
||||
|
||||
@@ -111,7 +115,6 @@ impl ConnectionTracker {
|
||||
pub fn tracked_ips(&self) -> usize {
|
||||
self.active.len()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref ctx) = metrics {
|
||||
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
data.len() as u64,
|
||||
0,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some(ref ctx) = metrics_c2b {
|
||||
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
n as u64,
|
||||
0,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
backend_write.shutdown(),
|
||||
).await;
|
||||
let _ =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some(ref ctx) = metrics_b2c {
|
||||
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
0,
|
||||
n as u64,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
client_write.shutdown(),
|
||||
).await;
|
||||
let _ =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
|
||||
@@ -4,26 +4,26 @@
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
|
||||
//! and UDP datagram session tracking with forwarding.
|
||||
|
||||
pub mod tcp_listener;
|
||||
pub mod sni_parser;
|
||||
pub mod connection_registry;
|
||||
pub mod connection_tracker;
|
||||
pub mod forwarder;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls_handler;
|
||||
pub mod connection_tracker;
|
||||
pub mod connection_registry;
|
||||
pub mod socket_opts;
|
||||
pub mod udp_session;
|
||||
pub mod udp_listener;
|
||||
pub mod quic_handler;
|
||||
pub mod sni_parser;
|
||||
pub mod socket_opts;
|
||||
pub mod tcp_listener;
|
||||
pub mod tls_handler;
|
||||
pub mod udp_listener;
|
||||
pub mod udp_session;
|
||||
|
||||
pub use tcp_listener::*;
|
||||
pub use sni_parser::*;
|
||||
pub use connection_registry::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use forwarder::*;
|
||||
pub use proxy_protocol::*;
|
||||
pub use tls_handler::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use connection_registry::*;
|
||||
pub use socket_opts::*;
|
||||
pub use udp_session::*;
|
||||
pub use udp_listener::*;
|
||||
pub use quic_handler::*;
|
||||
pub use sni_parser::*;
|
||||
pub use socket_opts::*;
|
||||
pub use tcp_listener::*;
|
||||
pub use tls_handler::*;
|
||||
pub use udp_listener::*;
|
||||
pub use udp_session::*;
|
||||
|
||||
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
.position(|w| w == b"\r\n")
|
||||
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
let line = std::str::from_utf8(&data[..line_end])
|
||||
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
let line =
|
||||
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
if !line.starts_with("PROXY ") {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let command = data[12] & 0x0F;
|
||||
// 0x0 = LOCAL, 0x1 = PROXY
|
||||
if command > 1 {
|
||||
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
|
||||
return Err(ProxyProtocolError::Parse(format!(
|
||||
"Unknown command: {}",
|
||||
command
|
||||
)));
|
||||
}
|
||||
|
||||
// Address family (high nibble) + transport (low nibble) of byte 13
|
||||
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET (0x1) + STREAM (0x1) = TCP4
|
||||
(0x1, 0x1) => {
|
||||
if addr_len < 12 {
|
||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv4 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET (0x1) + DGRAM (0x2) = UDP4
|
||||
(0x1, 0x2) => {
|
||||
if addr_len < 12 {
|
||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv4 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
|
||||
(0x2, 0x1) => {
|
||||
if addr_len < 36 {
|
||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv6 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
|
||||
(0x2, 0x2) => {
|
||||
if addr_len < 36 {
|
||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv6 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
}
|
||||
|
||||
/// Generate a PROXY protocol v2 binary header.
|
||||
pub fn generate_v2(
|
||||
source: &SocketAddr,
|
||||
dest: &SocketAddr,
|
||||
transport: ProxyV2Transport,
|
||||
) -> Vec<u8> {
|
||||
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
|
||||
let transport_nibble: u8 = match transport {
|
||||
ProxyV2Transport::Stream => 0x1,
|
||||
ProxyV2Transport::Datagram => 0x2,
|
||||
@@ -462,7 +469,10 @@ mod tests {
|
||||
header.push(0x11);
|
||||
header.extend_from_slice(&12u16.to_be_bytes());
|
||||
header.extend_from_slice(&[0u8; 12]);
|
||||
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
|
||||
assert!(matches!(
|
||||
parse_v2(&header),
|
||||
Err(ProxyProtocolError::UnsupportedVersion)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
|
||||
use rustproxy_config::{RouteConfig, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_security::IpBlockList;
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
|
||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||
///
|
||||
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
info!("QUIC endpoint listening on port {}", port);
|
||||
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
port: u16,
|
||||
tls_config: Arc<RustlsServerConfig>,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<QuicProxyRelay> {
|
||||
// Bind external socket on the real port
|
||||
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
internal_socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
let real_client_map = Arc::new(DashMap::new());
|
||||
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
external_socket,
|
||||
quinn_internal_addr,
|
||||
proxy_ips,
|
||||
security_policy,
|
||||
Arc::clone(&real_client_map),
|
||||
cancel,
|
||||
));
|
||||
|
||||
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
|
||||
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
|
||||
info!(
|
||||
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
|
||||
port, quinn_internal_addr
|
||||
);
|
||||
Ok(QuicProxyRelay {
|
||||
endpoint,
|
||||
relay_task,
|
||||
real_client_map,
|
||||
})
|
||||
}
|
||||
|
||||
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
||||
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
|
||||
external_socket: Arc<UdpSocket>,
|
||||
quinn_internal_addr: SocketAddr,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
|
||||
debug!(
|
||||
"QUIC PROXY v2 from {}: real client {}",
|
||||
src_addr, header.source_addr
|
||||
);
|
||||
proxy_addr_map.insert(src_addr, header.source_addr);
|
||||
continue; // consume the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
|
||||
debug!(
|
||||
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
|
||||
src_addr, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine real client address
|
||||
let real_client = proxy_addr_map.get(&src_addr)
|
||||
let real_client = proxy_addr_map
|
||||
.get(&src_addr)
|
||||
.map(|r| *r)
|
||||
.unwrap_or(src_addr);
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
|
||||
debug!(
|
||||
"QUIC datagram from {} blocked by global security policy",
|
||||
real_client
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get or create relay session for this external source
|
||||
let session = match relay_sessions.get(&src_addr) {
|
||||
Some(s) => {
|
||||
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
s.last_activity
|
||||
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
Arc::clone(s.value())
|
||||
}
|
||||
None => {
|
||||
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
|
||||
}
|
||||
};
|
||||
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
||||
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
|
||||
warn!(
|
||||
"QUIC relay: failed to connect relay socket to {}: {}",
|
||||
quinn_internal_addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let relay_local_addr = match relay_socket.local_addr() {
|
||||
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
|
||||
});
|
||||
|
||||
relay_sessions.insert(src_addr, Arc::clone(&session));
|
||||
debug!("QUIC relay: new session for {} (relay {}), real client {}",
|
||||
src_addr, relay_local_addr, real_client);
|
||||
debug!(
|
||||
"QUIC relay: new session for {} (relay {}), real client {}",
|
||||
src_addr, relay_local_addr, real_client
|
||||
);
|
||||
|
||||
session
|
||||
}
|
||||
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
|
||||
if last_cleanup.elapsed() >= cleanup_interval {
|
||||
last_cleanup = Instant::now();
|
||||
let now_ms = epoch.elapsed().as_millis() as u64;
|
||||
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
|
||||
let stale_keys: Vec<SocketAddr> = relay_sessions
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||
let age =
|
||||
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||
age > session_timeout_ms
|
||||
})
|
||||
.map(|entry| *entry.key())
|
||||
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
|
||||
|
||||
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
||||
// but no relay session was ever created, e.g. client never sent data)
|
||||
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||
let orphaned: Vec<SocketAddr> = proxy_addr_map
|
||||
.iter()
|
||||
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
for key in orphaned {
|
||||
proxy_addr_map.remove(&key);
|
||||
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
|
||||
debug!(
|
||||
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
|
||||
key
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -328,8 +365,14 @@ async fn relay_return_path(
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
|
||||
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
|
||||
if let Err(e) = external_socket
|
||||
.send_to(&buf[..len], external_src_addr)
|
||||
.await
|
||||
{
|
||||
debug!(
|
||||
"QUIC relay return send error to {}: {}",
|
||||
external_src_addr, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
|
||||
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
) {
|
||||
loop {
|
||||
let incoming = tokio::select! {
|
||||
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
|
||||
let remote_addr = incoming.remote_address();
|
||||
|
||||
// Resolve real client IP from PROXY protocol map if available
|
||||
let real_addr = real_client_map.as_ref()
|
||||
let real_addr = real_client_map
|
||||
.as_ref()
|
||||
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
||||
.unwrap_or(remote_addr);
|
||||
let ip = real_addr.ip();
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&ip) {
|
||||
debug!(
|
||||
"QUIC connection from {} blocked by global security policy",
|
||||
real_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Per-IP rate limiting
|
||||
if !conn_tracker.try_accept(&ip) {
|
||||
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
||||
@@ -414,7 +468,10 @@ pub async fn quic_accept_loop(
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security, &ip, ctx.domain,
|
||||
) {
|
||||
debug!("QUIC connection from {} blocked by route security", real_addr);
|
||||
debug!(
|
||||
"QUIC connection from {} blocked by route security",
|
||||
real_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -425,7 +482,8 @@ pub async fn quic_accept_loop(
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel)
|
||||
let route_cancel = match route_id.as_deref() {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
Some(id) => route_cancels
|
||||
.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel.child_token(),
|
||||
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
|
||||
let metrics = Arc::clone(&metrics);
|
||||
let conn_tracker = Arc::clone(&conn_tracker);
|
||||
let h3_svc = h3_service.clone();
|
||||
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
||||
let real_client_addr = if real_addr != remote_addr {
|
||||
Some(real_addr)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Register in connection registry (RAII guard removes on drop)
|
||||
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
|
||||
impl Drop for QuicConnGuard {
|
||||
fn drop(&mut self) {
|
||||
self.tracker.connection_closed(&self.ip);
|
||||
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||
self.metrics
|
||||
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||
}
|
||||
}
|
||||
let _guard = QuicConnGuard {
|
||||
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
|
||||
route_id,
|
||||
};
|
||||
|
||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await {
|
||||
match handle_quic_connection(
|
||||
incoming,
|
||||
route,
|
||||
port,
|
||||
Arc::clone(&metrics),
|
||||
&conn_cancel,
|
||||
h3_svc,
|
||||
real_client_addr,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
||||
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
||||
}
|
||||
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
|
||||
debug!("QUIC connection established from {}", effective_addr);
|
||||
|
||||
// Check if this route has HTTP/3 enabled
|
||||
let enable_http3 = route.action.udp.as_ref()
|
||||
let enable_http3 = route
|
||||
.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.and_then(|q| q.enable_http3)
|
||||
.unwrap_or(false);
|
||||
|
||||
if enable_http3 {
|
||||
if let Some(ref h3_svc) = h3_service {
|
||||
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
|
||||
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
|
||||
debug!(
|
||||
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
|
||||
route.name
|
||||
);
|
||||
h3_svc
|
||||
.handle_connection(connection, &route, port, real_client_addr, cancel)
|
||||
.await
|
||||
} else {
|
||||
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
|
||||
warn!(
|
||||
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
|
||||
route.name
|
||||
);
|
||||
// Keep connection alive until cancelled
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {}
|
||||
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
|
||||
}
|
||||
} else {
|
||||
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,7 +630,10 @@ async fn handle_quic_stream_forwarding(
|
||||
let metrics_arc = metrics;
|
||||
|
||||
// Resolve backend target
|
||||
let target = route.action.targets.as_ref()
|
||||
let target = route
|
||||
.action
|
||||
.targets
|
||||
.as_ref()
|
||||
.and_then(|t| t.first())
|
||||
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
||||
let backend_host = target.host.first();
|
||||
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
|
||||
|
||||
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||||
tokio::spawn(async move {
|
||||
match forward_quic_stream_to_tcp(
|
||||
send_stream,
|
||||
recv_stream,
|
||||
&backend_addr,
|
||||
stream_cancel,
|
||||
).await {
|
||||
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
|
||||
.await
|
||||
{
|
||||
Ok((bytes_in, bytes_out)) => {
|
||||
stream_metrics.record_bytes(
|
||||
bytes_in, bytes_out,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
stream_route_id.as_deref(),
|
||||
Some(&ip_str),
|
||||
);
|
||||
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
||||
debug!(
|
||||
"QUIC stream forwarded: {}B in, {}B out",
|
||||
bytes_in, bytes_out
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC stream forwarding error: {}", e);
|
||||
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
tcp_write.shutdown(),
|
||||
).await;
|
||||
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
@@ -721,8 +807,8 @@ mod tests {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Generate a single self-signed cert and use its key pair
|
||||
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
||||
.unwrap();
|
||||
let self_signed =
|
||||
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
|
||||
let cert_der = self_signed.cert.der().clone();
|
||||
let key_der = self_signed.key_pair.serialize_der();
|
||||
|
||||
@@ -737,6 +823,10 @@ mod tests {
|
||||
|
||||
// Port 0 = OS assigns a free port
|
||||
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
||||
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"QUIC endpoint creation failed: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes) - informational, we parse incrementally
|
||||
let _handshake_len = ((data[6] as usize) << 16)
|
||||
| ((data[7] as usize) << 8)
|
||||
| (data[8] as usize);
|
||||
let _handshake_len =
|
||||
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
|
||||
|
||||
let hello = &data[9..];
|
||||
|
||||
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
|
||||
pub fn extract_http_host(data: &[u8]) -> Option<String> {
|
||||
let text = std::str::from_utf8(data).ok()?;
|
||||
for line in text.split("\r\n") {
|
||||
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
|
||||
if let Some(value) = line
|
||||
.strip_prefix("Host: ")
|
||||
.or_else(|| line.strip_prefix("host: "))
|
||||
{
|
||||
// Strip port if present
|
||||
let host = value.split(':').next().unwrap_or(value).trim();
|
||||
if !host.is_empty() {
|
||||
@@ -196,7 +198,7 @@ pub fn is_http(data: &[u8]) -> bool {
|
||||
b"PATC",
|
||||
b"OPTI",
|
||||
b"CONN",
|
||||
b"PRI ", // HTTP/2 connection preface
|
||||
b"PRI ", // HTTP/2 connection preface
|
||||
];
|
||||
starts.iter().any(|s| data.starts_with(s))
|
||||
}
|
||||
@@ -213,7 +215,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_too_short() {
|
||||
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
|
||||
assert!(matches!(
|
||||
extract_sni(&[0x16, 0x03]),
|
||||
SniResult::NeedMoreData
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -263,7 +268,8 @@ mod tests {
|
||||
// Extension: type=0x0000 (SNI), length, data
|
||||
let sni_extension = {
|
||||
let mut e = Vec::new();
|
||||
e.push(0x00); e.push(0x00); // SNI type
|
||||
e.push(0x00);
|
||||
e.push(0x00); // SNI type
|
||||
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
||||
e.push((sni_ext_data.len() & 0xFF) as u8);
|
||||
e.extend_from_slice(&sni_ext_data);
|
||||
@@ -283,16 +289,20 @@ mod tests {
|
||||
let hello_body = {
|
||||
let mut h = Vec::new();
|
||||
// Client version TLS 1.2
|
||||
h.push(0x03); h.push(0x03);
|
||||
h.push(0x03);
|
||||
h.push(0x03);
|
||||
// Random (32 bytes)
|
||||
h.extend_from_slice(&[0u8; 32]);
|
||||
// Session ID length = 0
|
||||
h.push(0x00);
|
||||
// Cipher suites: length=2, one suite
|
||||
h.push(0x00); h.push(0x02);
|
||||
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01); h.push(0x00);
|
||||
h.push(0x00);
|
||||
h.push(0x02);
|
||||
h.push(0x00);
|
||||
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01);
|
||||
h.push(0x00);
|
||||
// Extensions
|
||||
h.extend_from_slice(&extensions);
|
||||
h
|
||||
@@ -302,7 +312,7 @@ mod tests {
|
||||
let handshake = {
|
||||
let mut hs = Vec::new();
|
||||
hs.push(0x01); // ClientHello
|
||||
// 3-byte length
|
||||
// 3-byte length
|
||||
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
|
||||
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
|
||||
hs.push((hello_body.len() & 0xFF) as u8);
|
||||
@@ -313,7 +323,8 @@ mod tests {
|
||||
// TLS record: type=0x16, version TLS 1.0, length
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // Handshake
|
||||
record.push(0x03); record.push(0x01); // TLS 1.0
|
||||
record.push(0x03);
|
||||
record.push(0x01); // TLS 1.0
|
||||
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
||||
record.push((handshake.len() & 0xFF) as u8);
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
|
||||
use rustls::sign::CertifiedKey;
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
||||
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::tcp_listener::TlsCertConfig;
|
||||
@@ -29,7 +29,9 @@ pub struct CertResolver {
|
||||
impl CertResolver {
|
||||
/// Build a resolver from PEM-encoded cert/key configs.
|
||||
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
|
||||
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn new(
|
||||
configs: &HashMap<String, TlsCertConfig>,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let provider = rustls::crypto::ring::default_provider();
|
||||
let mut certs = HashMap::new();
|
||||
@@ -38,8 +40,10 @@ impl CertResolver {
|
||||
for (domain, cfg) in configs {
|
||||
let cert_chain = load_certs(&cfg.cert_pem)?;
|
||||
let key = load_private_key(&cfg.key_pem)?;
|
||||
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
|
||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
|
||||
let ck = Arc::new(
|
||||
CertifiedKey::from_der(cert_chain, key, &provider)
|
||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
|
||||
);
|
||||
if domain == "*" {
|
||||
fallback = Some(Arc::clone(&ck));
|
||||
}
|
||||
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
|
||||
|
||||
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
|
||||
/// The returned acceptor can be reused across all connections (cheap Arc clone).
|
||||
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_shared_tls_acceptor(
|
||||
resolver: CertResolver,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let mut config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
|
||||
// Shared session cache — enables session ID resumption across connections
|
||||
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
|
||||
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
|
||||
config.ticketer = rustls::crypto::ring::Ticketer::new()
|
||||
.map_err(|e| format!("Ticketer: {}", e))?;
|
||||
config.ticketer =
|
||||
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
|
||||
|
||||
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
|
||||
info!(
|
||||
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
|
||||
);
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
||||
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
|
||||
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_tls_acceptor(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
|
||||
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
|
||||
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_tls_acceptor_h1_only(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let certs = load_certs(cert_pem)?;
|
||||
let key = load_private_key(key_pem)?;
|
||||
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
|
||||
// Apply TLS version restrictions
|
||||
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
||||
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
||||
builder
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
builder.with_no_client_auth().with_single_cert(certs, key)?
|
||||
} else {
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
|
||||
}
|
||||
|
||||
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
||||
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
fn resolve_tls_versions(
|
||||
versions: Option<&[String]>,
|
||||
) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
let versions = match versions {
|
||||
Some(v) if !v.is_empty() => v,
|
||||
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
||||
@@ -207,15 +221,17 @@ pub async fn accept_tls(
|
||||
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||
|
||||
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||
SHARED_CLIENT_CONFIG.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
info!("Built shared backend TLS client config with session resumption");
|
||||
Arc::new(config)
|
||||
}).clone()
|
||||
SHARED_CLIENT_CONFIG
|
||||
.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
info!("Built shared backend TLS client config with session resumption");
|
||||
Arc::new(config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
|
||||
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||
|
||||
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
|
||||
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
|
||||
Arc::new(config)
|
||||
}).clone()
|
||||
SHARED_CLIENT_CONFIG_ALPN
|
||||
.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
info!(
|
||||
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
|
||||
);
|
||||
Arc::new(config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
||||
@@ -249,7 +269,8 @@ pub async fn connect_tls(
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
|
||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
|
||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
|
||||
{
|
||||
debug!("Failed to set keepalive on backend TLS socket: {}", e);
|
||||
}
|
||||
|
||||
@@ -260,10 +281,12 @@ pub async fn connect_tls(
|
||||
}
|
||||
|
||||
/// Load certificates from PEM string.
|
||||
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn load_certs(
|
||||
pem: &str,
|
||||
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let certs: Vec<CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in PEM data".into());
|
||||
}
|
||||
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
|
||||
}
|
||||
|
||||
/// Load private key from PEM string.
|
||||
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn load_private_key(
|
||||
pem: &str,
|
||||
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
// Try PKCS8 first, then RSA, then EC
|
||||
let key = rustls_pemfile::private_key(&mut reader)?
|
||||
.ok_or("No private key found in PEM data")?;
|
||||
let key =
|
||||
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_config::{RouteActionType, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_security::IpBlockList;
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
/// Shared connection registry for selective recycling.
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
/// Global ingress block policy, hot-reloadable without restarting listeners.
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
}
|
||||
|
||||
impl Drop for UdpListenerManager {
|
||||
@@ -99,17 +102,26 @@ impl UdpListenerManager {
|
||||
proxy_ips: Arc::new(Vec::new()),
|
||||
route_cancels,
|
||||
connection_registry,
|
||||
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
|
||||
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
|
||||
if !ips.is_empty() {
|
||||
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
|
||||
info!(
|
||||
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
|
||||
ips.len()
|
||||
);
|
||||
}
|
||||
self.proxy_ips = Arc::new(ips);
|
||||
}
|
||||
|
||||
/// Set the shared global ingress security policy.
|
||||
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
|
||||
self.security_policy = policy;
|
||||
}
|
||||
|
||||
/// Set the H3 proxy service for HTTP/3 request handling.
|
||||
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
|
||||
self.h3_service = Some(svc);
|
||||
@@ -142,7 +154,9 @@ impl UdpListenerManager {
|
||||
// Check if any route on this port uses QUIC
|
||||
let rm = self.route_manager.load();
|
||||
let has_quic = rm.routes_for_port(port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
r.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
});
|
||||
@@ -164,8 +178,10 @@ impl UdpListenerManager {
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint started on port {}", port);
|
||||
} else {
|
||||
// Proxy relay path: we own external socket, quinn on localhost
|
||||
@@ -173,6 +189,7 @@ impl UdpListenerManager {
|
||||
port,
|
||||
tls,
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
self.cancel_token.child_token(),
|
||||
)?;
|
||||
let endpoint_for_updates = relay.endpoint.clone();
|
||||
@@ -187,13 +204,18 @@ impl UdpListenerManager {
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
||||
}
|
||||
return Ok(());
|
||||
} else {
|
||||
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
||||
warn!(
|
||||
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
|
||||
port
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,6 +236,7 @@ impl UdpListenerManager {
|
||||
Arc::clone(&self.relay_writer),
|
||||
self.cancel_token.child_token(),
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, (handle, None));
|
||||
@@ -254,8 +277,10 @@ impl UdpListenerManager {
|
||||
}
|
||||
debug!("UDP listener stopped on port {}", port);
|
||||
}
|
||||
info!("All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count());
|
||||
info!(
|
||||
"All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count()
|
||||
);
|
||||
}
|
||||
|
||||
/// Update TLS config on all active QUIC endpoints (cert refresh).
|
||||
@@ -288,11 +313,15 @@ impl UdpListenerManager {
|
||||
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
|
||||
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
|
||||
let rm = self.route_manager.load();
|
||||
let upgrade_ports: Vec<u16> = self.listeners.iter()
|
||||
let upgrade_ports: Vec<u16> = self
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|(_, (_, endpoint))| endpoint.is_none())
|
||||
.filter(|(port, _)| {
|
||||
rm.routes_for_port(**port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
r.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
})
|
||||
@@ -301,17 +330,23 @@ impl UdpListenerManager {
|
||||
.collect();
|
||||
|
||||
for port in upgrade_ports {
|
||||
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
|
||||
info!(
|
||||
"Upgrading raw UDP listener on port {} to QUIC endpoint",
|
||||
port
|
||||
);
|
||||
|
||||
// Stop the raw UDP listener task and drain sessions to release the socket
|
||||
if let Some((handle, _)) = self.listeners.remove(&port) {
|
||||
handle.abort();
|
||||
}
|
||||
let drained = self.session_table.drain_port(
|
||||
port, &self.metrics, &self.conn_tracker,
|
||||
);
|
||||
let drained = self
|
||||
.session_table
|
||||
.drain_port(port, &self.metrics, &self.conn_tracker);
|
||||
if drained > 0 {
|
||||
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
|
||||
debug!(
|
||||
"Drained {} UDP sessions on port {} for QUIC upgrade",
|
||||
drained, port
|
||||
);
|
||||
}
|
||||
|
||||
// Brief yield to let aborted tasks drop their socket references
|
||||
@@ -326,11 +361,17 @@ impl UdpListenerManager {
|
||||
|
||||
match create_result {
|
||||
Ok(()) => {
|
||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
|
||||
info!(
|
||||
"QUIC endpoint started on port {} (upgraded from raw UDP)",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Port may still be held — retry once after a brief delay
|
||||
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
|
||||
warn!(
|
||||
"QUIC endpoint creation failed on port {}, retrying: {}",
|
||||
port, e
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let retry_result = if self.proxy_ips.is_empty() {
|
||||
@@ -341,11 +382,17 @@ impl UdpListenerManager {
|
||||
|
||||
match retry_result {
|
||||
Ok(()) => {
|
||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
|
||||
info!(
|
||||
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e2) => {
|
||||
error!("Failed to upgrade port {} to QUIC after retry: {}. \
|
||||
Rebinding as raw UDP.", port, e2);
|
||||
error!(
|
||||
"Failed to upgrade port {} to QUIC after retry: {}. \
|
||||
Rebinding as raw UDP.",
|
||||
port, e2
|
||||
);
|
||||
// Fallback: rebind as raw UDP so the port isn't dead
|
||||
if let Ok(()) = self.rebind_raw_udp(port).await {
|
||||
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
|
||||
@@ -358,7 +405,11 @@ impl UdpListenerManager {
|
||||
}
|
||||
|
||||
/// Create a direct QUIC endpoint (quinn owns the socket).
|
||||
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||
fn create_quic_direct(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
|
||||
let endpoint_for_updates = endpoint.clone();
|
||||
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||
@@ -372,17 +423,24 @@ impl UdpListenerManager {
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a QUIC endpoint with PROXY protocol relay.
|
||||
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||
fn create_quic_with_relay(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
||||
port,
|
||||
tls_config,
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
self.cancel_token.child_token(),
|
||||
)?;
|
||||
let endpoint_for_updates = relay.endpoint.clone();
|
||||
@@ -397,8 +455,10 @@ impl UdpListenerManager {
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -419,6 +479,7 @@ impl UdpListenerManager {
|
||||
Arc::clone(&self.relay_writer),
|
||||
self.cancel_token.child_token(),
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, (handle, None));
|
||||
@@ -458,7 +519,10 @@ impl UdpListenerManager {
|
||||
info!("Datagram handler relay connected to {}", path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to connect datagram handler relay to {}: {}", path, e);
|
||||
error!(
|
||||
"Failed to connect datagram handler relay to {}: {}",
|
||||
path, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -514,6 +578,7 @@ impl UdpListenerManager {
|
||||
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
||||
cancel: CancellationToken,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
) {
|
||||
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
||||
let mut buf = vec![0u8; 65535];
|
||||
@@ -528,9 +593,11 @@ impl UdpListenerManager {
|
||||
|
||||
loop {
|
||||
// Periodic cleanup: remove proxy_addr_map entries with no active session
|
||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
|
||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
|
||||
{
|
||||
last_proxy_cleanup = tokio::time::Instant::now();
|
||||
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||
let stale: Vec<SocketAddr> = proxy_addr_map
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let key: SessionKey = (*entry.key(), port);
|
||||
session_table.get(&key).is_none()
|
||||
@@ -538,7 +605,11 @@ impl UdpListenerManager {
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
if !stale.is_empty() {
|
||||
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
|
||||
debug!(
|
||||
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
|
||||
stale.len(),
|
||||
port
|
||||
);
|
||||
for addr in stale {
|
||||
proxy_addr_map.remove(&addr);
|
||||
}
|
||||
@@ -564,34 +635,50 @@ impl UdpListenerManager {
|
||||
let datagram = &buf[..len];
|
||||
|
||||
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
|
||||
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
|
||||
// No session and no prior PROXY header — check for PROXY v2
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
|
||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||
continue; // discard the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||
client_addr.ip()
|
||||
let effective_client_ip =
|
||||
if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
if session_table.get(&session_key).is_none()
|
||||
&& !proxy_addr_map.contains_key(&client_addr)
|
||||
{
|
||||
// No session and no prior PROXY header — check for PROXY v2
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!(
|
||||
"UDP PROXY v2 from {}: real client {}",
|
||||
client_addr, header.source_addr
|
||||
);
|
||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||
continue; // discard the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||
client_addr.ip()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
// Use real client IP if we've previously seen a PROXY v2 header
|
||||
proxy_addr_map
|
||||
.get(&client_addr)
|
||||
.map(|r| r.ip())
|
||||
.unwrap_or_else(|| client_addr.ip())
|
||||
}
|
||||
} else {
|
||||
// Use real client IP if we've previously seen a PROXY v2 header
|
||||
proxy_addr_map.get(&client_addr)
|
||||
.map(|r| r.ip())
|
||||
.unwrap_or_else(|| client_addr.ip())
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
};
|
||||
client_addr.ip()
|
||||
};
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
|
||||
debug!(
|
||||
"UDP datagram from {} blocked by global security policy",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Route matching — use effective (real) client IP
|
||||
let rm = route_manager.load();
|
||||
@@ -611,7 +698,10 @@ impl UdpListenerManager {
|
||||
let route_match = match rm.find_route(&ctx) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
debug!("No UDP route matched for port {} from {}", port, client_addr);
|
||||
debug!(
|
||||
"No UDP route matched for port {} from {}",
|
||||
port, client_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -627,7 +717,9 @@ impl UdpListenerManager {
|
||||
&client_addr,
|
||||
port,
|
||||
datagram,
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("Failed to relay UDP datagram to TS: {}", e);
|
||||
}
|
||||
continue;
|
||||
@@ -638,8 +730,10 @@ impl UdpListenerManager {
|
||||
|
||||
// Check datagram size
|
||||
if len as u32 > udp_config.max_datagram_size {
|
||||
debug!("UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr);
|
||||
debug!(
|
||||
"UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -651,21 +745,27 @@ impl UdpListenerManager {
|
||||
None => {
|
||||
// New session — check per-IP limits using the real client IP
|
||||
if !conn_tracker.try_accept(&effective_client_ip) {
|
||||
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
|
||||
debug!(
|
||||
"UDP session rejected for {} (rate limit)",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if !session_table.can_create_session(
|
||||
&effective_client_ip,
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
|
||||
if !session_table
|
||||
.can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
|
||||
{
|
||||
debug!(
|
||||
"UDP session rejected for {} (per-IP session limit)",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resolve target
|
||||
let target = match route_match.target.or_else(|| {
|
||||
route.action.targets.as_ref().and_then(|t| t.first())
|
||||
}) {
|
||||
let target = match route_match
|
||||
.target
|
||||
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
|
||||
{
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("No target for UDP route {:?}", route_id);
|
||||
@@ -686,13 +786,18 @@ impl UdpListenerManager {
|
||||
}
|
||||
};
|
||||
if let Err(e) = backend_socket.connect(&backend_addr).await {
|
||||
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
|
||||
error!(
|
||||
"Failed to connect backend UDP socket to {}: {}",
|
||||
backend_addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let backend_socket = Arc::new(backend_socket);
|
||||
|
||||
debug!("New UDP session: {} -> {} (via port {}, real client {})",
|
||||
client_addr, backend_addr, port, effective_client_ip);
|
||||
debug!(
|
||||
"New UDP session: {} -> {} (via port {}, real client {})",
|
||||
client_addr, backend_addr, port, effective_client_ip
|
||||
);
|
||||
|
||||
// Spawn return-path relay task
|
||||
let session_cancel = CancellationToken::new();
|
||||
@@ -709,7 +814,9 @@ impl UdpListenerManager {
|
||||
|
||||
let session = Arc::new(UdpSession {
|
||||
backend_socket,
|
||||
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
||||
last_activity: std::sync::atomic::AtomicU64::new(
|
||||
session_table.elapsed_ms(),
|
||||
),
|
||||
created_at: std::time::Instant::now(),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
source_ip: effective_client_ip,
|
||||
@@ -718,7 +825,11 @@ impl UdpListenerManager {
|
||||
cancel: session_cancel,
|
||||
});
|
||||
|
||||
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
|
||||
if !session_table.insert(
|
||||
session_key,
|
||||
Arc::clone(&session),
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
warn!("Failed to insert UDP session (race condition)");
|
||||
continue;
|
||||
}
|
||||
@@ -735,7 +846,9 @@ impl UdpListenerManager {
|
||||
// Forward datagram to backend
|
||||
match session.backend_socket.send(datagram).await {
|
||||
Ok(_) => {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
session
|
||||
.last_activity
|
||||
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
|
||||
metrics.record_datagram_in();
|
||||
}
|
||||
@@ -779,7 +892,9 @@ impl UdpListenerManager {
|
||||
Ok(_) => {
|
||||
// Update session activity
|
||||
if let Some(session) = session_table.get(&session_key) {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
session
|
||||
.last_activity
|
||||
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
}
|
||||
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
|
||||
metrics.record_datagram_out();
|
||||
@@ -814,7 +929,8 @@ impl UdpListenerManager {
|
||||
let json = serde_json::to_vec(&msg)?;
|
||||
|
||||
let mut guard = writer.lock().await;
|
||||
let stream = guard.as_mut()
|
||||
let stream = guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
|
||||
|
||||
// Length-prefixed frame
|
||||
@@ -879,9 +995,15 @@ impl UdpListenerManager {
|
||||
}
|
||||
|
||||
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let source_port = reply
|
||||
.get("sourcePort")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u16;
|
||||
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let payload_b64 = reply
|
||||
.get("payloadBase64")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
|
||||
Ok(p) => p,
|
||||
|
||||
@@ -111,12 +111,15 @@ impl UdpSessionTable {
|
||||
|
||||
/// Look up an existing session.
|
||||
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
||||
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
|
||||
self.sessions
|
||||
.get(key)
|
||||
.map(|entry| Arc::clone(entry.value()))
|
||||
}
|
||||
|
||||
/// Check if we can create a new session for this IP (under the per-IP limit).
|
||||
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
||||
let count = self.ip_session_counts
|
||||
let count = self
|
||||
.ip_session_counts
|
||||
.get(ip)
|
||||
.map(|c| *c.value())
|
||||
.unwrap_or(0);
|
||||
@@ -124,12 +127,7 @@ impl UdpSessionTable {
|
||||
}
|
||||
|
||||
/// Insert a new session. Returns false if per-IP limit exceeded.
|
||||
pub fn insert(
|
||||
&self,
|
||||
key: SessionKey,
|
||||
session: Arc<UdpSession>,
|
||||
max_per_ip: u32,
|
||||
) -> bool {
|
||||
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
|
||||
let ip = session.source_ip;
|
||||
|
||||
// Atomically check and increment per-IP count
|
||||
@@ -173,7 +171,9 @@ impl UdpSessionTable {
|
||||
let mut removed = 0;
|
||||
|
||||
// Collect keys to remove (avoid holding DashMap refs during removal)
|
||||
let stale_keys: Vec<SessionKey> = self.sessions.iter()
|
||||
let stale_keys: Vec<SessionKey> = self
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
||||
now_ms.saturating_sub(last) >= timeout_ms
|
||||
@@ -185,7 +185,8 @@ impl UdpSessionTable {
|
||||
if let Some(session) = self.remove(&key) {
|
||||
debug!(
|
||||
"UDP session expired: {} -> port {} (idle {}ms)",
|
||||
session.client_addr, key.1,
|
||||
session.client_addr,
|
||||
key.1,
|
||||
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
||||
);
|
||||
conn_tracker.connection_closed(&session.source_ip);
|
||||
@@ -210,7 +211,9 @@ impl UdpSessionTable {
|
||||
metrics: &MetricsCollector,
|
||||
conn_tracker: &ConnectionTracker,
|
||||
) -> usize {
|
||||
let keys: Vec<SessionKey> = self.sessions.iter()
|
||||
let keys: Vec<SessionKey> = self
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| entry.key().1 == port)
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
@@ -257,9 +260,8 @@ mod tests {
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let backend_socket = rt.block_on(async {
|
||||
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
|
||||
});
|
||||
let backend_socket =
|
||||
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
|
||||
|
||||
let child_cancel = cancel.child_token();
|
||||
let return_task = rt.spawn(async move {
|
||||
|
||||
Reference in New Issue
Block a user