feat(rust-core): add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs

This commit is contained in:
2026-03-15 18:10:25 +00:00
parent 97bb148063
commit 9ee41348e0
15 changed files with 2152 additions and 101 deletions

View File

@@ -3,13 +3,14 @@ use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock};
use tokio::sync::{mpsc, watch, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn};
use tracing::{info, error, warn, debug};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
use crate::transport;
/// Client configuration (matches TS IVpnClientConfig).
@@ -65,6 +66,8 @@ pub struct VpnClient {
assigned_ip: Arc<RwLock<Option<String>>>,
shutdown_tx: Option<mpsc::Sender<()>>,
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
link_health: Arc<RwLock<LinkHealth>>,
}
impl VpnClient {
@@ -75,6 +78,8 @@ impl VpnClient {
assigned_ip: Arc::new(RwLock::new(None)),
shutdown_tx: None,
connected_since: Arc::new(RwLock::new(None)),
quality_rx: None,
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
}
}
@@ -93,6 +98,7 @@ impl VpnClient {
let stats = self.stats.clone();
let assigned_ip_ref = self.assigned_ip.clone();
let connected_since = self.connected_since.clone();
let link_health = self.link_health.clone();
// Decode server public key
let server_pub_key = base64::Engine::decode(
@@ -161,6 +167,13 @@ impl VpnClient {
info!("Connected to VPN, assigned IP: {}", assigned_ip);
// Create adaptive keepalive monitor
let (monitor, handle) = keepalive::create_keepalive(None);
self.quality_rx = Some(handle.quality_rx);
// Spawn the keepalive monitor
tokio::spawn(monitor.run());
// Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop(
@@ -170,7 +183,9 @@ impl VpnClient {
state,
stats,
shutdown_rx,
config.keepalive_interval_secs.unwrap_or(30),
handle.signal_rx,
handle.ack_tx,
link_health,
));
Ok(assigned_ip_clone)
@@ -184,6 +199,7 @@ impl VpnClient {
*self.assigned_ip.write().await = None;
*self.connected_since.write().await = None;
*self.state.write().await = ClientState::Disconnected;
self.quality_rx = None;
info!("Disconnected from VPN");
Ok(())
}
@@ -208,13 +224,14 @@ impl VpnClient {
status
}
/// Get traffic statistics.
/// Get traffic statistics (includes connection quality).
pub async fn get_statistics(&self) -> serde_json::Value {
let stats = self.stats.read().await;
let since = self.connected_since.read().await;
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
let health = self.link_health.read().await;
serde_json::json!({
let mut result = serde_json::json!({
"bytesSent": stats.bytes_sent,
"bytesReceived": stats.bytes_received,
"packetsSent": stats.packets_sent,
@@ -222,7 +239,35 @@ impl VpnClient {
"keepalivesSent": stats.keepalives_sent,
"keepalivesReceived": stats.keepalives_received,
"uptimeSeconds": uptime,
})
});
// Include connection quality if available
if let Some(ref rx) = self.quality_rx {
let quality = rx.borrow().clone();
result["quality"] = serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", *health),
"keepalivesSent": quality.keepalives_sent,
"keepalivesAcked": quality.keepalives_acked,
});
}
result
}
/// Get connection quality snapshot.
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
}
/// Get current link health.
pub async fn get_link_health(&self) -> LinkHealth {
*self.link_health.read().await
}
}
@@ -234,11 +279,11 @@ async fn client_loop(
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
mut shutdown_rx: mpsc::Receiver<()>,
keepalive_secs: u64,
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
ack_tx: mpsc::Sender<()>,
link_health: Arc<RwLock<LinkHealth>>,
) {
let mut buf = vec![0u8; 65535];
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
keepalive_ticker.tick().await; // skip first immediate tick
loop {
tokio::select! {
@@ -264,6 +309,8 @@ async fn client_loop(
}
PacketType::KeepaliveAck => {
stats.write().await.keepalives_received += 1;
// Signal the keepalive monitor that ACK was received
let _ = ack_tx.send(()).await;
}
PacketType::Disconnect => {
info!("Server sent disconnect");
@@ -290,19 +337,37 @@ async fn client_loop(
}
}
}
_ = keepalive_ticker.tick() => {
let ka_frame = Frame {
packet_type: PacketType::Keepalive,
payload: vec![],
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
warn!("Failed to send keepalive");
signal = signal_rx.recv() => {
match signal {
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
let ka_frame = Frame {
packet_type: PacketType::Keepalive,
payload: timestamp_ms.to_be_bytes().to_vec(),
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
warn!("Failed to send keepalive");
*state.write().await = ClientState::Disconnected;
break;
}
stats.write().await.keepalives_sent += 1;
}
}
Some(KeepaliveSignal::PeerDead) => {
warn!("Peer declared dead by keepalive monitor");
*state.write().await = ClientState::Disconnected;
break;
}
stats.write().await.keepalives_sent += 1;
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
debug!("Link health changed to: {}", health);
*link_health.write().await = health;
}
None => {
// Keepalive monitor channel closed
break;
}
}
}
_ = shutdown_rx.recv() => {

View File

@@ -1,87 +1,464 @@
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, watch};
use tokio::time::{interval, timeout};
use tracing::{debug, warn};
use tracing::{debug, info, warn};
/// Default keepalive interval (30 seconds).
use crate::telemetry::{ConnectionQuality, RttTracker};
/// Default keepalive interval (30 seconds — used for Degraded state).
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
/// Default keepalive ACK timeout (10 seconds).
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
/// Default keepalive ACK timeout (5 seconds).
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
/// Link health states for adaptive keepalive.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinkHealth {
/// RTT stable, jitter low, no loss. Interval: 60s.
Healthy,
/// Elevated jitter or occasional loss. Interval: 30s.
Degraded,
/// High loss or sustained jitter spike. Interval: 10s.
Critical,
}
impl std::fmt::Display for LinkHealth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Healthy => write!(f, "healthy"),
Self::Degraded => write!(f, "degraded"),
Self::Critical => write!(f, "critical"),
}
}
}
/// Configuration for the adaptive keepalive state machine.
#[derive(Debug, Clone)]
pub struct AdaptiveKeepaliveConfig {
/// Interval when link health is Healthy.
pub healthy_interval: Duration,
/// Interval when link health is Degraded.
pub degraded_interval: Duration,
/// Interval when link health is Critical.
pub critical_interval: Duration,
/// ACK timeout (how long to wait for ACK before declaring timeout).
pub ack_timeout: Duration,
/// Jitter threshold (ms) to enter Degraded from Healthy.
pub jitter_degraded_ms: f64,
/// Jitter threshold (ms) to return to Healthy from Degraded.
pub jitter_healthy_ms: f64,
/// Loss ratio threshold to enter Degraded.
pub loss_degraded: f64,
/// Loss ratio threshold to enter Critical.
pub loss_critical: f64,
/// Loss ratio threshold to return from Critical to Degraded.
pub loss_recover: f64,
/// Loss ratio threshold to return from Degraded to Healthy.
pub loss_healthy: f64,
/// Consecutive checks required for upward state transitions (hysteresis).
pub upgrade_checks: u32,
/// Consecutive timeouts to declare peer dead in Critical state.
pub dead_peer_timeouts: u32,
}
impl Default for AdaptiveKeepaliveConfig {
fn default() -> Self {
Self {
healthy_interval: Duration::from_secs(60),
degraded_interval: Duration::from_secs(30),
critical_interval: Duration::from_secs(10),
ack_timeout: Duration::from_secs(5),
jitter_degraded_ms: 50.0,
jitter_healthy_ms: 30.0,
loss_degraded: 0.05,
loss_critical: 0.20,
loss_recover: 0.10,
loss_healthy: 0.02,
upgrade_checks: 3,
dead_peer_timeouts: 3,
}
}
}
/// Signals from the keepalive monitor.
#[derive(Debug, Clone)]
pub enum KeepaliveSignal {
/// Time to send a keepalive ping.
SendPing,
/// Peer is considered dead (no ACK received within timeout).
/// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
SendPing(u64),
/// Peer is considered dead (no ACK received within timeout repeatedly).
PeerDead,
/// Link health state changed.
LinkHealthChanged(LinkHealth),
}
/// A keepalive monitor that emits signals on a channel.
/// A keepalive monitor with adaptive interval and RTT tracking.
pub struct KeepaliveMonitor {
interval: Duration,
timeout_duration: Duration,
config: AdaptiveKeepaliveConfig,
health: LinkHealth,
rtt_tracker: RttTracker,
signal_tx: mpsc::Sender<KeepaliveSignal>,
ack_rx: mpsc::Receiver<()>,
quality_tx: watch::Sender<ConnectionQuality>,
consecutive_upgrade_checks: u32,
}
/// Handle returned to the caller to send ACKs and receive signals.
pub struct KeepaliveHandle {
pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
pub ack_tx: mpsc::Sender<()>,
pub quality_rx: watch::Receiver<ConnectionQuality>,
}
/// Create a keepalive monitor and its handle.
/// Create an adaptive keepalive monitor and its handle.
pub fn create_keepalive(
keepalive_interval: Option<Duration>,
keepalive_timeout: Option<Duration>,
config: Option<AdaptiveKeepaliveConfig>,
) -> (KeepaliveMonitor, KeepaliveHandle) {
let config = config.unwrap_or_default();
let (signal_tx, signal_rx) = mpsc::channel(8);
let (ack_tx, ack_rx) = mpsc::channel(8);
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
let monitor = KeepaliveMonitor {
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL),
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT),
config,
health: LinkHealth::Degraded, // start in Degraded, earn Healthy
rtt_tracker: RttTracker::new(30),
signal_tx,
ack_rx,
quality_tx,
consecutive_upgrade_checks: 0,
};
let handle = KeepaliveHandle { signal_rx, ack_tx };
let handle = KeepaliveHandle {
signal_rx,
ack_tx,
quality_rx,
};
(monitor, handle)
}
impl KeepaliveMonitor {
fn current_interval(&self) -> Duration {
match self.health {
LinkHealth::Healthy => self.config.healthy_interval,
LinkHealth::Degraded => self.config.degraded_interval,
LinkHealth::Critical => self.config.critical_interval,
}
}
/// Run the keepalive loop. Blocks until the peer is dead or channels close.
pub async fn run(mut self) {
let mut ticker = interval(self.interval);
let mut ticker = interval(self.current_interval());
ticker.tick().await; // skip first immediate tick
loop {
ticker.tick().await;
debug!("Sending keepalive ping signal");
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() {
// Channel closed
break;
// Record ping sent, get timestamp for payload
let timestamp_ms = self.rtt_tracker.mark_ping_sent();
debug!("Sending keepalive ping (ts={})", timestamp_ms);
if self
.signal_tx
.send(KeepaliveSignal::SendPing(timestamp_ms))
.await
.is_err()
{
break; // channel closed
}
// Wait for ACK within timeout
match timeout(self.timeout_duration, self.ack_rx.recv()).await {
match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
Ok(Some(())) => {
debug!("Keepalive ACK received");
if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
debug!("Keepalive ACK received, RTT: {:?}", rtt);
}
}
Ok(None) => {
// Channel closed
break;
break; // channel closed
}
Err(_) => {
warn!("Keepalive ACK timeout — peer considered dead");
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
break;
self.rtt_tracker.record_timeout();
warn!(
"Keepalive ACK timeout (consecutive: {})",
self.rtt_tracker.consecutive_timeouts
);
}
}
// Publish quality snapshot
let quality = self.rtt_tracker.snapshot();
let _ = self.quality_tx.send(quality.clone());
// Evaluate state transition
let new_health = self.evaluate_health(&quality);
if new_health != self.health {
info!("Link health: {} -> {}", self.health, new_health);
self.health = new_health;
self.consecutive_upgrade_checks = 0;
// Reset ticker to new interval
ticker = interval(self.current_interval());
ticker.tick().await; // skip first immediate tick
let _ = self
.signal_tx
.send(KeepaliveSignal::LinkHealthChanged(new_health))
.await;
}
// Check for dead peer in Critical state
if self.health == LinkHealth::Critical
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
{
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
self.rtt_tracker.consecutive_timeouts);
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
break;
}
}
}
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
match self.health {
LinkHealth::Healthy => {
// Downgrade conditions
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Critical;
}
if quality.jitter_ms > self.config.jitter_degraded_ms
|| quality.loss_ratio > self.config.loss_degraded
|| quality.consecutive_timeouts >= 1
{
self.consecutive_upgrade_checks = 0;
return LinkHealth::Degraded;
}
LinkHealth::Healthy
}
LinkHealth::Degraded => {
// Downgrade to Critical
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Critical;
}
// Upgrade to Healthy (with hysteresis)
if quality.jitter_ms < self.config.jitter_healthy_ms
&& quality.loss_ratio < self.config.loss_healthy
&& quality.consecutive_timeouts == 0
{
self.consecutive_upgrade_checks += 1;
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Healthy;
}
} else {
self.consecutive_upgrade_checks = 0;
}
LinkHealth::Degraded
}
LinkHealth::Critical => {
// Upgrade to Degraded (with hysteresis), never directly to Healthy
if quality.loss_ratio < self.config.loss_recover
&& quality.consecutive_timeouts == 0
{
self.consecutive_upgrade_checks += 1;
if self.consecutive_upgrade_checks >= 2 {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Degraded;
}
} else {
self.consecutive_upgrade_checks = 0;
}
LinkHealth::Critical
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let config = AdaptiveKeepaliveConfig::default();
assert_eq!(config.healthy_interval, Duration::from_secs(60));
assert_eq!(config.degraded_interval, Duration::from_secs(30));
assert_eq!(config.critical_interval, Duration::from_secs(10));
assert_eq!(config.ack_timeout, Duration::from_secs(5));
assert_eq!(config.dead_peer_timeouts, 3);
}
#[test]
fn link_health_display() {
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
}
// Helper to create a monitor for unit-testing evaluate_health
fn make_test_monitor() -> KeepaliveMonitor {
let (signal_tx, _signal_rx) = mpsc::channel(8);
let (_ack_tx, ack_rx) = mpsc::channel(8);
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
KeepaliveMonitor {
config: AdaptiveKeepaliveConfig::default(),
health: LinkHealth::Degraded,
rtt_tracker: RttTracker::new(30),
signal_tx,
ack_rx,
quality_tx,
consecutive_upgrade_checks: 0,
}
}
#[test]
fn healthy_to_degraded_on_jitter() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
jitter_ms: 60.0, // > 50ms threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Degraded);
}
#[test]
fn healthy_to_degraded_on_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
loss_ratio: 0.06, // > 5% threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Degraded);
}
#[test]
fn healthy_to_critical_on_high_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
loss_ratio: 0.25, // > 20% threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Critical);
}
#[test]
fn healthy_to_critical_on_consecutive_timeouts() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
consecutive_timeouts: 2,
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Critical);
}
#[test]
fn degraded_to_healthy_requires_hysteresis() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let good_quality = ConnectionQuality {
jitter_ms: 10.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
srtt_ms: 20.0,
..Default::default()
};
// Should require 3 consecutive good checks (default upgrade_checks)
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
assert_eq!(m.consecutive_upgrade_checks, 1);
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
assert_eq!(m.consecutive_upgrade_checks, 2);
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
}
#[test]
fn degraded_to_healthy_resets_on_bad_check() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let good = ConnectionQuality {
jitter_ms: 10.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
..Default::default()
};
let bad = ConnectionQuality {
jitter_ms: 60.0, // too high
loss_ratio: 0.0,
consecutive_timeouts: 0,
..Default::default()
};
m.evaluate_health(&good); // 1 check
m.evaluate_health(&good); // 2 checks
m.evaluate_health(&bad); // resets
assert_eq!(m.consecutive_upgrade_checks, 0);
}
#[test]
fn critical_to_degraded_requires_hysteresis() {
let mut m = make_test_monitor();
m.health = LinkHealth::Critical;
let recovering = ConnectionQuality {
loss_ratio: 0.05, // < 10% recover threshold
consecutive_timeouts: 0,
..Default::default()
};
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
assert_eq!(m.consecutive_upgrade_checks, 1);
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
}
#[test]
fn critical_never_directly_to_healthy() {
let mut m = make_test_monitor();
m.health = LinkHealth::Critical;
let perfect = ConnectionQuality {
jitter_ms: 1.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
srtt_ms: 10.0,
..Default::default()
};
// Even with perfect quality, must go through Degraded first
m.evaluate_health(&perfect); // 1
let result = m.evaluate_health(&perfect); // 2 → Degraded
assert_eq!(result, LinkHealth::Degraded);
// Not Healthy yet
}
#[test]
fn degraded_to_critical_on_high_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let q = ConnectionQuality {
loss_ratio: 0.25,
..Default::default()
};
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
}
#[test]
fn interval_matches_health() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
assert_eq!(m.current_interval(), Duration::from_secs(60));
m.health = LinkHealth::Degraded;
assert_eq!(m.current_interval(), Duration::from_secs(30));
m.health = LinkHealth::Critical;
assert_eq!(m.current_interval(), Duration::from_secs(10));
}
}

View File

@@ -11,3 +11,7 @@ pub mod network;
pub mod server;
pub mod client;
pub mod reconnect;
pub mod telemetry;
pub mod ratelimit;
pub mod qos;
pub mod mtu;

View File

@@ -285,6 +285,39 @@ async fn handle_client_request(
let stats = vpn_client.get_statistics().await;
ManagementResponse::ok(id, stats)
}
"getConnectionQuality" => {
match vpn_client.get_connection_quality() {
Some(quality) => {
let health = vpn_client.get_link_health().await;
let interval_secs = match health {
crate::keepalive::LinkHealth::Healthy => 60,
crate::keepalive::LinkHealth::Degraded => 30,
crate::keepalive::LinkHealth::Critical => 10,
};
ManagementResponse::ok(id, serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", health),
"currentKeepaliveIntervalSecs": interval_secs,
}))
}
None => ManagementResponse::ok(id, serde_json::json!(null)),
}
}
"getMtuInfo" => {
ManagementResponse::ok(id, serde_json::json!({
"tunMtu": 1420,
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
"linkMtu": 1500,
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
"oversizedPacketsDropped": 0,
"icmpTooBigSent": 0,
}))
}
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
}
}
@@ -349,6 +382,50 @@ async fn handle_server_request(
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
}
}
"setClientRateLimit" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
Some(r) => r,
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
};
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
Some(b) => b,
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
};
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
}
}
"removeClientRateLimit" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
match vpn_server.remove_client_rate_limit(&client_id).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
}
}
"getClientTelemetry" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(cid) => cid.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
let clients = vpn_server.list_clients().await;
match clients.into_iter().find(|c| c.client_id == client_id) {
Some(info) => {
match serde_json::to_value(&info) {
Ok(v) => ManagementResponse::ok(id, v),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
}
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
}
}
"generateKeypair" => match crypto::generate_keypair() {
Ok((public_key, private_key)) => ManagementResponse::ok(
id,

314
rust/src/mtu.rs Normal file
View File

@@ -0,0 +1,314 @@
use std::net::Ipv4Addr;
/// Overhead breakdown for VPN tunnel encapsulation.
#[derive(Debug, Clone)]
pub struct TunnelOverhead {
/// Outer IP header: 20 bytes (IPv4, no options).
pub ip_header: u16,
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
pub tcp_header: u16,
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
pub ws_framing: u16,
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
pub vpn_header: u16,
/// Noise AEAD tag: 16 bytes (Poly1305).
pub noise_tag: u16,
}
impl TunnelOverhead {
/// Conservative default overhead estimate.
pub fn default_overhead() -> Self {
Self {
ip_header: 20,
tcp_header: 32,
ws_framing: 6,
vpn_header: 5,
noise_tag: 16,
}
}
/// Total encapsulation overhead in bytes.
pub fn total(&self) -> u16 {
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
}
/// Compute effective TUN MTU given the underlying link MTU.
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
link_mtu.saturating_sub(self.total())
}
}
/// MTU configuration for the VPN tunnel.
#[derive(Debug, Clone)]
pub struct MtuConfig {
/// Underlying link MTU (typically 1500 for Ethernet).
pub link_mtu: u16,
/// Computed effective TUN MTU.
pub effective_mtu: u16,
/// Whether to generate ICMP too-big for oversized packets.
pub send_icmp_too_big: bool,
/// Counter: oversized packets encountered.
pub oversized_packets: u64,
/// Counter: ICMP too-big messages generated.
pub icmp_too_big_sent: u64,
}
impl MtuConfig {
/// Create a new MTU config from the underlying link MTU.
pub fn new(link_mtu: u16) -> Self {
let overhead = TunnelOverhead::default_overhead();
let effective = overhead.effective_tun_mtu(link_mtu);
Self {
link_mtu,
effective_mtu: effective,
send_icmp_too_big: true,
oversized_packets: 0,
icmp_too_big_sent: 0,
}
}
/// Check if a packet exceeds the effective MTU.
pub fn is_oversized(&self, packet_len: usize) -> bool {
packet_len > self.effective_mtu as usize
}
}
/// Action to take after checking MTU.
pub enum MtuAction {
/// Packet is within MTU, forward normally.
Forward,
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
SendIcmpTooBig(Vec<u8>),
}
/// Check packet against MTU config and return the appropriate action.
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
if !config.is_oversized(packet.len()) {
return MtuAction::Forward;
}
if !config.send_icmp_too_big {
return MtuAction::Forward;
}
match generate_icmp_too_big(packet, config.effective_mtu) {
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
None => MtuAction::Forward,
}
}
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
///
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
/// Returns the complete IP + ICMP packet to write back into the TUN device.
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
// Need at least 20 bytes of original IP header
if original_packet.len() < 20 {
return None;
}
// Verify it's IPv4
if original_packet[0] >> 4 != 4 {
return None;
}
// Parse source/dest from original IP header
let src_ip = Ipv4Addr::new(
original_packet[12],
original_packet[13],
original_packet[14],
original_packet[15],
);
let dst_ip = Ipv4Addr::new(
original_packet[16],
original_packet[17],
original_packet[18],
original_packet[19],
);
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
let icmp_payload = &original_packet[..icmp_data_len];
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
icmp.push(3); // Type: Destination Unreachable
icmp.push(4); // Code: Fragmentation Needed and DF was Set
icmp.push(0); // Checksum placeholder
icmp.push(0);
icmp.push(0); // Unused
icmp.push(0);
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
icmp.extend_from_slice(icmp_payload);
// Compute ICMP checksum
let cksum = internet_checksum(&icmp);
icmp[2] = (cksum >> 8) as u8;
icmp[3] = (cksum & 0xff) as u8;
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
let total_len = (20 + icmp.len()) as u16;
let mut ip = Vec::with_capacity(total_len as usize);
ip.push(0x45); // Version 4, IHL 5
ip.push(0x00); // DSCP/ECN
ip.extend_from_slice(&total_len.to_be_bytes());
ip.extend_from_slice(&[0, 0]); // Identification
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
ip.push(64); // TTL
ip.push(1); // Protocol: ICMP
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
// Compute IP header checksum
let ip_cksum = internet_checksum(&ip[..20]);
ip[10] = (ip_cksum >> 8) as u8;
ip[11] = (ip_cksum & 0xff) as u8;
ip.extend_from_slice(&icmp);
Some(ip)
}
/// Standard Internet checksum (RFC 1071).
fn internet_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
i += 2;
}
if i < data.len() {
sum += (data[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_overhead_total() {
let oh = TunnelOverhead::default_overhead();
assert_eq!(oh.total(), 79); // 20+32+6+5+16
}
#[test]
fn effective_mtu_for_ethernet() {
let oh = TunnelOverhead::default_overhead();
let mtu = oh.effective_tun_mtu(1500);
assert_eq!(mtu, 1421); // 1500 - 79
}
#[test]
fn effective_mtu_saturates_at_zero() {
let oh = TunnelOverhead::default_overhead();
let mtu = oh.effective_tun_mtu(50); // Less than overhead
assert_eq!(mtu, 0);
}
#[test]
fn mtu_config_default() {
let config = MtuConfig::new(1500);
assert_eq!(config.effective_mtu, 1421);
assert_eq!(config.link_mtu, 1500);
assert!(config.send_icmp_too_big);
}
#[test]
fn is_oversized() {
let config = MtuConfig::new(1500);
assert!(!config.is_oversized(1421));
assert!(config.is_oversized(1422));
}
#[test]
fn icmp_too_big_generation() {
// Craft a minimal IPv4 packet
let mut original = vec![0u8; 28];
original[0] = 0x45; // version 4, IHL 5
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
original[9] = 6; // TCP
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
// Verify it's a valid IPv4 packet
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
// Source should be original dst (10.0.0.2)
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
// Destination should be original src (10.0.0.1)
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
// ICMP type 3, code 4
assert_eq!(icmp_pkt[20], 3);
assert_eq!(icmp_pkt[21], 4);
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
assert_eq!(mtu, 1421);
}
#[test]
fn icmp_too_big_rejects_short_packet() {
let short = vec![0u8; 10];
assert!(generate_icmp_too_big(&short, 1421).is_none());
}
#[test]
fn icmp_too_big_rejects_non_ipv4() {
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
}
#[test]
fn icmp_checksum_valid() {
let mut original = vec![0u8; 28];
original[0] = 0x45;
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
original[9] = 6;
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
// Verify IP header checksum
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
// Verify ICMP checksum
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
}
#[test]
fn check_mtu_forward() {
let config = MtuConfig::new(1500);
let pkt = vec![0u8; 1421]; // Exactly at MTU
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
}
#[test]
fn check_mtu_oversized_generates_icmp() {
let config = MtuConfig::new(1500);
let mut pkt = vec![0u8; 1500];
pkt[0] = 0x45; // Valid IPv4
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
match check_mtu(&pkt, &config) {
MtuAction::SendIcmpTooBig(icmp) => {
assert_eq!(icmp[20], 3); // ICMP type
assert_eq!(icmp[21], 4); // ICMP code
}
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
}
}
}

490
rust/src/qos.rs Normal file
View File

@@ -0,0 +1,490 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
/// Priority levels for IP packets.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Priority {
High = 0,
Normal = 1,
Low = 2,
}
/// QoS statistics per priority level.
pub struct QosStats {
pub high_enqueued: AtomicU64,
pub normal_enqueued: AtomicU64,
pub low_enqueued: AtomicU64,
pub high_dropped: AtomicU64,
pub normal_dropped: AtomicU64,
pub low_dropped: AtomicU64,
}
impl QosStats {
pub fn new() -> Self {
Self {
high_enqueued: AtomicU64::new(0),
normal_enqueued: AtomicU64::new(0),
low_enqueued: AtomicU64::new(0),
high_dropped: AtomicU64::new(0),
normal_dropped: AtomicU64::new(0),
low_dropped: AtomicU64::new(0),
}
}
}
impl Default for QosStats {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Packet classification
// ============================================================================
/// 5-tuple flow key for tracking bulk flows.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct FlowKey {
src_ip: u32,
dst_ip: u32,
src_port: u16,
dst_port: u16,
protocol: u8,
}
/// Per-flow state for bulk detection.
struct FlowState {
bytes_total: u64,
window_start: Instant,
}
/// Tracks per-flow byte counts for bulk flow detection.
struct FlowTracker {
flows: HashMap<FlowKey, FlowState>,
window_duration: Duration,
max_flows: usize,
}
impl FlowTracker {
fn new(window_duration: Duration, max_flows: usize) -> Self {
Self {
flows: HashMap::new(),
window_duration,
max_flows,
}
}
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
let now = Instant::now();
// Evict if at capacity — remove oldest entry
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
if let Some(oldest_key) = self
.flows
.iter()
.min_by_key(|(_, v)| v.window_start)
.map(|(k, _)| *k)
{
self.flows.remove(&oldest_key);
}
}
let state = self.flows.entry(key).or_insert(FlowState {
bytes_total: 0,
window_start: now,
});
// Reset window if expired
if now.duration_since(state.window_start) > self.window_duration {
state.bytes_total = 0;
state.window_start = now;
}
state.bytes_total += bytes;
state.bytes_total > threshold
}
}
/// Classifies raw IP packets into priority levels.
pub struct PacketClassifier {
flow_tracker: FlowTracker,
/// Byte threshold for classifying a flow as "bulk" (Low priority).
bulk_threshold_bytes: u64,
}
impl PacketClassifier {
/// Create a new classifier.
///
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
pub fn new(bulk_threshold_bytes: u64) -> Self {
Self {
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
bulk_threshold_bytes,
}
}
/// Classify a raw IPv4 packet into a priority level.
///
/// The packet must start with the IPv4 header (as read from a TUN device).
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
// Need at least 20 bytes for a minimal IPv4 header
if ip_packet.len() < 20 {
return Priority::Normal;
}
let version = ip_packet[0] >> 4;
if version != 4 {
return Priority::Normal; // Only classify IPv4 for now
}
let ihl = (ip_packet[0] & 0x0F) as usize;
let header_len = ihl * 4;
let protocol = ip_packet[9];
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
// ICMP is always high priority
if protocol == 1 {
return Priority::High;
}
// Small packets (<128 bytes) are high priority (likely interactive)
if total_len < 128 {
return Priority::High;
}
// Extract ports for TCP (6) and UDP (17)
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
&& ip_packet.len() >= header_len + 4
{
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
(sp, dp)
} else {
(0, 0)
};
// DNS (port 53) and SSH (port 22) are high priority
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
return Priority::High;
}
// Check for bulk flow
if protocol == 6 || protocol == 17 {
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
let key = FlowKey {
src_ip,
dst_ip,
src_port,
dst_port,
protocol,
};
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
return Priority::Low;
}
}
Priority::Normal
}
}
// ============================================================================
// Priority channel set
// ============================================================================
/// Error returned when a packet is dropped.
#[derive(Debug)]
pub enum PacketDropped {
LowPriorityDrop,
NormalPriorityDrop,
HighPriorityDrop,
ChannelClosed,
}
/// Sending half of the priority channel set.
pub struct PrioritySender {
high_tx: mpsc::Sender<Vec<u8>>,
normal_tx: mpsc::Sender<Vec<u8>>,
low_tx: mpsc::Sender<Vec<u8>>,
stats: Arc<QosStats>,
}
impl PrioritySender {
/// Send a packet with the given priority. Implements smart dropping under backpressure.
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
let (tx, enqueued_counter) = match priority {
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
};
match tx.try_send(packet) {
Ok(()) => {
enqueued_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
Err(mpsc::error::TrySendError::Full(packet)) => {
self.handle_backpressure(packet, priority).await
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
}
}
async fn handle_backpressure(
&self,
packet: Vec<u8>,
priority: Priority,
) -> Result<(), PacketDropped> {
match priority {
Priority::Low => {
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::LowPriorityDrop)
}
Priority::Normal => {
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::NormalPriorityDrop)
}
Priority::High => {
// Last resort: briefly wait for space, then drop
match tokio::time::timeout(
Duration::from_millis(5),
self.high_tx.send(packet),
)
.await
{
Ok(Ok(())) => {
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
Ok(())
}
_ => {
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::HighPriorityDrop)
}
}
}
}
}
}
/// Receiving half of the priority channel set.
pub struct PriorityReceiver {
high_rx: mpsc::Receiver<Vec<u8>>,
normal_rx: mpsc::Receiver<Vec<u8>>,
low_rx: mpsc::Receiver<Vec<u8>>,
}
impl PriorityReceiver {
/// Receive the next packet, draining high-priority first (biased select).
pub async fn recv(&mut self) -> Option<Vec<u8>> {
tokio::select! {
biased;
Some(pkt) = self.high_rx.recv() => Some(pkt),
Some(pkt) = self.normal_rx.recv() => Some(pkt),
Some(pkt) = self.low_rx.recv() => Some(pkt),
else => None,
}
}
}
/// Create a priority channel set split into sender and receiver halves.
///
/// - `high_cap`: capacity of the high-priority channel
/// - `normal_cap`: capacity of the normal-priority channel
/// - `low_cap`: capacity of the low-priority channel
pub fn create_priority_channels(
high_cap: usize,
normal_cap: usize,
low_cap: usize,
) -> (PrioritySender, PriorityReceiver) {
let (high_tx, high_rx) = mpsc::channel(high_cap);
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
let (low_tx, low_rx) = mpsc::channel(low_cap);
let stats = Arc::new(QosStats::new());
let sender = PrioritySender {
high_tx,
normal_tx,
low_tx,
stats,
};
let receiver = PriorityReceiver {
high_rx,
normal_rx,
low_rx,
};
(sender, receiver)
}
/// Get a reference to the QoS stats from a sender.
impl PrioritySender {
pub fn stats(&self) -> &Arc<QosStats> {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper: craft a minimal IPv4 packet
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
let mut pkt = vec![0u8; total_len.max(24) as usize];
pkt[0] = 0x45; // version 4, IHL 5
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
pkt[9] = protocol;
// src IP
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
// dst IP
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
// ports (at offset 20 for IHL=5)
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
pkt
}
#[test]
fn classify_icmp_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_dns_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_ssh_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_small_packet_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_normal_http() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_bulk_flow_as_low() {
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
// Send enough traffic to exceed the threshold
for _ in 0..100 {
let pkt = make_ipv4_packet(6, 12345, 80, 500);
c.classify(&pkt);
}
// Next packet from same flow should be Low
let pkt = make_ipv4_packet(6, 12345, 80, 500);
assert_eq!(c.classify(&pkt), Priority::Low);
}
#[test]
fn classify_too_short_packet() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = vec![0u8; 10]; // Too short for IPv4 header
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_non_ipv4() {
let mut c = PacketClassifier::new(1_000_000);
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6 version nibble
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[tokio::test]
async fn priority_receiver_drains_high_first() {
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
// Enqueue in reverse order
sender.send(vec![3], Priority::Low).await.unwrap();
sender.send(vec![2], Priority::Normal).await.unwrap();
sender.send(vec![1], Priority::High).await.unwrap();
// Should drain High first
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
}
#[tokio::test]
async fn smart_dropping_low_priority() {
let (sender, _receiver) = create_priority_channels(8, 8, 1);
// Fill the low channel
sender.send(vec![0], Priority::Low).await.unwrap();
// Next low-priority send should be dropped
let result = sender.send(vec![1], Priority::Low).await;
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn smart_dropping_normal_priority() {
let (sender, _receiver) = create_priority_channels(8, 1, 8);
sender.send(vec![0], Priority::Normal).await.unwrap();
let result = sender.send(vec![1], Priority::Normal).await;
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn stats_track_enqueued() {
let (sender, _receiver) = create_priority_channels(8, 8, 8);
sender.send(vec![1], Priority::High).await.unwrap();
sender.send(vec![2], Priority::High).await.unwrap();
sender.send(vec![3], Priority::Normal).await.unwrap();
sender.send(vec![4], Priority::Low).await.unwrap();
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
}
#[test]
fn flow_tracker_evicts_at_capacity() {
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
ft.record(k1, 100, 1000);
ft.record(k2, 100, 1000);
// Should evict k1 (oldest)
ft.record(k3, 100, 1000);
assert_eq!(ft.flows.len(), 2);
assert!(!ft.flows.contains_key(&k1));
}
}

139
rust/src/ratelimit.rs Normal file
View File

@@ -0,0 +1,139 @@
use std::time::Instant;
/// A token bucket rate limiter operating on byte granularity.
pub struct TokenBucket {
/// Tokens (bytes) added per second.
rate: f64,
/// Maximum burst capacity in bytes.
burst: f64,
/// Currently available tokens.
tokens: f64,
/// Last time tokens were refilled.
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket.
///
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
let burst = burst_bytes as f64;
Self {
rate: rate_bytes_per_sec as f64,
burst,
tokens: burst, // start full
last_refill: Instant::now(),
}
}
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
pub fn try_consume(&mut self, bytes: usize) -> bool {
self.refill();
let needed = bytes as f64;
if needed <= self.tokens {
self.tokens -= needed;
true
} else {
false
}
}
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
self.rate = rate_bytes_per_sec as f64;
self.burst = burst_bytes as f64;
// Cap current tokens at new burst
if self.tokens > self.burst {
self.tokens = self.burst;
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn allows_traffic_under_burst() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Should allow up to burst size immediately
assert!(tb.try_consume(1_500_000));
assert!(tb.try_consume(400_000));
}
#[test]
fn blocks_traffic_over_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
// Consume entire burst
assert!(tb.try_consume(1_000_000));
// Next consume should fail (no time to refill)
assert!(!tb.try_consume(1));
}
#[test]
fn zero_consume_always_succeeds() {
let mut tb = TokenBucket::new(0, 0);
assert!(tb.try_consume(0));
}
#[test]
fn refills_over_time() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
// Drain completely
assert!(tb.try_consume(1_000_000));
assert!(!tb.try_consume(1));
// Wait 100ms — should refill ~100KB
std::thread::sleep(Duration::from_millis(100));
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
}
#[test]
fn update_limits_caps_tokens() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Tokens start at burst (2MB)
tb.update_limits(500_000, 500_000);
// Tokens should be capped to new burst (500KB)
assert!(tb.try_consume(500_000));
assert!(!tb.try_consume(1));
}
#[test]
fn update_limits_changes_rate() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
assert!(tb.try_consume(1_000_000)); // drain
// Change to higher rate
tb.update_limits(10_000_000, 10_000_000);
std::thread::sleep(Duration::from_millis(50));
// At 10MB/s, 50ms should refill ~500KB
assert!(tb.try_consume(200_000));
}
#[test]
fn zero_rate_blocks_after_burst() {
let mut tb = TokenBucket::new(0, 100);
assert!(tb.try_consume(100));
std::thread::sleep(Duration::from_millis(10));
// Zero rate means no refill
assert!(!tb.try_consume(1));
}
#[test]
fn tokens_do_not_exceed_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000);
// Wait to accumulate — but should cap at burst
std::thread::sleep(Duration::from_millis(50));
assert!(tb.try_consume(1_000));
assert!(!tb.try_consume(1));
}
}

View File

@@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_tungstenite::tungstenite::Message;
@@ -12,9 +13,14 @@ use tracing::{info, error, warn};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::mtu::{MtuConfig, TunnelOverhead};
use crate::network::IpPool;
use crate::ratelimit::TokenBucket;
use crate::transport;
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
/// Server configuration (matches TS IVpnServerConfig).
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -29,6 +35,10 @@ pub struct ServerConfig {
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
pub enable_nat: Option<bool>,
/// Default rate limit for new clients (bytes/sec). None = unlimited.
pub default_rate_limit_bytes_per_sec: Option<u64>,
/// Default burst size for new clients (bytes). None = unlimited.
pub default_burst_bytes: Option<u64>,
}
/// Information about a connected client.
@@ -40,6 +50,12 @@ pub struct ClientInfo {
pub connected_since: String,
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_dropped: u64,
pub bytes_dropped: u64,
pub last_keepalive_at: Option<String>,
pub keepalives_received: u64,
pub rate_limit_bytes_per_sec: Option<u64>,
pub burst_bytes: Option<u64>,
}
/// Server statistics.
@@ -63,6 +79,8 @@ pub struct ServerState {
pub ip_pool: Mutex<IpPool>,
pub clients: RwLock<HashMap<String, ClientInfo>>,
pub stats: RwLock<ServerStatistics>,
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
pub mtu_config: MtuConfig,
pub started_at: std::time::Instant,
}
@@ -98,11 +116,18 @@ impl VpnServer {
}
}
let link_mtu = config.mtu.unwrap_or(1420);
// Compute effective MTU from overhead
let overhead = TunnelOverhead::default_overhead();
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
let state = Arc::new(ServerState {
config: config.clone(),
ip_pool: Mutex::new(ip_pool),
clients: RwLock::new(HashMap::new()),
stats: RwLock::new(ServerStatistics::default()),
rate_limiters: Mutex::new(HashMap::new()),
mtu_config,
started_at: std::time::Instant::now(),
});
@@ -166,11 +191,52 @@ impl VpnServer {
if let Some(client) = clients.remove(client_id) {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
state.rate_limiters.lock().await.remove(client_id);
info!("Client {} disconnected", client_id);
}
}
Ok(())
}
/// Set a rate limit for a specific client.
pub async fn set_client_rate_limit(
&self,
client_id: &str,
rate_bytes_per_sec: u64,
burst_bytes: u64,
) -> Result<()> {
if let Some(ref state) = self.state {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(client_id) {
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
} else {
limiters.insert(
client_id.to_string(),
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
);
}
// Update client info
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
info.burst_bytes = Some(burst_bytes);
}
}
Ok(())
}
/// Remove rate limit for a specific client (unlimited).
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
if let Some(ref state) = self.state {
state.rate_limiters.lock().await.remove(client_id);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = None;
info.burst_bytes = None;
}
}
Ok(())
}
}
async fn run_listener(
@@ -257,25 +323,43 @@ async fn handle_client_connection(
let mut noise_transport = responder.into_transport_mode()?;
// Register client
let default_rate = state.config.default_rate_limit_bytes_per_sec;
let default_burst = state.config.default_burst_bytes;
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: assigned_ip.to_string(),
connected_since: timestamp_now(),
bytes_sent: 0,
bytes_received: 0,
packets_dropped: 0,
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: default_rate,
burst_bytes: default_burst,
};
state.clients.write().await.insert(client_id.clone(), client_info);
// Set up rate limiter if defaults are configured
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
state
.rate_limiters
.lock()
.await
.insert(client_id.clone(), TokenBucket::new(rate, burst));
}
{
let mut stats = state.stats.write().await;
stats.total_connections += 1;
}
// Send assigned IP info (encrypted)
// Send assigned IP info (encrypted), include effective MTU
let ip_info = serde_json::json!({
"assignedIp": assigned_ip.to_string(),
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
"mtu": state.config.mtu.unwrap_or(1420),
"effectiveMtu": state.mtu_config.effective_mtu,
});
let ip_info_bytes = serde_json::to_vec(&ip_info)?;
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
@@ -289,66 +373,116 @@ async fn handle_client_connection(
info!("Client {} connected with IP {}", client_id, assigned_ip);
// Main packet loop
// Main packet loop with dead-peer detection
let mut last_activity = tokio::time::Instant::now();
loop {
match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => {
let mut frame_buf = BytesMut::from(&data[..][..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
let mut stats = state.stats.write().await;
stats.bytes_received += len as u64;
stats.packets_received += 1;
tokio::select! {
msg = ws_stream.next() => {
match msg {
Some(Ok(Message::Binary(data))) => {
last_activity = tokio::time::Instant::now();
let mut frame_buf = BytesMut::from(&data[..][..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
// Rate limiting check
let allowed = {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(&client_id) {
limiter.try_consume(len)
} else {
true
}
};
if !allowed {
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.packets_dropped += 1;
info.bytes_dropped += len as u64;
}
continue;
}
let mut stats = state.stats.write().await;
stats.bytes_received += len as u64;
stats.packets_received += 1;
// Update per-client stats
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_received += len as u64;
}
}
Err(e) => {
warn!("Decrypt error from {}: {}", client_id, e);
break;
}
}
}
Err(e) => {
warn!("Decrypt error from {}: {}", client_id, e);
PacketType::Keepalive => {
// Echo the keepalive payload back in the ACK
let ack_frame = Frame {
packet_type: PacketType::KeepaliveAck,
payload: frame.payload.clone(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
stats.keepalives_sent += 1;
// Update per-client keepalive tracking
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.last_keepalive_at = Some(timestamp_now());
info.keepalives_received += 1;
}
}
PacketType::Disconnect => {
info!("Client {} sent disconnect", client_id);
break;
}
_ => {
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
}
},
Ok(None) => {
warn!("Incomplete frame from {}", client_id);
}
Err(e) => {
warn!("Frame decode error from {}: {}", client_id, e);
break;
}
}
PacketType::Keepalive => {
let ack_frame = Frame {
packet_type: PacketType::KeepaliveAck,
payload: vec![],
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
stats.keepalives_sent += 1;
}
PacketType::Disconnect => {
info!("Client {} sent disconnect", client_id);
break;
}
_ => {
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
}
},
Ok(None) => {
warn!("Incomplete frame from {}", client_id);
}
Err(e) => {
warn!("Frame decode error from {}: {}", client_id, e);
Some(Ok(Message::Close(_))) | None => {
info!("Client {} connection closed", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
last_activity = tokio::time::Instant::now();
ws_sink.send(Message::Pong(data)).await?;
}
Some(Ok(_)) => {
last_activity = tokio::time::Instant::now();
continue;
}
Some(Err(e)) => {
warn!("WebSocket error from {}: {}", client_id, e);
break;
}
}
}
Some(Ok(Message::Close(_))) | None => {
info!("Client {} connection closed", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
ws_sink.send(Message::Pong(data)).await?;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
warn!("WebSocket error from {}: {}", client_id, e);
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
break;
}
}
@@ -357,6 +491,7 @@ async fn handle_client_connection(
// Cleanup
state.clients.write().await.remove(&client_id);
state.ip_pool.lock().await.release(&assigned_ip);
state.rate_limiters.lock().await.remove(&client_id);
info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
Ok(())

317
rust/src/telemetry.rs Normal file
View File

@@ -0,0 +1,317 @@
use serde::Serialize;
use std::collections::VecDeque;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
/// A single RTT sample.
#[derive(Debug, Clone)]
struct RttSample {
_rtt: Duration,
_timestamp: Instant,
was_timeout: bool,
}
/// Snapshot of connection quality metrics.
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionQuality {
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
pub srtt_ms: f64,
/// Jitter in milliseconds (mean deviation of RTT).
pub jitter_ms: f64,
/// Minimum RTT observed in the sample window.
pub min_rtt_ms: f64,
/// Maximum RTT observed in the sample window.
pub max_rtt_ms: f64,
/// Packet loss ratio over the sample window (0.0 - 1.0).
pub loss_ratio: f64,
/// Number of consecutive keepalive timeouts (0 if last succeeded).
pub consecutive_timeouts: u32,
/// Total keepalives sent.
pub keepalives_sent: u64,
/// Total keepalive ACKs received.
pub keepalives_acked: u64,
}
/// Tracks connection quality from keepalive round-trips.
pub struct RttTracker {
/// Maximum number of samples to keep in the window.
max_samples: usize,
/// Recent RTT samples (including timeout markers).
samples: VecDeque<RttSample>,
/// When the last keepalive was sent (for computing RTT on ACK).
pending_ping_sent_at: Option<Instant>,
/// Number of consecutive keepalive timeouts.
pub consecutive_timeouts: u32,
/// Smoothed RTT (EMA).
srtt: Option<f64>,
/// Jitter (mean deviation).
jitter: f64,
/// Minimum RTT observed.
min_rtt: f64,
/// Maximum RTT observed.
max_rtt: f64,
/// Total keepalives sent.
keepalives_sent: u64,
/// Total keepalive ACKs received.
keepalives_acked: u64,
/// Previous RTT sample for jitter calculation.
last_rtt_ms: Option<f64>,
}
impl RttTracker {
/// Create a new tracker with the given window size.
pub fn new(max_samples: usize) -> Self {
Self {
max_samples,
samples: VecDeque::with_capacity(max_samples),
pending_ping_sent_at: None,
consecutive_timeouts: 0,
srtt: None,
jitter: 0.0,
min_rtt: f64::MAX,
max_rtt: 0.0,
keepalives_sent: 0,
keepalives_acked: 0,
last_rtt_ms: None,
}
}
/// Record that a keepalive was sent.
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
pub fn mark_ping_sent(&mut self) -> u64 {
self.pending_ping_sent_at = Some(Instant::now());
self.keepalives_sent += 1;
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Record that a keepalive ACK was received with the echoed timestamp.
/// Returns the computed RTT if a pending ping was recorded.
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
let sent_at = self.pending_ping_sent_at.take()?;
let rtt = sent_at.elapsed();
let rtt_ms = rtt.as_secs_f64() * 1000.0;
self.keepalives_acked += 1;
self.consecutive_timeouts = 0;
// Update SRTT (RFC 6298: alpha = 1/8)
match self.srtt {
None => {
self.srtt = Some(rtt_ms);
self.jitter = rtt_ms / 2.0;
}
Some(prev_srtt) => {
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
}
}
// Update min/max
if rtt_ms < self.min_rtt {
self.min_rtt = rtt_ms;
}
if rtt_ms > self.max_rtt {
self.max_rtt = rtt_ms;
}
self.last_rtt_ms = Some(rtt_ms);
// Push sample into window
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(RttSample {
_rtt: rtt,
_timestamp: Instant::now(),
was_timeout: false,
});
Some(rtt)
}
/// Record that a keepalive timed out (no ACK received).
pub fn record_timeout(&mut self) {
self.consecutive_timeouts += 1;
self.pending_ping_sent_at = None;
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(RttSample {
_rtt: Duration::ZERO,
_timestamp: Instant::now(),
was_timeout: true,
});
}
/// Get a snapshot of the current connection quality.
pub fn snapshot(&self) -> ConnectionQuality {
let loss_ratio = if self.samples.is_empty() {
0.0
} else {
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
timeouts as f64 / self.samples.len() as f64
};
ConnectionQuality {
srtt_ms: self.srtt.unwrap_or(0.0),
jitter_ms: self.jitter,
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
max_rtt_ms: self.max_rtt,
loss_ratio,
consecutive_timeouts: self.consecutive_timeouts,
keepalives_sent: self.keepalives_sent,
keepalives_acked: self.keepalives_acked,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_tracker_has_zero_quality() {
let tracker = RttTracker::new(30);
let q = tracker.snapshot();
assert_eq!(q.srtt_ms, 0.0);
assert_eq!(q.jitter_ms, 0.0);
assert_eq!(q.loss_ratio, 0.0);
assert_eq!(q.consecutive_timeouts, 0);
assert_eq!(q.keepalives_sent, 0);
assert_eq!(q.keepalives_acked, 0);
}
#[test]
fn mark_ping_returns_timestamp() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
// Should be a reasonable epoch-ms value (after 2020)
assert!(ts > 1_577_836_800_000);
assert_eq!(tracker.keepalives_sent, 1);
}
#[test]
fn record_ack_computes_rtt() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(5));
let rtt = tracker.record_ack(ts);
assert!(rtt.is_some());
let rtt = rtt.unwrap();
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
assert_eq!(tracker.keepalives_acked, 1);
assert_eq!(tracker.consecutive_timeouts, 0);
}
#[test]
fn record_ack_without_pending_returns_none() {
let mut tracker = RttTracker::new(30);
assert!(tracker.record_ack(12345).is_none());
}
#[test]
fn srtt_converges() {
let mut tracker = RttTracker::new(30);
// Simulate several ping/ack cycles with ~10ms RTT
for _ in 0..10 {
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(10));
tracker.record_ack(ts);
}
let q = tracker.snapshot();
// SRTT should be roughly 10ms (allowing for scheduling variance)
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
}
#[test]
fn timeout_increments_counter_and_loss() {
let mut tracker = RttTracker::new(30);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 1);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 2);
let q = tracker.snapshot();
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
}
#[test]
fn ack_resets_consecutive_timeouts() {
let mut tracker = RttTracker::new(30);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 1);
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
assert_eq!(tracker.consecutive_timeouts, 0);
}
#[test]
fn loss_ratio_over_mixed_window() {
let mut tracker = RttTracker::new(30);
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
for _ in 0..3 {
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
}
tracker.mark_ping_sent();
tracker.record_timeout();
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
let q = tracker.snapshot();
assert!((q.loss_ratio - 0.2).abs() < 0.01);
}
#[test]
fn window_evicts_old_samples() {
let mut tracker = RttTracker::new(5);
// Fill window with 5 timeouts
for _ in 0..5 {
tracker.mark_ping_sent();
tracker.record_timeout();
}
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
// Add 5 successes — should evict all timeouts
for _ in 0..5 {
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
}
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
}
#[test]
fn min_max_rtt_tracked() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(5));
tracker.record_ack(ts);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(15));
tracker.record_ack(ts);
let q = tracker.snapshot();
assert!(q.min_rtt_ms < q.max_rtt_ms);
assert!(q.min_rtt_ms > 0.0);
}
}

View File

@@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
Ok(())
}
/// Action to take after checking a packet against the MTU.
pub enum TunMtuAction {
/// Packet is within MTU limits, forward it.
Forward,
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
IcmpTooBig(Vec<u8>),
}
/// Check a TUN packet against the MTU and return the appropriate action.
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
match crate::mtu::check_mtu(packet, mtu_config) {
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
}
}
/// Remove a route.
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
let output = tokio::process::Command::new("ip")