199 lines
5.7 KiB
Rust
199 lines
5.7 KiB
Rust
|
|
//! In-process SMTP rate limiter.
|
||
|
|
//!
|
||
|
|
//! Uses DashMap for lock-free concurrent access to rate counters.
|
||
|
|
//! Tracks connections per IP, messages per sender, and auth failures.
|
||
|
|
|
||
|
|
use dashmap::DashMap;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::time::{Duration, Instant};
|
||
|
|
|
||
|
|
/// Rate limiter configuration.
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct RateLimitConfig {
|
||
|
|
/// Maximum connections per IP per window.
|
||
|
|
pub max_connections_per_ip: u32,
|
||
|
|
/// Maximum messages per sender per window.
|
||
|
|
pub max_messages_per_sender: u32,
|
||
|
|
/// Maximum auth failures per IP per window.
|
||
|
|
pub max_auth_failures_per_ip: u32,
|
||
|
|
/// Window duration in seconds.
|
||
|
|
pub window_secs: u64,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for RateLimitConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
max_connections_per_ip: 50,
|
||
|
|
max_messages_per_sender: 100,
|
||
|
|
max_auth_failures_per_ip: 5,
|
||
|
|
window_secs: 60,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// A timestamped counter entry.
|
||
|
|
struct CounterEntry {
|
||
|
|
count: u32,
|
||
|
|
window_start: Instant,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// In-process rate limiter using DashMap.
|
||
|
|
pub struct RateLimiter {
|
||
|
|
config: RateLimitConfig,
|
||
|
|
window: Duration,
|
||
|
|
connections: DashMap<String, CounterEntry>,
|
||
|
|
messages: DashMap<String, CounterEntry>,
|
||
|
|
auth_failures: DashMap<String, CounterEntry>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl RateLimiter {
|
||
|
|
/// Create a new rate limiter with the given configuration.
|
||
|
|
pub fn new(config: RateLimitConfig) -> Self {
|
||
|
|
let window = Duration::from_secs(config.window_secs);
|
||
|
|
Self {
|
||
|
|
config,
|
||
|
|
window,
|
||
|
|
connections: DashMap::new(),
|
||
|
|
messages: DashMap::new(),
|
||
|
|
auth_failures: DashMap::new(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Update the configuration at runtime.
|
||
|
|
pub fn update_config(&mut self, config: RateLimitConfig) {
|
||
|
|
self.window = Duration::from_secs(config.window_secs);
|
||
|
|
self.config = config;
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check and record a new connection from an IP.
|
||
|
|
/// Returns `true` if the connection should be allowed.
|
||
|
|
pub fn check_connection(&self, ip: &str) -> bool {
|
||
|
|
self.increment_and_check(
|
||
|
|
&self.connections,
|
||
|
|
ip,
|
||
|
|
self.config.max_connections_per_ip,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check and record a message from a sender.
|
||
|
|
/// Returns `true` if the message should be allowed.
|
||
|
|
pub fn check_message(&self, sender: &str) -> bool {
|
||
|
|
self.increment_and_check(
|
||
|
|
&self.messages,
|
||
|
|
sender,
|
||
|
|
self.config.max_messages_per_sender,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check and record an auth failure from an IP.
|
||
|
|
/// Returns `true` if more attempts should be allowed.
|
||
|
|
pub fn check_auth_failure(&self, ip: &str) -> bool {
|
||
|
|
self.increment_and_check(
|
||
|
|
&self.auth_failures,
|
||
|
|
ip,
|
||
|
|
self.config.max_auth_failures_per_ip,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Increment a counter and check against the limit.
|
||
|
|
/// Returns `true` if within limits.
|
||
|
|
fn increment_and_check(
|
||
|
|
&self,
|
||
|
|
map: &DashMap<String, CounterEntry>,
|
||
|
|
key: &str,
|
||
|
|
limit: u32,
|
||
|
|
) -> bool {
|
||
|
|
let now = Instant::now();
|
||
|
|
let mut entry = map
|
||
|
|
.entry(key.to_string())
|
||
|
|
.or_insert_with(|| CounterEntry {
|
||
|
|
count: 0,
|
||
|
|
window_start: now,
|
||
|
|
});
|
||
|
|
|
||
|
|
// Reset window if expired
|
||
|
|
if now.duration_since(entry.window_start) > self.window {
|
||
|
|
entry.count = 0;
|
||
|
|
entry.window_start = now;
|
||
|
|
}
|
||
|
|
|
||
|
|
entry.count += 1;
|
||
|
|
entry.count <= limit
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Clean up expired entries. Call periodically.
|
||
|
|
pub fn cleanup(&self) {
|
||
|
|
let now = Instant::now();
|
||
|
|
let window = self.window;
|
||
|
|
self.connections
|
||
|
|
.retain(|_, v| now.duration_since(v.window_start) <= window);
|
||
|
|
self.messages
|
||
|
|
.retain(|_, v| now.duration_since(v.window_start) <= window);
|
||
|
|
self.auth_failures
|
||
|
|
.retain(|_, v| now.duration_since(v.window_start) <= window);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_connection_limit() {
|
||
|
|
let limiter = RateLimiter::new(RateLimitConfig {
|
||
|
|
max_connections_per_ip: 3,
|
||
|
|
window_secs: 60,
|
||
|
|
..Default::default()
|
||
|
|
});
|
||
|
|
|
||
|
|
assert!(limiter.check_connection("1.2.3.4"));
|
||
|
|
assert!(limiter.check_connection("1.2.3.4"));
|
||
|
|
assert!(limiter.check_connection("1.2.3.4"));
|
||
|
|
assert!(!limiter.check_connection("1.2.3.4")); // 4th = over limit
|
||
|
|
|
||
|
|
// Different IP is independent
|
||
|
|
assert!(limiter.check_connection("5.6.7.8"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_message_limit() {
|
||
|
|
let limiter = RateLimiter::new(RateLimitConfig {
|
||
|
|
max_messages_per_sender: 2,
|
||
|
|
window_secs: 60,
|
||
|
|
..Default::default()
|
||
|
|
});
|
||
|
|
|
||
|
|
assert!(limiter.check_message("sender@example.com"));
|
||
|
|
assert!(limiter.check_message("sender@example.com"));
|
||
|
|
assert!(!limiter.check_message("sender@example.com"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_auth_failure_limit() {
|
||
|
|
let limiter = RateLimiter::new(RateLimitConfig {
|
||
|
|
max_auth_failures_per_ip: 2,
|
||
|
|
window_secs: 60,
|
||
|
|
..Default::default()
|
||
|
|
});
|
||
|
|
|
||
|
|
assert!(limiter.check_auth_failure("1.2.3.4"));
|
||
|
|
assert!(limiter.check_auth_failure("1.2.3.4"));
|
||
|
|
assert!(!limiter.check_auth_failure("1.2.3.4"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_cleanup() {
|
||
|
|
let limiter = RateLimiter::new(RateLimitConfig {
|
||
|
|
max_connections_per_ip: 1,
|
||
|
|
window_secs: 60,
|
||
|
|
..Default::default()
|
||
|
|
});
|
||
|
|
|
||
|
|
limiter.check_connection("1.2.3.4");
|
||
|
|
assert_eq!(limiter.connections.len(), 1);
|
||
|
|
|
||
|
|
limiter.cleanup(); // entries not expired
|
||
|
|
assert_eq!(limiter.connections.len(), 1);
|
||
|
|
}
|
||
|
|
}
|