use dashmap::DashMap; use std::collections::VecDeque; use std::net::IpAddr; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; /// Tracks active connections per IP and enforces per-IP limits and rate limiting. pub struct ConnectionTracker { /// Active connection counts per IP active: DashMap, /// Connection timestamps per IP for rate limiting timestamps: DashMap>, /// Maximum concurrent connections per IP (None = unlimited) max_per_ip: Option, /// Maximum new connections per minute per IP (None = unlimited) rate_limit_per_minute: Option, } impl ConnectionTracker { pub fn new(max_per_ip: Option, rate_limit_per_minute: Option) -> Self { Self { active: DashMap::new(), timestamps: DashMap::new(), max_per_ip, rate_limit_per_minute, } } /// Try to accept a new connection from the given IP. /// Returns true if allowed, false if over limit. pub fn try_accept(&self, ip: &IpAddr) -> bool { // Check per-IP connection limit if let Some(max) = self.max_per_ip { let count = self.active .get(ip) .map(|c| c.value().load(Ordering::Relaxed)) .unwrap_or(0); if count >= max { return false; } } // Check rate limit if let Some(rate_limit) = self.rate_limit_per_minute { let now = Instant::now(); let one_minute = std::time::Duration::from_secs(60); let mut entry = self.timestamps.entry(*ip).or_default(); let timestamps = entry.value_mut(); // Remove timestamps older than 1 minute while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) { timestamps.pop_front(); } if timestamps.len() as u64 >= rate_limit { return false; } timestamps.push_back(now); } true } /// Record that a connection was opened from the given IP. pub fn connection_opened(&self, ip: &IpAddr) { self.active .entry(*ip) .or_insert_with(|| AtomicU64::new(0)) .value() .fetch_add(1, Ordering::Relaxed); } /// Record that a connection was closed from the given IP. pub fn connection_closed(&self, ip: &IpAddr) { if let Some(counter) = self.active.get(ip) { let prev = counter.value().fetch_sub(1, Ordering::Relaxed); // Clean up zero entries to prevent memory growth if prev <= 1 { drop(counter); self.active.remove(ip); self.timestamps.remove(ip); } } } /// Get the current number of active connections for an IP. pub fn active_connections(&self, ip: &IpAddr) -> u64 { self.active .get(ip) .map(|c| c.value().load(Ordering::Relaxed)) .unwrap_or(0) } /// Prune stale timestamp entries for IPs that have no active connections /// and no recent timestamps. This cleans up entries left by rate-limited IPs /// that never had connection_opened called. pub fn cleanup_stale_timestamps(&self) { if self.rate_limit_per_minute.is_none() { return; // No rate limiting — timestamps map should be empty } let now = Instant::now(); let one_minute = Duration::from_secs(60); self.timestamps.retain(|ip, timestamps| { timestamps.retain(|t| now.duration_since(*t) < one_minute); // Keep if there are active connections or recent timestamps !timestamps.is_empty() || self.active.contains_key(ip) }); } /// Get the total number of tracked IPs. pub fn tracked_ips(&self) -> usize { self.active.len() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tracking() { let tracker = ConnectionTracker::new(None, None); let ip: IpAddr = "127.0.0.1".parse().unwrap(); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); assert_eq!(tracker.active_connections(&ip), 1); tracker.connection_opened(&ip); assert_eq!(tracker.active_connections(&ip), 2); tracker.connection_closed(&ip); assert_eq!(tracker.active_connections(&ip), 1); tracker.connection_closed(&ip); assert_eq!(tracker.active_connections(&ip), 0); } #[test] fn test_per_ip_limit() { let tracker = ConnectionTracker::new(Some(2), None); let ip: IpAddr = "10.0.0.1".parse().unwrap(); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); // Third connection should be rejected assert!(!tracker.try_accept(&ip)); // Different IP should still be allowed let ip2: IpAddr = "10.0.0.2".parse().unwrap(); assert!(tracker.try_accept(&ip2)); } #[test] fn test_rate_limit() { let tracker = ConnectionTracker::new(None, Some(3)); let ip: IpAddr = "10.0.0.1".parse().unwrap(); assert!(tracker.try_accept(&ip)); assert!(tracker.try_accept(&ip)); assert!(tracker.try_accept(&ip)); // 4th attempt within the minute should be rejected assert!(!tracker.try_accept(&ip)); } #[test] fn test_no_limits() { let tracker = ConnectionTracker::new(None, None); let ip: IpAddr = "10.0.0.1".parse().unwrap(); for _ in 0..1000 { assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); } assert_eq!(tracker.active_connections(&ip), 1000); } #[test] fn test_tracked_ips() { let tracker = ConnectionTracker::new(None, None); assert_eq!(tracker.tracked_ips(), 0); let ip1: IpAddr = "10.0.0.1".parse().unwrap(); let ip2: IpAddr = "10.0.0.2".parse().unwrap(); tracker.connection_opened(&ip1); tracker.connection_opened(&ip2); assert_eq!(tracker.tracked_ips(), 2); tracker.connection_closed(&ip1); assert_eq!(tracker.tracked_ips(), 1); } #[test] fn test_timestamps_cleaned_on_last_close() { let tracker = ConnectionTracker::new(None, Some(100)); let ip: IpAddr = "10.0.0.1".parse().unwrap(); // try_accept populates the timestamps map (when rate limiting is enabled) assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); assert!(tracker.try_accept(&ip)); tracker.connection_opened(&ip); // Timestamps should exist assert!(tracker.timestamps.get(&ip).is_some()); // Close one connection — timestamps should still exist tracker.connection_closed(&ip); assert!(tracker.timestamps.get(&ip).is_some()); // Close last connection — timestamps should be cleaned up tracker.connection_closed(&ip); assert!(tracker.timestamps.get(&ip).is_none()); assert!(tracker.active.get(&ip).is_none()); } #[test] fn test_cleanup_stale_timestamps() { // Rate limit of 100/min so timestamps are tracked let tracker = ConnectionTracker::new(None, Some(100)); let ip: IpAddr = "10.0.0.1".parse().unwrap(); // try_accept adds a timestamp entry assert!(tracker.try_accept(&ip)); // Simulate: connection was rate-limited and never accepted, // so no connection_opened / connection_closed pair assert!(tracker.timestamps.get(&ip).is_some()); assert!(tracker.active.get(&ip).is_none()); // never opened // Cleanup won't remove it yet because timestamp is recent tracker.cleanup_stale_timestamps(); assert!(tracker.timestamps.get(&ip).is_some()); // After expiry (use 0-second window trick: create tracker with 0 rate) // Actually, we can't fast-forward time easily, so just verify the cleanup // doesn't panic and handles the no-rate-limit case let tracker2 = ConnectionTracker::new(None, None); tracker2.cleanup_stale_timestamps(); // should be a no-op } }