211 lines
7.0 KiB
Rust
211 lines
7.0 KiB
Rust
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use tokio::time::Instant;
|
|
|
|
/// Key identifying a unique UDP "session" (one client endpoint talking to one destination port).
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub struct UdpSessionKey {
|
|
pub client_addr: SocketAddr,
|
|
pub dest_port: u16,
|
|
}
|
|
|
|
/// A single UDP session tracked by the edge.
|
|
pub struct UdpSession {
|
|
pub stream_id: u32,
|
|
pub client_addr: SocketAddr,
|
|
pub dest_port: u16,
|
|
pub last_activity: Instant,
|
|
}
|
|
|
|
/// Manages UDP sessions with idle timeout expiry.
|
|
pub struct UdpSessionManager {
|
|
/// Forward map: session key → session data.
|
|
sessions: HashMap<UdpSessionKey, UdpSession>,
|
|
/// Reverse map: stream_id → session key (for dispatching return traffic).
|
|
by_stream_id: HashMap<u32, UdpSessionKey>,
|
|
/// Idle timeout duration.
|
|
idle_timeout: std::time::Duration,
|
|
}
|
|
|
|
impl UdpSessionManager {
|
|
pub fn new(idle_timeout: std::time::Duration) -> Self {
|
|
Self {
|
|
sessions: HashMap::new(),
|
|
by_stream_id: HashMap::new(),
|
|
idle_timeout,
|
|
}
|
|
}
|
|
|
|
/// Look up an existing session by key. Updates last_activity on hit.
|
|
pub fn get_mut(&mut self, key: &UdpSessionKey) -> Option<&mut UdpSession> {
|
|
let session = self.sessions.get_mut(key)?;
|
|
session.last_activity = Instant::now();
|
|
Some(session)
|
|
}
|
|
|
|
/// Look up a session's client address by stream_id (for return traffic).
|
|
pub fn client_addr_for_stream(&self, stream_id: u32) -> Option<SocketAddr> {
|
|
let key = self.by_stream_id.get(&stream_id)?;
|
|
self.sessions.get(key).map(|s| s.client_addr)
|
|
}
|
|
|
|
/// Look up a session by stream_id. Updates last_activity on hit.
|
|
pub fn get_by_stream_id(&mut self, stream_id: u32) -> Option<&mut UdpSession> {
|
|
let key = self.by_stream_id.get(&stream_id)?;
|
|
let session = self.sessions.get_mut(key)?;
|
|
session.last_activity = Instant::now();
|
|
Some(session)
|
|
}
|
|
|
|
/// Insert a new session. Returns a mutable reference to it.
|
|
pub fn insert(&mut self, key: UdpSessionKey, stream_id: u32) -> &mut UdpSession {
|
|
let session = UdpSession {
|
|
stream_id,
|
|
client_addr: key.client_addr,
|
|
dest_port: key.dest_port,
|
|
last_activity: Instant::now(),
|
|
};
|
|
self.by_stream_id.insert(stream_id, key);
|
|
self.sessions.entry(key).or_insert(session)
|
|
}
|
|
|
|
/// Remove a session by stream_id.
|
|
pub fn remove_by_stream_id(&mut self, stream_id: u32) -> Option<UdpSession> {
|
|
if let Some(key) = self.by_stream_id.remove(&stream_id) {
|
|
self.sessions.remove(&key)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Expire idle sessions. Returns the stream_ids of expired sessions.
|
|
pub fn expire_idle(&mut self) -> Vec<u32> {
|
|
let now = Instant::now();
|
|
let timeout = self.idle_timeout;
|
|
let expired_keys: Vec<UdpSessionKey> = self
|
|
.sessions
|
|
.iter()
|
|
.filter(|(_, s)| now.duration_since(s.last_activity) >= timeout)
|
|
.map(|(k, _)| *k)
|
|
.collect();
|
|
|
|
let mut expired_ids = Vec::with_capacity(expired_keys.len());
|
|
for key in expired_keys {
|
|
if let Some(session) = self.sessions.remove(&key) {
|
|
self.by_stream_id.remove(&session.stream_id);
|
|
expired_ids.push(session.stream_id);
|
|
}
|
|
}
|
|
expired_ids
|
|
}
|
|
|
|
/// Number of active sessions.
|
|
pub fn len(&self) -> usize {
|
|
self.sessions.len()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::time::Duration;
|
|
|
|
fn addr(port: u16) -> SocketAddr {
|
|
SocketAddr::from(([127, 0, 0, 1], port))
|
|
}
|
|
|
|
#[test]
|
|
fn test_insert_and_lookup() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_secs(60));
|
|
let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 };
|
|
mgr.insert(key, 1);
|
|
|
|
assert_eq!(mgr.len(), 1);
|
|
assert!(mgr.get_mut(&key).is_some());
|
|
assert_eq!(mgr.get_mut(&key).unwrap().stream_id, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_client_addr_for_stream() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_secs(60));
|
|
let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 };
|
|
mgr.insert(key, 42);
|
|
|
|
assert_eq!(mgr.client_addr_for_stream(42), Some(addr(5000)));
|
|
assert_eq!(mgr.client_addr_for_stream(99), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_remove_by_stream_id() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_secs(60));
|
|
let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 };
|
|
mgr.insert(key, 1);
|
|
|
|
let removed = mgr.remove_by_stream_id(1);
|
|
assert!(removed.is_some());
|
|
assert_eq!(mgr.len(), 0);
|
|
assert!(mgr.get_mut(&key).is_none());
|
|
assert_eq!(mgr.client_addr_for_stream(1), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_remove_nonexistent() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_secs(60));
|
|
assert!(mgr.remove_by_stream_id(999).is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_expire_idle() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_millis(50));
|
|
let key1 = UdpSessionKey { client_addr: addr(5000), dest_port: 53 };
|
|
let key2 = UdpSessionKey { client_addr: addr(5001), dest_port: 53 };
|
|
mgr.insert(key1, 1);
|
|
mgr.insert(key2, 2);
|
|
|
|
// Nothing expired yet
|
|
assert!(mgr.expire_idle().is_empty());
|
|
assert_eq!(mgr.len(), 2);
|
|
|
|
// Wait for timeout
|
|
tokio::time::sleep(Duration::from_millis(60)).await;
|
|
|
|
let expired = mgr.expire_idle();
|
|
assert_eq!(expired.len(), 2);
|
|
assert_eq!(mgr.len(), 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_activity_prevents_expiry() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_millis(100));
|
|
let key = UdpSessionKey { client_addr: addr(5000), dest_port: 53 };
|
|
mgr.insert(key, 1);
|
|
|
|
// Touch session at 50ms (before 100ms timeout)
|
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
|
mgr.get_mut(&key); // refreshes last_activity
|
|
|
|
// At 80ms from last touch, should still be alive
|
|
tokio::time::sleep(Duration::from_millis(80)).await;
|
|
assert!(mgr.expire_idle().is_empty());
|
|
assert_eq!(mgr.len(), 1);
|
|
|
|
// Wait for full timeout from last activity
|
|
tokio::time::sleep(Duration::from_millis(30)).await;
|
|
let expired = mgr.expire_idle();
|
|
assert_eq!(expired.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_sessions_same_client_different_ports() {
|
|
let mut mgr = UdpSessionManager::new(Duration::from_secs(60));
|
|
let key1 = UdpSessionKey { client_addr: addr(5000), dest_port: 53 };
|
|
let key2 = UdpSessionKey { client_addr: addr(5000), dest_port: 443 };
|
|
mgr.insert(key1, 1);
|
|
mgr.insert(key2, 2);
|
|
|
|
assert_eq!(mgr.len(), 2);
|
|
assert_eq!(mgr.get_mut(&key1).unwrap().stream_id, 1);
|
|
assert_eq!(mgr.get_mut(&key2).unwrap().stream_id, 2);
|
|
}
|
|
}
|