use dashmap::DashMap; use std::collections::VecDeque; use std::net::IpAddr; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; use tracing::{debug, warn}; use super::connection_record::ConnectionRecord; /// Thresholds for zombie detection (non-TLS connections). const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30); /// Thresholds for zombie detection (TLS connections). const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300); /// Stuck connection timeout (non-TLS): received data but never sent any. const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60); /// Stuck connection timeout (TLS): received data but never sent any. const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300); /// Tracks active connections per IP and enforces per-IP limits and rate limiting. /// Also maintains per-connection records for zombie detection. pub struct ConnectionTracker { /// Active connection counts per IP active: DashMap, /// Connection timestamps per IP for rate limiting timestamps: DashMap>, /// Maximum concurrent connections per IP (None = unlimited) max_per_ip: Option, /// Maximum new connections per minute per IP (None = unlimited) rate_limit_per_minute: Option, /// Per-connection tracking records for zombie detection connections: DashMap>, /// Monotonically increasing connection ID counter next_id: AtomicU64, } impl ConnectionTracker { pub fn new(max_per_ip: Option, rate_limit_per_minute: Option) -> Self { Self { active: DashMap::new(), timestamps: DashMap::new(), max_per_ip, rate_limit_per_minute, connections: DashMap::new(), next_id: AtomicU64::new(1), } } /// Try to accept a new connection from the given IP. /// Returns true if allowed, false if over limit. pub fn try_accept(&self, ip: &IpAddr) -> bool { // Check per-IP connection limit if let Some(max) = self.max_per_ip { let count = self.active .get(ip) .map(|c| c.value().load(Ordering::Relaxed)) .unwrap_or(0); if count >= max { return false; } } // Check rate limit if let Some(rate_limit) = self.rate_limit_per_minute { let now = Instant::now(); let one_minute = std::time::Duration::from_secs(60); let mut entry = self.timestamps.entry(*ip).or_default(); let timestamps = entry.value_mut(); // Remove timestamps older than 1 minute while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) { timestamps.pop_front(); } if timestamps.len() as u64 >= rate_limit { return false; } timestamps.push_back(now); } true } /// Record that a connection was opened from the given IP. pub fn connection_opened(&self, ip: &IpAddr) { self.active .entry(*ip) .or_insert_with(|| AtomicU64::new(0)) .value() .fetch_add(1, Ordering::Relaxed); } /// Record that a connection was closed from the given IP. pub fn connection_closed(&self, ip: &IpAddr) { if let Some(counter) = self.active.get(ip) { let prev = counter.value().fetch_sub(1, Ordering::Relaxed); // Clean up zero entries if prev <= 1 { drop(counter); self.active.remove(ip); } } } /// Get the current number of active connections for an IP. pub fn active_connections(&self, ip: &IpAddr) -> u64 { self.active .get(ip) .map(|c| c.value().load(Ordering::Relaxed)) .unwrap_or(0) } /// Get the total number of tracked IPs. pub fn tracked_ips(&self) -> usize { self.active.len() } /// Register a new connection and return its tracking record. /// /// The returned `Arc` should be passed to the forwarding /// loop so it can update bytes / activity atomics in real time. pub fn register_connection(&self, is_tls: bool) -> Arc { let id = self.next_id.fetch_add(1, Ordering::Relaxed); let record = Arc::new(ConnectionRecord::new(id)); record.is_tls.store(is_tls, Ordering::Relaxed); self.connections.insert(id, Arc::clone(&record)); record } /// Remove a connection record when the connection is fully closed. pub fn unregister_connection(&self, id: u64) { self.connections.remove(&id); } /// Scan all tracked connections and return IDs of zombie connections. /// /// A connection is considered a zombie in any of these cases: /// - **Full zombie**: both `client_closed` and `backend_closed` are true. /// - **Half zombie**: one side closed for longer than the threshold /// (5 min for TLS, 30s for non-TLS). /// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer /// than the stuck threshold (5 min for TLS, 60s for non-TLS). pub fn scan_zombies(&self) -> Vec { let mut zombies = Vec::new(); for entry in self.connections.iter() { let record = entry.value(); let id = *entry.key(); let is_tls = record.is_tls.load(Ordering::Relaxed); let client_closed = record.client_closed.load(Ordering::Relaxed); let backend_closed = record.backend_closed.load(Ordering::Relaxed); let idle = record.idle_duration(); let bytes_in = record.bytes_received.load(Ordering::Relaxed); let bytes_out = record.bytes_sent.load(Ordering::Relaxed); // Full zombie: both sides closed if client_closed && backend_closed { zombies.push(id); continue; } // Half zombie: one side closed for too long let half_timeout = if is_tls { HALF_ZOMBIE_TIMEOUT_TLS } else { HALF_ZOMBIE_TIMEOUT_PLAIN }; if (client_closed || backend_closed) && idle >= half_timeout { zombies.push(id); continue; } // Stuck: received data but never sent anything for too long let stuck_timeout = if is_tls { STUCK_TIMEOUT_TLS } else { STUCK_TIMEOUT_PLAIN }; if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout { zombies.push(id); } } zombies } /// Start a background task that periodically scans for zombie connections. /// /// The scanner runs every 10 seconds and logs any zombies it finds. /// It stops when the provided `CancellationToken` is cancelled. pub fn start_zombie_scanner(self: &Arc, cancel: CancellationToken) { let tracker = Arc::clone(self); tokio::spawn(async move { let interval = Duration::from_secs(10); loop { tokio::select! { _ = cancel.cancelled() => { debug!("Zombie scanner shutting down"); break; } _ = tokio::time::sleep(interval) => { let zombies = tracker.scan_zombies(); if !zombies.is_empty() { warn!( "Detected {} zombie connection(s): {:?}", zombies.len(), zombies ); } } } } }); } /// Get the total number of tracked connections (with records). pub fn total_connections(&self) -> usize { self.connections.len() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tracking() { let tracker = ConnectionTracker::new(None, None); let ip: IpAddr = "127.0.0.1".parse().unwrap(); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); assert_eq!(tracker.active_connections(&ip), 1); tracker.connection_opened(&ip); assert_eq!(tracker.active_connections(&ip), 2); tracker.connection_closed(&ip); assert_eq!(tracker.active_connections(&ip), 1); tracker.connection_closed(&ip); assert_eq!(tracker.active_connections(&ip), 0); } #[test] fn test_per_ip_limit() { let tracker = ConnectionTracker::new(Some(2), None); let ip: IpAddr = "10.0.0.1".parse().unwrap(); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); // Third connection should be rejected assert!(!tracker.try_accept(&ip)); // Different IP should still be allowed let ip2: IpAddr = "10.0.0.2".parse().unwrap(); assert!(tracker.try_accept(&ip2)); } #[test] fn test_rate_limit() { let tracker = ConnectionTracker::new(None, Some(3)); let ip: IpAddr = "10.0.0.1".parse().unwrap(); assert!(tracker.try_accept(&ip)); assert!(tracker.try_accept(&ip)); assert!(tracker.try_accept(&ip)); // 4th attempt within the minute should be rejected assert!(!tracker.try_accept(&ip)); } #[test] fn test_no_limits() { let tracker = ConnectionTracker::new(None, None); let ip: IpAddr = "10.0.0.1".parse().unwrap(); for _ in 0..1000 { assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); } assert_eq!(tracker.active_connections(&ip), 1000); } #[test] fn test_tracked_ips() { let tracker = ConnectionTracker::new(None, None); assert_eq!(tracker.tracked_ips(), 0); let ip1: IpAddr = "10.0.0.1".parse().unwrap(); let ip2: IpAddr = "10.0.0.2".parse().unwrap(); tracker.connection_opened(&ip1); tracker.connection_opened(&ip2); assert_eq!(tracker.tracked_ips(), 2); tracker.connection_closed(&ip1); assert_eq!(tracker.tracked_ips(), 1); } #[test] fn test_register_unregister_connection() { let tracker = ConnectionTracker::new(None, None); assert_eq!(tracker.total_connections(), 0); let record1 = tracker.register_connection(false); assert_eq!(tracker.total_connections(), 1); assert!(!record1.is_tls.load(Ordering::Relaxed)); let record2 = tracker.register_connection(true); assert_eq!(tracker.total_connections(), 2); assert!(record2.is_tls.load(Ordering::Relaxed)); // IDs should be unique assert_ne!(record1.id, record2.id); tracker.unregister_connection(record1.id); assert_eq!(tracker.total_connections(), 1); tracker.unregister_connection(record2.id); assert_eq!(tracker.total_connections(), 0); } #[test] fn test_full_zombie_detection() { let tracker = ConnectionTracker::new(None, None); let record = tracker.register_connection(false); // Not a zombie initially assert!(tracker.scan_zombies().is_empty()); // Set both sides closed -> full zombie record.client_closed.store(true, Ordering::Relaxed); record.backend_closed.store(true, Ordering::Relaxed); let zombies = tracker.scan_zombies(); assert_eq!(zombies.len(), 1); assert_eq!(zombies[0], record.id); } #[test] fn test_half_zombie_not_triggered_immediately() { let tracker = ConnectionTracker::new(None, None); let record = tracker.register_connection(false); record.touch(); // mark activity now // Only one side closed, but just now -> not a zombie yet record.client_closed.store(true, Ordering::Relaxed); assert!(tracker.scan_zombies().is_empty()); } #[test] fn test_stuck_connection_not_triggered_immediately() { let tracker = ConnectionTracker::new(None, None); let record = tracker.register_connection(false); record.touch(); // mark activity now // Has received data but sent nothing -> but just started, not stuck yet record.bytes_received.store(1000, Ordering::Relaxed); assert!(tracker.scan_zombies().is_empty()); } #[test] fn test_unregister_removes_from_zombie_scan() { let tracker = ConnectionTracker::new(None, None); let record = tracker.register_connection(false); let id = record.id; // Make it a full zombie record.client_closed.store(true, Ordering::Relaxed); record.backend_closed.store(true, Ordering::Relaxed); assert_eq!(tracker.scan_zombies().len(), 1); // Unregister should remove it tracker.unregister_connection(id); assert!(tracker.scan_zombies().is_empty()); } #[test] fn test_total_connections() { let tracker = ConnectionTracker::new(None, None); assert_eq!(tracker.total_connections(), 0); let r1 = tracker.register_connection(false); let r2 = tracker.register_connection(true); let r3 = tracker.register_connection(false); assert_eq!(tracker.total_connections(), 3); tracker.unregister_connection(r2.id); assert_eq!(tracker.total_connections(), 2); tracker.unregister_connection(r1.id); tracker.unregister_connection(r3.id); assert_eq!(tracker.total_connections(), 0); } }