Files
smartvpn/rust/src/ratelimit.rs

142 lines
4.4 KiB
Rust

use std::time::Instant;
/// A token bucket rate limiter operating on byte granularity.
pub struct TokenBucket {
/// Tokens (bytes) added per second.
rate: f64,
/// Maximum burst capacity in bytes.
burst: f64,
/// Currently available tokens.
tokens: f64,
/// Last time tokens were refilled.
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket.
///
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
let burst = burst_bytes as f64;
Self {
rate: rate_bytes_per_sec as f64,
burst,
tokens: burst, // start full
last_refill: Instant::now(),
}
}
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
pub fn try_consume(&mut self, bytes: usize) -> bool {
self.refill();
let needed = bytes as f64;
if needed <= self.tokens {
self.tokens -= needed;
true
} else {
false
}
}
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
self.rate = rate_bytes_per_sec as f64;
self.burst = burst_bytes as f64;
// Cap current tokens at new burst
if self.tokens > self.burst {
self.tokens = self.burst;
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn allows_traffic_under_burst() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Should allow up to burst size immediately
assert!(tb.try_consume(1_500_000));
assert!(tb.try_consume(400_000));
}
#[test]
fn blocks_traffic_over_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
// Consume entire burst
assert!(tb.try_consume(1_000_000));
// Next consume should fail (no time to refill)
assert!(!tb.try_consume(1));
}
#[test]
fn zero_consume_always_succeeds() {
let mut tb = TokenBucket::new(0, 0);
assert!(tb.try_consume(0));
}
#[test]
fn refills_over_time() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
// Drain completely
assert!(tb.try_consume(1_000_000));
assert!(!tb.try_consume(1));
// Wait 100ms — should refill ~100KB
std::thread::sleep(Duration::from_millis(100));
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
}
#[test]
fn update_limits_caps_tokens() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Tokens start at burst (2MB)
tb.update_limits(500_000, 500_000);
// Tokens should be capped to new burst (500KB)
assert!(tb.try_consume(500_000));
assert!(!tb.try_consume(1));
}
#[test]
fn update_limits_changes_rate() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
assert!(tb.try_consume(1_000_000)); // drain
// Change to higher rate
tb.update_limits(10_000_000, 10_000_000);
std::thread::sleep(Duration::from_millis(50));
// At 10MB/s, 50ms should refill ~500KB
assert!(tb.try_consume(200_000));
}
#[test]
fn zero_rate_blocks_after_burst() {
let mut tb = TokenBucket::new(0, 100);
assert!(tb.try_consume(100));
std::thread::sleep(Duration::from_millis(10));
// Zero rate means no refill
assert!(!tb.try_consume(1));
}
#[test]
fn tokens_do_not_exceed_burst() {
// Use a low rate so refill between consecutive calls is negligible
let mut tb = TokenBucket::new(100, 1_000);
// Wait to accumulate — but should cap at burst
std::thread::sleep(Duration::from_millis(50));
assert!(tb.try_consume(1_000));
// At 100 bytes/sec, the few μs between calls add ~0 tokens
assert!(!tb.try_consume(1));
}
}