Files
smartproxy/rust/crates/rustproxy-passthrough/src/connection_tracker.rs

248 lines
8.2 KiB
Rust

use dashmap::DashMap;
use std::collections::VecDeque;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
pub struct ConnectionTracker {
/// Active connection counts per IP
active: DashMap<IpAddr, AtomicU64>,
/// Connection timestamps per IP for rate limiting
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
/// Maximum concurrent connections per IP (None = unlimited)
max_per_ip: Option<u64>,
/// Maximum new connections per minute per IP (None = unlimited)
rate_limit_per_minute: Option<u64>,
}
impl ConnectionTracker {
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
Self {
active: DashMap::new(),
timestamps: DashMap::new(),
max_per_ip,
rate_limit_per_minute,
}
}
/// Try to accept a new connection from the given IP.
/// Returns true if allowed, false if over limit.
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
if count >= max {
return false;
}
}
// Check rate limit
if let Some(rate_limit) = self.rate_limit_per_minute {
let now = Instant::now();
let one_minute = std::time::Duration::from_secs(60);
let mut entry = self.timestamps.entry(*ip).or_default();
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
timestamps.pop_front();
}
if timestamps.len() as u64 >= rate_limit {
return false;
}
timestamps.push_back(now);
}
true
}
/// Record that a connection was opened from the given IP.
pub fn connection_opened(&self, ip: &IpAddr) {
self.active
.entry(*ip)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection was closed from the given IP.
pub fn connection_closed(&self, ip: &IpAddr) {
if let Some(counter) = self.active.get(ip) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Clean up zero entries to prevent memory growth
if prev <= 1 {
drop(counter);
self.active.remove(ip);
self.timestamps.remove(ip);
}
}
}
/// Get the current number of active connections for an IP.
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
}
/// Prune stale timestamp entries for IPs that have no active connections
/// and no recent timestamps. This cleans up entries left by rate-limited IPs
/// that never had connection_opened called.
pub fn cleanup_stale_timestamps(&self) {
if self.rate_limit_per_minute.is_none() {
return; // No rate limiting — timestamps map should be empty
}
let now = Instant::now();
let one_minute = Duration::from_secs(60);
self.timestamps.retain(|ip, timestamps| {
timestamps.retain(|t| now.duration_since(*t) < one_minute);
// Keep if there are active connections or recent timestamps
!timestamps.is_empty() || self.active.contains_key(ip)
});
}
/// Get the total number of tracked IPs.
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_tracking() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 2);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 0);
}
#[test]
fn test_per_ip_limit() {
let tracker = ConnectionTracker::new(Some(2), None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
// Third connection should be rejected
assert!(!tracker.try_accept(&ip));
// Different IP should still be allowed
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
assert!(tracker.try_accept(&ip2));
}
#[test]
fn test_rate_limit() {
let tracker = ConnectionTracker::new(None, Some(3));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
// 4th attempt within the minute should be rejected
assert!(!tracker.try_accept(&ip));
}
#[test]
fn test_no_limits() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..1000 {
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
}
assert_eq!(tracker.active_connections(&ip), 1000);
}
#[test]
fn test_tracked_ips() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.tracked_ips(), 0);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
tracker.connection_opened(&ip1);
tracker.connection_opened(&ip2);
assert_eq!(tracker.tracked_ips(), 2);
tracker.connection_closed(&ip1);
assert_eq!(tracker.tracked_ips(), 1);
}
#[test]
fn test_timestamps_cleaned_on_last_close() {
let tracker = ConnectionTracker::new(None, Some(100));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
// try_accept populates the timestamps map (when rate limiting is enabled)
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
// Timestamps should exist
assert!(tracker.timestamps.get(&ip).is_some());
// Close one connection — timestamps should still exist
tracker.connection_closed(&ip);
assert!(tracker.timestamps.get(&ip).is_some());
// Close last connection — timestamps should be cleaned up
tracker.connection_closed(&ip);
assert!(tracker.timestamps.get(&ip).is_none());
assert!(tracker.active.get(&ip).is_none());
}
#[test]
fn test_cleanup_stale_timestamps() {
// Rate limit of 100/min so timestamps are tracked
let tracker = ConnectionTracker::new(None, Some(100));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
// try_accept adds a timestamp entry
assert!(tracker.try_accept(&ip));
// Simulate: connection was rate-limited and never accepted,
// so no connection_opened / connection_closed pair
assert!(tracker.timestamps.get(&ip).is_some());
assert!(tracker.active.get(&ip).is_none()); // never opened
// Cleanup won't remove it yet because timestamp is recent
tracker.cleanup_stale_timestamps();
assert!(tracker.timestamps.get(&ip).is_some());
// After expiry (use 0-second window trick: create tracker with 0 rate)
// Actually, we can't fast-forward time easily, so just verify the cleanup
// doesn't panic and handles the no-rate-limit case
let tracker2 = ConnectionTracker::new(None, None);
tracker2.cleanup_stale_timestamps(); // should be a no-op
}
}