Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e14c357ba0 | |||
| eb30825f72 | |||
| 835f0f791d | |||
| aec545fe8c | |||
| 4fab721d87 | |||
| 9ee41348e0 | |||
| 97bb148063 | |||
| c8d572b719 |
28
changelog.md
28
changelog.md
@@ -1,5 +1,33 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-03-17 - 1.3.0 - feat(tests,client)
|
||||
add flow control and load test coverage and honor configured keepalive intervals
|
||||
|
||||
- Adds end-to-end node tests for client/server flow control, keepalive exchange, connection quality telemetry, rate limiting, concurrent clients, and disconnect tracking.
|
||||
- Adds load testing with throttled proxy scenarios to validate behavior under constrained bandwidth and repeated client churn.
|
||||
- Updates the Rust client to pass configured keepaliveIntervalSecs into the adaptive keepalive monitor instead of always using defaults.
|
||||
|
||||
## 2026-03-15 - 1.2.0 - feat(readme)
|
||||
document QoS, telemetry, MTU, and rate limiting capabilities in the README
|
||||
|
||||
- Expand the architecture and feature overview to cover adaptive keepalive, telemetry, QoS, rate limiting, and MTU handling
|
||||
- Update client and server examples to show new APIs such as getConnectionQuality(), getMtuInfo(), setClientRateLimit(), and getClientTelemetry()
|
||||
- Add TypeScript interface documentation for connection quality, MTU info, enriched client statistics, and per-client telemetry
|
||||
|
||||
## 2026-03-15 - 1.1.0 - feat(rust-core)
|
||||
add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
|
||||
|
||||
- adds adaptive keepalive monitoring with RTT, jitter, loss, and link health reporting to client statistics and management endpoints
|
||||
- introduces MTU overhead calculation and oversized-packet handling support, plus client MTU info APIs
|
||||
- adds token-bucket rate limiting with configurable default limits and server management commands to set, remove, and inspect per-client telemetry
|
||||
- extends TypeScript client and server interfaces with connection quality, MTU, and client telemetry methods
|
||||
|
||||
## 2026-02-27 - 1.0.3 - fix(build)
|
||||
add aarch64 linker configuration for cross-compilation
|
||||
|
||||
- Added rust/.cargo/config.toml to configure linker for target aarch64-unknown-linux-gnu
|
||||
- Sets linker to 'aarch64-linux-gnu-gcc' to enable cross-compilation to ARM64
|
||||
|
||||
## 2026-02-27 - 1.0.2 - fix()
|
||||
no changes detected - no code or content modifications
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartvpn",
|
||||
"version": "1.0.2",
|
||||
"version": "1.3.0",
|
||||
"private": false,
|
||||
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
|
||||
"type": "module",
|
||||
|
||||
269
readme.md
269
readme.md
@@ -1,6 +1,6 @@
|
||||
# @push.rocks/smartvpn
|
||||
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, typed APIs while all networking heavy lifting — encryption, tunneling, packet forwarding — runs at native speed in Rust.
|
||||
A high-performance VPN with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, fully-typed APIs while all networking heavy lifting — encryption, tunneling, QoS, rate limiting — runs at native speed in Rust.
|
||||
|
||||
## Issue Reporting and Security
|
||||
|
||||
@@ -9,8 +9,6 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community
|
||||
## Install
|
||||
|
||||
```bash
|
||||
npm install @push.rocks/smartvpn
|
||||
# or
|
||||
pnpm install @push.rocks/smartvpn
|
||||
```
|
||||
|
||||
@@ -18,17 +16,21 @@ pnpm install @push.rocks/smartvpn
|
||||
|
||||
```
|
||||
TypeScript (control plane) Rust (data plane)
|
||||
┌──────────────────────────┐ ┌───────────────────────────────┐
|
||||
│ VpnClient / VpnServer │ │ smartvpn_daemon │
|
||||
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
|
||||
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
|
||||
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha) │
|
||||
└──────────────────────────┘ │ ├─ codec (binary framing) │
|
||||
│ ├─ keepalive (app-level) │
|
||||
│ ├─ tunnel (TUN device) │
|
||||
│ ├─ network (NAT/IP pool) │
|
||||
│ └─ reconnect (backoff) │
|
||||
└───────────────────────────────┘
|
||||
┌──────────────────────────┐ ┌────────────────────────────────────┐
|
||||
│ VpnClient / VpnServer │ │ smartvpn_daemon │
|
||||
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
|
||||
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
|
||||
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha20) │
|
||||
└──────────────────────────┘ │ ├─ codec (binary framing) │
|
||||
│ ├─ keepalive (adaptive state FSM) │
|
||||
│ ├─ telemetry (RTT/jitter/loss) │
|
||||
│ ├─ qos (classify + priority Q) │
|
||||
│ ├─ ratelimit (token bucket) │
|
||||
│ ├─ mtu (overhead calc + ICMP) │
|
||||
│ ├─ tunnel (TUN device) │
|
||||
│ ├─ network (NAT/IP pool) │
|
||||
│ └─ reconnect (exp. backoff) │
|
||||
└────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key design decisions:**
|
||||
@@ -37,8 +39,10 @@ TypeScript (control plane) Rust (data plane)
|
||||
|----------|--------|-----|
|
||||
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
|
||||
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
|
||||
| Keepalive | App-level (not WS pings) | Cloudflare drops WS ping frames; app-level pings survive |
|
||||
| IPC | JSON lines over stdio/Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
|
||||
| Keepalive | Adaptive app-level pings | Cloudflare drops WS pings; interval adapts to link health (10–60s) |
|
||||
| QoS | Packet classification + priority queues | DNS/SSH/ICMP always drain first; bulk flows get deprioritized |
|
||||
| Rate limiting | Per-client token bucket | Byte-granular, dynamically reconfigurable via IPC |
|
||||
| IPC | JSON lines over stdio / Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
|
||||
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
|
||||
|
||||
## 🚀 Quick Start
|
||||
@@ -48,15 +52,12 @@ TypeScript (control plane) Rust (data plane)
|
||||
```typescript
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Development: spawn the Rust daemon as a child process
|
||||
const client = new VpnClient({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon bridge
|
||||
await client.start();
|
||||
|
||||
// Connect to a VPN server
|
||||
const { assignedIp } = await client.connect({
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
|
||||
@@ -67,15 +68,23 @@ const { assignedIp } = await client.connect({
|
||||
|
||||
console.log(`Connected! Assigned IP: ${assignedIp}`);
|
||||
|
||||
// Check status
|
||||
const status = await client.getStatus();
|
||||
console.log(status); // { state: 'connected', assignedIp: '10.8.0.2', ... }
|
||||
// Connection quality (adaptive keepalive + telemetry)
|
||||
const quality = await client.getConnectionQuality();
|
||||
console.log(quality);
|
||||
// {
|
||||
// srttMs: 42.5, jitterMs: 3.2, minRttMs: 38.0, maxRttMs: 67.0,
|
||||
// lossRatio: 0.0, consecutiveTimeouts: 0,
|
||||
// linkHealth: 'healthy', currentKeepaliveIntervalSecs: 60
|
||||
// }
|
||||
|
||||
// Get traffic stats
|
||||
// MTU info
|
||||
const mtu = await client.getMtuInfo();
|
||||
console.log(mtu);
|
||||
// { tunMtu: 1420, effectiveMtu: 1421, linkMtu: 1500, overheadBytes: 79, ... }
|
||||
|
||||
// Traffic stats (includes quality snapshot)
|
||||
const stats = await client.getStatistics();
|
||||
console.log(stats); // { bytesSent, bytesReceived, packetsSent, ... }
|
||||
|
||||
// Disconnect
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
```
|
||||
@@ -89,33 +98,44 @@ const server = new VpnServer({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon and the VPN server
|
||||
// Generate a Noise keypair first
|
||||
await server.start();
|
||||
// If you don't have keys yet:
|
||||
const keypair = await server.generateKeypair();
|
||||
|
||||
// Start the VPN listener (or pass config to start() directly)
|
||||
await server.start({
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'BASE64_PRIVATE_KEY',
|
||||
publicKey: 'BASE64_PUBLIC_KEY',
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
enableNat: true,
|
||||
// Optional: default rate limit for all new clients
|
||||
defaultRateLimitBytesPerSec: 10_000_000, // 10 MB/s
|
||||
defaultBurstBytes: 20_000_000, // 20 MB burst
|
||||
});
|
||||
|
||||
// Generate a Noise keypair
|
||||
const keypair = await server.generateKeypair();
|
||||
console.log(keypair); // { publicKey: '...', privateKey: '...' }
|
||||
|
||||
// List connected clients
|
||||
const clients = await server.listClients();
|
||||
// [{ clientId, assignedIp, connectedSince, bytesSent, bytesReceived }]
|
||||
|
||||
// Disconnect a specific client
|
||||
await server.disconnectClient('some-client-id');
|
||||
// Per-client rate limiting (live, no reconnect needed)
|
||||
await server.setClientRateLimit('client-id', 5_000_000, 10_000_000);
|
||||
await server.removeClientRateLimit('client-id'); // unlimited
|
||||
|
||||
// Get server stats
|
||||
const stats = await server.getStatistics();
|
||||
// { bytesSent, bytesReceived, activeClients, totalConnections, ... }
|
||||
// Per-client telemetry
|
||||
const telemetry = await server.getClientTelemetry('client-id');
|
||||
console.log(telemetry);
|
||||
// {
|
||||
// clientId, assignedIp, lastKeepaliveAt, keepalivesReceived,
|
||||
// packetsDropped, bytesDropped, bytesReceived, bytesSent,
|
||||
// rateLimitBytesPerSec, burstBytes
|
||||
// }
|
||||
|
||||
// Kick a client
|
||||
await server.disconnectClient('client-id');
|
||||
|
||||
// Stop
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
```
|
||||
@@ -151,7 +171,9 @@ When using socket transport, `client.stop()` closes the socket but **does not ki
|
||||
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
|
||||
| `disconnect()` | `Promise<void>` | Disconnect from VPN |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
|
||||
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic statistics |
|
||||
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic stats + connection quality |
|
||||
| `getConnectionQuality()` | `Promise<IVpnConnectionQuality>` | RTT, jitter, loss, link health |
|
||||
| `getMtuInfo()` | `Promise<IVpnMtuInfo>` | MTU info and overhead breakdown |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
| `running` | `boolean` | Whether bridge is active |
|
||||
|
||||
@@ -163,9 +185,12 @@ When using socket transport, `client.stop()` closes the socket but **does not ki
|
||||
| `stopServer()` | `Promise<void>` | Stop the VPN server |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
|
||||
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
|
||||
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients |
|
||||
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients with QoS stats |
|
||||
| `disconnectClient(id)` | `Promise<void>` | Kick a client |
|
||||
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
|
||||
| `setClientRateLimit(id, rate, burst)` | `Promise<void>` | Set per-client rate limit (bytes/sec) |
|
||||
| `removeClientRateLimit(id)` | `Promise<void>` | Remove rate limit (unlimited) |
|
||||
| `getClientTelemetry(id)` | `Promise<IVpnClientTelemetry>` | Per-client telemetry + drop stats |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
|
||||
### `VpnConfig`
|
||||
@@ -191,26 +216,23 @@ Generate system service units for the daemon:
|
||||
```typescript
|
||||
import { VpnInstaller } from '@push.rocks/smartvpn';
|
||||
|
||||
// Auto-detect platform
|
||||
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
|
||||
|
||||
// Generate systemd unit (Linux)
|
||||
// Linux (systemd)
|
||||
const unit = VpnInstaller.generateSystemdUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'server',
|
||||
});
|
||||
// unit.content = full systemd .service file
|
||||
// unit.installPath = '/etc/systemd/system/smartvpn-server.service'
|
||||
|
||||
// Generate launchd plist (macOS)
|
||||
// macOS (launchd)
|
||||
const plist = VpnInstaller.generateLaunchdPlist({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'client',
|
||||
});
|
||||
|
||||
// Auto-detect and generate
|
||||
// Auto-detect platform
|
||||
const serviceUnit = VpnInstaller.generateServiceUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
@@ -223,8 +245,6 @@ const serviceUnit = VpnInstaller.generateServiceUnit({
|
||||
Both `VpnClient` and `VpnServer` extend `EventEmitter`:
|
||||
|
||||
```typescript
|
||||
client.on('status', (status) => { /* IVpnStatus */ });
|
||||
client.on('error', (err) => { /* { message, code? } */ });
|
||||
client.on('exit', ({ code, signal }) => { /* daemon exited */ });
|
||||
client.on('reconnected', () => { /* socket reconnected */ });
|
||||
|
||||
@@ -232,13 +252,84 @@ server.on('client-connected', (info) => { /* IVpnClientInfo */ });
|
||||
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
|
||||
```
|
||||
|
||||
## 📊 QoS System
|
||||
|
||||
The Rust daemon includes a full QoS stack that operates on decrypted IP packets:
|
||||
|
||||
### Adaptive Keepalive
|
||||
|
||||
The keepalive system automatically adjusts its interval based on connection quality:
|
||||
|
||||
| Link Health | Keepalive Interval | Triggered When |
|
||||
|-------------|-------------------|----------------|
|
||||
| 🟢 Healthy | 60s | Jitter < 30ms, loss < 2%, no timeouts |
|
||||
| 🟡 Degraded | 30s | Jitter > 50ms, loss > 5%, or 1+ timeout |
|
||||
| 🔴 Critical | 10s | Loss > 20% or 2+ consecutive timeouts |
|
||||
|
||||
State transitions include hysteresis (3 consecutive good checks to upgrade, 2 to recover) to prevent flapping. Dead peer detection fires after 3 consecutive timeouts in Critical state.
|
||||
|
||||
### Packet Classification
|
||||
|
||||
IP packets are classified into three priority levels by inspecting headers (no deep packet inspection):
|
||||
|
||||
| Priority | Traffic |
|
||||
|----------|---------|
|
||||
| **High** | ICMP, DNS (port 53), SSH (port 22), small packets (< 128 bytes) |
|
||||
| **Normal** | Everything else |
|
||||
| **Low** | Bulk flows exceeding 1 MB within a 60s window |
|
||||
|
||||
Priority channels drain with biased `tokio::select!` — high-priority packets always go first.
|
||||
|
||||
### Smart Packet Dropping
|
||||
|
||||
Under backpressure, packets are dropped intelligently:
|
||||
|
||||
1. **Low** queue full → drop silently
|
||||
2. **Normal** queue full → drop
|
||||
3. **High** queue full → wait 5ms, then drop as last resort
|
||||
|
||||
Drop statistics are tracked per priority level and exposed via telemetry.
|
||||
|
||||
### Per-Client Rate Limiting
|
||||
|
||||
Token bucket algorithm with byte granularity:
|
||||
|
||||
```typescript
|
||||
// Set: 10 MB/s sustained, 20 MB burst
|
||||
await server.setClientRateLimit('client-id', 10_000_000, 20_000_000);
|
||||
|
||||
// Check drops via telemetry
|
||||
const t = await server.getClientTelemetry('client-id');
|
||||
console.log(`Dropped: ${t.packetsDropped} packets, ${t.bytesDropped} bytes`);
|
||||
|
||||
// Remove limit
|
||||
await server.removeClientRateLimit('client-id');
|
||||
```
|
||||
|
||||
Rate limits can be changed live without disconnecting the client.
|
||||
|
||||
### Path MTU
|
||||
|
||||
Tunnel overhead is calculated precisely:
|
||||
|
||||
| Layer | Bytes |
|
||||
|-------|-------|
|
||||
| IP header | 20 |
|
||||
| TCP header (with timestamps) | 32 |
|
||||
| WebSocket framing | 6 |
|
||||
| VPN frame header | 5 |
|
||||
| Noise AEAD tag | 16 |
|
||||
| **Total overhead** | **79** |
|
||||
|
||||
For a standard 1500-byte Ethernet link, effective TUN MTU = **1421 bytes**. The default TUN MTU of 1420 is conservative and correct. Oversized packets get an ICMP "Fragmentation Needed" (Type 3, Code 4) written back into the TUN, so the source TCP adjusts its MSS automatically.
|
||||
|
||||
## 🔐 Security Model
|
||||
|
||||
The VPN uses a **Noise NK** handshake pattern:
|
||||
|
||||
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
|
||||
2. The client generates an ephemeral keypair, performs `e, es` (Diffie-Hellman with server's static key)
|
||||
3. Server responds with `e, ee` (Diffie-Hellman with both ephemeral keys)
|
||||
2. The client generates an ephemeral keypair, performs `e, es` (DH with server's static key)
|
||||
3. Server responds with `e, ee` (DH with both ephemeral keys)
|
||||
4. Result: forward-secret transport keys derived from both DH operations
|
||||
|
||||
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
|
||||
@@ -261,8 +352,8 @@ Inside the WebSocket tunnel, packets use a simple binary framing:
|
||||
| `HandshakeInit` | `0x01` | Client → Server handshake |
|
||||
| `HandshakeResp` | `0x02` | Server → Client handshake |
|
||||
| `IpPacket` | `0x10` | Encrypted IP packet |
|
||||
| `Keepalive` | `0x20` | App-level ping |
|
||||
| `KeepaliveAck` | `0x21` | App-level pong |
|
||||
| `Keepalive` | `0x20` | App-level ping (8-byte timestamp payload) |
|
||||
| `KeepaliveAck` | `0x21` | App-level pong (echoes timestamp for RTT) |
|
||||
| `SessionResume` | `0x30` | Resume a dropped session |
|
||||
| `SessionResumeOk` | `0x31` | Resume accepted |
|
||||
| `SessionResumeErr` | `0x32` | Resume rejected |
|
||||
@@ -270,8 +361,6 @@ Inside the WebSocket tunnel, packets use a simple binary framing:
|
||||
|
||||
## 🛠️ Rust Daemon CLI
|
||||
|
||||
The Rust binary supports several modes:
|
||||
|
||||
```bash
|
||||
# Development: stdio management (JSON lines on stdin/stdout)
|
||||
smartvpn_daemon --management --mode client
|
||||
@@ -290,16 +379,14 @@ smartvpn_daemon --generate-keypair
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Build TypeScript + cross-compile Rust
|
||||
# Build TypeScript + cross-compile Rust (amd64 + arm64)
|
||||
pnpm build
|
||||
|
||||
# Build Rust only (debug)
|
||||
cd rust && cargo build
|
||||
|
||||
# Run Rust tests
|
||||
# Run all tests (71 Rust + 32 TypeScript)
|
||||
cd rust && cargo test
|
||||
|
||||
# Run TypeScript tests
|
||||
pnpm test
|
||||
```
|
||||
|
||||
@@ -323,25 +410,27 @@ type TVpnTransportOptions =
|
||||
|
||||
// Client config
|
||||
interface IVpnClientConfig {
|
||||
serverUrl: string; // e.g. 'wss://vpn.example.com/tunnel'
|
||||
serverPublicKey: string; // base64-encoded Noise static key
|
||||
serverUrl: string;
|
||||
serverPublicKey: string;
|
||||
dns?: string[];
|
||||
mtu?: number; // default: 1420
|
||||
keepaliveIntervalSecs?: number; // default: 30
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
}
|
||||
|
||||
// Server config
|
||||
interface IVpnServerConfig {
|
||||
listenAddr: string; // e.g. '0.0.0.0:443'
|
||||
privateKey: string; // base64 Noise static private key
|
||||
publicKey: string; // base64 Noise static public key
|
||||
subnet: string; // e.g. '10.8.0.0/24'
|
||||
listenAddr: string;
|
||||
privateKey: string;
|
||||
publicKey: string;
|
||||
subnet: string;
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
defaultBurstBytes?: number;
|
||||
}
|
||||
|
||||
// Status
|
||||
@@ -365,6 +454,7 @@ interface IVpnStatistics {
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
quality?: IVpnConnectionQuality;
|
||||
}
|
||||
|
||||
interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -372,12 +462,57 @@ interface IVpnServerStatistics extends IVpnStatistics {
|
||||
totalConnections: number;
|
||||
}
|
||||
|
||||
// Connection quality (QoS)
|
||||
type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
|
||||
|
||||
interface IVpnConnectionQuality {
|
||||
srttMs: number;
|
||||
jitterMs: number;
|
||||
minRttMs: number;
|
||||
maxRttMs: number;
|
||||
lossRatio: number;
|
||||
consecutiveTimeouts: number;
|
||||
linkHealth: TVpnLinkHealth;
|
||||
currentKeepaliveIntervalSecs: number;
|
||||
}
|
||||
|
||||
// MTU info
|
||||
interface IVpnMtuInfo {
|
||||
tunMtu: number;
|
||||
effectiveMtu: number;
|
||||
linkMtu: number;
|
||||
overheadBytes: number;
|
||||
oversizedPacketsDropped: number;
|
||||
icmpTooBigSent: number;
|
||||
}
|
||||
|
||||
// Client info (with QoS fields)
|
||||
interface IVpnClientInfo {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// Per-client telemetry
|
||||
interface IVpnClientTelemetry {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
bytesReceived: number;
|
||||
bytesSent: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
interface IVpnKeypair {
|
||||
|
||||
2
rust/.cargo/config.toml
Normal file
2
rust/.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
@@ -3,13 +3,14 @@ use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
@@ -65,6 +66,8 @@ pub struct VpnClient {
|
||||
assigned_ip: Arc<RwLock<Option<String>>>,
|
||||
shutdown_tx: Option<mpsc::Sender<()>>,
|
||||
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
|
||||
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
}
|
||||
|
||||
impl VpnClient {
|
||||
@@ -75,6 +78,8 @@ impl VpnClient {
|
||||
assigned_ip: Arc::new(RwLock::new(None)),
|
||||
shutdown_tx: None,
|
||||
connected_since: Arc::new(RwLock::new(None)),
|
||||
quality_rx: None,
|
||||
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +98,7 @@ impl VpnClient {
|
||||
let stats = self.stats.clone();
|
||||
let assigned_ip_ref = self.assigned_ip.clone();
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
@@ -161,6 +167,20 @@ impl VpnClient {
|
||||
|
||||
info!("Connected to VPN, assigned IP: {}", assigned_ip);
|
||||
|
||||
// Create adaptive keepalive monitor (use custom interval if configured)
|
||||
let ka_config = config.keepalive_interval_secs.map(|secs| {
|
||||
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
|
||||
cfg.degraded_interval = std::time::Duration::from_secs(secs);
|
||||
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
|
||||
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
|
||||
cfg
|
||||
});
|
||||
let (monitor, handle) = keepalive::create_keepalive(ka_config);
|
||||
self.quality_rx = Some(handle.quality_rx);
|
||||
|
||||
// Spawn the keepalive monitor
|
||||
tokio::spawn(monitor.run());
|
||||
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
@@ -170,7 +190,9 @@ impl VpnClient {
|
||||
state,
|
||||
stats,
|
||||
shutdown_rx,
|
||||
config.keepalive_interval_secs.unwrap_or(30),
|
||||
handle.signal_rx,
|
||||
handle.ack_tx,
|
||||
link_health,
|
||||
));
|
||||
|
||||
Ok(assigned_ip_clone)
|
||||
@@ -184,6 +206,7 @@ impl VpnClient {
|
||||
*self.assigned_ip.write().await = None;
|
||||
*self.connected_since.write().await = None;
|
||||
*self.state.write().await = ClientState::Disconnected;
|
||||
self.quality_rx = None;
|
||||
info!("Disconnected from VPN");
|
||||
Ok(())
|
||||
}
|
||||
@@ -208,13 +231,14 @@ impl VpnClient {
|
||||
status
|
||||
}
|
||||
|
||||
/// Get traffic statistics.
|
||||
/// Get traffic statistics (includes connection quality).
|
||||
pub async fn get_statistics(&self) -> serde_json::Value {
|
||||
let stats = self.stats.read().await;
|
||||
let since = self.connected_since.read().await;
|
||||
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
|
||||
let health = self.link_health.read().await;
|
||||
|
||||
serde_json::json!({
|
||||
let mut result = serde_json::json!({
|
||||
"bytesSent": stats.bytes_sent,
|
||||
"bytesReceived": stats.bytes_received,
|
||||
"packetsSent": stats.packets_sent,
|
||||
@@ -222,7 +246,35 @@ impl VpnClient {
|
||||
"keepalivesSent": stats.keepalives_sent,
|
||||
"keepalivesReceived": stats.keepalives_received,
|
||||
"uptimeSeconds": uptime,
|
||||
})
|
||||
});
|
||||
|
||||
// Include connection quality if available
|
||||
if let Some(ref rx) = self.quality_rx {
|
||||
let quality = rx.borrow().clone();
|
||||
result["quality"] = serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", *health),
|
||||
"keepalivesSent": quality.keepalives_sent,
|
||||
"keepalivesAcked": quality.keepalives_acked,
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get connection quality snapshot.
|
||||
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
|
||||
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
|
||||
}
|
||||
|
||||
/// Get current link health.
|
||||
pub async fn get_link_health(&self) -> LinkHealth {
|
||||
*self.link_health.read().await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,11 +286,11 @@ async fn client_loop(
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
keepalive_secs: u64,
|
||||
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
) {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
|
||||
keepalive_ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -264,6 +316,8 @@ async fn client_loop(
|
||||
}
|
||||
PacketType::KeepaliveAck => {
|
||||
stats.write().await.keepalives_received += 1;
|
||||
// Signal the keepalive monitor that ACK was received
|
||||
let _ = ack_tx.send(()).await;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Server sent disconnect");
|
||||
@@ -290,19 +344,37 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = keepalive_ticker.tick() => {
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
signal = signal_rx.recv() => {
|
||||
match signal {
|
||||
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
|
||||
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: timestamp_ms.to_be_bytes().to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
}
|
||||
}
|
||||
Some(KeepaliveSignal::PeerDead) => {
|
||||
warn!("Peer declared dead by keepalive monitor");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
|
||||
debug!("Link health changed to: {}", health);
|
||||
*link_health.write().await = health;
|
||||
}
|
||||
None => {
|
||||
// Keepalive monitor channel closed
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
|
||||
@@ -1,87 +1,464 @@
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio::time::{interval, timeout};
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Default keepalive interval (30 seconds).
|
||||
use crate::telemetry::{ConnectionQuality, RttTracker};
|
||||
|
||||
/// Default keepalive interval (30 seconds — used for Degraded state).
|
||||
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Default keepalive ACK timeout (10 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default keepalive ACK timeout (5 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Link health states for adaptive keepalive.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LinkHealth {
|
||||
/// RTT stable, jitter low, no loss. Interval: 60s.
|
||||
Healthy,
|
||||
/// Elevated jitter or occasional loss. Interval: 30s.
|
||||
Degraded,
|
||||
/// High loss or sustained jitter spike. Interval: 10s.
|
||||
Critical,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LinkHealth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded => write!(f, "degraded"),
|
||||
Self::Critical => write!(f, "critical"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the adaptive keepalive state machine.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveKeepaliveConfig {
|
||||
/// Interval when link health is Healthy.
|
||||
pub healthy_interval: Duration,
|
||||
/// Interval when link health is Degraded.
|
||||
pub degraded_interval: Duration,
|
||||
/// Interval when link health is Critical.
|
||||
pub critical_interval: Duration,
|
||||
/// ACK timeout (how long to wait for ACK before declaring timeout).
|
||||
pub ack_timeout: Duration,
|
||||
/// Jitter threshold (ms) to enter Degraded from Healthy.
|
||||
pub jitter_degraded_ms: f64,
|
||||
/// Jitter threshold (ms) to return to Healthy from Degraded.
|
||||
pub jitter_healthy_ms: f64,
|
||||
/// Loss ratio threshold to enter Degraded.
|
||||
pub loss_degraded: f64,
|
||||
/// Loss ratio threshold to enter Critical.
|
||||
pub loss_critical: f64,
|
||||
/// Loss ratio threshold to return from Critical to Degraded.
|
||||
pub loss_recover: f64,
|
||||
/// Loss ratio threshold to return from Degraded to Healthy.
|
||||
pub loss_healthy: f64,
|
||||
/// Consecutive checks required for upward state transitions (hysteresis).
|
||||
pub upgrade_checks: u32,
|
||||
/// Consecutive timeouts to declare peer dead in Critical state.
|
||||
pub dead_peer_timeouts: u32,
|
||||
}
|
||||
|
||||
impl Default for AdaptiveKeepaliveConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
healthy_interval: Duration::from_secs(60),
|
||||
degraded_interval: Duration::from_secs(30),
|
||||
critical_interval: Duration::from_secs(10),
|
||||
ack_timeout: Duration::from_secs(5),
|
||||
jitter_degraded_ms: 50.0,
|
||||
jitter_healthy_ms: 30.0,
|
||||
loss_degraded: 0.05,
|
||||
loss_critical: 0.20,
|
||||
loss_recover: 0.10,
|
||||
loss_healthy: 0.02,
|
||||
upgrade_checks: 3,
|
||||
dead_peer_timeouts: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signals from the keepalive monitor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeepaliveSignal {
|
||||
/// Time to send a keepalive ping.
|
||||
SendPing,
|
||||
/// Peer is considered dead (no ACK received within timeout).
|
||||
/// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
|
||||
SendPing(u64),
|
||||
/// Peer is considered dead (no ACK received within timeout repeatedly).
|
||||
PeerDead,
|
||||
/// Link health state changed.
|
||||
LinkHealthChanged(LinkHealth),
|
||||
}
|
||||
|
||||
/// A keepalive monitor that emits signals on a channel.
|
||||
/// A keepalive monitor with adaptive interval and RTT tracking.
|
||||
pub struct KeepaliveMonitor {
|
||||
interval: Duration,
|
||||
timeout_duration: Duration,
|
||||
config: AdaptiveKeepaliveConfig,
|
||||
health: LinkHealth,
|
||||
rtt_tracker: RttTracker,
|
||||
signal_tx: mpsc::Sender<KeepaliveSignal>,
|
||||
ack_rx: mpsc::Receiver<()>,
|
||||
quality_tx: watch::Sender<ConnectionQuality>,
|
||||
consecutive_upgrade_checks: u32,
|
||||
}
|
||||
|
||||
/// Handle returned to the caller to send ACKs and receive signals.
|
||||
pub struct KeepaliveHandle {
|
||||
pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
pub ack_tx: mpsc::Sender<()>,
|
||||
pub quality_rx: watch::Receiver<ConnectionQuality>,
|
||||
}
|
||||
|
||||
/// Create a keepalive monitor and its handle.
|
||||
/// Create an adaptive keepalive monitor and its handle.
|
||||
pub fn create_keepalive(
|
||||
keepalive_interval: Option<Duration>,
|
||||
keepalive_timeout: Option<Duration>,
|
||||
config: Option<AdaptiveKeepaliveConfig>,
|
||||
) -> (KeepaliveMonitor, KeepaliveHandle) {
|
||||
let config = config.unwrap_or_default();
|
||||
let (signal_tx, signal_rx) = mpsc::channel(8);
|
||||
let (ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
let monitor = KeepaliveMonitor {
|
||||
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL),
|
||||
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT),
|
||||
config,
|
||||
health: LinkHealth::Degraded, // start in Degraded, earn Healthy
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
};
|
||||
|
||||
let handle = KeepaliveHandle { signal_rx, ack_tx };
|
||||
let handle = KeepaliveHandle {
|
||||
signal_rx,
|
||||
ack_tx,
|
||||
quality_rx,
|
||||
};
|
||||
|
||||
(monitor, handle)
|
||||
}
|
||||
|
||||
impl KeepaliveMonitor {
|
||||
fn current_interval(&self) -> Duration {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => self.config.healthy_interval,
|
||||
LinkHealth::Degraded => self.config.degraded_interval,
|
||||
LinkHealth::Critical => self.config.critical_interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the keepalive loop. Blocks until the peer is dead or channels close.
|
||||
pub async fn run(mut self) {
|
||||
let mut ticker = interval(self.interval);
|
||||
let mut ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
debug!("Sending keepalive ping signal");
|
||||
|
||||
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() {
|
||||
// Channel closed
|
||||
break;
|
||||
// Record ping sent, get timestamp for payload
|
||||
let timestamp_ms = self.rtt_tracker.mark_ping_sent();
|
||||
debug!("Sending keepalive ping (ts={})", timestamp_ms);
|
||||
|
||||
if self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::SendPing(timestamp_ms))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break; // channel closed
|
||||
}
|
||||
|
||||
// Wait for ACK within timeout
|
||||
match timeout(self.timeout_duration, self.ack_rx.recv()).await {
|
||||
match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
|
||||
Ok(Some(())) => {
|
||||
debug!("Keepalive ACK received");
|
||||
if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
|
||||
debug!("Keepalive ACK received, RTT: {:?}", rtt);
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Channel closed
|
||||
break;
|
||||
break; // channel closed
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Keepalive ACK timeout — peer considered dead");
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
self.rtt_tracker.record_timeout();
|
||||
warn!(
|
||||
"Keepalive ACK timeout (consecutive: {})",
|
||||
self.rtt_tracker.consecutive_timeouts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Publish quality snapshot
|
||||
let quality = self.rtt_tracker.snapshot();
|
||||
let _ = self.quality_tx.send(quality.clone());
|
||||
|
||||
// Evaluate state transition
|
||||
let new_health = self.evaluate_health(&quality);
|
||||
|
||||
if new_health != self.health {
|
||||
info!("Link health: {} -> {}", self.health, new_health);
|
||||
self.health = new_health;
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
|
||||
// Reset ticker to new interval
|
||||
ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
let _ = self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::LinkHealthChanged(new_health))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Check for dead peer in Critical state
|
||||
if self.health == LinkHealth::Critical
|
||||
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
|
||||
{
|
||||
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
|
||||
self.rtt_tracker.consecutive_timeouts);
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => {
|
||||
// Downgrade conditions
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
if quality.jitter_ms > self.config.jitter_degraded_ms
|
||||
|| quality.loss_ratio > self.config.loss_degraded
|
||||
|| quality.consecutive_timeouts >= 1
|
||||
{
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
LinkHealth::Healthy
|
||||
}
|
||||
LinkHealth::Degraded => {
|
||||
// Downgrade to Critical
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
// Upgrade to Healthy (with hysteresis)
|
||||
if quality.jitter_ms < self.config.jitter_healthy_ms
|
||||
&& quality.loss_ratio < self.config.loss_healthy
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Healthy;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Degraded
|
||||
}
|
||||
LinkHealth::Critical => {
|
||||
// Upgrade to Degraded (with hysteresis), never directly to Healthy
|
||||
if quality.loss_ratio < self.config.loss_recover
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= 2 {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Critical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let config = AdaptiveKeepaliveConfig::default();
|
||||
assert_eq!(config.healthy_interval, Duration::from_secs(60));
|
||||
assert_eq!(config.degraded_interval, Duration::from_secs(30));
|
||||
assert_eq!(config.critical_interval, Duration::from_secs(10));
|
||||
assert_eq!(config.ack_timeout, Duration::from_secs(5));
|
||||
assert_eq!(config.dead_peer_timeouts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_health_display() {
|
||||
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
|
||||
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
|
||||
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
|
||||
}
|
||||
|
||||
// Helper to create a monitor for unit-testing evaluate_health
|
||||
fn make_test_monitor() -> KeepaliveMonitor {
|
||||
let (signal_tx, _signal_rx) = mpsc::channel(8);
|
||||
let (_ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
KeepaliveMonitor {
|
||||
config: AdaptiveKeepaliveConfig::default(),
|
||||
health: LinkHealth::Degraded,
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_jitter() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
jitter_ms: 60.0, // > 50ms threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.06, // > 5% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25, // > 20% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_consecutive_timeouts() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
consecutive_timeouts: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good_quality = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 20.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Should require 3 consecutive good checks (default upgrade_checks)
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 2);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_resets_on_bad_check() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let bad = ConnectionQuality {
|
||||
jitter_ms: 60.0, // too high
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
m.evaluate_health(&good); // 1 check
|
||||
m.evaluate_health(&good); // 2 checks
|
||||
m.evaluate_health(&bad); // resets
|
||||
assert_eq!(m.consecutive_upgrade_checks, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_to_degraded_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let recovering = ConnectionQuality {
|
||||
loss_ratio: 0.05, // < 10% recover threshold
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_never_directly_to_healthy() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let perfect = ConnectionQuality {
|
||||
jitter_ms: 1.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 10.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even with perfect quality, must go through Degraded first
|
||||
m.evaluate_health(&perfect); // 1
|
||||
let result = m.evaluate_health(&perfect); // 2 → Degraded
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
// Not Healthy yet
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interval_matches_health() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(60));
|
||||
m.health = LinkHealth::Degraded;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(30));
|
||||
m.health = LinkHealth::Critical;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(10));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,3 +11,7 @@ pub mod network;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod reconnect;
|
||||
pub mod telemetry;
|
||||
pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
|
||||
@@ -285,6 +285,39 @@ async fn handle_client_request(
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
}
|
||||
"getConnectionQuality" => {
|
||||
match vpn_client.get_connection_quality() {
|
||||
Some(quality) => {
|
||||
let health = vpn_client.get_link_health().await;
|
||||
let interval_secs = match health {
|
||||
crate::keepalive::LinkHealth::Healthy => 60,
|
||||
crate::keepalive::LinkHealth::Degraded => 30,
|
||||
crate::keepalive::LinkHealth::Critical => 10,
|
||||
};
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", health),
|
||||
"currentKeepaliveIntervalSecs": interval_secs,
|
||||
}))
|
||||
}
|
||||
None => ManagementResponse::ok(id, serde_json::json!(null)),
|
||||
}
|
||||
}
|
||||
"getMtuInfo" => {
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"tunMtu": 1420,
|
||||
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
|
||||
"linkMtu": 1500,
|
||||
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
|
||||
"oversizedPacketsDropped": 0,
|
||||
"icmpTooBigSent": 0,
|
||||
}))
|
||||
}
|
||||
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
@@ -349,6 +382,50 @@ async fn handle_server_request(
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"setClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
|
||||
Some(r) => r,
|
||||
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
|
||||
};
|
||||
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
|
||||
Some(b) => b,
|
||||
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
|
||||
};
|
||||
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_client_rate_limit(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClientTelemetry" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match clients.into_iter().find(|c| c.client_id == client_id) {
|
||||
Some(info) => {
|
||||
match serde_json::to_value(&info) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
|
||||
}
|
||||
}
|
||||
"generateKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
|
||||
314
rust/src/mtu.rs
Normal file
314
rust/src/mtu.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
/// Overhead breakdown for VPN tunnel encapsulation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TunnelOverhead {
|
||||
/// Outer IP header: 20 bytes (IPv4, no options).
|
||||
pub ip_header: u16,
|
||||
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
|
||||
pub tcp_header: u16,
|
||||
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
|
||||
pub ws_framing: u16,
|
||||
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
|
||||
pub vpn_header: u16,
|
||||
/// Noise AEAD tag: 16 bytes (Poly1305).
|
||||
pub noise_tag: u16,
|
||||
}
|
||||
|
||||
impl TunnelOverhead {
|
||||
/// Conservative default overhead estimate.
|
||||
pub fn default_overhead() -> Self {
|
||||
Self {
|
||||
ip_header: 20,
|
||||
tcp_header: 32,
|
||||
ws_framing: 6,
|
||||
vpn_header: 5,
|
||||
noise_tag: 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total encapsulation overhead in bytes.
|
||||
pub fn total(&self) -> u16 {
|
||||
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
|
||||
}
|
||||
|
||||
/// Compute effective TUN MTU given the underlying link MTU.
|
||||
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
|
||||
link_mtu.saturating_sub(self.total())
|
||||
}
|
||||
}
|
||||
|
||||
/// MTU configuration for the VPN tunnel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MtuConfig {
|
||||
/// Underlying link MTU (typically 1500 for Ethernet).
|
||||
pub link_mtu: u16,
|
||||
/// Computed effective TUN MTU.
|
||||
pub effective_mtu: u16,
|
||||
/// Whether to generate ICMP too-big for oversized packets.
|
||||
pub send_icmp_too_big: bool,
|
||||
/// Counter: oversized packets encountered.
|
||||
pub oversized_packets: u64,
|
||||
/// Counter: ICMP too-big messages generated.
|
||||
pub icmp_too_big_sent: u64,
|
||||
}
|
||||
|
||||
impl MtuConfig {
|
||||
/// Create a new MTU config from the underlying link MTU.
|
||||
pub fn new(link_mtu: u16) -> Self {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let effective = overhead.effective_tun_mtu(link_mtu);
|
||||
Self {
|
||||
link_mtu,
|
||||
effective_mtu: effective,
|
||||
send_icmp_too_big: true,
|
||||
oversized_packets: 0,
|
||||
icmp_too_big_sent: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a packet exceeds the effective MTU.
|
||||
pub fn is_oversized(&self, packet_len: usize) -> bool {
|
||||
packet_len > self.effective_mtu as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Action to take after checking MTU.
|
||||
pub enum MtuAction {
|
||||
/// Packet is within MTU, forward normally.
|
||||
Forward,
|
||||
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
|
||||
SendIcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check packet against MTU config and return the appropriate action.
|
||||
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
|
||||
if !config.is_oversized(packet.len()) {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
if !config.send_icmp_too_big {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
match generate_icmp_too_big(packet, config.effective_mtu) {
|
||||
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
|
||||
None => MtuAction::Forward,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
|
||||
///
|
||||
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
|
||||
/// Returns the complete IP + ICMP packet to write back into the TUN device.
|
||||
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
|
||||
// Need at least 20 bytes of original IP header
|
||||
if original_packet.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Verify it's IPv4
|
||||
if original_packet[0] >> 4 != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse source/dest from original IP header
|
||||
let src_ip = Ipv4Addr::new(
|
||||
original_packet[12],
|
||||
original_packet[13],
|
||||
original_packet[14],
|
||||
original_packet[15],
|
||||
);
|
||||
let dst_ip = Ipv4Addr::new(
|
||||
original_packet[16],
|
||||
original_packet[17],
|
||||
original_packet[18],
|
||||
original_packet[19],
|
||||
);
|
||||
|
||||
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
|
||||
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
|
||||
let icmp_payload = &original_packet[..icmp_data_len];
|
||||
|
||||
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
|
||||
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
|
||||
icmp.push(3); // Type: Destination Unreachable
|
||||
icmp.push(4); // Code: Fragmentation Needed and DF was Set
|
||||
icmp.push(0); // Checksum placeholder
|
||||
icmp.push(0);
|
||||
icmp.push(0); // Unused
|
||||
icmp.push(0);
|
||||
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
|
||||
icmp.extend_from_slice(icmp_payload);
|
||||
|
||||
// Compute ICMP checksum
|
||||
let cksum = internet_checksum(&icmp);
|
||||
icmp[2] = (cksum >> 8) as u8;
|
||||
icmp[3] = (cksum & 0xff) as u8;
|
||||
|
||||
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
|
||||
let total_len = (20 + icmp.len()) as u16;
|
||||
let mut ip = Vec::with_capacity(total_len as usize);
|
||||
ip.push(0x45); // Version 4, IHL 5
|
||||
ip.push(0x00); // DSCP/ECN
|
||||
ip.extend_from_slice(&total_len.to_be_bytes());
|
||||
ip.extend_from_slice(&[0, 0]); // Identification
|
||||
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
|
||||
ip.push(64); // TTL
|
||||
ip.push(1); // Protocol: ICMP
|
||||
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
|
||||
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
|
||||
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
|
||||
|
||||
// Compute IP header checksum
|
||||
let ip_cksum = internet_checksum(&ip[..20]);
|
||||
ip[10] = (ip_cksum >> 8) as u8;
|
||||
ip[11] = (ip_cksum & 0xff) as u8;
|
||||
|
||||
ip.extend_from_slice(&icmp);
|
||||
Some(ip)
|
||||
}
|
||||
|
||||
/// Standard Internet checksum (RFC 1071).
|
||||
fn internet_checksum(data: &[u8]) -> u16 {
|
||||
let mut sum: u32 = 0;
|
||||
let mut i = 0;
|
||||
while i + 1 < data.len() {
|
||||
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
|
||||
i += 2;
|
||||
}
|
||||
if i < data.len() {
|
||||
sum += (data[i] as u32) << 8;
|
||||
}
|
||||
while sum >> 16 != 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16);
|
||||
}
|
||||
!sum as u16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_overhead_total() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
assert_eq!(oh.total(), 79); // 20+32+6+5+16
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_for_ethernet() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(1500);
|
||||
assert_eq!(mtu, 1421); // 1500 - 79
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_saturates_at_zero() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(50); // Less than overhead
|
||||
assert_eq!(mtu, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mtu_config_default() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert_eq!(config.effective_mtu, 1421);
|
||||
assert_eq!(config.link_mtu, 1500);
|
||||
assert!(config.send_icmp_too_big);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_oversized() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert!(!config.is_oversized(1421));
|
||||
assert!(config.is_oversized(1422));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_generation() {
|
||||
// Craft a minimal IPv4 packet
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45; // version 4, IHL 5
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
|
||||
original[9] = 6; // TCP
|
||||
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
|
||||
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
|
||||
|
||||
// Verify it's a valid IPv4 packet
|
||||
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
|
||||
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
|
||||
|
||||
// Source should be original dst (10.0.0.2)
|
||||
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
|
||||
// Destination should be original src (10.0.0.1)
|
||||
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
|
||||
|
||||
// ICMP type 3, code 4
|
||||
assert_eq!(icmp_pkt[20], 3);
|
||||
assert_eq!(icmp_pkt[21], 4);
|
||||
|
||||
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
|
||||
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
|
||||
assert_eq!(mtu, 1421);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_short_packet() {
|
||||
let short = vec![0u8; 10];
|
||||
assert!(generate_icmp_too_big(&short, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_non_ipv4() {
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6
|
||||
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_checksum_valid() {
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45;
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
|
||||
original[9] = 6;
|
||||
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
|
||||
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
|
||||
|
||||
// Verify IP header checksum
|
||||
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
|
||||
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
|
||||
|
||||
// Verify ICMP checksum
|
||||
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
|
||||
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_forward() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let pkt = vec![0u8; 1421]; // Exactly at MTU
|
||||
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_oversized_generates_icmp() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let mut pkt = vec![0u8; 1500];
|
||||
pkt[0] = 0x45; // Valid IPv4
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
|
||||
match check_mtu(&pkt, &config) {
|
||||
MtuAction::SendIcmpTooBig(icmp) => {
|
||||
assert_eq!(icmp[20], 3); // ICMP type
|
||||
assert_eq!(icmp[21], 4); // ICMP code
|
||||
}
|
||||
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
|
||||
}
|
||||
}
|
||||
}
|
||||
490
rust/src/qos.rs
Normal file
490
rust/src/qos.rs
Normal file
@@ -0,0 +1,490 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Priority levels for IP packets.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(u8)]
|
||||
pub enum Priority {
|
||||
High = 0,
|
||||
Normal = 1,
|
||||
Low = 2,
|
||||
}
|
||||
|
||||
/// QoS statistics per priority level.
|
||||
pub struct QosStats {
|
||||
pub high_enqueued: AtomicU64,
|
||||
pub normal_enqueued: AtomicU64,
|
||||
pub low_enqueued: AtomicU64,
|
||||
pub high_dropped: AtomicU64,
|
||||
pub normal_dropped: AtomicU64,
|
||||
pub low_dropped: AtomicU64,
|
||||
}
|
||||
|
||||
impl QosStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
high_enqueued: AtomicU64::new(0),
|
||||
normal_enqueued: AtomicU64::new(0),
|
||||
low_enqueued: AtomicU64::new(0),
|
||||
high_dropped: AtomicU64::new(0),
|
||||
normal_dropped: AtomicU64::new(0),
|
||||
low_dropped: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QosStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Packet classification
|
||||
// ============================================================================
|
||||
|
||||
/// 5-tuple flow key for tracking bulk flows.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct FlowKey {
|
||||
src_ip: u32,
|
||||
dst_ip: u32,
|
||||
src_port: u16,
|
||||
dst_port: u16,
|
||||
protocol: u8,
|
||||
}
|
||||
|
||||
/// Per-flow state for bulk detection.
|
||||
struct FlowState {
|
||||
bytes_total: u64,
|
||||
window_start: Instant,
|
||||
}
|
||||
|
||||
/// Tracks per-flow byte counts for bulk flow detection.
|
||||
struct FlowTracker {
|
||||
flows: HashMap<FlowKey, FlowState>,
|
||||
window_duration: Duration,
|
||||
max_flows: usize,
|
||||
}
|
||||
|
||||
impl FlowTracker {
|
||||
fn new(window_duration: Duration, max_flows: usize) -> Self {
|
||||
Self {
|
||||
flows: HashMap::new(),
|
||||
window_duration,
|
||||
max_flows,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
|
||||
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
// Evict if at capacity — remove oldest entry
|
||||
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
|
||||
if let Some(oldest_key) = self
|
||||
.flows
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.window_start)
|
||||
.map(|(k, _)| *k)
|
||||
{
|
||||
self.flows.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
let state = self.flows.entry(key).or_insert(FlowState {
|
||||
bytes_total: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
// Reset window if expired
|
||||
if now.duration_since(state.window_start) > self.window_duration {
|
||||
state.bytes_total = 0;
|
||||
state.window_start = now;
|
||||
}
|
||||
|
||||
state.bytes_total += bytes;
|
||||
state.bytes_total > threshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Classifies raw IP packets into priority levels.
|
||||
pub struct PacketClassifier {
|
||||
flow_tracker: FlowTracker,
|
||||
/// Byte threshold for classifying a flow as "bulk" (Low priority).
|
||||
bulk_threshold_bytes: u64,
|
||||
}
|
||||
|
||||
impl PacketClassifier {
|
||||
/// Create a new classifier.
|
||||
///
|
||||
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
|
||||
pub fn new(bulk_threshold_bytes: u64) -> Self {
|
||||
Self {
|
||||
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
|
||||
bulk_threshold_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify a raw IPv4 packet into a priority level.
|
||||
///
|
||||
/// The packet must start with the IPv4 header (as read from a TUN device).
|
||||
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
|
||||
// Need at least 20 bytes for a minimal IPv4 header
|
||||
if ip_packet.len() < 20 {
|
||||
return Priority::Normal;
|
||||
}
|
||||
|
||||
let version = ip_packet[0] >> 4;
|
||||
if version != 4 {
|
||||
return Priority::Normal; // Only classify IPv4 for now
|
||||
}
|
||||
|
||||
let ihl = (ip_packet[0] & 0x0F) as usize;
|
||||
let header_len = ihl * 4;
|
||||
let protocol = ip_packet[9];
|
||||
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
|
||||
|
||||
// ICMP is always high priority
|
||||
if protocol == 1 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Small packets (<128 bytes) are high priority (likely interactive)
|
||||
if total_len < 128 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Extract ports for TCP (6) and UDP (17)
|
||||
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
|
||||
&& ip_packet.len() >= header_len + 4
|
||||
{
|
||||
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
|
||||
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
|
||||
(sp, dp)
|
||||
} else {
|
||||
(0, 0)
|
||||
};
|
||||
|
||||
// DNS (port 53) and SSH (port 22) are high priority
|
||||
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Check for bulk flow
|
||||
if protocol == 6 || protocol == 17 {
|
||||
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
|
||||
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
|
||||
|
||||
let key = FlowKey {
|
||||
src_ip,
|
||||
dst_ip,
|
||||
src_port,
|
||||
dst_port,
|
||||
protocol,
|
||||
};
|
||||
|
||||
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
|
||||
return Priority::Low;
|
||||
}
|
||||
}
|
||||
|
||||
Priority::Normal
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Priority channel set
|
||||
// ============================================================================
|
||||
|
||||
/// Error returned when a packet is dropped.
|
||||
#[derive(Debug)]
|
||||
pub enum PacketDropped {
|
||||
LowPriorityDrop,
|
||||
NormalPriorityDrop,
|
||||
HighPriorityDrop,
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
/// Sending half of the priority channel set.
|
||||
pub struct PrioritySender {
|
||||
high_tx: mpsc::Sender<Vec<u8>>,
|
||||
normal_tx: mpsc::Sender<Vec<u8>>,
|
||||
low_tx: mpsc::Sender<Vec<u8>>,
|
||||
stats: Arc<QosStats>,
|
||||
}
|
||||
|
||||
impl PrioritySender {
|
||||
/// Send a packet with the given priority. Implements smart dropping under backpressure.
|
||||
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
|
||||
let (tx, enqueued_counter) = match priority {
|
||||
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
|
||||
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
|
||||
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
|
||||
};
|
||||
|
||||
match tx.try_send(packet) {
|
||||
Ok(()) => {
|
||||
enqueued_counter.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(packet)) => {
|
||||
self.handle_backpressure(packet, priority).await
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_backpressure(
|
||||
&self,
|
||||
packet: Vec<u8>,
|
||||
priority: Priority,
|
||||
) -> Result<(), PacketDropped> {
|
||||
match priority {
|
||||
Priority::Low => {
|
||||
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::LowPriorityDrop)
|
||||
}
|
||||
Priority::Normal => {
|
||||
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::NormalPriorityDrop)
|
||||
}
|
||||
Priority::High => {
|
||||
// Last resort: briefly wait for space, then drop
|
||||
match tokio::time::timeout(
|
||||
Duration::from_millis(5),
|
||||
self.high_tx.send(packet),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(())) => {
|
||||
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::HighPriorityDrop)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Receiving half of the priority channel set.
|
||||
pub struct PriorityReceiver {
|
||||
high_rx: mpsc::Receiver<Vec<u8>>,
|
||||
normal_rx: mpsc::Receiver<Vec<u8>>,
|
||||
low_rx: mpsc::Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl PriorityReceiver {
|
||||
/// Receive the next packet, draining high-priority first (biased select).
|
||||
pub async fn recv(&mut self) -> Option<Vec<u8>> {
|
||||
tokio::select! {
|
||||
biased;
|
||||
Some(pkt) = self.high_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.normal_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.low_rx.recv() => Some(pkt),
|
||||
else => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a priority channel set split into sender and receiver halves.
|
||||
///
|
||||
/// - `high_cap`: capacity of the high-priority channel
|
||||
/// - `normal_cap`: capacity of the normal-priority channel
|
||||
/// - `low_cap`: capacity of the low-priority channel
|
||||
pub fn create_priority_channels(
|
||||
high_cap: usize,
|
||||
normal_cap: usize,
|
||||
low_cap: usize,
|
||||
) -> (PrioritySender, PriorityReceiver) {
|
||||
let (high_tx, high_rx) = mpsc::channel(high_cap);
|
||||
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
|
||||
let (low_tx, low_rx) = mpsc::channel(low_cap);
|
||||
let stats = Arc::new(QosStats::new());
|
||||
|
||||
let sender = PrioritySender {
|
||||
high_tx,
|
||||
normal_tx,
|
||||
low_tx,
|
||||
stats,
|
||||
};
|
||||
|
||||
let receiver = PriorityReceiver {
|
||||
high_rx,
|
||||
normal_rx,
|
||||
low_rx,
|
||||
};
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
/// Get a reference to the QoS stats from a sender.
|
||||
impl PrioritySender {
|
||||
pub fn stats(&self) -> &Arc<QosStats> {
|
||||
&self.stats
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper: craft a minimal IPv4 packet
|
||||
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
|
||||
let mut pkt = vec![0u8; total_len.max(24) as usize];
|
||||
pkt[0] = 0x45; // version 4, IHL 5
|
||||
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
|
||||
pkt[9] = protocol;
|
||||
// src IP
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
// dst IP
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
// ports (at offset 20 for IHL=5)
|
||||
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
|
||||
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
|
||||
pkt
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_icmp_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_dns_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_ssh_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_small_packet_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_normal_http() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_bulk_flow_as_low() {
|
||||
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
|
||||
|
||||
// Send enough traffic to exceed the threshold
|
||||
for _ in 0..100 {
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
c.classify(&pkt);
|
||||
}
|
||||
|
||||
// Next packet from same flow should be Low
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
assert_eq!(c.classify(&pkt), Priority::Low);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_too_short_packet() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = vec![0u8; 10]; // Too short for IPv4 header
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_non_ipv4() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6 version nibble
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_receiver_drains_high_first() {
|
||||
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
// Enqueue in reverse order
|
||||
sender.send(vec![3], Priority::Low).await.unwrap();
|
||||
sender.send(vec![2], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
|
||||
// Should drain High first
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_low_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 1);
|
||||
|
||||
// Fill the low channel
|
||||
sender.send(vec![0], Priority::Low).await.unwrap();
|
||||
|
||||
// Next low-priority send should be dropped
|
||||
let result = sender.send(vec![1], Priority::Low).await;
|
||||
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_normal_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 1, 8);
|
||||
|
||||
sender.send(vec![0], Priority::Normal).await.unwrap();
|
||||
|
||||
let result = sender.send(vec![1], Priority::Normal).await;
|
||||
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_track_enqueued() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
sender.send(vec![2], Priority::High).await.unwrap();
|
||||
sender.send(vec![3], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![4], Priority::Low).await.unwrap();
|
||||
|
||||
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flow_tracker_evicts_at_capacity() {
|
||||
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
|
||||
|
||||
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
|
||||
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
|
||||
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
|
||||
|
||||
ft.record(k1, 100, 1000);
|
||||
ft.record(k2, 100, 1000);
|
||||
// Should evict k1 (oldest)
|
||||
ft.record(k3, 100, 1000);
|
||||
|
||||
assert_eq!(ft.flows.len(), 2);
|
||||
assert!(!ft.flows.contains_key(&k1));
|
||||
}
|
||||
}
|
||||
139
rust/src/ratelimit.rs
Normal file
139
rust/src/ratelimit.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::time::Instant;
|
||||
|
||||
/// A token bucket rate limiter operating on byte granularity.
|
||||
pub struct TokenBucket {
|
||||
/// Tokens (bytes) added per second.
|
||||
rate: f64,
|
||||
/// Maximum burst capacity in bytes.
|
||||
burst: f64,
|
||||
/// Currently available tokens.
|
||||
tokens: f64,
|
||||
/// Last time tokens were refilled.
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new token bucket.
|
||||
///
|
||||
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
|
||||
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
|
||||
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
|
||||
let burst = burst_bytes as f64;
|
||||
Self {
|
||||
rate: rate_bytes_per_sec as f64,
|
||||
burst,
|
||||
tokens: burst, // start full
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
|
||||
pub fn try_consume(&mut self, bytes: usize) -> bool {
|
||||
self.refill();
|
||||
let needed = bytes as f64;
|
||||
if needed <= self.tokens {
|
||||
self.tokens -= needed;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
|
||||
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
|
||||
self.rate = rate_bytes_per_sec as f64;
|
||||
self.burst = burst_bytes as f64;
|
||||
// Cap current tokens at new burst
|
||||
if self.tokens > self.burst {
|
||||
self.tokens = self.burst;
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
||||
self.last_refill = now;
|
||||
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn allows_traffic_under_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Should allow up to burst size immediately
|
||||
assert!(tb.try_consume(1_500_000));
|
||||
assert!(tb.try_consume(400_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocks_traffic_over_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
// Consume entire burst
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
// Next consume should fail (no time to refill)
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_consume_always_succeeds() {
|
||||
let mut tb = TokenBucket::new(0, 0);
|
||||
assert!(tb.try_consume(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refills_over_time() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
|
||||
// Drain completely
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
|
||||
// Wait 100ms — should refill ~100KB
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_caps_tokens() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Tokens start at burst (2MB)
|
||||
tb.update_limits(500_000, 500_000);
|
||||
// Tokens should be capped to new burst (500KB)
|
||||
assert!(tb.try_consume(500_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_changes_rate() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
assert!(tb.try_consume(1_000_000)); // drain
|
||||
|
||||
// Change to higher rate
|
||||
tb.update_limits(10_000_000, 10_000_000);
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
// At 10MB/s, 50ms should refill ~500KB
|
||||
assert!(tb.try_consume(200_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_rate_blocks_after_burst() {
|
||||
let mut tb = TokenBucket::new(0, 100);
|
||||
assert!(tb.try_consume(100));
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
// Zero rate means no refill
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
@@ -12,9 +13,14 @@ use tracing::{info, error, warn};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
use crate::network::IpPool;
|
||||
use crate::ratelimit::TokenBucket;
|
||||
use crate::transport;
|
||||
|
||||
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
|
||||
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
|
||||
|
||||
/// Server configuration (matches TS IVpnServerConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -29,6 +35,10 @@ pub struct ServerConfig {
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
pub enable_nat: Option<bool>,
|
||||
/// Default rate limit for new clients (bytes/sec). None = unlimited.
|
||||
pub default_rate_limit_bytes_per_sec: Option<u64>,
|
||||
/// Default burst size for new clients (bytes). None = unlimited.
|
||||
pub default_burst_bytes: Option<u64>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -40,6 +50,12 @@ pub struct ClientInfo {
|
||||
pub connected_since: String,
|
||||
pub bytes_sent: u64,
|
||||
pub bytes_received: u64,
|
||||
pub packets_dropped: u64,
|
||||
pub bytes_dropped: u64,
|
||||
pub last_keepalive_at: Option<String>,
|
||||
pub keepalives_received: u64,
|
||||
pub rate_limit_bytes_per_sec: Option<u64>,
|
||||
pub burst_bytes: Option<u64>,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -63,6 +79,8 @@ pub struct ServerState {
|
||||
pub ip_pool: Mutex<IpPool>,
|
||||
pub clients: RwLock<HashMap<String, ClientInfo>>,
|
||||
pub stats: RwLock<ServerStatistics>,
|
||||
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
|
||||
pub mtu_config: MtuConfig,
|
||||
pub started_at: std::time::Instant,
|
||||
}
|
||||
|
||||
@@ -98,11 +116,18 @@ impl VpnServer {
|
||||
}
|
||||
}
|
||||
|
||||
let link_mtu = config.mtu.unwrap_or(1420);
|
||||
// Compute effective MTU from overhead
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
|
||||
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
ip_pool: Mutex::new(ip_pool),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(ServerStatistics::default()),
|
||||
rate_limiters: Mutex::new(HashMap::new()),
|
||||
mtu_config,
|
||||
started_at: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
@@ -166,11 +191,52 @@ impl VpnServer {
|
||||
if let Some(client) = clients.remove(client_id) {
|
||||
let ip: Ipv4Addr = client.assigned_ip.parse()?;
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
state.rate_limiters.lock().await.remove(client_id);
|
||||
info!("Client {} disconnected", client_id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a rate limit for a specific client.
|
||||
pub async fn set_client_rate_limit(
|
||||
&self,
|
||||
client_id: &str,
|
||||
rate_bytes_per_sec: u64,
|
||||
burst_bytes: u64,
|
||||
) -> Result<()> {
|
||||
if let Some(ref state) = self.state {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
if let Some(limiter) = limiters.get_mut(client_id) {
|
||||
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
|
||||
} else {
|
||||
limiters.insert(
|
||||
client_id.to_string(),
|
||||
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
|
||||
);
|
||||
}
|
||||
// Update client info
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(client_id) {
|
||||
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
|
||||
info.burst_bytes = Some(burst_bytes);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove rate limit for a specific client (unlimited).
|
||||
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.rate_limiters.lock().await.remove(client_id);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(client_id) {
|
||||
info.rate_limit_bytes_per_sec = None;
|
||||
info.burst_bytes = None;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_listener(
|
||||
@@ -257,25 +323,43 @@ async fn handle_client_connection(
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Register client
|
||||
let default_rate = state.config.default_rate_limit_bytes_per_sec;
|
||||
let default_burst = state.config.default_burst_bytes;
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: assigned_ip.to_string(),
|
||||
connected_since: timestamp_now(),
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
packets_dropped: 0,
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: default_rate,
|
||||
burst_bytes: default_burst,
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
// Set up rate limiter if defaults are configured
|
||||
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
|
||||
state
|
||||
.rate_limiters
|
||||
.lock()
|
||||
.await
|
||||
.insert(client_id.clone(), TokenBucket::new(rate, burst));
|
||||
}
|
||||
|
||||
{
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.total_connections += 1;
|
||||
}
|
||||
|
||||
// Send assigned IP info (encrypted)
|
||||
// Send assigned IP info (encrypted), include effective MTU
|
||||
let ip_info = serde_json::json!({
|
||||
"assignedIp": assigned_ip.to_string(),
|
||||
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
|
||||
"mtu": state.config.mtu.unwrap_or(1420),
|
||||
"effectiveMtu": state.mtu_config.effective_mtu,
|
||||
});
|
||||
let ip_info_bytes = serde_json::to_vec(&ip_info)?;
|
||||
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
|
||||
@@ -289,66 +373,116 @@ async fn handle_client_connection(
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
|
||||
// Main packet loop
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
|
||||
loop {
|
||||
match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.bytes_received += len as u64;
|
||||
stats.packets_received += 1;
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
// Rate limiting check
|
||||
let allowed = {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
if let Some(limiter) = limiters.get_mut(&client_id) {
|
||||
limiter.try_consume(len)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if !allowed {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.bytes_received += len as u64;
|
||||
stats.packets_received += 1;
|
||||
|
||||
// Update per-client stats
|
||||
drop(stats);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.bytes_received += len as u64;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error from {}: {}", client_id, e);
|
||||
PacketType::Keepalive => {
|
||||
// Echo the keepalive payload back in the ACK
|
||||
let ack_frame = Frame {
|
||||
packet_type: PacketType::KeepaliveAck,
|
||||
payload: frame.payload.clone(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
stats.keepalives_sent += 1;
|
||||
|
||||
// Update per-client keepalive tracking
|
||||
drop(stats);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.last_keepalive_at = Some(timestamp_now());
|
||||
info.keepalives_received += 1;
|
||||
}
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Client {} sent disconnect", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!("Incomplete frame from {}", client_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Frame decode error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
PacketType::Keepalive => {
|
||||
let ack_frame = Frame {
|
||||
packet_type: PacketType::KeepaliveAck,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
stats.keepalives_sent += 1;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Client {} sent disconnect", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!("Incomplete frame from {}", client_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Frame decode error from {}: {}", client_id, e);
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
continue;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
|
||||
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -357,6 +491,7 @@ async fn handle_client_connection(
|
||||
// Cleanup
|
||||
state.clients.write().await.remove(&client_id);
|
||||
state.ip_pool.lock().await.release(&assigned_ip);
|
||||
state.rate_limiters.lock().await.remove(&client_id);
|
||||
info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
|
||||
|
||||
Ok(())
|
||||
|
||||
317
rust/src/telemetry.rs
Normal file
317
rust/src/telemetry.rs
Normal file
@@ -0,0 +1,317 @@
|
||||
use serde::Serialize;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single RTT sample.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RttSample {
|
||||
_rtt: Duration,
|
||||
_timestamp: Instant,
|
||||
was_timeout: bool,
|
||||
}
|
||||
|
||||
/// Snapshot of connection quality metrics.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConnectionQuality {
|
||||
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
|
||||
pub srtt_ms: f64,
|
||||
/// Jitter in milliseconds (mean deviation of RTT).
|
||||
pub jitter_ms: f64,
|
||||
/// Minimum RTT observed in the sample window.
|
||||
pub min_rtt_ms: f64,
|
||||
/// Maximum RTT observed in the sample window.
|
||||
pub max_rtt_ms: f64,
|
||||
/// Packet loss ratio over the sample window (0.0 - 1.0).
|
||||
pub loss_ratio: f64,
|
||||
/// Number of consecutive keepalive timeouts (0 if last succeeded).
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Total keepalives sent.
|
||||
pub keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
pub keepalives_acked: u64,
|
||||
}
|
||||
|
||||
/// Tracks connection quality from keepalive round-trips.
|
||||
pub struct RttTracker {
|
||||
/// Maximum number of samples to keep in the window.
|
||||
max_samples: usize,
|
||||
/// Recent RTT samples (including timeout markers).
|
||||
samples: VecDeque<RttSample>,
|
||||
/// When the last keepalive was sent (for computing RTT on ACK).
|
||||
pending_ping_sent_at: Option<Instant>,
|
||||
/// Number of consecutive keepalive timeouts.
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Smoothed RTT (EMA).
|
||||
srtt: Option<f64>,
|
||||
/// Jitter (mean deviation).
|
||||
jitter: f64,
|
||||
/// Minimum RTT observed.
|
||||
min_rtt: f64,
|
||||
/// Maximum RTT observed.
|
||||
max_rtt: f64,
|
||||
/// Total keepalives sent.
|
||||
keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
keepalives_acked: u64,
|
||||
/// Previous RTT sample for jitter calculation.
|
||||
last_rtt_ms: Option<f64>,
|
||||
}
|
||||
|
||||
impl RttTracker {
|
||||
/// Create a new tracker with the given window size.
|
||||
pub fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
max_samples,
|
||||
samples: VecDeque::with_capacity(max_samples),
|
||||
pending_ping_sent_at: None,
|
||||
consecutive_timeouts: 0,
|
||||
srtt: None,
|
||||
jitter: 0.0,
|
||||
min_rtt: f64::MAX,
|
||||
max_rtt: 0.0,
|
||||
keepalives_sent: 0,
|
||||
keepalives_acked: 0,
|
||||
last_rtt_ms: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record that a keepalive was sent.
|
||||
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
|
||||
pub fn mark_ping_sent(&mut self) -> u64 {
|
||||
self.pending_ping_sent_at = Some(Instant::now());
|
||||
self.keepalives_sent += 1;
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Record that a keepalive ACK was received with the echoed timestamp.
|
||||
/// Returns the computed RTT if a pending ping was recorded.
|
||||
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
|
||||
let sent_at = self.pending_ping_sent_at.take()?;
|
||||
let rtt = sent_at.elapsed();
|
||||
let rtt_ms = rtt.as_secs_f64() * 1000.0;
|
||||
|
||||
self.keepalives_acked += 1;
|
||||
self.consecutive_timeouts = 0;
|
||||
|
||||
// Update SRTT (RFC 6298: alpha = 1/8)
|
||||
match self.srtt {
|
||||
None => {
|
||||
self.srtt = Some(rtt_ms);
|
||||
self.jitter = rtt_ms / 2.0;
|
||||
}
|
||||
Some(prev_srtt) => {
|
||||
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
|
||||
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
|
||||
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
|
||||
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
|
||||
}
|
||||
}
|
||||
|
||||
// Update min/max
|
||||
if rtt_ms < self.min_rtt {
|
||||
self.min_rtt = rtt_ms;
|
||||
}
|
||||
if rtt_ms > self.max_rtt {
|
||||
self.max_rtt = rtt_ms;
|
||||
}
|
||||
|
||||
self.last_rtt_ms = Some(rtt_ms);
|
||||
|
||||
// Push sample into window
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: rtt,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: false,
|
||||
});
|
||||
|
||||
Some(rtt)
|
||||
}
|
||||
|
||||
/// Record that a keepalive timed out (no ACK received).
|
||||
pub fn record_timeout(&mut self) {
|
||||
self.consecutive_timeouts += 1;
|
||||
self.pending_ping_sent_at = None;
|
||||
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: Duration::ZERO,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: true,
|
||||
});
|
||||
}
|
||||
|
||||
/// Get a snapshot of the current connection quality.
|
||||
pub fn snapshot(&self) -> ConnectionQuality {
|
||||
let loss_ratio = if self.samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
|
||||
timeouts as f64 / self.samples.len() as f64
|
||||
};
|
||||
|
||||
ConnectionQuality {
|
||||
srtt_ms: self.srtt.unwrap_or(0.0),
|
||||
jitter_ms: self.jitter,
|
||||
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
|
||||
max_rtt_ms: self.max_rtt,
|
||||
loss_ratio,
|
||||
consecutive_timeouts: self.consecutive_timeouts,
|
||||
keepalives_sent: self.keepalives_sent,
|
||||
keepalives_acked: self.keepalives_acked,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_tracker_has_zero_quality() {
|
||||
let tracker = RttTracker::new(30);
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.srtt_ms, 0.0);
|
||||
assert_eq!(q.jitter_ms, 0.0);
|
||||
assert_eq!(q.loss_ratio, 0.0);
|
||||
assert_eq!(q.consecutive_timeouts, 0);
|
||||
assert_eq!(q.keepalives_sent, 0);
|
||||
assert_eq!(q.keepalives_acked, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_ping_returns_timestamp() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
// Should be a reasonable epoch-ms value (after 2020)
|
||||
assert!(ts > 1_577_836_800_000);
|
||||
assert_eq!(tracker.keepalives_sent, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_computes_rtt() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
let rtt = tracker.record_ack(ts);
|
||||
assert!(rtt.is_some());
|
||||
let rtt = rtt.unwrap();
|
||||
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
|
||||
assert_eq!(tracker.keepalives_acked, 1);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_without_pending_returns_none() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
assert!(tracker.record_ack(12345).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn srtt_converges() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// Simulate several ping/ack cycles with ~10ms RTT
|
||||
for _ in 0..10 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
|
||||
let q = tracker.snapshot();
|
||||
// SRTT should be roughly 10ms (allowing for scheduling variance)
|
||||
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
|
||||
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timeout_increments_counter_and_loss() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 2);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ack_resets_consecutive_timeouts() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_ratio_over_mixed_window() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
|
||||
for _ in 0..3 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!((q.loss_ratio - 0.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_evicts_old_samples() {
|
||||
let mut tracker = RttTracker::new(5);
|
||||
|
||||
// Fill window with 5 timeouts
|
||||
for _ in 0..5 {
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
|
||||
|
||||
// Add 5 successes — should evict all timeouts
|
||||
for _ in 0..5 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_rtt_tracked() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(15));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!(q.min_rtt_ms < q.max_rtt_ms);
|
||||
assert!(q.min_rtt_ms > 0.0);
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Action to take after checking a packet against the MTU.
|
||||
pub enum TunMtuAction {
|
||||
/// Packet is within MTU limits, forward it.
|
||||
Forward,
|
||||
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
|
||||
IcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check a TUN packet against the MTU and return the appropriate action.
|
||||
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
|
||||
match crate::mtu::check_mtu(packet, mtu_config) {
|
||||
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
|
||||
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a route.
|
||||
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
let output = tokio::process::Command::new("ip")
|
||||
|
||||
271
test/test.flowcontrol.node.ts
Normal file
271
test/test.flowcontrol.node.ts
Normal file
@@ -0,0 +1,271 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let client: VpnClient;
|
||||
const extraClients: VpnClient[] = [];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server', async () => {
|
||||
serverPort = await findFreePort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
// Phase 1: start the daemon bridge
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
|
||||
// Phase 2: generate a keypair
|
||||
keypair = await server.generateKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
|
||||
// Phase 3: start the VPN listener
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
// Verify server is now running
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('single client connects and gets IP', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
// Verify client status
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(1);
|
||||
expect(clients[0].assignedIp).toEqual(result.assignedIp);
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange', async () => {
|
||||
// Wait for at least 2 keepalive cycles (interval=3s, so 8s should be enough)
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
const serverStats = await server.getStatistics();
|
||||
expect(serverStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
expect(serverStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify per-client keepalive tracking
|
||||
const clients = await server.listClients();
|
||||
expect(clients[0].keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('connection quality telemetry', async () => {
|
||||
const quality = await client.getConnectionQuality();
|
||||
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
expect(quality.minRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.maxRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.lossRatio).toBeTypeofNumber();
|
||||
expect(['healthy', 'degraded', 'critical']).toContain(quality.linkHealth);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: set and verify', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
// Set a tight rate limit
|
||||
await server.setClientRateLimit(clientId, 100, 100);
|
||||
|
||||
// Verify via telemetry
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(100);
|
||||
expect(telemetry.burstBytes).toEqual(100);
|
||||
expect(telemetry.clientId).toEqual(clientId);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: removal', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
await server.removeClientRateLimit(clientId);
|
||||
|
||||
// Verify telemetry no longer shows rate limit
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
expect(telemetry.burstBytes).toBeNullOrUndefined();
|
||||
|
||||
// Connection still healthy
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('5 concurrent clients', async () => {
|
||||
const assignedIps = new Set<string>();
|
||||
|
||||
// Get the first client's IP
|
||||
const existingClients = await server.listClients();
|
||||
assignedIps.add(existingClients[0].assignedIp);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
const result = await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
assignedIps.add(result.assignedIp);
|
||||
extraClients.push(c);
|
||||
}
|
||||
|
||||
// All IPs should be unique (6 total: original + 5 new)
|
||||
expect(assignedIps.size).toEqual(6);
|
||||
|
||||
// Server should see 6 clients
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 6;
|
||||
});
|
||||
const allClients = await server.listClients();
|
||||
expect(allClients.length).toEqual(6);
|
||||
});
|
||||
|
||||
tap.test('client disconnect tracking', async () => {
|
||||
// Disconnect 3 of the 5 extra clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = extraClients[i];
|
||||
await c.disconnect();
|
||||
c.stop();
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 3;
|
||||
}, 15000);
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(3);
|
||||
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(6);
|
||||
});
|
||||
|
||||
tap.test('server-side client disconnection', async () => {
|
||||
const clients = await server.listClients();
|
||||
// Pick one of the remaining extra clients (not the original)
|
||||
const targetClient = clients.find((c) => {
|
||||
// Find a client that belongs to extraClients[3] or extraClients[4]
|
||||
return c.clientId !== clients[0].clientId;
|
||||
});
|
||||
expect(targetClient).toBeTruthy();
|
||||
|
||||
await server.disconnectClient(targetClient!.clientId);
|
||||
|
||||
// Wait for server to update
|
||||
await waitFor(async () => {
|
||||
const remaining = await server.listClients();
|
||||
return remaining.length === 2;
|
||||
});
|
||||
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(2);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop the original client
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
|
||||
// Stop remaining extra clients
|
||||
for (const c of extraClients) {
|
||||
if (c.running) {
|
||||
try {
|
||||
await c.disconnect();
|
||||
} catch {
|
||||
// May already be disconnected
|
||||
}
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Stop the server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
357
test/test.loadtest.node.ts
Normal file
357
test/test.loadtest.node.ts
Normal file
@@ -0,0 +1,357 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as stream from 'stream';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ThrottleProxy (adapted from remoteingress)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class ThrottleTransform extends stream.Transform {
|
||||
private bytesPerSec: number;
|
||||
private bucket: number;
|
||||
private lastRefill: number;
|
||||
private destroyed_: boolean = false;
|
||||
|
||||
constructor(bytesPerSecond: number) {
|
||||
super();
|
||||
this.bytesPerSec = bytesPerSecond;
|
||||
this.bucket = bytesPerSecond;
|
||||
this.lastRefill = Date.now();
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: stream.TransformCallback) {
|
||||
if (this.destroyed_) return;
|
||||
|
||||
const now = Date.now();
|
||||
const elapsed = (now - this.lastRefill) / 1000;
|
||||
this.bucket = Math.min(this.bytesPerSec, this.bucket + elapsed * this.bytesPerSec);
|
||||
this.lastRefill = now;
|
||||
|
||||
if (chunk.length <= this.bucket) {
|
||||
this.bucket -= chunk.length;
|
||||
callback(null, chunk);
|
||||
} else {
|
||||
const deficit = chunk.length - this.bucket;
|
||||
this.bucket = 0;
|
||||
const delayMs = Math.min((deficit / this.bytesPerSec) * 1000, 1000);
|
||||
setTimeout(() => {
|
||||
if (this.destroyed_) { callback(); return; }
|
||||
this.lastRefill = Date.now();
|
||||
this.bucket = 0;
|
||||
callback(null, chunk);
|
||||
}, delayMs);
|
||||
}
|
||||
}
|
||||
|
||||
_destroy(err: Error | null, callback: (error: Error | null) => void) {
|
||||
this.destroyed_ = true;
|
||||
callback(err);
|
||||
}
|
||||
}
|
||||
|
||||
interface ThrottleProxy {
|
||||
server: net.Server;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
async function startThrottleProxy(
|
||||
listenPort: number,
|
||||
targetHost: string,
|
||||
targetPort: number,
|
||||
bytesPerSecond: number,
|
||||
): Promise<ThrottleProxy> {
|
||||
const connections = new Set<net.Socket>();
|
||||
const server = net.createServer((clientSock) => {
|
||||
connections.add(clientSock);
|
||||
const upstream = net.createConnection({ host: targetHost, port: targetPort });
|
||||
connections.add(upstream);
|
||||
|
||||
const throttleUp = new ThrottleTransform(bytesPerSecond);
|
||||
const throttleDown = new ThrottleTransform(bytesPerSecond);
|
||||
|
||||
clientSock.pipe(throttleUp).pipe(upstream);
|
||||
upstream.pipe(throttleDown).pipe(clientSock);
|
||||
|
||||
let cleaned = false;
|
||||
const cleanup = () => {
|
||||
if (cleaned) return;
|
||||
cleaned = true;
|
||||
throttleUp.destroy();
|
||||
throttleDown.destroy();
|
||||
clientSock.destroy();
|
||||
upstream.destroy();
|
||||
connections.delete(clientSock);
|
||||
connections.delete(upstream);
|
||||
};
|
||||
clientSock.on('error', () => cleanup());
|
||||
upstream.on('error', () => cleanup());
|
||||
throttleUp.on('error', () => cleanup());
|
||||
throttleDown.on('error', () => cleanup());
|
||||
clientSock.on('close', () => cleanup());
|
||||
upstream.on('close', () => cleanup());
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => server.listen(listenPort, '127.0.0.1', resolve));
|
||||
return {
|
||||
server,
|
||||
close: async () => {
|
||||
for (const c of connections) c.destroy();
|
||||
connections.clear();
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let proxyPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let throttle: ThrottleProxy;
|
||||
const allClients: VpnClient[] = [];
|
||||
|
||||
async function createConnectedClient(port: number): Promise<VpnClient> {
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${port}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
allClients.push(c);
|
||||
return c;
|
||||
}
|
||||
|
||||
async function stopClient(c: VpnClient): Promise<void> {
|
||||
if (c.running) {
|
||||
try { await c.disconnect(); } catch { /* already disconnected */ }
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start throttled VPN tunnel (1 MB/s)', async () => {
|
||||
serverPort = await findFreePort();
|
||||
proxyPort = await findFreePort();
|
||||
|
||||
// Start VPN server
|
||||
server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Start throttle proxy: 1 MB/s
|
||||
throttle = await startThrottleProxy(proxyPort, '127.0.0.1', serverPort, 1024 * 1024);
|
||||
});
|
||||
|
||||
tap.test('throttled connection: handshake succeeds through throttle', async () => {
|
||||
const client = await createConnectedClient(proxyPort);
|
||||
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
expect(status.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
});
|
||||
|
||||
tap.test('sustained keepalive under throttle', async () => {
|
||||
// Wait for at least 2 keepalive cycles (3s interval)
|
||||
await delay(8000);
|
||||
|
||||
const client = allClients[0];
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Throttle adds latency — RTT should be measurable
|
||||
const quality = await client.getConnectionQuality();
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
});
|
||||
|
||||
tap.test('3 concurrent throttled clients', async () => {
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await createConnectedClient(proxyPort);
|
||||
}
|
||||
|
||||
// All 4 clients should be visible
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 4;
|
||||
});
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(4);
|
||||
|
||||
// Verify all IPs are unique
|
||||
const ips = new Set(clients.map((c) => c.assignedIp));
|
||||
expect(ips.size).toEqual(4);
|
||||
});
|
||||
|
||||
tap.test('rate limiting combined with network throttle', async () => {
|
||||
const clients = await server.listClients();
|
||||
const targetId = clients[0].clientId;
|
||||
|
||||
// Set rate limit on first client
|
||||
await server.setClientRateLimit(targetId, 500, 500);
|
||||
const telemetry = await server.getClientTelemetry(targetId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(500);
|
||||
expect(telemetry.burstBytes).toEqual(500);
|
||||
|
||||
// Verify another client has no rate limit
|
||||
const otherTelemetry = await server.getClientTelemetry(clients[1].clientId);
|
||||
expect(otherTelemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
|
||||
// Clean up the rate limit
|
||||
await server.removeClientRateLimit(targetId);
|
||||
});
|
||||
|
||||
tap.test('burst waves: 3 waves of 3 clients', async () => {
|
||||
const initialCount = (await server.listClients()).length;
|
||||
|
||||
for (let wave = 0; wave < 3; wave++) {
|
||||
const waveClients: VpnClient[] = [];
|
||||
|
||||
// Connect 3 clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = await createConnectedClient(proxyPort);
|
||||
waveClients.push(c);
|
||||
}
|
||||
|
||||
// Verify all connected
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount + 3;
|
||||
});
|
||||
|
||||
// Disconnect all wave clients
|
||||
for (const c of waveClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount;
|
||||
}, 15000);
|
||||
|
||||
await delay(500);
|
||||
}
|
||||
|
||||
// Verify total connections accumulated
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(9 + initialCount);
|
||||
|
||||
// Original clients still connected
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(initialCount);
|
||||
});
|
||||
|
||||
tap.test('aggressive throttle: 10 KB/s', async () => {
|
||||
// Close current throttle proxy and start an aggressive one
|
||||
await throttle.close();
|
||||
const aggressivePort = await findFreePort();
|
||||
throttle = await startThrottleProxy(aggressivePort, '127.0.0.1', serverPort, 10 * 1024);
|
||||
|
||||
// Connect a client through the aggressive throttle
|
||||
const client = await createConnectedClient(aggressivePort);
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Wait for keepalive exchange (might take longer due to throttle)
|
||||
await delay(10000);
|
||||
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('post-load health: direct connection still works', async () => {
|
||||
// Server should still be healthy after all load tests
|
||||
const serverStatus = await server.getStatus();
|
||||
expect(serverStatus.state).toEqual('connected');
|
||||
|
||||
// Connect one more client directly (no throttle)
|
||||
const directClient = await createConnectedClient(serverPort);
|
||||
const status = await directClient.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
await delay(5000);
|
||||
|
||||
const stats = await directClient.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop all clients
|
||||
for (const c of allClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Close throttle proxy
|
||||
if (throttle) {
|
||||
await throttle.close();
|
||||
}
|
||||
|
||||
// Stop server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartvpn',
|
||||
version: '1.0.2',
|
||||
version: '1.3.0',
|
||||
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import type {
|
||||
IVpnClientConfig,
|
||||
IVpnStatus,
|
||||
IVpnStatistics,
|
||||
IVpnConnectionQuality,
|
||||
IVpnMtuInfo,
|
||||
TVpnClientCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -65,12 +67,26 @@ export class VpnClient extends plugins.events.EventEmitter {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get traffic statistics.
|
||||
* Get traffic statistics (includes connection quality when connected).
|
||||
*/
|
||||
public async getStatistics(): Promise<IVpnStatistics> {
|
||||
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection quality metrics (RTT, jitter, loss, link health).
|
||||
*/
|
||||
public async getConnectionQuality(): Promise<IVpnConnectionQuality> {
|
||||
return this.bridge.sendCommand('getConnectionQuality', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get MTU information (overhead, effective MTU, oversized packet stats).
|
||||
*/
|
||||
public async getMtuInfo(): Promise<IVpnMtuInfo> {
|
||||
return this.bridge.sendCommand('getMtuInfo', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
IVpnServerStatistics,
|
||||
IVpnClientInfo,
|
||||
IVpnKeypair,
|
||||
IVpnClientTelemetry,
|
||||
TVpnServerCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -91,6 +92,35 @@ export class VpnServer extends plugins.events.EventEmitter {
|
||||
return this.bridge.sendCommand('generateKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set rate limit for a specific client.
|
||||
*/
|
||||
public async setClientRateLimit(
|
||||
clientId: string,
|
||||
rateBytesPerSec: number,
|
||||
burstBytes: number,
|
||||
): Promise<void> {
|
||||
await this.bridge.sendCommand('setClientRateLimit', {
|
||||
clientId,
|
||||
rateBytesPerSec,
|
||||
burstBytes,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove rate limit for a specific client (unlimited).
|
||||
*/
|
||||
public async removeClientRateLimit(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClientRateLimit', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get telemetry for a specific client.
|
||||
*/
|
||||
public async getClientTelemetry(clientId: string): Promise<IVpnClientTelemetry> {
|
||||
return this.bridge.sendCommand('getClientTelemetry', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
@@ -64,6 +64,10 @@ export interface IVpnServerConfig {
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Enable NAT/masquerade for client traffic */
|
||||
enableNat?: boolean;
|
||||
/** Default rate limit for new clients (bytes/sec). Omit for unlimited. */
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
/** Default burst size for new clients (bytes). Omit for unlimited. */
|
||||
defaultBurstBytes?: number;
|
||||
}
|
||||
|
||||
export interface IVpnServerOptions {
|
||||
@@ -99,6 +103,7 @@ export interface IVpnStatistics {
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
quality?: IVpnConnectionQuality;
|
||||
}
|
||||
|
||||
export interface IVpnClientInfo {
|
||||
@@ -107,6 +112,12 @@ export interface IVpnClientInfo {
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
export interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -119,6 +130,53 @@ export interface IVpnKeypair {
|
||||
privateKey: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: Connection quality
|
||||
// ============================================================================
|
||||
|
||||
export type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
|
||||
|
||||
export interface IVpnConnectionQuality {
|
||||
srttMs: number;
|
||||
jitterMs: number;
|
||||
minRttMs: number;
|
||||
maxRttMs: number;
|
||||
lossRatio: number;
|
||||
consecutiveTimeouts: number;
|
||||
linkHealth: TVpnLinkHealth;
|
||||
currentKeepaliveIntervalSecs: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: MTU info
|
||||
// ============================================================================
|
||||
|
||||
export interface IVpnMtuInfo {
|
||||
tunMtu: number;
|
||||
effectiveMtu: number;
|
||||
linkMtu: number;
|
||||
overheadBytes: number;
|
||||
oversizedPacketsDropped: number;
|
||||
icmpTooBigSent: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: Client telemetry (server-side per-client)
|
||||
// ============================================================================
|
||||
|
||||
export interface IVpnClientTelemetry {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
bytesReceived: number;
|
||||
bytesSent: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// IPC Command maps (used by smartrust RustBridge<TCommands>)
|
||||
// ============================================================================
|
||||
@@ -128,6 +186,8 @@ export type TVpnClientCommands = {
|
||||
disconnect: { params: Record<string, never>; result: void };
|
||||
getStatus: { params: Record<string, never>; result: IVpnStatus };
|
||||
getStatistics: { params: Record<string, never>; result: IVpnStatistics };
|
||||
getConnectionQuality: { params: Record<string, never>; result: IVpnConnectionQuality };
|
||||
getMtuInfo: { params: Record<string, never>; result: IVpnMtuInfo };
|
||||
};
|
||||
|
||||
export type TVpnServerCommands = {
|
||||
@@ -138,6 +198,9 @@ export type TVpnServerCommands = {
|
||||
listClients: { params: Record<string, never>; result: { clients: IVpnClientInfo[] } };
|
||||
disconnectClient: { params: { clientId: string }; result: void };
|
||||
generateKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
|
||||
removeClientRateLimit: { params: { clientId: string }; result: void };
|
||||
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user