use std::collections::HashMap; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::mpsc; /// Priority levels for IP packets. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum Priority { High = 0, Normal = 1, Low = 2, } /// QoS statistics per priority level. pub struct QosStats { pub high_enqueued: AtomicU64, pub normal_enqueued: AtomicU64, pub low_enqueued: AtomicU64, pub high_dropped: AtomicU64, pub normal_dropped: AtomicU64, pub low_dropped: AtomicU64, } impl QosStats { pub fn new() -> Self { Self { high_enqueued: AtomicU64::new(0), normal_enqueued: AtomicU64::new(0), low_enqueued: AtomicU64::new(0), high_dropped: AtomicU64::new(0), normal_dropped: AtomicU64::new(0), low_dropped: AtomicU64::new(0), } } } impl Default for QosStats { fn default() -> Self { Self::new() } } // ============================================================================ // Packet classification // ============================================================================ /// 5-tuple flow key for tracking bulk flows. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct FlowKey { src_ip: u32, dst_ip: u32, src_port: u16, dst_port: u16, protocol: u8, } /// Per-flow state for bulk detection. struct FlowState { bytes_total: u64, window_start: Instant, } /// Tracks per-flow byte counts for bulk flow detection. struct FlowTracker { flows: HashMap, window_duration: Duration, max_flows: usize, } impl FlowTracker { fn new(window_duration: Duration, max_flows: usize) -> Self { Self { flows: HashMap::new(), window_duration, max_flows, } } /// Record bytes for a flow. Returns true if the flow exceeds the threshold. fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool { let now = Instant::now(); // Evict if at capacity — remove oldest entry if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) { if let Some(oldest_key) = self .flows .iter() .min_by_key(|(_, v)| v.window_start) .map(|(k, _)| *k) { self.flows.remove(&oldest_key); } } let state = self.flows.entry(key).or_insert(FlowState { bytes_total: 0, window_start: now, }); // Reset window if expired if now.duration_since(state.window_start) > self.window_duration { state.bytes_total = 0; state.window_start = now; } state.bytes_total += bytes; state.bytes_total > threshold } } /// Classifies raw IP packets into priority levels. pub struct PacketClassifier { flow_tracker: FlowTracker, /// Byte threshold for classifying a flow as "bulk" (Low priority). bulk_threshold_bytes: u64, } impl PacketClassifier { /// Create a new classifier. /// /// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB) pub fn new(bulk_threshold_bytes: u64) -> Self { Self { flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000), bulk_threshold_bytes, } } /// Classify a raw IPv4 packet into a priority level. /// /// The packet must start with the IPv4 header (as read from a TUN device). pub fn classify(&mut self, ip_packet: &[u8]) -> Priority { // Need at least 20 bytes for a minimal IPv4 header if ip_packet.len() < 20 { return Priority::Normal; } let version = ip_packet[0] >> 4; if version != 4 { return Priority::Normal; // Only classify IPv4 for now } let ihl = (ip_packet[0] & 0x0F) as usize; let header_len = ihl * 4; let protocol = ip_packet[9]; let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize; // ICMP is always high priority if protocol == 1 { return Priority::High; } // Small packets (<128 bytes) are high priority (likely interactive) if total_len < 128 { return Priority::High; } // Extract ports for TCP (6) and UDP (17) let (src_port, dst_port) = if (protocol == 6 || protocol == 17) && ip_packet.len() >= header_len + 4 { let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]); let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]); (sp, dp) } else { (0, 0) }; // DNS (port 53) and SSH (port 22) are high priority if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 { return Priority::High; } // Check for bulk flow if protocol == 6 || protocol == 17 { let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]); let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]); let key = FlowKey { src_ip, dst_ip, src_port, dst_port, protocol, }; if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) { return Priority::Low; } } Priority::Normal } } // ============================================================================ // Priority channel set // ============================================================================ /// Error returned when a packet is dropped. #[derive(Debug)] pub enum PacketDropped { LowPriorityDrop, NormalPriorityDrop, HighPriorityDrop, ChannelClosed, } /// Sending half of the priority channel set. pub struct PrioritySender { high_tx: mpsc::Sender>, normal_tx: mpsc::Sender>, low_tx: mpsc::Sender>, stats: Arc, } impl PrioritySender { /// Send a packet with the given priority. Implements smart dropping under backpressure. pub async fn send(&self, packet: Vec, priority: Priority) -> Result<(), PacketDropped> { let (tx, enqueued_counter) = match priority { Priority::High => (&self.high_tx, &self.stats.high_enqueued), Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued), Priority::Low => (&self.low_tx, &self.stats.low_enqueued), }; match tx.try_send(packet) { Ok(()) => { enqueued_counter.fetch_add(1, Ordering::Relaxed); Ok(()) } Err(mpsc::error::TrySendError::Full(packet)) => { self.handle_backpressure(packet, priority).await } Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed), } } async fn handle_backpressure( &self, packet: Vec, priority: Priority, ) -> Result<(), PacketDropped> { match priority { Priority::Low => { self.stats.low_dropped.fetch_add(1, Ordering::Relaxed); Err(PacketDropped::LowPriorityDrop) } Priority::Normal => { self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed); Err(PacketDropped::NormalPriorityDrop) } Priority::High => { // Last resort: briefly wait for space, then drop match tokio::time::timeout( Duration::from_millis(5), self.high_tx.send(packet), ) .await { Ok(Ok(())) => { self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed); Ok(()) } _ => { self.stats.high_dropped.fetch_add(1, Ordering::Relaxed); Err(PacketDropped::HighPriorityDrop) } } } } } } /// Receiving half of the priority channel set. pub struct PriorityReceiver { high_rx: mpsc::Receiver>, normal_rx: mpsc::Receiver>, low_rx: mpsc::Receiver>, } impl PriorityReceiver { /// Receive the next packet, draining high-priority first (biased select). pub async fn recv(&mut self) -> Option> { tokio::select! { biased; Some(pkt) = self.high_rx.recv() => Some(pkt), Some(pkt) = self.normal_rx.recv() => Some(pkt), Some(pkt) = self.low_rx.recv() => Some(pkt), else => None, } } } /// Create a priority channel set split into sender and receiver halves. /// /// - `high_cap`: capacity of the high-priority channel /// - `normal_cap`: capacity of the normal-priority channel /// - `low_cap`: capacity of the low-priority channel pub fn create_priority_channels( high_cap: usize, normal_cap: usize, low_cap: usize, ) -> (PrioritySender, PriorityReceiver) { let (high_tx, high_rx) = mpsc::channel(high_cap); let (normal_tx, normal_rx) = mpsc::channel(normal_cap); let (low_tx, low_rx) = mpsc::channel(low_cap); let stats = Arc::new(QosStats::new()); let sender = PrioritySender { high_tx, normal_tx, low_tx, stats, }; let receiver = PriorityReceiver { high_rx, normal_rx, low_rx, }; (sender, receiver) } /// Get a reference to the QoS stats from a sender. impl PrioritySender { pub fn stats(&self) -> &Arc { &self.stats } } #[cfg(test)] mod tests { use super::*; // Helper: craft a minimal IPv4 packet fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec { let mut pkt = vec![0u8; total_len.max(24) as usize]; pkt[0] = 0x45; // version 4, IHL 5 pkt[2..4].copy_from_slice(&total_len.to_be_bytes()); pkt[9] = protocol; // src IP pkt[12..16].copy_from_slice(&[10, 0, 0, 1]); // dst IP pkt[16..20].copy_from_slice(&[10, 0, 0, 2]); // ports (at offset 20 for IHL=5) pkt[20..22].copy_from_slice(&src_port.to_be_bytes()); pkt[22..24].copy_from_slice(&dst_port.to_be_bytes()); pkt } #[test] fn classify_icmp_as_high() { let mut c = PacketClassifier::new(1_000_000); let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP assert_eq!(c.classify(&pkt), Priority::High); } #[test] fn classify_dns_as_high() { let mut c = PacketClassifier::new(1_000_000); let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53 assert_eq!(c.classify(&pkt), Priority::High); } #[test] fn classify_ssh_as_high() { let mut c = PacketClassifier::new(1_000_000); let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22 assert_eq!(c.classify(&pkt), Priority::High); } #[test] fn classify_small_packet_as_high() { let mut c = PacketClassifier::new(1_000_000); let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet assert_eq!(c.classify(&pkt), Priority::High); } #[test] fn classify_normal_http() { let mut c = PacketClassifier::new(1_000_000); let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B assert_eq!(c.classify(&pkt), Priority::Normal); } #[test] fn classify_bulk_flow_as_low() { let mut c = PacketClassifier::new(10_000); // Low threshold for testing // Send enough traffic to exceed the threshold for _ in 0..100 { let pkt = make_ipv4_packet(6, 12345, 80, 500); c.classify(&pkt); } // Next packet from same flow should be Low let pkt = make_ipv4_packet(6, 12345, 80, 500); assert_eq!(c.classify(&pkt), Priority::Low); } #[test] fn classify_too_short_packet() { let mut c = PacketClassifier::new(1_000_000); let pkt = vec![0u8; 10]; // Too short for IPv4 header assert_eq!(c.classify(&pkt), Priority::Normal); } #[test] fn classify_non_ipv4() { let mut c = PacketClassifier::new(1_000_000); let mut pkt = vec![0u8; 40]; pkt[0] = 0x60; // IPv6 version nibble assert_eq!(c.classify(&pkt), Priority::Normal); } #[tokio::test] async fn priority_receiver_drains_high_first() { let (sender, mut receiver) = create_priority_channels(8, 8, 8); // Enqueue in reverse order sender.send(vec![3], Priority::Low).await.unwrap(); sender.send(vec![2], Priority::Normal).await.unwrap(); sender.send(vec![1], Priority::High).await.unwrap(); // Should drain High first assert_eq!(receiver.recv().await.unwrap(), vec![1]); assert_eq!(receiver.recv().await.unwrap(), vec![2]); assert_eq!(receiver.recv().await.unwrap(), vec![3]); } #[tokio::test] async fn smart_dropping_low_priority() { let (sender, _receiver) = create_priority_channels(8, 8, 1); // Fill the low channel sender.send(vec![0], Priority::Low).await.unwrap(); // Next low-priority send should be dropped let result = sender.send(vec![1], Priority::Low).await; assert!(matches!(result, Err(PacketDropped::LowPriorityDrop))); assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1); } #[tokio::test] async fn smart_dropping_normal_priority() { let (sender, _receiver) = create_priority_channels(8, 1, 8); sender.send(vec![0], Priority::Normal).await.unwrap(); let result = sender.send(vec![1], Priority::Normal).await; assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop))); assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1); } #[tokio::test] async fn stats_track_enqueued() { let (sender, _receiver) = create_priority_channels(8, 8, 8); sender.send(vec![1], Priority::High).await.unwrap(); sender.send(vec![2], Priority::High).await.unwrap(); sender.send(vec![3], Priority::Normal).await.unwrap(); sender.send(vec![4], Priority::Low).await.unwrap(); assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2); assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1); assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1); } #[test] fn flow_tracker_evicts_at_capacity() { let mut ft = FlowTracker::new(Duration::from_secs(60), 2); let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 }; let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 }; let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 }; ft.record(k1, 100, 1000); ft.record(k2, 100, 1000); // Should evict k1 (oldest) ft.record(k3, 100, 1000); assert_eq!(ft.flows.len(), 2); assert!(!ft.flows.contains_key(&k1)); } }