6 Commits

19 changed files with 3004 additions and 169 deletions

View File

@@ -1,5 +1,27 @@
# Changelog # Changelog
## 2026-03-17 - 1.3.0 - feat(tests,client)
add flow control and load test coverage and honor configured keepalive intervals
- Adds end-to-end node tests for client/server flow control, keepalive exchange, connection quality telemetry, rate limiting, concurrent clients, and disconnect tracking.
- Adds load testing with throttled proxy scenarios to validate behavior under constrained bandwidth and repeated client churn.
- Updates the Rust client to pass configured keepaliveIntervalSecs into the adaptive keepalive monitor instead of always using defaults.
## 2026-03-15 - 1.2.0 - feat(readme)
document QoS, telemetry, MTU, and rate limiting capabilities in the README
- Expand the architecture and feature overview to cover adaptive keepalive, telemetry, QoS, rate limiting, and MTU handling
- Update client and server examples to show new APIs such as getConnectionQuality(), getMtuInfo(), setClientRateLimit(), and getClientTelemetry()
- Add TypeScript interface documentation for connection quality, MTU info, enriched client statistics, and per-client telemetry
## 2026-03-15 - 1.1.0 - feat(rust-core)
add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
- adds adaptive keepalive monitoring with RTT, jitter, loss, and link health reporting to client statistics and management endpoints
- introduces MTU overhead calculation and oversized-packet handling support, plus client MTU info APIs
- adds token-bucket rate limiting with configurable default limits and server management commands to set, remove, and inspect per-client telemetry
- extends TypeScript client and server interfaces with connection quality, MTU, and client telemetry methods
## 2026-02-27 - 1.0.3 - fix(build) ## 2026-02-27 - 1.0.3 - fix(build)
add aarch64 linker configuration for cross-compilation add aarch64 linker configuration for cross-compilation

View File

@@ -1,6 +1,6 @@
{ {
"name": "@push.rocks/smartvpn", "name": "@push.rocks/smartvpn",
"version": "1.0.3", "version": "1.3.0",
"private": false, "private": false,
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon", "description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
"type": "module", "type": "module",

257
readme.md
View File

@@ -1,6 +1,6 @@
# @push.rocks/smartvpn # @push.rocks/smartvpn
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, typed APIs while all networking heavy lifting — encryption, tunneling, packet forwarding — runs at native speed in Rust. A high-performance VPN with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, fully-typed APIs while all networking heavy lifting — encryption, tunneling, QoS, rate limiting — runs at native speed in Rust.
## Issue Reporting and Security ## Issue Reporting and Security
@@ -9,8 +9,6 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community
## Install ## Install
```bash ```bash
npm install @push.rocks/smartvpn
# or
pnpm install @push.rocks/smartvpn pnpm install @push.rocks/smartvpn
``` ```
@@ -18,17 +16,21 @@ pnpm install @push.rocks/smartvpn
``` ```
TypeScript (control plane) Rust (data plane) TypeScript (control plane) Rust (data plane)
┌──────────────────────────┐ ┌───────────────────────────────┐ ┌──────────────────────────┐ ┌────────────────────────────────────
│ VpnClient / VpnServer │ │ smartvpn_daemon │ │ VpnClient / VpnServer │ │ smartvpn_daemon │
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │ │ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │ │ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha) │ (smartrust) │ │ ├─ crypto (Noise NK + XCha20)
└──────────────────────────┘ │ ├─ codec (binary framing) │ └──────────────────────────┘ │ ├─ codec (binary framing) │
│ ├─ keepalive (app-level) │ ├─ keepalive (adaptive state FSM)
│ ├─ telemetry (RTT/jitter/loss) │
│ ├─ qos (classify + priority Q) │
│ ├─ ratelimit (token bucket) │
│ ├─ mtu (overhead calc + ICMP) │
│ ├─ tunnel (TUN device) │ │ ├─ tunnel (TUN device) │
│ ├─ network (NAT/IP pool) │ │ ├─ network (NAT/IP pool) │
│ └─ reconnect (backoff) │ │ └─ reconnect (exp. backoff) │
└───────────────────────────────┘ └────────────────────────────────────
``` ```
**Key design decisions:** **Key design decisions:**
@@ -37,8 +39,10 @@ TypeScript (control plane) Rust (data plane)
|----------|--------|-----| |----------|--------|-----|
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies | | Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) | | Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
| Keepalive | App-level (not WS pings) | Cloudflare drops WS ping frames; app-level pings survive | | Keepalive | Adaptive app-level pings | Cloudflare drops WS pings; interval adapts to link health (1060s) |
| IPC | JSON lines over stdio/Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) | | QoS | Packet classification + priority queues | DNS/SSH/ICMP always drain first; bulk flows get deprioritized |
| Rate limiting | Per-client token bucket | Byte-granular, dynamically reconfigurable via IPC |
| IPC | JSON lines over stdio / Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed | | Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
## 🚀 Quick Start ## 🚀 Quick Start
@@ -48,15 +52,12 @@ TypeScript (control plane) Rust (data plane)
```typescript ```typescript
import { VpnClient } from '@push.rocks/smartvpn'; import { VpnClient } from '@push.rocks/smartvpn';
// Development: spawn the Rust daemon as a child process
const client = new VpnClient({ const client = new VpnClient({
transport: { transport: 'stdio' }, transport: { transport: 'stdio' },
}); });
// Start the daemon bridge
await client.start(); await client.start();
// Connect to a VPN server
const { assignedIp } = await client.connect({ const { assignedIp } = await client.connect({
serverUrl: 'wss://vpn.example.com/tunnel', serverUrl: 'wss://vpn.example.com/tunnel',
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY', serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
@@ -67,15 +68,23 @@ const { assignedIp } = await client.connect({
console.log(`Connected! Assigned IP: ${assignedIp}`); console.log(`Connected! Assigned IP: ${assignedIp}`);
// Check status // Connection quality (adaptive keepalive + telemetry)
const status = await client.getStatus(); const quality = await client.getConnectionQuality();
console.log(status); // { state: 'connected', assignedIp: '10.8.0.2', ... } console.log(quality);
// {
// srttMs: 42.5, jitterMs: 3.2, minRttMs: 38.0, maxRttMs: 67.0,
// lossRatio: 0.0, consecutiveTimeouts: 0,
// linkHealth: 'healthy', currentKeepaliveIntervalSecs: 60
// }
// Get traffic stats // MTU info
const mtu = await client.getMtuInfo();
console.log(mtu);
// { tunMtu: 1420, effectiveMtu: 1421, linkMtu: 1500, overheadBytes: 79, ... }
// Traffic stats (includes quality snapshot)
const stats = await client.getStatistics(); const stats = await client.getStatistics();
console.log(stats); // { bytesSent, bytesReceived, packetsSent, ... }
// Disconnect
await client.disconnect(); await client.disconnect();
client.stop(); client.stop();
``` ```
@@ -89,33 +98,44 @@ const server = new VpnServer({
transport: { transport: 'stdio' }, transport: { transport: 'stdio' },
}); });
// Start the daemon and the VPN server // Generate a Noise keypair first
await server.start();
// If you don't have keys yet:
const keypair = await server.generateKeypair();
// Start the VPN listener (or pass config to start() directly)
await server.start({ await server.start({
listenAddr: '0.0.0.0:443', listenAddr: '0.0.0.0:443',
privateKey: 'BASE64_PRIVATE_KEY', privateKey: keypair.privateKey,
publicKey: 'BASE64_PUBLIC_KEY', publicKey: keypair.publicKey,
subnet: '10.8.0.0/24', subnet: '10.8.0.0/24',
dns: ['1.1.1.1'], dns: ['1.1.1.1'],
mtu: 1420, mtu: 1420,
enableNat: true, enableNat: true,
// Optional: default rate limit for all new clients
defaultRateLimitBytesPerSec: 10_000_000, // 10 MB/s
defaultBurstBytes: 20_000_000, // 20 MB burst
}); });
// Generate a Noise keypair
const keypair = await server.generateKeypair();
console.log(keypair); // { publicKey: '...', privateKey: '...' }
// List connected clients // List connected clients
const clients = await server.listClients(); const clients = await server.listClients();
// [{ clientId, assignedIp, connectedSince, bytesSent, bytesReceived }]
// Disconnect a specific client // Per-client rate limiting (live, no reconnect needed)
await server.disconnectClient('some-client-id'); await server.setClientRateLimit('client-id', 5_000_000, 10_000_000);
await server.removeClientRateLimit('client-id'); // unlimited
// Get server stats // Per-client telemetry
const stats = await server.getStatistics(); const telemetry = await server.getClientTelemetry('client-id');
// { bytesSent, bytesReceived, activeClients, totalConnections, ... } console.log(telemetry);
// {
// clientId, assignedIp, lastKeepaliveAt, keepalivesReceived,
// packetsDropped, bytesDropped, bytesReceived, bytesSent,
// rateLimitBytesPerSec, burstBytes
// }
// Kick a client
await server.disconnectClient('client-id');
// Stop
await server.stopServer(); await server.stopServer();
server.stop(); server.stop();
``` ```
@@ -151,7 +171,9 @@ When using socket transport, `client.stop()` closes the socket but **does not ki
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server | | `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
| `disconnect()` | `Promise<void>` | Disconnect from VPN | | `disconnect()` | `Promise<void>` | Disconnect from VPN |
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state | | `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic statistics | | `getStatistics()` | `Promise<IVpnStatistics>` | Traffic stats + connection quality |
| `getConnectionQuality()` | `Promise<IVpnConnectionQuality>` | RTT, jitter, loss, link health |
| `getMtuInfo()` | `Promise<IVpnMtuInfo>` | MTU info and overhead breakdown |
| `stop()` | `void` | Kill/close the daemon bridge | | `stop()` | `void` | Kill/close the daemon bridge |
| `running` | `boolean` | Whether bridge is active | | `running` | `boolean` | Whether bridge is active |
@@ -163,9 +185,12 @@ When using socket transport, `client.stop()` closes the socket but **does not ki
| `stopServer()` | `Promise<void>` | Stop the VPN server | | `stopServer()` | `Promise<void>` | Stop the VPN server |
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state | | `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) | | `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients | | `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients with QoS stats |
| `disconnectClient(id)` | `Promise<void>` | Kick a client | | `disconnectClient(id)` | `Promise<void>` | Kick a client |
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair | | `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
| `setClientRateLimit(id, rate, burst)` | `Promise<void>` | Set per-client rate limit (bytes/sec) |
| `removeClientRateLimit(id)` | `Promise<void>` | Remove rate limit (unlimited) |
| `getClientTelemetry(id)` | `Promise<IVpnClientTelemetry>` | Per-client telemetry + drop stats |
| `stop()` | `void` | Kill/close the daemon bridge | | `stop()` | `void` | Kill/close the daemon bridge |
### `VpnConfig` ### `VpnConfig`
@@ -191,26 +216,23 @@ Generate system service units for the daemon:
```typescript ```typescript
import { VpnInstaller } from '@push.rocks/smartvpn'; import { VpnInstaller } from '@push.rocks/smartvpn';
// Auto-detect platform
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown' const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
// Generate systemd unit (Linux) // Linux (systemd)
const unit = VpnInstaller.generateSystemdUnit({ const unit = VpnInstaller.generateSystemdUnit({
binaryPath: '/usr/local/bin/smartvpn_daemon', binaryPath: '/usr/local/bin/smartvpn_daemon',
socketPath: '/var/run/smartvpn.sock', socketPath: '/var/run/smartvpn.sock',
mode: 'server', mode: 'server',
}); });
// unit.content = full systemd .service file
// unit.installPath = '/etc/systemd/system/smartvpn-server.service'
// Generate launchd plist (macOS) // macOS (launchd)
const plist = VpnInstaller.generateLaunchdPlist({ const plist = VpnInstaller.generateLaunchdPlist({
binaryPath: '/usr/local/bin/smartvpn_daemon', binaryPath: '/usr/local/bin/smartvpn_daemon',
socketPath: '/var/run/smartvpn.sock', socketPath: '/var/run/smartvpn.sock',
mode: 'client', mode: 'client',
}); });
// Auto-detect and generate // Auto-detect platform
const serviceUnit = VpnInstaller.generateServiceUnit({ const serviceUnit = VpnInstaller.generateServiceUnit({
binaryPath: '/usr/local/bin/smartvpn_daemon', binaryPath: '/usr/local/bin/smartvpn_daemon',
socketPath: '/var/run/smartvpn.sock', socketPath: '/var/run/smartvpn.sock',
@@ -223,8 +245,6 @@ const serviceUnit = VpnInstaller.generateServiceUnit({
Both `VpnClient` and `VpnServer` extend `EventEmitter`: Both `VpnClient` and `VpnServer` extend `EventEmitter`:
```typescript ```typescript
client.on('status', (status) => { /* IVpnStatus */ });
client.on('error', (err) => { /* { message, code? } */ });
client.on('exit', ({ code, signal }) => { /* daemon exited */ }); client.on('exit', ({ code, signal }) => { /* daemon exited */ });
client.on('reconnected', () => { /* socket reconnected */ }); client.on('reconnected', () => { /* socket reconnected */ });
@@ -232,13 +252,84 @@ server.on('client-connected', (info) => { /* IVpnClientInfo */ });
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ }); server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
``` ```
## 📊 QoS System
The Rust daemon includes a full QoS stack that operates on decrypted IP packets:
### Adaptive Keepalive
The keepalive system automatically adjusts its interval based on connection quality:
| Link Health | Keepalive Interval | Triggered When |
|-------------|-------------------|----------------|
| 🟢 Healthy | 60s | Jitter < 30ms, loss < 2%, no timeouts |
| 🟡 Degraded | 30s | Jitter > 50ms, loss > 5%, or 1+ timeout |
| 🔴 Critical | 10s | Loss > 20% or 2+ consecutive timeouts |
State transitions include hysteresis (3 consecutive good checks to upgrade, 2 to recover) to prevent flapping. Dead peer detection fires after 3 consecutive timeouts in Critical state.
### Packet Classification
IP packets are classified into three priority levels by inspecting headers (no deep packet inspection):
| Priority | Traffic |
|----------|---------|
| **High** | ICMP, DNS (port 53), SSH (port 22), small packets (< 128 bytes) |
| **Normal** | Everything else |
| **Low** | Bulk flows exceeding 1 MB within a 60s window |
Priority channels drain with biased `tokio::select!` — high-priority packets always go first.
### Smart Packet Dropping
Under backpressure, packets are dropped intelligently:
1. **Low** queue full → drop silently
2. **Normal** queue full → drop
3. **High** queue full → wait 5ms, then drop as last resort
Drop statistics are tracked per priority level and exposed via telemetry.
### Per-Client Rate Limiting
Token bucket algorithm with byte granularity:
```typescript
// Set: 10 MB/s sustained, 20 MB burst
await server.setClientRateLimit('client-id', 10_000_000, 20_000_000);
// Check drops via telemetry
const t = await server.getClientTelemetry('client-id');
console.log(`Dropped: ${t.packetsDropped} packets, ${t.bytesDropped} bytes`);
// Remove limit
await server.removeClientRateLimit('client-id');
```
Rate limits can be changed live without disconnecting the client.
### Path MTU
Tunnel overhead is calculated precisely:
| Layer | Bytes |
|-------|-------|
| IP header | 20 |
| TCP header (with timestamps) | 32 |
| WebSocket framing | 6 |
| VPN frame header | 5 |
| Noise AEAD tag | 16 |
| **Total overhead** | **79** |
For a standard 1500-byte Ethernet link, effective TUN MTU = **1421 bytes**. The default TUN MTU of 1420 is conservative and correct. Oversized packets get an ICMP "Fragmentation Needed" (Type 3, Code 4) written back into the TUN, so the source TCP adjusts its MSS automatically.
## 🔐 Security Model ## 🔐 Security Model
The VPN uses a **Noise NK** handshake pattern: The VPN uses a **Noise NK** handshake pattern:
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key 1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
2. The client generates an ephemeral keypair, performs `e, es` (Diffie-Hellman with server's static key) 2. The client generates an ephemeral keypair, performs `e, es` (DH with server's static key)
3. Server responds with `e, ee` (Diffie-Hellman with both ephemeral keys) 3. Server responds with `e, ee` (DH with both ephemeral keys)
4. Result: forward-secret transport keys derived from both DH operations 4. Result: forward-secret transport keys derived from both DH operations
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**: Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
@@ -261,8 +352,8 @@ Inside the WebSocket tunnel, packets use a simple binary framing:
| `HandshakeInit` | `0x01` | Client → Server handshake | | `HandshakeInit` | `0x01` | Client → Server handshake |
| `HandshakeResp` | `0x02` | Server → Client handshake | | `HandshakeResp` | `0x02` | Server → Client handshake |
| `IpPacket` | `0x10` | Encrypted IP packet | | `IpPacket` | `0x10` | Encrypted IP packet |
| `Keepalive` | `0x20` | App-level ping | | `Keepalive` | `0x20` | App-level ping (8-byte timestamp payload) |
| `KeepaliveAck` | `0x21` | App-level pong | | `KeepaliveAck` | `0x21` | App-level pong (echoes timestamp for RTT) |
| `SessionResume` | `0x30` | Resume a dropped session | | `SessionResume` | `0x30` | Resume a dropped session |
| `SessionResumeOk` | `0x31` | Resume accepted | | `SessionResumeOk` | `0x31` | Resume accepted |
| `SessionResumeErr` | `0x32` | Resume rejected | | `SessionResumeErr` | `0x32` | Resume rejected |
@@ -270,8 +361,6 @@ Inside the WebSocket tunnel, packets use a simple binary framing:
## 🛠️ Rust Daemon CLI ## 🛠️ Rust Daemon CLI
The Rust binary supports several modes:
```bash ```bash
# Development: stdio management (JSON lines on stdin/stdout) # Development: stdio management (JSON lines on stdin/stdout)
smartvpn_daemon --management --mode client smartvpn_daemon --management --mode client
@@ -290,16 +379,14 @@ smartvpn_daemon --generate-keypair
# Install dependencies # Install dependencies
pnpm install pnpm install
# Build TypeScript + cross-compile Rust # Build TypeScript + cross-compile Rust (amd64 + arm64)
pnpm build pnpm build
# Build Rust only (debug) # Build Rust only (debug)
cd rust && cargo build cd rust && cargo build
# Run Rust tests # Run all tests (71 Rust + 32 TypeScript)
cd rust && cargo test cd rust && cargo test
# Run TypeScript tests
pnpm test pnpm test
``` ```
@@ -323,25 +410,27 @@ type TVpnTransportOptions =
// Client config // Client config
interface IVpnClientConfig { interface IVpnClientConfig {
serverUrl: string; // e.g. 'wss://vpn.example.com/tunnel' serverUrl: string;
serverPublicKey: string; // base64-encoded Noise static key serverPublicKey: string;
dns?: string[]; dns?: string[];
mtu?: number; // default: 1420 mtu?: number;
keepaliveIntervalSecs?: number; // default: 30 keepaliveIntervalSecs?: number;
} }
// Server config // Server config
interface IVpnServerConfig { interface IVpnServerConfig {
listenAddr: string; // e.g. '0.0.0.0:443' listenAddr: string;
privateKey: string; // base64 Noise static private key privateKey: string;
publicKey: string; // base64 Noise static public key publicKey: string;
subnet: string; // e.g. '10.8.0.0/24' subnet: string;
tlsCert?: string; tlsCert?: string;
tlsKey?: string; tlsKey?: string;
dns?: string[]; dns?: string[];
mtu?: number; mtu?: number;
keepaliveIntervalSecs?: number; keepaliveIntervalSecs?: number;
enableNat?: boolean; enableNat?: boolean;
defaultRateLimitBytesPerSec?: number;
defaultBurstBytes?: number;
} }
// Status // Status
@@ -365,6 +454,7 @@ interface IVpnStatistics {
keepalivesSent: number; keepalivesSent: number;
keepalivesReceived: number; keepalivesReceived: number;
uptimeSeconds: number; uptimeSeconds: number;
quality?: IVpnConnectionQuality;
} }
interface IVpnServerStatistics extends IVpnStatistics { interface IVpnServerStatistics extends IVpnStatistics {
@@ -372,12 +462,57 @@ interface IVpnServerStatistics extends IVpnStatistics {
totalConnections: number; totalConnections: number;
} }
// Connection quality (QoS)
type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
interface IVpnConnectionQuality {
srttMs: number;
jitterMs: number;
minRttMs: number;
maxRttMs: number;
lossRatio: number;
consecutiveTimeouts: number;
linkHealth: TVpnLinkHealth;
currentKeepaliveIntervalSecs: number;
}
// MTU info
interface IVpnMtuInfo {
tunMtu: number;
effectiveMtu: number;
linkMtu: number;
overheadBytes: number;
oversizedPacketsDropped: number;
icmpTooBigSent: number;
}
// Client info (with QoS fields)
interface IVpnClientInfo { interface IVpnClientInfo {
clientId: string; clientId: string;
assignedIp: string; assignedIp: string;
connectedSince: string; connectedSince: string;
bytesSent: number; bytesSent: number;
bytesReceived: number; bytesReceived: number;
packetsDropped: number;
bytesDropped: number;
lastKeepaliveAt?: string;
keepalivesReceived: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
}
// Per-client telemetry
interface IVpnClientTelemetry {
clientId: string;
assignedIp: string;
lastKeepaliveAt?: string;
keepalivesReceived: number;
packetsDropped: number;
bytesDropped: number;
bytesReceived: number;
bytesSent: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
} }
interface IVpnKeypair { interface IVpnKeypair {

View File

@@ -3,13 +3,14 @@ use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use tokio::sync::{mpsc, watch, RwLock};
use tokio::sync::{mpsc, RwLock};
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn}; use tracing::{info, error, warn, debug};
use crate::codec::{Frame, FrameCodec, PacketType}; use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto; use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
use crate::transport; use crate::transport;
/// Client configuration (matches TS IVpnClientConfig). /// Client configuration (matches TS IVpnClientConfig).
@@ -65,6 +66,8 @@ pub struct VpnClient {
assigned_ip: Arc<RwLock<Option<String>>>, assigned_ip: Arc<RwLock<Option<String>>>,
shutdown_tx: Option<mpsc::Sender<()>>, shutdown_tx: Option<mpsc::Sender<()>>,
connected_since: Arc<RwLock<Option<std::time::Instant>>>, connected_since: Arc<RwLock<Option<std::time::Instant>>>,
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
link_health: Arc<RwLock<LinkHealth>>,
} }
impl VpnClient { impl VpnClient {
@@ -75,6 +78,8 @@ impl VpnClient {
assigned_ip: Arc::new(RwLock::new(None)), assigned_ip: Arc::new(RwLock::new(None)),
shutdown_tx: None, shutdown_tx: None,
connected_since: Arc::new(RwLock::new(None)), connected_since: Arc::new(RwLock::new(None)),
quality_rx: None,
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
} }
} }
@@ -93,6 +98,7 @@ impl VpnClient {
let stats = self.stats.clone(); let stats = self.stats.clone();
let assigned_ip_ref = self.assigned_ip.clone(); let assigned_ip_ref = self.assigned_ip.clone();
let connected_since = self.connected_since.clone(); let connected_since = self.connected_since.clone();
let link_health = self.link_health.clone();
// Decode server public key // Decode server public key
let server_pub_key = base64::Engine::decode( let server_pub_key = base64::Engine::decode(
@@ -161,6 +167,20 @@ impl VpnClient {
info!("Connected to VPN, assigned IP: {}", assigned_ip); info!("Connected to VPN, assigned IP: {}", assigned_ip);
// Create adaptive keepalive monitor (use custom interval if configured)
let ka_config = config.keepalive_interval_secs.map(|secs| {
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
cfg.degraded_interval = std::time::Duration::from_secs(secs);
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
cfg
});
let (monitor, handle) = keepalive::create_keepalive(ka_config);
self.quality_rx = Some(handle.quality_rx);
// Spawn the keepalive monitor
tokio::spawn(monitor.run());
// Spawn packet forwarding loop // Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone(); let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop( tokio::spawn(client_loop(
@@ -170,7 +190,9 @@ impl VpnClient {
state, state,
stats, stats,
shutdown_rx, shutdown_rx,
config.keepalive_interval_secs.unwrap_or(30), handle.signal_rx,
handle.ack_tx,
link_health,
)); ));
Ok(assigned_ip_clone) Ok(assigned_ip_clone)
@@ -184,6 +206,7 @@ impl VpnClient {
*self.assigned_ip.write().await = None; *self.assigned_ip.write().await = None;
*self.connected_since.write().await = None; *self.connected_since.write().await = None;
*self.state.write().await = ClientState::Disconnected; *self.state.write().await = ClientState::Disconnected;
self.quality_rx = None;
info!("Disconnected from VPN"); info!("Disconnected from VPN");
Ok(()) Ok(())
} }
@@ -208,13 +231,14 @@ impl VpnClient {
status status
} }
/// Get traffic statistics. /// Get traffic statistics (includes connection quality).
pub async fn get_statistics(&self) -> serde_json::Value { pub async fn get_statistics(&self) -> serde_json::Value {
let stats = self.stats.read().await; let stats = self.stats.read().await;
let since = self.connected_since.read().await; let since = self.connected_since.read().await;
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0); let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
let health = self.link_health.read().await;
serde_json::json!({ let mut result = serde_json::json!({
"bytesSent": stats.bytes_sent, "bytesSent": stats.bytes_sent,
"bytesReceived": stats.bytes_received, "bytesReceived": stats.bytes_received,
"packetsSent": stats.packets_sent, "packetsSent": stats.packets_sent,
@@ -222,7 +246,35 @@ impl VpnClient {
"keepalivesSent": stats.keepalives_sent, "keepalivesSent": stats.keepalives_sent,
"keepalivesReceived": stats.keepalives_received, "keepalivesReceived": stats.keepalives_received,
"uptimeSeconds": uptime, "uptimeSeconds": uptime,
}) });
// Include connection quality if available
if let Some(ref rx) = self.quality_rx {
let quality = rx.borrow().clone();
result["quality"] = serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", *health),
"keepalivesSent": quality.keepalives_sent,
"keepalivesAcked": quality.keepalives_acked,
});
}
result
}
/// Get connection quality snapshot.
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
}
/// Get current link health.
pub async fn get_link_health(&self) -> LinkHealth {
*self.link_health.read().await
} }
} }
@@ -234,11 +286,11 @@ async fn client_loop(
state: Arc<RwLock<ClientState>>, state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>, stats: Arc<RwLock<ClientStatistics>>,
mut shutdown_rx: mpsc::Receiver<()>, mut shutdown_rx: mpsc::Receiver<()>,
keepalive_secs: u64, mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
ack_tx: mpsc::Sender<()>,
link_health: Arc<RwLock<LinkHealth>>,
) { ) {
let mut buf = vec![0u8; 65535]; let mut buf = vec![0u8; 65535];
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
keepalive_ticker.tick().await; // skip first immediate tick
loop { loop {
tokio::select! { tokio::select! {
@@ -264,6 +316,8 @@ async fn client_loop(
} }
PacketType::KeepaliveAck => { PacketType::KeepaliveAck => {
stats.write().await.keepalives_received += 1; stats.write().await.keepalives_received += 1;
// Signal the keepalive monitor that ACK was received
let _ = ack_tx.send(()).await;
} }
PacketType::Disconnect => { PacketType::Disconnect => {
info!("Server sent disconnect"); info!("Server sent disconnect");
@@ -290,10 +344,13 @@ async fn client_loop(
} }
} }
} }
_ = keepalive_ticker.tick() => { signal = signal_rx.recv() => {
match signal {
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
let ka_frame = Frame { let ka_frame = Frame {
packet_type: PacketType::Keepalive, packet_type: PacketType::Keepalive,
payload: vec![], payload: timestamp_ms.to_be_bytes().to_vec(),
}; };
let mut frame_bytes = BytesMut::new(); let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() { if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
@@ -305,6 +362,21 @@ async fn client_loop(
stats.write().await.keepalives_sent += 1; stats.write().await.keepalives_sent += 1;
} }
} }
Some(KeepaliveSignal::PeerDead) => {
warn!("Peer declared dead by keepalive monitor");
*state.write().await = ClientState::Disconnected;
break;
}
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
debug!("Link health changed to: {}", health);
*link_health.write().await = health;
}
None => {
// Keepalive monitor channel closed
break;
}
}
}
_ = shutdown_rx.recv() => { _ = shutdown_rx.recv() => {
// Send disconnect frame // Send disconnect frame
let dc_frame = Frame { let dc_frame = Frame {

View File

@@ -1,87 +1,464 @@
use std::time::Duration; use std::time::Duration;
use tokio::sync::mpsc; use tokio::sync::{mpsc, watch};
use tokio::time::{interval, timeout}; use tokio::time::{interval, timeout};
use tracing::{debug, warn}; use tracing::{debug, info, warn};
/// Default keepalive interval (30 seconds). use crate::telemetry::{ConnectionQuality, RttTracker};
/// Default keepalive interval (30 seconds — used for Degraded state).
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
/// Default keepalive ACK timeout (10 seconds). /// Default keepalive ACK timeout (5 seconds).
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
/// Link health states for adaptive keepalive.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinkHealth {
/// RTT stable, jitter low, no loss. Interval: 60s.
Healthy,
/// Elevated jitter or occasional loss. Interval: 30s.
Degraded,
/// High loss or sustained jitter spike. Interval: 10s.
Critical,
}
impl std::fmt::Display for LinkHealth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Healthy => write!(f, "healthy"),
Self::Degraded => write!(f, "degraded"),
Self::Critical => write!(f, "critical"),
}
}
}
/// Configuration for the adaptive keepalive state machine.
#[derive(Debug, Clone)]
pub struct AdaptiveKeepaliveConfig {
/// Interval when link health is Healthy.
pub healthy_interval: Duration,
/// Interval when link health is Degraded.
pub degraded_interval: Duration,
/// Interval when link health is Critical.
pub critical_interval: Duration,
/// ACK timeout (how long to wait for ACK before declaring timeout).
pub ack_timeout: Duration,
/// Jitter threshold (ms) to enter Degraded from Healthy.
pub jitter_degraded_ms: f64,
/// Jitter threshold (ms) to return to Healthy from Degraded.
pub jitter_healthy_ms: f64,
/// Loss ratio threshold to enter Degraded.
pub loss_degraded: f64,
/// Loss ratio threshold to enter Critical.
pub loss_critical: f64,
/// Loss ratio threshold to return from Critical to Degraded.
pub loss_recover: f64,
/// Loss ratio threshold to return from Degraded to Healthy.
pub loss_healthy: f64,
/// Consecutive checks required for upward state transitions (hysteresis).
pub upgrade_checks: u32,
/// Consecutive timeouts to declare peer dead in Critical state.
pub dead_peer_timeouts: u32,
}
impl Default for AdaptiveKeepaliveConfig {
fn default() -> Self {
Self {
healthy_interval: Duration::from_secs(60),
degraded_interval: Duration::from_secs(30),
critical_interval: Duration::from_secs(10),
ack_timeout: Duration::from_secs(5),
jitter_degraded_ms: 50.0,
jitter_healthy_ms: 30.0,
loss_degraded: 0.05,
loss_critical: 0.20,
loss_recover: 0.10,
loss_healthy: 0.02,
upgrade_checks: 3,
dead_peer_timeouts: 3,
}
}
}
/// Signals from the keepalive monitor. /// Signals from the keepalive monitor.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum KeepaliveSignal { pub enum KeepaliveSignal {
/// Time to send a keepalive ping. /// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
SendPing, SendPing(u64),
/// Peer is considered dead (no ACK received within timeout). /// Peer is considered dead (no ACK received within timeout repeatedly).
PeerDead, PeerDead,
/// Link health state changed.
LinkHealthChanged(LinkHealth),
} }
/// A keepalive monitor that emits signals on a channel. /// A keepalive monitor with adaptive interval and RTT tracking.
pub struct KeepaliveMonitor { pub struct KeepaliveMonitor {
interval: Duration, config: AdaptiveKeepaliveConfig,
timeout_duration: Duration, health: LinkHealth,
rtt_tracker: RttTracker,
signal_tx: mpsc::Sender<KeepaliveSignal>, signal_tx: mpsc::Sender<KeepaliveSignal>,
ack_rx: mpsc::Receiver<()>, ack_rx: mpsc::Receiver<()>,
quality_tx: watch::Sender<ConnectionQuality>,
consecutive_upgrade_checks: u32,
} }
/// Handle returned to the caller to send ACKs and receive signals. /// Handle returned to the caller to send ACKs and receive signals.
pub struct KeepaliveHandle { pub struct KeepaliveHandle {
pub signal_rx: mpsc::Receiver<KeepaliveSignal>, pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
pub ack_tx: mpsc::Sender<()>, pub ack_tx: mpsc::Sender<()>,
pub quality_rx: watch::Receiver<ConnectionQuality>,
} }
/// Create a keepalive monitor and its handle. /// Create an adaptive keepalive monitor and its handle.
pub fn create_keepalive( pub fn create_keepalive(
keepalive_interval: Option<Duration>, config: Option<AdaptiveKeepaliveConfig>,
keepalive_timeout: Option<Duration>,
) -> (KeepaliveMonitor, KeepaliveHandle) { ) -> (KeepaliveMonitor, KeepaliveHandle) {
let config = config.unwrap_or_default();
let (signal_tx, signal_rx) = mpsc::channel(8); let (signal_tx, signal_rx) = mpsc::channel(8);
let (ack_tx, ack_rx) = mpsc::channel(8); let (ack_tx, ack_rx) = mpsc::channel(8);
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
let monitor = KeepaliveMonitor { let monitor = KeepaliveMonitor {
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL), config,
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT), health: LinkHealth::Degraded, // start in Degraded, earn Healthy
rtt_tracker: RttTracker::new(30),
signal_tx, signal_tx,
ack_rx, ack_rx,
quality_tx,
consecutive_upgrade_checks: 0,
}; };
let handle = KeepaliveHandle { signal_rx, ack_tx }; let handle = KeepaliveHandle {
signal_rx,
ack_tx,
quality_rx,
};
(monitor, handle) (monitor, handle)
} }
impl KeepaliveMonitor { impl KeepaliveMonitor {
fn current_interval(&self) -> Duration {
match self.health {
LinkHealth::Healthy => self.config.healthy_interval,
LinkHealth::Degraded => self.config.degraded_interval,
LinkHealth::Critical => self.config.critical_interval,
}
}
/// Run the keepalive loop. Blocks until the peer is dead or channels close. /// Run the keepalive loop. Blocks until the peer is dead or channels close.
pub async fn run(mut self) { pub async fn run(mut self) {
let mut ticker = interval(self.interval); let mut ticker = interval(self.current_interval());
ticker.tick().await; // skip first immediate tick ticker.tick().await; // skip first immediate tick
loop { loop {
ticker.tick().await; ticker.tick().await;
debug!("Sending keepalive ping signal");
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() { // Record ping sent, get timestamp for payload
// Channel closed let timestamp_ms = self.rtt_tracker.mark_ping_sent();
break; debug!("Sending keepalive ping (ts={})", timestamp_ms);
if self
.signal_tx
.send(KeepaliveSignal::SendPing(timestamp_ms))
.await
.is_err()
{
break; // channel closed
} }
// Wait for ACK within timeout // Wait for ACK within timeout
match timeout(self.timeout_duration, self.ack_rx.recv()).await { match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
Ok(Some(())) => { Ok(Some(())) => {
debug!("Keepalive ACK received"); if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
debug!("Keepalive ACK received, RTT: {:?}", rtt);
}
} }
Ok(None) => { Ok(None) => {
// Channel closed break; // channel closed
break;
} }
Err(_) => { Err(_) => {
warn!("Keepalive ACK timeout — peer considered dead"); self.rtt_tracker.record_timeout();
warn!(
"Keepalive ACK timeout (consecutive: {})",
self.rtt_tracker.consecutive_timeouts
);
}
}
// Publish quality snapshot
let quality = self.rtt_tracker.snapshot();
let _ = self.quality_tx.send(quality.clone());
// Evaluate state transition
let new_health = self.evaluate_health(&quality);
if new_health != self.health {
info!("Link health: {} -> {}", self.health, new_health);
self.health = new_health;
self.consecutive_upgrade_checks = 0;
// Reset ticker to new interval
ticker = interval(self.current_interval());
ticker.tick().await; // skip first immediate tick
let _ = self
.signal_tx
.send(KeepaliveSignal::LinkHealthChanged(new_health))
.await;
}
// Check for dead peer in Critical state
if self.health == LinkHealth::Critical
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
{
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
self.rtt_tracker.consecutive_timeouts);
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await; let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
break; break;
} }
} }
} }
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
match self.health {
LinkHealth::Healthy => {
// Downgrade conditions
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Critical;
}
if quality.jitter_ms > self.config.jitter_degraded_ms
|| quality.loss_ratio > self.config.loss_degraded
|| quality.consecutive_timeouts >= 1
{
self.consecutive_upgrade_checks = 0;
return LinkHealth::Degraded;
}
LinkHealth::Healthy
}
LinkHealth::Degraded => {
// Downgrade to Critical
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Critical;
}
// Upgrade to Healthy (with hysteresis)
if quality.jitter_ms < self.config.jitter_healthy_ms
&& quality.loss_ratio < self.config.loss_healthy
&& quality.consecutive_timeouts == 0
{
self.consecutive_upgrade_checks += 1;
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Healthy;
}
} else {
self.consecutive_upgrade_checks = 0;
}
LinkHealth::Degraded
}
LinkHealth::Critical => {
// Upgrade to Degraded (with hysteresis), never directly to Healthy
if quality.loss_ratio < self.config.loss_recover
&& quality.consecutive_timeouts == 0
{
self.consecutive_upgrade_checks += 1;
if self.consecutive_upgrade_checks >= 2 {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Degraded;
}
} else {
self.consecutive_upgrade_checks = 0;
}
LinkHealth::Critical
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let config = AdaptiveKeepaliveConfig::default();
assert_eq!(config.healthy_interval, Duration::from_secs(60));
assert_eq!(config.degraded_interval, Duration::from_secs(30));
assert_eq!(config.critical_interval, Duration::from_secs(10));
assert_eq!(config.ack_timeout, Duration::from_secs(5));
assert_eq!(config.dead_peer_timeouts, 3);
}
#[test]
fn link_health_display() {
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
}
// Helper to create a monitor for unit-testing evaluate_health
fn make_test_monitor() -> KeepaliveMonitor {
let (signal_tx, _signal_rx) = mpsc::channel(8);
let (_ack_tx, ack_rx) = mpsc::channel(8);
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
KeepaliveMonitor {
config: AdaptiveKeepaliveConfig::default(),
health: LinkHealth::Degraded,
rtt_tracker: RttTracker::new(30),
signal_tx,
ack_rx,
quality_tx,
consecutive_upgrade_checks: 0,
}
}
#[test]
fn healthy_to_degraded_on_jitter() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
jitter_ms: 60.0, // > 50ms threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Degraded);
}
#[test]
fn healthy_to_degraded_on_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
loss_ratio: 0.06, // > 5% threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Degraded);
}
#[test]
fn healthy_to_critical_on_high_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
loss_ratio: 0.25, // > 20% threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Critical);
}
#[test]
fn healthy_to_critical_on_consecutive_timeouts() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
consecutive_timeouts: 2,
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Critical);
}
#[test]
fn degraded_to_healthy_requires_hysteresis() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let good_quality = ConnectionQuality {
jitter_ms: 10.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
srtt_ms: 20.0,
..Default::default()
};
// Should require 3 consecutive good checks (default upgrade_checks)
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
assert_eq!(m.consecutive_upgrade_checks, 1);
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
assert_eq!(m.consecutive_upgrade_checks, 2);
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
}
#[test]
fn degraded_to_healthy_resets_on_bad_check() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let good = ConnectionQuality {
jitter_ms: 10.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
..Default::default()
};
let bad = ConnectionQuality {
jitter_ms: 60.0, // too high
loss_ratio: 0.0,
consecutive_timeouts: 0,
..Default::default()
};
m.evaluate_health(&good); // 1 check
m.evaluate_health(&good); // 2 checks
m.evaluate_health(&bad); // resets
assert_eq!(m.consecutive_upgrade_checks, 0);
}
#[test]
fn critical_to_degraded_requires_hysteresis() {
let mut m = make_test_monitor();
m.health = LinkHealth::Critical;
let recovering = ConnectionQuality {
loss_ratio: 0.05, // < 10% recover threshold
consecutive_timeouts: 0,
..Default::default()
};
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
assert_eq!(m.consecutive_upgrade_checks, 1);
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
}
#[test]
fn critical_never_directly_to_healthy() {
let mut m = make_test_monitor();
m.health = LinkHealth::Critical;
let perfect = ConnectionQuality {
jitter_ms: 1.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
srtt_ms: 10.0,
..Default::default()
};
// Even with perfect quality, must go through Degraded first
m.evaluate_health(&perfect); // 1
let result = m.evaluate_health(&perfect); // 2 → Degraded
assert_eq!(result, LinkHealth::Degraded);
// Not Healthy yet
}
#[test]
fn degraded_to_critical_on_high_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let q = ConnectionQuality {
loss_ratio: 0.25,
..Default::default()
};
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
}
#[test]
fn interval_matches_health() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
assert_eq!(m.current_interval(), Duration::from_secs(60));
m.health = LinkHealth::Degraded;
assert_eq!(m.current_interval(), Duration::from_secs(30));
m.health = LinkHealth::Critical;
assert_eq!(m.current_interval(), Duration::from_secs(10));
} }
} }

View File

@@ -11,3 +11,7 @@ pub mod network;
pub mod server; pub mod server;
pub mod client; pub mod client;
pub mod reconnect; pub mod reconnect;
pub mod telemetry;
pub mod ratelimit;
pub mod qos;
pub mod mtu;

View File

@@ -285,6 +285,39 @@ async fn handle_client_request(
let stats = vpn_client.get_statistics().await; let stats = vpn_client.get_statistics().await;
ManagementResponse::ok(id, stats) ManagementResponse::ok(id, stats)
} }
"getConnectionQuality" => {
match vpn_client.get_connection_quality() {
Some(quality) => {
let health = vpn_client.get_link_health().await;
let interval_secs = match health {
crate::keepalive::LinkHealth::Healthy => 60,
crate::keepalive::LinkHealth::Degraded => 30,
crate::keepalive::LinkHealth::Critical => 10,
};
ManagementResponse::ok(id, serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", health),
"currentKeepaliveIntervalSecs": interval_secs,
}))
}
None => ManagementResponse::ok(id, serde_json::json!(null)),
}
}
"getMtuInfo" => {
ManagementResponse::ok(id, serde_json::json!({
"tunMtu": 1420,
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
"linkMtu": 1500,
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
"oversizedPacketsDropped": 0,
"icmpTooBigSent": 0,
}))
}
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)), _ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
} }
} }
@@ -349,6 +382,50 @@ async fn handle_server_request(
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)), Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
} }
} }
"setClientRateLimit" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
Some(r) => r,
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
};
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
Some(b) => b,
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
};
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
}
}
"removeClientRateLimit" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
match vpn_server.remove_client_rate_limit(&client_id).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
}
}
"getClientTelemetry" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(cid) => cid.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
let clients = vpn_server.list_clients().await;
match clients.into_iter().find(|c| c.client_id == client_id) {
Some(info) => {
match serde_json::to_value(&info) {
Ok(v) => ManagementResponse::ok(id, v),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
}
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
}
}
"generateKeypair" => match crypto::generate_keypair() { "generateKeypair" => match crypto::generate_keypair() {
Ok((public_key, private_key)) => ManagementResponse::ok( Ok((public_key, private_key)) => ManagementResponse::ok(
id, id,

314
rust/src/mtu.rs Normal file
View File

@@ -0,0 +1,314 @@
use std::net::Ipv4Addr;
/// Overhead breakdown for VPN tunnel encapsulation.
#[derive(Debug, Clone)]
pub struct TunnelOverhead {
/// Outer IP header: 20 bytes (IPv4, no options).
pub ip_header: u16,
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
pub tcp_header: u16,
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
pub ws_framing: u16,
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
pub vpn_header: u16,
/// Noise AEAD tag: 16 bytes (Poly1305).
pub noise_tag: u16,
}
impl TunnelOverhead {
/// Conservative default overhead estimate.
pub fn default_overhead() -> Self {
Self {
ip_header: 20,
tcp_header: 32,
ws_framing: 6,
vpn_header: 5,
noise_tag: 16,
}
}
/// Total encapsulation overhead in bytes.
pub fn total(&self) -> u16 {
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
}
/// Compute effective TUN MTU given the underlying link MTU.
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
link_mtu.saturating_sub(self.total())
}
}
/// MTU configuration for the VPN tunnel.
#[derive(Debug, Clone)]
pub struct MtuConfig {
/// Underlying link MTU (typically 1500 for Ethernet).
pub link_mtu: u16,
/// Computed effective TUN MTU.
pub effective_mtu: u16,
/// Whether to generate ICMP too-big for oversized packets.
pub send_icmp_too_big: bool,
/// Counter: oversized packets encountered.
pub oversized_packets: u64,
/// Counter: ICMP too-big messages generated.
pub icmp_too_big_sent: u64,
}
impl MtuConfig {
/// Create a new MTU config from the underlying link MTU.
pub fn new(link_mtu: u16) -> Self {
let overhead = TunnelOverhead::default_overhead();
let effective = overhead.effective_tun_mtu(link_mtu);
Self {
link_mtu,
effective_mtu: effective,
send_icmp_too_big: true,
oversized_packets: 0,
icmp_too_big_sent: 0,
}
}
/// Check if a packet exceeds the effective MTU.
pub fn is_oversized(&self, packet_len: usize) -> bool {
packet_len > self.effective_mtu as usize
}
}
/// Action to take after checking MTU.
pub enum MtuAction {
/// Packet is within MTU, forward normally.
Forward,
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
SendIcmpTooBig(Vec<u8>),
}
/// Check packet against MTU config and return the appropriate action.
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
if !config.is_oversized(packet.len()) {
return MtuAction::Forward;
}
if !config.send_icmp_too_big {
return MtuAction::Forward;
}
match generate_icmp_too_big(packet, config.effective_mtu) {
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
None => MtuAction::Forward,
}
}
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
///
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
/// Returns the complete IP + ICMP packet to write back into the TUN device.
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
// Need at least 20 bytes of original IP header
if original_packet.len() < 20 {
return None;
}
// Verify it's IPv4
if original_packet[0] >> 4 != 4 {
return None;
}
// Parse source/dest from original IP header
let src_ip = Ipv4Addr::new(
original_packet[12],
original_packet[13],
original_packet[14],
original_packet[15],
);
let dst_ip = Ipv4Addr::new(
original_packet[16],
original_packet[17],
original_packet[18],
original_packet[19],
);
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
let icmp_payload = &original_packet[..icmp_data_len];
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
icmp.push(3); // Type: Destination Unreachable
icmp.push(4); // Code: Fragmentation Needed and DF was Set
icmp.push(0); // Checksum placeholder
icmp.push(0);
icmp.push(0); // Unused
icmp.push(0);
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
icmp.extend_from_slice(icmp_payload);
// Compute ICMP checksum
let cksum = internet_checksum(&icmp);
icmp[2] = (cksum >> 8) as u8;
icmp[3] = (cksum & 0xff) as u8;
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
let total_len = (20 + icmp.len()) as u16;
let mut ip = Vec::with_capacity(total_len as usize);
ip.push(0x45); // Version 4, IHL 5
ip.push(0x00); // DSCP/ECN
ip.extend_from_slice(&total_len.to_be_bytes());
ip.extend_from_slice(&[0, 0]); // Identification
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
ip.push(64); // TTL
ip.push(1); // Protocol: ICMP
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
// Compute IP header checksum
let ip_cksum = internet_checksum(&ip[..20]);
ip[10] = (ip_cksum >> 8) as u8;
ip[11] = (ip_cksum & 0xff) as u8;
ip.extend_from_slice(&icmp);
Some(ip)
}
/// Standard Internet checksum (RFC 1071).
fn internet_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
i += 2;
}
if i < data.len() {
sum += (data[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_overhead_total() {
let oh = TunnelOverhead::default_overhead();
assert_eq!(oh.total(), 79); // 20+32+6+5+16
}
#[test]
fn effective_mtu_for_ethernet() {
let oh = TunnelOverhead::default_overhead();
let mtu = oh.effective_tun_mtu(1500);
assert_eq!(mtu, 1421); // 1500 - 79
}
#[test]
fn effective_mtu_saturates_at_zero() {
let oh = TunnelOverhead::default_overhead();
let mtu = oh.effective_tun_mtu(50); // Less than overhead
assert_eq!(mtu, 0);
}
#[test]
fn mtu_config_default() {
let config = MtuConfig::new(1500);
assert_eq!(config.effective_mtu, 1421);
assert_eq!(config.link_mtu, 1500);
assert!(config.send_icmp_too_big);
}
#[test]
fn is_oversized() {
let config = MtuConfig::new(1500);
assert!(!config.is_oversized(1421));
assert!(config.is_oversized(1422));
}
#[test]
fn icmp_too_big_generation() {
// Craft a minimal IPv4 packet
let mut original = vec![0u8; 28];
original[0] = 0x45; // version 4, IHL 5
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
original[9] = 6; // TCP
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
// Verify it's a valid IPv4 packet
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
// Source should be original dst (10.0.0.2)
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
// Destination should be original src (10.0.0.1)
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
// ICMP type 3, code 4
assert_eq!(icmp_pkt[20], 3);
assert_eq!(icmp_pkt[21], 4);
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
assert_eq!(mtu, 1421);
}
#[test]
fn icmp_too_big_rejects_short_packet() {
let short = vec![0u8; 10];
assert!(generate_icmp_too_big(&short, 1421).is_none());
}
#[test]
fn icmp_too_big_rejects_non_ipv4() {
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
}
#[test]
fn icmp_checksum_valid() {
let mut original = vec![0u8; 28];
original[0] = 0x45;
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
original[9] = 6;
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
// Verify IP header checksum
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
// Verify ICMP checksum
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
}
#[test]
fn check_mtu_forward() {
let config = MtuConfig::new(1500);
let pkt = vec![0u8; 1421]; // Exactly at MTU
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
}
#[test]
fn check_mtu_oversized_generates_icmp() {
let config = MtuConfig::new(1500);
let mut pkt = vec![0u8; 1500];
pkt[0] = 0x45; // Valid IPv4
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
match check_mtu(&pkt, &config) {
MtuAction::SendIcmpTooBig(icmp) => {
assert_eq!(icmp[20], 3); // ICMP type
assert_eq!(icmp[21], 4); // ICMP code
}
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
}
}
}

490
rust/src/qos.rs Normal file
View File

@@ -0,0 +1,490 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
/// Priority levels for IP packets.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Priority {
High = 0,
Normal = 1,
Low = 2,
}
/// QoS statistics per priority level.
pub struct QosStats {
pub high_enqueued: AtomicU64,
pub normal_enqueued: AtomicU64,
pub low_enqueued: AtomicU64,
pub high_dropped: AtomicU64,
pub normal_dropped: AtomicU64,
pub low_dropped: AtomicU64,
}
impl QosStats {
pub fn new() -> Self {
Self {
high_enqueued: AtomicU64::new(0),
normal_enqueued: AtomicU64::new(0),
low_enqueued: AtomicU64::new(0),
high_dropped: AtomicU64::new(0),
normal_dropped: AtomicU64::new(0),
low_dropped: AtomicU64::new(0),
}
}
}
impl Default for QosStats {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Packet classification
// ============================================================================
/// 5-tuple flow key for tracking bulk flows.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct FlowKey {
src_ip: u32,
dst_ip: u32,
src_port: u16,
dst_port: u16,
protocol: u8,
}
/// Per-flow state for bulk detection.
struct FlowState {
bytes_total: u64,
window_start: Instant,
}
/// Tracks per-flow byte counts for bulk flow detection.
struct FlowTracker {
flows: HashMap<FlowKey, FlowState>,
window_duration: Duration,
max_flows: usize,
}
impl FlowTracker {
fn new(window_duration: Duration, max_flows: usize) -> Self {
Self {
flows: HashMap::new(),
window_duration,
max_flows,
}
}
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
let now = Instant::now();
// Evict if at capacity — remove oldest entry
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
if let Some(oldest_key) = self
.flows
.iter()
.min_by_key(|(_, v)| v.window_start)
.map(|(k, _)| *k)
{
self.flows.remove(&oldest_key);
}
}
let state = self.flows.entry(key).or_insert(FlowState {
bytes_total: 0,
window_start: now,
});
// Reset window if expired
if now.duration_since(state.window_start) > self.window_duration {
state.bytes_total = 0;
state.window_start = now;
}
state.bytes_total += bytes;
state.bytes_total > threshold
}
}
/// Classifies raw IP packets into priority levels.
pub struct PacketClassifier {
flow_tracker: FlowTracker,
/// Byte threshold for classifying a flow as "bulk" (Low priority).
bulk_threshold_bytes: u64,
}
impl PacketClassifier {
/// Create a new classifier.
///
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
pub fn new(bulk_threshold_bytes: u64) -> Self {
Self {
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
bulk_threshold_bytes,
}
}
/// Classify a raw IPv4 packet into a priority level.
///
/// The packet must start with the IPv4 header (as read from a TUN device).
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
// Need at least 20 bytes for a minimal IPv4 header
if ip_packet.len() < 20 {
return Priority::Normal;
}
let version = ip_packet[0] >> 4;
if version != 4 {
return Priority::Normal; // Only classify IPv4 for now
}
let ihl = (ip_packet[0] & 0x0F) as usize;
let header_len = ihl * 4;
let protocol = ip_packet[9];
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
// ICMP is always high priority
if protocol == 1 {
return Priority::High;
}
// Small packets (<128 bytes) are high priority (likely interactive)
if total_len < 128 {
return Priority::High;
}
// Extract ports for TCP (6) and UDP (17)
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
&& ip_packet.len() >= header_len + 4
{
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
(sp, dp)
} else {
(0, 0)
};
// DNS (port 53) and SSH (port 22) are high priority
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
return Priority::High;
}
// Check for bulk flow
if protocol == 6 || protocol == 17 {
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
let key = FlowKey {
src_ip,
dst_ip,
src_port,
dst_port,
protocol,
};
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
return Priority::Low;
}
}
Priority::Normal
}
}
// ============================================================================
// Priority channel set
// ============================================================================
/// Error returned when a packet is dropped.
#[derive(Debug)]
pub enum PacketDropped {
LowPriorityDrop,
NormalPriorityDrop,
HighPriorityDrop,
ChannelClosed,
}
/// Sending half of the priority channel set.
pub struct PrioritySender {
high_tx: mpsc::Sender<Vec<u8>>,
normal_tx: mpsc::Sender<Vec<u8>>,
low_tx: mpsc::Sender<Vec<u8>>,
stats: Arc<QosStats>,
}
impl PrioritySender {
/// Send a packet with the given priority. Implements smart dropping under backpressure.
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
let (tx, enqueued_counter) = match priority {
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
};
match tx.try_send(packet) {
Ok(()) => {
enqueued_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
Err(mpsc::error::TrySendError::Full(packet)) => {
self.handle_backpressure(packet, priority).await
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
}
}
async fn handle_backpressure(
&self,
packet: Vec<u8>,
priority: Priority,
) -> Result<(), PacketDropped> {
match priority {
Priority::Low => {
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::LowPriorityDrop)
}
Priority::Normal => {
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::NormalPriorityDrop)
}
Priority::High => {
// Last resort: briefly wait for space, then drop
match tokio::time::timeout(
Duration::from_millis(5),
self.high_tx.send(packet),
)
.await
{
Ok(Ok(())) => {
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
Ok(())
}
_ => {
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::HighPriorityDrop)
}
}
}
}
}
}
/// Receiving half of the priority channel set.
pub struct PriorityReceiver {
high_rx: mpsc::Receiver<Vec<u8>>,
normal_rx: mpsc::Receiver<Vec<u8>>,
low_rx: mpsc::Receiver<Vec<u8>>,
}
impl PriorityReceiver {
/// Receive the next packet, draining high-priority first (biased select).
pub async fn recv(&mut self) -> Option<Vec<u8>> {
tokio::select! {
biased;
Some(pkt) = self.high_rx.recv() => Some(pkt),
Some(pkt) = self.normal_rx.recv() => Some(pkt),
Some(pkt) = self.low_rx.recv() => Some(pkt),
else => None,
}
}
}
/// Create a priority channel set split into sender and receiver halves.
///
/// - `high_cap`: capacity of the high-priority channel
/// - `normal_cap`: capacity of the normal-priority channel
/// - `low_cap`: capacity of the low-priority channel
pub fn create_priority_channels(
high_cap: usize,
normal_cap: usize,
low_cap: usize,
) -> (PrioritySender, PriorityReceiver) {
let (high_tx, high_rx) = mpsc::channel(high_cap);
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
let (low_tx, low_rx) = mpsc::channel(low_cap);
let stats = Arc::new(QosStats::new());
let sender = PrioritySender {
high_tx,
normal_tx,
low_tx,
stats,
};
let receiver = PriorityReceiver {
high_rx,
normal_rx,
low_rx,
};
(sender, receiver)
}
/// Get a reference to the QoS stats from a sender.
impl PrioritySender {
pub fn stats(&self) -> &Arc<QosStats> {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper: craft a minimal IPv4 packet
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
let mut pkt = vec![0u8; total_len.max(24) as usize];
pkt[0] = 0x45; // version 4, IHL 5
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
pkt[9] = protocol;
// src IP
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
// dst IP
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
// ports (at offset 20 for IHL=5)
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
pkt
}
#[test]
fn classify_icmp_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_dns_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_ssh_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_small_packet_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_normal_http() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_bulk_flow_as_low() {
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
// Send enough traffic to exceed the threshold
for _ in 0..100 {
let pkt = make_ipv4_packet(6, 12345, 80, 500);
c.classify(&pkt);
}
// Next packet from same flow should be Low
let pkt = make_ipv4_packet(6, 12345, 80, 500);
assert_eq!(c.classify(&pkt), Priority::Low);
}
#[test]
fn classify_too_short_packet() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = vec![0u8; 10]; // Too short for IPv4 header
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_non_ipv4() {
let mut c = PacketClassifier::new(1_000_000);
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6 version nibble
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[tokio::test]
async fn priority_receiver_drains_high_first() {
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
// Enqueue in reverse order
sender.send(vec![3], Priority::Low).await.unwrap();
sender.send(vec![2], Priority::Normal).await.unwrap();
sender.send(vec![1], Priority::High).await.unwrap();
// Should drain High first
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
}
#[tokio::test]
async fn smart_dropping_low_priority() {
let (sender, _receiver) = create_priority_channels(8, 8, 1);
// Fill the low channel
sender.send(vec![0], Priority::Low).await.unwrap();
// Next low-priority send should be dropped
let result = sender.send(vec![1], Priority::Low).await;
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn smart_dropping_normal_priority() {
let (sender, _receiver) = create_priority_channels(8, 1, 8);
sender.send(vec![0], Priority::Normal).await.unwrap();
let result = sender.send(vec![1], Priority::Normal).await;
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn stats_track_enqueued() {
let (sender, _receiver) = create_priority_channels(8, 8, 8);
sender.send(vec![1], Priority::High).await.unwrap();
sender.send(vec![2], Priority::High).await.unwrap();
sender.send(vec![3], Priority::Normal).await.unwrap();
sender.send(vec![4], Priority::Low).await.unwrap();
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
}
#[test]
fn flow_tracker_evicts_at_capacity() {
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
ft.record(k1, 100, 1000);
ft.record(k2, 100, 1000);
// Should evict k1 (oldest)
ft.record(k3, 100, 1000);
assert_eq!(ft.flows.len(), 2);
assert!(!ft.flows.contains_key(&k1));
}
}

139
rust/src/ratelimit.rs Normal file
View File

@@ -0,0 +1,139 @@
use std::time::Instant;
/// A token bucket rate limiter operating on byte granularity.
pub struct TokenBucket {
/// Tokens (bytes) added per second.
rate: f64,
/// Maximum burst capacity in bytes.
burst: f64,
/// Currently available tokens.
tokens: f64,
/// Last time tokens were refilled.
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket.
///
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
let burst = burst_bytes as f64;
Self {
rate: rate_bytes_per_sec as f64,
burst,
tokens: burst, // start full
last_refill: Instant::now(),
}
}
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
pub fn try_consume(&mut self, bytes: usize) -> bool {
self.refill();
let needed = bytes as f64;
if needed <= self.tokens {
self.tokens -= needed;
true
} else {
false
}
}
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
self.rate = rate_bytes_per_sec as f64;
self.burst = burst_bytes as f64;
// Cap current tokens at new burst
if self.tokens > self.burst {
self.tokens = self.burst;
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn allows_traffic_under_burst() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Should allow up to burst size immediately
assert!(tb.try_consume(1_500_000));
assert!(tb.try_consume(400_000));
}
#[test]
fn blocks_traffic_over_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
// Consume entire burst
assert!(tb.try_consume(1_000_000));
// Next consume should fail (no time to refill)
assert!(!tb.try_consume(1));
}
#[test]
fn zero_consume_always_succeeds() {
let mut tb = TokenBucket::new(0, 0);
assert!(tb.try_consume(0));
}
#[test]
fn refills_over_time() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
// Drain completely
assert!(tb.try_consume(1_000_000));
assert!(!tb.try_consume(1));
// Wait 100ms — should refill ~100KB
std::thread::sleep(Duration::from_millis(100));
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
}
#[test]
fn update_limits_caps_tokens() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Tokens start at burst (2MB)
tb.update_limits(500_000, 500_000);
// Tokens should be capped to new burst (500KB)
assert!(tb.try_consume(500_000));
assert!(!tb.try_consume(1));
}
#[test]
fn update_limits_changes_rate() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
assert!(tb.try_consume(1_000_000)); // drain
// Change to higher rate
tb.update_limits(10_000_000, 10_000_000);
std::thread::sleep(Duration::from_millis(50));
// At 10MB/s, 50ms should refill ~500KB
assert!(tb.try_consume(200_000));
}
#[test]
fn zero_rate_blocks_after_burst() {
let mut tb = TokenBucket::new(0, 100);
assert!(tb.try_consume(100));
std::thread::sleep(Duration::from_millis(10));
// Zero rate means no refill
assert!(!tb.try_consume(1));
}
#[test]
fn tokens_do_not_exceed_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000);
// Wait to accumulate — but should cap at burst
std::thread::sleep(Duration::from_millis(50));
assert!(tb.try_consume(1_000));
assert!(!tb.try_consume(1));
}
}

View File

@@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
@@ -12,9 +13,14 @@ use tracing::{info, error, warn};
use crate::codec::{Frame, FrameCodec, PacketType}; use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto; use crate::crypto;
use crate::mtu::{MtuConfig, TunnelOverhead};
use crate::network::IpPool; use crate::network::IpPool;
use crate::ratelimit::TokenBucket;
use crate::transport; use crate::transport;
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
/// Server configuration (matches TS IVpnServerConfig). /// Server configuration (matches TS IVpnServerConfig).
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@@ -29,6 +35,10 @@ pub struct ServerConfig {
pub mtu: Option<u16>, pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>, pub keepalive_interval_secs: Option<u64>,
pub enable_nat: Option<bool>, pub enable_nat: Option<bool>,
/// Default rate limit for new clients (bytes/sec). None = unlimited.
pub default_rate_limit_bytes_per_sec: Option<u64>,
/// Default burst size for new clients (bytes). None = unlimited.
pub default_burst_bytes: Option<u64>,
} }
/// Information about a connected client. /// Information about a connected client.
@@ -40,6 +50,12 @@ pub struct ClientInfo {
pub connected_since: String, pub connected_since: String,
pub bytes_sent: u64, pub bytes_sent: u64,
pub bytes_received: u64, pub bytes_received: u64,
pub packets_dropped: u64,
pub bytes_dropped: u64,
pub last_keepalive_at: Option<String>,
pub keepalives_received: u64,
pub rate_limit_bytes_per_sec: Option<u64>,
pub burst_bytes: Option<u64>,
} }
/// Server statistics. /// Server statistics.
@@ -63,6 +79,8 @@ pub struct ServerState {
pub ip_pool: Mutex<IpPool>, pub ip_pool: Mutex<IpPool>,
pub clients: RwLock<HashMap<String, ClientInfo>>, pub clients: RwLock<HashMap<String, ClientInfo>>,
pub stats: RwLock<ServerStatistics>, pub stats: RwLock<ServerStatistics>,
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
pub mtu_config: MtuConfig,
pub started_at: std::time::Instant, pub started_at: std::time::Instant,
} }
@@ -98,11 +116,18 @@ impl VpnServer {
} }
} }
let link_mtu = config.mtu.unwrap_or(1420);
// Compute effective MTU from overhead
let overhead = TunnelOverhead::default_overhead();
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
let state = Arc::new(ServerState { let state = Arc::new(ServerState {
config: config.clone(), config: config.clone(),
ip_pool: Mutex::new(ip_pool), ip_pool: Mutex::new(ip_pool),
clients: RwLock::new(HashMap::new()), clients: RwLock::new(HashMap::new()),
stats: RwLock::new(ServerStatistics::default()), stats: RwLock::new(ServerStatistics::default()),
rate_limiters: Mutex::new(HashMap::new()),
mtu_config,
started_at: std::time::Instant::now(), started_at: std::time::Instant::now(),
}); });
@@ -166,11 +191,52 @@ impl VpnServer {
if let Some(client) = clients.remove(client_id) { if let Some(client) = clients.remove(client_id) {
let ip: Ipv4Addr = client.assigned_ip.parse()?; let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip); state.ip_pool.lock().await.release(&ip);
state.rate_limiters.lock().await.remove(client_id);
info!("Client {} disconnected", client_id); info!("Client {} disconnected", client_id);
} }
} }
Ok(()) Ok(())
} }
/// Set a rate limit for a specific client.
pub async fn set_client_rate_limit(
&self,
client_id: &str,
rate_bytes_per_sec: u64,
burst_bytes: u64,
) -> Result<()> {
if let Some(ref state) = self.state {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(client_id) {
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
} else {
limiters.insert(
client_id.to_string(),
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
);
}
// Update client info
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
info.burst_bytes = Some(burst_bytes);
}
}
Ok(())
}
/// Remove rate limit for a specific client (unlimited).
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
if let Some(ref state) = self.state {
state.rate_limiters.lock().await.remove(client_id);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = None;
info.burst_bytes = None;
}
}
Ok(())
}
} }
async fn run_listener( async fn run_listener(
@@ -257,25 +323,43 @@ async fn handle_client_connection(
let mut noise_transport = responder.into_transport_mode()?; let mut noise_transport = responder.into_transport_mode()?;
// Register client // Register client
let default_rate = state.config.default_rate_limit_bytes_per_sec;
let default_burst = state.config.default_burst_bytes;
let client_info = ClientInfo { let client_info = ClientInfo {
client_id: client_id.clone(), client_id: client_id.clone(),
assigned_ip: assigned_ip.to_string(), assigned_ip: assigned_ip.to_string(),
connected_since: timestamp_now(), connected_since: timestamp_now(),
bytes_sent: 0, bytes_sent: 0,
bytes_received: 0, bytes_received: 0,
packets_dropped: 0,
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: default_rate,
burst_bytes: default_burst,
}; };
state.clients.write().await.insert(client_id.clone(), client_info); state.clients.write().await.insert(client_id.clone(), client_info);
// Set up rate limiter if defaults are configured
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
state
.rate_limiters
.lock()
.await
.insert(client_id.clone(), TokenBucket::new(rate, burst));
}
{ {
let mut stats = state.stats.write().await; let mut stats = state.stats.write().await;
stats.total_connections += 1; stats.total_connections += 1;
} }
// Send assigned IP info (encrypted) // Send assigned IP info (encrypted), include effective MTU
let ip_info = serde_json::json!({ let ip_info = serde_json::json!({
"assignedIp": assigned_ip.to_string(), "assignedIp": assigned_ip.to_string(),
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(), "gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
"mtu": state.config.mtu.unwrap_or(1420), "mtu": state.config.mtu.unwrap_or(1420),
"effectiveMtu": state.mtu_config.effective_mtu,
}); });
let ip_info_bytes = serde_json::to_vec(&ip_info)?; let ip_info_bytes = serde_json::to_vec(&ip_info)?;
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?; let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
@@ -289,19 +373,50 @@ async fn handle_client_connection(
info!("Client {} connected with IP {}", client_id, assigned_ip); info!("Client {} connected with IP {}", client_id, assigned_ip);
// Main packet loop // Main packet loop with dead-peer detection
let mut last_activity = tokio::time::Instant::now();
loop { loop {
match ws_stream.next().await { tokio::select! {
msg = ws_stream.next() => {
match msg {
Some(Ok(Message::Binary(data))) => { Some(Ok(Message::Binary(data))) => {
last_activity = tokio::time::Instant::now();
let mut frame_buf = BytesMut::from(&data[..][..]); let mut frame_buf = BytesMut::from(&data[..][..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) { match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type { Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => { PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) { match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => { Ok(len) => {
// Rate limiting check
let allowed = {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(&client_id) {
limiter.try_consume(len)
} else {
true
}
};
if !allowed {
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.packets_dropped += 1;
info.bytes_dropped += len as u64;
}
continue;
}
let mut stats = state.stats.write().await; let mut stats = state.stats.write().await;
stats.bytes_received += len as u64; stats.bytes_received += len as u64;
stats.packets_received += 1; stats.packets_received += 1;
// Update per-client stats
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_received += len as u64;
}
} }
Err(e) => { Err(e) => {
warn!("Decrypt error from {}: {}", client_id, e); warn!("Decrypt error from {}: {}", client_id, e);
@@ -310,9 +425,10 @@ async fn handle_client_connection(
} }
} }
PacketType::Keepalive => { PacketType::Keepalive => {
// Echo the keepalive payload back in the ACK
let ack_frame = Frame { let ack_frame = Frame {
packet_type: PacketType::KeepaliveAck, packet_type: PacketType::KeepaliveAck,
payload: vec![], payload: frame.payload.clone(),
}; };
let mut frame_bytes = BytesMut::new(); let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?; <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
@@ -321,6 +437,14 @@ async fn handle_client_connection(
let mut stats = state.stats.write().await; let mut stats = state.stats.write().await;
stats.keepalives_received += 1; stats.keepalives_received += 1;
stats.keepalives_sent += 1; stats.keepalives_sent += 1;
// Update per-client keepalive tracking
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.last_keepalive_at = Some(timestamp_now());
info.keepalives_received += 1;
}
} }
PacketType::Disconnect => { PacketType::Disconnect => {
info!("Client {} sent disconnect", client_id); info!("Client {} sent disconnect", client_id);
@@ -344,19 +468,30 @@ async fn handle_client_connection(
break; break;
} }
Some(Ok(Message::Ping(data))) => { Some(Ok(Message::Ping(data))) => {
last_activity = tokio::time::Instant::now();
ws_sink.send(Message::Pong(data)).await?; ws_sink.send(Message::Pong(data)).await?;
} }
Some(Ok(_)) => continue, Some(Ok(_)) => {
last_activity = tokio::time::Instant::now();
continue;
}
Some(Err(e)) => { Some(Err(e)) => {
warn!("WebSocket error from {}: {}", client_id, e); warn!("WebSocket error from {}: {}", client_id, e);
break; break;
} }
} }
} }
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
break;
}
}
}
// Cleanup // Cleanup
state.clients.write().await.remove(&client_id); state.clients.write().await.remove(&client_id);
state.ip_pool.lock().await.release(&assigned_ip); state.ip_pool.lock().await.release(&assigned_ip);
state.rate_limiters.lock().await.remove(&client_id);
info!("Client {} disconnected, released IP {}", client_id, assigned_ip); info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
Ok(()) Ok(())

317
rust/src/telemetry.rs Normal file
View File

@@ -0,0 +1,317 @@
use serde::Serialize;
use std::collections::VecDeque;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
/// A single RTT sample.
#[derive(Debug, Clone)]
struct RttSample {
_rtt: Duration,
_timestamp: Instant,
was_timeout: bool,
}
/// Snapshot of connection quality metrics.
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionQuality {
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
pub srtt_ms: f64,
/// Jitter in milliseconds (mean deviation of RTT).
pub jitter_ms: f64,
/// Minimum RTT observed in the sample window.
pub min_rtt_ms: f64,
/// Maximum RTT observed in the sample window.
pub max_rtt_ms: f64,
/// Packet loss ratio over the sample window (0.0 - 1.0).
pub loss_ratio: f64,
/// Number of consecutive keepalive timeouts (0 if last succeeded).
pub consecutive_timeouts: u32,
/// Total keepalives sent.
pub keepalives_sent: u64,
/// Total keepalive ACKs received.
pub keepalives_acked: u64,
}
/// Tracks connection quality from keepalive round-trips.
pub struct RttTracker {
/// Maximum number of samples to keep in the window.
max_samples: usize,
/// Recent RTT samples (including timeout markers).
samples: VecDeque<RttSample>,
/// When the last keepalive was sent (for computing RTT on ACK).
pending_ping_sent_at: Option<Instant>,
/// Number of consecutive keepalive timeouts.
pub consecutive_timeouts: u32,
/// Smoothed RTT (EMA).
srtt: Option<f64>,
/// Jitter (mean deviation).
jitter: f64,
/// Minimum RTT observed.
min_rtt: f64,
/// Maximum RTT observed.
max_rtt: f64,
/// Total keepalives sent.
keepalives_sent: u64,
/// Total keepalive ACKs received.
keepalives_acked: u64,
/// Previous RTT sample for jitter calculation.
last_rtt_ms: Option<f64>,
}
impl RttTracker {
/// Create a new tracker with the given window size.
pub fn new(max_samples: usize) -> Self {
Self {
max_samples,
samples: VecDeque::with_capacity(max_samples),
pending_ping_sent_at: None,
consecutive_timeouts: 0,
srtt: None,
jitter: 0.0,
min_rtt: f64::MAX,
max_rtt: 0.0,
keepalives_sent: 0,
keepalives_acked: 0,
last_rtt_ms: None,
}
}
/// Record that a keepalive was sent.
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
pub fn mark_ping_sent(&mut self) -> u64 {
self.pending_ping_sent_at = Some(Instant::now());
self.keepalives_sent += 1;
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Record that a keepalive ACK was received with the echoed timestamp.
/// Returns the computed RTT if a pending ping was recorded.
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
let sent_at = self.pending_ping_sent_at.take()?;
let rtt = sent_at.elapsed();
let rtt_ms = rtt.as_secs_f64() * 1000.0;
self.keepalives_acked += 1;
self.consecutive_timeouts = 0;
// Update SRTT (RFC 6298: alpha = 1/8)
match self.srtt {
None => {
self.srtt = Some(rtt_ms);
self.jitter = rtt_ms / 2.0;
}
Some(prev_srtt) => {
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
}
}
// Update min/max
if rtt_ms < self.min_rtt {
self.min_rtt = rtt_ms;
}
if rtt_ms > self.max_rtt {
self.max_rtt = rtt_ms;
}
self.last_rtt_ms = Some(rtt_ms);
// Push sample into window
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(RttSample {
_rtt: rtt,
_timestamp: Instant::now(),
was_timeout: false,
});
Some(rtt)
}
/// Record that a keepalive timed out (no ACK received).
pub fn record_timeout(&mut self) {
self.consecutive_timeouts += 1;
self.pending_ping_sent_at = None;
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(RttSample {
_rtt: Duration::ZERO,
_timestamp: Instant::now(),
was_timeout: true,
});
}
/// Get a snapshot of the current connection quality.
pub fn snapshot(&self) -> ConnectionQuality {
let loss_ratio = if self.samples.is_empty() {
0.0
} else {
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
timeouts as f64 / self.samples.len() as f64
};
ConnectionQuality {
srtt_ms: self.srtt.unwrap_or(0.0),
jitter_ms: self.jitter,
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
max_rtt_ms: self.max_rtt,
loss_ratio,
consecutive_timeouts: self.consecutive_timeouts,
keepalives_sent: self.keepalives_sent,
keepalives_acked: self.keepalives_acked,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_tracker_has_zero_quality() {
let tracker = RttTracker::new(30);
let q = tracker.snapshot();
assert_eq!(q.srtt_ms, 0.0);
assert_eq!(q.jitter_ms, 0.0);
assert_eq!(q.loss_ratio, 0.0);
assert_eq!(q.consecutive_timeouts, 0);
assert_eq!(q.keepalives_sent, 0);
assert_eq!(q.keepalives_acked, 0);
}
#[test]
fn mark_ping_returns_timestamp() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
// Should be a reasonable epoch-ms value (after 2020)
assert!(ts > 1_577_836_800_000);
assert_eq!(tracker.keepalives_sent, 1);
}
#[test]
fn record_ack_computes_rtt() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(5));
let rtt = tracker.record_ack(ts);
assert!(rtt.is_some());
let rtt = rtt.unwrap();
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
assert_eq!(tracker.keepalives_acked, 1);
assert_eq!(tracker.consecutive_timeouts, 0);
}
#[test]
fn record_ack_without_pending_returns_none() {
let mut tracker = RttTracker::new(30);
assert!(tracker.record_ack(12345).is_none());
}
#[test]
fn srtt_converges() {
let mut tracker = RttTracker::new(30);
// Simulate several ping/ack cycles with ~10ms RTT
for _ in 0..10 {
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(10));
tracker.record_ack(ts);
}
let q = tracker.snapshot();
// SRTT should be roughly 10ms (allowing for scheduling variance)
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
}
#[test]
fn timeout_increments_counter_and_loss() {
let mut tracker = RttTracker::new(30);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 1);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 2);
let q = tracker.snapshot();
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
}
#[test]
fn ack_resets_consecutive_timeouts() {
let mut tracker = RttTracker::new(30);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 1);
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
assert_eq!(tracker.consecutive_timeouts, 0);
}
#[test]
fn loss_ratio_over_mixed_window() {
let mut tracker = RttTracker::new(30);
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
for _ in 0..3 {
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
}
tracker.mark_ping_sent();
tracker.record_timeout();
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
let q = tracker.snapshot();
assert!((q.loss_ratio - 0.2).abs() < 0.01);
}
#[test]
fn window_evicts_old_samples() {
let mut tracker = RttTracker::new(5);
// Fill window with 5 timeouts
for _ in 0..5 {
tracker.mark_ping_sent();
tracker.record_timeout();
}
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
// Add 5 successes — should evict all timeouts
for _ in 0..5 {
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
}
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
}
#[test]
fn min_max_rtt_tracked() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(5));
tracker.record_ack(ts);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(15));
tracker.record_ack(ts);
let q = tracker.snapshot();
assert!(q.min_rtt_ms < q.max_rtt_ms);
assert!(q.min_rtt_ms > 0.0);
}
}

View File

@@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
Ok(()) Ok(())
} }
/// Action to take after checking a packet against the MTU.
pub enum TunMtuAction {
/// Packet is within MTU limits, forward it.
Forward,
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
IcmpTooBig(Vec<u8>),
}
/// Check a TUN packet against the MTU and return the appropriate action.
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
match crate::mtu::check_mtu(packet, mtu_config) {
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
}
}
/// Remove a route. /// Remove a route.
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> { pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
let output = tokio::process::Command::new("ip") let output = tokio::process::Command::new("ip")

View File

@@ -0,0 +1,271 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import { VpnClient, VpnServer } from '../ts/index.js';
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
async function findFreePort(): Promise<number> {
const server = net.createServer();
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
const port = (server.address() as net.AddressInfo).port;
await new Promise<void>((resolve) => server.close(() => resolve()));
return port;
}
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function waitFor(
fn: () => Promise<boolean>,
timeoutMs: number = 10000,
pollMs: number = 500,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await fn()) return;
await delay(pollMs);
}
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
}
// ---------------------------------------------------------------------------
// Test state
// ---------------------------------------------------------------------------
let server: VpnServer;
let serverPort: number;
let keypair: IVpnKeypair;
let client: VpnClient;
const extraClients: VpnClient[] = [];
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
tap.test('setup: start VPN server', async () => {
serverPort = await findFreePort();
const options: IVpnServerOptions = {
transport: { transport: 'stdio' },
};
server = new VpnServer(options);
// Phase 1: start the daemon bridge
const started = await server['bridge'].start();
expect(started).toBeTrue();
expect(server.running).toBeTrue();
// Phase 2: generate a keypair
keypair = await server.generateKeypair();
expect(keypair.publicKey).toBeTypeofString();
expect(keypair.privateKey).toBeTypeofString();
// Phase 3: start the VPN listener
const serverConfig: IVpnServerConfig = {
listenAddr: `127.0.0.1:${serverPort}`,
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.8.0.0/24',
};
await server['bridge'].sendCommand('start', { config: serverConfig });
// Verify server is now running
const status = await server.getStatus();
expect(status.state).toEqual('connected');
});
tap.test('single client connects and gets IP', async () => {
const options: IVpnClientOptions = {
transport: { transport: 'stdio' },
};
client = new VpnClient(options);
const started = await client.start();
expect(started).toBeTrue();
const result = await client.connect({
serverUrl: `ws://127.0.0.1:${serverPort}`,
serverPublicKey: keypair.publicKey,
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toBeTypeofString();
expect(result.assignedIp).toStartWith('10.8.0.');
// Verify client status
const clientStatus = await client.getStatus();
expect(clientStatus.state).toEqual('connected');
// Verify server sees the client
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 1;
});
const clients = await server.listClients();
expect(clients.length).toEqual(1);
expect(clients[0].assignedIp).toEqual(result.assignedIp);
});
tap.test('keepalive exchange', async () => {
// Wait for at least 2 keepalive cycles (interval=3s, so 8s should be enough)
await delay(8000);
const clientStats = await client.getStatistics();
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
const serverStats = await server.getStatistics();
expect(serverStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
expect(serverStats.keepalivesSent).toBeGreaterThanOrEqual(1);
// Verify per-client keepalive tracking
const clients = await server.listClients();
expect(clients[0].keepalivesReceived).toBeGreaterThanOrEqual(1);
});
tap.test('connection quality telemetry', async () => {
const quality = await client.getConnectionQuality();
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
expect(quality.jitterMs).toBeTypeofNumber();
expect(quality.minRttMs).toBeGreaterThanOrEqual(0);
expect(quality.maxRttMs).toBeGreaterThanOrEqual(0);
expect(quality.lossRatio).toBeTypeofNumber();
expect(['healthy', 'degraded', 'critical']).toContain(quality.linkHealth);
});
tap.test('rate limiting: set and verify', async () => {
const clients = await server.listClients();
const clientId = clients[0].clientId;
// Set a tight rate limit
await server.setClientRateLimit(clientId, 100, 100);
// Verify via telemetry
const telemetry = await server.getClientTelemetry(clientId);
expect(telemetry.rateLimitBytesPerSec).toEqual(100);
expect(telemetry.burstBytes).toEqual(100);
expect(telemetry.clientId).toEqual(clientId);
});
tap.test('rate limiting: removal', async () => {
const clients = await server.listClients();
const clientId = clients[0].clientId;
await server.removeClientRateLimit(clientId);
// Verify telemetry no longer shows rate limit
const telemetry = await server.getClientTelemetry(clientId);
expect(telemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
expect(telemetry.burstBytes).toBeNullOrUndefined();
// Connection still healthy
const status = await client.getStatus();
expect(status.state).toEqual('connected');
});
tap.test('5 concurrent clients', async () => {
const assignedIps = new Set<string>();
// Get the first client's IP
const existingClients = await server.listClients();
assignedIps.add(existingClients[0].assignedIp);
for (let i = 0; i < 5; i++) {
const c = new VpnClient({ transport: { transport: 'stdio' } });
await c.start();
const result = await c.connect({
serverUrl: `ws://127.0.0.1:${serverPort}`,
serverPublicKey: keypair.publicKey,
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toStartWith('10.8.0.');
assignedIps.add(result.assignedIp);
extraClients.push(c);
}
// All IPs should be unique (6 total: original + 5 new)
expect(assignedIps.size).toEqual(6);
// Server should see 6 clients
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 6;
});
const allClients = await server.listClients();
expect(allClients.length).toEqual(6);
});
tap.test('client disconnect tracking', async () => {
// Disconnect 3 of the 5 extra clients
for (let i = 0; i < 3; i++) {
const c = extraClients[i];
await c.disconnect();
c.stop();
}
// Wait for server to detect disconnections
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 3;
}, 15000);
const clients = await server.listClients();
expect(clients.length).toEqual(3);
const stats = await server.getStatistics();
expect(stats.totalConnections).toBeGreaterThanOrEqual(6);
});
tap.test('server-side client disconnection', async () => {
const clients = await server.listClients();
// Pick one of the remaining extra clients (not the original)
const targetClient = clients.find((c) => {
// Find a client that belongs to extraClients[3] or extraClients[4]
return c.clientId !== clients[0].clientId;
});
expect(targetClient).toBeTruthy();
await server.disconnectClient(targetClient!.clientId);
// Wait for server to update
await waitFor(async () => {
const remaining = await server.listClients();
return remaining.length === 2;
});
const remaining = await server.listClients();
expect(remaining.length).toEqual(2);
});
tap.test('teardown: stop all', async () => {
// Stop the original client
await client.disconnect();
client.stop();
// Stop remaining extra clients
for (const c of extraClients) {
if (c.running) {
try {
await c.disconnect();
} catch {
// May already be disconnected
}
c.stop();
}
}
await delay(500);
// Stop the server
await server.stopServer();
server.stop();
await delay(500);
expect(server.running).toBeFalse();
});
export default tap.start();

357
test/test.loadtest.node.ts Normal file
View File

@@ -0,0 +1,357 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as stream from 'stream';
import { VpnClient, VpnServer } from '../ts/index.js';
import type { IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
async function findFreePort(): Promise<number> {
const server = net.createServer();
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
const port = (server.address() as net.AddressInfo).port;
await new Promise<void>((resolve) => server.close(() => resolve()));
return port;
}
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function waitFor(
fn: () => Promise<boolean>,
timeoutMs: number = 10000,
pollMs: number = 500,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await fn()) return;
await delay(pollMs);
}
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
}
// ---------------------------------------------------------------------------
// ThrottleProxy (adapted from remoteingress)
// ---------------------------------------------------------------------------
class ThrottleTransform extends stream.Transform {
private bytesPerSec: number;
private bucket: number;
private lastRefill: number;
private destroyed_: boolean = false;
constructor(bytesPerSecond: number) {
super();
this.bytesPerSec = bytesPerSecond;
this.bucket = bytesPerSecond;
this.lastRefill = Date.now();
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: stream.TransformCallback) {
if (this.destroyed_) return;
const now = Date.now();
const elapsed = (now - this.lastRefill) / 1000;
this.bucket = Math.min(this.bytesPerSec, this.bucket + elapsed * this.bytesPerSec);
this.lastRefill = now;
if (chunk.length <= this.bucket) {
this.bucket -= chunk.length;
callback(null, chunk);
} else {
const deficit = chunk.length - this.bucket;
this.bucket = 0;
const delayMs = Math.min((deficit / this.bytesPerSec) * 1000, 1000);
setTimeout(() => {
if (this.destroyed_) { callback(); return; }
this.lastRefill = Date.now();
this.bucket = 0;
callback(null, chunk);
}, delayMs);
}
}
_destroy(err: Error | null, callback: (error: Error | null) => void) {
this.destroyed_ = true;
callback(err);
}
}
interface ThrottleProxy {
server: net.Server;
close: () => Promise<void>;
}
async function startThrottleProxy(
listenPort: number,
targetHost: string,
targetPort: number,
bytesPerSecond: number,
): Promise<ThrottleProxy> {
const connections = new Set<net.Socket>();
const server = net.createServer((clientSock) => {
connections.add(clientSock);
const upstream = net.createConnection({ host: targetHost, port: targetPort });
connections.add(upstream);
const throttleUp = new ThrottleTransform(bytesPerSecond);
const throttleDown = new ThrottleTransform(bytesPerSecond);
clientSock.pipe(throttleUp).pipe(upstream);
upstream.pipe(throttleDown).pipe(clientSock);
let cleaned = false;
const cleanup = () => {
if (cleaned) return;
cleaned = true;
throttleUp.destroy();
throttleDown.destroy();
clientSock.destroy();
upstream.destroy();
connections.delete(clientSock);
connections.delete(upstream);
};
clientSock.on('error', () => cleanup());
upstream.on('error', () => cleanup());
throttleUp.on('error', () => cleanup());
throttleDown.on('error', () => cleanup());
clientSock.on('close', () => cleanup());
upstream.on('close', () => cleanup());
});
await new Promise<void>((resolve) => server.listen(listenPort, '127.0.0.1', resolve));
return {
server,
close: async () => {
for (const c of connections) c.destroy();
connections.clear();
await new Promise<void>((resolve) => server.close(() => resolve()));
},
};
}
// ---------------------------------------------------------------------------
// Test state
// ---------------------------------------------------------------------------
let server: VpnServer;
let serverPort: number;
let proxyPort: number;
let keypair: IVpnKeypair;
let throttle: ThrottleProxy;
const allClients: VpnClient[] = [];
async function createConnectedClient(port: number): Promise<VpnClient> {
const c = new VpnClient({ transport: { transport: 'stdio' } });
await c.start();
await c.connect({
serverUrl: `ws://127.0.0.1:${port}`,
serverPublicKey: keypair.publicKey,
keepaliveIntervalSecs: 3,
});
allClients.push(c);
return c;
}
async function stopClient(c: VpnClient): Promise<void> {
if (c.running) {
try { await c.disconnect(); } catch { /* already disconnected */ }
c.stop();
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
tap.test('setup: start throttled VPN tunnel (1 MB/s)', async () => {
serverPort = await findFreePort();
proxyPort = await findFreePort();
// Start VPN server
server = new VpnServer({ transport: { transport: 'stdio' } });
const started = await server['bridge'].start();
expect(started).toBeTrue();
keypair = await server.generateKeypair();
const serverConfig: IVpnServerConfig = {
listenAddr: `127.0.0.1:${serverPort}`,
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.8.0.0/24',
};
await server['bridge'].sendCommand('start', { config: serverConfig });
const status = await server.getStatus();
expect(status.state).toEqual('connected');
// Start throttle proxy: 1 MB/s
throttle = await startThrottleProxy(proxyPort, '127.0.0.1', serverPort, 1024 * 1024);
});
tap.test('throttled connection: handshake succeeds through throttle', async () => {
const client = await createConnectedClient(proxyPort);
const status = await client.getStatus();
expect(status.state).toEqual('connected');
expect(status.assignedIp).toStartWith('10.8.0.');
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 1;
});
});
tap.test('sustained keepalive under throttle', async () => {
// Wait for at least 2 keepalive cycles (3s interval)
await delay(8000);
const client = allClients[0];
const stats = await client.getStatistics();
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
// Throttle adds latency — RTT should be measurable
const quality = await client.getConnectionQuality();
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
expect(quality.jitterMs).toBeTypeofNumber();
});
tap.test('3 concurrent throttled clients', async () => {
for (let i = 0; i < 3; i++) {
await createConnectedClient(proxyPort);
}
// All 4 clients should be visible
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 4;
});
const clients = await server.listClients();
expect(clients.length).toEqual(4);
// Verify all IPs are unique
const ips = new Set(clients.map((c) => c.assignedIp));
expect(ips.size).toEqual(4);
});
tap.test('rate limiting combined with network throttle', async () => {
const clients = await server.listClients();
const targetId = clients[0].clientId;
// Set rate limit on first client
await server.setClientRateLimit(targetId, 500, 500);
const telemetry = await server.getClientTelemetry(targetId);
expect(telemetry.rateLimitBytesPerSec).toEqual(500);
expect(telemetry.burstBytes).toEqual(500);
// Verify another client has no rate limit
const otherTelemetry = await server.getClientTelemetry(clients[1].clientId);
expect(otherTelemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
// Clean up the rate limit
await server.removeClientRateLimit(targetId);
});
tap.test('burst waves: 3 waves of 3 clients', async () => {
const initialCount = (await server.listClients()).length;
for (let wave = 0; wave < 3; wave++) {
const waveClients: VpnClient[] = [];
// Connect 3 clients
for (let i = 0; i < 3; i++) {
const c = await createConnectedClient(proxyPort);
waveClients.push(c);
}
// Verify all connected
await waitFor(async () => {
const all = await server.listClients();
return all.length === initialCount + 3;
});
// Disconnect all wave clients
for (const c of waveClients) {
await stopClient(c);
}
// Wait for server to detect disconnections
await waitFor(async () => {
const all = await server.listClients();
return all.length === initialCount;
}, 15000);
await delay(500);
}
// Verify total connections accumulated
const stats = await server.getStatistics();
expect(stats.totalConnections).toBeGreaterThanOrEqual(9 + initialCount);
// Original clients still connected
const remaining = await server.listClients();
expect(remaining.length).toEqual(initialCount);
});
tap.test('aggressive throttle: 10 KB/s', async () => {
// Close current throttle proxy and start an aggressive one
await throttle.close();
const aggressivePort = await findFreePort();
throttle = await startThrottleProxy(aggressivePort, '127.0.0.1', serverPort, 10 * 1024);
// Connect a client through the aggressive throttle
const client = await createConnectedClient(aggressivePort);
const status = await client.getStatus();
expect(status.state).toEqual('connected');
// Wait for keepalive exchange (might take longer due to throttle)
await delay(10000);
const stats = await client.getStatistics();
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
});
tap.test('post-load health: direct connection still works', async () => {
// Server should still be healthy after all load tests
const serverStatus = await server.getStatus();
expect(serverStatus.state).toEqual('connected');
// Connect one more client directly (no throttle)
const directClient = await createConnectedClient(serverPort);
const status = await directClient.getStatus();
expect(status.state).toEqual('connected');
await delay(5000);
const stats = await directClient.getStatistics();
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
});
tap.test('teardown: stop all', async () => {
// Stop all clients
for (const c of allClients) {
await stopClient(c);
}
await delay(500);
// Close throttle proxy
if (throttle) {
await throttle.close();
}
// Stop server
await server.stopServer();
server.stop();
await delay(500);
expect(server.running).toBeFalse();
});
export default tap.start();

View File

@@ -3,6 +3,6 @@
*/ */
export const commitinfo = { export const commitinfo = {
name: '@push.rocks/smartvpn', name: '@push.rocks/smartvpn',
version: '1.0.3', version: '1.3.0',
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon' description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
} }

View File

@@ -5,6 +5,8 @@ import type {
IVpnClientConfig, IVpnClientConfig,
IVpnStatus, IVpnStatus,
IVpnStatistics, IVpnStatistics,
IVpnConnectionQuality,
IVpnMtuInfo,
TVpnClientCommands, TVpnClientCommands,
} from './smartvpn.interfaces.js'; } from './smartvpn.interfaces.js';
@@ -65,12 +67,26 @@ export class VpnClient extends plugins.events.EventEmitter {
} }
/** /**
* Get traffic statistics. * Get traffic statistics (includes connection quality when connected).
*/ */
public async getStatistics(): Promise<IVpnStatistics> { public async getStatistics(): Promise<IVpnStatistics> {
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>); return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
} }
/**
* Get connection quality metrics (RTT, jitter, loss, link health).
*/
public async getConnectionQuality(): Promise<IVpnConnectionQuality> {
return this.bridge.sendCommand('getConnectionQuality', {} as Record<string, never>);
}
/**
* Get MTU information (overhead, effective MTU, oversized packet stats).
*/
public async getMtuInfo(): Promise<IVpnMtuInfo> {
return this.bridge.sendCommand('getMtuInfo', {} as Record<string, never>);
}
/** /**
* Stop the daemon bridge. * Stop the daemon bridge.
*/ */

View File

@@ -7,6 +7,7 @@ import type {
IVpnServerStatistics, IVpnServerStatistics,
IVpnClientInfo, IVpnClientInfo,
IVpnKeypair, IVpnKeypair,
IVpnClientTelemetry,
TVpnServerCommands, TVpnServerCommands,
} from './smartvpn.interfaces.js'; } from './smartvpn.interfaces.js';
@@ -91,6 +92,35 @@ export class VpnServer extends plugins.events.EventEmitter {
return this.bridge.sendCommand('generateKeypair', {} as Record<string, never>); return this.bridge.sendCommand('generateKeypair', {} as Record<string, never>);
} }
/**
* Set rate limit for a specific client.
*/
public async setClientRateLimit(
clientId: string,
rateBytesPerSec: number,
burstBytes: number,
): Promise<void> {
await this.bridge.sendCommand('setClientRateLimit', {
clientId,
rateBytesPerSec,
burstBytes,
});
}
/**
* Remove rate limit for a specific client (unlimited).
*/
public async removeClientRateLimit(clientId: string): Promise<void> {
await this.bridge.sendCommand('removeClientRateLimit', { clientId });
}
/**
* Get telemetry for a specific client.
*/
public async getClientTelemetry(clientId: string): Promise<IVpnClientTelemetry> {
return this.bridge.sendCommand('getClientTelemetry', { clientId });
}
/** /**
* Stop the daemon bridge. * Stop the daemon bridge.
*/ */

View File

@@ -64,6 +64,10 @@ export interface IVpnServerConfig {
keepaliveIntervalSecs?: number; keepaliveIntervalSecs?: number;
/** Enable NAT/masquerade for client traffic */ /** Enable NAT/masquerade for client traffic */
enableNat?: boolean; enableNat?: boolean;
/** Default rate limit for new clients (bytes/sec). Omit for unlimited. */
defaultRateLimitBytesPerSec?: number;
/** Default burst size for new clients (bytes). Omit for unlimited. */
defaultBurstBytes?: number;
} }
export interface IVpnServerOptions { export interface IVpnServerOptions {
@@ -99,6 +103,7 @@ export interface IVpnStatistics {
keepalivesSent: number; keepalivesSent: number;
keepalivesReceived: number; keepalivesReceived: number;
uptimeSeconds: number; uptimeSeconds: number;
quality?: IVpnConnectionQuality;
} }
export interface IVpnClientInfo { export interface IVpnClientInfo {
@@ -107,6 +112,12 @@ export interface IVpnClientInfo {
connectedSince: string; connectedSince: string;
bytesSent: number; bytesSent: number;
bytesReceived: number; bytesReceived: number;
packetsDropped: number;
bytesDropped: number;
lastKeepaliveAt?: string;
keepalivesReceived: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
} }
export interface IVpnServerStatistics extends IVpnStatistics { export interface IVpnServerStatistics extends IVpnStatistics {
@@ -119,6 +130,53 @@ export interface IVpnKeypair {
privateKey: string; privateKey: string;
} }
// ============================================================================
// QoS: Connection quality
// ============================================================================
export type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
export interface IVpnConnectionQuality {
srttMs: number;
jitterMs: number;
minRttMs: number;
maxRttMs: number;
lossRatio: number;
consecutiveTimeouts: number;
linkHealth: TVpnLinkHealth;
currentKeepaliveIntervalSecs: number;
}
// ============================================================================
// QoS: MTU info
// ============================================================================
export interface IVpnMtuInfo {
tunMtu: number;
effectiveMtu: number;
linkMtu: number;
overheadBytes: number;
oversizedPacketsDropped: number;
icmpTooBigSent: number;
}
// ============================================================================
// QoS: Client telemetry (server-side per-client)
// ============================================================================
export interface IVpnClientTelemetry {
clientId: string;
assignedIp: string;
lastKeepaliveAt?: string;
keepalivesReceived: number;
packetsDropped: number;
bytesDropped: number;
bytesReceived: number;
bytesSent: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
}
// ============================================================================ // ============================================================================
// IPC Command maps (used by smartrust RustBridge<TCommands>) // IPC Command maps (used by smartrust RustBridge<TCommands>)
// ============================================================================ // ============================================================================
@@ -128,6 +186,8 @@ export type TVpnClientCommands = {
disconnect: { params: Record<string, never>; result: void }; disconnect: { params: Record<string, never>; result: void };
getStatus: { params: Record<string, never>; result: IVpnStatus }; getStatus: { params: Record<string, never>; result: IVpnStatus };
getStatistics: { params: Record<string, never>; result: IVpnStatistics }; getStatistics: { params: Record<string, never>; result: IVpnStatistics };
getConnectionQuality: { params: Record<string, never>; result: IVpnConnectionQuality };
getMtuInfo: { params: Record<string, never>; result: IVpnMtuInfo };
}; };
export type TVpnServerCommands = { export type TVpnServerCommands = {
@@ -138,6 +198,9 @@ export type TVpnServerCommands = {
listClients: { params: Record<string, never>; result: { clients: IVpnClientInfo[] } }; listClients: { params: Record<string, never>; result: { clients: IVpnClientInfo[] } };
disconnectClient: { params: { clientId: string }; result: void }; disconnectClient: { params: { clientId: string }; result: void };
generateKeypair: { params: Record<string, never>; result: IVpnKeypair }; generateKeypair: { params: Record<string, never>; result: IVpnKeypair };
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
removeClientRateLimit: { params: { clientId: string }; result: void };
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
}; };
// ============================================================================ // ============================================================================