//! In-process SMTP rate limiter. //! //! Uses DashMap for lock-free concurrent access to rate counters. //! Tracks connections per IP, messages per sender, and auth failures. use dashmap::DashMap; use serde::{Deserialize, Serialize}; use std::time::{Duration, Instant}; /// Rate limiter configuration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RateLimitConfig { /// Maximum connections per IP per window. pub max_connections_per_ip: u32, /// Maximum messages per sender per window. pub max_messages_per_sender: u32, /// Maximum auth failures per IP per window. pub max_auth_failures_per_ip: u32, /// Window duration in seconds. pub window_secs: u64, } impl Default for RateLimitConfig { fn default() -> Self { Self { max_connections_per_ip: 50, max_messages_per_sender: 100, max_auth_failures_per_ip: 5, window_secs: 60, } } } /// A timestamped counter entry. struct CounterEntry { count: u32, window_start: Instant, } /// In-process rate limiter using DashMap. pub struct RateLimiter { config: RateLimitConfig, window: Duration, connections: DashMap, messages: DashMap, auth_failures: DashMap, } impl RateLimiter { /// Create a new rate limiter with the given configuration. pub fn new(config: RateLimitConfig) -> Self { let window = Duration::from_secs(config.window_secs); Self { config, window, connections: DashMap::new(), messages: DashMap::new(), auth_failures: DashMap::new(), } } /// Update the configuration at runtime. pub fn update_config(&mut self, config: RateLimitConfig) { self.window = Duration::from_secs(config.window_secs); self.config = config; } /// Check and record a new connection from an IP. /// Returns `true` if the connection should be allowed. pub fn check_connection(&self, ip: &str) -> bool { self.increment_and_check( &self.connections, ip, self.config.max_connections_per_ip, ) } /// Check and record a message from a sender. /// Returns `true` if the message should be allowed. pub fn check_message(&self, sender: &str) -> bool { self.increment_and_check( &self.messages, sender, self.config.max_messages_per_sender, ) } /// Check and record an auth failure from an IP. /// Returns `true` if more attempts should be allowed. pub fn check_auth_failure(&self, ip: &str) -> bool { self.increment_and_check( &self.auth_failures, ip, self.config.max_auth_failures_per_ip, ) } /// Increment a counter and check against the limit. /// Returns `true` if within limits. fn increment_and_check( &self, map: &DashMap, key: &str, limit: u32, ) -> bool { let now = Instant::now(); let mut entry = map .entry(key.to_string()) .or_insert_with(|| CounterEntry { count: 0, window_start: now, }); // Reset window if expired if now.duration_since(entry.window_start) > self.window { entry.count = 0; entry.window_start = now; } entry.count += 1; entry.count <= limit } /// Clean up expired entries. Call periodically. pub fn cleanup(&self) { let now = Instant::now(); let window = self.window; self.connections .retain(|_, v| now.duration_since(v.window_start) <= window); self.messages .retain(|_, v| now.duration_since(v.window_start) <= window); self.auth_failures .retain(|_, v| now.duration_since(v.window_start) <= window); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_connection_limit() { let limiter = RateLimiter::new(RateLimitConfig { max_connections_per_ip: 3, window_secs: 60, ..Default::default() }); assert!(limiter.check_connection("1.2.3.4")); assert!(limiter.check_connection("1.2.3.4")); assert!(limiter.check_connection("1.2.3.4")); assert!(!limiter.check_connection("1.2.3.4")); // 4th = over limit // Different IP is independent assert!(limiter.check_connection("5.6.7.8")); } #[test] fn test_message_limit() { let limiter = RateLimiter::new(RateLimitConfig { max_messages_per_sender: 2, window_secs: 60, ..Default::default() }); assert!(limiter.check_message("sender@example.com")); assert!(limiter.check_message("sender@example.com")); assert!(!limiter.check_message("sender@example.com")); } #[test] fn test_auth_failure_limit() { let limiter = RateLimiter::new(RateLimitConfig { max_auth_failures_per_ip: 2, window_secs: 60, ..Default::default() }); assert!(limiter.check_auth_failure("1.2.3.4")); assert!(limiter.check_auth_failure("1.2.3.4")); assert!(!limiter.check_auth_failure("1.2.3.4")); } #[test] fn test_cleanup() { let limiter = RateLimiter::new(RateLimitConfig { max_connections_per_ip: 1, window_secs: 60, ..Default::default() }); limiter.check_connection("1.2.3.4"); assert_eq!(limiter.connections.len(), 1); limiter.cleanup(); // entries not expired assert_eq!(limiter.connections.len(), 1); } }