feat(rust-core): add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
This commit is contained in:
@@ -3,13 +3,14 @@ use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
@@ -65,6 +66,8 @@ pub struct VpnClient {
|
||||
assigned_ip: Arc<RwLock<Option<String>>>,
|
||||
shutdown_tx: Option<mpsc::Sender<()>>,
|
||||
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
|
||||
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
}
|
||||
|
||||
impl VpnClient {
|
||||
@@ -75,6 +78,8 @@ impl VpnClient {
|
||||
assigned_ip: Arc::new(RwLock::new(None)),
|
||||
shutdown_tx: None,
|
||||
connected_since: Arc::new(RwLock::new(None)),
|
||||
quality_rx: None,
|
||||
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +98,7 @@ impl VpnClient {
|
||||
let stats = self.stats.clone();
|
||||
let assigned_ip_ref = self.assigned_ip.clone();
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
@@ -161,6 +167,13 @@ impl VpnClient {
|
||||
|
||||
info!("Connected to VPN, assigned IP: {}", assigned_ip);
|
||||
|
||||
// Create adaptive keepalive monitor
|
||||
let (monitor, handle) = keepalive::create_keepalive(None);
|
||||
self.quality_rx = Some(handle.quality_rx);
|
||||
|
||||
// Spawn the keepalive monitor
|
||||
tokio::spawn(monitor.run());
|
||||
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
@@ -170,7 +183,9 @@ impl VpnClient {
|
||||
state,
|
||||
stats,
|
||||
shutdown_rx,
|
||||
config.keepalive_interval_secs.unwrap_or(30),
|
||||
handle.signal_rx,
|
||||
handle.ack_tx,
|
||||
link_health,
|
||||
));
|
||||
|
||||
Ok(assigned_ip_clone)
|
||||
@@ -184,6 +199,7 @@ impl VpnClient {
|
||||
*self.assigned_ip.write().await = None;
|
||||
*self.connected_since.write().await = None;
|
||||
*self.state.write().await = ClientState::Disconnected;
|
||||
self.quality_rx = None;
|
||||
info!("Disconnected from VPN");
|
||||
Ok(())
|
||||
}
|
||||
@@ -208,13 +224,14 @@ impl VpnClient {
|
||||
status
|
||||
}
|
||||
|
||||
/// Get traffic statistics.
|
||||
/// Get traffic statistics (includes connection quality).
|
||||
pub async fn get_statistics(&self) -> serde_json::Value {
|
||||
let stats = self.stats.read().await;
|
||||
let since = self.connected_since.read().await;
|
||||
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
|
||||
let health = self.link_health.read().await;
|
||||
|
||||
serde_json::json!({
|
||||
let mut result = serde_json::json!({
|
||||
"bytesSent": stats.bytes_sent,
|
||||
"bytesReceived": stats.bytes_received,
|
||||
"packetsSent": stats.packets_sent,
|
||||
@@ -222,7 +239,35 @@ impl VpnClient {
|
||||
"keepalivesSent": stats.keepalives_sent,
|
||||
"keepalivesReceived": stats.keepalives_received,
|
||||
"uptimeSeconds": uptime,
|
||||
})
|
||||
});
|
||||
|
||||
// Include connection quality if available
|
||||
if let Some(ref rx) = self.quality_rx {
|
||||
let quality = rx.borrow().clone();
|
||||
result["quality"] = serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", *health),
|
||||
"keepalivesSent": quality.keepalives_sent,
|
||||
"keepalivesAcked": quality.keepalives_acked,
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get connection quality snapshot.
|
||||
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
|
||||
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
|
||||
}
|
||||
|
||||
/// Get current link health.
|
||||
pub async fn get_link_health(&self) -> LinkHealth {
|
||||
*self.link_health.read().await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,11 +279,11 @@ async fn client_loop(
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
keepalive_secs: u64,
|
||||
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
) {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
|
||||
keepalive_ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -264,6 +309,8 @@ async fn client_loop(
|
||||
}
|
||||
PacketType::KeepaliveAck => {
|
||||
stats.write().await.keepalives_received += 1;
|
||||
// Signal the keepalive monitor that ACK was received
|
||||
let _ = ack_tx.send(()).await;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Server sent disconnect");
|
||||
@@ -290,19 +337,37 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = keepalive_ticker.tick() => {
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
signal = signal_rx.recv() => {
|
||||
match signal {
|
||||
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
|
||||
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: timestamp_ms.to_be_bytes().to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
}
|
||||
}
|
||||
Some(KeepaliveSignal::PeerDead) => {
|
||||
warn!("Peer declared dead by keepalive monitor");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
|
||||
debug!("Link health changed to: {}", health);
|
||||
*link_health.write().await = health;
|
||||
}
|
||||
None => {
|
||||
// Keepalive monitor channel closed
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
|
||||
@@ -1,87 +1,464 @@
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio::time::{interval, timeout};
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Default keepalive interval (30 seconds).
|
||||
use crate::telemetry::{ConnectionQuality, RttTracker};
|
||||
|
||||
/// Default keepalive interval (30 seconds — used for Degraded state).
|
||||
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Default keepalive ACK timeout (10 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default keepalive ACK timeout (5 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Link health states for adaptive keepalive.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LinkHealth {
|
||||
/// RTT stable, jitter low, no loss. Interval: 60s.
|
||||
Healthy,
|
||||
/// Elevated jitter or occasional loss. Interval: 30s.
|
||||
Degraded,
|
||||
/// High loss or sustained jitter spike. Interval: 10s.
|
||||
Critical,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LinkHealth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded => write!(f, "degraded"),
|
||||
Self::Critical => write!(f, "critical"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the adaptive keepalive state machine.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveKeepaliveConfig {
|
||||
/// Interval when link health is Healthy.
|
||||
pub healthy_interval: Duration,
|
||||
/// Interval when link health is Degraded.
|
||||
pub degraded_interval: Duration,
|
||||
/// Interval when link health is Critical.
|
||||
pub critical_interval: Duration,
|
||||
/// ACK timeout (how long to wait for ACK before declaring timeout).
|
||||
pub ack_timeout: Duration,
|
||||
/// Jitter threshold (ms) to enter Degraded from Healthy.
|
||||
pub jitter_degraded_ms: f64,
|
||||
/// Jitter threshold (ms) to return to Healthy from Degraded.
|
||||
pub jitter_healthy_ms: f64,
|
||||
/// Loss ratio threshold to enter Degraded.
|
||||
pub loss_degraded: f64,
|
||||
/// Loss ratio threshold to enter Critical.
|
||||
pub loss_critical: f64,
|
||||
/// Loss ratio threshold to return from Critical to Degraded.
|
||||
pub loss_recover: f64,
|
||||
/// Loss ratio threshold to return from Degraded to Healthy.
|
||||
pub loss_healthy: f64,
|
||||
/// Consecutive checks required for upward state transitions (hysteresis).
|
||||
pub upgrade_checks: u32,
|
||||
/// Consecutive timeouts to declare peer dead in Critical state.
|
||||
pub dead_peer_timeouts: u32,
|
||||
}
|
||||
|
||||
impl Default for AdaptiveKeepaliveConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
healthy_interval: Duration::from_secs(60),
|
||||
degraded_interval: Duration::from_secs(30),
|
||||
critical_interval: Duration::from_secs(10),
|
||||
ack_timeout: Duration::from_secs(5),
|
||||
jitter_degraded_ms: 50.0,
|
||||
jitter_healthy_ms: 30.0,
|
||||
loss_degraded: 0.05,
|
||||
loss_critical: 0.20,
|
||||
loss_recover: 0.10,
|
||||
loss_healthy: 0.02,
|
||||
upgrade_checks: 3,
|
||||
dead_peer_timeouts: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signals from the keepalive monitor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeepaliveSignal {
|
||||
/// Time to send a keepalive ping.
|
||||
SendPing,
|
||||
/// Peer is considered dead (no ACK received within timeout).
|
||||
/// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
|
||||
SendPing(u64),
|
||||
/// Peer is considered dead (no ACK received within timeout repeatedly).
|
||||
PeerDead,
|
||||
/// Link health state changed.
|
||||
LinkHealthChanged(LinkHealth),
|
||||
}
|
||||
|
||||
/// A keepalive monitor that emits signals on a channel.
|
||||
/// A keepalive monitor with adaptive interval and RTT tracking.
|
||||
pub struct KeepaliveMonitor {
|
||||
interval: Duration,
|
||||
timeout_duration: Duration,
|
||||
config: AdaptiveKeepaliveConfig,
|
||||
health: LinkHealth,
|
||||
rtt_tracker: RttTracker,
|
||||
signal_tx: mpsc::Sender<KeepaliveSignal>,
|
||||
ack_rx: mpsc::Receiver<()>,
|
||||
quality_tx: watch::Sender<ConnectionQuality>,
|
||||
consecutive_upgrade_checks: u32,
|
||||
}
|
||||
|
||||
/// Handle returned to the caller to send ACKs and receive signals.
|
||||
pub struct KeepaliveHandle {
|
||||
pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
pub ack_tx: mpsc::Sender<()>,
|
||||
pub quality_rx: watch::Receiver<ConnectionQuality>,
|
||||
}
|
||||
|
||||
/// Create a keepalive monitor and its handle.
|
||||
/// Create an adaptive keepalive monitor and its handle.
|
||||
pub fn create_keepalive(
|
||||
keepalive_interval: Option<Duration>,
|
||||
keepalive_timeout: Option<Duration>,
|
||||
config: Option<AdaptiveKeepaliveConfig>,
|
||||
) -> (KeepaliveMonitor, KeepaliveHandle) {
|
||||
let config = config.unwrap_or_default();
|
||||
let (signal_tx, signal_rx) = mpsc::channel(8);
|
||||
let (ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
let monitor = KeepaliveMonitor {
|
||||
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL),
|
||||
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT),
|
||||
config,
|
||||
health: LinkHealth::Degraded, // start in Degraded, earn Healthy
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
};
|
||||
|
||||
let handle = KeepaliveHandle { signal_rx, ack_tx };
|
||||
let handle = KeepaliveHandle {
|
||||
signal_rx,
|
||||
ack_tx,
|
||||
quality_rx,
|
||||
};
|
||||
|
||||
(monitor, handle)
|
||||
}
|
||||
|
||||
impl KeepaliveMonitor {
|
||||
fn current_interval(&self) -> Duration {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => self.config.healthy_interval,
|
||||
LinkHealth::Degraded => self.config.degraded_interval,
|
||||
LinkHealth::Critical => self.config.critical_interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the keepalive loop. Blocks until the peer is dead or channels close.
|
||||
pub async fn run(mut self) {
|
||||
let mut ticker = interval(self.interval);
|
||||
let mut ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
debug!("Sending keepalive ping signal");
|
||||
|
||||
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() {
|
||||
// Channel closed
|
||||
break;
|
||||
// Record ping sent, get timestamp for payload
|
||||
let timestamp_ms = self.rtt_tracker.mark_ping_sent();
|
||||
debug!("Sending keepalive ping (ts={})", timestamp_ms);
|
||||
|
||||
if self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::SendPing(timestamp_ms))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break; // channel closed
|
||||
}
|
||||
|
||||
// Wait for ACK within timeout
|
||||
match timeout(self.timeout_duration, self.ack_rx.recv()).await {
|
||||
match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
|
||||
Ok(Some(())) => {
|
||||
debug!("Keepalive ACK received");
|
||||
if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
|
||||
debug!("Keepalive ACK received, RTT: {:?}", rtt);
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Channel closed
|
||||
break;
|
||||
break; // channel closed
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Keepalive ACK timeout — peer considered dead");
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
self.rtt_tracker.record_timeout();
|
||||
warn!(
|
||||
"Keepalive ACK timeout (consecutive: {})",
|
||||
self.rtt_tracker.consecutive_timeouts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Publish quality snapshot
|
||||
let quality = self.rtt_tracker.snapshot();
|
||||
let _ = self.quality_tx.send(quality.clone());
|
||||
|
||||
// Evaluate state transition
|
||||
let new_health = self.evaluate_health(&quality);
|
||||
|
||||
if new_health != self.health {
|
||||
info!("Link health: {} -> {}", self.health, new_health);
|
||||
self.health = new_health;
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
|
||||
// Reset ticker to new interval
|
||||
ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
let _ = self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::LinkHealthChanged(new_health))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Check for dead peer in Critical state
|
||||
if self.health == LinkHealth::Critical
|
||||
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
|
||||
{
|
||||
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
|
||||
self.rtt_tracker.consecutive_timeouts);
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => {
|
||||
// Downgrade conditions
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
if quality.jitter_ms > self.config.jitter_degraded_ms
|
||||
|| quality.loss_ratio > self.config.loss_degraded
|
||||
|| quality.consecutive_timeouts >= 1
|
||||
{
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
LinkHealth::Healthy
|
||||
}
|
||||
LinkHealth::Degraded => {
|
||||
// Downgrade to Critical
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
// Upgrade to Healthy (with hysteresis)
|
||||
if quality.jitter_ms < self.config.jitter_healthy_ms
|
||||
&& quality.loss_ratio < self.config.loss_healthy
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Healthy;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Degraded
|
||||
}
|
||||
LinkHealth::Critical => {
|
||||
// Upgrade to Degraded (with hysteresis), never directly to Healthy
|
||||
if quality.loss_ratio < self.config.loss_recover
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= 2 {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Critical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let config = AdaptiveKeepaliveConfig::default();
|
||||
assert_eq!(config.healthy_interval, Duration::from_secs(60));
|
||||
assert_eq!(config.degraded_interval, Duration::from_secs(30));
|
||||
assert_eq!(config.critical_interval, Duration::from_secs(10));
|
||||
assert_eq!(config.ack_timeout, Duration::from_secs(5));
|
||||
assert_eq!(config.dead_peer_timeouts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_health_display() {
|
||||
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
|
||||
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
|
||||
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
|
||||
}
|
||||
|
||||
// Helper to create a monitor for unit-testing evaluate_health
|
||||
fn make_test_monitor() -> KeepaliveMonitor {
|
||||
let (signal_tx, _signal_rx) = mpsc::channel(8);
|
||||
let (_ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
KeepaliveMonitor {
|
||||
config: AdaptiveKeepaliveConfig::default(),
|
||||
health: LinkHealth::Degraded,
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_jitter() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
jitter_ms: 60.0, // > 50ms threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.06, // > 5% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25, // > 20% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_consecutive_timeouts() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
consecutive_timeouts: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good_quality = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 20.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Should require 3 consecutive good checks (default upgrade_checks)
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 2);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_resets_on_bad_check() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let bad = ConnectionQuality {
|
||||
jitter_ms: 60.0, // too high
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
m.evaluate_health(&good); // 1 check
|
||||
m.evaluate_health(&good); // 2 checks
|
||||
m.evaluate_health(&bad); // resets
|
||||
assert_eq!(m.consecutive_upgrade_checks, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_to_degraded_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let recovering = ConnectionQuality {
|
||||
loss_ratio: 0.05, // < 10% recover threshold
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_never_directly_to_healthy() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let perfect = ConnectionQuality {
|
||||
jitter_ms: 1.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 10.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even with perfect quality, must go through Degraded first
|
||||
m.evaluate_health(&perfect); // 1
|
||||
let result = m.evaluate_health(&perfect); // 2 → Degraded
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
// Not Healthy yet
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interval_matches_health() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(60));
|
||||
m.health = LinkHealth::Degraded;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(30));
|
||||
m.health = LinkHealth::Critical;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(10));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,3 +11,7 @@ pub mod network;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod reconnect;
|
||||
pub mod telemetry;
|
||||
pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
|
||||
@@ -285,6 +285,39 @@ async fn handle_client_request(
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
}
|
||||
"getConnectionQuality" => {
|
||||
match vpn_client.get_connection_quality() {
|
||||
Some(quality) => {
|
||||
let health = vpn_client.get_link_health().await;
|
||||
let interval_secs = match health {
|
||||
crate::keepalive::LinkHealth::Healthy => 60,
|
||||
crate::keepalive::LinkHealth::Degraded => 30,
|
||||
crate::keepalive::LinkHealth::Critical => 10,
|
||||
};
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", health),
|
||||
"currentKeepaliveIntervalSecs": interval_secs,
|
||||
}))
|
||||
}
|
||||
None => ManagementResponse::ok(id, serde_json::json!(null)),
|
||||
}
|
||||
}
|
||||
"getMtuInfo" => {
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"tunMtu": 1420,
|
||||
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
|
||||
"linkMtu": 1500,
|
||||
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
|
||||
"oversizedPacketsDropped": 0,
|
||||
"icmpTooBigSent": 0,
|
||||
}))
|
||||
}
|
||||
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
@@ -349,6 +382,50 @@ async fn handle_server_request(
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"setClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
|
||||
Some(r) => r,
|
||||
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
|
||||
};
|
||||
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
|
||||
Some(b) => b,
|
||||
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
|
||||
};
|
||||
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_client_rate_limit(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClientTelemetry" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match clients.into_iter().find(|c| c.client_id == client_id) {
|
||||
Some(info) => {
|
||||
match serde_json::to_value(&info) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
|
||||
}
|
||||
}
|
||||
"generateKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
|
||||
314
rust/src/mtu.rs
Normal file
314
rust/src/mtu.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
/// Overhead breakdown for VPN tunnel encapsulation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TunnelOverhead {
|
||||
/// Outer IP header: 20 bytes (IPv4, no options).
|
||||
pub ip_header: u16,
|
||||
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
|
||||
pub tcp_header: u16,
|
||||
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
|
||||
pub ws_framing: u16,
|
||||
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
|
||||
pub vpn_header: u16,
|
||||
/// Noise AEAD tag: 16 bytes (Poly1305).
|
||||
pub noise_tag: u16,
|
||||
}
|
||||
|
||||
impl TunnelOverhead {
|
||||
/// Conservative default overhead estimate.
|
||||
pub fn default_overhead() -> Self {
|
||||
Self {
|
||||
ip_header: 20,
|
||||
tcp_header: 32,
|
||||
ws_framing: 6,
|
||||
vpn_header: 5,
|
||||
noise_tag: 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total encapsulation overhead in bytes.
|
||||
pub fn total(&self) -> u16 {
|
||||
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
|
||||
}
|
||||
|
||||
/// Compute effective TUN MTU given the underlying link MTU.
|
||||
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
|
||||
link_mtu.saturating_sub(self.total())
|
||||
}
|
||||
}
|
||||
|
||||
/// MTU configuration for the VPN tunnel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MtuConfig {
|
||||
/// Underlying link MTU (typically 1500 for Ethernet).
|
||||
pub link_mtu: u16,
|
||||
/// Computed effective TUN MTU.
|
||||
pub effective_mtu: u16,
|
||||
/// Whether to generate ICMP too-big for oversized packets.
|
||||
pub send_icmp_too_big: bool,
|
||||
/// Counter: oversized packets encountered.
|
||||
pub oversized_packets: u64,
|
||||
/// Counter: ICMP too-big messages generated.
|
||||
pub icmp_too_big_sent: u64,
|
||||
}
|
||||
|
||||
impl MtuConfig {
|
||||
/// Create a new MTU config from the underlying link MTU.
|
||||
pub fn new(link_mtu: u16) -> Self {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let effective = overhead.effective_tun_mtu(link_mtu);
|
||||
Self {
|
||||
link_mtu,
|
||||
effective_mtu: effective,
|
||||
send_icmp_too_big: true,
|
||||
oversized_packets: 0,
|
||||
icmp_too_big_sent: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a packet exceeds the effective MTU.
|
||||
pub fn is_oversized(&self, packet_len: usize) -> bool {
|
||||
packet_len > self.effective_mtu as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Action to take after checking MTU.
|
||||
pub enum MtuAction {
|
||||
/// Packet is within MTU, forward normally.
|
||||
Forward,
|
||||
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
|
||||
SendIcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check packet against MTU config and return the appropriate action.
|
||||
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
|
||||
if !config.is_oversized(packet.len()) {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
if !config.send_icmp_too_big {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
match generate_icmp_too_big(packet, config.effective_mtu) {
|
||||
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
|
||||
None => MtuAction::Forward,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
|
||||
///
|
||||
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
|
||||
/// Returns the complete IP + ICMP packet to write back into the TUN device.
|
||||
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
|
||||
// Need at least 20 bytes of original IP header
|
||||
if original_packet.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Verify it's IPv4
|
||||
if original_packet[0] >> 4 != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse source/dest from original IP header
|
||||
let src_ip = Ipv4Addr::new(
|
||||
original_packet[12],
|
||||
original_packet[13],
|
||||
original_packet[14],
|
||||
original_packet[15],
|
||||
);
|
||||
let dst_ip = Ipv4Addr::new(
|
||||
original_packet[16],
|
||||
original_packet[17],
|
||||
original_packet[18],
|
||||
original_packet[19],
|
||||
);
|
||||
|
||||
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
|
||||
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
|
||||
let icmp_payload = &original_packet[..icmp_data_len];
|
||||
|
||||
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
|
||||
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
|
||||
icmp.push(3); // Type: Destination Unreachable
|
||||
icmp.push(4); // Code: Fragmentation Needed and DF was Set
|
||||
icmp.push(0); // Checksum placeholder
|
||||
icmp.push(0);
|
||||
icmp.push(0); // Unused
|
||||
icmp.push(0);
|
||||
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
|
||||
icmp.extend_from_slice(icmp_payload);
|
||||
|
||||
// Compute ICMP checksum
|
||||
let cksum = internet_checksum(&icmp);
|
||||
icmp[2] = (cksum >> 8) as u8;
|
||||
icmp[3] = (cksum & 0xff) as u8;
|
||||
|
||||
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
|
||||
let total_len = (20 + icmp.len()) as u16;
|
||||
let mut ip = Vec::with_capacity(total_len as usize);
|
||||
ip.push(0x45); // Version 4, IHL 5
|
||||
ip.push(0x00); // DSCP/ECN
|
||||
ip.extend_from_slice(&total_len.to_be_bytes());
|
||||
ip.extend_from_slice(&[0, 0]); // Identification
|
||||
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
|
||||
ip.push(64); // TTL
|
||||
ip.push(1); // Protocol: ICMP
|
||||
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
|
||||
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
|
||||
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
|
||||
|
||||
// Compute IP header checksum
|
||||
let ip_cksum = internet_checksum(&ip[..20]);
|
||||
ip[10] = (ip_cksum >> 8) as u8;
|
||||
ip[11] = (ip_cksum & 0xff) as u8;
|
||||
|
||||
ip.extend_from_slice(&icmp);
|
||||
Some(ip)
|
||||
}
|
||||
|
||||
/// Standard Internet checksum (RFC 1071).
|
||||
fn internet_checksum(data: &[u8]) -> u16 {
|
||||
let mut sum: u32 = 0;
|
||||
let mut i = 0;
|
||||
while i + 1 < data.len() {
|
||||
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
|
||||
i += 2;
|
||||
}
|
||||
if i < data.len() {
|
||||
sum += (data[i] as u32) << 8;
|
||||
}
|
||||
while sum >> 16 != 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16);
|
||||
}
|
||||
!sum as u16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_overhead_total() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
assert_eq!(oh.total(), 79); // 20+32+6+5+16
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_for_ethernet() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(1500);
|
||||
assert_eq!(mtu, 1421); // 1500 - 79
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_saturates_at_zero() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(50); // Less than overhead
|
||||
assert_eq!(mtu, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mtu_config_default() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert_eq!(config.effective_mtu, 1421);
|
||||
assert_eq!(config.link_mtu, 1500);
|
||||
assert!(config.send_icmp_too_big);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_oversized() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert!(!config.is_oversized(1421));
|
||||
assert!(config.is_oversized(1422));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_generation() {
|
||||
// Craft a minimal IPv4 packet
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45; // version 4, IHL 5
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
|
||||
original[9] = 6; // TCP
|
||||
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
|
||||
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
|
||||
|
||||
// Verify it's a valid IPv4 packet
|
||||
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
|
||||
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
|
||||
|
||||
// Source should be original dst (10.0.0.2)
|
||||
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
|
||||
// Destination should be original src (10.0.0.1)
|
||||
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
|
||||
|
||||
// ICMP type 3, code 4
|
||||
assert_eq!(icmp_pkt[20], 3);
|
||||
assert_eq!(icmp_pkt[21], 4);
|
||||
|
||||
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
|
||||
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
|
||||
assert_eq!(mtu, 1421);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_short_packet() {
|
||||
let short = vec![0u8; 10];
|
||||
assert!(generate_icmp_too_big(&short, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_non_ipv4() {
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6
|
||||
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_checksum_valid() {
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45;
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
|
||||
original[9] = 6;
|
||||
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
|
||||
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
|
||||
|
||||
// Verify IP header checksum
|
||||
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
|
||||
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
|
||||
|
||||
// Verify ICMP checksum
|
||||
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
|
||||
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_forward() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let pkt = vec![0u8; 1421]; // Exactly at MTU
|
||||
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_oversized_generates_icmp() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let mut pkt = vec![0u8; 1500];
|
||||
pkt[0] = 0x45; // Valid IPv4
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
|
||||
match check_mtu(&pkt, &config) {
|
||||
MtuAction::SendIcmpTooBig(icmp) => {
|
||||
assert_eq!(icmp[20], 3); // ICMP type
|
||||
assert_eq!(icmp[21], 4); // ICMP code
|
||||
}
|
||||
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
|
||||
}
|
||||
}
|
||||
}
|
||||
490
rust/src/qos.rs
Normal file
490
rust/src/qos.rs
Normal file
@@ -0,0 +1,490 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Priority levels for IP packets.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(u8)]
|
||||
pub enum Priority {
|
||||
High = 0,
|
||||
Normal = 1,
|
||||
Low = 2,
|
||||
}
|
||||
|
||||
/// QoS statistics per priority level.
|
||||
pub struct QosStats {
|
||||
pub high_enqueued: AtomicU64,
|
||||
pub normal_enqueued: AtomicU64,
|
||||
pub low_enqueued: AtomicU64,
|
||||
pub high_dropped: AtomicU64,
|
||||
pub normal_dropped: AtomicU64,
|
||||
pub low_dropped: AtomicU64,
|
||||
}
|
||||
|
||||
impl QosStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
high_enqueued: AtomicU64::new(0),
|
||||
normal_enqueued: AtomicU64::new(0),
|
||||
low_enqueued: AtomicU64::new(0),
|
||||
high_dropped: AtomicU64::new(0),
|
||||
normal_dropped: AtomicU64::new(0),
|
||||
low_dropped: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QosStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Packet classification
|
||||
// ============================================================================
|
||||
|
||||
/// 5-tuple flow key for tracking bulk flows.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct FlowKey {
|
||||
src_ip: u32,
|
||||
dst_ip: u32,
|
||||
src_port: u16,
|
||||
dst_port: u16,
|
||||
protocol: u8,
|
||||
}
|
||||
|
||||
/// Per-flow state for bulk detection.
|
||||
struct FlowState {
|
||||
bytes_total: u64,
|
||||
window_start: Instant,
|
||||
}
|
||||
|
||||
/// Tracks per-flow byte counts for bulk flow detection.
|
||||
struct FlowTracker {
|
||||
flows: HashMap<FlowKey, FlowState>,
|
||||
window_duration: Duration,
|
||||
max_flows: usize,
|
||||
}
|
||||
|
||||
impl FlowTracker {
|
||||
fn new(window_duration: Duration, max_flows: usize) -> Self {
|
||||
Self {
|
||||
flows: HashMap::new(),
|
||||
window_duration,
|
||||
max_flows,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
|
||||
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
// Evict if at capacity — remove oldest entry
|
||||
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
|
||||
if let Some(oldest_key) = self
|
||||
.flows
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.window_start)
|
||||
.map(|(k, _)| *k)
|
||||
{
|
||||
self.flows.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
let state = self.flows.entry(key).or_insert(FlowState {
|
||||
bytes_total: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
// Reset window if expired
|
||||
if now.duration_since(state.window_start) > self.window_duration {
|
||||
state.bytes_total = 0;
|
||||
state.window_start = now;
|
||||
}
|
||||
|
||||
state.bytes_total += bytes;
|
||||
state.bytes_total > threshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Classifies raw IP packets into priority levels.
|
||||
pub struct PacketClassifier {
|
||||
flow_tracker: FlowTracker,
|
||||
/// Byte threshold for classifying a flow as "bulk" (Low priority).
|
||||
bulk_threshold_bytes: u64,
|
||||
}
|
||||
|
||||
impl PacketClassifier {
|
||||
/// Create a new classifier.
|
||||
///
|
||||
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
|
||||
pub fn new(bulk_threshold_bytes: u64) -> Self {
|
||||
Self {
|
||||
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
|
||||
bulk_threshold_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify a raw IPv4 packet into a priority level.
|
||||
///
|
||||
/// The packet must start with the IPv4 header (as read from a TUN device).
|
||||
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
|
||||
// Need at least 20 bytes for a minimal IPv4 header
|
||||
if ip_packet.len() < 20 {
|
||||
return Priority::Normal;
|
||||
}
|
||||
|
||||
let version = ip_packet[0] >> 4;
|
||||
if version != 4 {
|
||||
return Priority::Normal; // Only classify IPv4 for now
|
||||
}
|
||||
|
||||
let ihl = (ip_packet[0] & 0x0F) as usize;
|
||||
let header_len = ihl * 4;
|
||||
let protocol = ip_packet[9];
|
||||
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
|
||||
|
||||
// ICMP is always high priority
|
||||
if protocol == 1 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Small packets (<128 bytes) are high priority (likely interactive)
|
||||
if total_len < 128 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Extract ports for TCP (6) and UDP (17)
|
||||
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
|
||||
&& ip_packet.len() >= header_len + 4
|
||||
{
|
||||
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
|
||||
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
|
||||
(sp, dp)
|
||||
} else {
|
||||
(0, 0)
|
||||
};
|
||||
|
||||
// DNS (port 53) and SSH (port 22) are high priority
|
||||
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Check for bulk flow
|
||||
if protocol == 6 || protocol == 17 {
|
||||
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
|
||||
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
|
||||
|
||||
let key = FlowKey {
|
||||
src_ip,
|
||||
dst_ip,
|
||||
src_port,
|
||||
dst_port,
|
||||
protocol,
|
||||
};
|
||||
|
||||
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
|
||||
return Priority::Low;
|
||||
}
|
||||
}
|
||||
|
||||
Priority::Normal
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Priority channel set
|
||||
// ============================================================================
|
||||
|
||||
/// Error returned when a packet is dropped.
|
||||
#[derive(Debug)]
|
||||
pub enum PacketDropped {
|
||||
LowPriorityDrop,
|
||||
NormalPriorityDrop,
|
||||
HighPriorityDrop,
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
/// Sending half of the priority channel set.
|
||||
pub struct PrioritySender {
|
||||
high_tx: mpsc::Sender<Vec<u8>>,
|
||||
normal_tx: mpsc::Sender<Vec<u8>>,
|
||||
low_tx: mpsc::Sender<Vec<u8>>,
|
||||
stats: Arc<QosStats>,
|
||||
}
|
||||
|
||||
impl PrioritySender {
|
||||
/// Send a packet with the given priority. Implements smart dropping under backpressure.
|
||||
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
|
||||
let (tx, enqueued_counter) = match priority {
|
||||
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
|
||||
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
|
||||
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
|
||||
};
|
||||
|
||||
match tx.try_send(packet) {
|
||||
Ok(()) => {
|
||||
enqueued_counter.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(packet)) => {
|
||||
self.handle_backpressure(packet, priority).await
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_backpressure(
|
||||
&self,
|
||||
packet: Vec<u8>,
|
||||
priority: Priority,
|
||||
) -> Result<(), PacketDropped> {
|
||||
match priority {
|
||||
Priority::Low => {
|
||||
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::LowPriorityDrop)
|
||||
}
|
||||
Priority::Normal => {
|
||||
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::NormalPriorityDrop)
|
||||
}
|
||||
Priority::High => {
|
||||
// Last resort: briefly wait for space, then drop
|
||||
match tokio::time::timeout(
|
||||
Duration::from_millis(5),
|
||||
self.high_tx.send(packet),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(())) => {
|
||||
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::HighPriorityDrop)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Receiving half of the priority channel set.
|
||||
pub struct PriorityReceiver {
|
||||
high_rx: mpsc::Receiver<Vec<u8>>,
|
||||
normal_rx: mpsc::Receiver<Vec<u8>>,
|
||||
low_rx: mpsc::Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl PriorityReceiver {
|
||||
/// Receive the next packet, draining high-priority first (biased select).
|
||||
pub async fn recv(&mut self) -> Option<Vec<u8>> {
|
||||
tokio::select! {
|
||||
biased;
|
||||
Some(pkt) = self.high_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.normal_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.low_rx.recv() => Some(pkt),
|
||||
else => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a priority channel set split into sender and receiver halves.
|
||||
///
|
||||
/// - `high_cap`: capacity of the high-priority channel
|
||||
/// - `normal_cap`: capacity of the normal-priority channel
|
||||
/// - `low_cap`: capacity of the low-priority channel
|
||||
pub fn create_priority_channels(
|
||||
high_cap: usize,
|
||||
normal_cap: usize,
|
||||
low_cap: usize,
|
||||
) -> (PrioritySender, PriorityReceiver) {
|
||||
let (high_tx, high_rx) = mpsc::channel(high_cap);
|
||||
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
|
||||
let (low_tx, low_rx) = mpsc::channel(low_cap);
|
||||
let stats = Arc::new(QosStats::new());
|
||||
|
||||
let sender = PrioritySender {
|
||||
high_tx,
|
||||
normal_tx,
|
||||
low_tx,
|
||||
stats,
|
||||
};
|
||||
|
||||
let receiver = PriorityReceiver {
|
||||
high_rx,
|
||||
normal_rx,
|
||||
low_rx,
|
||||
};
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
/// Get a reference to the QoS stats from a sender.
|
||||
impl PrioritySender {
|
||||
pub fn stats(&self) -> &Arc<QosStats> {
|
||||
&self.stats
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper: craft a minimal IPv4 packet
|
||||
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
|
||||
let mut pkt = vec![0u8; total_len.max(24) as usize];
|
||||
pkt[0] = 0x45; // version 4, IHL 5
|
||||
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
|
||||
pkt[9] = protocol;
|
||||
// src IP
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
// dst IP
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
// ports (at offset 20 for IHL=5)
|
||||
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
|
||||
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
|
||||
pkt
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_icmp_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_dns_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_ssh_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_small_packet_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_normal_http() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_bulk_flow_as_low() {
|
||||
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
|
||||
|
||||
// Send enough traffic to exceed the threshold
|
||||
for _ in 0..100 {
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
c.classify(&pkt);
|
||||
}
|
||||
|
||||
// Next packet from same flow should be Low
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
assert_eq!(c.classify(&pkt), Priority::Low);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_too_short_packet() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = vec![0u8; 10]; // Too short for IPv4 header
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_non_ipv4() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6 version nibble
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_receiver_drains_high_first() {
|
||||
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
// Enqueue in reverse order
|
||||
sender.send(vec![3], Priority::Low).await.unwrap();
|
||||
sender.send(vec![2], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
|
||||
// Should drain High first
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_low_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 1);
|
||||
|
||||
// Fill the low channel
|
||||
sender.send(vec![0], Priority::Low).await.unwrap();
|
||||
|
||||
// Next low-priority send should be dropped
|
||||
let result = sender.send(vec![1], Priority::Low).await;
|
||||
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_normal_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 1, 8);
|
||||
|
||||
sender.send(vec![0], Priority::Normal).await.unwrap();
|
||||
|
||||
let result = sender.send(vec![1], Priority::Normal).await;
|
||||
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_track_enqueued() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
sender.send(vec![2], Priority::High).await.unwrap();
|
||||
sender.send(vec![3], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![4], Priority::Low).await.unwrap();
|
||||
|
||||
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flow_tracker_evicts_at_capacity() {
|
||||
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
|
||||
|
||||
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
|
||||
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
|
||||
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
|
||||
|
||||
ft.record(k1, 100, 1000);
|
||||
ft.record(k2, 100, 1000);
|
||||
// Should evict k1 (oldest)
|
||||
ft.record(k3, 100, 1000);
|
||||
|
||||
assert_eq!(ft.flows.len(), 2);
|
||||
assert!(!ft.flows.contains_key(&k1));
|
||||
}
|
||||
}
|
||||
139
rust/src/ratelimit.rs
Normal file
139
rust/src/ratelimit.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::time::Instant;
|
||||
|
||||
/// A token bucket rate limiter operating on byte granularity.
|
||||
pub struct TokenBucket {
|
||||
/// Tokens (bytes) added per second.
|
||||
rate: f64,
|
||||
/// Maximum burst capacity in bytes.
|
||||
burst: f64,
|
||||
/// Currently available tokens.
|
||||
tokens: f64,
|
||||
/// Last time tokens were refilled.
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new token bucket.
|
||||
///
|
||||
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
|
||||
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
|
||||
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
|
||||
let burst = burst_bytes as f64;
|
||||
Self {
|
||||
rate: rate_bytes_per_sec as f64,
|
||||
burst,
|
||||
tokens: burst, // start full
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
|
||||
pub fn try_consume(&mut self, bytes: usize) -> bool {
|
||||
self.refill();
|
||||
let needed = bytes as f64;
|
||||
if needed <= self.tokens {
|
||||
self.tokens -= needed;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
|
||||
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
|
||||
self.rate = rate_bytes_per_sec as f64;
|
||||
self.burst = burst_bytes as f64;
|
||||
// Cap current tokens at new burst
|
||||
if self.tokens > self.burst {
|
||||
self.tokens = self.burst;
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
||||
self.last_refill = now;
|
||||
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn allows_traffic_under_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Should allow up to burst size immediately
|
||||
assert!(tb.try_consume(1_500_000));
|
||||
assert!(tb.try_consume(400_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocks_traffic_over_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
// Consume entire burst
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
// Next consume should fail (no time to refill)
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_consume_always_succeeds() {
|
||||
let mut tb = TokenBucket::new(0, 0);
|
||||
assert!(tb.try_consume(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refills_over_time() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
|
||||
// Drain completely
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
|
||||
// Wait 100ms — should refill ~100KB
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_caps_tokens() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Tokens start at burst (2MB)
|
||||
tb.update_limits(500_000, 500_000);
|
||||
// Tokens should be capped to new burst (500KB)
|
||||
assert!(tb.try_consume(500_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_changes_rate() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
assert!(tb.try_consume(1_000_000)); // drain
|
||||
|
||||
// Change to higher rate
|
||||
tb.update_limits(10_000_000, 10_000_000);
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
// At 10MB/s, 50ms should refill ~500KB
|
||||
assert!(tb.try_consume(200_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_rate_blocks_after_burst() {
|
||||
let mut tb = TokenBucket::new(0, 100);
|
||||
assert!(tb.try_consume(100));
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
// Zero rate means no refill
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
@@ -12,9 +13,14 @@ use tracing::{info, error, warn};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
use crate::network::IpPool;
|
||||
use crate::ratelimit::TokenBucket;
|
||||
use crate::transport;
|
||||
|
||||
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
|
||||
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
|
||||
|
||||
/// Server configuration (matches TS IVpnServerConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -29,6 +35,10 @@ pub struct ServerConfig {
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
pub enable_nat: Option<bool>,
|
||||
/// Default rate limit for new clients (bytes/sec). None = unlimited.
|
||||
pub default_rate_limit_bytes_per_sec: Option<u64>,
|
||||
/// Default burst size for new clients (bytes). None = unlimited.
|
||||
pub default_burst_bytes: Option<u64>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -40,6 +50,12 @@ pub struct ClientInfo {
|
||||
pub connected_since: String,
|
||||
pub bytes_sent: u64,
|
||||
pub bytes_received: u64,
|
||||
pub packets_dropped: u64,
|
||||
pub bytes_dropped: u64,
|
||||
pub last_keepalive_at: Option<String>,
|
||||
pub keepalives_received: u64,
|
||||
pub rate_limit_bytes_per_sec: Option<u64>,
|
||||
pub burst_bytes: Option<u64>,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -63,6 +79,8 @@ pub struct ServerState {
|
||||
pub ip_pool: Mutex<IpPool>,
|
||||
pub clients: RwLock<HashMap<String, ClientInfo>>,
|
||||
pub stats: RwLock<ServerStatistics>,
|
||||
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
|
||||
pub mtu_config: MtuConfig,
|
||||
pub started_at: std::time::Instant,
|
||||
}
|
||||
|
||||
@@ -98,11 +116,18 @@ impl VpnServer {
|
||||
}
|
||||
}
|
||||
|
||||
let link_mtu = config.mtu.unwrap_or(1420);
|
||||
// Compute effective MTU from overhead
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
|
||||
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
ip_pool: Mutex::new(ip_pool),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(ServerStatistics::default()),
|
||||
rate_limiters: Mutex::new(HashMap::new()),
|
||||
mtu_config,
|
||||
started_at: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
@@ -166,11 +191,52 @@ impl VpnServer {
|
||||
if let Some(client) = clients.remove(client_id) {
|
||||
let ip: Ipv4Addr = client.assigned_ip.parse()?;
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
state.rate_limiters.lock().await.remove(client_id);
|
||||
info!("Client {} disconnected", client_id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a rate limit for a specific client.
|
||||
pub async fn set_client_rate_limit(
|
||||
&self,
|
||||
client_id: &str,
|
||||
rate_bytes_per_sec: u64,
|
||||
burst_bytes: u64,
|
||||
) -> Result<()> {
|
||||
if let Some(ref state) = self.state {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
if let Some(limiter) = limiters.get_mut(client_id) {
|
||||
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
|
||||
} else {
|
||||
limiters.insert(
|
||||
client_id.to_string(),
|
||||
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
|
||||
);
|
||||
}
|
||||
// Update client info
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(client_id) {
|
||||
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
|
||||
info.burst_bytes = Some(burst_bytes);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove rate limit for a specific client (unlimited).
|
||||
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.rate_limiters.lock().await.remove(client_id);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(client_id) {
|
||||
info.rate_limit_bytes_per_sec = None;
|
||||
info.burst_bytes = None;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_listener(
|
||||
@@ -257,25 +323,43 @@ async fn handle_client_connection(
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Register client
|
||||
let default_rate = state.config.default_rate_limit_bytes_per_sec;
|
||||
let default_burst = state.config.default_burst_bytes;
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: assigned_ip.to_string(),
|
||||
connected_since: timestamp_now(),
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
packets_dropped: 0,
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: default_rate,
|
||||
burst_bytes: default_burst,
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
// Set up rate limiter if defaults are configured
|
||||
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
|
||||
state
|
||||
.rate_limiters
|
||||
.lock()
|
||||
.await
|
||||
.insert(client_id.clone(), TokenBucket::new(rate, burst));
|
||||
}
|
||||
|
||||
{
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.total_connections += 1;
|
||||
}
|
||||
|
||||
// Send assigned IP info (encrypted)
|
||||
// Send assigned IP info (encrypted), include effective MTU
|
||||
let ip_info = serde_json::json!({
|
||||
"assignedIp": assigned_ip.to_string(),
|
||||
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
|
||||
"mtu": state.config.mtu.unwrap_or(1420),
|
||||
"effectiveMtu": state.mtu_config.effective_mtu,
|
||||
});
|
||||
let ip_info_bytes = serde_json::to_vec(&ip_info)?;
|
||||
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
|
||||
@@ -289,66 +373,116 @@ async fn handle_client_connection(
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
|
||||
// Main packet loop
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
|
||||
loop {
|
||||
match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.bytes_received += len as u64;
|
||||
stats.packets_received += 1;
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
// Rate limiting check
|
||||
let allowed = {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
if let Some(limiter) = limiters.get_mut(&client_id) {
|
||||
limiter.try_consume(len)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if !allowed {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.bytes_received += len as u64;
|
||||
stats.packets_received += 1;
|
||||
|
||||
// Update per-client stats
|
||||
drop(stats);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.bytes_received += len as u64;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error from {}: {}", client_id, e);
|
||||
PacketType::Keepalive => {
|
||||
// Echo the keepalive payload back in the ACK
|
||||
let ack_frame = Frame {
|
||||
packet_type: PacketType::KeepaliveAck,
|
||||
payload: frame.payload.clone(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
stats.keepalives_sent += 1;
|
||||
|
||||
// Update per-client keepalive tracking
|
||||
drop(stats);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.last_keepalive_at = Some(timestamp_now());
|
||||
info.keepalives_received += 1;
|
||||
}
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Client {} sent disconnect", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!("Incomplete frame from {}", client_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Frame decode error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
PacketType::Keepalive => {
|
||||
let ack_frame = Frame {
|
||||
packet_type: PacketType::KeepaliveAck,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
stats.keepalives_sent += 1;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Client {} sent disconnect", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!("Incomplete frame from {}", client_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Frame decode error from {}: {}", client_id, e);
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
continue;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
|
||||
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -357,6 +491,7 @@ async fn handle_client_connection(
|
||||
// Cleanup
|
||||
state.clients.write().await.remove(&client_id);
|
||||
state.ip_pool.lock().await.release(&assigned_ip);
|
||||
state.rate_limiters.lock().await.remove(&client_id);
|
||||
info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
|
||||
|
||||
Ok(())
|
||||
|
||||
317
rust/src/telemetry.rs
Normal file
317
rust/src/telemetry.rs
Normal file
@@ -0,0 +1,317 @@
|
||||
use serde::Serialize;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single RTT sample.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RttSample {
|
||||
_rtt: Duration,
|
||||
_timestamp: Instant,
|
||||
was_timeout: bool,
|
||||
}
|
||||
|
||||
/// Snapshot of connection quality metrics.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConnectionQuality {
|
||||
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
|
||||
pub srtt_ms: f64,
|
||||
/// Jitter in milliseconds (mean deviation of RTT).
|
||||
pub jitter_ms: f64,
|
||||
/// Minimum RTT observed in the sample window.
|
||||
pub min_rtt_ms: f64,
|
||||
/// Maximum RTT observed in the sample window.
|
||||
pub max_rtt_ms: f64,
|
||||
/// Packet loss ratio over the sample window (0.0 - 1.0).
|
||||
pub loss_ratio: f64,
|
||||
/// Number of consecutive keepalive timeouts (0 if last succeeded).
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Total keepalives sent.
|
||||
pub keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
pub keepalives_acked: u64,
|
||||
}
|
||||
|
||||
/// Tracks connection quality from keepalive round-trips.
|
||||
pub struct RttTracker {
|
||||
/// Maximum number of samples to keep in the window.
|
||||
max_samples: usize,
|
||||
/// Recent RTT samples (including timeout markers).
|
||||
samples: VecDeque<RttSample>,
|
||||
/// When the last keepalive was sent (for computing RTT on ACK).
|
||||
pending_ping_sent_at: Option<Instant>,
|
||||
/// Number of consecutive keepalive timeouts.
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Smoothed RTT (EMA).
|
||||
srtt: Option<f64>,
|
||||
/// Jitter (mean deviation).
|
||||
jitter: f64,
|
||||
/// Minimum RTT observed.
|
||||
min_rtt: f64,
|
||||
/// Maximum RTT observed.
|
||||
max_rtt: f64,
|
||||
/// Total keepalives sent.
|
||||
keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
keepalives_acked: u64,
|
||||
/// Previous RTT sample for jitter calculation.
|
||||
last_rtt_ms: Option<f64>,
|
||||
}
|
||||
|
||||
impl RttTracker {
|
||||
/// Create a new tracker with the given window size.
|
||||
pub fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
max_samples,
|
||||
samples: VecDeque::with_capacity(max_samples),
|
||||
pending_ping_sent_at: None,
|
||||
consecutive_timeouts: 0,
|
||||
srtt: None,
|
||||
jitter: 0.0,
|
||||
min_rtt: f64::MAX,
|
||||
max_rtt: 0.0,
|
||||
keepalives_sent: 0,
|
||||
keepalives_acked: 0,
|
||||
last_rtt_ms: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record that a keepalive was sent.
|
||||
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
|
||||
pub fn mark_ping_sent(&mut self) -> u64 {
|
||||
self.pending_ping_sent_at = Some(Instant::now());
|
||||
self.keepalives_sent += 1;
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Record that a keepalive ACK was received with the echoed timestamp.
|
||||
/// Returns the computed RTT if a pending ping was recorded.
|
||||
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
|
||||
let sent_at = self.pending_ping_sent_at.take()?;
|
||||
let rtt = sent_at.elapsed();
|
||||
let rtt_ms = rtt.as_secs_f64() * 1000.0;
|
||||
|
||||
self.keepalives_acked += 1;
|
||||
self.consecutive_timeouts = 0;
|
||||
|
||||
// Update SRTT (RFC 6298: alpha = 1/8)
|
||||
match self.srtt {
|
||||
None => {
|
||||
self.srtt = Some(rtt_ms);
|
||||
self.jitter = rtt_ms / 2.0;
|
||||
}
|
||||
Some(prev_srtt) => {
|
||||
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
|
||||
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
|
||||
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
|
||||
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
|
||||
}
|
||||
}
|
||||
|
||||
// Update min/max
|
||||
if rtt_ms < self.min_rtt {
|
||||
self.min_rtt = rtt_ms;
|
||||
}
|
||||
if rtt_ms > self.max_rtt {
|
||||
self.max_rtt = rtt_ms;
|
||||
}
|
||||
|
||||
self.last_rtt_ms = Some(rtt_ms);
|
||||
|
||||
// Push sample into window
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: rtt,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: false,
|
||||
});
|
||||
|
||||
Some(rtt)
|
||||
}
|
||||
|
||||
/// Record that a keepalive timed out (no ACK received).
|
||||
pub fn record_timeout(&mut self) {
|
||||
self.consecutive_timeouts += 1;
|
||||
self.pending_ping_sent_at = None;
|
||||
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: Duration::ZERO,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: true,
|
||||
});
|
||||
}
|
||||
|
||||
/// Get a snapshot of the current connection quality.
|
||||
pub fn snapshot(&self) -> ConnectionQuality {
|
||||
let loss_ratio = if self.samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
|
||||
timeouts as f64 / self.samples.len() as f64
|
||||
};
|
||||
|
||||
ConnectionQuality {
|
||||
srtt_ms: self.srtt.unwrap_or(0.0),
|
||||
jitter_ms: self.jitter,
|
||||
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
|
||||
max_rtt_ms: self.max_rtt,
|
||||
loss_ratio,
|
||||
consecutive_timeouts: self.consecutive_timeouts,
|
||||
keepalives_sent: self.keepalives_sent,
|
||||
keepalives_acked: self.keepalives_acked,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_tracker_has_zero_quality() {
|
||||
let tracker = RttTracker::new(30);
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.srtt_ms, 0.0);
|
||||
assert_eq!(q.jitter_ms, 0.0);
|
||||
assert_eq!(q.loss_ratio, 0.0);
|
||||
assert_eq!(q.consecutive_timeouts, 0);
|
||||
assert_eq!(q.keepalives_sent, 0);
|
||||
assert_eq!(q.keepalives_acked, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_ping_returns_timestamp() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
// Should be a reasonable epoch-ms value (after 2020)
|
||||
assert!(ts > 1_577_836_800_000);
|
||||
assert_eq!(tracker.keepalives_sent, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_computes_rtt() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
let rtt = tracker.record_ack(ts);
|
||||
assert!(rtt.is_some());
|
||||
let rtt = rtt.unwrap();
|
||||
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
|
||||
assert_eq!(tracker.keepalives_acked, 1);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_without_pending_returns_none() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
assert!(tracker.record_ack(12345).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn srtt_converges() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// Simulate several ping/ack cycles with ~10ms RTT
|
||||
for _ in 0..10 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
|
||||
let q = tracker.snapshot();
|
||||
// SRTT should be roughly 10ms (allowing for scheduling variance)
|
||||
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
|
||||
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timeout_increments_counter_and_loss() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 2);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ack_resets_consecutive_timeouts() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_ratio_over_mixed_window() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
|
||||
for _ in 0..3 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!((q.loss_ratio - 0.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_evicts_old_samples() {
|
||||
let mut tracker = RttTracker::new(5);
|
||||
|
||||
// Fill window with 5 timeouts
|
||||
for _ in 0..5 {
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
|
||||
|
||||
// Add 5 successes — should evict all timeouts
|
||||
for _ in 0..5 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_rtt_tracked() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(15));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!(q.min_rtt_ms < q.max_rtt_ms);
|
||||
assert!(q.min_rtt_ms > 0.0);
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Action to take after checking a packet against the MTU.
|
||||
pub enum TunMtuAction {
|
||||
/// Packet is within MTU limits, forward it.
|
||||
Forward,
|
||||
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
|
||||
IcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check a TUN packet against the MTU and return the appropriate action.
|
||||
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
|
||||
match crate::mtu::check_mtu(packet, mtu_config) {
|
||||
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
|
||||
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a route.
|
||||
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
let output = tokio::process::Command::new("ip")
|
||||
|
||||
Reference in New Issue
Block a user