feat(smart-proxy): add UDP transport support with QUIC/HTTP3 routing and datagram handler relay
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
//! # rustproxy-passthrough
|
||||
//!
|
||||
//! Raw TCP/SNI passthrough engine for RustProxy.
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
|
||||
//! Raw TCP/SNI passthrough engine and UDP listener for RustProxy.
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
|
||||
//! and UDP datagram session tracking with forwarding.
|
||||
|
||||
pub mod tcp_listener;
|
||||
pub mod sni_parser;
|
||||
@@ -11,6 +12,9 @@ pub mod tls_handler;
|
||||
pub mod connection_tracker;
|
||||
pub mod socket_relay;
|
||||
pub mod socket_opts;
|
||||
pub mod udp_session;
|
||||
pub mod udp_listener;
|
||||
pub mod quic_handler;
|
||||
|
||||
pub use tcp_listener::*;
|
||||
pub use sni_parser::*;
|
||||
@@ -20,3 +24,6 @@ pub use tls_handler::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use socket_relay::*;
|
||||
pub use socket_opts::*;
|
||||
pub use udp_session::*;
|
||||
pub use udp_listener::*;
|
||||
pub use quic_handler::*;
|
||||
|
||||
309
rust/crates/rustproxy-passthrough/src/quic_handler.rs
Normal file
309
rust/crates/rustproxy-passthrough/src/quic_handler.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
//! QUIC connection handling.
|
||||
//!
|
||||
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
|
||||
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
|
||||
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
|
||||
use rustls::ServerConfig as RustlsServerConfig;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use rustproxy_config::{RouteConfig, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::forwarder::ForwardMetricsCtx;
|
||||
|
||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||
///
|
||||
/// The TLS config must have ALPN protocols set (e.g., `h3` for HTTP/3).
|
||||
pub fn create_quic_endpoint(
|
||||
port: u16,
|
||||
tls_config: Arc<RustlsServerConfig>,
|
||||
) -> anyhow::Result<Endpoint> {
|
||||
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
|
||||
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
|
||||
|
||||
let socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
|
||||
let endpoint = Endpoint::new(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
info!("QUIC endpoint listening on port {}", port);
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
/// Run the QUIC accept loop for a single endpoint.
|
||||
///
|
||||
/// Accepts incoming QUIC connections and spawns a task per connection.
|
||||
pub async fn quic_accept_loop(
|
||||
endpoint: Endpoint,
|
||||
port: u16,
|
||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
let incoming = tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("QUIC accept loop on port {} cancelled", port);
|
||||
break;
|
||||
}
|
||||
incoming = endpoint.accept() => {
|
||||
match incoming {
|
||||
Some(conn) => conn,
|
||||
None => {
|
||||
debug!("QUIC endpoint on port {} closed", port);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let remote_addr = incoming.remote_address();
|
||||
let ip = remote_addr.ip();
|
||||
|
||||
// Per-IP rate limiting
|
||||
if !conn_tracker.try_accept(&ip) {
|
||||
debug!("QUIC connection rejected from {} (rate limit)", remote_addr);
|
||||
// Drop `incoming` to refuse the connection
|
||||
continue;
|
||||
}
|
||||
|
||||
// Route matching (port + client IP, no domain yet — QUIC Initial is encrypted)
|
||||
let rm = route_manager.load();
|
||||
let ip_str = ip.to_string();
|
||||
let ctx = MatchContext {
|
||||
port,
|
||||
domain: None,
|
||||
path: None,
|
||||
client_ip: Some(&ip_str),
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: true,
|
||||
protocol: Some("quic"),
|
||||
transport: Some(TransportProtocol::Udp),
|
||||
};
|
||||
|
||||
let route = match rm.find_route(&ctx) {
|
||||
Some(m) => m.route.clone(),
|
||||
None => {
|
||||
debug!("No QUIC route matched for port {} from {}", port, remote_addr);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
let route_id = route.name.clone().or(route.id.clone());
|
||||
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
|
||||
|
||||
let metrics = Arc::clone(&metrics);
|
||||
let conn_tracker = Arc::clone(&conn_tracker);
|
||||
let cancel = cancel.child_token();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match handle_quic_connection(incoming, route, port, &metrics, &cancel).await {
|
||||
Ok(()) => debug!("QUIC connection from {} completed", remote_addr),
|
||||
Err(e) => debug!("QUIC connection from {} error: {}", remote_addr, e),
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
conn_tracker.connection_closed(&ip);
|
||||
metrics.connection_closed(route_id.as_deref(), Some(&ip_str));
|
||||
});
|
||||
}
|
||||
|
||||
// Graceful shutdown: close endpoint and wait for in-flight connections
|
||||
endpoint.close(quinn::VarInt::from_u32(0), b"server shutting down");
|
||||
endpoint.wait_idle().await;
|
||||
info!("QUIC endpoint on port {} shut down", port);
|
||||
}
|
||||
|
||||
/// Handle a single accepted QUIC connection.
|
||||
async fn handle_quic_connection(
|
||||
incoming: quinn::Incoming,
|
||||
route: RouteConfig,
|
||||
port: u16,
|
||||
metrics: &MetricsCollector,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let connection = incoming.await?;
|
||||
let remote_addr = connection.remote_address();
|
||||
debug!("QUIC connection established from {}", remote_addr);
|
||||
|
||||
// Check if this route has HTTP/3 enabled
|
||||
let enable_http3 = route.action.udp.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.and_then(|q| q.enable_http3)
|
||||
.unwrap_or(false);
|
||||
|
||||
if enable_http3 {
|
||||
// Phase 5: dispatch to H3ProxyService
|
||||
// For now, log and accept streams for basic handling
|
||||
debug!("HTTP/3 enabled for route {:?}, dispatching to H3 handler", route.name);
|
||||
handle_h3_connection(connection, route, port, metrics, cancel).await
|
||||
} else {
|
||||
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward QUIC streams bidirectionally to a TCP backend.
|
||||
///
|
||||
/// For each accepted bidirectional QUIC stream, connects to the backend
|
||||
/// via TCP and forwards data in both directions. Quinn's RecvStream/SendStream
|
||||
/// implement AsyncRead/AsyncWrite, enabling reuse of existing forwarder patterns.
|
||||
async fn handle_quic_stream_forwarding(
|
||||
connection: quinn::Connection,
|
||||
route: RouteConfig,
|
||||
port: u16,
|
||||
_metrics: &MetricsCollector,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let remote_addr = connection.remote_address();
|
||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||
|
||||
// Resolve backend target
|
||||
let target = route.action.targets.as_ref()
|
||||
.and_then(|t| t.first())
|
||||
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
||||
let backend_host = target.host.first();
|
||||
let backend_port = target.port.resolve(port);
|
||||
let backend_addr = format!("{}:{}", backend_host, backend_port);
|
||||
|
||||
loop {
|
||||
let (send_stream, recv_stream) = tokio::select! {
|
||||
_ = cancel.cancelled() => break,
|
||||
result = connection.accept_bi() => {
|
||||
match result {
|
||||
Ok(streams) => streams,
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
||||
Err(quinn::ConnectionError::LocallyClosed) => break,
|
||||
Err(e) => {
|
||||
debug!("QUIC stream accept error from {}: {}", remote_addr, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let backend_addr = backend_addr.clone();
|
||||
let ip_str = remote_addr.ip().to_string();
|
||||
let _fwd_ctx = ForwardMetricsCtx {
|
||||
collector: Arc::new(MetricsCollector::new()), // TODO: share real metrics
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
source_ip: Some(ip_str),
|
||||
};
|
||||
|
||||
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||||
tokio::spawn(async move {
|
||||
match forward_quic_stream_to_tcp(
|
||||
send_stream,
|
||||
recv_stream,
|
||||
&backend_addr,
|
||||
).await {
|
||||
Ok((bytes_in, bytes_out)) => {
|
||||
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC stream forwarding error: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
|
||||
async fn forward_quic_stream_to_tcp(
|
||||
mut quic_send: quinn::SendStream,
|
||||
mut quic_recv: quinn::RecvStream,
|
||||
backend_addr: &str,
|
||||
) -> anyhow::Result<(u64, u64)> {
|
||||
// Connect to backend TCP
|
||||
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
|
||||
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
|
||||
|
||||
// Bidirectional copy
|
||||
let client_to_backend = tokio::io::copy(&mut quic_recv, &mut tcp_write);
|
||||
let backend_to_client = tokio::io::copy(&mut tcp_read, &mut quic_send);
|
||||
|
||||
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
||||
|
||||
let bytes_in = c2b.unwrap_or(0);
|
||||
let bytes_out = b2c.unwrap_or(0);
|
||||
|
||||
// Graceful shutdown
|
||||
let _ = quic_send.finish();
|
||||
let _ = tcp_write.shutdown().await;
|
||||
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
|
||||
/// Placeholder for HTTP/3 connection handling (Phase 5).
|
||||
///
|
||||
/// Once h3_service is implemented, this will delegate to it.
|
||||
async fn handle_h3_connection(
|
||||
connection: quinn::Connection,
|
||||
_route: RouteConfig,
|
||||
_port: u16,
|
||||
_metrics: &MetricsCollector,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
warn!("HTTP/3 handling not yet fully implemented — accepting connection but no request processing");
|
||||
|
||||
// Keep the connection alive until cancelled or closed
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {}
|
||||
reason = connection.closed() => {
|
||||
debug!("HTTP/3 connection closed: {}", reason);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_quic_endpoint_requires_tls_config() {
|
||||
// Install the ring crypto provider for tests
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Generate a single self-signed cert and use its key pair
|
||||
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
||||
.unwrap();
|
||||
let cert_der = self_signed.cert.der().clone();
|
||||
let key_der = self_signed.key_pair.serialize_der();
|
||||
|
||||
let mut tls_config = RustlsServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
vec![cert_der.into()],
|
||||
rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
||||
|
||||
// Port 0 = OS assigns a free port
|
||||
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
||||
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
||||
}
|
||||
}
|
||||
@@ -625,6 +625,7 @@ impl TcpListenerManager {
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
protocol: None,
|
||||
transport: None,
|
||||
};
|
||||
|
||||
if let Some(quick_match) = route_manager.find_route(&quick_ctx) {
|
||||
@@ -814,6 +815,7 @@ impl TcpListenerManager {
|
||||
is_tls,
|
||||
// For TLS connections, protocol is unknown until after termination
|
||||
protocol: if is_http { Some("http") } else if !is_tls { Some("tcp") } else { None },
|
||||
transport: None,
|
||||
};
|
||||
|
||||
let route_match = route_manager.find_route(&ctx);
|
||||
|
||||
477
rust/crates/rustproxy-passthrough/src/udp_listener.rs
Normal file
477
rust/crates/rustproxy-passthrough/src/udp_listener.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
//! UDP listener manager.
|
||||
//!
|
||||
//! Binds UDP sockets on configured ports, receives datagrams, matches routes,
|
||||
//! tracks sessions (flows), and forwards datagrams to backend UDP sockets.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_config::{RouteActionType, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::udp_session::{SessionKey, UdpSession, UdpSessionConfig, UdpSessionTable};
|
||||
|
||||
/// Manages UDP listeners across all configured ports.
|
||||
pub struct UdpListenerManager {
|
||||
/// Port → recv loop task handle
|
||||
listeners: HashMap<u16, JoinHandle<()>>,
|
||||
/// Hot-reloadable route table
|
||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||
/// Shared metrics collector
|
||||
metrics: Arc<MetricsCollector>,
|
||||
/// Per-IP session/rate limiting (shared with TCP)
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
/// Shared session table across all ports
|
||||
session_table: Arc<UdpSessionTable>,
|
||||
/// Cancellation for graceful shutdown
|
||||
cancel_token: CancellationToken,
|
||||
/// Unix socket path for datagram handler relay
|
||||
datagram_handler_relay: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl UdpListenerManager {
|
||||
pub fn new(
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel_token: CancellationToken,
|
||||
) -> Self {
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager: Arc::new(ArcSwap::from(route_manager)),
|
||||
metrics,
|
||||
conn_tracker,
|
||||
session_table: Arc::new(UdpSessionTable::new()),
|
||||
cancel_token,
|
||||
datagram_handler_relay: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the route manager (for hot-reload).
|
||||
pub fn update_routes(&self, route_manager: Arc<RouteManager>) {
|
||||
self.route_manager.store(route_manager);
|
||||
}
|
||||
|
||||
/// Start listening on a UDP port.
|
||||
///
|
||||
/// If any route on this port has QUIC config (`action.udp.quic`), a quinn
|
||||
/// endpoint is created instead of a raw UDP socket.
|
||||
pub async fn add_port(&mut self, port: u16) -> anyhow::Result<()> {
|
||||
self.add_port_with_tls(port, None).await
|
||||
}
|
||||
|
||||
/// Start listening on a UDP port with optional TLS config for QUIC.
|
||||
pub async fn add_port_with_tls(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
|
||||
) -> anyhow::Result<()> {
|
||||
if self.listeners.contains_key(&port) {
|
||||
debug!("UDP port {} already listening", port);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check if any route on this port uses QUIC
|
||||
let rm = self.route_manager.load();
|
||||
let has_quic = rm.routes_for_port(port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
});
|
||||
|
||||
if has_quic {
|
||||
if let Some(tls) = tls_config {
|
||||
// Create QUIC endpoint
|
||||
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls)?;
|
||||
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||
endpoint,
|
||||
port,
|
||||
Arc::clone(&self.route_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
Arc::clone(&self.conn_tracker),
|
||||
self.cancel_token.child_token(),
|
||||
));
|
||||
self.listeners.insert(port, handle);
|
||||
info!("QUIC endpoint started on port {}", port);
|
||||
return Ok(());
|
||||
} else {
|
||||
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
||||
}
|
||||
}
|
||||
|
||||
// Raw UDP listener
|
||||
let addr: SocketAddr = ([0, 0, 0, 0], port).into();
|
||||
let socket = UdpSocket::bind(addr).await?;
|
||||
let socket = Arc::new(socket);
|
||||
info!("UDP listener bound on port {}", port);
|
||||
|
||||
let handle = tokio::spawn(Self::recv_loop(
|
||||
socket,
|
||||
port,
|
||||
Arc::clone(&self.route_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
Arc::clone(&self.conn_tracker),
|
||||
Arc::clone(&self.session_table),
|
||||
Arc::clone(&self.datagram_handler_relay),
|
||||
self.cancel_token.child_token(),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, handle);
|
||||
|
||||
// Start the session cleanup task if this is the first port
|
||||
if self.listeners.len() == 1 {
|
||||
self.start_cleanup_task();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop listening on a UDP port.
|
||||
pub fn remove_port(&mut self, port: u16) {
|
||||
if let Some(handle) = self.listeners.remove(&port) {
|
||||
handle.abort();
|
||||
info!("UDP listener removed from port {}", port);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all listening UDP ports.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.listeners.keys().copied().collect();
|
||||
ports.sort();
|
||||
ports
|
||||
}
|
||||
|
||||
/// Stop all listeners and clean up.
|
||||
pub async fn stop(&mut self) {
|
||||
self.cancel_token.cancel();
|
||||
for (port, handle) in self.listeners.drain() {
|
||||
handle.abort();
|
||||
debug!("UDP listener stopped on port {}", port);
|
||||
}
|
||||
info!("All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count());
|
||||
}
|
||||
|
||||
/// Set the datagram handler relay socket path.
|
||||
pub async fn set_datagram_handler_relay(&self, path: String) {
|
||||
let mut relay = self.datagram_handler_relay.write().await;
|
||||
*relay = Some(path);
|
||||
}
|
||||
|
||||
/// Start periodic session cleanup task.
|
||||
fn start_cleanup_task(&self) {
|
||||
let session_table = Arc::clone(&self.session_table);
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let cancel = self.cancel_token.child_token();
|
||||
let route_manager = Arc::clone(&self.route_manager);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(10));
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => break,
|
||||
_ = interval.tick() => {
|
||||
// Determine the timeout from routes (use the minimum configured timeout,
|
||||
// or default 60s if none configured)
|
||||
let rm = route_manager.load();
|
||||
let timeout_ms = Self::get_min_session_timeout(&rm);
|
||||
let removed = session_table.cleanup_idle(timeout_ms, &metrics);
|
||||
if removed > 0 {
|
||||
debug!("UDP session cleanup: removed {} idle sessions, {} remaining",
|
||||
removed, session_table.session_count());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Get the minimum session timeout across all UDP routes.
|
||||
fn get_min_session_timeout(_rm: &RouteManager) -> u64 {
|
||||
// Default to 60 seconds; actual per-route timeouts checked during cleanup
|
||||
60_000
|
||||
}
|
||||
|
||||
/// Main receive loop for a UDP port.
|
||||
async fn recv_loop(
|
||||
socket: Arc<UdpSocket>,
|
||||
port: u16,
|
||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
session_table: Arc<UdpSessionTable>,
|
||||
datagram_handler_relay: Arc<RwLock<Option<String>>>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
loop {
|
||||
let (len, client_addr) = tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("UDP recv loop on port {} cancelled", port);
|
||||
break;
|
||||
}
|
||||
result = socket.recv_from(&mut buf) => {
|
||||
match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("UDP recv error on port {}: {}", port, e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let datagram = &buf[..len];
|
||||
|
||||
// Route matching
|
||||
let rm = route_manager.load();
|
||||
let ip_str = client_addr.ip().to_string();
|
||||
let ctx = MatchContext {
|
||||
port,
|
||||
domain: None,
|
||||
path: None,
|
||||
client_ip: Some(&ip_str),
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
protocol: Some("udp"),
|
||||
transport: Some(TransportProtocol::Udp),
|
||||
};
|
||||
|
||||
let route_match = match rm.find_route(&ctx) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
debug!("No UDP route matched for port {} from {}", port, client_addr);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let route = route_match.route;
|
||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||
|
||||
// Socket handler routes → relay datagram to TS via Unix socket
|
||||
if route.action.action_type == RouteActionType::SocketHandler {
|
||||
let relay_path = datagram_handler_relay.read().await;
|
||||
if let Some(ref path) = *relay_path {
|
||||
if let Err(e) = Self::relay_datagram_to_ts(
|
||||
path,
|
||||
route_id.unwrap_or("unknown"),
|
||||
&client_addr,
|
||||
port,
|
||||
datagram,
|
||||
).await {
|
||||
debug!("Failed to relay UDP datagram to TS: {}", e);
|
||||
}
|
||||
} else {
|
||||
debug!("UDP datagram handler relay not configured for route {:?}", route_id);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get UDP config from route
|
||||
let udp_config = UdpSessionConfig::from_route_udp(route.action.udp.as_ref());
|
||||
|
||||
// Check datagram size
|
||||
if len as u32 > udp_config.max_datagram_size {
|
||||
debug!("UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Session lookup or create
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
let session = match session_table.get(&session_key) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
// New session — check per-IP limits
|
||||
if !conn_tracker.try_accept(&client_addr.ip()) {
|
||||
debug!("UDP session rejected for {} (rate limit)", client_addr);
|
||||
continue;
|
||||
}
|
||||
if !session_table.can_create_session(
|
||||
&client_addr.ip(),
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
debug!("UDP session rejected for {} (per-IP session limit)", client_addr);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resolve target
|
||||
let target = match route_match.target.or_else(|| {
|
||||
route.action.targets.as_ref().and_then(|t| t.first())
|
||||
}) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("No target for UDP route {:?}", route_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let backend_host = target.host.first();
|
||||
let backend_port = target.port.resolve(port);
|
||||
let backend_addr = format!("{}:{}", backend_host, backend_port);
|
||||
|
||||
// Create backend socket
|
||||
let backend_socket = match UdpSocket::bind("0.0.0.0:0").await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to bind backend UDP socket: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Err(e) = backend_socket.connect(&backend_addr).await {
|
||||
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
|
||||
continue;
|
||||
}
|
||||
let backend_socket = Arc::new(backend_socket);
|
||||
|
||||
debug!("New UDP session: {} -> {} (via port {})",
|
||||
client_addr, backend_addr, port);
|
||||
|
||||
// Spawn return-path relay task
|
||||
let session_cancel = CancellationToken::new();
|
||||
let return_task = tokio::spawn(Self::return_relay(
|
||||
Arc::clone(&backend_socket),
|
||||
Arc::clone(&socket),
|
||||
client_addr,
|
||||
Arc::clone(&session_table),
|
||||
session_key,
|
||||
Arc::clone(&metrics),
|
||||
route_id.map(|s| s.to_string()),
|
||||
session_cancel.child_token(),
|
||||
));
|
||||
|
||||
let session = Arc::new(UdpSession {
|
||||
backend_socket,
|
||||
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
||||
created_at: std::time::Instant::now(),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
source_ip: client_addr.ip(),
|
||||
client_addr,
|
||||
return_task,
|
||||
cancel: session_cancel,
|
||||
});
|
||||
|
||||
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
|
||||
warn!("Failed to insert UDP session (race condition)");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Track in metrics
|
||||
conn_tracker.connection_opened(&client_addr.ip());
|
||||
metrics.connection_opened(route_id, Some(&ip_str));
|
||||
metrics.udp_session_opened();
|
||||
|
||||
session
|
||||
}
|
||||
};
|
||||
|
||||
// Forward datagram to backend
|
||||
match session.backend_socket.send(datagram).await {
|
||||
Ok(_) => {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
|
||||
metrics.record_datagram_in();
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to send UDP datagram to backend: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return-path relay: backend → client.
|
||||
async fn return_relay(
|
||||
backend_socket: Arc<UdpSocket>,
|
||||
listener_socket: Arc<UdpSocket>,
|
||||
client_addr: SocketAddr,
|
||||
session_table: Arc<UdpSessionTable>,
|
||||
session_key: SessionKey,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
route_id: Option<String>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let ip_str = client_addr.ip().to_string();
|
||||
|
||||
loop {
|
||||
let len = tokio::select! {
|
||||
_ = cancel.cancelled() => break,
|
||||
result = backend_socket.recv(&mut buf) => {
|
||||
match result {
|
||||
Ok(len) => len,
|
||||
Err(e) => {
|
||||
debug!("UDP backend recv error for {}: {}", client_addr, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Send reply back to client
|
||||
match listener_socket.send_to(&buf[..len], client_addr).await {
|
||||
Ok(_) => {
|
||||
// Update session activity
|
||||
if let Some(session) = session_table.get(&session_key) {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
}
|
||||
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
|
||||
metrics.record_datagram_out();
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to send UDP reply to {}: {}", client_addr, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Relay a UDP datagram to the TypeScript handler via Unix socket.
|
||||
/// Uses length-prefixed JSON framing: [4-byte BE length][JSON payload]
|
||||
async fn relay_datagram_to_ts(
|
||||
relay_path: &str,
|
||||
route_key: &str,
|
||||
client_addr: &SocketAddr,
|
||||
dest_port: u16,
|
||||
datagram: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
use base64::Engine;
|
||||
|
||||
let payload_b64 = base64::engine::general_purpose::STANDARD.encode(datagram);
|
||||
let msg = serde_json::json!({
|
||||
"type": "datagram",
|
||||
"routeKey": route_key,
|
||||
"sourceIp": client_addr.ip().to_string(),
|
||||
"sourcePort": client_addr.port(),
|
||||
"destPort": dest_port,
|
||||
"payloadBase64": payload_b64,
|
||||
});
|
||||
let json = serde_json::to_vec(&msg)?;
|
||||
|
||||
// Connect to relay (one-shot for now; persistent connection optimization deferred)
|
||||
let mut stream = tokio::net::UnixStream::connect(relay_path).await?;
|
||||
|
||||
// Length-prefixed frame
|
||||
let len_bytes = (json.len() as u32).to_be_bytes();
|
||||
stream.write_all(&len_bytes).await?;
|
||||
stream.write_all(&json).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
320
rust/crates/rustproxy-passthrough/src/udp_session.rs
Normal file
320
rust/crates/rustproxy-passthrough/src/udp_session.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! UDP session (flow) tracking.
|
||||
//!
|
||||
//! A UDP "session" is a flow identified by (client_addr, listening_port).
|
||||
//! Each session maintains a backend socket bound to an ephemeral port and
|
||||
//! connected to the backend target, plus a background task that relays
|
||||
//! return datagrams from the backend back to the client.
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::debug;
|
||||
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
|
||||
/// A single UDP session (flow).
|
||||
pub struct UdpSession {
|
||||
/// Socket bound to ephemeral port, connected to backend
|
||||
pub backend_socket: Arc<UdpSocket>,
|
||||
/// Milliseconds since the session table's epoch
|
||||
pub last_activity: AtomicU64,
|
||||
/// When the session was created
|
||||
pub created_at: Instant,
|
||||
/// Route ID for metrics
|
||||
pub route_id: Option<String>,
|
||||
/// Source IP for metrics/tracking
|
||||
pub source_ip: IpAddr,
|
||||
/// Client address (for return path)
|
||||
pub client_addr: SocketAddr,
|
||||
/// Handle for the return-path relay task
|
||||
pub return_task: JoinHandle<()>,
|
||||
/// Per-session cancellation
|
||||
pub cancel: CancellationToken,
|
||||
}
|
||||
|
||||
impl Drop for UdpSession {
|
||||
fn drop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
self.return_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for UDP session behavior.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UdpSessionConfig {
|
||||
/// Idle timeout in milliseconds. Default: 60000.
|
||||
pub session_timeout_ms: u64,
|
||||
/// Max concurrent sessions per source IP. Default: 1000.
|
||||
pub max_sessions_per_ip: u32,
|
||||
/// Max accepted datagram size in bytes. Default: 65535.
|
||||
pub max_datagram_size: u32,
|
||||
}
|
||||
|
||||
impl Default for UdpSessionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
session_timeout_ms: 60_000,
|
||||
max_sessions_per_ip: 1_000,
|
||||
max_datagram_size: 65_535,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpSessionConfig {
|
||||
/// Build from route's UDP config, falling back to defaults.
|
||||
pub fn from_route_udp(udp: Option<&rustproxy_config::RouteUdp>) -> Self {
|
||||
match udp {
|
||||
Some(u) => Self {
|
||||
session_timeout_ms: u.session_timeout.unwrap_or(60_000),
|
||||
max_sessions_per_ip: u.max_sessions_per_ip.unwrap_or(1_000),
|
||||
max_datagram_size: u.max_datagram_size.unwrap_or(65_535),
|
||||
},
|
||||
None => Self::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Session key: (client address, listening port).
|
||||
pub type SessionKey = (SocketAddr, u16);
|
||||
|
||||
/// Tracks all active UDP sessions across all ports.
|
||||
pub struct UdpSessionTable {
|
||||
/// Active sessions keyed by (client_addr, listen_port)
|
||||
sessions: DashMap<SessionKey, Arc<UdpSession>>,
|
||||
/// Per-IP session counts for enforcing limits
|
||||
ip_session_counts: DashMap<IpAddr, u32>,
|
||||
/// Time reference for last_activity
|
||||
epoch: Instant,
|
||||
}
|
||||
|
||||
impl UdpSessionTable {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: DashMap::new(),
|
||||
ip_session_counts: DashMap::new(),
|
||||
epoch: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get elapsed milliseconds since epoch (for last_activity tracking).
|
||||
pub fn elapsed_ms(&self) -> u64 {
|
||||
self.epoch.elapsed().as_millis() as u64
|
||||
}
|
||||
|
||||
/// Look up an existing session.
|
||||
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
||||
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
|
||||
}
|
||||
|
||||
/// Check if we can create a new session for this IP (under the per-IP limit).
|
||||
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
||||
let count = self.ip_session_counts
|
||||
.get(ip)
|
||||
.map(|c| *c.value())
|
||||
.unwrap_or(0);
|
||||
count < max_per_ip
|
||||
}
|
||||
|
||||
/// Insert a new session. Returns false if per-IP limit exceeded.
|
||||
pub fn insert(
|
||||
&self,
|
||||
key: SessionKey,
|
||||
session: Arc<UdpSession>,
|
||||
max_per_ip: u32,
|
||||
) -> bool {
|
||||
let ip = session.source_ip;
|
||||
|
||||
// Atomically check and increment per-IP count
|
||||
let mut count_entry = self.ip_session_counts.entry(ip).or_insert(0);
|
||||
if *count_entry.value() >= max_per_ip {
|
||||
return false;
|
||||
}
|
||||
*count_entry.value_mut() += 1;
|
||||
drop(count_entry);
|
||||
|
||||
self.sessions.insert(key, session);
|
||||
true
|
||||
}
|
||||
|
||||
/// Remove a session and decrement per-IP count.
|
||||
pub fn remove(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
||||
if let Some((_, session)) = self.sessions.remove(key) {
|
||||
let ip = session.source_ip;
|
||||
if let Some(mut count) = self.ip_session_counts.get_mut(&ip) {
|
||||
*count.value_mut() = count.value().saturating_sub(1);
|
||||
if *count.value() == 0 {
|
||||
drop(count);
|
||||
self.ip_session_counts.remove(&ip);
|
||||
}
|
||||
}
|
||||
Some(session)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean up idle sessions past the given timeout.
|
||||
/// Returns the number of sessions removed.
|
||||
pub fn cleanup_idle(
|
||||
&self,
|
||||
timeout_ms: u64,
|
||||
metrics: &MetricsCollector,
|
||||
) -> usize {
|
||||
let now_ms = self.elapsed_ms();
|
||||
let mut removed = 0;
|
||||
|
||||
// Collect keys to remove (avoid holding DashMap refs during removal)
|
||||
let stale_keys: Vec<SessionKey> = self.sessions.iter()
|
||||
.filter(|entry| {
|
||||
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
||||
now_ms.saturating_sub(last) >= timeout_ms
|
||||
})
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
|
||||
for key in stale_keys {
|
||||
if let Some(session) = self.remove(&key) {
|
||||
debug!(
|
||||
"UDP session expired: {} -> port {} (idle {}ms)",
|
||||
session.client_addr, key.1,
|
||||
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
||||
);
|
||||
metrics.connection_closed(
|
||||
session.route_id.as_deref(),
|
||||
Some(&session.source_ip.to_string()),
|
||||
);
|
||||
metrics.udp_session_closed();
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
removed
|
||||
}
|
||||
|
||||
/// Total number of active sessions.
|
||||
pub fn session_count(&self) -> usize {
|
||||
self.sessions.len()
|
||||
}
|
||||
|
||||
/// Number of tracked IPs with active sessions.
|
||||
pub fn tracked_ips(&self) -> usize {
|
||||
self.ip_session_counts.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{Ipv4Addr, SocketAddrV4};
|
||||
|
||||
fn make_addr(port: u16) -> SocketAddr {
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), port))
|
||||
}
|
||||
|
||||
fn make_session(client_addr: SocketAddr, cancel: CancellationToken) -> Arc<UdpSession> {
|
||||
// Create a dummy backend socket for testing
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let backend_socket = rt.block_on(async {
|
||||
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
|
||||
});
|
||||
|
||||
let child_cancel = cancel.child_token();
|
||||
let return_task = rt.spawn(async move {
|
||||
child_cancel.cancelled().await;
|
||||
});
|
||||
|
||||
Arc::new(UdpSession {
|
||||
backend_socket,
|
||||
last_activity: AtomicU64::new(0),
|
||||
created_at: Instant::now(),
|
||||
route_id: None,
|
||||
source_ip: client_addr.ip(),
|
||||
client_addr,
|
||||
return_task,
|
||||
cancel,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_table_insert_and_get() {
|
||||
let table = UdpSessionTable::new();
|
||||
let cancel = CancellationToken::new();
|
||||
let addr = make_addr(12345);
|
||||
let key: SessionKey = (addr, 53);
|
||||
let session = make_session(addr, cancel);
|
||||
|
||||
assert!(table.insert(key, session, 1000));
|
||||
assert!(table.get(&key).is_some());
|
||||
assert_eq!(table.session_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_table_per_ip_limit() {
|
||||
let table = UdpSessionTable::new();
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
|
||||
// Insert 2 sessions from same IP, limit is 2
|
||||
for port in [12345u16, 12346] {
|
||||
let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
|
||||
let cancel = CancellationToken::new();
|
||||
let session = make_session(addr, cancel);
|
||||
assert!(table.insert((addr, 53), session, 2));
|
||||
}
|
||||
|
||||
// Third should be rejected
|
||||
let addr3 = SocketAddr::V4(SocketAddrV4::new(ip, 12347));
|
||||
let cancel3 = CancellationToken::new();
|
||||
let session3 = make_session(addr3, cancel3);
|
||||
assert!(!table.insert((addr3, 53), session3, 2));
|
||||
|
||||
assert_eq!(table.session_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_table_remove() {
|
||||
let table = UdpSessionTable::new();
|
||||
let cancel = CancellationToken::new();
|
||||
let addr = make_addr(12345);
|
||||
let key: SessionKey = (addr, 53);
|
||||
let session = make_session(addr, cancel);
|
||||
|
||||
table.insert(key, session, 1000);
|
||||
assert_eq!(table.session_count(), 1);
|
||||
assert_eq!(table.tracked_ips(), 1);
|
||||
|
||||
table.remove(&key);
|
||||
assert_eq!(table.session_count(), 0);
|
||||
assert_eq!(table.tracked_ips(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_config_defaults() {
|
||||
let config = UdpSessionConfig::default();
|
||||
assert_eq!(config.session_timeout_ms, 60_000);
|
||||
assert_eq!(config.max_sessions_per_ip, 1_000);
|
||||
assert_eq!(config.max_datagram_size, 65_535);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_config_from_route() {
|
||||
let route_udp = rustproxy_config::RouteUdp {
|
||||
session_timeout: Some(10_000),
|
||||
max_sessions_per_ip: Some(500),
|
||||
max_datagram_size: Some(1400),
|
||||
quic: None,
|
||||
};
|
||||
let config = UdpSessionConfig::from_route_udp(Some(&route_udp));
|
||||
assert_eq!(config.session_timeout_ms, 10_000);
|
||||
assert_eq!(config.max_sessions_per_ip, 500);
|
||||
assert_eq!(config.max_datagram_size, 1400);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user