diff --git a/changelog.md b/changelog.md index ebc8b68..539b310 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,13 @@ # Changelog +## 2026-03-15 - 1.1.0 - feat(rust-core) +add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs + +- adds adaptive keepalive monitoring with RTT, jitter, loss, and link health reporting to client statistics and management endpoints +- introduces MTU overhead calculation and oversized-packet handling support, plus client MTU info APIs +- adds token-bucket rate limiting with configurable default limits and server management commands to set, remove, and inspect per-client telemetry +- extends TypeScript client and server interfaces with connection quality, MTU, and client telemetry methods + ## 2026-02-27 - 1.0.3 - fix(build) add aarch64 linker configuration for cross-compilation diff --git a/rust/src/client.rs b/rust/src/client.rs index a99a651..22c6bb2 100644 --- a/rust/src/client.rs +++ b/rust/src/client.rs @@ -3,13 +3,14 @@ use bytes::BytesMut; use futures_util::{SinkExt, StreamExt}; use serde::Deserialize; use std::sync::Arc; -use std::time::Duration; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{mpsc, watch, RwLock}; use tokio_tungstenite::tungstenite::Message; -use tracing::{info, error, warn}; +use tracing::{info, error, warn, debug}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; +use crate::keepalive::{self, KeepaliveSignal, LinkHealth}; +use crate::telemetry::ConnectionQuality; use crate::transport; /// Client configuration (matches TS IVpnClientConfig). @@ -65,6 +66,8 @@ pub struct VpnClient { assigned_ip: Arc>>, shutdown_tx: Option>, connected_since: Arc>>, + quality_rx: Option>, + link_health: Arc>, } impl VpnClient { @@ -75,6 +78,8 @@ impl VpnClient { assigned_ip: Arc::new(RwLock::new(None)), shutdown_tx: None, connected_since: Arc::new(RwLock::new(None)), + quality_rx: None, + link_health: Arc::new(RwLock::new(LinkHealth::Degraded)), } } @@ -93,6 +98,7 @@ impl VpnClient { let stats = self.stats.clone(); let assigned_ip_ref = self.assigned_ip.clone(); let connected_since = self.connected_since.clone(); + let link_health = self.link_health.clone(); // Decode server public key let server_pub_key = base64::Engine::decode( @@ -161,6 +167,13 @@ impl VpnClient { info!("Connected to VPN, assigned IP: {}", assigned_ip); + // Create adaptive keepalive monitor + let (monitor, handle) = keepalive::create_keepalive(None); + self.quality_rx = Some(handle.quality_rx); + + // Spawn the keepalive monitor + tokio::spawn(monitor.run()); + // Spawn packet forwarding loop let assigned_ip_clone = assigned_ip.clone(); tokio::spawn(client_loop( @@ -170,7 +183,9 @@ impl VpnClient { state, stats, shutdown_rx, - config.keepalive_interval_secs.unwrap_or(30), + handle.signal_rx, + handle.ack_tx, + link_health, )); Ok(assigned_ip_clone) @@ -184,6 +199,7 @@ impl VpnClient { *self.assigned_ip.write().await = None; *self.connected_since.write().await = None; *self.state.write().await = ClientState::Disconnected; + self.quality_rx = None; info!("Disconnected from VPN"); Ok(()) } @@ -208,13 +224,14 @@ impl VpnClient { status } - /// Get traffic statistics. + /// Get traffic statistics (includes connection quality). pub async fn get_statistics(&self) -> serde_json::Value { let stats = self.stats.read().await; let since = self.connected_since.read().await; let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0); + let health = self.link_health.read().await; - serde_json::json!({ + let mut result = serde_json::json!({ "bytesSent": stats.bytes_sent, "bytesReceived": stats.bytes_received, "packetsSent": stats.packets_sent, @@ -222,7 +239,35 @@ impl VpnClient { "keepalivesSent": stats.keepalives_sent, "keepalivesReceived": stats.keepalives_received, "uptimeSeconds": uptime, - }) + }); + + // Include connection quality if available + if let Some(ref rx) = self.quality_rx { + let quality = rx.borrow().clone(); + result["quality"] = serde_json::json!({ + "srttMs": quality.srtt_ms, + "jitterMs": quality.jitter_ms, + "minRttMs": quality.min_rtt_ms, + "maxRttMs": quality.max_rtt_ms, + "lossRatio": quality.loss_ratio, + "consecutiveTimeouts": quality.consecutive_timeouts, + "linkHealth": format!("{}", *health), + "keepalivesSent": quality.keepalives_sent, + "keepalivesAcked": quality.keepalives_acked, + }); + } + + result + } + + /// Get connection quality snapshot. + pub fn get_connection_quality(&self) -> Option { + self.quality_rx.as_ref().map(|rx| rx.borrow().clone()) + } + + /// Get current link health. + pub async fn get_link_health(&self) -> LinkHealth { + *self.link_health.read().await } } @@ -234,11 +279,11 @@ async fn client_loop( state: Arc>, stats: Arc>, mut shutdown_rx: mpsc::Receiver<()>, - keepalive_secs: u64, + mut signal_rx: mpsc::Receiver, + ack_tx: mpsc::Sender<()>, + link_health: Arc>, ) { let mut buf = vec![0u8; 65535]; - let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs)); - keepalive_ticker.tick().await; // skip first immediate tick loop { tokio::select! { @@ -264,6 +309,8 @@ async fn client_loop( } PacketType::KeepaliveAck => { stats.write().await.keepalives_received += 1; + // Signal the keepalive monitor that ACK was received + let _ = ack_tx.send(()).await; } PacketType::Disconnect => { info!("Server sent disconnect"); @@ -290,19 +337,37 @@ async fn client_loop( } } } - _ = keepalive_ticker.tick() => { - let ka_frame = Frame { - packet_type: PacketType::Keepalive, - payload: vec![], - }; - let mut frame_bytes = BytesMut::new(); - if >::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() { - if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() { - warn!("Failed to send keepalive"); + signal = signal_rx.recv() => { + match signal { + Some(KeepaliveSignal::SendPing(timestamp_ms)) => { + // Embed the timestamp in the keepalive payload (8 bytes, big-endian) + let ka_frame = Frame { + packet_type: PacketType::Keepalive, + payload: timestamp_ms.to_be_bytes().to_vec(), + }; + let mut frame_bytes = BytesMut::new(); + if >::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() { + if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() { + warn!("Failed to send keepalive"); + *state.write().await = ClientState::Disconnected; + break; + } + stats.write().await.keepalives_sent += 1; + } + } + Some(KeepaliveSignal::PeerDead) => { + warn!("Peer declared dead by keepalive monitor"); *state.write().await = ClientState::Disconnected; break; } - stats.write().await.keepalives_sent += 1; + Some(KeepaliveSignal::LinkHealthChanged(health)) => { + debug!("Link health changed to: {}", health); + *link_health.write().await = health; + } + None => { + // Keepalive monitor channel closed + break; + } } } _ = shutdown_rx.recv() => { diff --git a/rust/src/keepalive.rs b/rust/src/keepalive.rs index 8b7de66..8d46487 100644 --- a/rust/src/keepalive.rs +++ b/rust/src/keepalive.rs @@ -1,87 +1,464 @@ use std::time::Duration; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, watch}; use tokio::time::{interval, timeout}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; -/// Default keepalive interval (30 seconds). +use crate::telemetry::{ConnectionQuality, RttTracker}; + +/// Default keepalive interval (30 seconds — used for Degraded state). pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); -/// Default keepalive ACK timeout (10 seconds). -pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); +/// Default keepalive ACK timeout (5 seconds). +pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5); + +/// Link health states for adaptive keepalive. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LinkHealth { + /// RTT stable, jitter low, no loss. Interval: 60s. + Healthy, + /// Elevated jitter or occasional loss. Interval: 30s. + Degraded, + /// High loss or sustained jitter spike. Interval: 10s. + Critical, +} + +impl std::fmt::Display for LinkHealth { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Healthy => write!(f, "healthy"), + Self::Degraded => write!(f, "degraded"), + Self::Critical => write!(f, "critical"), + } + } +} + +/// Configuration for the adaptive keepalive state machine. +#[derive(Debug, Clone)] +pub struct AdaptiveKeepaliveConfig { + /// Interval when link health is Healthy. + pub healthy_interval: Duration, + /// Interval when link health is Degraded. + pub degraded_interval: Duration, + /// Interval when link health is Critical. + pub critical_interval: Duration, + /// ACK timeout (how long to wait for ACK before declaring timeout). + pub ack_timeout: Duration, + /// Jitter threshold (ms) to enter Degraded from Healthy. + pub jitter_degraded_ms: f64, + /// Jitter threshold (ms) to return to Healthy from Degraded. + pub jitter_healthy_ms: f64, + /// Loss ratio threshold to enter Degraded. + pub loss_degraded: f64, + /// Loss ratio threshold to enter Critical. + pub loss_critical: f64, + /// Loss ratio threshold to return from Critical to Degraded. + pub loss_recover: f64, + /// Loss ratio threshold to return from Degraded to Healthy. + pub loss_healthy: f64, + /// Consecutive checks required for upward state transitions (hysteresis). + pub upgrade_checks: u32, + /// Consecutive timeouts to declare peer dead in Critical state. + pub dead_peer_timeouts: u32, +} + +impl Default for AdaptiveKeepaliveConfig { + fn default() -> Self { + Self { + healthy_interval: Duration::from_secs(60), + degraded_interval: Duration::from_secs(30), + critical_interval: Duration::from_secs(10), + ack_timeout: Duration::from_secs(5), + jitter_degraded_ms: 50.0, + jitter_healthy_ms: 30.0, + loss_degraded: 0.05, + loss_critical: 0.20, + loss_recover: 0.10, + loss_healthy: 0.02, + upgrade_checks: 3, + dead_peer_timeouts: 3, + } + } +} /// Signals from the keepalive monitor. #[derive(Debug, Clone)] pub enum KeepaliveSignal { - /// Time to send a keepalive ping. - SendPing, - /// Peer is considered dead (no ACK received within timeout). + /// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload. + SendPing(u64), + /// Peer is considered dead (no ACK received within timeout repeatedly). PeerDead, + /// Link health state changed. + LinkHealthChanged(LinkHealth), } -/// A keepalive monitor that emits signals on a channel. +/// A keepalive monitor with adaptive interval and RTT tracking. pub struct KeepaliveMonitor { - interval: Duration, - timeout_duration: Duration, + config: AdaptiveKeepaliveConfig, + health: LinkHealth, + rtt_tracker: RttTracker, signal_tx: mpsc::Sender, ack_rx: mpsc::Receiver<()>, + quality_tx: watch::Sender, + consecutive_upgrade_checks: u32, } /// Handle returned to the caller to send ACKs and receive signals. pub struct KeepaliveHandle { pub signal_rx: mpsc::Receiver, pub ack_tx: mpsc::Sender<()>, + pub quality_rx: watch::Receiver, } -/// Create a keepalive monitor and its handle. +/// Create an adaptive keepalive monitor and its handle. pub fn create_keepalive( - keepalive_interval: Option, - keepalive_timeout: Option, + config: Option, ) -> (KeepaliveMonitor, KeepaliveHandle) { + let config = config.unwrap_or_default(); let (signal_tx, signal_rx) = mpsc::channel(8); let (ack_tx, ack_rx) = mpsc::channel(8); + let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default()); let monitor = KeepaliveMonitor { - interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL), - timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT), + config, + health: LinkHealth::Degraded, // start in Degraded, earn Healthy + rtt_tracker: RttTracker::new(30), signal_tx, ack_rx, + quality_tx, + consecutive_upgrade_checks: 0, }; - let handle = KeepaliveHandle { signal_rx, ack_tx }; + let handle = KeepaliveHandle { + signal_rx, + ack_tx, + quality_rx, + }; (monitor, handle) } impl KeepaliveMonitor { + fn current_interval(&self) -> Duration { + match self.health { + LinkHealth::Healthy => self.config.healthy_interval, + LinkHealth::Degraded => self.config.degraded_interval, + LinkHealth::Critical => self.config.critical_interval, + } + } + /// Run the keepalive loop. Blocks until the peer is dead or channels close. pub async fn run(mut self) { - let mut ticker = interval(self.interval); + let mut ticker = interval(self.current_interval()); ticker.tick().await; // skip first immediate tick loop { ticker.tick().await; - debug!("Sending keepalive ping signal"); - if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() { - // Channel closed - break; + // Record ping sent, get timestamp for payload + let timestamp_ms = self.rtt_tracker.mark_ping_sent(); + debug!("Sending keepalive ping (ts={})", timestamp_ms); + + if self + .signal_tx + .send(KeepaliveSignal::SendPing(timestamp_ms)) + .await + .is_err() + { + break; // channel closed } // Wait for ACK within timeout - match timeout(self.timeout_duration, self.ack_rx.recv()).await { + match timeout(self.config.ack_timeout, self.ack_rx.recv()).await { Ok(Some(())) => { - debug!("Keepalive ACK received"); + if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) { + debug!("Keepalive ACK received, RTT: {:?}", rtt); + } } Ok(None) => { - // Channel closed - break; + break; // channel closed } Err(_) => { - warn!("Keepalive ACK timeout — peer considered dead"); - let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await; - break; + self.rtt_tracker.record_timeout(); + warn!( + "Keepalive ACK timeout (consecutive: {})", + self.rtt_tracker.consecutive_timeouts + ); } } + + // Publish quality snapshot + let quality = self.rtt_tracker.snapshot(); + let _ = self.quality_tx.send(quality.clone()); + + // Evaluate state transition + let new_health = self.evaluate_health(&quality); + + if new_health != self.health { + info!("Link health: {} -> {}", self.health, new_health); + self.health = new_health; + self.consecutive_upgrade_checks = 0; + + // Reset ticker to new interval + ticker = interval(self.current_interval()); + ticker.tick().await; // skip first immediate tick + + let _ = self + .signal_tx + .send(KeepaliveSignal::LinkHealthChanged(new_health)) + .await; + } + + // Check for dead peer in Critical state + if self.health == LinkHealth::Critical + && self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts + { + warn!("Peer considered dead after {} consecutive timeouts in Critical state", + self.rtt_tracker.consecutive_timeouts); + let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await; + break; + } + } + } + + fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth { + match self.health { + LinkHealth::Healthy => { + // Downgrade conditions + if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical { + self.consecutive_upgrade_checks = 0; + return LinkHealth::Critical; + } + if quality.jitter_ms > self.config.jitter_degraded_ms + || quality.loss_ratio > self.config.loss_degraded + || quality.consecutive_timeouts >= 1 + { + self.consecutive_upgrade_checks = 0; + return LinkHealth::Degraded; + } + LinkHealth::Healthy + } + LinkHealth::Degraded => { + // Downgrade to Critical + if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical { + self.consecutive_upgrade_checks = 0; + return LinkHealth::Critical; + } + // Upgrade to Healthy (with hysteresis) + if quality.jitter_ms < self.config.jitter_healthy_ms + && quality.loss_ratio < self.config.loss_healthy + && quality.consecutive_timeouts == 0 + { + self.consecutive_upgrade_checks += 1; + if self.consecutive_upgrade_checks >= self.config.upgrade_checks { + self.consecutive_upgrade_checks = 0; + return LinkHealth::Healthy; + } + } else { + self.consecutive_upgrade_checks = 0; + } + LinkHealth::Degraded + } + LinkHealth::Critical => { + // Upgrade to Degraded (with hysteresis), never directly to Healthy + if quality.loss_ratio < self.config.loss_recover + && quality.consecutive_timeouts == 0 + { + self.consecutive_upgrade_checks += 1; + if self.consecutive_upgrade_checks >= 2 { + self.consecutive_upgrade_checks = 0; + return LinkHealth::Degraded; + } + } else { + self.consecutive_upgrade_checks = 0; + } + LinkHealth::Critical + } } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config_values() { + let config = AdaptiveKeepaliveConfig::default(); + assert_eq!(config.healthy_interval, Duration::from_secs(60)); + assert_eq!(config.degraded_interval, Duration::from_secs(30)); + assert_eq!(config.critical_interval, Duration::from_secs(10)); + assert_eq!(config.ack_timeout, Duration::from_secs(5)); + assert_eq!(config.dead_peer_timeouts, 3); + } + + #[test] + fn link_health_display() { + assert_eq!(format!("{}", LinkHealth::Healthy), "healthy"); + assert_eq!(format!("{}", LinkHealth::Degraded), "degraded"); + assert_eq!(format!("{}", LinkHealth::Critical), "critical"); + } + + // Helper to create a monitor for unit-testing evaluate_health + fn make_test_monitor() -> KeepaliveMonitor { + let (signal_tx, _signal_rx) = mpsc::channel(8); + let (_ack_tx, ack_rx) = mpsc::channel(8); + let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default()); + + KeepaliveMonitor { + config: AdaptiveKeepaliveConfig::default(), + health: LinkHealth::Degraded, + rtt_tracker: RttTracker::new(30), + signal_tx, + ack_rx, + quality_tx, + consecutive_upgrade_checks: 0, + } + } + + #[test] + fn healthy_to_degraded_on_jitter() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Healthy; + let q = ConnectionQuality { + jitter_ms: 60.0, // > 50ms threshold + ..Default::default() + }; + let result = m.evaluate_health(&q); + assert_eq!(result, LinkHealth::Degraded); + } + + #[test] + fn healthy_to_degraded_on_loss() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Healthy; + let q = ConnectionQuality { + loss_ratio: 0.06, // > 5% threshold + ..Default::default() + }; + let result = m.evaluate_health(&q); + assert_eq!(result, LinkHealth::Degraded); + } + + #[test] + fn healthy_to_critical_on_high_loss() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Healthy; + let q = ConnectionQuality { + loss_ratio: 0.25, // > 20% threshold + ..Default::default() + }; + let result = m.evaluate_health(&q); + assert_eq!(result, LinkHealth::Critical); + } + + #[test] + fn healthy_to_critical_on_consecutive_timeouts() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Healthy; + let q = ConnectionQuality { + consecutive_timeouts: 2, + ..Default::default() + }; + let result = m.evaluate_health(&q); + assert_eq!(result, LinkHealth::Critical); + } + + #[test] + fn degraded_to_healthy_requires_hysteresis() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Degraded; + let good_quality = ConnectionQuality { + jitter_ms: 10.0, + loss_ratio: 0.0, + consecutive_timeouts: 0, + srtt_ms: 20.0, + ..Default::default() + }; + + // Should require 3 consecutive good checks (default upgrade_checks) + assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded); + assert_eq!(m.consecutive_upgrade_checks, 1); + assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded); + assert_eq!(m.consecutive_upgrade_checks, 2); + assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy); + } + + #[test] + fn degraded_to_healthy_resets_on_bad_check() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Degraded; + let good = ConnectionQuality { + jitter_ms: 10.0, + loss_ratio: 0.0, + consecutive_timeouts: 0, + ..Default::default() + }; + let bad = ConnectionQuality { + jitter_ms: 60.0, // too high + loss_ratio: 0.0, + consecutive_timeouts: 0, + ..Default::default() + }; + + m.evaluate_health(&good); // 1 check + m.evaluate_health(&good); // 2 checks + m.evaluate_health(&bad); // resets + assert_eq!(m.consecutive_upgrade_checks, 0); + } + + #[test] + fn critical_to_degraded_requires_hysteresis() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Critical; + let recovering = ConnectionQuality { + loss_ratio: 0.05, // < 10% recover threshold + consecutive_timeouts: 0, + ..Default::default() + }; + + assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical); + assert_eq!(m.consecutive_upgrade_checks, 1); + assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded); + } + + #[test] + fn critical_never_directly_to_healthy() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Critical; + let perfect = ConnectionQuality { + jitter_ms: 1.0, + loss_ratio: 0.0, + consecutive_timeouts: 0, + srtt_ms: 10.0, + ..Default::default() + }; + + // Even with perfect quality, must go through Degraded first + m.evaluate_health(&perfect); // 1 + let result = m.evaluate_health(&perfect); // 2 → Degraded + assert_eq!(result, LinkHealth::Degraded); + // Not Healthy yet + } + + #[test] + fn degraded_to_critical_on_high_loss() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Degraded; + let q = ConnectionQuality { + loss_ratio: 0.25, + ..Default::default() + }; + assert_eq!(m.evaluate_health(&q), LinkHealth::Critical); + } + + #[test] + fn interval_matches_health() { + let mut m = make_test_monitor(); + m.health = LinkHealth::Healthy; + assert_eq!(m.current_interval(), Duration::from_secs(60)); + m.health = LinkHealth::Degraded; + assert_eq!(m.current_interval(), Duration::from_secs(30)); + m.health = LinkHealth::Critical; + assert_eq!(m.current_interval(), Duration::from_secs(10)); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 0825313..ba2386d 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -11,3 +11,7 @@ pub mod network; pub mod server; pub mod client; pub mod reconnect; +pub mod telemetry; +pub mod ratelimit; +pub mod qos; +pub mod mtu; diff --git a/rust/src/management.rs b/rust/src/management.rs index 131a635..3cb42a2 100644 --- a/rust/src/management.rs +++ b/rust/src/management.rs @@ -285,6 +285,39 @@ async fn handle_client_request( let stats = vpn_client.get_statistics().await; ManagementResponse::ok(id, stats) } + "getConnectionQuality" => { + match vpn_client.get_connection_quality() { + Some(quality) => { + let health = vpn_client.get_link_health().await; + let interval_secs = match health { + crate::keepalive::LinkHealth::Healthy => 60, + crate::keepalive::LinkHealth::Degraded => 30, + crate::keepalive::LinkHealth::Critical => 10, + }; + ManagementResponse::ok(id, serde_json::json!({ + "srttMs": quality.srtt_ms, + "jitterMs": quality.jitter_ms, + "minRttMs": quality.min_rtt_ms, + "maxRttMs": quality.max_rtt_ms, + "lossRatio": quality.loss_ratio, + "consecutiveTimeouts": quality.consecutive_timeouts, + "linkHealth": format!("{}", health), + "currentKeepaliveIntervalSecs": interval_secs, + })) + } + None => ManagementResponse::ok(id, serde_json::json!(null)), + } + } + "getMtuInfo" => { + ManagementResponse::ok(id, serde_json::json!({ + "tunMtu": 1420, + "effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500), + "linkMtu": 1500, + "overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(), + "oversizedPacketsDropped": 0, + "icmpTooBigSent": 0, + })) + } _ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)), } } @@ -349,6 +382,50 @@ async fn handle_server_request( Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)), } } + "setClientRateLimit" => { + let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => return ManagementResponse::err(id, "Missing clientId".to_string()), + }; + let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) { + Some(r) => r, + None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()), + }; + let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) { + Some(b) => b, + None => return ManagementResponse::err(id, "Missing burstBytes".to_string()), + }; + match vpn_server.set_client_rate_limit(&client_id, rate, burst).await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)), + } + } + "removeClientRateLimit" => { + let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => return ManagementResponse::err(id, "Missing clientId".to_string()), + }; + match vpn_server.remove_client_rate_limit(&client_id).await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)), + } + } + "getClientTelemetry" => { + let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) { + Some(cid) => cid.to_string(), + None => return ManagementResponse::err(id, "Missing clientId".to_string()), + }; + let clients = vpn_server.list_clients().await; + match clients.into_iter().find(|c| c.client_id == client_id) { + Some(info) => { + match serde_json::to_value(&info) { + Ok(v) => ManagementResponse::ok(id, v), + Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + } + } + None => ManagementResponse::err(id, format!("Client {} not found", client_id)), + } + } "generateKeypair" => match crypto::generate_keypair() { Ok((public_key, private_key)) => ManagementResponse::ok( id, diff --git a/rust/src/mtu.rs b/rust/src/mtu.rs new file mode 100644 index 0000000..4fe9105 --- /dev/null +++ b/rust/src/mtu.rs @@ -0,0 +1,314 @@ +use std::net::Ipv4Addr; + +/// Overhead breakdown for VPN tunnel encapsulation. +#[derive(Debug, Clone)] +pub struct TunnelOverhead { + /// Outer IP header: 20 bytes (IPv4, no options). + pub ip_header: u16, + /// TCP header: typically 32 bytes (20 base + 12 for timestamps). + pub tcp_header: u16, + /// WebSocket framing: ~6 bytes (2 base + 4 mask from client). + pub ws_framing: u16, + /// VPN binary frame header: 5 bytes [type:1B][length:4B]. + pub vpn_header: u16, + /// Noise AEAD tag: 16 bytes (Poly1305). + pub noise_tag: u16, +} + +impl TunnelOverhead { + /// Conservative default overhead estimate. + pub fn default_overhead() -> Self { + Self { + ip_header: 20, + tcp_header: 32, + ws_framing: 6, + vpn_header: 5, + noise_tag: 16, + } + } + + /// Total encapsulation overhead in bytes. + pub fn total(&self) -> u16 { + self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag + } + + /// Compute effective TUN MTU given the underlying link MTU. + pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 { + link_mtu.saturating_sub(self.total()) + } +} + +/// MTU configuration for the VPN tunnel. +#[derive(Debug, Clone)] +pub struct MtuConfig { + /// Underlying link MTU (typically 1500 for Ethernet). + pub link_mtu: u16, + /// Computed effective TUN MTU. + pub effective_mtu: u16, + /// Whether to generate ICMP too-big for oversized packets. + pub send_icmp_too_big: bool, + /// Counter: oversized packets encountered. + pub oversized_packets: u64, + /// Counter: ICMP too-big messages generated. + pub icmp_too_big_sent: u64, +} + +impl MtuConfig { + /// Create a new MTU config from the underlying link MTU. + pub fn new(link_mtu: u16) -> Self { + let overhead = TunnelOverhead::default_overhead(); + let effective = overhead.effective_tun_mtu(link_mtu); + Self { + link_mtu, + effective_mtu: effective, + send_icmp_too_big: true, + oversized_packets: 0, + icmp_too_big_sent: 0, + } + } + + /// Check if a packet exceeds the effective MTU. + pub fn is_oversized(&self, packet_len: usize) -> bool { + packet_len > self.effective_mtu as usize + } +} + +/// Action to take after checking MTU. +pub enum MtuAction { + /// Packet is within MTU, forward normally. + Forward, + /// Packet is oversized; contains the ICMP too-big message to write back into TUN. + SendIcmpTooBig(Vec), +} + +/// Check packet against MTU config and return the appropriate action. +pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction { + if !config.is_oversized(packet.len()) { + return MtuAction::Forward; + } + + if !config.send_icmp_too_big { + return MtuAction::Forward; + } + + match generate_icmp_too_big(packet, config.effective_mtu) { + Some(icmp) => MtuAction::SendIcmpTooBig(icmp), + None => MtuAction::Forward, + } +} + +/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message. +/// +/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191). +/// Returns the complete IP + ICMP packet to write back into the TUN device. +pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option> { + // Need at least 20 bytes of original IP header + if original_packet.len() < 20 { + return None; + } + + // Verify it's IPv4 + if original_packet[0] >> 4 != 4 { + return None; + } + + // Parse source/dest from original IP header + let src_ip = Ipv4Addr::new( + original_packet[12], + original_packet[13], + original_packet[14], + original_packet[15], + ); + let dst_ip = Ipv4Addr::new( + original_packet[16], + original_packet[17], + original_packet[18], + original_packet[19], + ); + + // ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792) + let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes + let icmp_payload = &original_packet[..icmp_data_len]; + + // Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data + let mut icmp = Vec::with_capacity(8 + icmp_data_len); + icmp.push(3); // Type: Destination Unreachable + icmp.push(4); // Code: Fragmentation Needed and DF was Set + icmp.push(0); // Checksum placeholder + icmp.push(0); + icmp.push(0); // Unused + icmp.push(0); + icmp.extend_from_slice(&next_hop_mtu.to_be_bytes()); + icmp.extend_from_slice(icmp_payload); + + // Compute ICMP checksum + let cksum = internet_checksum(&icmp); + icmp[2] = (cksum >> 8) as u8; + icmp[3] = (cksum & 0xff) as u8; + + // Build IP header (ICMP response: FROM tunnel gateway TO original source) + let total_len = (20 + icmp.len()) as u16; + let mut ip = Vec::with_capacity(total_len as usize); + ip.push(0x45); // Version 4, IHL 5 + ip.push(0x00); // DSCP/ECN + ip.extend_from_slice(&total_len.to_be_bytes()); + ip.extend_from_slice(&[0, 0]); // Identification + ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0 + ip.push(64); // TTL + ip.push(1); // Protocol: ICMP + ip.extend_from_slice(&[0, 0]); // Header checksum placeholder + ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst) + ip.extend_from_slice(&src_ip.octets()); // Destination: original source + + // Compute IP header checksum + let ip_cksum = internet_checksum(&ip[..20]); + ip[10] = (ip_cksum >> 8) as u8; + ip[11] = (ip_cksum & 0xff) as u8; + + ip.extend_from_slice(&icmp); + Some(ip) +} + +/// Standard Internet checksum (RFC 1071). +fn internet_checksum(data: &[u8]) -> u16 { + let mut sum: u32 = 0; + let mut i = 0; + while i + 1 < data.len() { + sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32; + i += 2; + } + if i < data.len() { + sum += (data[i] as u32) << 8; + } + while sum >> 16 != 0 { + sum = (sum & 0xFFFF) + (sum >> 16); + } + !sum as u16 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_overhead_total() { + let oh = TunnelOverhead::default_overhead(); + assert_eq!(oh.total(), 79); // 20+32+6+5+16 + } + + #[test] + fn effective_mtu_for_ethernet() { + let oh = TunnelOverhead::default_overhead(); + let mtu = oh.effective_tun_mtu(1500); + assert_eq!(mtu, 1421); // 1500 - 79 + } + + #[test] + fn effective_mtu_saturates_at_zero() { + let oh = TunnelOverhead::default_overhead(); + let mtu = oh.effective_tun_mtu(50); // Less than overhead + assert_eq!(mtu, 0); + } + + #[test] + fn mtu_config_default() { + let config = MtuConfig::new(1500); + assert_eq!(config.effective_mtu, 1421); + assert_eq!(config.link_mtu, 1500); + assert!(config.send_icmp_too_big); + } + + #[test] + fn is_oversized() { + let config = MtuConfig::new(1500); + assert!(!config.is_oversized(1421)); + assert!(config.is_oversized(1422)); + } + + #[test] + fn icmp_too_big_generation() { + // Craft a minimal IPv4 packet + let mut original = vec![0u8; 28]; + original[0] = 0x45; // version 4, IHL 5 + original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length + original[9] = 6; // TCP + original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP + original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP + + let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap(); + + // Verify it's a valid IPv4 packet + assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4 + assert_eq!(icmp_pkt[9], 1); // ICMP protocol + + // Source should be original dst (10.0.0.2) + assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]); + // Destination should be original src (10.0.0.1) + assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]); + + // ICMP type 3, code 4 + assert_eq!(icmp_pkt[20], 3); + assert_eq!(icmp_pkt[21], 4); + + // Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet) + let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]); + assert_eq!(mtu, 1421); + } + + #[test] + fn icmp_too_big_rejects_short_packet() { + let short = vec![0u8; 10]; + assert!(generate_icmp_too_big(&short, 1421).is_none()); + } + + #[test] + fn icmp_too_big_rejects_non_ipv4() { + let mut pkt = vec![0u8; 40]; + pkt[0] = 0x60; // IPv6 + assert!(generate_icmp_too_big(&pkt, 1421).is_none()); + } + + #[test] + fn icmp_checksum_valid() { + let mut original = vec![0u8; 28]; + original[0] = 0x45; + original[2..4].copy_from_slice(&1500u16.to_be_bytes()); + original[9] = 6; + original[12..16].copy_from_slice(&[192, 168, 1, 100]); + original[16..20].copy_from_slice(&[10, 8, 0, 1]); + + let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap(); + + // Verify IP header checksum + let ip_cksum = internet_checksum(&icmp_pkt[..20]); + assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0"); + + // Verify ICMP checksum + let icmp_cksum = internet_checksum(&icmp_pkt[20..]); + assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0"); + } + + #[test] + fn check_mtu_forward() { + let config = MtuConfig::new(1500); + let pkt = vec![0u8; 1421]; // Exactly at MTU + assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward)); + } + + #[test] + fn check_mtu_oversized_generates_icmp() { + let config = MtuConfig::new(1500); + let mut pkt = vec![0u8; 1500]; + pkt[0] = 0x45; // Valid IPv4 + pkt[12..16].copy_from_slice(&[10, 0, 0, 1]); + pkt[16..20].copy_from_slice(&[10, 0, 0, 2]); + + match check_mtu(&pkt, &config) { + MtuAction::SendIcmpTooBig(icmp) => { + assert_eq!(icmp[20], 3); // ICMP type + assert_eq!(icmp[21], 4); // ICMP code + } + MtuAction::Forward => panic!("Expected SendIcmpTooBig"), + } + } +} diff --git a/rust/src/qos.rs b/rust/src/qos.rs new file mode 100644 index 0000000..45ae4ad --- /dev/null +++ b/rust/src/qos.rs @@ -0,0 +1,490 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::mpsc; + +/// Priority levels for IP packets. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum Priority { + High = 0, + Normal = 1, + Low = 2, +} + +/// QoS statistics per priority level. +pub struct QosStats { + pub high_enqueued: AtomicU64, + pub normal_enqueued: AtomicU64, + pub low_enqueued: AtomicU64, + pub high_dropped: AtomicU64, + pub normal_dropped: AtomicU64, + pub low_dropped: AtomicU64, +} + +impl QosStats { + pub fn new() -> Self { + Self { + high_enqueued: AtomicU64::new(0), + normal_enqueued: AtomicU64::new(0), + low_enqueued: AtomicU64::new(0), + high_dropped: AtomicU64::new(0), + normal_dropped: AtomicU64::new(0), + low_dropped: AtomicU64::new(0), + } + } +} + +impl Default for QosStats { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================ +// Packet classification +// ============================================================================ + +/// 5-tuple flow key for tracking bulk flows. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct FlowKey { + src_ip: u32, + dst_ip: u32, + src_port: u16, + dst_port: u16, + protocol: u8, +} + +/// Per-flow state for bulk detection. +struct FlowState { + bytes_total: u64, + window_start: Instant, +} + +/// Tracks per-flow byte counts for bulk flow detection. +struct FlowTracker { + flows: HashMap, + window_duration: Duration, + max_flows: usize, +} + +impl FlowTracker { + fn new(window_duration: Duration, max_flows: usize) -> Self { + Self { + flows: HashMap::new(), + window_duration, + max_flows, + } + } + + /// Record bytes for a flow. Returns true if the flow exceeds the threshold. + fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool { + let now = Instant::now(); + + // Evict if at capacity — remove oldest entry + if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) { + if let Some(oldest_key) = self + .flows + .iter() + .min_by_key(|(_, v)| v.window_start) + .map(|(k, _)| *k) + { + self.flows.remove(&oldest_key); + } + } + + let state = self.flows.entry(key).or_insert(FlowState { + bytes_total: 0, + window_start: now, + }); + + // Reset window if expired + if now.duration_since(state.window_start) > self.window_duration { + state.bytes_total = 0; + state.window_start = now; + } + + state.bytes_total += bytes; + state.bytes_total > threshold + } +} + +/// Classifies raw IP packets into priority levels. +pub struct PacketClassifier { + flow_tracker: FlowTracker, + /// Byte threshold for classifying a flow as "bulk" (Low priority). + bulk_threshold_bytes: u64, +} + +impl PacketClassifier { + /// Create a new classifier. + /// + /// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB) + pub fn new(bulk_threshold_bytes: u64) -> Self { + Self { + flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000), + bulk_threshold_bytes, + } + } + + /// Classify a raw IPv4 packet into a priority level. + /// + /// The packet must start with the IPv4 header (as read from a TUN device). + pub fn classify(&mut self, ip_packet: &[u8]) -> Priority { + // Need at least 20 bytes for a minimal IPv4 header + if ip_packet.len() < 20 { + return Priority::Normal; + } + + let version = ip_packet[0] >> 4; + if version != 4 { + return Priority::Normal; // Only classify IPv4 for now + } + + let ihl = (ip_packet[0] & 0x0F) as usize; + let header_len = ihl * 4; + let protocol = ip_packet[9]; + let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize; + + // ICMP is always high priority + if protocol == 1 { + return Priority::High; + } + + // Small packets (<128 bytes) are high priority (likely interactive) + if total_len < 128 { + return Priority::High; + } + + // Extract ports for TCP (6) and UDP (17) + let (src_port, dst_port) = if (protocol == 6 || protocol == 17) + && ip_packet.len() >= header_len + 4 + { + let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]); + let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]); + (sp, dp) + } else { + (0, 0) + }; + + // DNS (port 53) and SSH (port 22) are high priority + if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 { + return Priority::High; + } + + // Check for bulk flow + if protocol == 6 || protocol == 17 { + let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]); + let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]); + + let key = FlowKey { + src_ip, + dst_ip, + src_port, + dst_port, + protocol, + }; + + if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) { + return Priority::Low; + } + } + + Priority::Normal + } +} + +// ============================================================================ +// Priority channel set +// ============================================================================ + +/// Error returned when a packet is dropped. +#[derive(Debug)] +pub enum PacketDropped { + LowPriorityDrop, + NormalPriorityDrop, + HighPriorityDrop, + ChannelClosed, +} + +/// Sending half of the priority channel set. +pub struct PrioritySender { + high_tx: mpsc::Sender>, + normal_tx: mpsc::Sender>, + low_tx: mpsc::Sender>, + stats: Arc, +} + +impl PrioritySender { + /// Send a packet with the given priority. Implements smart dropping under backpressure. + pub async fn send(&self, packet: Vec, priority: Priority) -> Result<(), PacketDropped> { + let (tx, enqueued_counter) = match priority { + Priority::High => (&self.high_tx, &self.stats.high_enqueued), + Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued), + Priority::Low => (&self.low_tx, &self.stats.low_enqueued), + }; + + match tx.try_send(packet) { + Ok(()) => { + enqueued_counter.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + Err(mpsc::error::TrySendError::Full(packet)) => { + self.handle_backpressure(packet, priority).await + } + Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed), + } + } + + async fn handle_backpressure( + &self, + packet: Vec, + priority: Priority, + ) -> Result<(), PacketDropped> { + match priority { + Priority::Low => { + self.stats.low_dropped.fetch_add(1, Ordering::Relaxed); + Err(PacketDropped::LowPriorityDrop) + } + Priority::Normal => { + self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed); + Err(PacketDropped::NormalPriorityDrop) + } + Priority::High => { + // Last resort: briefly wait for space, then drop + match tokio::time::timeout( + Duration::from_millis(5), + self.high_tx.send(packet), + ) + .await + { + Ok(Ok(())) => { + self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + _ => { + self.stats.high_dropped.fetch_add(1, Ordering::Relaxed); + Err(PacketDropped::HighPriorityDrop) + } + } + } + } + } +} + +/// Receiving half of the priority channel set. +pub struct PriorityReceiver { + high_rx: mpsc::Receiver>, + normal_rx: mpsc::Receiver>, + low_rx: mpsc::Receiver>, +} + +impl PriorityReceiver { + /// Receive the next packet, draining high-priority first (biased select). + pub async fn recv(&mut self) -> Option> { + tokio::select! { + biased; + Some(pkt) = self.high_rx.recv() => Some(pkt), + Some(pkt) = self.normal_rx.recv() => Some(pkt), + Some(pkt) = self.low_rx.recv() => Some(pkt), + else => None, + } + } +} + +/// Create a priority channel set split into sender and receiver halves. +/// +/// - `high_cap`: capacity of the high-priority channel +/// - `normal_cap`: capacity of the normal-priority channel +/// - `low_cap`: capacity of the low-priority channel +pub fn create_priority_channels( + high_cap: usize, + normal_cap: usize, + low_cap: usize, +) -> (PrioritySender, PriorityReceiver) { + let (high_tx, high_rx) = mpsc::channel(high_cap); + let (normal_tx, normal_rx) = mpsc::channel(normal_cap); + let (low_tx, low_rx) = mpsc::channel(low_cap); + let stats = Arc::new(QosStats::new()); + + let sender = PrioritySender { + high_tx, + normal_tx, + low_tx, + stats, + }; + + let receiver = PriorityReceiver { + high_rx, + normal_rx, + low_rx, + }; + + (sender, receiver) +} + +/// Get a reference to the QoS stats from a sender. +impl PrioritySender { + pub fn stats(&self) -> &Arc { + &self.stats + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Helper: craft a minimal IPv4 packet + fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec { + let mut pkt = vec![0u8; total_len.max(24) as usize]; + pkt[0] = 0x45; // version 4, IHL 5 + pkt[2..4].copy_from_slice(&total_len.to_be_bytes()); + pkt[9] = protocol; + // src IP + pkt[12..16].copy_from_slice(&[10, 0, 0, 1]); + // dst IP + pkt[16..20].copy_from_slice(&[10, 0, 0, 2]); + // ports (at offset 20 for IHL=5) + pkt[20..22].copy_from_slice(&src_port.to_be_bytes()); + pkt[22..24].copy_from_slice(&dst_port.to_be_bytes()); + pkt + } + + #[test] + fn classify_icmp_as_high() { + let mut c = PacketClassifier::new(1_000_000); + let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP + assert_eq!(c.classify(&pkt), Priority::High); + } + + #[test] + fn classify_dns_as_high() { + let mut c = PacketClassifier::new(1_000_000); + let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53 + assert_eq!(c.classify(&pkt), Priority::High); + } + + #[test] + fn classify_ssh_as_high() { + let mut c = PacketClassifier::new(1_000_000); + let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22 + assert_eq!(c.classify(&pkt), Priority::High); + } + + #[test] + fn classify_small_packet_as_high() { + let mut c = PacketClassifier::new(1_000_000); + let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet + assert_eq!(c.classify(&pkt), Priority::High); + } + + #[test] + fn classify_normal_http() { + let mut c = PacketClassifier::new(1_000_000); + let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B + assert_eq!(c.classify(&pkt), Priority::Normal); + } + + #[test] + fn classify_bulk_flow_as_low() { + let mut c = PacketClassifier::new(10_000); // Low threshold for testing + + // Send enough traffic to exceed the threshold + for _ in 0..100 { + let pkt = make_ipv4_packet(6, 12345, 80, 500); + c.classify(&pkt); + } + + // Next packet from same flow should be Low + let pkt = make_ipv4_packet(6, 12345, 80, 500); + assert_eq!(c.classify(&pkt), Priority::Low); + } + + #[test] + fn classify_too_short_packet() { + let mut c = PacketClassifier::new(1_000_000); + let pkt = vec![0u8; 10]; // Too short for IPv4 header + assert_eq!(c.classify(&pkt), Priority::Normal); + } + + #[test] + fn classify_non_ipv4() { + let mut c = PacketClassifier::new(1_000_000); + let mut pkt = vec![0u8; 40]; + pkt[0] = 0x60; // IPv6 version nibble + assert_eq!(c.classify(&pkt), Priority::Normal); + } + + #[tokio::test] + async fn priority_receiver_drains_high_first() { + let (sender, mut receiver) = create_priority_channels(8, 8, 8); + + // Enqueue in reverse order + sender.send(vec![3], Priority::Low).await.unwrap(); + sender.send(vec![2], Priority::Normal).await.unwrap(); + sender.send(vec![1], Priority::High).await.unwrap(); + + // Should drain High first + assert_eq!(receiver.recv().await.unwrap(), vec![1]); + assert_eq!(receiver.recv().await.unwrap(), vec![2]); + assert_eq!(receiver.recv().await.unwrap(), vec![3]); + } + + #[tokio::test] + async fn smart_dropping_low_priority() { + let (sender, _receiver) = create_priority_channels(8, 8, 1); + + // Fill the low channel + sender.send(vec![0], Priority::Low).await.unwrap(); + + // Next low-priority send should be dropped + let result = sender.send(vec![1], Priority::Low).await; + assert!(matches!(result, Err(PacketDropped::LowPriorityDrop))); + + assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn smart_dropping_normal_priority() { + let (sender, _receiver) = create_priority_channels(8, 1, 8); + + sender.send(vec![0], Priority::Normal).await.unwrap(); + + let result = sender.send(vec![1], Priority::Normal).await; + assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop))); + + assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn stats_track_enqueued() { + let (sender, _receiver) = create_priority_channels(8, 8, 8); + + sender.send(vec![1], Priority::High).await.unwrap(); + sender.send(vec![2], Priority::High).await.unwrap(); + sender.send(vec![3], Priority::Normal).await.unwrap(); + sender.send(vec![4], Priority::Low).await.unwrap(); + + assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2); + assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1); + assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1); + } + + #[test] + fn flow_tracker_evicts_at_capacity() { + let mut ft = FlowTracker::new(Duration::from_secs(60), 2); + + let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 }; + let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 }; + let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 }; + + ft.record(k1, 100, 1000); + ft.record(k2, 100, 1000); + // Should evict k1 (oldest) + ft.record(k3, 100, 1000); + + assert_eq!(ft.flows.len(), 2); + assert!(!ft.flows.contains_key(&k1)); + } +} diff --git a/rust/src/ratelimit.rs b/rust/src/ratelimit.rs new file mode 100644 index 0000000..0720eb4 --- /dev/null +++ b/rust/src/ratelimit.rs @@ -0,0 +1,139 @@ +use std::time::Instant; + +/// A token bucket rate limiter operating on byte granularity. +pub struct TokenBucket { + /// Tokens (bytes) added per second. + rate: f64, + /// Maximum burst capacity in bytes. + burst: f64, + /// Currently available tokens. + tokens: f64, + /// Last time tokens were refilled. + last_refill: Instant, +} + +impl TokenBucket { + /// Create a new token bucket. + /// + /// - `rate_bytes_per_sec`: sustained rate in bytes/second + /// - `burst_bytes`: maximum burst size in bytes (also the initial token count) + pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self { + let burst = burst_bytes as f64; + Self { + rate: rate_bytes_per_sec as f64, + burst, + tokens: burst, // start full + last_refill: Instant::now(), + } + } + + /// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded. + pub fn try_consume(&mut self, bytes: usize) -> bool { + self.refill(); + let needed = bytes as f64; + if needed <= self.tokens { + self.tokens -= needed; + true + } else { + false + } + } + + /// Update rate and burst limits dynamically (for live IPC reconfiguration). + pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) { + self.rate = rate_bytes_per_sec as f64; + self.burst = burst_bytes as f64; + // Cap current tokens at new burst + if self.tokens > self.burst { + self.tokens = self.burst; + } + } + + fn refill(&mut self) { + let now = Instant::now(); + let elapsed = now.duration_since(self.last_refill).as_secs_f64(); + self.last_refill = now; + self.tokens = (self.tokens + elapsed * self.rate).min(self.burst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn allows_traffic_under_burst() { + let mut tb = TokenBucket::new(1_000_000, 2_000_000); + // Should allow up to burst size immediately + assert!(tb.try_consume(1_500_000)); + assert!(tb.try_consume(400_000)); + } + + #[test] + fn blocks_traffic_over_burst() { + let mut tb = TokenBucket::new(1_000_000, 1_000_000); + // Consume entire burst + assert!(tb.try_consume(1_000_000)); + // Next consume should fail (no time to refill) + assert!(!tb.try_consume(1)); + } + + #[test] + fn zero_consume_always_succeeds() { + let mut tb = TokenBucket::new(0, 0); + assert!(tb.try_consume(0)); + } + + #[test] + fn refills_over_time() { + let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst + // Drain completely + assert!(tb.try_consume(1_000_000)); + assert!(!tb.try_consume(1)); + + // Wait 100ms — should refill ~100KB + std::thread::sleep(Duration::from_millis(100)); + assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s + } + + #[test] + fn update_limits_caps_tokens() { + let mut tb = TokenBucket::new(1_000_000, 2_000_000); + // Tokens start at burst (2MB) + tb.update_limits(500_000, 500_000); + // Tokens should be capped to new burst (500KB) + assert!(tb.try_consume(500_000)); + assert!(!tb.try_consume(1)); + } + + #[test] + fn update_limits_changes_rate() { + let mut tb = TokenBucket::new(1_000_000, 1_000_000); + assert!(tb.try_consume(1_000_000)); // drain + + // Change to higher rate + tb.update_limits(10_000_000, 10_000_000); + std::thread::sleep(Duration::from_millis(50)); + // At 10MB/s, 50ms should refill ~500KB + assert!(tb.try_consume(200_000)); + } + + #[test] + fn zero_rate_blocks_after_burst() { + let mut tb = TokenBucket::new(0, 100); + assert!(tb.try_consume(100)); + std::thread::sleep(Duration::from_millis(10)); + // Zero rate means no refill + assert!(!tb.try_consume(1)); + } + + #[test] + fn tokens_do_not_exceed_burst() { + let mut tb = TokenBucket::new(1_000_000, 1_000); + // Wait to accumulate — but should cap at burst + std::thread::sleep(Duration::from_millis(50)); + assert!(tb.try_consume(1_000)); + assert!(!tb.try_consume(1)); + } +} diff --git a/rust/src/server.rs b/rust/src/server.rs index 368eb9c..580b689 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::Ipv4Addr; use std::sync::Arc; +use std::time::Duration; use tokio::net::TcpListener; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio_tungstenite::tungstenite::Message; @@ -12,9 +13,14 @@ use tracing::{info, error, warn}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; +use crate::mtu::{MtuConfig, TunnelOverhead}; use crate::network::IpPool; +use crate::ratelimit::TokenBucket; use crate::transport; +/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s). +const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180); + /// Server configuration (matches TS IVpnServerConfig). #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] @@ -29,6 +35,10 @@ pub struct ServerConfig { pub mtu: Option, pub keepalive_interval_secs: Option, pub enable_nat: Option, + /// Default rate limit for new clients (bytes/sec). None = unlimited. + pub default_rate_limit_bytes_per_sec: Option, + /// Default burst size for new clients (bytes). None = unlimited. + pub default_burst_bytes: Option, } /// Information about a connected client. @@ -40,6 +50,12 @@ pub struct ClientInfo { pub connected_since: String, pub bytes_sent: u64, pub bytes_received: u64, + pub packets_dropped: u64, + pub bytes_dropped: u64, + pub last_keepalive_at: Option, + pub keepalives_received: u64, + pub rate_limit_bytes_per_sec: Option, + pub burst_bytes: Option, } /// Server statistics. @@ -63,6 +79,8 @@ pub struct ServerState { pub ip_pool: Mutex, pub clients: RwLock>, pub stats: RwLock, + pub rate_limiters: Mutex>, + pub mtu_config: MtuConfig, pub started_at: std::time::Instant, } @@ -98,11 +116,18 @@ impl VpnServer { } } + let link_mtu = config.mtu.unwrap_or(1420); + // Compute effective MTU from overhead + let overhead = TunnelOverhead::default_overhead(); + let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu)); + let state = Arc::new(ServerState { config: config.clone(), ip_pool: Mutex::new(ip_pool), clients: RwLock::new(HashMap::new()), stats: RwLock::new(ServerStatistics::default()), + rate_limiters: Mutex::new(HashMap::new()), + mtu_config, started_at: std::time::Instant::now(), }); @@ -166,11 +191,52 @@ impl VpnServer { if let Some(client) = clients.remove(client_id) { let ip: Ipv4Addr = client.assigned_ip.parse()?; state.ip_pool.lock().await.release(&ip); + state.rate_limiters.lock().await.remove(client_id); info!("Client {} disconnected", client_id); } } Ok(()) } + + /// Set a rate limit for a specific client. + pub async fn set_client_rate_limit( + &self, + client_id: &str, + rate_bytes_per_sec: u64, + burst_bytes: u64, + ) -> Result<()> { + if let Some(ref state) = self.state { + let mut limiters = state.rate_limiters.lock().await; + if let Some(limiter) = limiters.get_mut(client_id) { + limiter.update_limits(rate_bytes_per_sec, burst_bytes); + } else { + limiters.insert( + client_id.to_string(), + TokenBucket::new(rate_bytes_per_sec, burst_bytes), + ); + } + // Update client info + let mut clients = state.clients.write().await; + if let Some(info) = clients.get_mut(client_id) { + info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec); + info.burst_bytes = Some(burst_bytes); + } + } + Ok(()) + } + + /// Remove rate limit for a specific client (unlimited). + pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> { + if let Some(ref state) = self.state { + state.rate_limiters.lock().await.remove(client_id); + let mut clients = state.clients.write().await; + if let Some(info) = clients.get_mut(client_id) { + info.rate_limit_bytes_per_sec = None; + info.burst_bytes = None; + } + } + Ok(()) + } } async fn run_listener( @@ -257,25 +323,43 @@ async fn handle_client_connection( let mut noise_transport = responder.into_transport_mode()?; // Register client + let default_rate = state.config.default_rate_limit_bytes_per_sec; + let default_burst = state.config.default_burst_bytes; let client_info = ClientInfo { client_id: client_id.clone(), assigned_ip: assigned_ip.to_string(), connected_since: timestamp_now(), bytes_sent: 0, bytes_received: 0, + packets_dropped: 0, + bytes_dropped: 0, + last_keepalive_at: None, + keepalives_received: 0, + rate_limit_bytes_per_sec: default_rate, + burst_bytes: default_burst, }; state.clients.write().await.insert(client_id.clone(), client_info); + // Set up rate limiter if defaults are configured + if let (Some(rate), Some(burst)) = (default_rate, default_burst) { + state + .rate_limiters + .lock() + .await + .insert(client_id.clone(), TokenBucket::new(rate, burst)); + } + { let mut stats = state.stats.write().await; stats.total_connections += 1; } - // Send assigned IP info (encrypted) + // Send assigned IP info (encrypted), include effective MTU let ip_info = serde_json::json!({ "assignedIp": assigned_ip.to_string(), "gateway": state.ip_pool.lock().await.gateway_addr().to_string(), "mtu": state.config.mtu.unwrap_or(1420), + "effectiveMtu": state.mtu_config.effective_mtu, }); let ip_info_bytes = serde_json::to_vec(&ip_info)?; let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?; @@ -289,66 +373,116 @@ async fn handle_client_connection( info!("Client {} connected with IP {}", client_id, assigned_ip); - // Main packet loop + // Main packet loop with dead-peer detection + let mut last_activity = tokio::time::Instant::now(); + loop { - match ws_stream.next().await { - Some(Ok(Message::Binary(data))) => { - let mut frame_buf = BytesMut::from(&data[..][..]); - match ::decode(&mut FrameCodec, &mut frame_buf) { - Ok(Some(frame)) => match frame.packet_type { - PacketType::IpPacket => { - match noise_transport.read_message(&frame.payload, &mut buf) { - Ok(len) => { - let mut stats = state.stats.write().await; - stats.bytes_received += len as u64; - stats.packets_received += 1; + tokio::select! { + msg = ws_stream.next() => { + match msg { + Some(Ok(Message::Binary(data))) => { + last_activity = tokio::time::Instant::now(); + let mut frame_buf = BytesMut::from(&data[..][..]); + match ::decode(&mut FrameCodec, &mut frame_buf) { + Ok(Some(frame)) => match frame.packet_type { + PacketType::IpPacket => { + match noise_transport.read_message(&frame.payload, &mut buf) { + Ok(len) => { + // Rate limiting check + let allowed = { + let mut limiters = state.rate_limiters.lock().await; + if let Some(limiter) = limiters.get_mut(&client_id) { + limiter.try_consume(len) + } else { + true + } + }; + + if !allowed { + let mut clients = state.clients.write().await; + if let Some(info) = clients.get_mut(&client_id) { + info.packets_dropped += 1; + info.bytes_dropped += len as u64; + } + continue; + } + + let mut stats = state.stats.write().await; + stats.bytes_received += len as u64; + stats.packets_received += 1; + + // Update per-client stats + drop(stats); + let mut clients = state.clients.write().await; + if let Some(info) = clients.get_mut(&client_id) { + info.bytes_received += len as u64; + } + } + Err(e) => { + warn!("Decrypt error from {}: {}", client_id, e); + break; + } + } } - Err(e) => { - warn!("Decrypt error from {}: {}", client_id, e); + PacketType::Keepalive => { + // Echo the keepalive payload back in the ACK + let ack_frame = Frame { + packet_type: PacketType::KeepaliveAck, + payload: frame.payload.clone(), + }; + let mut frame_bytes = BytesMut::new(); + >::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?; + ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?; + + let mut stats = state.stats.write().await; + stats.keepalives_received += 1; + stats.keepalives_sent += 1; + + // Update per-client keepalive tracking + drop(stats); + let mut clients = state.clients.write().await; + if let Some(info) = clients.get_mut(&client_id) { + info.last_keepalive_at = Some(timestamp_now()); + info.keepalives_received += 1; + } + } + PacketType::Disconnect => { + info!("Client {} sent disconnect", client_id); break; } + _ => { + warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type); + } + }, + Ok(None) => { + warn!("Incomplete frame from {}", client_id); + } + Err(e) => { + warn!("Frame decode error from {}: {}", client_id, e); + break; } } - PacketType::Keepalive => { - let ack_frame = Frame { - packet_type: PacketType::KeepaliveAck, - payload: vec![], - }; - let mut frame_bytes = BytesMut::new(); - >::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?; - ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?; - - let mut stats = state.stats.write().await; - stats.keepalives_received += 1; - stats.keepalives_sent += 1; - } - PacketType::Disconnect => { - info!("Client {} sent disconnect", client_id); - break; - } - _ => { - warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type); - } - }, - Ok(None) => { - warn!("Incomplete frame from {}", client_id); } - Err(e) => { - warn!("Frame decode error from {}: {}", client_id, e); + Some(Ok(Message::Close(_))) | None => { + info!("Client {} connection closed", client_id); + break; + } + Some(Ok(Message::Ping(data))) => { + last_activity = tokio::time::Instant::now(); + ws_sink.send(Message::Pong(data)).await?; + } + Some(Ok(_)) => { + last_activity = tokio::time::Instant::now(); + continue; + } + Some(Err(e)) => { + warn!("WebSocket error from {}: {}", client_id, e); break; } } } - Some(Ok(Message::Close(_))) | None => { - info!("Client {} connection closed", client_id); - break; - } - Some(Ok(Message::Ping(data))) => { - ws_sink.send(Message::Pong(data)).await?; - } - Some(Ok(_)) => continue, - Some(Err(e)) => { - warn!("WebSocket error from {}: {}", client_id, e); + _ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => { + warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs()); break; } } @@ -357,6 +491,7 @@ async fn handle_client_connection( // Cleanup state.clients.write().await.remove(&client_id); state.ip_pool.lock().await.release(&assigned_ip); + state.rate_limiters.lock().await.remove(&client_id); info!("Client {} disconnected, released IP {}", client_id, assigned_ip); Ok(()) diff --git a/rust/src/telemetry.rs b/rust/src/telemetry.rs new file mode 100644 index 0000000..d131b97 --- /dev/null +++ b/rust/src/telemetry.rs @@ -0,0 +1,317 @@ +use serde::Serialize; +use std::collections::VecDeque; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +/// A single RTT sample. +#[derive(Debug, Clone)] +struct RttSample { + _rtt: Duration, + _timestamp: Instant, + was_timeout: bool, +} + +/// Snapshot of connection quality metrics. +#[derive(Debug, Clone, Serialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct ConnectionQuality { + /// Smoothed RTT in milliseconds (EMA, RFC 6298 style). + pub srtt_ms: f64, + /// Jitter in milliseconds (mean deviation of RTT). + pub jitter_ms: f64, + /// Minimum RTT observed in the sample window. + pub min_rtt_ms: f64, + /// Maximum RTT observed in the sample window. + pub max_rtt_ms: f64, + /// Packet loss ratio over the sample window (0.0 - 1.0). + pub loss_ratio: f64, + /// Number of consecutive keepalive timeouts (0 if last succeeded). + pub consecutive_timeouts: u32, + /// Total keepalives sent. + pub keepalives_sent: u64, + /// Total keepalive ACKs received. + pub keepalives_acked: u64, +} + +/// Tracks connection quality from keepalive round-trips. +pub struct RttTracker { + /// Maximum number of samples to keep in the window. + max_samples: usize, + /// Recent RTT samples (including timeout markers). + samples: VecDeque, + /// When the last keepalive was sent (for computing RTT on ACK). + pending_ping_sent_at: Option, + /// Number of consecutive keepalive timeouts. + pub consecutive_timeouts: u32, + /// Smoothed RTT (EMA). + srtt: Option, + /// Jitter (mean deviation). + jitter: f64, + /// Minimum RTT observed. + min_rtt: f64, + /// Maximum RTT observed. + max_rtt: f64, + /// Total keepalives sent. + keepalives_sent: u64, + /// Total keepalive ACKs received. + keepalives_acked: u64, + /// Previous RTT sample for jitter calculation. + last_rtt_ms: Option, +} + +impl RttTracker { + /// Create a new tracker with the given window size. + pub fn new(max_samples: usize) -> Self { + Self { + max_samples, + samples: VecDeque::with_capacity(max_samples), + pending_ping_sent_at: None, + consecutive_timeouts: 0, + srtt: None, + jitter: 0.0, + min_rtt: f64::MAX, + max_rtt: 0.0, + keepalives_sent: 0, + keepalives_acked: 0, + last_rtt_ms: None, + } + } + + /// Record that a keepalive was sent. + /// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload. + pub fn mark_ping_sent(&mut self) -> u64 { + self.pending_ping_sent_at = Some(Instant::now()); + self.keepalives_sent += 1; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 + } + + /// Record that a keepalive ACK was received with the echoed timestamp. + /// Returns the computed RTT if a pending ping was recorded. + pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option { + let sent_at = self.pending_ping_sent_at.take()?; + let rtt = sent_at.elapsed(); + let rtt_ms = rtt.as_secs_f64() * 1000.0; + + self.keepalives_acked += 1; + self.consecutive_timeouts = 0; + + // Update SRTT (RFC 6298: alpha = 1/8) + match self.srtt { + None => { + self.srtt = Some(rtt_ms); + self.jitter = rtt_ms / 2.0; + } + Some(prev_srtt) => { + // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4) + self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs(); + // SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8) + self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms); + } + } + + // Update min/max + if rtt_ms < self.min_rtt { + self.min_rtt = rtt_ms; + } + if rtt_ms > self.max_rtt { + self.max_rtt = rtt_ms; + } + + self.last_rtt_ms = Some(rtt_ms); + + // Push sample into window + if self.samples.len() >= self.max_samples { + self.samples.pop_front(); + } + self.samples.push_back(RttSample { + _rtt: rtt, + _timestamp: Instant::now(), + was_timeout: false, + }); + + Some(rtt) + } + + /// Record that a keepalive timed out (no ACK received). + pub fn record_timeout(&mut self) { + self.consecutive_timeouts += 1; + self.pending_ping_sent_at = None; + + if self.samples.len() >= self.max_samples { + self.samples.pop_front(); + } + self.samples.push_back(RttSample { + _rtt: Duration::ZERO, + _timestamp: Instant::now(), + was_timeout: true, + }); + } + + /// Get a snapshot of the current connection quality. + pub fn snapshot(&self) -> ConnectionQuality { + let loss_ratio = if self.samples.is_empty() { + 0.0 + } else { + let timeouts = self.samples.iter().filter(|s| s.was_timeout).count(); + timeouts as f64 / self.samples.len() as f64 + }; + + ConnectionQuality { + srtt_ms: self.srtt.unwrap_or(0.0), + jitter_ms: self.jitter, + min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt }, + max_rtt_ms: self.max_rtt, + loss_ratio, + consecutive_timeouts: self.consecutive_timeouts, + keepalives_sent: self.keepalives_sent, + keepalives_acked: self.keepalives_acked, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_tracker_has_zero_quality() { + let tracker = RttTracker::new(30); + let q = tracker.snapshot(); + assert_eq!(q.srtt_ms, 0.0); + assert_eq!(q.jitter_ms, 0.0); + assert_eq!(q.loss_ratio, 0.0); + assert_eq!(q.consecutive_timeouts, 0); + assert_eq!(q.keepalives_sent, 0); + assert_eq!(q.keepalives_acked, 0); + } + + #[test] + fn mark_ping_returns_timestamp() { + let mut tracker = RttTracker::new(30); + let ts = tracker.mark_ping_sent(); + // Should be a reasonable epoch-ms value (after 2020) + assert!(ts > 1_577_836_800_000); + assert_eq!(tracker.keepalives_sent, 1); + } + + #[test] + fn record_ack_computes_rtt() { + let mut tracker = RttTracker::new(30); + let ts = tracker.mark_ping_sent(); + std::thread::sleep(Duration::from_millis(5)); + let rtt = tracker.record_ack(ts); + assert!(rtt.is_some()); + let rtt = rtt.unwrap(); + assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter + assert_eq!(tracker.keepalives_acked, 1); + assert_eq!(tracker.consecutive_timeouts, 0); + } + + #[test] + fn record_ack_without_pending_returns_none() { + let mut tracker = RttTracker::new(30); + assert!(tracker.record_ack(12345).is_none()); + } + + #[test] + fn srtt_converges() { + let mut tracker = RttTracker::new(30); + + // Simulate several ping/ack cycles with ~10ms RTT + for _ in 0..10 { + let ts = tracker.mark_ping_sent(); + std::thread::sleep(Duration::from_millis(10)); + tracker.record_ack(ts); + } + + let q = tracker.snapshot(); + // SRTT should be roughly 10ms (allowing for scheduling variance) + assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms); + assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms); + } + + #[test] + fn timeout_increments_counter_and_loss() { + let mut tracker = RttTracker::new(30); + + tracker.mark_ping_sent(); + tracker.record_timeout(); + assert_eq!(tracker.consecutive_timeouts, 1); + + tracker.mark_ping_sent(); + tracker.record_timeout(); + assert_eq!(tracker.consecutive_timeouts, 2); + + let q = tracker.snapshot(); + assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples + } + + #[test] + fn ack_resets_consecutive_timeouts() { + let mut tracker = RttTracker::new(30); + + tracker.mark_ping_sent(); + tracker.record_timeout(); + assert_eq!(tracker.consecutive_timeouts, 1); + + let ts = tracker.mark_ping_sent(); + tracker.record_ack(ts); + assert_eq!(tracker.consecutive_timeouts, 0); + } + + #[test] + fn loss_ratio_over_mixed_window() { + let mut tracker = RttTracker::new(30); + + // 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss + for _ in 0..3 { + let ts = tracker.mark_ping_sent(); + tracker.record_ack(ts); + } + tracker.mark_ping_sent(); + tracker.record_timeout(); + let ts = tracker.mark_ping_sent(); + tracker.record_ack(ts); + + let q = tracker.snapshot(); + assert!((q.loss_ratio - 0.2).abs() < 0.01); + } + + #[test] + fn window_evicts_old_samples() { + let mut tracker = RttTracker::new(5); + + // Fill window with 5 timeouts + for _ in 0..5 { + tracker.mark_ping_sent(); + tracker.record_timeout(); + } + assert_eq!(tracker.snapshot().loss_ratio, 1.0); + + // Add 5 successes — should evict all timeouts + for _ in 0..5 { + let ts = tracker.mark_ping_sent(); + tracker.record_ack(ts); + } + assert_eq!(tracker.snapshot().loss_ratio, 0.0); + } + + #[test] + fn min_max_rtt_tracked() { + let mut tracker = RttTracker::new(30); + + let ts = tracker.mark_ping_sent(); + std::thread::sleep(Duration::from_millis(5)); + tracker.record_ack(ts); + + let ts = tracker.mark_ping_sent(); + std::thread::sleep(Duration::from_millis(15)); + tracker.record_ack(ts); + + let q = tracker.snapshot(); + assert!(q.min_rtt_ms < q.max_rtt_ms); + assert!(q.min_rtt_ms > 0.0); + } +} diff --git a/rust/src/tunnel.rs b/rust/src/tunnel.rs index ed29bba..faa221a 100644 --- a/rust/src/tunnel.rs +++ b/rust/src/tunnel.rs @@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> { Ok(()) } +/// Action to take after checking a packet against the MTU. +pub enum TunMtuAction { + /// Packet is within MTU limits, forward it. + Forward, + /// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN. + IcmpTooBig(Vec), +} + +/// Check a TUN packet against the MTU and return the appropriate action. +pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction { + match crate::mtu::check_mtu(packet, mtu_config) { + crate::mtu::MtuAction::Forward => TunMtuAction::Forward, + crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp), + } +} + /// Remove a route. pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> { let output = tokio::process::Command::new("ip") diff --git a/ts/00_commitinfo_data.ts b/ts/00_commitinfo_data.ts index e99a78d..1500640 100644 --- a/ts/00_commitinfo_data.ts +++ b/ts/00_commitinfo_data.ts @@ -3,6 +3,6 @@ */ export const commitinfo = { name: '@push.rocks/smartvpn', - version: '1.0.3', + version: '1.1.0', description: 'A VPN solution with TypeScript control plane and Rust data plane daemon' } diff --git a/ts/smartvpn.classes.vpnclient.ts b/ts/smartvpn.classes.vpnclient.ts index 5aa456b..4d0536d 100644 --- a/ts/smartvpn.classes.vpnclient.ts +++ b/ts/smartvpn.classes.vpnclient.ts @@ -5,6 +5,8 @@ import type { IVpnClientConfig, IVpnStatus, IVpnStatistics, + IVpnConnectionQuality, + IVpnMtuInfo, TVpnClientCommands, } from './smartvpn.interfaces.js'; @@ -65,12 +67,26 @@ export class VpnClient extends plugins.events.EventEmitter { } /** - * Get traffic statistics. + * Get traffic statistics (includes connection quality when connected). */ public async getStatistics(): Promise { return this.bridge.sendCommand('getStatistics', {} as Record); } + /** + * Get connection quality metrics (RTT, jitter, loss, link health). + */ + public async getConnectionQuality(): Promise { + return this.bridge.sendCommand('getConnectionQuality', {} as Record); + } + + /** + * Get MTU information (overhead, effective MTU, oversized packet stats). + */ + public async getMtuInfo(): Promise { + return this.bridge.sendCommand('getMtuInfo', {} as Record); + } + /** * Stop the daemon bridge. */ diff --git a/ts/smartvpn.classes.vpnserver.ts b/ts/smartvpn.classes.vpnserver.ts index f472e37..7e36ca2 100644 --- a/ts/smartvpn.classes.vpnserver.ts +++ b/ts/smartvpn.classes.vpnserver.ts @@ -7,6 +7,7 @@ import type { IVpnServerStatistics, IVpnClientInfo, IVpnKeypair, + IVpnClientTelemetry, TVpnServerCommands, } from './smartvpn.interfaces.js'; @@ -91,6 +92,35 @@ export class VpnServer extends plugins.events.EventEmitter { return this.bridge.sendCommand('generateKeypair', {} as Record); } + /** + * Set rate limit for a specific client. + */ + public async setClientRateLimit( + clientId: string, + rateBytesPerSec: number, + burstBytes: number, + ): Promise { + await this.bridge.sendCommand('setClientRateLimit', { + clientId, + rateBytesPerSec, + burstBytes, + }); + } + + /** + * Remove rate limit for a specific client (unlimited). + */ + public async removeClientRateLimit(clientId: string): Promise { + await this.bridge.sendCommand('removeClientRateLimit', { clientId }); + } + + /** + * Get telemetry for a specific client. + */ + public async getClientTelemetry(clientId: string): Promise { + return this.bridge.sendCommand('getClientTelemetry', { clientId }); + } + /** * Stop the daemon bridge. */ diff --git a/ts/smartvpn.interfaces.ts b/ts/smartvpn.interfaces.ts index 52ceb9b..0da6e48 100644 --- a/ts/smartvpn.interfaces.ts +++ b/ts/smartvpn.interfaces.ts @@ -64,6 +64,10 @@ export interface IVpnServerConfig { keepaliveIntervalSecs?: number; /** Enable NAT/masquerade for client traffic */ enableNat?: boolean; + /** Default rate limit for new clients (bytes/sec). Omit for unlimited. */ + defaultRateLimitBytesPerSec?: number; + /** Default burst size for new clients (bytes). Omit for unlimited. */ + defaultBurstBytes?: number; } export interface IVpnServerOptions { @@ -99,6 +103,7 @@ export interface IVpnStatistics { keepalivesSent: number; keepalivesReceived: number; uptimeSeconds: number; + quality?: IVpnConnectionQuality; } export interface IVpnClientInfo { @@ -107,6 +112,12 @@ export interface IVpnClientInfo { connectedSince: string; bytesSent: number; bytesReceived: number; + packetsDropped: number; + bytesDropped: number; + lastKeepaliveAt?: string; + keepalivesReceived: number; + rateLimitBytesPerSec?: number; + burstBytes?: number; } export interface IVpnServerStatistics extends IVpnStatistics { @@ -119,6 +130,53 @@ export interface IVpnKeypair { privateKey: string; } +// ============================================================================ +// QoS: Connection quality +// ============================================================================ + +export type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical'; + +export interface IVpnConnectionQuality { + srttMs: number; + jitterMs: number; + minRttMs: number; + maxRttMs: number; + lossRatio: number; + consecutiveTimeouts: number; + linkHealth: TVpnLinkHealth; + currentKeepaliveIntervalSecs: number; +} + +// ============================================================================ +// QoS: MTU info +// ============================================================================ + +export interface IVpnMtuInfo { + tunMtu: number; + effectiveMtu: number; + linkMtu: number; + overheadBytes: number; + oversizedPacketsDropped: number; + icmpTooBigSent: number; +} + +// ============================================================================ +// QoS: Client telemetry (server-side per-client) +// ============================================================================ + +export interface IVpnClientTelemetry { + clientId: string; + assignedIp: string; + lastKeepaliveAt?: string; + keepalivesReceived: number; + packetsDropped: number; + bytesDropped: number; + bytesReceived: number; + bytesSent: number; + rateLimitBytesPerSec?: number; + burstBytes?: number; +} + // ============================================================================ // IPC Command maps (used by smartrust RustBridge) // ============================================================================ @@ -128,6 +186,8 @@ export type TVpnClientCommands = { disconnect: { params: Record; result: void }; getStatus: { params: Record; result: IVpnStatus }; getStatistics: { params: Record; result: IVpnStatistics }; + getConnectionQuality: { params: Record; result: IVpnConnectionQuality }; + getMtuInfo: { params: Record; result: IVpnMtuInfo }; }; export type TVpnServerCommands = { @@ -138,6 +198,9 @@ export type TVpnServerCommands = { listClients: { params: Record; result: { clients: IVpnClientInfo[] } }; disconnectClient: { params: { clientId: string }; result: void }; generateKeypair: { params: Record; result: IVpnKeypair }; + setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void }; + removeClientRateLimit: { params: { clientId: string }; result: void }; + getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry }; }; // ============================================================================