use std::collections::HashMap; use std::net::SocketAddr; use tokio::time::Instant; /// Key identifying a unique UDP "session" (one client endpoint talking to one destination port). #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct UdpSessionKey { pub client_addr: SocketAddr, pub dest_port: u16, } /// A single UDP session tracked by the edge. pub struct UdpSession { pub stream_id: u32, pub client_addr: SocketAddr, pub dest_port: u16, pub last_activity: Instant, } /// Manages UDP sessions with idle timeout expiry. pub struct UdpSessionManager { /// Forward map: session key → session data. sessions: HashMap, /// Reverse map: stream_id → session key (for dispatching return traffic). by_stream_id: HashMap, /// Idle timeout duration. idle_timeout: std::time::Duration, } impl UdpSessionManager { pub fn new(idle_timeout: std::time::Duration) -> Self { Self { sessions: HashMap::new(), by_stream_id: HashMap::new(), idle_timeout, } } /// Look up an existing session by key. Updates last_activity on hit. pub fn get_mut(&mut self, key: &UdpSessionKey) -> Option<&mut UdpSession> { let session = self.sessions.get_mut(key)?; session.last_activity = Instant::now(); Some(session) } /// Look up a session's client address by stream_id (for return traffic). pub fn client_addr_for_stream(&self, stream_id: u32) -> Option { let key = self.by_stream_id.get(&stream_id)?; self.sessions.get(key).map(|s| s.client_addr) } /// Look up a session by stream_id. Updates last_activity on hit. pub fn get_by_stream_id(&mut self, stream_id: u32) -> Option<&mut UdpSession> { let key = self.by_stream_id.get(&stream_id)?; let session = self.sessions.get_mut(key)?; session.last_activity = Instant::now(); Some(session) } /// Insert a new session. Returns a mutable reference to it. pub fn insert(&mut self, key: UdpSessionKey, stream_id: u32) -> &mut UdpSession { let session = UdpSession { stream_id, client_addr: key.client_addr, dest_port: key.dest_port, last_activity: Instant::now(), }; self.by_stream_id.insert(stream_id, key); self.sessions.entry(key).or_insert(session) } /// Remove a session by stream_id. pub fn remove_by_stream_id(&mut self, stream_id: u32) -> Option { if let Some(key) = self.by_stream_id.remove(&stream_id) { self.sessions.remove(&key) } else { None } } /// Expire idle sessions. Returns the stream_ids of expired sessions. pub fn expire_idle(&mut self) -> Vec { let now = Instant::now(); let timeout = self.idle_timeout; let expired_keys: Vec = self .sessions .iter() .filter(|(_, s)| now.duration_since(s.last_activity) >= timeout) .map(|(k, _)| *k) .collect(); let mut expired_ids = Vec::with_capacity(expired_keys.len()); for key in expired_keys { if let Some(session) = self.sessions.remove(&key) { self.by_stream_id.remove(&session.stream_id); expired_ids.push(session.stream_id); } } expired_ids } /// Number of active sessions. pub fn len(&self) -> usize { self.sessions.len() } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; fn addr(port: u16) -> SocketAddr { SocketAddr::from(([127, 0, 0, 1], port)) } #[test] fn test_insert_and_lookup() { let mut mgr = UdpSessionManager::new(Duration::from_secs(60)); let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 }; mgr.insert(key, 1); assert_eq!(mgr.len(), 1); assert!(mgr.get_mut(&key).is_some()); assert_eq!(mgr.get_mut(&key).unwrap().stream_id, 1); } #[test] fn test_client_addr_for_stream() { let mut mgr = UdpSessionManager::new(Duration::from_secs(60)); let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 }; mgr.insert(key, 42); assert_eq!(mgr.client_addr_for_stream(42), Some(addr(5000))); assert_eq!(mgr.client_addr_for_stream(99), None); } #[test] fn test_remove_by_stream_id() { let mut mgr = UdpSessionManager::new(Duration::from_secs(60)); let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 }; mgr.insert(key, 1); let removed = mgr.remove_by_stream_id(1); assert!(removed.is_some()); assert_eq!(mgr.len(), 0); assert!(mgr.get_mut(&key).is_none()); assert_eq!(mgr.client_addr_for_stream(1), None); } #[test] fn test_remove_nonexistent() { let mut mgr = UdpSessionManager::new(Duration::from_secs(60)); assert!(mgr.remove_by_stream_id(999).is_none()); } #[tokio::test] async fn test_expire_idle() { let mut mgr = UdpSessionManager::new(Duration::from_millis(50)); let key1 = UdpSessionKey { client_addr: addr(5000), dest_port: 53 }; let key2 = UdpSessionKey { client_addr: addr(5001), dest_port: 53 }; mgr.insert(key1, 1); mgr.insert(key2, 2); // Nothing expired yet assert!(mgr.expire_idle().is_empty()); assert_eq!(mgr.len(), 2); // Wait for timeout tokio::time::sleep(Duration::from_millis(60)).await; let expired = mgr.expire_idle(); assert_eq!(expired.len(), 2); assert_eq!(mgr.len(), 0); } #[tokio::test] async fn test_activity_prevents_expiry() { let mut mgr = UdpSessionManager::new(Duration::from_millis(100)); let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 }; mgr.insert(key, 1); // Touch session at 50ms (before 100ms timeout) tokio::time::sleep(Duration::from_millis(50)).await; mgr.get_mut(&key); // refreshes last_activity // At 80ms from last touch, should still be alive tokio::time::sleep(Duration::from_millis(80)).await; assert!(mgr.expire_idle().is_empty()); assert_eq!(mgr.len(), 1); // Wait for full timeout from last activity tokio::time::sleep(Duration::from_millis(30)).await; let expired = mgr.expire_idle(); assert_eq!(expired.len(), 1); } #[test] fn test_multiple_sessions_same_client_different_ports() { let mut mgr = UdpSessionManager::new(Duration::from_secs(60)); let key1 = UdpSessionKey { client_addr: addr(5000), dest_port: 53 }; let key2 = UdpSessionKey { client_addr: addr(5000), dest_port: 443 }; mgr.insert(key1, 1); mgr.insert(key2, 2); assert_eq!(mgr.len(), 2); assert_eq!(mgr.get_mut(&key1).unwrap().stream_id, 1); assert_eq!(mgr.get_mut(&key2).unwrap().stream_id, 2); } }