248 lines
8.2 KiB
Rust
248 lines
8.2 KiB
Rust
use dashmap::DashMap;
|
|
use std::collections::VecDeque;
|
|
use std::net::IpAddr;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::time::{Duration, Instant};
|
|
|
|
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
|
|
pub struct ConnectionTracker {
|
|
/// Active connection counts per IP
|
|
active: DashMap<IpAddr, AtomicU64>,
|
|
/// Connection timestamps per IP for rate limiting
|
|
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
|
|
/// Maximum concurrent connections per IP (None = unlimited)
|
|
max_per_ip: Option<u64>,
|
|
/// Maximum new connections per minute per IP (None = unlimited)
|
|
rate_limit_per_minute: Option<u64>,
|
|
}
|
|
|
|
impl ConnectionTracker {
|
|
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
|
|
Self {
|
|
active: DashMap::new(),
|
|
timestamps: DashMap::new(),
|
|
max_per_ip,
|
|
rate_limit_per_minute,
|
|
}
|
|
}
|
|
|
|
/// Try to accept a new connection from the given IP.
|
|
/// Returns true if allowed, false if over limit.
|
|
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
|
// Check per-IP connection limit
|
|
if let Some(max) = self.max_per_ip {
|
|
let count = self.active
|
|
.get(ip)
|
|
.map(|c| c.value().load(Ordering::Relaxed))
|
|
.unwrap_or(0);
|
|
if count >= max {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Check rate limit
|
|
if let Some(rate_limit) = self.rate_limit_per_minute {
|
|
let now = Instant::now();
|
|
let one_minute = std::time::Duration::from_secs(60);
|
|
let mut entry = self.timestamps.entry(*ip).or_default();
|
|
let timestamps = entry.value_mut();
|
|
|
|
// Remove timestamps older than 1 minute
|
|
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
|
timestamps.pop_front();
|
|
}
|
|
|
|
if timestamps.len() as u64 >= rate_limit {
|
|
return false;
|
|
}
|
|
timestamps.push_back(now);
|
|
}
|
|
|
|
true
|
|
}
|
|
|
|
/// Record that a connection was opened from the given IP.
|
|
pub fn connection_opened(&self, ip: &IpAddr) {
|
|
self.active
|
|
.entry(*ip)
|
|
.or_insert_with(|| AtomicU64::new(0))
|
|
.value()
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Record that a connection was closed from the given IP.
|
|
pub fn connection_closed(&self, ip: &IpAddr) {
|
|
if let Some(counter) = self.active.get(ip) {
|
|
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
|
// Clean up zero entries to prevent memory growth
|
|
if prev <= 1 {
|
|
drop(counter);
|
|
self.active.remove(ip);
|
|
self.timestamps.remove(ip);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Get the current number of active connections for an IP.
|
|
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
|
|
self.active
|
|
.get(ip)
|
|
.map(|c| c.value().load(Ordering::Relaxed))
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
/// Prune stale timestamp entries for IPs that have no active connections
|
|
/// and no recent timestamps. This cleans up entries left by rate-limited IPs
|
|
/// that never had connection_opened called.
|
|
pub fn cleanup_stale_timestamps(&self) {
|
|
if self.rate_limit_per_minute.is_none() {
|
|
return; // No rate limiting — timestamps map should be empty
|
|
}
|
|
let now = Instant::now();
|
|
let one_minute = Duration::from_secs(60);
|
|
self.timestamps.retain(|ip, timestamps| {
|
|
timestamps.retain(|t| now.duration_since(*t) < one_minute);
|
|
// Keep if there are active connections or recent timestamps
|
|
!timestamps.is_empty() || self.active.contains_key(ip)
|
|
});
|
|
}
|
|
|
|
/// Get the total number of tracked IPs.
|
|
pub fn tracked_ips(&self) -> usize {
|
|
self.active.len()
|
|
}
|
|
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_basic_tracking() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 1);
|
|
|
|
tracker.connection_opened(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 2);
|
|
|
|
tracker.connection_closed(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 1);
|
|
|
|
tracker.connection_closed(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_per_ip_limit() {
|
|
let tracker = ConnectionTracker::new(Some(2), None);
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
|
|
// Third connection should be rejected
|
|
assert!(!tracker.try_accept(&ip));
|
|
|
|
// Different IP should still be allowed
|
|
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
|
assert!(tracker.try_accept(&ip2));
|
|
}
|
|
|
|
#[test]
|
|
fn test_rate_limit() {
|
|
let tracker = ConnectionTracker::new(None, Some(3));
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
assert!(tracker.try_accept(&ip));
|
|
assert!(tracker.try_accept(&ip));
|
|
// 4th attempt within the minute should be rejected
|
|
assert!(!tracker.try_accept(&ip));
|
|
}
|
|
|
|
#[test]
|
|
fn test_no_limits() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
for _ in 0..1000 {
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
}
|
|
assert_eq!(tracker.active_connections(&ip), 1000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tracked_ips() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
assert_eq!(tracker.tracked_ips(), 0);
|
|
|
|
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
|
|
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
|
|
|
tracker.connection_opened(&ip1);
|
|
tracker.connection_opened(&ip2);
|
|
assert_eq!(tracker.tracked_ips(), 2);
|
|
|
|
tracker.connection_closed(&ip1);
|
|
assert_eq!(tracker.tracked_ips(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_timestamps_cleaned_on_last_close() {
|
|
let tracker = ConnectionTracker::new(None, Some(100));
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
// try_accept populates the timestamps map (when rate limiting is enabled)
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
|
|
// Timestamps should exist
|
|
assert!(tracker.timestamps.get(&ip).is_some());
|
|
|
|
// Close one connection — timestamps should still exist
|
|
tracker.connection_closed(&ip);
|
|
assert!(tracker.timestamps.get(&ip).is_some());
|
|
|
|
// Close last connection — timestamps should be cleaned up
|
|
tracker.connection_closed(&ip);
|
|
assert!(tracker.timestamps.get(&ip).is_none());
|
|
assert!(tracker.active.get(&ip).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_cleanup_stale_timestamps() {
|
|
// Rate limit of 100/min so timestamps are tracked
|
|
let tracker = ConnectionTracker::new(None, Some(100));
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
// try_accept adds a timestamp entry
|
|
assert!(tracker.try_accept(&ip));
|
|
|
|
// Simulate: connection was rate-limited and never accepted,
|
|
// so no connection_opened / connection_closed pair
|
|
assert!(tracker.timestamps.get(&ip).is_some());
|
|
assert!(tracker.active.get(&ip).is_none()); // never opened
|
|
|
|
// Cleanup won't remove it yet because timestamp is recent
|
|
tracker.cleanup_stale_timestamps();
|
|
assert!(tracker.timestamps.get(&ip).is_some());
|
|
|
|
// After expiry (use 0-second window trick: create tracker with 0 rate)
|
|
// Actually, we can't fast-forward time easily, so just verify the cleanup
|
|
// doesn't panic and handles the no-rate-limit case
|
|
let tracker2 = ConnectionTracker::new(None, None);
|
|
tracker2.cleanup_stale_timestamps(); // should be a no-op
|
|
}
|
|
}
|