16 Commits

Author SHA1 Message Date
f46ea70286 v1.6.0 2026-03-29 15:25:22 +00:00
26ee3634c8 feat(readme): document WireGuard transport support, configuration, and usage examples 2026-03-29 15:25:22 +00:00
049fa00563 v1.5.0 2026-03-29 15:24:41 +00:00
e4e59d72f9 feat(wireguard): add WireGuard transport support with management APIs and config generation 2026-03-29 15:24:41 +00:00
51d33127bf v1.4.1 2026-03-21 20:50:11 +00:00
a4ba6806e5 fix(readme): preserve markdown line breaks in feature list 2026-03-21 20:50:11 +00:00
6330921160 v1.4.0 2026-03-19 21:53:30 +00:00
e81dd377d8 feat(vpn transport): add QUIC transport support with auto fallback to WebSocket 2026-03-19 21:53:30 +00:00
e14c357ba0 v1.3.0 2026-03-17 19:15:43 +00:00
eb30825f72 feat(tests,client): add flow control and load test coverage and honor configured keepalive intervals 2026-03-17 19:15:43 +00:00
835f0f791d v1.2.0 2026-03-15 18:16:15 +00:00
aec545fe8c feat(readme): document QoS, telemetry, MTU, and rate limiting capabilities in the README 2026-03-15 18:16:15 +00:00
4fab721d87 v1.1.0 2026-03-15 18:10:25 +00:00
9ee41348e0 feat(rust-core): add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs 2026-03-15 18:10:25 +00:00
97bb148063 v1.0.3 2026-02-27 10:26:13 +00:00
c8d572b719 fix(build): add aarch64 linker configuration for cross-compilation 2026-02-27 10:26:13 +00:00
32 changed files with 8519 additions and 2103 deletions

View File

@@ -1,5 +1,63 @@
# Changelog
## 2026-03-29 - 1.6.0 - feat(readme)
document WireGuard transport support, configuration, and usage examples
- Expand the README from dual-transport to triple-transport support by adding WireGuard alongside WebSocket and QUIC
- Add client and server WireGuard examples, including live peer management and .conf generation with WgConfigGenerator
- Document new WireGuard-related API methods, config fields, transport modes, and security model details
## 2026-03-29 - 1.5.0 - feat(wireguard)
add WireGuard transport support with management APIs and config generation
- add Rust WireGuard module integration using boringtun and route management through client/server management handlers
- extend TypeScript client and server configuration schemas with WireGuard-specific options and validation
- add server-side WireGuard peer management commands including keypair generation, peer add/remove, and peer listing
- export a WireGuard config generator for producing client and server .conf files
- add WireGuard-focused test coverage for config validation and config generation
## 2026-03-21 - 1.4.1 - fix(readme)
preserve markdown line breaks in feature list
- Adds trailing spaces to the README feature list so each highlighted capability renders on its own line.
## 2026-03-19 - 1.4.0 - feat(vpn transport)
add QUIC transport support with auto fallback to WebSocket
- introduces a transport abstraction in the Rust daemon so client and server can operate over WebSocket or QUIC
- adds dual-mode server configuration with websocket, quic, and both transport modes plus QUIC idle timeout and listen address options
- adds client transport selection with auto mode that attempts QUIC first and falls back to WebSocket
- adds QUIC certificate hash pinning support and required Rust dependencies for QUIC and TLS
- updates TypeScript interfaces, config validation, tests, and documentation to cover the new transport modes
## 2026-03-17 - 1.3.0 - feat(tests,client)
add flow control and load test coverage and honor configured keepalive intervals
- Adds end-to-end node tests for client/server flow control, keepalive exchange, connection quality telemetry, rate limiting, concurrent clients, and disconnect tracking.
- Adds load testing with throttled proxy scenarios to validate behavior under constrained bandwidth and repeated client churn.
- Updates the Rust client to pass configured keepaliveIntervalSecs into the adaptive keepalive monitor instead of always using defaults.
## 2026-03-15 - 1.2.0 - feat(readme)
document QoS, telemetry, MTU, and rate limiting capabilities in the README
- Expand the architecture and feature overview to cover adaptive keepalive, telemetry, QoS, rate limiting, and MTU handling
- Update client and server examples to show new APIs such as getConnectionQuality(), getMtuInfo(), setClientRateLimit(), and getClientTelemetry()
- Add TypeScript interface documentation for connection quality, MTU info, enriched client statistics, and per-client telemetry
## 2026-03-15 - 1.1.0 - feat(rust-core)
add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
- adds adaptive keepalive monitoring with RTT, jitter, loss, and link health reporting to client statistics and management endpoints
- introduces MTU overhead calculation and oversized-packet handling support, plus client MTU info APIs
- adds token-bucket rate limiting with configurable default limits and server management commands to set, remove, and inspect per-client telemetry
- extends TypeScript client and server interfaces with connection quality, MTU, and client telemetry methods
## 2026-02-27 - 1.0.3 - fix(build)
add aarch64 linker configuration for cross-compilation
- Added rust/.cargo/config.toml to configure linker for target aarch64-unknown-linux-gnu
- Sets linker to 'aarch64-linux-gnu-gcc' to enable cross-compilation to ARM64
## 2026-02-27 - 1.0.2 - fix()
no changes detected - no code or content modifications

View File

@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartvpn",
"version": "1.0.2",
"version": "1.6.0",
"private": false,
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
"type": "module",
@@ -11,6 +11,7 @@
"typings": "dist_ts/index.d.ts",
"scripts": {
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
"test:before": "(tsrust)",
"test": "tstest test/ --verbose",
"buildDocs": "tsdoc"
},
@@ -28,15 +29,15 @@
],
"license": "MIT",
"dependencies": {
"@push.rocks/smartrust": "^1.3.0",
"@push.rocks/smartpath": "^5.0.18"
"@push.rocks/smartpath": "^6.0.0",
"@push.rocks/smartrust": "^1.3.2"
},
"devDependencies": {
"@git.zone/tsbuild": "^2.2.12",
"@git.zone/tsrun": "^1.3.3",
"@git.zone/tstest": "^1.0.96",
"@git.zone/tsrust": "^1.0.29",
"@types/node": "^22.0.0"
"@git.zone/tsbuild": "^4.3.0",
"@git.zone/tsrun": "^2.0.1",
"@git.zone/tsrust": "^1.3.0",
"@git.zone/tstest": "^3.5.0",
"@types/node": "^25.5.0"
},
"files": [
"ts/**/*",

2806
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

643
readme.md
View File

@@ -1,6 +1,13 @@
# @push.rocks/smartvpn
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, typed APIs while all networking heavy lifting — encryption, tunneling, packet forwarding — runs at native speed in Rust.
A high-performance VPN with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, fully-typed APIs while all networking heavy lifting — encryption, tunneling, QoS, rate limiting — runs at native speed in Rust.
🔒 **Noise NK** handshake + **XChaCha20-Poly1305** encryption
🚀 **Triple transport**: WebSocket (Cloudflare-friendly), raw **QUIC** (datagrams), and **WireGuard** (standard protocol)
📊 **Adaptive QoS**: packet classification, priority queues, per-client rate limiting
🔄 **Auto-transport**: tries QUIC first, falls back to WebSocket seamlessly
📡 **Real-time telemetry**: RTT, jitter, loss, link health — all exposed via typed APIs
🛡️ **WireGuard mode**: full userspace WireGuard via `boringtun` — generate `.conf` files, manage peers live
## Issue Reporting and Security
@@ -9,8 +16,6 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community
## Install
```bash
npm install @push.rocks/smartvpn
# or
pnpm install @push.rocks/smartvpn
```
@@ -18,27 +23,39 @@ pnpm install @push.rocks/smartvpn
```
TypeScript (control plane) Rust (data plane)
┌──────────────────────────┐ ┌───────────────────────────────┐
│ VpnClient / VpnServer │ │ smartvpn_daemon │
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS)
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha)
└──────────────────────────┘ │ ├─ codec (binary framing)
│ ├─ keepalive (app-level)
│ ├─ tunnel (TUN device)
│ ├─ network (NAT/IP pool)
reconnect (backoff)
└───────────────────────────────┘
┌──────────────────────────┐ ┌────────────────────────────────────
│ VpnClient / VpnServer │ │ smartvpn_daemon
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC)
│ └─ RustBridge │ socket │ ├─ transport_trait (abstraction)
│ (smartrust) │ │ │ ├─ transport (WebSocket/TLS)
│ │ │ │ └─ quic_transport (QUIC/UDP)
│ WgConfigGenerator │ ├─ wireguard (boringtun WG)
└─ .conf file output │ │ ├─ crypto (Noise NK + XCha20)
└──────────────────────────┘ │ ├─ codec (binary framing)
keepalive (adaptive state FSM)
│ ├─ telemetry (RTT/jitter/loss) │
│ ├─ qos (classify + priority Q) │
│ ├─ ratelimit (token bucket) │
│ ├─ mtu (overhead calc + ICMP) │
│ ├─ tunnel (TUN device) │
│ ├─ network (NAT/IP pool) │
│ └─ reconnect (exp. backoff) │
└────────────────────────────────────┘
```
**Key design decisions:**
| Decision | Choice | Why |
|----------|--------|-----|
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
| Keepalive | App-level (not WS pings) | Cloudflare drops WS ping frames; app-level pings survive |
| IPC | JSON lines over stdio/Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
| Transport | WebSocket + QUIC + WireGuard | WS works through Cloudflare; QUIC gives low latency + datagrams; WG for standard protocol interop |
| Auto-transport | QUIC first, WS fallback | Best performance when QUIC is available, graceful degradation when it's not |
| WireGuard | Userspace via `boringtun` | No kernel module needed, runs on any platform, full peer management via IPC |
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter sync needed) |
| QUIC auth | Certificate hash pinning | WireGuard-style trust model — no CA needed, just pin the server cert hash |
| Keepalive | Adaptive app-level pings | Cloudflare drops WS pings; interval adapts to link health (1060s) |
| QoS | Packet classification + priority queues | DNS/SSH/ICMP always drain first; bulk flows get deprioritized |
| Rate limiting | Per-client token bucket | Byte-granular, dynamically reconfigurable via IPC |
| IPC | JSON lines over stdio / Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
## 🚀 Quick Start
@@ -48,15 +65,12 @@ TypeScript (control plane) Rust (data plane)
```typescript
import { VpnClient } from '@push.rocks/smartvpn';
// Development: spawn the Rust daemon as a child process
const client = new VpnClient({
transport: { transport: 'stdio' },
});
// Start the daemon bridge
await client.start();
// Connect to a VPN server
const { assignedIp } = await client.connect({
serverUrl: 'wss://vpn.example.com/tunnel',
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
@@ -67,19 +81,91 @@ const { assignedIp } = await client.connect({
console.log(`Connected! Assigned IP: ${assignedIp}`);
// Check status
const status = await client.getStatus();
console.log(status); // { state: 'connected', assignedIp: '10.8.0.2', ... }
// Connection quality (adaptive keepalive + telemetry)
const quality = await client.getConnectionQuality();
console.log(quality);
// {
// srttMs: 42.5, jitterMs: 3.2, minRttMs: 38.0, maxRttMs: 67.0,
// lossRatio: 0.0, consecutiveTimeouts: 0,
// linkHealth: 'healthy', currentKeepaliveIntervalSecs: 60
// }
// Get traffic stats
// MTU info
const mtu = await client.getMtuInfo();
console.log(mtu);
// { tunMtu: 1420, effectiveMtu: 1421, linkMtu: 1500, overheadBytes: 79, ... }
// Traffic stats (includes quality snapshot)
const stats = await client.getStatistics();
console.log(stats); // { bytesSent, bytesReceived, packetsSent, ... }
// Disconnect
await client.disconnect();
client.stop();
```
### VPN Client with QUIC
```typescript
import { VpnClient } from '@push.rocks/smartvpn';
// Explicit QUIC — serverUrl is host:port, pinned by cert hash
const quicClient = new VpnClient({
transport: { transport: 'stdio' },
});
await quicClient.start();
const { assignedIp } = await quicClient.connect({
serverUrl: 'vpn.example.com:443',
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
transport: 'quic',
serverCertHash: 'BASE64_SHA256_CERT_HASH', // printed by server on startup
});
// Or use auto-transport: tries QUIC first (3s timeout), falls back to WS
const autoClient = new VpnClient({
transport: { transport: 'stdio' },
});
await autoClient.start();
await autoClient.connect({
serverUrl: 'wss://vpn.example.com/tunnel', // WS URL — host:port extracted for QUIC attempt
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
transport: 'auto', // default — QUIC first, then WS
});
```
### VPN Client with WireGuard
```typescript
import { VpnClient } from '@push.rocks/smartvpn';
const wgClient = new VpnClient({
transport: { transport: 'stdio' },
});
await wgClient.start();
const { assignedIp } = await wgClient.connect({
serverPublicKey: 'BASE64_SERVER_WG_PUBLIC_KEY',
serverUrl: '', // not used for WireGuard
transport: 'wireguard',
wgPrivateKey: 'BASE64_CLIENT_PRIVATE_KEY',
wgAddress: '10.8.0.2',
wgAddressPrefix: 24,
wgEndpoint: 'vpn.example.com:51820',
wgAllowedIps: ['0.0.0.0/0'], // route all traffic
wgPersistentKeepalive: 25,
wgPresharedKey: 'OPTIONAL_PSK', // optional extra layer
dns: ['1.1.1.1'],
mtu: 1420,
});
console.log(`WireGuard connected! IP: ${assignedIp}`);
await wgClient.disconnect();
wgClient.stop();
```
### VPN Server
```typescript
@@ -89,37 +175,172 @@ const server = new VpnServer({
transport: { transport: 'stdio' },
});
// Start the daemon and the VPN server
// Generate a Noise keypair first
await server.start();
const keypair = await server.generateKeypair();
// Start the VPN listener
await server.start({
listenAddr: '0.0.0.0:443',
privateKey: 'BASE64_PRIVATE_KEY',
publicKey: 'BASE64_PUBLIC_KEY',
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.8.0.0/24',
dns: ['1.1.1.1'],
mtu: 1420,
enableNat: true,
// Transport mode: 'websocket', 'quic', 'both', or 'wireguard'
transportMode: 'both',
// Optional: separate QUIC listen address
quicListenAddr: '0.0.0.0:4433',
// Optional: QUIC idle timeout
quicIdleTimeoutSecs: 30,
// Optional: default rate limit for all new clients
defaultRateLimitBytesPerSec: 10_000_000, // 10 MB/s
defaultBurstBytes: 20_000_000, // 20 MB burst
});
// Generate a Noise keypair
const keypair = await server.generateKeypair();
console.log(keypair); // { publicKey: '...', privateKey: '...' }
// List connected clients
const clients = await server.listClients();
// [{ clientId, assignedIp, connectedSince, bytesSent, bytesReceived }]
// Disconnect a specific client
await server.disconnectClient('some-client-id');
// Per-client rate limiting (live, no reconnect needed)
await server.setClientRateLimit('client-id', 5_000_000, 10_000_000);
await server.removeClientRateLimit('client-id'); // unlimited
// Get server stats
const stats = await server.getStatistics();
// { bytesSent, bytesReceived, activeClients, totalConnections, ... }
// Per-client telemetry
const telemetry = await server.getClientTelemetry('client-id');
console.log(telemetry);
// {
// clientId, assignedIp, lastKeepaliveAt, keepalivesReceived,
// packetsDropped, bytesDropped, bytesReceived, bytesSent,
// rateLimitBytesPerSec, burstBytes
// }
// Kick a client
await server.disconnectClient('client-id');
// Stop
await server.stopServer();
server.stop();
```
### WireGuard Server Mode
```typescript
import { VpnServer } from '@push.rocks/smartvpn';
const wgServer = new VpnServer({
transport: { transport: 'stdio' },
});
// Generate a WireGuard X25519 keypair
await wgServer.start();
const keypair = await wgServer.generateWgKeypair();
console.log(`Server public key: ${keypair.publicKey}`);
// Start in WireGuard mode
await wgServer.start({
listenAddr: '0.0.0.0:51820',
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.8.0.0/24',
transportMode: 'wireguard',
wgListenPort: 51820,
wgPeers: [
{
publicKey: 'CLIENT_PUBLIC_KEY_BASE64',
allowedIps: ['10.8.0.2/32'],
persistentKeepalive: 25,
},
],
enableNat: true,
dns: ['1.1.1.1'],
mtu: 1420,
});
// Live peer management — add/remove peers without restart
await wgServer.addWgPeer({
publicKey: 'NEW_CLIENT_PUBLIC_KEY',
allowedIps: ['10.8.0.3/32'],
persistentKeepalive: 25,
});
// List peers with live stats
const peers = await wgServer.listWgPeers();
for (const peer of peers) {
console.log(`${peer.publicKey}: ↑${peer.bytesSent}${peer.bytesReceived}`);
}
// Remove a peer by public key
await wgServer.removeWgPeer('CLIENT_PUBLIC_KEY_BASE64');
await wgServer.stopServer();
wgServer.stop();
```
### Generating WireGuard .conf Files
The `WgConfigGenerator` creates standard WireGuard `.conf` files compatible with `wg-quick`, iOS/Android apps, and all standard WireGuard clients:
```typescript
import { WgConfigGenerator } from '@push.rocks/smartvpn';
// Client config (for wg-quick or mobile apps)
const clientConf = WgConfigGenerator.generateClientConfig({
privateKey: 'CLIENT_PRIVATE_KEY_BASE64',
address: '10.8.0.2/24',
dns: ['1.1.1.1', '8.8.8.8'],
mtu: 1420,
peer: {
publicKey: 'SERVER_PUBLIC_KEY_BASE64',
endpoint: 'vpn.example.com:51820',
allowedIps: ['0.0.0.0/0', '::/0'],
persistentKeepalive: 25,
presharedKey: 'OPTIONAL_PSK_BASE64',
},
});
// Server config (for wg-quick)
const serverConf = WgConfigGenerator.generateServerConfig({
privateKey: 'SERVER_PRIVATE_KEY_BASE64',
address: '10.8.0.1/24',
listenPort: 51820,
dns: ['1.1.1.1'],
mtu: 1420,
enableNat: true,
natInterface: 'eth0', // auto-detected if omitted
peers: [
{
publicKey: 'CLIENT_PUBLIC_KEY_BASE64',
allowedIps: ['10.8.0.2/32'],
persistentKeepalive: 25,
},
],
});
// Write to disk
import * as fs from 'fs';
fs.writeFileSync('/etc/wireguard/wg0.conf', serverConf);
```
<details>
<summary>Example output: client .conf</summary>
```ini
[Interface]
PrivateKey = CLIENT_PRIVATE_KEY_BASE64
Address = 10.8.0.2/24
DNS = 1.1.1.1, 8.8.8.8
MTU = 1420
[Peer]
PublicKey = SERVER_PUBLIC_KEY_BASE64
PresharedKey = OPTIONAL_PSK_BASE64
Endpoint = vpn.example.com:51820
AllowedIPs = 0.0.0.0/0, ::/0
PersistentKeepalive = 25
```
</details>
### Production: Socket Transport
In production, the daemon runs as a system service and you connect over a Unix socket:
@@ -148,10 +369,12 @@ When using socket transport, `client.stop()` closes the socket but **does not ki
| Method | Returns | Description |
|--------|---------|-------------|
| `start()` | `Promise<boolean>` | Start the daemon bridge (spawn or connect) |
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server (WS, QUIC, or WireGuard) |
| `disconnect()` | `Promise<void>` | Disconnect from VPN |
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic statistics |
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic stats + connection quality |
| `getConnectionQuality()` | `Promise<IVpnConnectionQuality>` | RTT, jitter, loss, link health |
| `getMtuInfo()` | `Promise<IVpnMtuInfo>` | MTU info and overhead breakdown |
| `stop()` | `void` | Kill/close the daemon bridge |
| `running` | `boolean` | Whether bridge is active |
@@ -163,9 +386,16 @@ When using socket transport, `client.stop()` closes the socket but **does not ki
| `stopServer()` | `Promise<void>` | Stop the VPN server |
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients |
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients with QoS stats |
| `disconnectClient(id)` | `Promise<void>` | Kick a client |
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
| `setClientRateLimit(id, rate, burst)` | `Promise<void>` | Set per-client rate limit (bytes/sec) |
| `removeClientRateLimit(id)` | `Promise<void>` | Remove rate limit (unlimited) |
| `getClientTelemetry(id)` | `Promise<IVpnClientTelemetry>` | Per-client telemetry + drop stats |
| `generateWgKeypair()` | `Promise<IVpnKeypair>` | Generate WireGuard X25519 keypair |
| `addWgPeer(peer)` | `Promise<void>` | Add a WireGuard peer at runtime |
| `removeWgPeer(publicKey)` | `Promise<void>` | Remove a WireGuard peer by key |
| `listWgPeers()` | `Promise<IWgPeerInfo[]>` | List WG peers with traffic stats |
| `stop()` | `void` | Kill/close the daemon bridge |
### `VpnConfig`
@@ -184,6 +414,19 @@ const config = await VpnConfig.loadFromFile<IVpnClientConfig>('/etc/smartvpn/cli
await VpnConfig.saveToFile('/etc/smartvpn/client.json', config);
```
Validation covers both smartvpn-native configs and WireGuard configs — base64 key format, CIDR ranges, port ranges, and required fields are all checked.
### `WgConfigGenerator`
Static generator for standard WireGuard `.conf` files:
| Method | Returns | Description |
|--------|---------|-------------|
| `generateClientConfig(opts)` | `string` | Generate a `wg-quick` compatible client `.conf` |
| `generateServerConfig(opts)` | `string` | Generate a `wg-quick` compatible server `.conf` with NAT rules |
Output is compatible with `wg-quick`, WireGuard iOS/Android apps, and any standard WireGuard implementation.
### `VpnInstaller`
Generate system service units for the daemon:
@@ -191,26 +434,23 @@ Generate system service units for the daemon:
```typescript
import { VpnInstaller } from '@push.rocks/smartvpn';
// Auto-detect platform
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
// Generate systemd unit (Linux)
// Linux (systemd)
const unit = VpnInstaller.generateSystemdUnit({
binaryPath: '/usr/local/bin/smartvpn_daemon',
socketPath: '/var/run/smartvpn.sock',
mode: 'server',
});
// unit.content = full systemd .service file
// unit.installPath = '/etc/systemd/system/smartvpn-server.service'
// Generate launchd plist (macOS)
// macOS (launchd)
const plist = VpnInstaller.generateLaunchdPlist({
binaryPath: '/usr/local/bin/smartvpn_daemon',
socketPath: '/var/run/smartvpn.sock',
mode: 'client',
});
// Auto-detect and generate
// Auto-detect platform
const serviceUnit = VpnInstaller.generateServiceUnit({
binaryPath: '/usr/local/bin/smartvpn_daemon',
socketPath: '/var/run/smartvpn.sock',
@@ -223,22 +463,168 @@ const serviceUnit = VpnInstaller.generateServiceUnit({
Both `VpnClient` and `VpnServer` extend `EventEmitter`:
```typescript
client.on('status', (status) => { /* IVpnStatus */ });
client.on('error', (err) => { /* { message, code? } */ });
client.on('exit', ({ code, signal }) => { /* daemon exited */ });
client.on('reconnected', () => { /* socket reconnected */ });
client.on('status', (status) => { /* IVpnStatus update */ });
client.on('error', (error) => { /* error from daemon */ });
server.on('client-connected', (info) => { /* IVpnClientInfo */ });
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
server.on('started', () => { /* server listener started */ });
server.on('stopped', () => { /* server listener stopped */ });
```
## 🌐 Transport Modes
smartvpn supports three transport protocols. The smartvpn-native transports (WebSocket + QUIC) share the same encryption, framing, and QoS pipeline. WireGuard mode uses the standard WireGuard protocol for broad interoperability.
### WebSocket (default for smartvpn-native)
- Works through Cloudflare, reverse proxies, and HTTP load balancers
- Reliable delivery only (no datagram support)
- URL format: `wss://host/path` or `ws://host:port/path`
### QUIC
- Lower latency, built-in multiplexing, 0-RTT connection establishment
- Supports **unreliable datagrams** for IP packets (with automatic fallback to reliable if oversized)
- Certificate hash pinning — no CA chain needed, WireGuard-style trust
- URL format: `host:port`
- ALPN protocol: `smartvpn`
### WireGuard
- Standard WireGuard protocol via `boringtun` (userspace, no kernel module)
- Compatible with **all WireGuard clients** — iOS, Android, macOS, Windows, Linux, routers
- X25519 key exchange, ChaCha20-Poly1305 encryption
- Dynamic peer management at runtime (add/remove without restart)
- Optional preshared keys for post-quantum defense-in-depth
- Generate `.conf` files for standard clients via `WgConfigGenerator`
- Default port: `51820/UDP`
### Auto-Transport (Recommended for smartvpn-native)
The default `transport: 'auto'` mode gives you the best of both worlds:
1. Extract `host:port` from the WebSocket URL
2. Attempt QUIC connection (3-second timeout)
3. If QUIC fails or times out → fall back to WebSocket
4. Completely transparent to the application
```typescript
await client.connect({
serverUrl: 'wss://vpn.example.com/tunnel',
serverPublicKey: '...',
transport: 'auto', // default — QUIC first, WS fallback
});
```
### Server Dual-Mode / Multi-Mode
The server can listen on multiple transports simultaneously:
```typescript
// WebSocket + QUIC (dual mode)
await server.start({
listenAddr: '0.0.0.0:443', // WebSocket listener
quicListenAddr: '0.0.0.0:4433', // QUIC listener (optional, defaults to listenAddr)
transportMode: 'both', // 'websocket' | 'quic' | 'both' | 'wireguard'
quicIdleTimeoutSecs: 30,
// ... other config
});
// WireGuard standalone
await server.start({
listenAddr: '0.0.0.0:51820',
transportMode: 'wireguard',
wgListenPort: 51820,
wgPeers: [{ publicKey: '...', allowedIps: ['10.8.0.2/32'] }],
// ... other config
});
```
When using `'both'` mode, the server logs the QUIC certificate hash on startup — share this with clients for cert pinning.
## 📊 QoS System
The Rust daemon includes a full QoS stack that operates on decrypted IP packets:
### Adaptive Keepalive
The keepalive system automatically adjusts its interval based on connection quality:
| Link Health | Keepalive Interval | Triggered When |
|-------------|-------------------|----------------|
| 🟢 Healthy | 60s | Jitter < 30ms, loss < 2%, no timeouts |
| 🟡 Degraded | 30s | Jitter > 50ms, loss > 5%, or 1+ timeout |
| 🔴 Critical | 10s | Loss > 20% or 2+ consecutive timeouts |
State transitions include hysteresis (3 consecutive good checks to upgrade, 2 to recover) to prevent flapping. Dead peer detection fires after 3 consecutive timeouts in Critical state.
### Packet Classification
IP packets are classified into three priority levels by inspecting headers (no deep packet inspection):
| Priority | Traffic |
|----------|---------|
| **High** | ICMP, DNS (port 53), SSH (port 22), small packets (< 128 bytes) |
| **Normal** | Everything else |
| **Low** | Bulk flows exceeding 1 MB within a 60s window |
Priority channels drain with biased `tokio::select!` — high-priority packets always go first.
### Smart Packet Dropping
Under backpressure, packets are dropped intelligently:
1. **Low** queue full → drop silently
2. **Normal** queue full → drop
3. **High** queue full → wait 5ms, then drop as last resort
Drop statistics are tracked per priority level and exposed via telemetry.
### Per-Client Rate Limiting
Token bucket algorithm with byte granularity:
```typescript
// Set: 10 MB/s sustained, 20 MB burst
await server.setClientRateLimit('client-id', 10_000_000, 20_000_000);
// Check drops via telemetry
const t = await server.getClientTelemetry('client-id');
console.log(`Dropped: ${t.packetsDropped} packets, ${t.bytesDropped} bytes`);
// Remove limit
await server.removeClientRateLimit('client-id');
```
Rate limits can be changed live without disconnecting the client.
### Path MTU
Tunnel overhead is calculated precisely:
| Layer | Bytes |
|-------|-------|
| IP header | 20 |
| TCP header (with timestamps) | 32 |
| WebSocket framing | 6 |
| VPN frame header | 5 |
| Noise AEAD tag | 16 |
| **Total overhead** | **79** |
For a standard 1500-byte Ethernet link, effective TUN MTU = **1421 bytes**. The default TUN MTU of 1420 is conservative and correct. Oversized packets get an ICMP "Fragmentation Needed" (Type 3, Code 4) written back into the TUN, so the source TCP adjusts its MSS automatically.
## 🔐 Security Model
### smartvpn-native (WebSocket / QUIC)
The VPN uses a **Noise NK** handshake pattern:
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
2. The client generates an ephemeral keypair, performs `e, es` (Diffie-Hellman with server's static key)
3. Server responds with `e, ee` (Diffie-Hellman with both ephemeral keys)
2. The client generates an ephemeral keypair, performs `e, es` (DH with server's static key)
3. Server responds with `e, ee` (DH with both ephemeral keys)
4. Result: forward-secret transport keys derived from both DH operations
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
@@ -246,9 +632,34 @@ Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
- 16-byte authentication tags
- Wire format: `[nonce:24B][ciphertext:var][tag:16B]`
### WireGuard Mode
Uses the standard [Noise IKpsk2](https://www.wireguard.com/protocol/) handshake:
- **X25519** key exchange (Curve25519 Diffie-Hellman)
- **ChaCha20-Poly1305** AEAD encryption
- Optional **preshared keys** for post-quantum defense-in-depth
- Implemented via `boringtun` — Cloudflare's userspace WireGuard in Rust
### QUIC Certificate Pinning
When using QUIC transport, the server generates a self-signed TLS certificate (or uses a configured PEM). Instead of relying on a CA chain, clients pin the server's certificate by its **SHA-256 hash** (base64-encoded) — a WireGuard-inspired trust model:
```typescript
// Server logs the cert hash on startup:
// "QUIC cert hash: <BASE64_HASH>"
// Client pins it:
await client.connect({
serverUrl: 'vpn.example.com:443',
transport: 'quic',
serverCertHash: '<BASE64_HASH>',
serverPublicKey: '...',
});
```
## 📦 Binary Protocol
Inside the WebSocket tunnel, packets use a simple binary framing:
Inside the tunnel (both WebSocket and QUIC reliable channels), packets use a simple binary framing:
```
┌──────────┬──────────┬────────────────────┐
@@ -261,16 +672,18 @@ Inside the WebSocket tunnel, packets use a simple binary framing:
| `HandshakeInit` | `0x01` | Client → Server handshake |
| `HandshakeResp` | `0x02` | Server → Client handshake |
| `IpPacket` | `0x10` | Encrypted IP packet |
| `Keepalive` | `0x20` | App-level ping |
| `KeepaliveAck` | `0x21` | App-level pong |
| `Keepalive` | `0x20` | App-level ping (8-byte timestamp payload) |
| `KeepaliveAck` | `0x21` | App-level pong (echoes timestamp for RTT) |
| `SessionResume` | `0x30` | Resume a dropped session |
| `SessionResumeOk` | `0x31` | Resume accepted |
| `SessionResumeErr` | `0x32` | Resume rejected |
| `Disconnect` | `0x3F` | Graceful disconnect |
## 🛠️ Rust Daemon CLI
When QUIC datagrams are available, IP packets can optionally be sent via the unreliable datagram channel for lower latency. Packets that exceed the max datagram size automatically fall back to the reliable stream.
The Rust binary supports several modes:
> **Note:** WireGuard mode uses the standard WireGuard wire protocol, not this binary framing.
## 🛠️ Rust Daemon CLI
```bash
# Development: stdio management (JSON lines on stdin/stdout)
@@ -290,20 +703,18 @@ smartvpn_daemon --generate-keypair
# Install dependencies
pnpm install
# Build TypeScript + cross-compile Rust
# Build TypeScript + cross-compile Rust (amd64 + arm64)
pnpm build
# Build Rust only (debug)
cd rust && cargo build
# Run Rust tests
# Run all tests (82 Rust + 77 TypeScript)
cd rust && cargo test
# Run TypeScript tests
pnpm test
```
## TypeScript Interfaces
## 📘 TypeScript Interfaces
<details>
<summary>Click to expand full type definitions</summary>
@@ -323,25 +734,65 @@ type TVpnTransportOptions =
// Client config
interface IVpnClientConfig {
serverUrl: string; // e.g. 'wss://vpn.example.com/tunnel'
serverPublicKey: string; // base64-encoded Noise static key
serverUrl: string; // WS: 'wss://host/path' | QUIC: 'host:port'
serverPublicKey: string; // Base64-encoded Noise static key (or WG public key)
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard'; // Default: 'auto'
serverCertHash?: string; // SHA-256 cert hash (base64) for QUIC pinning
dns?: string[];
mtu?: number; // default: 1420
keepaliveIntervalSecs?: number; // default: 30
mtu?: number;
keepaliveIntervalSecs?: number;
// WireGuard-specific
wgPrivateKey?: string; // Client private key (base64, X25519)
wgAddress?: string; // Client TUN address (e.g. 10.8.0.2)
wgAddressPrefix?: number; // Address prefix length (default: 24)
wgPresharedKey?: string; // Optional preshared key (base64)
wgPersistentKeepalive?: number; // Persistent keepalive interval (seconds)
wgEndpoint?: string; // Server endpoint (host:port)
wgAllowedIps?: string[]; // Allowed IPs (CIDR strings)
}
// Server config
interface IVpnServerConfig {
listenAddr: string; // e.g. '0.0.0.0:443'
privateKey: string; // base64 Noise static private key
publicKey: string; // base64 Noise static public key
subnet: string; // e.g. '10.8.0.0/24'
listenAddr: string;
privateKey: string;
publicKey: string;
subnet: string;
tlsCert?: string;
tlsKey?: string;
dns?: string[];
mtu?: number;
keepaliveIntervalSecs?: number;
enableNat?: boolean;
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
quicListenAddr?: string;
quicIdleTimeoutSecs?: number;
defaultRateLimitBytesPerSec?: number;
defaultBurstBytes?: number;
// WireGuard-specific
wgListenPort?: number; // UDP port (default: 51820)
wgPeers?: IWgPeerConfig[]; // Initial peers
}
// WireGuard peer config
interface IWgPeerConfig {
publicKey: string; // Peer's X25519 public key (base64)
presharedKey?: string; // Optional preshared key (base64)
allowedIps: string[]; // Allowed IP ranges (CIDR)
endpoint?: string; // Peer endpoint (host:port)
persistentKeepalive?: number; // Keepalive interval (seconds)
}
// WireGuard peer info (with live stats)
interface IWgPeerInfo {
publicKey: string;
allowedIps: string[];
endpoint?: string;
persistentKeepalive?: number;
bytesSent: number;
bytesReceived: number;
packetsSent: number;
packetsReceived: number;
lastHandshakeTime?: string;
}
// Status
@@ -365,6 +816,7 @@ interface IVpnStatistics {
keepalivesSent: number;
keepalivesReceived: number;
uptimeSeconds: number;
quality?: IVpnConnectionQuality;
}
interface IVpnServerStatistics extends IVpnStatistics {
@@ -372,12 +824,57 @@ interface IVpnServerStatistics extends IVpnStatistics {
totalConnections: number;
}
// Connection quality (QoS)
type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
interface IVpnConnectionQuality {
srttMs: number;
jitterMs: number;
minRttMs: number;
maxRttMs: number;
lossRatio: number;
consecutiveTimeouts: number;
linkHealth: TVpnLinkHealth;
currentKeepaliveIntervalSecs: number;
}
// MTU info
interface IVpnMtuInfo {
tunMtu: number;
effectiveMtu: number;
linkMtu: number;
overheadBytes: number;
oversizedPacketsDropped: number;
icmpTooBigSent: number;
}
// Client info (with QoS fields)
interface IVpnClientInfo {
clientId: string;
assignedIp: string;
connectedSince: string;
bytesSent: number;
bytesReceived: number;
packetsDropped: number;
bytesDropped: number;
lastKeepaliveAt?: string;
keepalivesReceived: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
}
// Per-client telemetry
interface IVpnClientTelemetry {
clientId: string;
assignedIp: string;
lastKeepaliveAt?: string;
keepalivesReceived: number;
packetsDropped: number;
bytesDropped: number;
bytesReceived: number;
bytesSent: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
}
interface IVpnKeypair {
@@ -390,7 +887,7 @@ interface IVpnKeypair {
## License and Legal Information
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./license.md) file.
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.

2
rust/.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"

682
rust/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,16 @@ tun = { version = "0.7", features = ["async"] }
bytes = "1"
tokio-util = "0.7"
futures-util = "0.3"
async-trait = "0.1"
quinn = "0.11"
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
rcgen = "0.13"
ring = "0.17"
rustls-pki-types = "1"
rustls-pemfile = "2"
webpki-roots = "1"
mimalloc = "0.1"
boringtun = "0.7"
[profile.release]
opt-level = 3

View File

@@ -1,16 +1,17 @@
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn};
use tokio::sync::{mpsc, watch, RwLock};
use tracing::{info, error, warn, debug};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Client configuration (matches TS IVpnClientConfig).
#[derive(Debug, Clone, Deserialize)]
@@ -21,6 +22,10 @@ pub struct ClientConfig {
pub dns: Option<Vec<String>>,
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
/// Transport type: "websocket" (default) or "quic".
pub transport: Option<String>,
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
pub server_cert_hash: Option<String>,
}
/// Client statistics.
@@ -65,6 +70,8 @@ pub struct VpnClient {
assigned_ip: Arc<RwLock<Option<String>>>,
shutdown_tx: Option<mpsc::Sender<()>>,
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
link_health: Arc<RwLock<LinkHealth>>,
}
impl VpnClient {
@@ -75,6 +82,8 @@ impl VpnClient {
assigned_ip: Arc::new(RwLock::new(None)),
shutdown_tx: None,
connected_since: Arc::new(RwLock::new(None)),
quality_rx: None,
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
}
}
@@ -93,6 +102,7 @@ impl VpnClient {
let stats = self.stats.clone();
let assigned_ip_ref = self.assigned_ip.clone();
let connected_since = self.connected_since.clone();
let link_health = self.link_health.clone();
// Decode server public key
let server_pub_key = base64::Engine::decode(
@@ -100,9 +110,66 @@ impl VpnClient {
&config.server_public_key,
)?;
// Connect to WebSocket server
let ws = transport::connect_to_server(&config.server_url).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
// Create transport based on configuration
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
let transport_type = config.transport.as_deref().unwrap_or("auto");
match transport_type {
"quic" => {
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
let cert_hash = config.server_cert_hash.as_deref();
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
info!("Connected via QUIC");
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
"websocket" => {
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
_ => {
// "auto" (default): try QUIC first, fall back to WebSocket
// Extract host:port from the URL for QUIC attempt
let quic_addr = extract_host_port(&config.server_url);
let cert_hash = config.server_cert_hash.as_deref();
if let Some(ref addr) = quic_addr {
match tokio::time::timeout(
std::time::Duration::from_secs(3),
try_quic_connect(addr, cert_hash),
).await {
Ok(Ok((quic_sink, quic_stream))) => {
info!("Auto: connected via QUIC to {}", addr);
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
Ok(Err(e)) => {
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC unavailable)");
(Box::new(ws_sink), Box::new(ws_stream))
}
Err(_) => {
debug!("Auto: QUIC timed out, falling back to WebSocket");
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC timed out)");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
} else {
// Can't extract host:port for QUIC, use WebSocket directly
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
}
};
// Noise NK handshake (client side = initiator)
*state.write().await = ClientState::Handshaking;
@@ -117,13 +184,11 @@ impl VpnClient {
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
// <- e, ee
let resp_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
let resp_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed during handshake"),
};
@@ -139,9 +204,9 @@ impl VpnClient {
let mut noise_transport = initiator.into_transport_mode()?;
// Receive assigned IP info (encrypted)
let info_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected IP info message"),
let info_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before IP info"),
};
let mut frame_buf = BytesMut::from(&info_msg[..]);
@@ -161,16 +226,32 @@ impl VpnClient {
info!("Connected to VPN, assigned IP: {}", assigned_ip);
// Create adaptive keepalive monitor (use custom interval if configured)
let ka_config = config.keepalive_interval_secs.map(|secs| {
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
cfg.degraded_interval = std::time::Duration::from_secs(secs);
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
cfg
});
let (monitor, handle) = keepalive::create_keepalive(ka_config);
self.quality_rx = Some(handle.quality_rx);
// Spawn the keepalive monitor
tokio::spawn(monitor.run());
// Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop(
ws_sink,
ws_stream,
sink,
stream,
noise_transport,
state,
stats,
shutdown_rx,
config.keepalive_interval_secs.unwrap_or(30),
handle.signal_rx,
handle.ack_tx,
link_health,
));
Ok(assigned_ip_clone)
@@ -184,6 +265,7 @@ impl VpnClient {
*self.assigned_ip.write().await = None;
*self.connected_since.write().await = None;
*self.state.write().await = ClientState::Disconnected;
self.quality_rx = None;
info!("Disconnected from VPN");
Ok(())
}
@@ -208,13 +290,14 @@ impl VpnClient {
status
}
/// Get traffic statistics.
/// Get traffic statistics (includes connection quality).
pub async fn get_statistics(&self) -> serde_json::Value {
let stats = self.stats.read().await;
let since = self.connected_since.read().await;
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
let health = self.link_health.read().await;
serde_json::json!({
let mut result = serde_json::json!({
"bytesSent": stats.bytes_sent,
"bytesReceived": stats.bytes_received,
"packetsSent": stats.packets_sent,
@@ -222,30 +305,58 @@ impl VpnClient {
"keepalivesSent": stats.keepalives_sent,
"keepalivesReceived": stats.keepalives_received,
"uptimeSeconds": uptime,
})
});
// Include connection quality if available
if let Some(ref rx) = self.quality_rx {
let quality = rx.borrow().clone();
result["quality"] = serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", *health),
"keepalivesSent": quality.keepalives_sent,
"keepalivesAcked": quality.keepalives_acked,
});
}
result
}
/// Get connection quality snapshot.
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
}
/// Get current link health.
pub async fn get_link_health(&self) -> LinkHealth {
*self.link_health.read().await
}
}
/// The main client packet forwarding loop (runs in a spawned task).
async fn client_loop(
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
mut noise_transport: snow::TransportState,
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
mut shutdown_rx: mpsc::Receiver<()>,
keepalive_secs: u64,
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
ack_tx: mpsc::Sender<()>,
link_health: Arc<RwLock<LinkHealth>>,
) {
let mut buf = vec![0u8; 65535];
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
keepalive_ticker.tick().await; // skip first immediate tick
loop {
tokio::select! {
msg = ws_stream.next() => {
msg = stream.recv_reliable() => {
match msg {
Some(Ok(Message::Binary(data))) => {
let mut frame_buf = BytesMut::from(&data[..][..]);
Ok(Some(data)) => {
let mut frame_buf = BytesMut::from(&data[..]);
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
match frame.packet_type {
PacketType::IpPacket => {
@@ -264,6 +375,8 @@ async fn client_loop(
}
PacketType::KeepaliveAck => {
stats.write().await.keepalives_received += 1;
// Signal the keepalive monitor that ACK was received
let _ = ack_tx.send(()).await;
}
PacketType::Disconnect => {
info!("Server sent disconnect");
@@ -274,35 +387,49 @@ async fn client_loop(
}
}
}
Some(Ok(Message::Close(_))) | None => {
Ok(None) => {
info!("Connection closed");
*state.write().await = ClientState::Disconnected;
break;
}
Some(Ok(Message::Ping(data))) => {
let _ = ws_sink.send(Message::Pong(data)).await;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
error!("WebSocket error: {}", e);
Err(e) => {
error!("Transport error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
}
}
_ = keepalive_ticker.tick() => {
let ka_frame = Frame {
packet_type: PacketType::Keepalive,
payload: vec![],
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
warn!("Failed to send keepalive");
signal = signal_rx.recv() => {
match signal {
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
let ka_frame = Frame {
packet_type: PacketType::Keepalive,
payload: timestamp_ms.to_be_bytes().to_vec(),
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
warn!("Failed to send keepalive");
*state.write().await = ClientState::Disconnected;
break;
}
stats.write().await.keepalives_sent += 1;
}
}
Some(KeepaliveSignal::PeerDead) => {
warn!("Peer declared dead by keepalive monitor");
*state.write().await = ClientState::Disconnected;
break;
}
stats.write().await.keepalives_sent += 1;
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
debug!("Link health changed to: {}", health);
*link_health.write().await = health;
}
None => {
// Keepalive monitor channel closed
break;
}
}
}
_ = shutdown_rx.recv() => {
@@ -313,12 +440,51 @@ async fn client_loop(
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
}
let _ = ws_sink.close().await;
let _ = sink.close().await;
*state.write().await = ClientState::Disconnected;
break;
}
}
}
}
/// Try to connect via QUIC. Returns transport halves on success.
async fn try_quic_connect(
addr: &str,
cert_hash: Option<&str>,
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
Ok((sink, stream))
}
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
fn extract_host_port(url: &str) -> Option<String> {
if url.starts_with("ws://") || url.starts_with("wss://") {
// Parse as URL
let stripped = if url.starts_with("wss://") {
&url[6..]
} else {
&url[5..]
};
// Remove path
let host_port = stripped.split('/').next()?;
if host_port.contains(':') {
Some(host_port.to_string())
} else {
// Default port
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
Some(format!("{}:{}", host_port, default_port))
}
} else if url.contains(':') {
// Already host:port
Some(url.to_string())
} else {
None
}
}

View File

@@ -1,87 +1,464 @@
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, watch};
use tokio::time::{interval, timeout};
use tracing::{debug, warn};
use tracing::{debug, info, warn};
/// Default keepalive interval (30 seconds).
use crate::telemetry::{ConnectionQuality, RttTracker};
/// Default keepalive interval (30 seconds — used for Degraded state).
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
/// Default keepalive ACK timeout (10 seconds).
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
/// Default keepalive ACK timeout (5 seconds).
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
/// Link health states for adaptive keepalive.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinkHealth {
/// RTT stable, jitter low, no loss. Interval: 60s.
Healthy,
/// Elevated jitter or occasional loss. Interval: 30s.
Degraded,
/// High loss or sustained jitter spike. Interval: 10s.
Critical,
}
impl std::fmt::Display for LinkHealth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Healthy => write!(f, "healthy"),
Self::Degraded => write!(f, "degraded"),
Self::Critical => write!(f, "critical"),
}
}
}
/// Configuration for the adaptive keepalive state machine.
#[derive(Debug, Clone)]
pub struct AdaptiveKeepaliveConfig {
/// Interval when link health is Healthy.
pub healthy_interval: Duration,
/// Interval when link health is Degraded.
pub degraded_interval: Duration,
/// Interval when link health is Critical.
pub critical_interval: Duration,
/// ACK timeout (how long to wait for ACK before declaring timeout).
pub ack_timeout: Duration,
/// Jitter threshold (ms) to enter Degraded from Healthy.
pub jitter_degraded_ms: f64,
/// Jitter threshold (ms) to return to Healthy from Degraded.
pub jitter_healthy_ms: f64,
/// Loss ratio threshold to enter Degraded.
pub loss_degraded: f64,
/// Loss ratio threshold to enter Critical.
pub loss_critical: f64,
/// Loss ratio threshold to return from Critical to Degraded.
pub loss_recover: f64,
/// Loss ratio threshold to return from Degraded to Healthy.
pub loss_healthy: f64,
/// Consecutive checks required for upward state transitions (hysteresis).
pub upgrade_checks: u32,
/// Consecutive timeouts to declare peer dead in Critical state.
pub dead_peer_timeouts: u32,
}
impl Default for AdaptiveKeepaliveConfig {
fn default() -> Self {
Self {
healthy_interval: Duration::from_secs(60),
degraded_interval: Duration::from_secs(30),
critical_interval: Duration::from_secs(10),
ack_timeout: Duration::from_secs(5),
jitter_degraded_ms: 50.0,
jitter_healthy_ms: 30.0,
loss_degraded: 0.05,
loss_critical: 0.20,
loss_recover: 0.10,
loss_healthy: 0.02,
upgrade_checks: 3,
dead_peer_timeouts: 3,
}
}
}
/// Signals from the keepalive monitor.
#[derive(Debug, Clone)]
pub enum KeepaliveSignal {
/// Time to send a keepalive ping.
SendPing,
/// Peer is considered dead (no ACK received within timeout).
/// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
SendPing(u64),
/// Peer is considered dead (no ACK received within timeout repeatedly).
PeerDead,
/// Link health state changed.
LinkHealthChanged(LinkHealth),
}
/// A keepalive monitor that emits signals on a channel.
/// A keepalive monitor with adaptive interval and RTT tracking.
pub struct KeepaliveMonitor {
interval: Duration,
timeout_duration: Duration,
config: AdaptiveKeepaliveConfig,
health: LinkHealth,
rtt_tracker: RttTracker,
signal_tx: mpsc::Sender<KeepaliveSignal>,
ack_rx: mpsc::Receiver<()>,
quality_tx: watch::Sender<ConnectionQuality>,
consecutive_upgrade_checks: u32,
}
/// Handle returned to the caller to send ACKs and receive signals.
pub struct KeepaliveHandle {
pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
pub ack_tx: mpsc::Sender<()>,
pub quality_rx: watch::Receiver<ConnectionQuality>,
}
/// Create a keepalive monitor and its handle.
/// Create an adaptive keepalive monitor and its handle.
pub fn create_keepalive(
keepalive_interval: Option<Duration>,
keepalive_timeout: Option<Duration>,
config: Option<AdaptiveKeepaliveConfig>,
) -> (KeepaliveMonitor, KeepaliveHandle) {
let config = config.unwrap_or_default();
let (signal_tx, signal_rx) = mpsc::channel(8);
let (ack_tx, ack_rx) = mpsc::channel(8);
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
let monitor = KeepaliveMonitor {
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL),
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT),
config,
health: LinkHealth::Degraded, // start in Degraded, earn Healthy
rtt_tracker: RttTracker::new(30),
signal_tx,
ack_rx,
quality_tx,
consecutive_upgrade_checks: 0,
};
let handle = KeepaliveHandle { signal_rx, ack_tx };
let handle = KeepaliveHandle {
signal_rx,
ack_tx,
quality_rx,
};
(monitor, handle)
}
impl KeepaliveMonitor {
fn current_interval(&self) -> Duration {
match self.health {
LinkHealth::Healthy => self.config.healthy_interval,
LinkHealth::Degraded => self.config.degraded_interval,
LinkHealth::Critical => self.config.critical_interval,
}
}
/// Run the keepalive loop. Blocks until the peer is dead or channels close.
pub async fn run(mut self) {
let mut ticker = interval(self.interval);
let mut ticker = interval(self.current_interval());
ticker.tick().await; // skip first immediate tick
loop {
ticker.tick().await;
debug!("Sending keepalive ping signal");
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() {
// Channel closed
break;
// Record ping sent, get timestamp for payload
let timestamp_ms = self.rtt_tracker.mark_ping_sent();
debug!("Sending keepalive ping (ts={})", timestamp_ms);
if self
.signal_tx
.send(KeepaliveSignal::SendPing(timestamp_ms))
.await
.is_err()
{
break; // channel closed
}
// Wait for ACK within timeout
match timeout(self.timeout_duration, self.ack_rx.recv()).await {
match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
Ok(Some(())) => {
debug!("Keepalive ACK received");
if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
debug!("Keepalive ACK received, RTT: {:?}", rtt);
}
}
Ok(None) => {
// Channel closed
break;
break; // channel closed
}
Err(_) => {
warn!("Keepalive ACK timeout — peer considered dead");
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
break;
self.rtt_tracker.record_timeout();
warn!(
"Keepalive ACK timeout (consecutive: {})",
self.rtt_tracker.consecutive_timeouts
);
}
}
// Publish quality snapshot
let quality = self.rtt_tracker.snapshot();
let _ = self.quality_tx.send(quality.clone());
// Evaluate state transition
let new_health = self.evaluate_health(&quality);
if new_health != self.health {
info!("Link health: {} -> {}", self.health, new_health);
self.health = new_health;
self.consecutive_upgrade_checks = 0;
// Reset ticker to new interval
ticker = interval(self.current_interval());
ticker.tick().await; // skip first immediate tick
let _ = self
.signal_tx
.send(KeepaliveSignal::LinkHealthChanged(new_health))
.await;
}
// Check for dead peer in Critical state
if self.health == LinkHealth::Critical
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
{
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
self.rtt_tracker.consecutive_timeouts);
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
break;
}
}
}
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
match self.health {
LinkHealth::Healthy => {
// Downgrade conditions
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Critical;
}
if quality.jitter_ms > self.config.jitter_degraded_ms
|| quality.loss_ratio > self.config.loss_degraded
|| quality.consecutive_timeouts >= 1
{
self.consecutive_upgrade_checks = 0;
return LinkHealth::Degraded;
}
LinkHealth::Healthy
}
LinkHealth::Degraded => {
// Downgrade to Critical
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Critical;
}
// Upgrade to Healthy (with hysteresis)
if quality.jitter_ms < self.config.jitter_healthy_ms
&& quality.loss_ratio < self.config.loss_healthy
&& quality.consecutive_timeouts == 0
{
self.consecutive_upgrade_checks += 1;
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Healthy;
}
} else {
self.consecutive_upgrade_checks = 0;
}
LinkHealth::Degraded
}
LinkHealth::Critical => {
// Upgrade to Degraded (with hysteresis), never directly to Healthy
if quality.loss_ratio < self.config.loss_recover
&& quality.consecutive_timeouts == 0
{
self.consecutive_upgrade_checks += 1;
if self.consecutive_upgrade_checks >= 2 {
self.consecutive_upgrade_checks = 0;
return LinkHealth::Degraded;
}
} else {
self.consecutive_upgrade_checks = 0;
}
LinkHealth::Critical
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let config = AdaptiveKeepaliveConfig::default();
assert_eq!(config.healthy_interval, Duration::from_secs(60));
assert_eq!(config.degraded_interval, Duration::from_secs(30));
assert_eq!(config.critical_interval, Duration::from_secs(10));
assert_eq!(config.ack_timeout, Duration::from_secs(5));
assert_eq!(config.dead_peer_timeouts, 3);
}
#[test]
fn link_health_display() {
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
}
// Helper to create a monitor for unit-testing evaluate_health
fn make_test_monitor() -> KeepaliveMonitor {
let (signal_tx, _signal_rx) = mpsc::channel(8);
let (_ack_tx, ack_rx) = mpsc::channel(8);
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
KeepaliveMonitor {
config: AdaptiveKeepaliveConfig::default(),
health: LinkHealth::Degraded,
rtt_tracker: RttTracker::new(30),
signal_tx,
ack_rx,
quality_tx,
consecutive_upgrade_checks: 0,
}
}
#[test]
fn healthy_to_degraded_on_jitter() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
jitter_ms: 60.0, // > 50ms threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Degraded);
}
#[test]
fn healthy_to_degraded_on_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
loss_ratio: 0.06, // > 5% threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Degraded);
}
#[test]
fn healthy_to_critical_on_high_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
loss_ratio: 0.25, // > 20% threshold
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Critical);
}
#[test]
fn healthy_to_critical_on_consecutive_timeouts() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
let q = ConnectionQuality {
consecutive_timeouts: 2,
..Default::default()
};
let result = m.evaluate_health(&q);
assert_eq!(result, LinkHealth::Critical);
}
#[test]
fn degraded_to_healthy_requires_hysteresis() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let good_quality = ConnectionQuality {
jitter_ms: 10.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
srtt_ms: 20.0,
..Default::default()
};
// Should require 3 consecutive good checks (default upgrade_checks)
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
assert_eq!(m.consecutive_upgrade_checks, 1);
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
assert_eq!(m.consecutive_upgrade_checks, 2);
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
}
#[test]
fn degraded_to_healthy_resets_on_bad_check() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let good = ConnectionQuality {
jitter_ms: 10.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
..Default::default()
};
let bad = ConnectionQuality {
jitter_ms: 60.0, // too high
loss_ratio: 0.0,
consecutive_timeouts: 0,
..Default::default()
};
m.evaluate_health(&good); // 1 check
m.evaluate_health(&good); // 2 checks
m.evaluate_health(&bad); // resets
assert_eq!(m.consecutive_upgrade_checks, 0);
}
#[test]
fn critical_to_degraded_requires_hysteresis() {
let mut m = make_test_monitor();
m.health = LinkHealth::Critical;
let recovering = ConnectionQuality {
loss_ratio: 0.05, // < 10% recover threshold
consecutive_timeouts: 0,
..Default::default()
};
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
assert_eq!(m.consecutive_upgrade_checks, 1);
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
}
#[test]
fn critical_never_directly_to_healthy() {
let mut m = make_test_monitor();
m.health = LinkHealth::Critical;
let perfect = ConnectionQuality {
jitter_ms: 1.0,
loss_ratio: 0.0,
consecutive_timeouts: 0,
srtt_ms: 10.0,
..Default::default()
};
// Even with perfect quality, must go through Degraded first
m.evaluate_health(&perfect); // 1
let result = m.evaluate_health(&perfect); // 2 → Degraded
assert_eq!(result, LinkHealth::Degraded);
// Not Healthy yet
}
#[test]
fn degraded_to_critical_on_high_loss() {
let mut m = make_test_monitor();
m.health = LinkHealth::Degraded;
let q = ConnectionQuality {
loss_ratio: 0.25,
..Default::default()
};
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
}
#[test]
fn interval_matches_health() {
let mut m = make_test_monitor();
m.health = LinkHealth::Healthy;
assert_eq!(m.current_interval(), Duration::from_secs(60));
m.health = LinkHealth::Degraded;
assert_eq!(m.current_interval(), Duration::from_secs(30));
m.health = LinkHealth::Critical;
assert_eq!(m.current_interval(), Duration::from_secs(10));
}
}

View File

@@ -5,9 +5,16 @@ pub mod management;
pub mod codec;
pub mod crypto;
pub mod transport;
pub mod transport_trait;
pub mod quic_transport;
pub mod keepalive;
pub mod tunnel;
pub mod network;
pub mod server;
pub mod client;
pub mod reconnect;
pub mod telemetry;
pub mod ratelimit;
pub mod qos;
pub mod mtu;
pub mod wireguard;

View File

@@ -7,6 +7,7 @@ use tracing::{info, error, warn};
use crate::client::{ClientConfig, VpnClient};
use crate::crypto;
use crate::server::{ServerConfig, VpnServer};
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig};
// ============================================================================
// IPC protocol types
@@ -93,6 +94,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
let mut vpn_client = VpnClient::new();
let mut vpn_server = VpnServer::new();
let mut wg_client = WgClient::new();
let mut wg_server = WgServer::new();
send_event_stdout("ready", serde_json::json!({ "mode": mode }));
@@ -127,8 +130,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
};
let response = match mode {
"client" => handle_client_request(&request, &mut vpn_client).await,
"server" => handle_server_request(&request, &mut vpn_server).await,
"client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await,
"server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await,
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
};
send_response_stdout(&response);
@@ -150,6 +153,8 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
// Shared state behind Mutex for socket mode (multiple connections)
let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new()));
let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new()));
let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new()));
let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new()));
loop {
match listener.accept().await {
@@ -157,9 +162,11 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
let mode = mode.to_string();
let client = vpn_client.clone();
let server = vpn_server.clone();
let wg_c = wg_client.clone();
let wg_s = wg_server.clone();
tokio::spawn(async move {
if let Err(e) =
handle_socket_connection(stream, &mode, client, server).await
handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await
{
warn!("Socket connection error: {}", e);
}
@@ -177,6 +184,8 @@ async fn handle_socket_connection(
mode: &str,
vpn_client: std::sync::Arc<Mutex<VpnClient>>,
vpn_server: std::sync::Arc<Mutex<VpnServer>>,
wg_client: std::sync::Arc<Mutex<WgClient>>,
wg_server: std::sync::Arc<Mutex<WgServer>>,
) -> Result<()> {
let (reader, mut writer) = stream.into_split();
let buf_reader = BufReader::new(reader);
@@ -227,11 +236,13 @@ async fn handle_socket_connection(
let response = match mode {
"client" => {
let mut client = vpn_client.lock().await;
handle_client_request(&request, &mut client).await
let mut wg_c = wg_client.lock().await;
handle_client_request(&request, &mut client, &mut wg_c).await
}
"server" => {
let mut server = vpn_server.lock().await;
handle_server_request(&request, &mut server).await
let mut wg_s = wg_server.lock().await;
handle_server_request(&request, &mut server, &mut wg_s).await
}
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
};
@@ -252,38 +263,112 @@ async fn handle_socket_connection(
async fn handle_client_request(
request: &ManagementRequest,
vpn_client: &mut VpnClient,
wg_client: &mut WgClient,
) -> ManagementResponse {
let id = request.id.clone();
match request.method.as_str() {
"connect" => {
let config: ClientConfig = match serde_json::from_value(
request.params.get("config").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid config: {}", e));
}
};
// Check if transport is "wireguard"
let transport = request.params
.get("config")
.and_then(|c| c.get("transport"))
.and_then(|t| t.as_str())
.unwrap_or("");
match vpn_client.connect(config).await {
Ok(assigned_ip) => {
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
if transport == "wireguard" {
let config: WgClientConfig = match serde_json::from_value(
request.params.get("config").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
}
};
match wg_client.connect(config).await {
Ok(assigned_ip) => {
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
}
Err(e) => ManagementResponse::err(id, format!("WG connect failed: {}", e)),
}
} else {
let config: ClientConfig = match serde_json::from_value(
request.params.get("config").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid config: {}", e));
}
};
match vpn_client.connect(config).await {
Ok(assigned_ip) => {
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
}
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
}
}
}
"disconnect" => {
if wg_client.is_running() {
match wg_client.disconnect().await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("WG disconnect failed: {}", e)),
}
} else {
match vpn_client.disconnect().await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
}
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
}
}
"disconnect" => match vpn_client.disconnect().await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
},
"getStatus" => {
let status = vpn_client.get_status().await;
ManagementResponse::ok(id, status)
if wg_client.is_running() {
ManagementResponse::ok(id, wg_client.get_status().await)
} else {
let status = vpn_client.get_status().await;
ManagementResponse::ok(id, status)
}
}
"getStatistics" => {
let stats = vpn_client.get_statistics().await;
ManagementResponse::ok(id, stats)
if wg_client.is_running() {
ManagementResponse::ok(id, wg_client.get_statistics().await)
} else {
let stats = vpn_client.get_statistics().await;
ManagementResponse::ok(id, stats)
}
}
"getConnectionQuality" => {
match vpn_client.get_connection_quality() {
Some(quality) => {
let health = vpn_client.get_link_health().await;
let interval_secs = match health {
crate::keepalive::LinkHealth::Healthy => 60,
crate::keepalive::LinkHealth::Degraded => 30,
crate::keepalive::LinkHealth::Critical => 10,
};
ManagementResponse::ok(id, serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", health),
"currentKeepaliveIntervalSecs": interval_secs,
}))
}
None => ManagementResponse::ok(id, serde_json::json!(null)),
}
}
"getMtuInfo" => {
ManagementResponse::ok(id, serde_json::json!({
"tunMtu": 1420,
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
"linkMtu": 1500,
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
"oversizedPacketsDropped": 0,
"icmpTooBigSent": 0,
}))
}
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
}
@@ -296,45 +381,92 @@ async fn handle_client_request(
async fn handle_server_request(
request: &ManagementRequest,
vpn_server: &mut VpnServer,
wg_server: &mut WgServer,
) -> ManagementResponse {
let id = request.id.clone();
match request.method.as_str() {
"start" => {
let config: ServerConfig = match serde_json::from_value(
request.params.get("config").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid config: {}", e));
}
};
// Check if transportMode is "wireguard"
let transport_mode = request.params
.get("config")
.and_then(|c| c.get("transportMode"))
.and_then(|t| t.as_str())
.unwrap_or("");
match vpn_server.start(config).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
if transport_mode == "wireguard" {
let config: WgServerConfig = match serde_json::from_value(
request.params.get("config").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
}
};
match wg_server.start(config).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)),
}
} else {
let config: ServerConfig = match serde_json::from_value(
request.params.get("config").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid config: {}", e));
}
};
match vpn_server.start(config).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
}
}
}
"stop" => {
if wg_server.is_running() {
match wg_server.stop().await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)),
}
} else {
match vpn_server.stop().await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
}
}
}
"stop" => match vpn_server.stop().await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
},
"getStatus" => {
let status = vpn_server.get_status();
ManagementResponse::ok(id, status)
if wg_server.is_running() {
ManagementResponse::ok(id, wg_server.get_status())
} else {
let status = vpn_server.get_status();
ManagementResponse::ok(id, status)
}
}
"getStatistics" => {
let stats = vpn_server.get_statistics().await;
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id, v),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
if wg_server.is_running() {
ManagementResponse::ok(id, wg_server.get_statistics().await)
} else {
let stats = vpn_server.get_statistics().await;
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id, v),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
}
}
"listClients" => {
let clients = vpn_server.list_clients().await;
match serde_json::to_value(&clients) {
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
if wg_server.is_running() {
let peers = wg_server.list_peers().await;
match serde_json::to_value(&peers) {
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
} else {
let clients = vpn_server.list_clients().await;
match serde_json::to_value(&clients) {
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
}
}
"disconnectClient" => {
@@ -349,6 +481,50 @@ async fn handle_server_request(
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
}
}
"setClientRateLimit" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
Some(r) => r,
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
};
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
Some(b) => b,
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
};
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
}
}
"removeClientRateLimit" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
match vpn_server.remove_client_rate_limit(&client_id).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
}
}
"getClientTelemetry" => {
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
Some(cid) => cid.to_string(),
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
};
let clients = vpn_server.list_clients().await;
match clients.into_iter().find(|c| c.client_id == client_id) {
Some(info) => {
match serde_json::to_value(&info) {
Ok(v) => ManagementResponse::ok(id, v),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
}
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
}
}
"generateKeypair" => match crypto::generate_keypair() {
Ok((public_key, private_key)) => ManagementResponse::ok(
id,
@@ -359,6 +535,56 @@ async fn handle_server_request(
),
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
},
"generateWgKeypair" => {
let (public_key, private_key) = wireguard::generate_wg_keypair();
ManagementResponse::ok(
id,
serde_json::json!({
"publicKey": public_key,
"privateKey": private_key,
}),
)
}
"addWgPeer" => {
if !wg_server.is_running() {
return ManagementResponse::err(id, "WireGuard server not running".to_string());
}
let config: WgPeerConfig = match serde_json::from_value(
request.params.get("peer").cloned().unwrap_or_default(),
) {
Ok(c) => c,
Err(e) => {
return ManagementResponse::err(id, format!("Invalid peer config: {}", e));
}
};
match wg_server.add_peer(config).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)),
}
}
"removeWgPeer" => {
if !wg_server.is_running() {
return ManagementResponse::err(id, "WireGuard server not running".to_string());
}
let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id, "Missing publicKey".to_string()),
};
match wg_server.remove_peer(&public_key).await {
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)),
}
}
"listWgPeers" => {
if !wg_server.is_running() {
return ManagementResponse::err(id, "WireGuard server not running".to_string());
}
let peers = wg_server.list_peers().await;
match serde_json::to_value(&peers) {
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })),
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
}
}
_ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)),
}
}

314
rust/src/mtu.rs Normal file
View File

@@ -0,0 +1,314 @@
use std::net::Ipv4Addr;
/// Overhead breakdown for VPN tunnel encapsulation.
#[derive(Debug, Clone)]
pub struct TunnelOverhead {
/// Outer IP header: 20 bytes (IPv4, no options).
pub ip_header: u16,
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
pub tcp_header: u16,
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
pub ws_framing: u16,
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
pub vpn_header: u16,
/// Noise AEAD tag: 16 bytes (Poly1305).
pub noise_tag: u16,
}
impl TunnelOverhead {
/// Conservative default overhead estimate.
pub fn default_overhead() -> Self {
Self {
ip_header: 20,
tcp_header: 32,
ws_framing: 6,
vpn_header: 5,
noise_tag: 16,
}
}
/// Total encapsulation overhead in bytes.
pub fn total(&self) -> u16 {
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
}
/// Compute effective TUN MTU given the underlying link MTU.
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
link_mtu.saturating_sub(self.total())
}
}
/// MTU configuration for the VPN tunnel.
#[derive(Debug, Clone)]
pub struct MtuConfig {
/// Underlying link MTU (typically 1500 for Ethernet).
pub link_mtu: u16,
/// Computed effective TUN MTU.
pub effective_mtu: u16,
/// Whether to generate ICMP too-big for oversized packets.
pub send_icmp_too_big: bool,
/// Counter: oversized packets encountered.
pub oversized_packets: u64,
/// Counter: ICMP too-big messages generated.
pub icmp_too_big_sent: u64,
}
impl MtuConfig {
/// Create a new MTU config from the underlying link MTU.
pub fn new(link_mtu: u16) -> Self {
let overhead = TunnelOverhead::default_overhead();
let effective = overhead.effective_tun_mtu(link_mtu);
Self {
link_mtu,
effective_mtu: effective,
send_icmp_too_big: true,
oversized_packets: 0,
icmp_too_big_sent: 0,
}
}
/// Check if a packet exceeds the effective MTU.
pub fn is_oversized(&self, packet_len: usize) -> bool {
packet_len > self.effective_mtu as usize
}
}
/// Action to take after checking MTU.
pub enum MtuAction {
/// Packet is within MTU, forward normally.
Forward,
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
SendIcmpTooBig(Vec<u8>),
}
/// Check packet against MTU config and return the appropriate action.
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
if !config.is_oversized(packet.len()) {
return MtuAction::Forward;
}
if !config.send_icmp_too_big {
return MtuAction::Forward;
}
match generate_icmp_too_big(packet, config.effective_mtu) {
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
None => MtuAction::Forward,
}
}
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
///
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
/// Returns the complete IP + ICMP packet to write back into the TUN device.
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
// Need at least 20 bytes of original IP header
if original_packet.len() < 20 {
return None;
}
// Verify it's IPv4
if original_packet[0] >> 4 != 4 {
return None;
}
// Parse source/dest from original IP header
let src_ip = Ipv4Addr::new(
original_packet[12],
original_packet[13],
original_packet[14],
original_packet[15],
);
let dst_ip = Ipv4Addr::new(
original_packet[16],
original_packet[17],
original_packet[18],
original_packet[19],
);
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
let icmp_payload = &original_packet[..icmp_data_len];
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
icmp.push(3); // Type: Destination Unreachable
icmp.push(4); // Code: Fragmentation Needed and DF was Set
icmp.push(0); // Checksum placeholder
icmp.push(0);
icmp.push(0); // Unused
icmp.push(0);
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
icmp.extend_from_slice(icmp_payload);
// Compute ICMP checksum
let cksum = internet_checksum(&icmp);
icmp[2] = (cksum >> 8) as u8;
icmp[3] = (cksum & 0xff) as u8;
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
let total_len = (20 + icmp.len()) as u16;
let mut ip = Vec::with_capacity(total_len as usize);
ip.push(0x45); // Version 4, IHL 5
ip.push(0x00); // DSCP/ECN
ip.extend_from_slice(&total_len.to_be_bytes());
ip.extend_from_slice(&[0, 0]); // Identification
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
ip.push(64); // TTL
ip.push(1); // Protocol: ICMP
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
// Compute IP header checksum
let ip_cksum = internet_checksum(&ip[..20]);
ip[10] = (ip_cksum >> 8) as u8;
ip[11] = (ip_cksum & 0xff) as u8;
ip.extend_from_slice(&icmp);
Some(ip)
}
/// Standard Internet checksum (RFC 1071).
fn internet_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
i += 2;
}
if i < data.len() {
sum += (data[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_overhead_total() {
let oh = TunnelOverhead::default_overhead();
assert_eq!(oh.total(), 79); // 20+32+6+5+16
}
#[test]
fn effective_mtu_for_ethernet() {
let oh = TunnelOverhead::default_overhead();
let mtu = oh.effective_tun_mtu(1500);
assert_eq!(mtu, 1421); // 1500 - 79
}
#[test]
fn effective_mtu_saturates_at_zero() {
let oh = TunnelOverhead::default_overhead();
let mtu = oh.effective_tun_mtu(50); // Less than overhead
assert_eq!(mtu, 0);
}
#[test]
fn mtu_config_default() {
let config = MtuConfig::new(1500);
assert_eq!(config.effective_mtu, 1421);
assert_eq!(config.link_mtu, 1500);
assert!(config.send_icmp_too_big);
}
#[test]
fn is_oversized() {
let config = MtuConfig::new(1500);
assert!(!config.is_oversized(1421));
assert!(config.is_oversized(1422));
}
#[test]
fn icmp_too_big_generation() {
// Craft a minimal IPv4 packet
let mut original = vec![0u8; 28];
original[0] = 0x45; // version 4, IHL 5
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
original[9] = 6; // TCP
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
// Verify it's a valid IPv4 packet
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
// Source should be original dst (10.0.0.2)
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
// Destination should be original src (10.0.0.1)
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
// ICMP type 3, code 4
assert_eq!(icmp_pkt[20], 3);
assert_eq!(icmp_pkt[21], 4);
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
assert_eq!(mtu, 1421);
}
#[test]
fn icmp_too_big_rejects_short_packet() {
let short = vec![0u8; 10];
assert!(generate_icmp_too_big(&short, 1421).is_none());
}
#[test]
fn icmp_too_big_rejects_non_ipv4() {
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
}
#[test]
fn icmp_checksum_valid() {
let mut original = vec![0u8; 28];
original[0] = 0x45;
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
original[9] = 6;
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
// Verify IP header checksum
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
// Verify ICMP checksum
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
}
#[test]
fn check_mtu_forward() {
let config = MtuConfig::new(1500);
let pkt = vec![0u8; 1421]; // Exactly at MTU
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
}
#[test]
fn check_mtu_oversized_generates_icmp() {
let config = MtuConfig::new(1500);
let mut pkt = vec![0u8; 1500];
pkt[0] = 0x45; // Valid IPv4
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
match check_mtu(&pkt, &config) {
MtuAction::SendIcmpTooBig(icmp) => {
assert_eq!(icmp[20], 3); // ICMP type
assert_eq!(icmp[21], 4); // ICMP code
}
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
}
}
}

490
rust/src/qos.rs Normal file
View File

@@ -0,0 +1,490 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
/// Priority levels for IP packets.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Priority {
High = 0,
Normal = 1,
Low = 2,
}
/// QoS statistics per priority level.
pub struct QosStats {
pub high_enqueued: AtomicU64,
pub normal_enqueued: AtomicU64,
pub low_enqueued: AtomicU64,
pub high_dropped: AtomicU64,
pub normal_dropped: AtomicU64,
pub low_dropped: AtomicU64,
}
impl QosStats {
pub fn new() -> Self {
Self {
high_enqueued: AtomicU64::new(0),
normal_enqueued: AtomicU64::new(0),
low_enqueued: AtomicU64::new(0),
high_dropped: AtomicU64::new(0),
normal_dropped: AtomicU64::new(0),
low_dropped: AtomicU64::new(0),
}
}
}
impl Default for QosStats {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Packet classification
// ============================================================================
/// 5-tuple flow key for tracking bulk flows.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct FlowKey {
src_ip: u32,
dst_ip: u32,
src_port: u16,
dst_port: u16,
protocol: u8,
}
/// Per-flow state for bulk detection.
struct FlowState {
bytes_total: u64,
window_start: Instant,
}
/// Tracks per-flow byte counts for bulk flow detection.
struct FlowTracker {
flows: HashMap<FlowKey, FlowState>,
window_duration: Duration,
max_flows: usize,
}
impl FlowTracker {
fn new(window_duration: Duration, max_flows: usize) -> Self {
Self {
flows: HashMap::new(),
window_duration,
max_flows,
}
}
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
let now = Instant::now();
// Evict if at capacity — remove oldest entry
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
if let Some(oldest_key) = self
.flows
.iter()
.min_by_key(|(_, v)| v.window_start)
.map(|(k, _)| *k)
{
self.flows.remove(&oldest_key);
}
}
let state = self.flows.entry(key).or_insert(FlowState {
bytes_total: 0,
window_start: now,
});
// Reset window if expired
if now.duration_since(state.window_start) > self.window_duration {
state.bytes_total = 0;
state.window_start = now;
}
state.bytes_total += bytes;
state.bytes_total > threshold
}
}
/// Classifies raw IP packets into priority levels.
pub struct PacketClassifier {
flow_tracker: FlowTracker,
/// Byte threshold for classifying a flow as "bulk" (Low priority).
bulk_threshold_bytes: u64,
}
impl PacketClassifier {
/// Create a new classifier.
///
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
pub fn new(bulk_threshold_bytes: u64) -> Self {
Self {
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
bulk_threshold_bytes,
}
}
/// Classify a raw IPv4 packet into a priority level.
///
/// The packet must start with the IPv4 header (as read from a TUN device).
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
// Need at least 20 bytes for a minimal IPv4 header
if ip_packet.len() < 20 {
return Priority::Normal;
}
let version = ip_packet[0] >> 4;
if version != 4 {
return Priority::Normal; // Only classify IPv4 for now
}
let ihl = (ip_packet[0] & 0x0F) as usize;
let header_len = ihl * 4;
let protocol = ip_packet[9];
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
// ICMP is always high priority
if protocol == 1 {
return Priority::High;
}
// Small packets (<128 bytes) are high priority (likely interactive)
if total_len < 128 {
return Priority::High;
}
// Extract ports for TCP (6) and UDP (17)
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
&& ip_packet.len() >= header_len + 4
{
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
(sp, dp)
} else {
(0, 0)
};
// DNS (port 53) and SSH (port 22) are high priority
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
return Priority::High;
}
// Check for bulk flow
if protocol == 6 || protocol == 17 {
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
let key = FlowKey {
src_ip,
dst_ip,
src_port,
dst_port,
protocol,
};
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
return Priority::Low;
}
}
Priority::Normal
}
}
// ============================================================================
// Priority channel set
// ============================================================================
/// Error returned when a packet is dropped.
#[derive(Debug)]
pub enum PacketDropped {
LowPriorityDrop,
NormalPriorityDrop,
HighPriorityDrop,
ChannelClosed,
}
/// Sending half of the priority channel set.
pub struct PrioritySender {
high_tx: mpsc::Sender<Vec<u8>>,
normal_tx: mpsc::Sender<Vec<u8>>,
low_tx: mpsc::Sender<Vec<u8>>,
stats: Arc<QosStats>,
}
impl PrioritySender {
/// Send a packet with the given priority. Implements smart dropping under backpressure.
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
let (tx, enqueued_counter) = match priority {
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
};
match tx.try_send(packet) {
Ok(()) => {
enqueued_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
Err(mpsc::error::TrySendError::Full(packet)) => {
self.handle_backpressure(packet, priority).await
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
}
}
async fn handle_backpressure(
&self,
packet: Vec<u8>,
priority: Priority,
) -> Result<(), PacketDropped> {
match priority {
Priority::Low => {
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::LowPriorityDrop)
}
Priority::Normal => {
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::NormalPriorityDrop)
}
Priority::High => {
// Last resort: briefly wait for space, then drop
match tokio::time::timeout(
Duration::from_millis(5),
self.high_tx.send(packet),
)
.await
{
Ok(Ok(())) => {
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
Ok(())
}
_ => {
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
Err(PacketDropped::HighPriorityDrop)
}
}
}
}
}
}
/// Receiving half of the priority channel set.
pub struct PriorityReceiver {
high_rx: mpsc::Receiver<Vec<u8>>,
normal_rx: mpsc::Receiver<Vec<u8>>,
low_rx: mpsc::Receiver<Vec<u8>>,
}
impl PriorityReceiver {
/// Receive the next packet, draining high-priority first (biased select).
pub async fn recv(&mut self) -> Option<Vec<u8>> {
tokio::select! {
biased;
Some(pkt) = self.high_rx.recv() => Some(pkt),
Some(pkt) = self.normal_rx.recv() => Some(pkt),
Some(pkt) = self.low_rx.recv() => Some(pkt),
else => None,
}
}
}
/// Create a priority channel set split into sender and receiver halves.
///
/// - `high_cap`: capacity of the high-priority channel
/// - `normal_cap`: capacity of the normal-priority channel
/// - `low_cap`: capacity of the low-priority channel
pub fn create_priority_channels(
high_cap: usize,
normal_cap: usize,
low_cap: usize,
) -> (PrioritySender, PriorityReceiver) {
let (high_tx, high_rx) = mpsc::channel(high_cap);
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
let (low_tx, low_rx) = mpsc::channel(low_cap);
let stats = Arc::new(QosStats::new());
let sender = PrioritySender {
high_tx,
normal_tx,
low_tx,
stats,
};
let receiver = PriorityReceiver {
high_rx,
normal_rx,
low_rx,
};
(sender, receiver)
}
/// Get a reference to the QoS stats from a sender.
impl PrioritySender {
pub fn stats(&self) -> &Arc<QosStats> {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper: craft a minimal IPv4 packet
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
let mut pkt = vec![0u8; total_len.max(24) as usize];
pkt[0] = 0x45; // version 4, IHL 5
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
pkt[9] = protocol;
// src IP
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
// dst IP
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
// ports (at offset 20 for IHL=5)
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
pkt
}
#[test]
fn classify_icmp_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_dns_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_ssh_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_small_packet_as_high() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
assert_eq!(c.classify(&pkt), Priority::High);
}
#[test]
fn classify_normal_http() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_bulk_flow_as_low() {
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
// Send enough traffic to exceed the threshold
for _ in 0..100 {
let pkt = make_ipv4_packet(6, 12345, 80, 500);
c.classify(&pkt);
}
// Next packet from same flow should be Low
let pkt = make_ipv4_packet(6, 12345, 80, 500);
assert_eq!(c.classify(&pkt), Priority::Low);
}
#[test]
fn classify_too_short_packet() {
let mut c = PacketClassifier::new(1_000_000);
let pkt = vec![0u8; 10]; // Too short for IPv4 header
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[test]
fn classify_non_ipv4() {
let mut c = PacketClassifier::new(1_000_000);
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; // IPv6 version nibble
assert_eq!(c.classify(&pkt), Priority::Normal);
}
#[tokio::test]
async fn priority_receiver_drains_high_first() {
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
// Enqueue in reverse order
sender.send(vec![3], Priority::Low).await.unwrap();
sender.send(vec![2], Priority::Normal).await.unwrap();
sender.send(vec![1], Priority::High).await.unwrap();
// Should drain High first
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
}
#[tokio::test]
async fn smart_dropping_low_priority() {
let (sender, _receiver) = create_priority_channels(8, 8, 1);
// Fill the low channel
sender.send(vec![0], Priority::Low).await.unwrap();
// Next low-priority send should be dropped
let result = sender.send(vec![1], Priority::Low).await;
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn smart_dropping_normal_priority() {
let (sender, _receiver) = create_priority_channels(8, 1, 8);
sender.send(vec![0], Priority::Normal).await.unwrap();
let result = sender.send(vec![1], Priority::Normal).await;
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn stats_track_enqueued() {
let (sender, _receiver) = create_priority_channels(8, 8, 8);
sender.send(vec![1], Priority::High).await.unwrap();
sender.send(vec![2], Priority::High).await.unwrap();
sender.send(vec![3], Priority::Normal).await.unwrap();
sender.send(vec![4], Priority::Low).await.unwrap();
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
}
#[test]
fn flow_tracker_evicts_at_capacity() {
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
ft.record(k1, 100, 1000);
ft.record(k2, 100, 1000);
// Should evict k1 (oldest)
ft.record(k3, 100, 1000);
assert_eq!(ft.flows.len(), 2);
assert!(!ft.flows.contains_key(&k1));
}
}

546
rust/src/quic_transport.rs Normal file
View File

@@ -0,0 +1,546 @@
use anyhow::Result;
use async_trait::async_trait;
use quinn::crypto::rustls::QuicClientConfig;
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn, debug};
use crate::transport_trait::{TransportSink, TransportStream};
// ============================================================================
// TLS / Certificate helpers
// ============================================================================
/// Generate a self-signed certificate and private key for QUIC.
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
let cert_der = CertificateDer::from(cert.cert);
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
Ok((vec![cert_der], key_der))
}
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
use ring::digest;
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
}
// ============================================================================
// Server-side QUIC endpoint
// ============================================================================
/// Configuration for the QUIC server endpoint.
pub struct QuicServerConfig {
pub listen_addr: String,
pub cert_chain: Vec<CertificateDer<'static>>,
pub private_key: PrivateKeyDer<'static>,
pub idle_timeout_secs: u64,
}
/// Create a QUIC server endpoint bound to the given address.
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
let addr: SocketAddr = config.listen_addr.parse()?;
let provider = Arc::new(rustls::crypto::ring::default_provider());
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.with_no_client_auth()
.with_single_cert(config.cert_chain, config.private_key)?;
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
));
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
));
// Enable datagrams with a generous max size
transport.datagram_receive_buffer_size(Some(65535));
transport.datagram_send_buffer_size(65535);
server_config.transport_config(Arc::new(transport));
let endpoint = quinn::Endpoint::server(server_config, addr)?;
info!("QUIC server listening on {}", addr);
Ok(endpoint)
}
// ============================================================================
// Client-side QUIC connection
// ============================================================================
/// A certificate verifier that accepts any server certificate.
/// Safe when Noise NK provides server authentication at the application layer.
#[derive(Debug)]
struct AcceptAnyCert;
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
#[derive(Debug)]
struct CertHashVerifier {
expected_hash: String,
}
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let actual_hash = cert_hash(end_entity);
if actual_hash == self.expected_hash {
Ok(rustls::client::danger::ServerCertVerified::assertion())
} else {
Err(rustls::Error::General(format!(
"Certificate hash mismatch: expected {}, got {}",
self.expected_hash, actual_hash
)))
}
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
// QUIC always uses TLS 1.3
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
/// Connect to a QUIC server.
///
/// - If `server_cert_hash` is provided, verifies the server certificate matches
/// the given SHA-256 hash (cert pinning).
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
/// safe because the Noise NK handshake (which runs over the QUIC stream)
/// authenticates the server via its pre-shared public key — the same trust
/// model as WireGuard.
pub async fn connect_quic(
addr: &str,
server_cert_hash: Option<&str>,
) -> Result<quinn::Connection> {
let remote: SocketAddr = addr.parse()?;
let provider = Arc::new(rustls::crypto::ring::default_provider());
let tls_config = if let Some(hash) = server_cert_hash {
// Pin to a specific certificate hash
let mut config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.dangerous()
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
expected_hash: hash.to_string(),
}))
.with_no_client_auth();
config.alpn_protocols = vec![b"smartvpn".to_vec()];
config
} else {
// Accept any cert — Noise NK provides server authentication
let mut config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.dangerous()
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
.with_no_client_auth();
config.alpn_protocols = vec![b"smartvpn".to_vec()];
config
};
let client_config = quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(tls_config)?,
));
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
endpoint.set_default_client_config(client_config);
info!("Connecting to QUIC server at {}", addr);
let connection = endpoint.connect(remote, "smartvpn")?.await?;
info!("QUIC connection established");
Ok(connection)
}
// ============================================================================
// QUIC Transport Sink / Stream implementations
// ============================================================================
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
pub struct QuicTransportSink {
send_stream: quinn::SendStream,
connection: quinn::Connection,
}
impl QuicTransportSink {
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
Self {
send_stream,
connection,
}
}
}
#[async_trait]
impl TransportSink for QuicTransportSink {
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
// Length-prefix framing: [4-byte big-endian length][payload]
let len = data.len() as u32;
self.send_stream.write_all(&len.to_be_bytes()).await?;
self.send_stream.write_all(&data).await?;
Ok(())
}
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
let max_size = self.connection.max_datagram_size();
match max_size {
Some(max) if data.len() <= max => {
self.connection.send_datagram(data.into())?;
Ok(())
}
_ => {
// Datagram too large or datagrams disabled — fall back to reliable
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
self.send_reliable(data).await
}
}
}
async fn close(&mut self) -> Result<()> {
self.send_stream.finish()?;
Ok(())
}
}
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
pub struct QuicTransportStream {
recv_stream: quinn::RecvStream,
connection: quinn::Connection,
}
impl QuicTransportStream {
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
Self {
recv_stream,
connection,
}
}
}
#[async_trait]
impl TransportStream for QuicTransportStream {
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
// Read length prefix
let mut len_buf = [0u8; 4];
match self.recv_stream.read_exact(&mut len_buf).await {
Ok(()) => {}
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
warn!("QUIC connection lost: {}", e);
return Ok(None);
}
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > 65536 {
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
}
let mut data = vec![0u8; len];
match self.recv_stream.read_exact(&mut data).await {
Ok(()) => Ok(Some(data)),
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
}
}
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
match self.connection.read_datagram().await {
Ok(data) => Ok(Some(data.to_vec())),
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
}
}
fn supports_datagrams(&self) -> bool {
self.connection.max_datagram_size().is_some()
}
}
/// Accept a QUIC connection and open a bidirectional control stream.
/// Returns the transport sink/stream pair ready for the VPN handshake.
pub async fn accept_quic_connection(
conn: quinn::Connection,
) -> Result<(QuicTransportSink, QuicTransportStream)> {
// The client opens the bidirectional control stream
let (send, recv) = conn.accept_bi().await?;
info!("QUIC bidirectional control stream accepted");
Ok((
QuicTransportSink::new(send, conn.clone()),
QuicTransportStream::new(recv, conn),
))
}
/// Open a QUIC connection's bidirectional control stream (client side).
pub async fn open_quic_streams(
conn: quinn::Connection,
) -> Result<(QuicTransportSink, QuicTransportStream)> {
let (send, recv) = conn.open_bi().await?;
info!("QUIC bidirectional control stream opened");
Ok((
QuicTransportSink::new(send, conn.clone()),
QuicTransportStream::new(recv, conn),
))
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cert_generation_and_hash() {
let (certs, _key) = generate_self_signed_cert().unwrap();
assert_eq!(certs.len(), 1);
let hash = cert_hash(&certs[0]);
// SHA-256 base64 is 44 characters
assert_eq!(hash.len(), 44);
}
#[test]
fn test_cert_hash_deterministic() {
let (certs, _key) = generate_self_signed_cert().unwrap();
let hash1 = cert_hash(&certs[0]);
let hash2 = cert_hash(&certs[0]);
assert_eq!(hash1, hash2);
}
/// Helper: create QUIC server and client endpoints.
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
let (certs, key) = generate_self_signed_cert().unwrap();
let hash = cert_hash(&certs[0]);
let provider = Arc::new(rustls::crypto::ring::default_provider());
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions().unwrap()
.with_no_client_auth()
.with_single_cert(certs, key).unwrap();
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
));
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions().unwrap()
.dangerous()
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
expected_hash: hash,
}))
.with_no_client_auth();
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
let client_config = quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(client_tls).unwrap(),
));
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
client_ep.set_default_client_config(client_config);
let server_addr = server_ep.local_addr().unwrap().to_string();
(server_ep, client_ep, server_addr)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_server_client_roundtrip() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
// Server: accept, accept_bi, read, echo, finish
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
let data = s_recv.read_to_end(1024).await.unwrap();
s_send.write_all(&data).await.unwrap();
s_send.finish().unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
server_ep
});
// Client: connect, open_bi, write, finish, read
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
c_send.write_all(b"hello quinn").await.unwrap();
c_send.finish().unwrap();
let data = c_recv.read_to_end(1024).await.unwrap();
assert_eq!(&data[..], b"hello quinn");
let _ = server_task.await;
drop(client_ep);
}
/// Test transport trait wrappers over QUIC.
/// Key: client must send data first (QUIC streams are opened implicitly by data).
/// The server accept_bi runs concurrently with the client's first send_reliable.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_transport_trait_roundtrip() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
// Server task: accept connection, then accept_bi (blocks until client sends data)
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
(s_sink, s_stream, server_ep)
});
// Client: connect, open_bi via wrapper
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
// Client sends first — this triggers the QUIC stream to become visible to the server
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
// Now server's accept_bi unblocks
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
// Server reads the message
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
assert_eq!(msg, b"hello-from-client");
// Server -> Client
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
assert_eq!(msg, b"hello-from-server");
drop(client_ep);
}
/// Test QUIC datagram support.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_datagram_exchange() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
// Server: accept, accept_bi (opens control stream), then read datagram
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
// Accept bi stream (control channel)
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
// Read datagram
let dgram = conn.read_datagram().await.unwrap();
assert_eq!(&dgram[..], b"dgram-payload");
server_ep
});
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
// Send initial data to open the stream (required for QUIC)
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
// Small yield to let the server process the bi stream
tokio::task::yield_now().await;
// Send datagram
assert!(conn.max_datagram_size().is_some());
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
let _ = server_task.await.unwrap();
drop(client_ep);
}
/// Test that supports_datagrams returns true for QUIC transports.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_supports_datagrams() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
assert!(s_stream.supports_datagrams());
server_ep
});
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
assert!(c_stream.supports_datagrams());
// Send data to trigger server's accept_bi
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
let _ = server_task.await.unwrap();
drop(client_ep);
}
}

141
rust/src/ratelimit.rs Normal file
View File

@@ -0,0 +1,141 @@
use std::time::Instant;
/// A token bucket rate limiter operating on byte granularity.
pub struct TokenBucket {
/// Tokens (bytes) added per second.
rate: f64,
/// Maximum burst capacity in bytes.
burst: f64,
/// Currently available tokens.
tokens: f64,
/// Last time tokens were refilled.
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket.
///
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
let burst = burst_bytes as f64;
Self {
rate: rate_bytes_per_sec as f64,
burst,
tokens: burst, // start full
last_refill: Instant::now(),
}
}
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
pub fn try_consume(&mut self, bytes: usize) -> bool {
self.refill();
let needed = bytes as f64;
if needed <= self.tokens {
self.tokens -= needed;
true
} else {
false
}
}
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
self.rate = rate_bytes_per_sec as f64;
self.burst = burst_bytes as f64;
// Cap current tokens at new burst
if self.tokens > self.burst {
self.tokens = self.burst;
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn allows_traffic_under_burst() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Should allow up to burst size immediately
assert!(tb.try_consume(1_500_000));
assert!(tb.try_consume(400_000));
}
#[test]
fn blocks_traffic_over_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
// Consume entire burst
assert!(tb.try_consume(1_000_000));
// Next consume should fail (no time to refill)
assert!(!tb.try_consume(1));
}
#[test]
fn zero_consume_always_succeeds() {
let mut tb = TokenBucket::new(0, 0);
assert!(tb.try_consume(0));
}
#[test]
fn refills_over_time() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
// Drain completely
assert!(tb.try_consume(1_000_000));
assert!(!tb.try_consume(1));
// Wait 100ms — should refill ~100KB
std::thread::sleep(Duration::from_millis(100));
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
}
#[test]
fn update_limits_caps_tokens() {
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
// Tokens start at burst (2MB)
tb.update_limits(500_000, 500_000);
// Tokens should be capped to new burst (500KB)
assert!(tb.try_consume(500_000));
assert!(!tb.try_consume(1));
}
#[test]
fn update_limits_changes_rate() {
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
assert!(tb.try_consume(1_000_000)); // drain
// Change to higher rate
tb.update_limits(10_000_000, 10_000_000);
std::thread::sleep(Duration::from_millis(50));
// At 10MB/s, 50ms should refill ~500KB
assert!(tb.try_consume(200_000));
}
#[test]
fn zero_rate_blocks_after_burst() {
let mut tb = TokenBucket::new(0, 100);
assert!(tb.try_consume(100));
std::thread::sleep(Duration::from_millis(10));
// Zero rate means no refill
assert!(!tb.try_consume(1));
}
#[test]
fn tokens_do_not_exceed_burst() {
// Use a low rate so refill between consecutive calls is negligible
let mut tb = TokenBucket::new(100, 1_000);
// Wait to accumulate — but should cap at burst
std::thread::sleep(Duration::from_millis(50));
assert!(tb.try_consume(1_000));
// At 100 bytes/sec, the few μs between calls add ~0 tokens
assert!(!tb.try_consume(1));
}
}

View File

@@ -1,19 +1,25 @@
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::mtu::{MtuConfig, TunnelOverhead};
use crate::network::IpPool;
use crate::ratelimit::TokenBucket;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
/// Server configuration (matches TS IVpnServerConfig).
#[derive(Debug, Clone, Deserialize)]
@@ -29,6 +35,16 @@ pub struct ServerConfig {
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
pub enable_nat: Option<bool>,
/// Default rate limit for new clients (bytes/sec). None = unlimited.
pub default_rate_limit_bytes_per_sec: Option<u64>,
/// Default burst size for new clients (bytes). None = unlimited.
pub default_burst_bytes: Option<u64>,
/// Transport mode: "websocket" (default), "quic", or "both".
pub transport_mode: Option<String>,
/// QUIC listen address (host:port). Defaults to listen_addr.
pub quic_listen_addr: Option<String>,
/// QUIC idle timeout in seconds (default: 30).
pub quic_idle_timeout_secs: Option<u64>,
}
/// Information about a connected client.
@@ -40,6 +56,12 @@ pub struct ClientInfo {
pub connected_since: String,
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_dropped: u64,
pub bytes_dropped: u64,
pub last_keepalive_at: Option<String>,
pub keepalives_received: u64,
pub rate_limit_bytes_per_sec: Option<u64>,
pub burst_bytes: Option<u64>,
}
/// Server statistics.
@@ -63,6 +85,8 @@ pub struct ServerState {
pub ip_pool: Mutex<IpPool>,
pub clients: RwLock<HashMap<String, ClientInfo>>,
pub stats: RwLock<ServerStatistics>,
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
pub mtu_config: MtuConfig,
pub started_at: std::time::Instant,
}
@@ -98,11 +122,18 @@ impl VpnServer {
}
}
let link_mtu = config.mtu.unwrap_or(1420);
// Compute effective MTU from overhead
let overhead = TunnelOverhead::default_overhead();
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
let state = Arc::new(ServerState {
config: config.clone(),
ip_pool: Mutex::new(ip_pool),
clients: RwLock::new(HashMap::new()),
stats: RwLock::new(ServerStatistics::default()),
rate_limiters: Mutex::new(HashMap::new()),
mtu_config,
started_at: std::time::Instant::now(),
});
@@ -110,14 +141,58 @@ impl VpnServer {
self.state = Some(state.clone());
self.shutdown_tx = Some(shutdown_tx);
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
let listen_addr = config.listen_addr.clone();
tokio::spawn(async move {
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
info!("VPN server started");
match transport_mode {
"quic" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
error!("QUIC listener error: {}", e);
}
});
}
"both" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
let state2 = state.clone();
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
// Store second shutdown sender so both listeners stop
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(combined_tx);
// Forward combined shutdown to both listeners
tokio::spawn(async move {
combined_rx.recv().await;
let _ = shutdown_tx_orig.send(()).await;
let _ = shutdown_tx2.send(()).await;
});
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("WebSocket listener error: {}", e);
}
});
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
error!("QUIC listener error: {}", e);
}
});
}
_ => {
// "websocket" (default)
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
}
}
info!("VPN server started (transport: {})", transport_mode);
Ok(())
}
@@ -166,14 +241,57 @@ impl VpnServer {
if let Some(client) = clients.remove(client_id) {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
state.rate_limiters.lock().await.remove(client_id);
info!("Client {} disconnected", client_id);
}
}
Ok(())
}
/// Set a rate limit for a specific client.
pub async fn set_client_rate_limit(
&self,
client_id: &str,
rate_bytes_per_sec: u64,
burst_bytes: u64,
) -> Result<()> {
if let Some(ref state) = self.state {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(client_id) {
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
} else {
limiters.insert(
client_id.to_string(),
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
);
}
// Update client info
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
info.burst_bytes = Some(burst_bytes);
}
}
Ok(())
}
/// Remove rate limit for a specific client (unlimited).
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
if let Some(ref state) = self.state {
state.rate_limiters.lock().await.remove(client_id);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = None;
info.burst_bytes = None;
}
}
Ok(())
}
}
async fn run_listener(
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
/// to the transport-agnostic `handle_client_connection`.
async fn run_ws_listener(
state: Arc<ServerState>,
listen_addr: String,
shutdown_rx: &mut mpsc::Receiver<()>,
@@ -189,8 +307,20 @@ async fn run_listener(
info!("New connection from {}", addr);
let state = state.clone();
tokio::spawn(async move {
if let Err(e) = handle_client_connection(state, stream).await {
warn!("Client connection error: {}", e);
match transport::accept_connection(stream).await {
Ok(ws) => {
let (sink, stream) = transport_trait::split_ws(ws);
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("Client connection error: {}", e);
}
}
Err(e) => {
warn!("WebSocket upgrade failed: {}", e);
}
}
});
}
@@ -209,13 +339,95 @@ async fn run_listener(
Ok(())
}
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
/// `handle_client_connection`.
async fn run_quic_listener(
state: Arc<ServerState>,
listen_addr: String,
idle_timeout_secs: u64,
shutdown_rx: &mut mpsc::Receiver<()>,
) -> Result<()> {
// Generate or use configured TLS certificate for QUIC
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
(&state.config.tls_cert, &state.config.tls_key)
{
// Parse PEM certificates
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
(certs, key)
} else {
// Generate self-signed certificate
let (certs, key) = quic_transport::generate_self_signed_cert()?;
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
(certs, key)
};
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
listen_addr,
cert_chain,
private_key,
idle_timeout_secs,
})?;
loop {
tokio::select! {
incoming = endpoint.accept() => {
match incoming {
Some(incoming) => {
let state = state.clone();
tokio::spawn(async move {
match incoming.await {
Ok(conn) => {
let remote = conn.remote_address();
info!("New QUIC connection from {}", remote);
match quic_transport::accept_quic_connection(conn).await {
Ok((sink, stream)) => {
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("QUIC client error: {}", e);
}
}
Err(e) => {
warn!("QUIC stream accept failed: {}", e);
}
}
}
Err(e) => {
warn!("QUIC handshake failed: {}", e);
}
}
});
}
None => {
info!("QUIC endpoint closed");
break;
}
}
}
_ = shutdown_rx.recv() => {
info!("QUIC shutdown signal received");
endpoint.close(0u32.into(), b"shutdown");
break;
}
}
}
Ok(())
}
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
/// the client, and runs the main packet forwarding loop.
async fn handle_client_connection(
state: Arc<ServerState>,
stream: tokio::net::TcpStream,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
) -> Result<()> {
let ws = transport::accept_connection(stream).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
let client_id = uuid_v4();
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
@@ -229,9 +441,9 @@ async fn handle_client_connection(
let mut buf = vec![0u8; 65535];
// Receive handshake init
let init_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected handshake init message"),
let init_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before handshake"),
};
let mut frame_buf = BytesMut::from(&init_msg[..]);
@@ -252,30 +464,48 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
let mut noise_transport = responder.into_transport_mode()?;
// Register client
let default_rate = state.config.default_rate_limit_bytes_per_sec;
let default_burst = state.config.default_burst_bytes;
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: assigned_ip.to_string(),
connected_since: timestamp_now(),
bytes_sent: 0,
bytes_received: 0,
packets_dropped: 0,
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: default_rate,
burst_bytes: default_burst,
};
state.clients.write().await.insert(client_id.clone(), client_info);
// Set up rate limiter if defaults are configured
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
state
.rate_limiters
.lock()
.await
.insert(client_id.clone(), TokenBucket::new(rate, burst));
}
{
let mut stats = state.stats.write().await;
stats.total_connections += 1;
}
// Send assigned IP info (encrypted)
// Send assigned IP info (encrypted), include effective MTU
let ip_info = serde_json::json!({
"assignedIp": assigned_ip.to_string(),
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
"mtu": state.config.mtu.unwrap_or(1420),
"effectiveMtu": state.mtu_config.effective_mtu,
});
let ip_info_bytes = serde_json::to_vec(&ip_info)?;
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
@@ -285,70 +515,112 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
info!("Client {} connected with IP {}", client_id, assigned_ip);
// Main packet loop
// Main packet loop with dead-peer detection
let mut last_activity = tokio::time::Instant::now();
loop {
match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => {
let mut frame_buf = BytesMut::from(&data[..][..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
let mut stats = state.stats.write().await;
stats.bytes_received += len as u64;
stats.packets_received += 1;
tokio::select! {
msg = stream.recv_reliable() => {
match msg {
Ok(Some(data)) => {
last_activity = tokio::time::Instant::now();
let mut frame_buf = BytesMut::from(&data[..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
// Rate limiting check
let allowed = {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(&client_id) {
limiter.try_consume(len)
} else {
true
}
};
if !allowed {
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.packets_dropped += 1;
info.bytes_dropped += len as u64;
}
continue;
}
let mut stats = state.stats.write().await;
stats.bytes_received += len as u64;
stats.packets_received += 1;
// Update per-client stats
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_received += len as u64;
}
}
Err(e) => {
warn!("Decrypt error from {}: {}", client_id, e);
break;
}
}
}
Err(e) => {
warn!("Decrypt error from {}: {}", client_id, e);
PacketType::Keepalive => {
// Echo the keepalive payload back in the ACK
let ack_frame = Frame {
packet_type: PacketType::KeepaliveAck,
payload: frame.payload.clone(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
stats.keepalives_sent += 1;
// Update per-client keepalive tracking
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.last_keepalive_at = Some(timestamp_now());
info.keepalives_received += 1;
}
}
PacketType::Disconnect => {
info!("Client {} sent disconnect", client_id);
break;
}
_ => {
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
}
},
Ok(None) => {
warn!("Incomplete frame from {}", client_id);
}
Err(e) => {
warn!("Frame decode error from {}: {}", client_id, e);
break;
}
}
PacketType::Keepalive => {
let ack_frame = Frame {
packet_type: PacketType::KeepaliveAck,
payload: vec![],
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
stats.keepalives_sent += 1;
}
PacketType::Disconnect => {
info!("Client {} sent disconnect", client_id);
break;
}
_ => {
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
}
},
}
Ok(None) => {
warn!("Incomplete frame from {}", client_id);
info!("Client {} connection closed", client_id);
break;
}
Err(e) => {
warn!("Frame decode error from {}: {}", client_id, e);
warn!("Transport error from {}: {}", client_id, e);
break;
}
}
}
Some(Ok(Message::Close(_))) | None => {
info!("Client {} connection closed", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
ws_sink.send(Message::Pong(data)).await?;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
warn!("WebSocket error from {}: {}", client_id, e);
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
break;
}
}
@@ -357,6 +629,7 @@ async fn handle_client_connection(
// Cleanup
state.clients.write().await.remove(&client_id);
state.ip_pool.lock().await.release(&assigned_ip);
state.rate_limiters.lock().await.remove(&client_id);
info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
Ok(())

317
rust/src/telemetry.rs Normal file
View File

@@ -0,0 +1,317 @@
use serde::Serialize;
use std::collections::VecDeque;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
/// A single RTT sample.
#[derive(Debug, Clone)]
struct RttSample {
_rtt: Duration,
_timestamp: Instant,
was_timeout: bool,
}
/// Snapshot of connection quality metrics.
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionQuality {
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
pub srtt_ms: f64,
/// Jitter in milliseconds (mean deviation of RTT).
pub jitter_ms: f64,
/// Minimum RTT observed in the sample window.
pub min_rtt_ms: f64,
/// Maximum RTT observed in the sample window.
pub max_rtt_ms: f64,
/// Packet loss ratio over the sample window (0.0 - 1.0).
pub loss_ratio: f64,
/// Number of consecutive keepalive timeouts (0 if last succeeded).
pub consecutive_timeouts: u32,
/// Total keepalives sent.
pub keepalives_sent: u64,
/// Total keepalive ACKs received.
pub keepalives_acked: u64,
}
/// Tracks connection quality from keepalive round-trips.
pub struct RttTracker {
/// Maximum number of samples to keep in the window.
max_samples: usize,
/// Recent RTT samples (including timeout markers).
samples: VecDeque<RttSample>,
/// When the last keepalive was sent (for computing RTT on ACK).
pending_ping_sent_at: Option<Instant>,
/// Number of consecutive keepalive timeouts.
pub consecutive_timeouts: u32,
/// Smoothed RTT (EMA).
srtt: Option<f64>,
/// Jitter (mean deviation).
jitter: f64,
/// Minimum RTT observed.
min_rtt: f64,
/// Maximum RTT observed.
max_rtt: f64,
/// Total keepalives sent.
keepalives_sent: u64,
/// Total keepalive ACKs received.
keepalives_acked: u64,
/// Previous RTT sample for jitter calculation.
last_rtt_ms: Option<f64>,
}
impl RttTracker {
/// Create a new tracker with the given window size.
pub fn new(max_samples: usize) -> Self {
Self {
max_samples,
samples: VecDeque::with_capacity(max_samples),
pending_ping_sent_at: None,
consecutive_timeouts: 0,
srtt: None,
jitter: 0.0,
min_rtt: f64::MAX,
max_rtt: 0.0,
keepalives_sent: 0,
keepalives_acked: 0,
last_rtt_ms: None,
}
}
/// Record that a keepalive was sent.
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
pub fn mark_ping_sent(&mut self) -> u64 {
self.pending_ping_sent_at = Some(Instant::now());
self.keepalives_sent += 1;
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Record that a keepalive ACK was received with the echoed timestamp.
/// Returns the computed RTT if a pending ping was recorded.
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
let sent_at = self.pending_ping_sent_at.take()?;
let rtt = sent_at.elapsed();
let rtt_ms = rtt.as_secs_f64() * 1000.0;
self.keepalives_acked += 1;
self.consecutive_timeouts = 0;
// Update SRTT (RFC 6298: alpha = 1/8)
match self.srtt {
None => {
self.srtt = Some(rtt_ms);
self.jitter = rtt_ms / 2.0;
}
Some(prev_srtt) => {
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
}
}
// Update min/max
if rtt_ms < self.min_rtt {
self.min_rtt = rtt_ms;
}
if rtt_ms > self.max_rtt {
self.max_rtt = rtt_ms;
}
self.last_rtt_ms = Some(rtt_ms);
// Push sample into window
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(RttSample {
_rtt: rtt,
_timestamp: Instant::now(),
was_timeout: false,
});
Some(rtt)
}
/// Record that a keepalive timed out (no ACK received).
pub fn record_timeout(&mut self) {
self.consecutive_timeouts += 1;
self.pending_ping_sent_at = None;
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(RttSample {
_rtt: Duration::ZERO,
_timestamp: Instant::now(),
was_timeout: true,
});
}
/// Get a snapshot of the current connection quality.
pub fn snapshot(&self) -> ConnectionQuality {
let loss_ratio = if self.samples.is_empty() {
0.0
} else {
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
timeouts as f64 / self.samples.len() as f64
};
ConnectionQuality {
srtt_ms: self.srtt.unwrap_or(0.0),
jitter_ms: self.jitter,
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
max_rtt_ms: self.max_rtt,
loss_ratio,
consecutive_timeouts: self.consecutive_timeouts,
keepalives_sent: self.keepalives_sent,
keepalives_acked: self.keepalives_acked,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_tracker_has_zero_quality() {
let tracker = RttTracker::new(30);
let q = tracker.snapshot();
assert_eq!(q.srtt_ms, 0.0);
assert_eq!(q.jitter_ms, 0.0);
assert_eq!(q.loss_ratio, 0.0);
assert_eq!(q.consecutive_timeouts, 0);
assert_eq!(q.keepalives_sent, 0);
assert_eq!(q.keepalives_acked, 0);
}
#[test]
fn mark_ping_returns_timestamp() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
// Should be a reasonable epoch-ms value (after 2020)
assert!(ts > 1_577_836_800_000);
assert_eq!(tracker.keepalives_sent, 1);
}
#[test]
fn record_ack_computes_rtt() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(5));
let rtt = tracker.record_ack(ts);
assert!(rtt.is_some());
let rtt = rtt.unwrap();
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
assert_eq!(tracker.keepalives_acked, 1);
assert_eq!(tracker.consecutive_timeouts, 0);
}
#[test]
fn record_ack_without_pending_returns_none() {
let mut tracker = RttTracker::new(30);
assert!(tracker.record_ack(12345).is_none());
}
#[test]
fn srtt_converges() {
let mut tracker = RttTracker::new(30);
// Simulate several ping/ack cycles with ~10ms RTT
for _ in 0..10 {
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(10));
tracker.record_ack(ts);
}
let q = tracker.snapshot();
// SRTT should be roughly 10ms (allowing for scheduling variance)
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
}
#[test]
fn timeout_increments_counter_and_loss() {
let mut tracker = RttTracker::new(30);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 1);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 2);
let q = tracker.snapshot();
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
}
#[test]
fn ack_resets_consecutive_timeouts() {
let mut tracker = RttTracker::new(30);
tracker.mark_ping_sent();
tracker.record_timeout();
assert_eq!(tracker.consecutive_timeouts, 1);
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
assert_eq!(tracker.consecutive_timeouts, 0);
}
#[test]
fn loss_ratio_over_mixed_window() {
let mut tracker = RttTracker::new(30);
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
for _ in 0..3 {
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
}
tracker.mark_ping_sent();
tracker.record_timeout();
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
let q = tracker.snapshot();
assert!((q.loss_ratio - 0.2).abs() < 0.01);
}
#[test]
fn window_evicts_old_samples() {
let mut tracker = RttTracker::new(5);
// Fill window with 5 timeouts
for _ in 0..5 {
tracker.mark_ping_sent();
tracker.record_timeout();
}
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
// Add 5 successes — should evict all timeouts
for _ in 0..5 {
let ts = tracker.mark_ping_sent();
tracker.record_ack(ts);
}
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
}
#[test]
fn min_max_rtt_tracked() {
let mut tracker = RttTracker::new(30);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(5));
tracker.record_ack(ts);
let ts = tracker.mark_ping_sent();
std::thread::sleep(Duration::from_millis(15));
tracker.record_ack(ts);
let q = tracker.snapshot();
assert!(q.min_rtt_ms < q.max_rtt_ms);
assert!(q.min_rtt_ms > 0.0);
}
}

116
rust/src/transport_trait.rs Normal file
View File

@@ -0,0 +1,116 @@
use anyhow::Result;
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
use crate::transport::WsStream;
// ============================================================================
// Transport trait abstraction
// ============================================================================
/// Outbound half of a VPN transport connection.
#[async_trait]
pub trait TransportSink: Send + 'static {
/// Send a framed binary message on the reliable channel.
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
/// Send a datagram (unreliable, best-effort).
/// Falls back to reliable if the transport does not support datagrams.
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
/// Gracefully close the transport.
async fn close(&mut self) -> Result<()>;
}
/// Inbound half of a VPN transport connection.
#[async_trait]
pub trait TransportStream: Send + 'static {
/// Receive the next reliable binary message. Returns `None` on close.
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
/// Receive the next datagram. Returns `None` if datagrams are unsupported
/// or the connection is closed.
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
/// Whether this transport supports unreliable datagrams.
fn supports_datagrams(&self) -> bool;
}
// ============================================================================
// WebSocket implementation
// ============================================================================
/// WebSocket transport sink (wraps the write half of a split WsStream).
pub struct WsTransportSink {
inner: futures_util::stream::SplitSink<WsStream, Message>,
}
impl WsTransportSink {
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
Self { inner }
}
}
#[async_trait]
impl TransportSink for WsTransportSink {
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
self.inner.send(Message::Binary(data.into())).await?;
Ok(())
}
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
// WebSocket has no datagram support — fall back to reliable.
self.send_reliable(data).await
}
async fn close(&mut self) -> Result<()> {
self.inner.close().await?;
Ok(())
}
}
/// WebSocket transport stream (wraps the read half of a split WsStream).
pub struct WsTransportStream {
inner: futures_util::stream::SplitStream<WsStream>,
}
impl WsTransportStream {
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
Self { inner }
}
}
#[async_trait]
impl TransportStream for WsTransportStream {
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
loop {
match self.inner.next().await {
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
Some(Ok(Message::Close(_))) | None => return Ok(None),
Some(Ok(Message::Ping(_))) => {
// Ping handling is done at the tungstenite layer automatically
// when the sink side is alive. Just skip here.
continue;
}
Some(Ok(_)) => continue,
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
}
}
}
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
// WebSocket does not support datagrams.
Ok(None)
}
fn supports_datagrams(&self) -> bool {
false
}
}
/// Split a WebSocket stream into transport sink and stream halves.
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
let (sink, stream) = ws.split();
(WsTransportSink::new(sink), WsTransportStream::new(stream))
}

View File

@@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
Ok(())
}
/// Action to take after checking a packet against the MTU.
pub enum TunMtuAction {
/// Packet is within MTU limits, forward it.
Forward,
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
IcmpTooBig(Vec<u8>),
}
/// Check a TUN packet against the MTU and return the appropriate action.
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
match crate::mtu::check_mtu(packet, mtu_config) {
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
}
}
/// Remove a route.
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
let output = tokio::process::Command::new("ip")

1329
rust/src/wireguard.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,271 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import { VpnClient, VpnServer } from '../ts/index.js';
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
async function findFreePort(): Promise<number> {
const server = net.createServer();
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
const port = (server.address() as net.AddressInfo).port;
await new Promise<void>((resolve) => server.close(() => resolve()));
return port;
}
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function waitFor(
fn: () => Promise<boolean>,
timeoutMs: number = 10000,
pollMs: number = 500,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await fn()) return;
await delay(pollMs);
}
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
}
// ---------------------------------------------------------------------------
// Test state
// ---------------------------------------------------------------------------
let server: VpnServer;
let serverPort: number;
let keypair: IVpnKeypair;
let client: VpnClient;
const extraClients: VpnClient[] = [];
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
tap.test('setup: start VPN server', async () => {
serverPort = await findFreePort();
const options: IVpnServerOptions = {
transport: { transport: 'stdio' },
};
server = new VpnServer(options);
// Phase 1: start the daemon bridge
const started = await server['bridge'].start();
expect(started).toBeTrue();
expect(server.running).toBeTrue();
// Phase 2: generate a keypair
keypair = await server.generateKeypair();
expect(keypair.publicKey).toBeTypeofString();
expect(keypair.privateKey).toBeTypeofString();
// Phase 3: start the VPN listener
const serverConfig: IVpnServerConfig = {
listenAddr: `127.0.0.1:${serverPort}`,
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.8.0.0/24',
};
await server['bridge'].sendCommand('start', { config: serverConfig });
// Verify server is now running
const status = await server.getStatus();
expect(status.state).toEqual('connected');
});
tap.test('single client connects and gets IP', async () => {
const options: IVpnClientOptions = {
transport: { transport: 'stdio' },
};
client = new VpnClient(options);
const started = await client.start();
expect(started).toBeTrue();
const result = await client.connect({
serverUrl: `ws://127.0.0.1:${serverPort}`,
serverPublicKey: keypair.publicKey,
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toBeTypeofString();
expect(result.assignedIp).toStartWith('10.8.0.');
// Verify client status
const clientStatus = await client.getStatus();
expect(clientStatus.state).toEqual('connected');
// Verify server sees the client
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 1;
});
const clients = await server.listClients();
expect(clients.length).toEqual(1);
expect(clients[0].assignedIp).toEqual(result.assignedIp);
});
tap.test('keepalive exchange', async () => {
// Wait for at least 2 keepalive cycles (interval=3s, so 8s should be enough)
await delay(8000);
const clientStats = await client.getStatistics();
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
const serverStats = await server.getStatistics();
expect(serverStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
expect(serverStats.keepalivesSent).toBeGreaterThanOrEqual(1);
// Verify per-client keepalive tracking
const clients = await server.listClients();
expect(clients[0].keepalivesReceived).toBeGreaterThanOrEqual(1);
});
tap.test('connection quality telemetry', async () => {
const quality = await client.getConnectionQuality();
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
expect(quality.jitterMs).toBeTypeofNumber();
expect(quality.minRttMs).toBeGreaterThanOrEqual(0);
expect(quality.maxRttMs).toBeGreaterThanOrEqual(0);
expect(quality.lossRatio).toBeTypeofNumber();
expect(['healthy', 'degraded', 'critical']).toContain(quality.linkHealth);
});
tap.test('rate limiting: set and verify', async () => {
const clients = await server.listClients();
const clientId = clients[0].clientId;
// Set a tight rate limit
await server.setClientRateLimit(clientId, 100, 100);
// Verify via telemetry
const telemetry = await server.getClientTelemetry(clientId);
expect(telemetry.rateLimitBytesPerSec).toEqual(100);
expect(telemetry.burstBytes).toEqual(100);
expect(telemetry.clientId).toEqual(clientId);
});
tap.test('rate limiting: removal', async () => {
const clients = await server.listClients();
const clientId = clients[0].clientId;
await server.removeClientRateLimit(clientId);
// Verify telemetry no longer shows rate limit
const telemetry = await server.getClientTelemetry(clientId);
expect(telemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
expect(telemetry.burstBytes).toBeNullOrUndefined();
// Connection still healthy
const status = await client.getStatus();
expect(status.state).toEqual('connected');
});
tap.test('5 concurrent clients', async () => {
const assignedIps = new Set<string>();
// Get the first client's IP
const existingClients = await server.listClients();
assignedIps.add(existingClients[0].assignedIp);
for (let i = 0; i < 5; i++) {
const c = new VpnClient({ transport: { transport: 'stdio' } });
await c.start();
const result = await c.connect({
serverUrl: `ws://127.0.0.1:${serverPort}`,
serverPublicKey: keypair.publicKey,
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toStartWith('10.8.0.');
assignedIps.add(result.assignedIp);
extraClients.push(c);
}
// All IPs should be unique (6 total: original + 5 new)
expect(assignedIps.size).toEqual(6);
// Server should see 6 clients
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 6;
});
const allClients = await server.listClients();
expect(allClients.length).toEqual(6);
});
tap.test('client disconnect tracking', async () => {
// Disconnect 3 of the 5 extra clients
for (let i = 0; i < 3; i++) {
const c = extraClients[i];
await c.disconnect();
c.stop();
}
// Wait for server to detect disconnections
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 3;
}, 15000);
const clients = await server.listClients();
expect(clients.length).toEqual(3);
const stats = await server.getStatistics();
expect(stats.totalConnections).toBeGreaterThanOrEqual(6);
});
tap.test('server-side client disconnection', async () => {
const clients = await server.listClients();
// Pick one of the remaining extra clients (not the original)
const targetClient = clients.find((c) => {
// Find a client that belongs to extraClients[3] or extraClients[4]
return c.clientId !== clients[0].clientId;
});
expect(targetClient).toBeTruthy();
await server.disconnectClient(targetClient!.clientId);
// Wait for server to update
await waitFor(async () => {
const remaining = await server.listClients();
return remaining.length === 2;
});
const remaining = await server.listClients();
expect(remaining.length).toEqual(2);
});
tap.test('teardown: stop all', async () => {
// Stop the original client
await client.disconnect();
client.stop();
// Stop remaining extra clients
for (const c of extraClients) {
if (c.running) {
try {
await c.disconnect();
} catch {
// May already be disconnected
}
c.stop();
}
}
await delay(500);
// Stop the server
await server.stopServer();
server.stop();
await delay(500);
expect(server.running).toBeFalse();
});
export default tap.start();

357
test/test.loadtest.node.ts Normal file
View File

@@ -0,0 +1,357 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as stream from 'stream';
import { VpnClient, VpnServer } from '../ts/index.js';
import type { IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
async function findFreePort(): Promise<number> {
const server = net.createServer();
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
const port = (server.address() as net.AddressInfo).port;
await new Promise<void>((resolve) => server.close(() => resolve()));
return port;
}
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function waitFor(
fn: () => Promise<boolean>,
timeoutMs: number = 10000,
pollMs: number = 500,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await fn()) return;
await delay(pollMs);
}
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
}
// ---------------------------------------------------------------------------
// ThrottleProxy (adapted from remoteingress)
// ---------------------------------------------------------------------------
class ThrottleTransform extends stream.Transform {
private bytesPerSec: number;
private bucket: number;
private lastRefill: number;
private destroyed_: boolean = false;
constructor(bytesPerSecond: number) {
super();
this.bytesPerSec = bytesPerSecond;
this.bucket = bytesPerSecond;
this.lastRefill = Date.now();
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: stream.TransformCallback) {
if (this.destroyed_) return;
const now = Date.now();
const elapsed = (now - this.lastRefill) / 1000;
this.bucket = Math.min(this.bytesPerSec, this.bucket + elapsed * this.bytesPerSec);
this.lastRefill = now;
if (chunk.length <= this.bucket) {
this.bucket -= chunk.length;
callback(null, chunk);
} else {
const deficit = chunk.length - this.bucket;
this.bucket = 0;
const delayMs = Math.min((deficit / this.bytesPerSec) * 1000, 1000);
setTimeout(() => {
if (this.destroyed_) { callback(); return; }
this.lastRefill = Date.now();
this.bucket = 0;
callback(null, chunk);
}, delayMs);
}
}
_destroy(err: Error | null, callback: (error: Error | null) => void) {
this.destroyed_ = true;
callback(err);
}
}
interface ThrottleProxy {
server: net.Server;
close: () => Promise<void>;
}
async function startThrottleProxy(
listenPort: number,
targetHost: string,
targetPort: number,
bytesPerSecond: number,
): Promise<ThrottleProxy> {
const connections = new Set<net.Socket>();
const server = net.createServer((clientSock) => {
connections.add(clientSock);
const upstream = net.createConnection({ host: targetHost, port: targetPort });
connections.add(upstream);
const throttleUp = new ThrottleTransform(bytesPerSecond);
const throttleDown = new ThrottleTransform(bytesPerSecond);
clientSock.pipe(throttleUp).pipe(upstream);
upstream.pipe(throttleDown).pipe(clientSock);
let cleaned = false;
const cleanup = () => {
if (cleaned) return;
cleaned = true;
throttleUp.destroy();
throttleDown.destroy();
clientSock.destroy();
upstream.destroy();
connections.delete(clientSock);
connections.delete(upstream);
};
clientSock.on('error', () => cleanup());
upstream.on('error', () => cleanup());
throttleUp.on('error', () => cleanup());
throttleDown.on('error', () => cleanup());
clientSock.on('close', () => cleanup());
upstream.on('close', () => cleanup());
});
await new Promise<void>((resolve) => server.listen(listenPort, '127.0.0.1', resolve));
return {
server,
close: async () => {
for (const c of connections) c.destroy();
connections.clear();
await new Promise<void>((resolve) => server.close(() => resolve()));
},
};
}
// ---------------------------------------------------------------------------
// Test state
// ---------------------------------------------------------------------------
let server: VpnServer;
let serverPort: number;
let proxyPort: number;
let keypair: IVpnKeypair;
let throttle: ThrottleProxy;
const allClients: VpnClient[] = [];
async function createConnectedClient(port: number): Promise<VpnClient> {
const c = new VpnClient({ transport: { transport: 'stdio' } });
await c.start();
await c.connect({
serverUrl: `ws://127.0.0.1:${port}`,
serverPublicKey: keypair.publicKey,
keepaliveIntervalSecs: 3,
});
allClients.push(c);
return c;
}
async function stopClient(c: VpnClient): Promise<void> {
if (c.running) {
try { await c.disconnect(); } catch { /* already disconnected */ }
c.stop();
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
tap.test('setup: start throttled VPN tunnel (1 MB/s)', async () => {
serverPort = await findFreePort();
proxyPort = await findFreePort();
// Start VPN server
server = new VpnServer({ transport: { transport: 'stdio' } });
const started = await server['bridge'].start();
expect(started).toBeTrue();
keypair = await server.generateKeypair();
const serverConfig: IVpnServerConfig = {
listenAddr: `127.0.0.1:${serverPort}`,
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.8.0.0/24',
};
await server['bridge'].sendCommand('start', { config: serverConfig });
const status = await server.getStatus();
expect(status.state).toEqual('connected');
// Start throttle proxy: 1 MB/s
throttle = await startThrottleProxy(proxyPort, '127.0.0.1', serverPort, 1024 * 1024);
});
tap.test('throttled connection: handshake succeeds through throttle', async () => {
const client = await createConnectedClient(proxyPort);
const status = await client.getStatus();
expect(status.state).toEqual('connected');
expect(status.assignedIp).toStartWith('10.8.0.');
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 1;
});
});
tap.test('sustained keepalive under throttle', async () => {
// Wait for at least 2 keepalive cycles (3s interval)
await delay(8000);
const client = allClients[0];
const stats = await client.getStatistics();
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
// Throttle adds latency — RTT should be measurable
const quality = await client.getConnectionQuality();
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
expect(quality.jitterMs).toBeTypeofNumber();
});
tap.test('3 concurrent throttled clients', async () => {
for (let i = 0; i < 3; i++) {
await createConnectedClient(proxyPort);
}
// All 4 clients should be visible
await waitFor(async () => {
const clients = await server.listClients();
return clients.length === 4;
});
const clients = await server.listClients();
expect(clients.length).toEqual(4);
// Verify all IPs are unique
const ips = new Set(clients.map((c) => c.assignedIp));
expect(ips.size).toEqual(4);
});
tap.test('rate limiting combined with network throttle', async () => {
const clients = await server.listClients();
const targetId = clients[0].clientId;
// Set rate limit on first client
await server.setClientRateLimit(targetId, 500, 500);
const telemetry = await server.getClientTelemetry(targetId);
expect(telemetry.rateLimitBytesPerSec).toEqual(500);
expect(telemetry.burstBytes).toEqual(500);
// Verify another client has no rate limit
const otherTelemetry = await server.getClientTelemetry(clients[1].clientId);
expect(otherTelemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
// Clean up the rate limit
await server.removeClientRateLimit(targetId);
});
tap.test('burst waves: 3 waves of 3 clients', async () => {
const initialCount = (await server.listClients()).length;
for (let wave = 0; wave < 3; wave++) {
const waveClients: VpnClient[] = [];
// Connect 3 clients
for (let i = 0; i < 3; i++) {
const c = await createConnectedClient(proxyPort);
waveClients.push(c);
}
// Verify all connected
await waitFor(async () => {
const all = await server.listClients();
return all.length === initialCount + 3;
});
// Disconnect all wave clients
for (const c of waveClients) {
await stopClient(c);
}
// Wait for server to detect disconnections
await waitFor(async () => {
const all = await server.listClients();
return all.length === initialCount;
}, 15000);
await delay(500);
}
// Verify total connections accumulated
const stats = await server.getStatistics();
expect(stats.totalConnections).toBeGreaterThanOrEqual(9 + initialCount);
// Original clients still connected
const remaining = await server.listClients();
expect(remaining.length).toEqual(initialCount);
});
tap.test('aggressive throttle: 10 KB/s', async () => {
// Close current throttle proxy and start an aggressive one
await throttle.close();
const aggressivePort = await findFreePort();
throttle = await startThrottleProxy(aggressivePort, '127.0.0.1', serverPort, 10 * 1024);
// Connect a client through the aggressive throttle
const client = await createConnectedClient(aggressivePort);
const status = await client.getStatus();
expect(status.state).toEqual('connected');
// Wait for keepalive exchange (might take longer due to throttle)
await delay(10000);
const stats = await client.getStatistics();
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
});
tap.test('post-load health: direct connection still works', async () => {
// Server should still be healthy after all load tests
const serverStatus = await server.getStatus();
expect(serverStatus.state).toEqual('connected');
// Connect one more client directly (no throttle)
const directClient = await createConnectedClient(serverPort);
const status = await directClient.getStatus();
expect(status.state).toEqual('connected');
await delay(5000);
const stats = await directClient.getStatistics();
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
});
tap.test('teardown: stop all', async () => {
// Stop all clients
for (const c of allClients) {
await stopClient(c);
}
await delay(500);
// Close throttle proxy
if (throttle) {
await throttle.close();
}
// Stop server
await server.stopServer();
server.stop();
await delay(500);
expect(server.running).toBeFalse();
});
export default tap.start();

242
test/test.quic.node.ts Normal file
View File

@@ -0,0 +1,242 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as dgram from 'dgram';
import { VpnClient, VpnServer } from '../ts/index.js';
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
async function findFreePort(): Promise<number> {
const server = net.createServer();
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
const port = (server.address() as net.AddressInfo).port;
await new Promise<void>((resolve) => server.close(() => resolve()));
return port;
}
async function findFreeUdpPort(): Promise<number> {
const sock = dgram.createSocket('udp4');
await new Promise<void>((resolve) => sock.bind(0, '127.0.0.1', resolve));
const port = (sock.address() as net.AddressInfo).port;
await new Promise<void>((resolve) => sock.close(resolve));
return port;
}
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function waitFor(
fn: () => Promise<boolean>,
timeoutMs: number = 10000,
pollMs: number = 500,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await fn()) return;
await delay(pollMs);
}
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
}
// ---------------------------------------------------------------------------
// Test state
// ---------------------------------------------------------------------------
let server: VpnServer;
let wsPort: number;
let quicPort: number;
let keypair: IVpnKeypair;
// ---------------------------------------------------------------------------
// Tests: QUIC-only server + QUIC client
// ---------------------------------------------------------------------------
tap.test('setup: start VPN server in QUIC mode', async () => {
quicPort = await findFreeUdpPort();
const options: IVpnServerOptions = {
transport: { transport: 'stdio' },
};
server = new VpnServer(options);
const started = await server['bridge'].start();
expect(started).toBeTrue();
keypair = await server.generateKeypair();
const serverConfig: IVpnServerConfig = {
listenAddr: `127.0.0.1:${quicPort}`,
privateKey: keypair.privateKey,
publicKey: keypair.publicKey,
subnet: '10.9.0.0/24',
transportMode: 'quic',
keepaliveIntervalSecs: 3,
};
await server['bridge'].sendCommand('start', { config: serverConfig });
const status = await server.getStatus();
expect(status.state).toEqual('connected');
});
tap.test('QUIC client connects and gets IP', async () => {
const options: IVpnClientOptions = {
transport: { transport: 'stdio' },
};
const client = new VpnClient(options);
const started = await client.start();
expect(started).toBeTrue();
const result = await client.connect({
serverUrl: `127.0.0.1:${quicPort}`,
serverPublicKey: keypair.publicKey,
transport: 'quic',
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toBeTypeofString();
expect(result.assignedIp).toStartWith('10.9.0.');
const clientStatus = await client.getStatus();
expect(clientStatus.state).toEqual('connected');
// Verify server sees the client
await waitFor(async () => {
const clients = await server.listClients();
return clients.length >= 1;
});
await client.stop();
});
tap.test('teardown: stop QUIC server', async () => {
await server.stop();
await delay(500);
});
// ---------------------------------------------------------------------------
// Tests: dual-mode server (both) + auto client
// ---------------------------------------------------------------------------
let dualServer: VpnServer;
let dualWsPort: number;
let dualQuicPort: number;
let dualKeypair: IVpnKeypair;
tap.test('setup: start VPN server in both mode', async () => {
dualWsPort = await findFreePort();
dualQuicPort = await findFreeUdpPort();
const options: IVpnServerOptions = {
transport: { transport: 'stdio' },
};
dualServer = new VpnServer(options);
const started = await dualServer['bridge'].start();
expect(started).toBeTrue();
dualKeypair = await dualServer.generateKeypair();
const serverConfig: IVpnServerConfig = {
listenAddr: `127.0.0.1:${dualWsPort}`,
privateKey: dualKeypair.privateKey,
publicKey: dualKeypair.publicKey,
subnet: '10.10.0.0/24',
transportMode: 'both',
quicListenAddr: `127.0.0.1:${dualQuicPort}`,
keepaliveIntervalSecs: 3,
};
await dualServer['bridge'].sendCommand('start', { config: serverConfig });
const status = await dualServer.getStatus();
expect(status.state).toEqual('connected');
});
tap.test('auto client connects to dual-mode server (QUIC preferred)', async () => {
const options: IVpnClientOptions = {
transport: { transport: 'stdio' },
};
const client = new VpnClient(options);
const started = await client.start();
expect(started).toBeTrue();
// "auto" mode (default): tries QUIC first at same host:port, falls back to WS
// Since the WS port and QUIC port differ, auto will try QUIC on WS port (fail),
// then fall back to WebSocket
const result = await client.connect({
serverUrl: `ws://127.0.0.1:${dualWsPort}`,
serverPublicKey: dualKeypair.publicKey,
// transport defaults to 'auto'
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toBeTypeofString();
expect(result.assignedIp).toStartWith('10.10.0.');
const clientStatus = await client.getStatus();
expect(clientStatus.state).toEqual('connected');
await waitFor(async () => {
const clients = await dualServer.listClients();
return clients.length >= 1;
});
await client.stop();
});
tap.test('explicit QUIC client connects to dual-mode server', async () => {
const options: IVpnClientOptions = {
transport: { transport: 'stdio' },
};
const client = new VpnClient(options);
const started = await client.start();
expect(started).toBeTrue();
const result = await client.connect({
serverUrl: `127.0.0.1:${dualQuicPort}`,
serverPublicKey: dualKeypair.publicKey,
transport: 'quic',
keepaliveIntervalSecs: 3,
});
expect(result.assignedIp).toBeTypeofString();
expect(result.assignedIp).toStartWith('10.10.0.');
const clientStatus = await client.getStatus();
expect(clientStatus.state).toEqual('connected');
await client.stop();
});
tap.test('keepalive exchange over QUIC', async () => {
const options: IVpnClientOptions = {
transport: { transport: 'stdio' },
};
const client = new VpnClient(options);
await client.start();
await client.connect({
serverUrl: `127.0.0.1:${dualQuicPort}`,
serverPublicKey: dualKeypair.publicKey,
transport: 'quic',
keepaliveIntervalSecs: 3,
});
// Wait for keepalive exchange
await delay(8000);
const clientStats = await client.getStatistics();
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
await client.stop();
});
tap.test('teardown: stop dual-mode server', async () => {
await dualServer.stop();
await delay(500);
});
export default tap.start();

353
test/test.wireguard.node.ts Normal file
View File

@@ -0,0 +1,353 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
VpnConfig,
VpnServer,
WgConfigGenerator,
} from '../ts/index.js';
import type {
IVpnClientConfig,
IVpnServerConfig,
IVpnServerOptions,
IWgPeerConfig,
} from '../ts/index.js';
// ============================================================================
// WireGuard config validation — client
// ============================================================================
// A valid 32-byte key in base64 (44 chars)
const VALID_KEY = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=';
const VALID_KEY_2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=';
tap.test('WG client config: valid wireguard config passes validation', async () => {
const config: IVpnClientConfig = {
serverUrl: '', // not needed for WG
serverPublicKey: VALID_KEY,
transport: 'wireguard',
wgPrivateKey: VALID_KEY_2,
wgAddress: '10.8.0.2',
wgEndpoint: 'vpn.example.com:51820',
wgAllowedIps: ['0.0.0.0/0'],
};
VpnConfig.validateClientConfig(config);
});
tap.test('WG client config: rejects missing wgPrivateKey', async () => {
const config: IVpnClientConfig = {
serverUrl: '',
serverPublicKey: VALID_KEY,
transport: 'wireguard',
wgAddress: '10.8.0.2',
wgEndpoint: 'vpn.example.com:51820',
};
let threw = false;
try {
VpnConfig.validateClientConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('wgPrivateKey');
}
expect(threw).toBeTrue();
});
tap.test('WG client config: rejects missing wgAddress', async () => {
const config: IVpnClientConfig = {
serverUrl: '',
serverPublicKey: VALID_KEY,
transport: 'wireguard',
wgPrivateKey: VALID_KEY_2,
wgEndpoint: 'vpn.example.com:51820',
};
let threw = false;
try {
VpnConfig.validateClientConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('wgAddress');
}
expect(threw).toBeTrue();
});
tap.test('WG client config: rejects missing wgEndpoint', async () => {
const config: IVpnClientConfig = {
serverUrl: '',
serverPublicKey: VALID_KEY,
transport: 'wireguard',
wgPrivateKey: VALID_KEY_2,
wgAddress: '10.8.0.2',
};
let threw = false;
try {
VpnConfig.validateClientConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('wgEndpoint');
}
expect(threw).toBeTrue();
});
tap.test('WG client config: rejects invalid key length', async () => {
const config: IVpnClientConfig = {
serverUrl: '',
serverPublicKey: VALID_KEY,
transport: 'wireguard',
wgPrivateKey: 'tooshort',
wgAddress: '10.8.0.2',
wgEndpoint: 'vpn.example.com:51820',
};
let threw = false;
try {
VpnConfig.validateClientConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('44 characters');
}
expect(threw).toBeTrue();
});
tap.test('WG client config: rejects invalid CIDR in allowedIps', async () => {
const config: IVpnClientConfig = {
serverUrl: '',
serverPublicKey: VALID_KEY,
transport: 'wireguard',
wgPrivateKey: VALID_KEY_2,
wgAddress: '10.8.0.2',
wgEndpoint: 'vpn.example.com:51820',
wgAllowedIps: ['not-a-cidr'],
};
let threw = false;
try {
VpnConfig.validateClientConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('CIDR');
}
expect(threw).toBeTrue();
});
// ============================================================================
// WireGuard config validation — server
// ============================================================================
tap.test('WG server config: valid config passes validation', async () => {
const config: IVpnServerConfig = {
listenAddr: '',
privateKey: VALID_KEY,
publicKey: VALID_KEY_2,
subnet: '10.8.0.0/24',
transportMode: 'wireguard',
wgPeers: [
{
publicKey: VALID_KEY_2,
allowedIps: ['10.8.0.2/32'],
},
],
};
VpnConfig.validateServerConfig(config);
});
tap.test('WG server config: rejects empty wgPeers', async () => {
const config: IVpnServerConfig = {
listenAddr: '',
privateKey: VALID_KEY,
publicKey: VALID_KEY_2,
subnet: '10.8.0.0/24',
transportMode: 'wireguard',
wgPeers: [],
};
let threw = false;
try {
VpnConfig.validateServerConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('wgPeers');
}
expect(threw).toBeTrue();
});
tap.test('WG server config: rejects peer without publicKey', async () => {
const config: IVpnServerConfig = {
listenAddr: '',
privateKey: VALID_KEY,
publicKey: VALID_KEY_2,
subnet: '10.8.0.0/24',
transportMode: 'wireguard',
wgPeers: [
{
publicKey: '',
allowedIps: ['10.8.0.2/32'],
},
],
};
let threw = false;
try {
VpnConfig.validateServerConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('publicKey');
}
expect(threw).toBeTrue();
});
tap.test('WG server config: rejects invalid wgListenPort', async () => {
const config: IVpnServerConfig = {
listenAddr: '',
privateKey: VALID_KEY,
publicKey: VALID_KEY_2,
subnet: '10.8.0.0/24',
transportMode: 'wireguard',
wgListenPort: 0,
wgPeers: [
{
publicKey: VALID_KEY_2,
allowedIps: ['10.8.0.2/32'],
},
],
};
let threw = false;
try {
VpnConfig.validateServerConfig(config);
} catch (e) {
threw = true;
expect((e as Error).message).toContain('wgListenPort');
}
expect(threw).toBeTrue();
});
// ============================================================================
// WireGuard keypair generation via daemon
// ============================================================================
let server: VpnServer;
tap.test('WG: spawn server daemon for keypair generation', async () => {
const options: IVpnServerOptions = {
transport: { transport: 'stdio' },
};
server = new VpnServer(options);
const started = await server['bridge'].start();
expect(started).toBeTrue();
expect(server.running).toBeTrue();
});
tap.test('WG: generateWgKeypair returns valid keypair', async () => {
const keypair = await server.generateWgKeypair();
expect(keypair.publicKey).toBeTypeofString();
expect(keypair.privateKey).toBeTypeofString();
// WireGuard keys: base64 of 32 bytes = 44 characters
expect(keypair.publicKey.length).toEqual(44);
expect(keypair.privateKey.length).toEqual(44);
// Verify they decode to 32 bytes
const pubBuf = Buffer.from(keypair.publicKey, 'base64');
const privBuf = Buffer.from(keypair.privateKey, 'base64');
expect(pubBuf.length).toEqual(32);
expect(privBuf.length).toEqual(32);
});
tap.test('WG: generateWgKeypair returns unique keys each time', async () => {
const kp1 = await server.generateWgKeypair();
const kp2 = await server.generateWgKeypair();
expect(kp1.publicKey).not.toEqual(kp2.publicKey);
expect(kp1.privateKey).not.toEqual(kp2.privateKey);
});
tap.test('WG: stop server daemon', async () => {
server.stop();
await new Promise((resolve) => setTimeout(resolve, 500));
expect(server.running).toBeFalse();
});
// ============================================================================
// WireGuard config file generation
// ============================================================================
tap.test('WgConfigGenerator: generate client config', async () => {
const conf = WgConfigGenerator.generateClientConfig({
privateKey: 'clientPrivateKeyBase64====================',
address: '10.8.0.2/24',
dns: ['1.1.1.1', '8.8.8.8'],
mtu: 1420,
peer: {
publicKey: 'serverPublicKeyBase64====================',
endpoint: 'vpn.example.com:51820',
allowedIps: ['0.0.0.0/0', '::/0'],
persistentKeepalive: 25,
},
});
expect(conf).toContain('[Interface]');
expect(conf).toContain('PrivateKey = clientPrivateKeyBase64====================');
expect(conf).toContain('Address = 10.8.0.2/24');
expect(conf).toContain('DNS = 1.1.1.1, 8.8.8.8');
expect(conf).toContain('MTU = 1420');
expect(conf).toContain('[Peer]');
expect(conf).toContain('PublicKey = serverPublicKeyBase64====================');
expect(conf).toContain('Endpoint = vpn.example.com:51820');
expect(conf).toContain('AllowedIPs = 0.0.0.0/0, ::/0');
expect(conf).toContain('PersistentKeepalive = 25');
});
tap.test('WgConfigGenerator: generate client config without optional fields', async () => {
const conf = WgConfigGenerator.generateClientConfig({
privateKey: 'key1',
address: '10.0.0.2/32',
peer: {
publicKey: 'key2',
endpoint: 'server:51820',
allowedIps: ['10.0.0.0/24'],
},
});
expect(conf).toContain('[Interface]');
expect(conf).not.toContain('DNS');
expect(conf).not.toContain('MTU');
expect(conf).not.toContain('PresharedKey');
expect(conf).not.toContain('PersistentKeepalive');
});
tap.test('WgConfigGenerator: generate server config with NAT', async () => {
const conf = WgConfigGenerator.generateServerConfig({
privateKey: 'serverPrivKey',
address: '10.8.0.1/24',
listenPort: 51820,
dns: ['1.1.1.1'],
enableNat: true,
natInterface: 'ens3',
peers: [
{
publicKey: 'peer1PubKey',
allowedIps: ['10.8.0.2/32'],
presharedKey: 'psk1',
persistentKeepalive: 25,
},
{
publicKey: 'peer2PubKey',
allowedIps: ['10.8.0.3/32'],
},
],
});
expect(conf).toContain('[Interface]');
expect(conf).toContain('ListenPort = 51820');
expect(conf).toContain('PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ens3 -j MASQUERADE');
expect(conf).toContain('PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ens3 -j MASQUERADE');
// Two [Peer] sections
const peerCount = (conf.match(/\[Peer\]/g) || []).length;
expect(peerCount).toEqual(2);
expect(conf).toContain('PresharedKey = psk1');
});
tap.test('WgConfigGenerator: generate server config without NAT', async () => {
const conf = WgConfigGenerator.generateServerConfig({
privateKey: 'serverPrivKey',
address: '10.8.0.1/24',
listenPort: 51820,
peers: [
{
publicKey: 'peerKey',
allowedIps: ['10.8.0.2/32'],
},
],
});
expect(conf).not.toContain('PostUp');
expect(conf).not.toContain('PostDown');
});
export default tap.start();

View File

@@ -3,6 +3,6 @@
*/
export const commitinfo = {
name: '@push.rocks/smartvpn',
version: '1.0.2',
version: '1.6.0',
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
}

View File

@@ -4,3 +4,4 @@ export { VpnClient } from './smartvpn.classes.vpnclient.js';
export { VpnServer } from './smartvpn.classes.vpnserver.js';
export { VpnConfig } from './smartvpn.classes.vpnconfig.js';
export { VpnInstaller } from './smartvpn.classes.vpninstaller.js';
export { WgConfigGenerator } from './smartvpn.classes.wgconfig.js';

View File

@@ -5,6 +5,8 @@ import type {
IVpnClientConfig,
IVpnStatus,
IVpnStatistics,
IVpnConnectionQuality,
IVpnMtuInfo,
TVpnClientCommands,
} from './smartvpn.interfaces.js';
@@ -65,12 +67,26 @@ export class VpnClient extends plugins.events.EventEmitter {
}
/**
* Get traffic statistics.
* Get traffic statistics (includes connection quality when connected).
*/
public async getStatistics(): Promise<IVpnStatistics> {
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
}
/**
* Get connection quality metrics (RTT, jitter, loss, link health).
*/
public async getConnectionQuality(): Promise<IVpnConnectionQuality> {
return this.bridge.sendCommand('getConnectionQuality', {} as Record<string, never>);
}
/**
* Get MTU information (overhead, effective MTU, oversized packet stats).
*/
public async getMtuInfo(): Promise<IVpnMtuInfo> {
return this.bridge.sendCommand('getMtuInfo', {} as Record<string, never>);
}
/**
* Stop the daemon bridge.
*/

View File

@@ -12,14 +12,45 @@ export class VpnConfig {
* Validate a client config object. Throws on invalid config.
*/
public static validateClientConfig(config: IVpnClientConfig): void {
if (!config.serverUrl) {
throw new Error('VpnConfig: serverUrl is required');
}
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
throw new Error('VpnConfig: serverUrl must start with wss:// or ws://');
}
if (!config.serverPublicKey) {
throw new Error('VpnConfig: serverPublicKey is required');
if (config.transport === 'wireguard') {
// WireGuard-specific validation
if (!config.wgPrivateKey) {
throw new Error('VpnConfig: wgPrivateKey is required for WireGuard transport');
}
VpnConfig.validateBase64Key(config.wgPrivateKey, 'wgPrivateKey');
if (!config.wgAddress) {
throw new Error('VpnConfig: wgAddress is required for WireGuard transport');
}
if (!config.serverPublicKey) {
throw new Error('VpnConfig: serverPublicKey is required for WireGuard transport');
}
VpnConfig.validateBase64Key(config.serverPublicKey, 'serverPublicKey');
if (!config.wgEndpoint) {
throw new Error('VpnConfig: wgEndpoint is required for WireGuard transport');
}
if (config.wgPresharedKey) {
VpnConfig.validateBase64Key(config.wgPresharedKey, 'wgPresharedKey');
}
if (config.wgAllowedIps) {
for (const cidr of config.wgAllowedIps) {
if (!VpnConfig.isValidCidr(cidr)) {
throw new Error(`VpnConfig: invalid allowedIp CIDR: ${cidr}`);
}
}
}
} else {
if (!config.serverUrl) {
throw new Error('VpnConfig: serverUrl is required');
}
// For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss://
if (config.transport !== 'quic') {
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)');
}
}
if (!config.serverPublicKey) {
throw new Error('VpnConfig: serverPublicKey is required');
}
}
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
throw new Error('VpnConfig: mtu must be between 576 and 65535');
@@ -40,20 +71,51 @@ export class VpnConfig {
* Validate a server config object. Throws on invalid config.
*/
public static validateServerConfig(config: IVpnServerConfig): void {
if (!config.listenAddr) {
throw new Error('VpnConfig: listenAddr is required');
}
if (!config.privateKey) {
throw new Error('VpnConfig: privateKey is required');
}
if (!config.publicKey) {
throw new Error('VpnConfig: publicKey is required');
}
if (!config.subnet) {
throw new Error('VpnConfig: subnet is required');
}
if (!VpnConfig.isValidSubnet(config.subnet)) {
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
if (config.transportMode === 'wireguard') {
// WireGuard server validation
if (!config.privateKey) {
throw new Error('VpnConfig: privateKey is required');
}
VpnConfig.validateBase64Key(config.privateKey, 'privateKey');
if (!config.wgPeers || config.wgPeers.length === 0) {
throw new Error('VpnConfig: at least one wgPeers entry is required for WireGuard mode');
}
for (const peer of config.wgPeers) {
if (!peer.publicKey) {
throw new Error('VpnConfig: peer publicKey is required');
}
VpnConfig.validateBase64Key(peer.publicKey, 'peer.publicKey');
if (!peer.allowedIps || peer.allowedIps.length === 0) {
throw new Error('VpnConfig: peer allowedIps is required');
}
for (const cidr of peer.allowedIps) {
if (!VpnConfig.isValidCidr(cidr)) {
throw new Error(`VpnConfig: invalid peer allowedIp CIDR: ${cidr}`);
}
}
if (peer.presharedKey) {
VpnConfig.validateBase64Key(peer.presharedKey, 'peer.presharedKey');
}
}
if (config.wgListenPort !== undefined && (config.wgListenPort < 1 || config.wgListenPort > 65535)) {
throw new Error('VpnConfig: wgListenPort must be between 1 and 65535');
}
} else {
if (!config.listenAddr) {
throw new Error('VpnConfig: listenAddr is required');
}
if (!config.privateKey) {
throw new Error('VpnConfig: privateKey is required');
}
if (!config.publicKey) {
throw new Error('VpnConfig: publicKey is required');
}
if (!config.subnet) {
throw new Error('VpnConfig: subnet is required');
}
if (!VpnConfig.isValidSubnet(config.subnet)) {
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
}
}
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
throw new Error('VpnConfig: mtu must be between 576 and 65535');
@@ -101,4 +163,41 @@ export class VpnConfig {
const prefixNum = parseInt(prefix, 10);
return !isNaN(prefixNum) && prefixNum >= 0 && prefixNum <= 32;
}
/**
* Validate a CIDR string (IPv4 or IPv6).
*/
private static isValidCidr(cidr: string): boolean {
const parts = cidr.split('/');
if (parts.length !== 2) return false;
const prefixNum = parseInt(parts[1], 10);
if (isNaN(prefixNum) || prefixNum < 0) return false;
// IPv4
if (VpnConfig.isValidIp(parts[0])) {
return prefixNum <= 32;
}
// IPv6 (basic check)
if (parts[0].includes(':')) {
return prefixNum <= 128;
}
return false;
}
/**
* Validate a base64-encoded 32-byte key (WireGuard X25519 format).
*/
private static validateBase64Key(key: string, fieldName: string): void {
if (key.length !== 44) {
throw new Error(`VpnConfig: ${fieldName} must be 44 characters (base64 of 32 bytes), got ${key.length}`);
}
try {
const buf = Buffer.from(key, 'base64');
if (buf.length !== 32) {
throw new Error(`VpnConfig: ${fieldName} must decode to 32 bytes, got ${buf.length}`);
}
} catch (e) {
if (e instanceof Error && e.message.startsWith('VpnConfig:')) throw e;
throw new Error(`VpnConfig: ${fieldName} is not valid base64`);
}
}
}

View File

@@ -7,6 +7,9 @@ import type {
IVpnServerStatistics,
IVpnClientInfo,
IVpnKeypair,
IVpnClientTelemetry,
IWgPeerConfig,
IWgPeerInfo,
TVpnServerCommands,
} from './smartvpn.interfaces.js';
@@ -91,6 +94,64 @@ export class VpnServer extends plugins.events.EventEmitter {
return this.bridge.sendCommand('generateKeypair', {} as Record<string, never>);
}
/**
* Set rate limit for a specific client.
*/
public async setClientRateLimit(
clientId: string,
rateBytesPerSec: number,
burstBytes: number,
): Promise<void> {
await this.bridge.sendCommand('setClientRateLimit', {
clientId,
rateBytesPerSec,
burstBytes,
});
}
/**
* Remove rate limit for a specific client (unlimited).
*/
public async removeClientRateLimit(clientId: string): Promise<void> {
await this.bridge.sendCommand('removeClientRateLimit', { clientId });
}
/**
* Get telemetry for a specific client.
*/
public async getClientTelemetry(clientId: string): Promise<IVpnClientTelemetry> {
return this.bridge.sendCommand('getClientTelemetry', { clientId });
}
/**
* Generate a WireGuard-compatible X25519 keypair.
*/
public async generateWgKeypair(): Promise<IVpnKeypair> {
return this.bridge.sendCommand('generateWgKeypair', {} as Record<string, never>);
}
/**
* Add a WireGuard peer (server must be running in wireguard mode).
*/
public async addWgPeer(peer: IWgPeerConfig): Promise<void> {
await this.bridge.sendCommand('addWgPeer', { peer });
}
/**
* Remove a WireGuard peer by public key.
*/
public async removeWgPeer(publicKey: string): Promise<void> {
await this.bridge.sendCommand('removeWgPeer', { publicKey });
}
/**
* List WireGuard peers with stats.
*/
public async listWgPeers(): Promise<IWgPeerInfo[]> {
const result = await this.bridge.sendCommand('listWgPeers', {} as Record<string, never>);
return result.peers;
}
/**
* Stop the daemon bridge.
*/

View File

@@ -0,0 +1,123 @@
import type { IWgPeerConfig } from './smartvpn.interfaces.js';
// ============================================================================
// WireGuard .conf file generator
// ============================================================================
export interface IWgClientConfOptions {
/** Client private key (base64) */
privateKey: string;
/** Client TUN address with prefix (e.g. 10.8.0.2/24) */
address: string;
/** DNS servers */
dns?: string[];
/** TUN MTU */
mtu?: number;
/** Server peer config */
peer: {
publicKey: string;
presharedKey?: string;
endpoint: string;
allowedIps: string[];
persistentKeepalive?: number;
};
}
export interface IWgServerConfOptions {
/** Server private key (base64) */
privateKey: string;
/** Server TUN address with prefix (e.g. 10.8.0.1/24) */
address: string;
/** UDP listen port */
listenPort: number;
/** DNS servers */
dns?: string[];
/** TUN MTU */
mtu?: number;
/** Enable NAT — adds PostUp/PostDown iptables rules */
enableNat?: boolean;
/** Network interface for NAT (e.g. eth0). Auto-detected if omitted. */
natInterface?: string;
/** Configured peers */
peers: IWgPeerConfig[];
}
/**
* Generates standard WireGuard .conf files compatible with wg-quick,
* WireGuard iOS/Android apps, and other standard WireGuard clients.
*/
export class WgConfigGenerator {
/**
* Generate a client .conf file content.
*/
public static generateClientConfig(opts: IWgClientConfOptions): string {
const lines: string[] = [];
lines.push('[Interface]');
lines.push(`PrivateKey = ${opts.privateKey}`);
lines.push(`Address = ${opts.address}`);
if (opts.dns && opts.dns.length > 0) {
lines.push(`DNS = ${opts.dns.join(', ')}`);
}
if (opts.mtu) {
lines.push(`MTU = ${opts.mtu}`);
}
lines.push('');
lines.push('[Peer]');
lines.push(`PublicKey = ${opts.peer.publicKey}`);
if (opts.peer.presharedKey) {
lines.push(`PresharedKey = ${opts.peer.presharedKey}`);
}
lines.push(`Endpoint = ${opts.peer.endpoint}`);
lines.push(`AllowedIPs = ${opts.peer.allowedIps.join(', ')}`);
if (opts.peer.persistentKeepalive) {
lines.push(`PersistentKeepalive = ${opts.peer.persistentKeepalive}`);
}
lines.push('');
return lines.join('\n');
}
/**
* Generate a server .conf file content.
*/
public static generateServerConfig(opts: IWgServerConfOptions): string {
const lines: string[] = [];
lines.push('[Interface]');
lines.push(`PrivateKey = ${opts.privateKey}`);
lines.push(`Address = ${opts.address}`);
lines.push(`ListenPort = ${opts.listenPort}`);
if (opts.dns && opts.dns.length > 0) {
lines.push(`DNS = ${opts.dns.join(', ')}`);
}
if (opts.mtu) {
lines.push(`MTU = ${opts.mtu}`);
}
if (opts.enableNat) {
const iface = opts.natInterface || 'eth0';
lines.push(`PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ${iface} -j MASQUERADE`);
lines.push(`PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ${iface} -j MASQUERADE`);
}
for (const peer of opts.peers) {
lines.push('');
lines.push('[Peer]');
lines.push(`PublicKey = ${peer.publicKey}`);
if (peer.presharedKey) {
lines.push(`PresharedKey = ${peer.presharedKey}`);
}
lines.push(`AllowedIPs = ${peer.allowedIps.join(', ')}`);
if (peer.endpoint) {
lines.push(`Endpoint = ${peer.endpoint}`);
}
if (peer.persistentKeepalive) {
lines.push(`PersistentKeepalive = ${peer.persistentKeepalive}`);
}
}
lines.push('');
return lines.join('\n');
}
}

View File

@@ -32,6 +32,24 @@ export interface IVpnClientConfig {
mtu?: number;
/** Keepalive interval in seconds (default: 30) */
keepaliveIntervalSecs?: number;
/** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', 'quic', or 'wireguard' */
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
/** For QUIC: SHA-256 hash of server certificate (base64) for cert pinning */
serverCertHash?: string;
/** WireGuard: client private key (base64, X25519) */
wgPrivateKey?: string;
/** WireGuard: client TUN address (e.g. 10.8.0.2) */
wgAddress?: string;
/** WireGuard: client TUN address prefix length (default: 24) */
wgAddressPrefix?: number;
/** WireGuard: preshared key (base64, optional) */
wgPresharedKey?: string;
/** WireGuard: persistent keepalive interval in seconds */
wgPersistentKeepalive?: number;
/** WireGuard: server endpoint (host:port, e.g. vpn.example.com:51820) */
wgEndpoint?: string;
/** WireGuard: allowed IPs (CIDR strings, e.g. ['0.0.0.0/0']) */
wgAllowedIps?: string[];
}
export interface IVpnClientOptions {
@@ -64,6 +82,20 @@ export interface IVpnServerConfig {
keepaliveIntervalSecs?: number;
/** Enable NAT/masquerade for client traffic */
enableNat?: boolean;
/** Default rate limit for new clients (bytes/sec). Omit for unlimited. */
defaultRateLimitBytesPerSec?: number;
/** Default burst size for new clients (bytes). Omit for unlimited. */
defaultBurstBytes?: number;
/** Transport mode: 'both' (default, WS+QUIC), 'websocket', 'quic', or 'wireguard' */
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
/** QUIC listen address (host:port). Defaults to listenAddr. */
quicListenAddr?: string;
/** QUIC idle timeout in seconds (default: 30) */
quicIdleTimeoutSecs?: number;
/** WireGuard: UDP listen port (default: 51820) */
wgListenPort?: number;
/** WireGuard: configured peers */
wgPeers?: IWgPeerConfig[];
}
export interface IVpnServerOptions {
@@ -99,6 +131,7 @@ export interface IVpnStatistics {
keepalivesSent: number;
keepalivesReceived: number;
uptimeSeconds: number;
quality?: IVpnConnectionQuality;
}
export interface IVpnClientInfo {
@@ -107,6 +140,12 @@ export interface IVpnClientInfo {
connectedSince: string;
bytesSent: number;
bytesReceived: number;
packetsDropped: number;
bytesDropped: number;
lastKeepaliveAt?: string;
keepalivesReceived: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
}
export interface IVpnServerStatistics extends IVpnStatistics {
@@ -119,6 +158,82 @@ export interface IVpnKeypair {
privateKey: string;
}
// ============================================================================
// QoS: Connection quality
// ============================================================================
export type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
export interface IVpnConnectionQuality {
srttMs: number;
jitterMs: number;
minRttMs: number;
maxRttMs: number;
lossRatio: number;
consecutiveTimeouts: number;
linkHealth: TVpnLinkHealth;
currentKeepaliveIntervalSecs: number;
}
// ============================================================================
// QoS: MTU info
// ============================================================================
export interface IVpnMtuInfo {
tunMtu: number;
effectiveMtu: number;
linkMtu: number;
overheadBytes: number;
oversizedPacketsDropped: number;
icmpTooBigSent: number;
}
// ============================================================================
// QoS: Client telemetry (server-side per-client)
// ============================================================================
export interface IVpnClientTelemetry {
clientId: string;
assignedIp: string;
lastKeepaliveAt?: string;
keepalivesReceived: number;
packetsDropped: number;
bytesDropped: number;
bytesReceived: number;
bytesSent: number;
rateLimitBytesPerSec?: number;
burstBytes?: number;
}
// ============================================================================
// WireGuard-specific types
// ============================================================================
export interface IWgPeerConfig {
/** Peer's public key (base64, X25519) */
publicKey: string;
/** Optional preshared key (base64) */
presharedKey?: string;
/** Allowed IP ranges (CIDR strings) */
allowedIps: string[];
/** Peer endpoint (host:port) — optional for server peers, required for client */
endpoint?: string;
/** Persistent keepalive interval in seconds */
persistentKeepalive?: number;
}
export interface IWgPeerInfo {
publicKey: string;
allowedIps: string[];
endpoint?: string;
persistentKeepalive?: number;
bytesSent: number;
bytesReceived: number;
packetsSent: number;
packetsReceived: number;
lastHandshakeTime?: string;
}
// ============================================================================
// IPC Command maps (used by smartrust RustBridge<TCommands>)
// ============================================================================
@@ -128,6 +243,8 @@ export type TVpnClientCommands = {
disconnect: { params: Record<string, never>; result: void };
getStatus: { params: Record<string, never>; result: IVpnStatus };
getStatistics: { params: Record<string, never>; result: IVpnStatistics };
getConnectionQuality: { params: Record<string, never>; result: IVpnConnectionQuality };
getMtuInfo: { params: Record<string, never>; result: IVpnMtuInfo };
};
export type TVpnServerCommands = {
@@ -138,6 +255,13 @@ export type TVpnServerCommands = {
listClients: { params: Record<string, never>; result: { clients: IVpnClientInfo[] } };
disconnectClient: { params: { clientId: string }; result: void };
generateKeypair: { params: Record<string, never>; result: IVpnKeypair };
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
removeClientRateLimit: { params: { clientId: string }; result: void };
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
generateWgKeypair: { params: Record<string, never>; result: IVpnKeypair };
addWgPeer: { params: { peer: IWgPeerConfig }; result: void };
removeWgPeer: { params: { publicKey: string }; result: void };
listWgPeers: { params: Record<string, never>; result: { peers: IWgPeerInfo[] } };
};
// ============================================================================