Files
smartvpn/rust/src/qos.rs

491 lines
15 KiB
Rust

use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
/// Priority levels for IP packets.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Priority {
High = 0,
Normal = 1,
Low = 2,
}
/// QoS statistics per priority level.
pub struct QosStats {
pub high_enqueued: AtomicU64,
pub normal_enqueued: AtomicU64,
pub low_enqueued: AtomicU64,
pub high_dropped: AtomicU64,
pub normal_dropped: AtomicU64,
pub low_dropped: AtomicU64,
}
impl QosStats {
pub fn new() -> Self {
Self {
high_enqueued: AtomicU64::new(0),
normal_enqueued: AtomicU64::new(0),
low_enqueued: AtomicU64::new(0),
high_dropped: AtomicU64::new(0),
normal_dropped: AtomicU64::new(0),
low_dropped: AtomicU64::new(0),
}
}
}
impl Default for QosStats {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Packet classification
// ============================================================================
/// 5-tuple flow key for tracking bulk flows.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct FlowKey {
src_ip: u32,
dst_ip: u32,
src_port: u16,
dst_port: u16,
protocol: u8,
}
/// Per-flow state for bulk detection.
struct FlowState {
bytes_total: u64,
window_start: Instant,
}
/// Tracks per-flow byte counts for bulk flow detection.
struct FlowTracker {
flows: HashMap<FlowKey, FlowState>,
window_duration: Duration,
max_flows: usize,
}
impl FlowTracker {
fn new(window_duration: Duration, max_flows: usize) -> Self {
Self {
flows: HashMap::new(),
window_duration,
max_flows,
}
}
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
let now = Instant::now();
// Evict if at capacity — remove oldest entry
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
if let Some(oldest_key) = self
.flows
.iter()
.min_by_key(|(_, v)| v.window_start)
.map(|(k, _)| *k)
{
self.flows.remove(&oldest_key);
}
}
let state = self.flows.entry(key).or_insert(FlowState {
bytes_total: 0,
window_start: now,
});
// Reset window if expired
if now.duration_since(state.window_start) > self.window_duration {
state.bytes_total = 0;
state.window_start = now;
}
state.bytes_total += bytes;
state.bytes_total > threshold
}
}
/// Classifies raw IP packets into priority levels.
pub struct PacketClassifier {
flow_tracker: FlowTracker,
/// Byte threshold for classifying a flow as "bulk" (Low priority).
bulk_threshold_bytes: u64,
}
impl PacketClassifier {
/// Create a new classifier.
///
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
pub fn new(bulk_threshold_bytes: u64) -> Self {
Self {
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
bulk_threshold_bytes,
}
}
/// Classify a raw IPv4 packet into a priority level.
///
/// The packet must start with the IPv4 header (as read from a TUN device).
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
// Need at least 20 bytes for a minimal IPv4 header
if ip_packet.len() < 20 {
return Priority::Normal;
}
let version = ip_packet[0] >> 4;
if version != 4 {
return Priority::Normal; // Only classify IPv4 for now
}
let ihl = (ip_packet[0] & 0x0F) as usize;
let header_len = ihl * 4;
let protocol = ip_packet[9];
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
// ICMP is always high priority
if protocol == 1 {
return Priority::High;
}
// Small packets (<128 bytes) are high priority (likely interactive)
if total_len < 128 {
return Priority::High;
}
// Extract ports for TCP (6) and UDP (17)
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
&& ip_packet.len() >= header_len + 4
{
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
(sp, dp)
} else {
(0, 0)
};
// DNS (port 53) and SSH (port 22) are high priority
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
return Priority::High;
}
// Check for bulk flow
if protocol == 6 || protocol == 17 {
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
let key = FlowKey {
src_ip,
dst_ip,
src_port,
dst_port,
protocol,
};
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
return Priority::Low;
}
}
Priority::Normal
}
}
// ============================================================================
// Priority channel set
// ============================================================================
/// Error returned when a packet is dropped.
#[derive(Debug)]
pub enum PacketDropped {
LowPriorityDrop,
NormalPriorityDrop,
HighPriorityDrop,
ChannelClosed,
}
/// Sending half of the priority channel set.
pub struct PrioritySender {
high_tx: mpsc::Sender<Vec<u8>>,
normal_tx: mpsc::Sender<Vec<u8>>,
low_tx: mpsc::Sender<Vec<u8>>,
stats: Arc<QosStats>,
}
impl PrioritySender {
/// Send a packet with the given priority. Implements smart dropping under backpressure.
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
let (tx, enqueued_counter) = match priority {
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
};
match tx.try_send(packet) {
Ok(()) => {
enqueued_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
Err(mpsc::error::TrySendError::Full(packet)) => {
self.handle_backpressure(packet, priority).await
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
}
}
async fn handle_backpressure(
&self,
packet: Vec<u8>,
priority: Priority,
) -> Result<(), PacketDropped> {
match priority {
Priority::Low => {
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::LowPriorityDrop)
}
Priority::Normal => {
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::NormalPriorityDrop)
}
Priority::High => {
// Last resort: briefly wait for space, then drop
match tokio::time::timeout(
Duration::from_millis(5),
self.high_tx.send(packet),
)
.await
{
Ok(Ok(())) => {
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
Ok(())
}
_ => {
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::HighPriorityDrop)
}
}
}
}
}
}
/// Receiving half of the priority channel set.
pub struct PriorityReceiver {
high_rx: mpsc::Receiver<Vec<u8>>,
normal_rx: mpsc::Receiver<Vec<u8>>,
low_rx: mpsc::Receiver<Vec<u8>>,
}
impl PriorityReceiver {
/// Receive the next packet, draining high-priority first (biased select).
pub async fn recv(&mut self) -> Option<Vec<u8>> {
tokio::select! {
biased;
Some(pkt) = self.high_rx.recv() => Some(pkt),
Some(pkt) = self.normal_rx.recv() => Some(pkt),
Some(pkt) = self.low_rx.recv() => Some(pkt),
else => None,
}
}
}
/// Create a priority channel set split into sender and receiver halves.
///
/// - `high_cap`: capacity of the high-priority channel
/// - `normal_cap`: capacity of the normal-priority channel
/// - `low_cap`: capacity of the low-priority channel
pub fn create_priority_channels(
high_cap: usize,
normal_cap: usize,
low_cap: usize,
) -> (PrioritySender, PriorityReceiver) {
let (high_tx, high_rx) = mpsc::channel(high_cap);
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
let (low_tx, low_rx) = mpsc::channel(low_cap);
let stats = Arc::new(QosStats::new());
let sender = PrioritySender {
high_tx,
normal_tx,
low_tx,
stats,
};
let receiver = PriorityReceiver {
high_rx,
normal_rx,
low_rx,
};
(sender, receiver)
}
/// Get a reference to the QoS stats from a sender.
impl PrioritySender {
pub fn stats(&self) -> &Arc<QosStats> {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper: craft a minimal IPv4 packet
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
let mut pkt = vec![0u8; total_len.max(24) as usize];
pkt[0] = 0x45; // version 4, IHL 5
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
pkt[9] = protocol;
// src IP
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
// dst IP
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
// ports (at offset 20 for IHL=5)
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
pkt
}
#[test]
fn classify_icmp_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_dns_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_ssh_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_small_packet_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_normal_http() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_bulk_flow_as_low() {
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
// Send enough traffic to exceed the threshold
for _ in 0..100 {
let pkt = make_ipv4_packet(6, 12345, 80, 500);
c.classify(&pkt);
}
// Next packet from same flow should be Low
let pkt = make_ipv4_packet(6, 12345, 80, 500);
assert_eq!(c.classify(&pkt), Priority::Low);
}
#[test]
fn classify_too_short_packet() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = vec![0u8; 10]; // Too short for IPv4 header
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_non_ipv4() {
let mut c = PacketClassifier::new(1_000_000);
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6 version nibble
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[tokio::test]
async fn priority_receiver_drains_high_first() {
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
// Enqueue in reverse order
sender.send(vec![3], Priority::Low).await.unwrap();
sender.send(vec![2], Priority::Normal).await.unwrap();
sender.send(vec![1], Priority::High).await.unwrap();
// Should drain High first
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
}
#[tokio::test]
async fn smart_dropping_low_priority() {
let (sender, _receiver) = create_priority_channels(8, 8, 1);
// Fill the low channel
sender.send(vec![0], Priority::Low).await.unwrap();
// Next low-priority send should be dropped
let result = sender.send(vec![1], Priority::Low).await;
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn smart_dropping_normal_priority() {
let (sender, _receiver) = create_priority_channels(8, 1, 8);
sender.send(vec![0], Priority::Normal).await.unwrap();
let result = sender.send(vec![1], Priority::Normal).await;
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn stats_track_enqueued() {
let (sender, _receiver) = create_priority_channels(8, 8, 8);
sender.send(vec![1], Priority::High).await.unwrap();
sender.send(vec![2], Priority::High).await.unwrap();
sender.send(vec![3], Priority::Normal).await.unwrap();
sender.send(vec![4], Priority::Low).await.unwrap();
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
}
#[test]
fn flow_tracker_evicts_at_capacity() {
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
ft.record(k1, 100, 1000);
ft.record(k2, 100, 1000);
// Should evict k1 (oldest)
ft.record(k3, 100, 1000);
assert_eq!(ft.flows.len(), 2);
assert!(!ft.flows.contains_key(&k1));
}
}