491 lines
15 KiB
Rust
491 lines
15 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tokio::sync::mpsc;
|
|
|
|
/// Priority levels for IP packets.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
#[repr(u8)]
|
|
pub enum Priority {
|
|
High = 0,
|
|
Normal = 1,
|
|
Low = 2,
|
|
}
|
|
|
|
/// QoS statistics per priority level.
|
|
pub struct QosStats {
|
|
pub high_enqueued: AtomicU64,
|
|
pub normal_enqueued: AtomicU64,
|
|
pub low_enqueued: AtomicU64,
|
|
pub high_dropped: AtomicU64,
|
|
pub normal_dropped: AtomicU64,
|
|
pub low_dropped: AtomicU64,
|
|
}
|
|
|
|
impl QosStats {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
high_enqueued: AtomicU64::new(0),
|
|
normal_enqueued: AtomicU64::new(0),
|
|
low_enqueued: AtomicU64::new(0),
|
|
high_dropped: AtomicU64::new(0),
|
|
normal_dropped: AtomicU64::new(0),
|
|
low_dropped: AtomicU64::new(0),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for QosStats {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Packet classification
|
|
// ============================================================================
|
|
|
|
/// 5-tuple flow key for tracking bulk flows.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
struct FlowKey {
|
|
src_ip: u32,
|
|
dst_ip: u32,
|
|
src_port: u16,
|
|
dst_port: u16,
|
|
protocol: u8,
|
|
}
|
|
|
|
/// Per-flow state for bulk detection.
|
|
struct FlowState {
|
|
bytes_total: u64,
|
|
window_start: Instant,
|
|
}
|
|
|
|
/// Tracks per-flow byte counts for bulk flow detection.
|
|
struct FlowTracker {
|
|
flows: HashMap<FlowKey, FlowState>,
|
|
window_duration: Duration,
|
|
max_flows: usize,
|
|
}
|
|
|
|
impl FlowTracker {
|
|
fn new(window_duration: Duration, max_flows: usize) -> Self {
|
|
Self {
|
|
flows: HashMap::new(),
|
|
window_duration,
|
|
max_flows,
|
|
}
|
|
}
|
|
|
|
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
|
|
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
|
|
let now = Instant::now();
|
|
|
|
// Evict if at capacity — remove oldest entry
|
|
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
|
|
if let Some(oldest_key) = self
|
|
.flows
|
|
.iter()
|
|
.min_by_key(|(_, v)| v.window_start)
|
|
.map(|(k, _)| *k)
|
|
{
|
|
self.flows.remove(&oldest_key);
|
|
}
|
|
}
|
|
|
|
let state = self.flows.entry(key).or_insert(FlowState {
|
|
bytes_total: 0,
|
|
window_start: now,
|
|
});
|
|
|
|
// Reset window if expired
|
|
if now.duration_since(state.window_start) > self.window_duration {
|
|
state.bytes_total = 0;
|
|
state.window_start = now;
|
|
}
|
|
|
|
state.bytes_total += bytes;
|
|
state.bytes_total > threshold
|
|
}
|
|
}
|
|
|
|
/// Classifies raw IP packets into priority levels.
|
|
pub struct PacketClassifier {
|
|
flow_tracker: FlowTracker,
|
|
/// Byte threshold for classifying a flow as "bulk" (Low priority).
|
|
bulk_threshold_bytes: u64,
|
|
}
|
|
|
|
impl PacketClassifier {
|
|
/// Create a new classifier.
|
|
///
|
|
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
|
|
pub fn new(bulk_threshold_bytes: u64) -> Self {
|
|
Self {
|
|
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
|
|
bulk_threshold_bytes,
|
|
}
|
|
}
|
|
|
|
/// Classify a raw IPv4 packet into a priority level.
|
|
///
|
|
/// The packet must start with the IPv4 header (as read from a TUN device).
|
|
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
|
|
// Need at least 20 bytes for a minimal IPv4 header
|
|
if ip_packet.len() < 20 {
|
|
return Priority::Normal;
|
|
}
|
|
|
|
let version = ip_packet[0] >> 4;
|
|
if version != 4 {
|
|
return Priority::Normal; // Only classify IPv4 for now
|
|
}
|
|
|
|
let ihl = (ip_packet[0] & 0x0F) as usize;
|
|
let header_len = ihl * 4;
|
|
let protocol = ip_packet[9];
|
|
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
|
|
|
|
// ICMP is always high priority
|
|
if protocol == 1 {
|
|
return Priority::High;
|
|
}
|
|
|
|
// Small packets (<128 bytes) are high priority (likely interactive)
|
|
if total_len < 128 {
|
|
return Priority::High;
|
|
}
|
|
|
|
// Extract ports for TCP (6) and UDP (17)
|
|
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
|
|
&& ip_packet.len() >= header_len + 4
|
|
{
|
|
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
|
|
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
|
|
(sp, dp)
|
|
} else {
|
|
(0, 0)
|
|
};
|
|
|
|
// DNS (port 53) and SSH (port 22) are high priority
|
|
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
|
|
return Priority::High;
|
|
}
|
|
|
|
// Check for bulk flow
|
|
if protocol == 6 || protocol == 17 {
|
|
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
|
|
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
|
|
|
|
let key = FlowKey {
|
|
src_ip,
|
|
dst_ip,
|
|
src_port,
|
|
dst_port,
|
|
protocol,
|
|
};
|
|
|
|
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
|
|
return Priority::Low;
|
|
}
|
|
}
|
|
|
|
Priority::Normal
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Priority channel set
|
|
// ============================================================================
|
|
|
|
/// Error returned when a packet is dropped.
|
|
#[derive(Debug)]
|
|
pub enum PacketDropped {
|
|
LowPriorityDrop,
|
|
NormalPriorityDrop,
|
|
HighPriorityDrop,
|
|
ChannelClosed,
|
|
}
|
|
|
|
/// Sending half of the priority channel set.
|
|
pub struct PrioritySender {
|
|
high_tx: mpsc::Sender<Vec<u8>>,
|
|
normal_tx: mpsc::Sender<Vec<u8>>,
|
|
low_tx: mpsc::Sender<Vec<u8>>,
|
|
stats: Arc<QosStats>,
|
|
}
|
|
|
|
impl PrioritySender {
|
|
/// Send a packet with the given priority. Implements smart dropping under backpressure.
|
|
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
|
|
let (tx, enqueued_counter) = match priority {
|
|
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
|
|
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
|
|
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
|
|
};
|
|
|
|
match tx.try_send(packet) {
|
|
Ok(()) => {
|
|
enqueued_counter.fetch_add(1, Ordering::Relaxed);
|
|
Ok(())
|
|
}
|
|
Err(mpsc::error::TrySendError::Full(packet)) => {
|
|
self.handle_backpressure(packet, priority).await
|
|
}
|
|
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
|
|
}
|
|
}
|
|
|
|
async fn handle_backpressure(
|
|
&self,
|
|
packet: Vec<u8>,
|
|
priority: Priority,
|
|
) -> Result<(), PacketDropped> {
|
|
match priority {
|
|
Priority::Low => {
|
|
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
|
|
Err(PacketDropped::LowPriorityDrop)
|
|
}
|
|
Priority::Normal => {
|
|
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
|
|
Err(PacketDropped::NormalPriorityDrop)
|
|
}
|
|
Priority::High => {
|
|
// Last resort: briefly wait for space, then drop
|
|
match tokio::time::timeout(
|
|
Duration::from_millis(5),
|
|
self.high_tx.send(packet),
|
|
)
|
|
.await
|
|
{
|
|
Ok(Ok(())) => {
|
|
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
|
|
Ok(())
|
|
}
|
|
_ => {
|
|
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
|
|
Err(PacketDropped::HighPriorityDrop)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Receiving half of the priority channel set.
|
|
pub struct PriorityReceiver {
|
|
high_rx: mpsc::Receiver<Vec<u8>>,
|
|
normal_rx: mpsc::Receiver<Vec<u8>>,
|
|
low_rx: mpsc::Receiver<Vec<u8>>,
|
|
}
|
|
|
|
impl PriorityReceiver {
|
|
/// Receive the next packet, draining high-priority first (biased select).
|
|
pub async fn recv(&mut self) -> Option<Vec<u8>> {
|
|
tokio::select! {
|
|
biased;
|
|
Some(pkt) = self.high_rx.recv() => Some(pkt),
|
|
Some(pkt) = self.normal_rx.recv() => Some(pkt),
|
|
Some(pkt) = self.low_rx.recv() => Some(pkt),
|
|
else => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Create a priority channel set split into sender and receiver halves.
|
|
///
|
|
/// - `high_cap`: capacity of the high-priority channel
|
|
/// - `normal_cap`: capacity of the normal-priority channel
|
|
/// - `low_cap`: capacity of the low-priority channel
|
|
pub fn create_priority_channels(
|
|
high_cap: usize,
|
|
normal_cap: usize,
|
|
low_cap: usize,
|
|
) -> (PrioritySender, PriorityReceiver) {
|
|
let (high_tx, high_rx) = mpsc::channel(high_cap);
|
|
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
|
|
let (low_tx, low_rx) = mpsc::channel(low_cap);
|
|
let stats = Arc::new(QosStats::new());
|
|
|
|
let sender = PrioritySender {
|
|
high_tx,
|
|
normal_tx,
|
|
low_tx,
|
|
stats,
|
|
};
|
|
|
|
let receiver = PriorityReceiver {
|
|
high_rx,
|
|
normal_rx,
|
|
low_rx,
|
|
};
|
|
|
|
(sender, receiver)
|
|
}
|
|
|
|
/// Get a reference to the QoS stats from a sender.
|
|
impl PrioritySender {
|
|
pub fn stats(&self) -> &Arc<QosStats> {
|
|
&self.stats
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// Helper: craft a minimal IPv4 packet
|
|
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
|
|
let mut pkt = vec![0u8; total_len.max(24) as usize];
|
|
pkt[0] = 0x45; // version 4, IHL 5
|
|
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
|
|
pkt[9] = protocol;
|
|
// src IP
|
|
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
|
// dst IP
|
|
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
|
// ports (at offset 20 for IHL=5)
|
|
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
|
|
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
|
|
pkt
|
|
}
|
|
|
|
#[test]
|
|
fn classify_icmp_as_high() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
|
|
assert_eq!(c.classify(&pkt), Priority::High);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_dns_as_high() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
|
|
assert_eq!(c.classify(&pkt), Priority::High);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_ssh_as_high() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
|
|
assert_eq!(c.classify(&pkt), Priority::High);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_small_packet_as_high() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
|
|
assert_eq!(c.classify(&pkt), Priority::High);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_normal_http() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
|
|
assert_eq!(c.classify(&pkt), Priority::Normal);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_bulk_flow_as_low() {
|
|
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
|
|
|
|
// Send enough traffic to exceed the threshold
|
|
for _ in 0..100 {
|
|
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
|
c.classify(&pkt);
|
|
}
|
|
|
|
// Next packet from same flow should be Low
|
|
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
|
assert_eq!(c.classify(&pkt), Priority::Low);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_too_short_packet() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let pkt = vec![0u8; 10]; // Too short for IPv4 header
|
|
assert_eq!(c.classify(&pkt), Priority::Normal);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_non_ipv4() {
|
|
let mut c = PacketClassifier::new(1_000_000);
|
|
let mut pkt = vec![0u8; 40];
|
|
pkt[0] = 0x60; // IPv6 version nibble
|
|
assert_eq!(c.classify(&pkt), Priority::Normal);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn priority_receiver_drains_high_first() {
|
|
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
|
|
|
|
// Enqueue in reverse order
|
|
sender.send(vec![3], Priority::Low).await.unwrap();
|
|
sender.send(vec![2], Priority::Normal).await.unwrap();
|
|
sender.send(vec![1], Priority::High).await.unwrap();
|
|
|
|
// Should drain High first
|
|
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
|
|
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
|
|
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn smart_dropping_low_priority() {
|
|
let (sender, _receiver) = create_priority_channels(8, 8, 1);
|
|
|
|
// Fill the low channel
|
|
sender.send(vec![0], Priority::Low).await.unwrap();
|
|
|
|
// Next low-priority send should be dropped
|
|
let result = sender.send(vec![1], Priority::Low).await;
|
|
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
|
|
|
|
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn smart_dropping_normal_priority() {
|
|
let (sender, _receiver) = create_priority_channels(8, 1, 8);
|
|
|
|
sender.send(vec![0], Priority::Normal).await.unwrap();
|
|
|
|
let result = sender.send(vec![1], Priority::Normal).await;
|
|
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
|
|
|
|
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn stats_track_enqueued() {
|
|
let (sender, _receiver) = create_priority_channels(8, 8, 8);
|
|
|
|
sender.send(vec![1], Priority::High).await.unwrap();
|
|
sender.send(vec![2], Priority::High).await.unwrap();
|
|
sender.send(vec![3], Priority::Normal).await.unwrap();
|
|
sender.send(vec![4], Priority::Low).await.unwrap();
|
|
|
|
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
|
|
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
|
|
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn flow_tracker_evicts_at_capacity() {
|
|
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
|
|
|
|
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
|
|
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
|
|
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
|
|
|
|
ft.record(k1, 100, 1000);
|
|
ft.record(k2, 100, 1000);
|
|
// Should evict k1 (oldest)
|
|
ft.record(k3, 100, 1000);
|
|
|
|
assert_eq!(ft.flows.len(), 2);
|
|
assert!(!ft.flows.contains_key(&k1));
|
|
}
|
|
}
|