403 lines
14 KiB
Rust
403 lines
14 KiB
Rust
use dashmap::DashMap;
|
|
use std::collections::VecDeque;
|
|
use std::net::IpAddr;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{debug, warn};
|
|
|
|
use super::connection_record::ConnectionRecord;
|
|
|
|
/// Thresholds for zombie detection (non-TLS connections).
|
|
const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30);
|
|
/// Thresholds for zombie detection (TLS connections).
|
|
const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300);
|
|
/// Stuck connection timeout (non-TLS): received data but never sent any.
|
|
const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60);
|
|
/// Stuck connection timeout (TLS): received data but never sent any.
|
|
const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300);
|
|
|
|
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
|
|
/// Also maintains per-connection records for zombie detection.
|
|
pub struct ConnectionTracker {
|
|
/// Active connection counts per IP
|
|
active: DashMap<IpAddr, AtomicU64>,
|
|
/// Connection timestamps per IP for rate limiting
|
|
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
|
|
/// Maximum concurrent connections per IP (None = unlimited)
|
|
max_per_ip: Option<u64>,
|
|
/// Maximum new connections per minute per IP (None = unlimited)
|
|
rate_limit_per_minute: Option<u64>,
|
|
/// Per-connection tracking records for zombie detection
|
|
connections: DashMap<u64, Arc<ConnectionRecord>>,
|
|
/// Monotonically increasing connection ID counter
|
|
next_id: AtomicU64,
|
|
}
|
|
|
|
impl ConnectionTracker {
|
|
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
|
|
Self {
|
|
active: DashMap::new(),
|
|
timestamps: DashMap::new(),
|
|
max_per_ip,
|
|
rate_limit_per_minute,
|
|
connections: DashMap::new(),
|
|
next_id: AtomicU64::new(1),
|
|
}
|
|
}
|
|
|
|
/// Try to accept a new connection from the given IP.
|
|
/// Returns true if allowed, false if over limit.
|
|
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
|
// Check per-IP connection limit
|
|
if let Some(max) = self.max_per_ip {
|
|
let count = self.active
|
|
.get(ip)
|
|
.map(|c| c.value().load(Ordering::Relaxed))
|
|
.unwrap_or(0);
|
|
if count >= max {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Check rate limit
|
|
if let Some(rate_limit) = self.rate_limit_per_minute {
|
|
let now = Instant::now();
|
|
let one_minute = std::time::Duration::from_secs(60);
|
|
let mut entry = self.timestamps.entry(*ip).or_default();
|
|
let timestamps = entry.value_mut();
|
|
|
|
// Remove timestamps older than 1 minute
|
|
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
|
timestamps.pop_front();
|
|
}
|
|
|
|
if timestamps.len() as u64 >= rate_limit {
|
|
return false;
|
|
}
|
|
timestamps.push_back(now);
|
|
}
|
|
|
|
true
|
|
}
|
|
|
|
/// Record that a connection was opened from the given IP.
|
|
pub fn connection_opened(&self, ip: &IpAddr) {
|
|
self.active
|
|
.entry(*ip)
|
|
.or_insert_with(|| AtomicU64::new(0))
|
|
.value()
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Record that a connection was closed from the given IP.
|
|
pub fn connection_closed(&self, ip: &IpAddr) {
|
|
if let Some(counter) = self.active.get(ip) {
|
|
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
|
// Clean up zero entries
|
|
if prev <= 1 {
|
|
drop(counter);
|
|
self.active.remove(ip);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Get the current number of active connections for an IP.
|
|
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
|
|
self.active
|
|
.get(ip)
|
|
.map(|c| c.value().load(Ordering::Relaxed))
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
/// Get the total number of tracked IPs.
|
|
pub fn tracked_ips(&self) -> usize {
|
|
self.active.len()
|
|
}
|
|
|
|
/// Register a new connection and return its tracking record.
|
|
///
|
|
/// The returned `Arc<ConnectionRecord>` should be passed to the forwarding
|
|
/// loop so it can update bytes / activity atomics in real time.
|
|
pub fn register_connection(&self, is_tls: bool) -> Arc<ConnectionRecord> {
|
|
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
|
let record = Arc::new(ConnectionRecord::new(id));
|
|
record.is_tls.store(is_tls, Ordering::Relaxed);
|
|
self.connections.insert(id, Arc::clone(&record));
|
|
record
|
|
}
|
|
|
|
/// Remove a connection record when the connection is fully closed.
|
|
pub fn unregister_connection(&self, id: u64) {
|
|
self.connections.remove(&id);
|
|
}
|
|
|
|
/// Scan all tracked connections and return IDs of zombie connections.
|
|
///
|
|
/// A connection is considered a zombie in any of these cases:
|
|
/// - **Full zombie**: both `client_closed` and `backend_closed` are true.
|
|
/// - **Half zombie**: one side closed for longer than the threshold
|
|
/// (5 min for TLS, 30s for non-TLS).
|
|
/// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer
|
|
/// than the stuck threshold (5 min for TLS, 60s for non-TLS).
|
|
pub fn scan_zombies(&self) -> Vec<u64> {
|
|
let mut zombies = Vec::new();
|
|
|
|
for entry in self.connections.iter() {
|
|
let record = entry.value();
|
|
let id = *entry.key();
|
|
let is_tls = record.is_tls.load(Ordering::Relaxed);
|
|
let client_closed = record.client_closed.load(Ordering::Relaxed);
|
|
let backend_closed = record.backend_closed.load(Ordering::Relaxed);
|
|
let idle = record.idle_duration();
|
|
let bytes_in = record.bytes_received.load(Ordering::Relaxed);
|
|
let bytes_out = record.bytes_sent.load(Ordering::Relaxed);
|
|
|
|
// Full zombie: both sides closed
|
|
if client_closed && backend_closed {
|
|
zombies.push(id);
|
|
continue;
|
|
}
|
|
|
|
// Half zombie: one side closed for too long
|
|
let half_timeout = if is_tls {
|
|
HALF_ZOMBIE_TIMEOUT_TLS
|
|
} else {
|
|
HALF_ZOMBIE_TIMEOUT_PLAIN
|
|
};
|
|
|
|
if (client_closed || backend_closed) && idle >= half_timeout {
|
|
zombies.push(id);
|
|
continue;
|
|
}
|
|
|
|
// Stuck: received data but never sent anything for too long
|
|
let stuck_timeout = if is_tls {
|
|
STUCK_TIMEOUT_TLS
|
|
} else {
|
|
STUCK_TIMEOUT_PLAIN
|
|
};
|
|
|
|
if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout {
|
|
zombies.push(id);
|
|
}
|
|
}
|
|
|
|
zombies
|
|
}
|
|
|
|
/// Start a background task that periodically scans for zombie connections.
|
|
///
|
|
/// The scanner runs every 10 seconds and logs any zombies it finds.
|
|
/// It stops when the provided `CancellationToken` is cancelled.
|
|
pub fn start_zombie_scanner(self: &Arc<Self>, cancel: CancellationToken) {
|
|
let tracker = Arc::clone(self);
|
|
tokio::spawn(async move {
|
|
let interval = Duration::from_secs(10);
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel.cancelled() => {
|
|
debug!("Zombie scanner shutting down");
|
|
break;
|
|
}
|
|
_ = tokio::time::sleep(interval) => {
|
|
let zombies = tracker.scan_zombies();
|
|
if !zombies.is_empty() {
|
|
warn!(
|
|
"Detected {} zombie connection(s): {:?}",
|
|
zombies.len(),
|
|
zombies
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Get the total number of tracked connections (with records).
|
|
pub fn total_connections(&self) -> usize {
|
|
self.connections.len()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_basic_tracking() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 1);
|
|
|
|
tracker.connection_opened(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 2);
|
|
|
|
tracker.connection_closed(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 1);
|
|
|
|
tracker.connection_closed(&ip);
|
|
assert_eq!(tracker.active_connections(&ip), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_per_ip_limit() {
|
|
let tracker = ConnectionTracker::new(Some(2), None);
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
|
|
// Third connection should be rejected
|
|
assert!(!tracker.try_accept(&ip));
|
|
|
|
// Different IP should still be allowed
|
|
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
|
assert!(tracker.try_accept(&ip2));
|
|
}
|
|
|
|
#[test]
|
|
fn test_rate_limit() {
|
|
let tracker = ConnectionTracker::new(None, Some(3));
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
assert!(tracker.try_accept(&ip));
|
|
assert!(tracker.try_accept(&ip));
|
|
assert!(tracker.try_accept(&ip));
|
|
// 4th attempt within the minute should be rejected
|
|
assert!(!tracker.try_accept(&ip));
|
|
}
|
|
|
|
#[test]
|
|
fn test_no_limits() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
|
|
|
for _ in 0..1000 {
|
|
assert!(tracker.try_accept(&ip));
|
|
tracker.connection_opened(&ip);
|
|
}
|
|
assert_eq!(tracker.active_connections(&ip), 1000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tracked_ips() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
assert_eq!(tracker.tracked_ips(), 0);
|
|
|
|
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
|
|
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
|
|
|
tracker.connection_opened(&ip1);
|
|
tracker.connection_opened(&ip2);
|
|
assert_eq!(tracker.tracked_ips(), 2);
|
|
|
|
tracker.connection_closed(&ip1);
|
|
assert_eq!(tracker.tracked_ips(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_register_unregister_connection() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
assert_eq!(tracker.total_connections(), 0);
|
|
|
|
let record1 = tracker.register_connection(false);
|
|
assert_eq!(tracker.total_connections(), 1);
|
|
assert!(!record1.is_tls.load(Ordering::Relaxed));
|
|
|
|
let record2 = tracker.register_connection(true);
|
|
assert_eq!(tracker.total_connections(), 2);
|
|
assert!(record2.is_tls.load(Ordering::Relaxed));
|
|
|
|
// IDs should be unique
|
|
assert_ne!(record1.id, record2.id);
|
|
|
|
tracker.unregister_connection(record1.id);
|
|
assert_eq!(tracker.total_connections(), 1);
|
|
|
|
tracker.unregister_connection(record2.id);
|
|
assert_eq!(tracker.total_connections(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_full_zombie_detection() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let record = tracker.register_connection(false);
|
|
|
|
// Not a zombie initially
|
|
assert!(tracker.scan_zombies().is_empty());
|
|
|
|
// Set both sides closed -> full zombie
|
|
record.client_closed.store(true, Ordering::Relaxed);
|
|
record.backend_closed.store(true, Ordering::Relaxed);
|
|
|
|
let zombies = tracker.scan_zombies();
|
|
assert_eq!(zombies.len(), 1);
|
|
assert_eq!(zombies[0], record.id);
|
|
}
|
|
|
|
#[test]
|
|
fn test_half_zombie_not_triggered_immediately() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let record = tracker.register_connection(false);
|
|
record.touch(); // mark activity now
|
|
|
|
// Only one side closed, but just now -> not a zombie yet
|
|
record.client_closed.store(true, Ordering::Relaxed);
|
|
assert!(tracker.scan_zombies().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_stuck_connection_not_triggered_immediately() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let record = tracker.register_connection(false);
|
|
record.touch(); // mark activity now
|
|
|
|
// Has received data but sent nothing -> but just started, not stuck yet
|
|
record.bytes_received.store(1000, Ordering::Relaxed);
|
|
assert!(tracker.scan_zombies().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_unregister_removes_from_zombie_scan() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
let record = tracker.register_connection(false);
|
|
let id = record.id;
|
|
|
|
// Make it a full zombie
|
|
record.client_closed.store(true, Ordering::Relaxed);
|
|
record.backend_closed.store(true, Ordering::Relaxed);
|
|
assert_eq!(tracker.scan_zombies().len(), 1);
|
|
|
|
// Unregister should remove it
|
|
tracker.unregister_connection(id);
|
|
assert!(tracker.scan_zombies().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_total_connections() {
|
|
let tracker = ConnectionTracker::new(None, None);
|
|
assert_eq!(tracker.total_connections(), 0);
|
|
|
|
let r1 = tracker.register_connection(false);
|
|
let r2 = tracker.register_connection(true);
|
|
let r3 = tracker.register_connection(false);
|
|
assert_eq!(tracker.total_connections(), 3);
|
|
|
|
tracker.unregister_connection(r2.id);
|
|
assert_eq!(tracker.total_connections(), 2);
|
|
|
|
tracker.unregister_connection(r1.id);
|
|
tracker.unregister_connection(r3.id);
|
|
assert_eq!(tracker.total_connections(), 0);
|
|
}
|
|
}
|