//! UDP session (flow) tracking. //! //! A UDP "session" is a flow identified by (client_addr, listening_port). //! Each session maintains a backend socket bound to an ephemeral port and //! connected to the backend target, plus a background task that relays //! return datagrams from the backend back to the client. use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Instant; use dashmap::DashMap; use tokio::net::UdpSocket; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::debug; use rustproxy_metrics::MetricsCollector; use crate::connection_tracker::ConnectionTracker; /// A single UDP session (flow). pub struct UdpSession { /// Socket bound to ephemeral port, connected to backend pub backend_socket: Arc, /// Milliseconds since the session table's epoch pub last_activity: AtomicU64, /// When the session was created pub created_at: Instant, /// Route ID for metrics pub route_id: Option, /// Source IP for metrics/tracking pub source_ip: IpAddr, /// Client address (for return path) pub client_addr: SocketAddr, /// Handle for the return-path relay task pub return_task: JoinHandle<()>, /// Per-session cancellation pub cancel: CancellationToken, } impl Drop for UdpSession { fn drop(&mut self) { self.cancel.cancel(); self.return_task.abort(); } } /// Configuration for UDP session behavior. #[derive(Debug, Clone)] pub struct UdpSessionConfig { /// Idle timeout in milliseconds. Default: 60000. pub session_timeout_ms: u64, /// Max concurrent sessions per source IP. Default: 1000. pub max_sessions_per_ip: u32, /// Max accepted datagram size in bytes. Default: 65535. pub max_datagram_size: u32, } impl Default for UdpSessionConfig { fn default() -> Self { Self { session_timeout_ms: 60_000, max_sessions_per_ip: 1_000, max_datagram_size: 65_535, } } } impl UdpSessionConfig { /// Build from route's UDP config, falling back to defaults. pub fn from_route_udp(udp: Option<&rustproxy_config::RouteUdp>) -> Self { match udp { Some(u) => Self { session_timeout_ms: u.session_timeout.unwrap_or(60_000), max_sessions_per_ip: u.max_sessions_per_ip.unwrap_or(1_000), max_datagram_size: u.max_datagram_size.unwrap_or(65_535), }, None => Self::default(), } } } /// Session key: (client address, listening port). pub type SessionKey = (SocketAddr, u16); /// Tracks all active UDP sessions across all ports. pub struct UdpSessionTable { /// Active sessions keyed by (client_addr, listen_port) sessions: DashMap>, /// Per-IP session counts for enforcing limits ip_session_counts: DashMap, /// Time reference for last_activity epoch: Instant, } impl UdpSessionTable { pub fn new() -> Self { Self { sessions: DashMap::new(), ip_session_counts: DashMap::new(), epoch: Instant::now(), } } /// Get elapsed milliseconds since epoch (for last_activity tracking). pub fn elapsed_ms(&self) -> u64 { self.epoch.elapsed().as_millis() as u64 } /// Look up an existing session. pub fn get(&self, key: &SessionKey) -> Option> { self.sessions.get(key).map(|entry| Arc::clone(entry.value())) } /// Check if we can create a new session for this IP (under the per-IP limit). pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool { let count = self.ip_session_counts .get(ip) .map(|c| *c.value()) .unwrap_or(0); count < max_per_ip } /// Insert a new session. Returns false if per-IP limit exceeded. pub fn insert( &self, key: SessionKey, session: Arc, max_per_ip: u32, ) -> bool { let ip = session.source_ip; // Atomically check and increment per-IP count let mut count_entry = self.ip_session_counts.entry(ip).or_insert(0); if *count_entry.value() >= max_per_ip { return false; } *count_entry.value_mut() += 1; drop(count_entry); self.sessions.insert(key, session); true } /// Remove a session and decrement per-IP count. pub fn remove(&self, key: &SessionKey) -> Option> { if let Some((_, session)) = self.sessions.remove(key) { let ip = session.source_ip; if let Some(mut count) = self.ip_session_counts.get_mut(&ip) { *count.value_mut() = count.value().saturating_sub(1); if *count.value() == 0 { drop(count); self.ip_session_counts.remove(&ip); } } Some(session) } else { None } } /// Clean up idle sessions past the given timeout. /// Returns the number of sessions removed. pub fn cleanup_idle( &self, timeout_ms: u64, metrics: &MetricsCollector, conn_tracker: &ConnectionTracker, ) -> usize { let now_ms = self.elapsed_ms(); let mut removed = 0; // Collect keys to remove (avoid holding DashMap refs during removal) let stale_keys: Vec = self.sessions.iter() .filter(|entry| { let last = entry.value().last_activity.load(Ordering::Relaxed); now_ms.saturating_sub(last) >= timeout_ms }) .map(|entry| *entry.key()) .collect(); for key in stale_keys { if let Some(session) = self.remove(&key) { debug!( "UDP session expired: {} -> port {} (idle {}ms)", session.client_addr, key.1, now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed)) ); conn_tracker.connection_closed(&session.source_ip); metrics.connection_closed( session.route_id.as_deref(), Some(&session.source_ip.to_string()), ); metrics.udp_session_closed(); removed += 1; } } removed } /// Drain all sessions on a given listening port, releasing socket references. /// Used when upgrading a raw UDP listener to QUIC — the raw UDP socket's /// Arc refcount must drop to zero so the port can be rebound. pub fn drain_port( &self, port: u16, metrics: &MetricsCollector, conn_tracker: &ConnectionTracker, ) -> usize { let keys: Vec = self.sessions.iter() .filter(|entry| entry.key().1 == port) .map(|entry| *entry.key()) .collect(); let mut removed = 0; for key in keys { if let Some(session) = self.remove(&key) { session.cancel.cancel(); conn_tracker.connection_closed(&session.source_ip); metrics.connection_closed( session.route_id.as_deref(), Some(&session.source_ip.to_string()), ); metrics.udp_session_closed(); removed += 1; } } removed } /// Total number of active sessions. pub fn session_count(&self) -> usize { self.sessions.len() } /// Number of tracked IPs with active sessions. pub fn tracked_ips(&self) -> usize { self.ip_session_counts.len() } } #[cfg(test)] mod tests { use super::*; use std::net::{Ipv4Addr, SocketAddrV4}; fn make_addr(port: u16) -> SocketAddr { SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), port)) } fn make_session(client_addr: SocketAddr, cancel: CancellationToken) -> Arc { // Create a dummy backend socket for testing let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); let backend_socket = rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) }); let child_cancel = cancel.child_token(); let return_task = rt.spawn(async move { child_cancel.cancelled().await; }); Arc::new(UdpSession { backend_socket, last_activity: AtomicU64::new(0), created_at: Instant::now(), route_id: None, source_ip: client_addr.ip(), client_addr, return_task, cancel, }) } #[test] fn test_session_table_insert_and_get() { let table = UdpSessionTable::new(); let cancel = CancellationToken::new(); let addr = make_addr(12345); let key: SessionKey = (addr, 53); let session = make_session(addr, cancel); assert!(table.insert(key, session, 1000)); assert!(table.get(&key).is_some()); assert_eq!(table.session_count(), 1); } #[test] fn test_session_table_per_ip_limit() { let table = UdpSessionTable::new(); let ip = Ipv4Addr::new(10, 0, 0, 1); // Insert 2 sessions from same IP, limit is 2 for port in [12345u16, 12346] { let addr = SocketAddr::V4(SocketAddrV4::new(ip, port)); let cancel = CancellationToken::new(); let session = make_session(addr, cancel); assert!(table.insert((addr, 53), session, 2)); } // Third should be rejected let addr3 = SocketAddr::V4(SocketAddrV4::new(ip, 12347)); let cancel3 = CancellationToken::new(); let session3 = make_session(addr3, cancel3); assert!(!table.insert((addr3, 53), session3, 2)); assert_eq!(table.session_count(), 2); } #[test] fn test_session_table_remove() { let table = UdpSessionTable::new(); let cancel = CancellationToken::new(); let addr = make_addr(12345); let key: SessionKey = (addr, 53); let session = make_session(addr, cancel); table.insert(key, session, 1000); assert_eq!(table.session_count(), 1); assert_eq!(table.tracked_ips(), 1); table.remove(&key); assert_eq!(table.session_count(), 0); assert_eq!(table.tracked_ips(), 0); } #[test] fn test_session_config_defaults() { let config = UdpSessionConfig::default(); assert_eq!(config.session_timeout_ms, 60_000); assert_eq!(config.max_sessions_per_ip, 1_000); assert_eq!(config.max_datagram_size, 65_535); } #[test] fn test_session_config_from_route() { let route_udp = rustproxy_config::RouteUdp { session_timeout: Some(10_000), max_sessions_per_ip: Some(500), max_datagram_size: Some(1400), quic: None, }; let config = UdpSessionConfig::from_route_udp(Some(&route_udp)); assert_eq!(config.session_timeout_ms, 10_000); assert_eq!(config.max_sessions_per_ip, 500); assert_eq!(config.max_datagram_size, 1400); } }