325 lines
10 KiB
Rust
325 lines
10 KiB
Rust
//! UDP session (flow) tracking.
|
|
//!
|
|
//! A UDP "session" is a flow identified by (client_addr, listening_port).
|
|
//! Each session maintains a backend socket bound to an ephemeral port and
|
|
//! connected to the backend target, plus a background task that relays
|
|
//! return datagrams from the backend back to the client.
|
|
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use dashmap::DashMap;
|
|
use tokio::net::UdpSocket;
|
|
use tokio::task::JoinHandle;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::debug;
|
|
|
|
use rustproxy_metrics::MetricsCollector;
|
|
|
|
use crate::connection_tracker::ConnectionTracker;
|
|
|
|
/// A single UDP session (flow).
|
|
pub struct UdpSession {
|
|
/// Socket bound to ephemeral port, connected to backend
|
|
pub backend_socket: Arc<UdpSocket>,
|
|
/// Milliseconds since the session table's epoch
|
|
pub last_activity: AtomicU64,
|
|
/// When the session was created
|
|
pub created_at: Instant,
|
|
/// Route ID for metrics
|
|
pub route_id: Option<String>,
|
|
/// Source IP for metrics/tracking
|
|
pub source_ip: IpAddr,
|
|
/// Client address (for return path)
|
|
pub client_addr: SocketAddr,
|
|
/// Handle for the return-path relay task
|
|
pub return_task: JoinHandle<()>,
|
|
/// Per-session cancellation
|
|
pub cancel: CancellationToken,
|
|
}
|
|
|
|
impl Drop for UdpSession {
|
|
fn drop(&mut self) {
|
|
self.cancel.cancel();
|
|
self.return_task.abort();
|
|
}
|
|
}
|
|
|
|
/// Configuration for UDP session behavior.
|
|
#[derive(Debug, Clone)]
|
|
pub struct UdpSessionConfig {
|
|
/// Idle timeout in milliseconds. Default: 60000.
|
|
pub session_timeout_ms: u64,
|
|
/// Max concurrent sessions per source IP. Default: 1000.
|
|
pub max_sessions_per_ip: u32,
|
|
/// Max accepted datagram size in bytes. Default: 65535.
|
|
pub max_datagram_size: u32,
|
|
}
|
|
|
|
impl Default for UdpSessionConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
session_timeout_ms: 60_000,
|
|
max_sessions_per_ip: 1_000,
|
|
max_datagram_size: 65_535,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl UdpSessionConfig {
|
|
/// Build from route's UDP config, falling back to defaults.
|
|
pub fn from_route_udp(udp: Option<&rustproxy_config::RouteUdp>) -> Self {
|
|
match udp {
|
|
Some(u) => Self {
|
|
session_timeout_ms: u.session_timeout.unwrap_or(60_000),
|
|
max_sessions_per_ip: u.max_sessions_per_ip.unwrap_or(1_000),
|
|
max_datagram_size: u.max_datagram_size.unwrap_or(65_535),
|
|
},
|
|
None => Self::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Session key: (client address, listening port).
|
|
pub type SessionKey = (SocketAddr, u16);
|
|
|
|
/// Tracks all active UDP sessions across all ports.
|
|
pub struct UdpSessionTable {
|
|
/// Active sessions keyed by (client_addr, listen_port)
|
|
sessions: DashMap<SessionKey, Arc<UdpSession>>,
|
|
/// Per-IP session counts for enforcing limits
|
|
ip_session_counts: DashMap<IpAddr, u32>,
|
|
/// Time reference for last_activity
|
|
epoch: Instant,
|
|
}
|
|
|
|
impl UdpSessionTable {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
sessions: DashMap::new(),
|
|
ip_session_counts: DashMap::new(),
|
|
epoch: Instant::now(),
|
|
}
|
|
}
|
|
|
|
/// Get elapsed milliseconds since epoch (for last_activity tracking).
|
|
pub fn elapsed_ms(&self) -> u64 {
|
|
self.epoch.elapsed().as_millis() as u64
|
|
}
|
|
|
|
/// Look up an existing session.
|
|
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
|
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
|
|
}
|
|
|
|
/// Check if we can create a new session for this IP (under the per-IP limit).
|
|
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
|
let count = self.ip_session_counts
|
|
.get(ip)
|
|
.map(|c| *c.value())
|
|
.unwrap_or(0);
|
|
count < max_per_ip
|
|
}
|
|
|
|
/// Insert a new session. Returns false if per-IP limit exceeded.
|
|
pub fn insert(
|
|
&self,
|
|
key: SessionKey,
|
|
session: Arc<UdpSession>,
|
|
max_per_ip: u32,
|
|
) -> bool {
|
|
let ip = session.source_ip;
|
|
|
|
// Atomically check and increment per-IP count
|
|
let mut count_entry = self.ip_session_counts.entry(ip).or_insert(0);
|
|
if *count_entry.value() >= max_per_ip {
|
|
return false;
|
|
}
|
|
*count_entry.value_mut() += 1;
|
|
drop(count_entry);
|
|
|
|
self.sessions.insert(key, session);
|
|
true
|
|
}
|
|
|
|
/// Remove a session and decrement per-IP count.
|
|
pub fn remove(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
|
if let Some((_, session)) = self.sessions.remove(key) {
|
|
let ip = session.source_ip;
|
|
if let Some(mut count) = self.ip_session_counts.get_mut(&ip) {
|
|
*count.value_mut() = count.value().saturating_sub(1);
|
|
if *count.value() == 0 {
|
|
drop(count);
|
|
self.ip_session_counts.remove(&ip);
|
|
}
|
|
}
|
|
Some(session)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Clean up idle sessions past the given timeout.
|
|
/// Returns the number of sessions removed.
|
|
pub fn cleanup_idle(
|
|
&self,
|
|
timeout_ms: u64,
|
|
metrics: &MetricsCollector,
|
|
conn_tracker: &ConnectionTracker,
|
|
) -> usize {
|
|
let now_ms = self.elapsed_ms();
|
|
let mut removed = 0;
|
|
|
|
// Collect keys to remove (avoid holding DashMap refs during removal)
|
|
let stale_keys: Vec<SessionKey> = self.sessions.iter()
|
|
.filter(|entry| {
|
|
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
|
now_ms.saturating_sub(last) >= timeout_ms
|
|
})
|
|
.map(|entry| *entry.key())
|
|
.collect();
|
|
|
|
for key in stale_keys {
|
|
if let Some(session) = self.remove(&key) {
|
|
debug!(
|
|
"UDP session expired: {} -> port {} (idle {}ms)",
|
|
session.client_addr, key.1,
|
|
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
|
);
|
|
conn_tracker.connection_closed(&session.source_ip);
|
|
metrics.connection_closed(
|
|
session.route_id.as_deref(),
|
|
Some(&session.source_ip.to_string()),
|
|
);
|
|
metrics.udp_session_closed();
|
|
removed += 1;
|
|
}
|
|
}
|
|
|
|
removed
|
|
}
|
|
|
|
/// Total number of active sessions.
|
|
pub fn session_count(&self) -> usize {
|
|
self.sessions.len()
|
|
}
|
|
|
|
/// Number of tracked IPs with active sessions.
|
|
pub fn tracked_ips(&self) -> usize {
|
|
self.ip_session_counts.len()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::net::{Ipv4Addr, SocketAddrV4};
|
|
|
|
fn make_addr(port: u16) -> SocketAddr {
|
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), port))
|
|
}
|
|
|
|
fn make_session(client_addr: SocketAddr, cancel: CancellationToken) -> Arc<UdpSession> {
|
|
// Create a dummy backend socket for testing
|
|
let rt = tokio::runtime::Builder::new_current_thread()
|
|
.enable_all()
|
|
.build()
|
|
.unwrap();
|
|
let backend_socket = rt.block_on(async {
|
|
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
|
|
});
|
|
|
|
let child_cancel = cancel.child_token();
|
|
let return_task = rt.spawn(async move {
|
|
child_cancel.cancelled().await;
|
|
});
|
|
|
|
Arc::new(UdpSession {
|
|
backend_socket,
|
|
last_activity: AtomicU64::new(0),
|
|
created_at: Instant::now(),
|
|
route_id: None,
|
|
source_ip: client_addr.ip(),
|
|
client_addr,
|
|
return_task,
|
|
cancel,
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_table_insert_and_get() {
|
|
let table = UdpSessionTable::new();
|
|
let cancel = CancellationToken::new();
|
|
let addr = make_addr(12345);
|
|
let key: SessionKey = (addr, 53);
|
|
let session = make_session(addr, cancel);
|
|
|
|
assert!(table.insert(key, session, 1000));
|
|
assert!(table.get(&key).is_some());
|
|
assert_eq!(table.session_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_table_per_ip_limit() {
|
|
let table = UdpSessionTable::new();
|
|
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
|
|
|
// Insert 2 sessions from same IP, limit is 2
|
|
for port in [12345u16, 12346] {
|
|
let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
|
|
let cancel = CancellationToken::new();
|
|
let session = make_session(addr, cancel);
|
|
assert!(table.insert((addr, 53), session, 2));
|
|
}
|
|
|
|
// Third should be rejected
|
|
let addr3 = SocketAddr::V4(SocketAddrV4::new(ip, 12347));
|
|
let cancel3 = CancellationToken::new();
|
|
let session3 = make_session(addr3, cancel3);
|
|
assert!(!table.insert((addr3, 53), session3, 2));
|
|
|
|
assert_eq!(table.session_count(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_table_remove() {
|
|
let table = UdpSessionTable::new();
|
|
let cancel = CancellationToken::new();
|
|
let addr = make_addr(12345);
|
|
let key: SessionKey = (addr, 53);
|
|
let session = make_session(addr, cancel);
|
|
|
|
table.insert(key, session, 1000);
|
|
assert_eq!(table.session_count(), 1);
|
|
assert_eq!(table.tracked_ips(), 1);
|
|
|
|
table.remove(&key);
|
|
assert_eq!(table.session_count(), 0);
|
|
assert_eq!(table.tracked_ips(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_config_defaults() {
|
|
let config = UdpSessionConfig::default();
|
|
assert_eq!(config.session_timeout_ms, 60_000);
|
|
assert_eq!(config.max_sessions_per_ip, 1_000);
|
|
assert_eq!(config.max_datagram_size, 65_535);
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_config_from_route() {
|
|
let route_udp = rustproxy_config::RouteUdp {
|
|
session_timeout: Some(10_000),
|
|
max_sessions_per_ip: Some(500),
|
|
max_datagram_size: Some(1400),
|
|
quic: None,
|
|
};
|
|
let config = UdpSessionConfig::from_route_udp(Some(&route_udp));
|
|
assert_eq!(config.session_timeout_ms, 10_000);
|
|
assert_eq!(config.max_sessions_per_ip, 500);
|
|
assert_eq!(config.max_datagram_size, 1400);
|
|
}
|
|
}
|