142 lines
4.4 KiB
Rust
142 lines
4.4 KiB
Rust
use std::time::Instant;
|
|
|
|
/// A token bucket rate limiter operating on byte granularity.
|
|
pub struct TokenBucket {
|
|
/// Tokens (bytes) added per second.
|
|
rate: f64,
|
|
/// Maximum burst capacity in bytes.
|
|
burst: f64,
|
|
/// Currently available tokens.
|
|
tokens: f64,
|
|
/// Last time tokens were refilled.
|
|
last_refill: Instant,
|
|
}
|
|
|
|
impl TokenBucket {
|
|
/// Create a new token bucket.
|
|
///
|
|
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
|
|
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
|
|
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
|
|
let burst = burst_bytes as f64;
|
|
Self {
|
|
rate: rate_bytes_per_sec as f64,
|
|
burst,
|
|
tokens: burst, // start full
|
|
last_refill: Instant::now(),
|
|
}
|
|
}
|
|
|
|
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
|
|
pub fn try_consume(&mut self, bytes: usize) -> bool {
|
|
self.refill();
|
|
let needed = bytes as f64;
|
|
if needed <= self.tokens {
|
|
self.tokens -= needed;
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
|
|
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
|
|
self.rate = rate_bytes_per_sec as f64;
|
|
self.burst = burst_bytes as f64;
|
|
// Cap current tokens at new burst
|
|
if self.tokens > self.burst {
|
|
self.tokens = self.burst;
|
|
}
|
|
}
|
|
|
|
fn refill(&mut self) {
|
|
let now = Instant::now();
|
|
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
|
self.last_refill = now;
|
|
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::time::Duration;
|
|
|
|
#[test]
|
|
fn allows_traffic_under_burst() {
|
|
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
|
// Should allow up to burst size immediately
|
|
assert!(tb.try_consume(1_500_000));
|
|
assert!(tb.try_consume(400_000));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_traffic_over_burst() {
|
|
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
|
// Consume entire burst
|
|
assert!(tb.try_consume(1_000_000));
|
|
// Next consume should fail (no time to refill)
|
|
assert!(!tb.try_consume(1));
|
|
}
|
|
|
|
#[test]
|
|
fn zero_consume_always_succeeds() {
|
|
let mut tb = TokenBucket::new(0, 0);
|
|
assert!(tb.try_consume(0));
|
|
}
|
|
|
|
#[test]
|
|
fn refills_over_time() {
|
|
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
|
|
// Drain completely
|
|
assert!(tb.try_consume(1_000_000));
|
|
assert!(!tb.try_consume(1));
|
|
|
|
// Wait 100ms — should refill ~100KB
|
|
std::thread::sleep(Duration::from_millis(100));
|
|
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
|
|
}
|
|
|
|
#[test]
|
|
fn update_limits_caps_tokens() {
|
|
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
|
// Tokens start at burst (2MB)
|
|
tb.update_limits(500_000, 500_000);
|
|
// Tokens should be capped to new burst (500KB)
|
|
assert!(tb.try_consume(500_000));
|
|
assert!(!tb.try_consume(1));
|
|
}
|
|
|
|
#[test]
|
|
fn update_limits_changes_rate() {
|
|
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
|
assert!(tb.try_consume(1_000_000)); // drain
|
|
|
|
// Change to higher rate
|
|
tb.update_limits(10_000_000, 10_000_000);
|
|
std::thread::sleep(Duration::from_millis(50));
|
|
// At 10MB/s, 50ms should refill ~500KB
|
|
assert!(tb.try_consume(200_000));
|
|
}
|
|
|
|
#[test]
|
|
fn zero_rate_blocks_after_burst() {
|
|
let mut tb = TokenBucket::new(0, 100);
|
|
assert!(tb.try_consume(100));
|
|
std::thread::sleep(Duration::from_millis(10));
|
|
// Zero rate means no refill
|
|
assert!(!tb.try_consume(1));
|
|
}
|
|
|
|
#[test]
|
|
fn tokens_do_not_exceed_burst() {
|
|
// Use a low rate so refill between consecutive calls is negligible
|
|
let mut tb = TokenBucket::new(100, 1_000);
|
|
// Wait to accumulate — but should cap at burst
|
|
std::thread::sleep(Duration::from_millis(50));
|
|
assert!(tb.try_consume(1_000));
|
|
// At 100 bytes/sec, the few μs between calls add ~0 tokens
|
|
assert!(!tb.try_consume(1));
|
|
}
|
|
}
|