Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e31086d0c2 | |||
| 01a0d8b9f4 | |||
| 187a69028b | |||
| 64dedd389e | |||
| 13d8cbe3fa | |||
| f46ea70286 | |||
| 26ee3634c8 | |||
| 049fa00563 | |||
| e4e59d72f9 | |||
| 51d33127bf | |||
| a4ba6806e5 | |||
| 6330921160 | |||
| e81dd377d8 | |||
| e14c357ba0 | |||
| eb30825f72 | |||
| 835f0f791d | |||
| aec545fe8c | |||
| 4fab721d87 | |||
| 9ee41348e0 | |||
| 97bb148063 | |||
| c8d572b719 |
74
changelog.md
74
changelog.md
@@ -1,5 +1,79 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-03-29 - 1.8.0 - feat(auth,client-registry)
|
||||
add Noise IK client authentication with managed client registry and per-client ACL controls
|
||||
|
||||
- switch the native tunnel handshake from Noise NK to Noise IK and require client keypairs in client configuration
|
||||
- add server-side client registry management APIs for creating, updating, disabling, rotating, listing, and exporting client configs
|
||||
- enforce client authorization from the registry during handshake and expose authenticated client metadata in server client info
|
||||
- introduce per-client security policies with source/destination ACLs and per-client rate limit settings
|
||||
- add Rust ACL matching support for exact IPs, CIDR ranges, wildcards, and IP ranges with test coverage
|
||||
|
||||
## 2026-03-29 - 1.7.0 - feat(rust-tests)
|
||||
add end-to-end WireGuard UDP integration tests and align TypeScript build configuration
|
||||
|
||||
- Add userspace Rust end-to-end tests that validate WireGuard handshake, encryption, peer isolation, and preshared-key data exchange over real UDP sockets.
|
||||
- Update the TypeScript build setup by removing the allowimplicitany build flag and explicitly including Node types in tsconfig.
|
||||
- Refresh development toolchain versions to support the updated test and build workflow.
|
||||
|
||||
## 2026-03-29 - 1.6.0 - feat(readme)
|
||||
document WireGuard transport support, configuration, and usage examples
|
||||
|
||||
- Expand the README from dual-transport to triple-transport support by adding WireGuard alongside WebSocket and QUIC
|
||||
- Add client and server WireGuard examples, including live peer management and .conf generation with WgConfigGenerator
|
||||
- Document new WireGuard-related API methods, config fields, transport modes, and security model details
|
||||
|
||||
## 2026-03-29 - 1.5.0 - feat(wireguard)
|
||||
add WireGuard transport support with management APIs and config generation
|
||||
|
||||
- add Rust WireGuard module integration using boringtun and route management through client/server management handlers
|
||||
- extend TypeScript client and server configuration schemas with WireGuard-specific options and validation
|
||||
- add server-side WireGuard peer management commands including keypair generation, peer add/remove, and peer listing
|
||||
- export a WireGuard config generator for producing client and server .conf files
|
||||
- add WireGuard-focused test coverage for config validation and config generation
|
||||
|
||||
## 2026-03-21 - 1.4.1 - fix(readme)
|
||||
preserve markdown line breaks in feature list
|
||||
|
||||
- Adds trailing spaces to the README feature list so each highlighted capability renders on its own line.
|
||||
|
||||
## 2026-03-19 - 1.4.0 - feat(vpn transport)
|
||||
add QUIC transport support with auto fallback to WebSocket
|
||||
|
||||
- introduces a transport abstraction in the Rust daemon so client and server can operate over WebSocket or QUIC
|
||||
- adds dual-mode server configuration with websocket, quic, and both transport modes plus QUIC idle timeout and listen address options
|
||||
- adds client transport selection with auto mode that attempts QUIC first and falls back to WebSocket
|
||||
- adds QUIC certificate hash pinning support and required Rust dependencies for QUIC and TLS
|
||||
- updates TypeScript interfaces, config validation, tests, and documentation to cover the new transport modes
|
||||
|
||||
## 2026-03-17 - 1.3.0 - feat(tests,client)
|
||||
add flow control and load test coverage and honor configured keepalive intervals
|
||||
|
||||
- Adds end-to-end node tests for client/server flow control, keepalive exchange, connection quality telemetry, rate limiting, concurrent clients, and disconnect tracking.
|
||||
- Adds load testing with throttled proxy scenarios to validate behavior under constrained bandwidth and repeated client churn.
|
||||
- Updates the Rust client to pass configured keepaliveIntervalSecs into the adaptive keepalive monitor instead of always using defaults.
|
||||
|
||||
## 2026-03-15 - 1.2.0 - feat(readme)
|
||||
document QoS, telemetry, MTU, and rate limiting capabilities in the README
|
||||
|
||||
- Expand the architecture and feature overview to cover adaptive keepalive, telemetry, QoS, rate limiting, and MTU handling
|
||||
- Update client and server examples to show new APIs such as getConnectionQuality(), getMtuInfo(), setClientRateLimit(), and getClientTelemetry()
|
||||
- Add TypeScript interface documentation for connection quality, MTU info, enriched client statistics, and per-client telemetry
|
||||
|
||||
## 2026-03-15 - 1.1.0 - feat(rust-core)
|
||||
add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
|
||||
|
||||
- adds adaptive keepalive monitoring with RTT, jitter, loss, and link health reporting to client statistics and management endpoints
|
||||
- introduces MTU overhead calculation and oversized-packet handling support, plus client MTU info APIs
|
||||
- adds token-bucket rate limiting with configurable default limits and server management commands to set, remove, and inspect per-client telemetry
|
||||
- extends TypeScript client and server interfaces with connection quality, MTU, and client telemetry methods
|
||||
|
||||
## 2026-02-27 - 1.0.3 - fix(build)
|
||||
add aarch64 linker configuration for cross-compilation
|
||||
|
||||
- Added rust/.cargo/config.toml to configure linker for target aarch64-unknown-linux-gnu
|
||||
- Sets linker to 'aarch64-linux-gnu-gcc' to enable cross-compilation to ARM64
|
||||
|
||||
## 2026-02-27 - 1.0.2 - fix()
|
||||
no changes detected - no code or content modifications
|
||||
|
||||
|
||||
19
package.json
19
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartvpn",
|
||||
"version": "1.0.2",
|
||||
"version": "1.8.0",
|
||||
"private": false,
|
||||
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
|
||||
"type": "module",
|
||||
@@ -10,7 +10,8 @@
|
||||
"main": "dist_ts/index.js",
|
||||
"typings": "dist_ts/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
|
||||
"build": "(tsbuild tsfolders) && (tsrust)",
|
||||
"test:before": "(tsrust)",
|
||||
"test": "tstest test/ --verbose",
|
||||
"buildDocs": "tsdoc"
|
||||
},
|
||||
@@ -28,15 +29,15 @@
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@push.rocks/smartrust": "^1.3.0",
|
||||
"@push.rocks/smartpath": "^5.0.18"
|
||||
"@push.rocks/smartpath": "^6.0.0",
|
||||
"@push.rocks/smartrust": "^1.3.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@git.zone/tsbuild": "^2.2.12",
|
||||
"@git.zone/tsrun": "^1.3.3",
|
||||
"@git.zone/tstest": "^1.0.96",
|
||||
"@git.zone/tsrust": "^1.0.29",
|
||||
"@types/node": "^22.0.0"
|
||||
"@git.zone/tsbuild": "^4.4.0",
|
||||
"@git.zone/tsrun": "^2.0.2",
|
||||
"@git.zone/tsrust": "^1.3.2",
|
||||
"@git.zone/tstest": "^3.6.3",
|
||||
"@types/node": "^25.5.0"
|
||||
},
|
||||
"files": [
|
||||
"ts/**/*",
|
||||
|
||||
3545
pnpm-lock.yaml
generated
3545
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
588
readme.md
588
readme.md
@@ -1,396 +1,366 @@
|
||||
# @push.rocks/smartvpn
|
||||
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, typed APIs while all networking heavy lifting — encryption, tunneling, packet forwarding — runs at native speed in Rust.
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Enterprise-ready client authentication, triple transport support (WebSocket + QUIC + WireGuard), and a typed hub API for managing clients from code.
|
||||
|
||||
🔐 **Noise IK** mutual authentication — per-client X25519 keypairs, server-side registry
|
||||
🚀 **Triple transport**: WebSocket (Cloudflare-friendly), raw **QUIC** (datagrams), and **WireGuard** (standard protocol)
|
||||
🛡️ **ACL engine** — deny-overrides-allow IP filtering, aligned with SmartProxy conventions
|
||||
📊 **Adaptive QoS**: per-client rate limiting, priority queues, connection quality tracking
|
||||
🔄 **Hub API**: one `createClient()` call generates keys, assigns IP, returns both SmartVPN + WireGuard configs
|
||||
📡 **Real-time telemetry**: RTT, jitter, loss ratio, link health — all via typed APIs
|
||||
|
||||
## Issue Reporting and Security
|
||||
|
||||
For reporting bugs, issues, or security vulnerabilities, please visit [community.foss.global/](https://community.foss.global/). This is the central community hub for all issue reporting. Developers who sign and comply with our contribution agreement and go through identification can also get a [code.foss.global/](https://code.foss.global/) account to submit Pull Requests directly.
|
||||
|
||||
## Install
|
||||
## Install 📦
|
||||
|
||||
```bash
|
||||
npm install @push.rocks/smartvpn
|
||||
# or
|
||||
pnpm install @push.rocks/smartvpn
|
||||
# or
|
||||
npm install @push.rocks/smartvpn
|
||||
```
|
||||
|
||||
## 🏗️ Architecture
|
||||
The package ships with pre-compiled Rust binaries for **linux/amd64** and **linux/arm64**. No Rust toolchain required at runtime.
|
||||
|
||||
## Architecture 🏗️
|
||||
|
||||
```
|
||||
TypeScript (control plane) Rust (data plane)
|
||||
┌──────────────────────────┐ ┌───────────────────────────────┐
|
||||
│ VpnClient / VpnServer │ │ smartvpn_daemon │
|
||||
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
|
||||
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
|
||||
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha) │
|
||||
└──────────────────────────┘ │ ├─ codec (binary framing) │
|
||||
│ ├─ keepalive (app-level) │
|
||||
│ ├─ tunnel (TUN device) │
|
||||
│ ├─ network (NAT/IP pool) │
|
||||
│ └─ reconnect (backoff) │
|
||||
└───────────────────────────────┘
|
||||
┌──────────────────────────────┐ JSON-lines IPC ┌───────────────────────────────┐
|
||||
│ TypeScript Control Plane │ ◄─────────────────────► │ Rust Data Plane Daemon │
|
||||
│ │ stdio or Unix sock │ │
|
||||
│ VpnServer / VpnClient │ │ Noise IK handshake │
|
||||
│ Typed IPC commands │ │ XChaCha20-Poly1305 │
|
||||
│ Config validation │ │ WS + QUIC + WireGuard │
|
||||
│ Hub: client management │ │ TUN device, IP pool, NAT │
|
||||
│ WireGuard .conf generation │ │ Rate limiting, ACLs, QoS │
|
||||
└──────────────────────────────┘ └───────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key design decisions:**
|
||||
**Split-plane design** — TypeScript handles orchestration, config, and DX; Rust handles every hot-path byte with zero-copy async I/O (tokio, mimalloc).
|
||||
|
||||
| Decision | Choice | Why |
|
||||
|----------|--------|-----|
|
||||
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
|
||||
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
|
||||
| Keepalive | App-level (not WS pings) | Cloudflare drops WS ping frames; app-level pings survive |
|
||||
| IPC | JSON lines over stdio/Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
|
||||
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
|
||||
## Quick Start 🚀
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### VPN Client
|
||||
|
||||
```typescript
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Development: spawn the Rust daemon as a child process
|
||||
const client = new VpnClient({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon bridge
|
||||
await client.start();
|
||||
|
||||
// Connect to a VPN server
|
||||
const { assignedIp } = await client.connect({
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
});
|
||||
|
||||
console.log(`Connected! Assigned IP: ${assignedIp}`);
|
||||
|
||||
// Check status
|
||||
const status = await client.getStatus();
|
||||
console.log(status); // { state: 'connected', assignedIp: '10.8.0.2', ... }
|
||||
|
||||
// Get traffic stats
|
||||
const stats = await client.getStatistics();
|
||||
console.log(stats); // { bytesSent, bytesReceived, packetsSent, ... }
|
||||
|
||||
// Disconnect
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
```
|
||||
|
||||
### VPN Server
|
||||
### 1. Start a VPN Server (Hub)
|
||||
|
||||
```typescript
|
||||
import { VpnServer } from '@push.rocks/smartvpn';
|
||||
|
||||
const server = new VpnServer({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon and the VPN server
|
||||
const server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
await server.start({
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'BASE64_PRIVATE_KEY',
|
||||
publicKey: 'BASE64_PUBLIC_KEY',
|
||||
privateKey: '<server-noise-private-key-base64>',
|
||||
publicKey: '<server-noise-public-key-base64>',
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
transportMode: 'both', // WebSocket + QUIC simultaneously
|
||||
enableNat: true,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
});
|
||||
|
||||
// Generate a Noise keypair
|
||||
const keypair = await server.generateKeypair();
|
||||
console.log(keypair); // { publicKey: '...', privateKey: '...' }
|
||||
|
||||
// List connected clients
|
||||
const clients = await server.listClients();
|
||||
// [{ clientId, assignedIp, connectedSince, bytesSent, bytesReceived }]
|
||||
|
||||
// Disconnect a specific client
|
||||
await server.disconnectClient('some-client-id');
|
||||
|
||||
// Get server stats
|
||||
const stats = await server.getStatistics();
|
||||
// { bytesSent, bytesReceived, activeClients, totalConnections, ... }
|
||||
|
||||
// Stop
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
```
|
||||
|
||||
### Production: Socket Transport
|
||||
|
||||
In production, the daemon runs as a system service and you connect over a Unix socket:
|
||||
### 2. Create a Client (One Call = Everything)
|
||||
|
||||
```typescript
|
||||
const client = new VpnClient({
|
||||
transport: {
|
||||
transport: 'socket',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
autoReconnect: true,
|
||||
reconnectBaseDelayMs: 100,
|
||||
reconnectMaxDelayMs: 30000,
|
||||
maxReconnectAttempts: 10,
|
||||
const bundle = await server.createClient({
|
||||
clientId: 'alice-laptop',
|
||||
tags: ['engineering'],
|
||||
security: {
|
||||
destinationAllowList: ['10.0.0.0/8'], // can only reach internal network
|
||||
destinationBlockList: ['10.0.0.99'], // except this host
|
||||
rateLimit: { bytesPerSec: 10_000_000, burstBytes: 20_000_000 },
|
||||
},
|
||||
});
|
||||
|
||||
await client.start(); // connects to existing daemon (does not spawn)
|
||||
// bundle.smartvpnConfig → typed IVpnClientConfig, ready to use
|
||||
// bundle.wireguardConfig → standard WireGuard .conf string
|
||||
// bundle.secrets → { noisePrivateKey, wgPrivateKey } — shown ONCE
|
||||
```
|
||||
|
||||
When using socket transport, `client.stop()` closes the socket but **does not kill the daemon** — exactly what you want in production.
|
||||
|
||||
## 📋 API Reference
|
||||
|
||||
### `VpnClient`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start()` | `Promise<boolean>` | Start the daemon bridge (spawn or connect) |
|
||||
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
|
||||
| `disconnect()` | `Promise<void>` | Disconnect from VPN |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
|
||||
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic statistics |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
| `running` | `boolean` | Whether bridge is active |
|
||||
|
||||
### `VpnServer`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start(config?)` | `Promise<void>` | Start daemon + VPN server |
|
||||
| `stopServer()` | `Promise<void>` | Stop the VPN server |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
|
||||
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
|
||||
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients |
|
||||
| `disconnectClient(id)` | `Promise<void>` | Kick a client |
|
||||
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
|
||||
### `VpnConfig`
|
||||
|
||||
Static utility class for config validation and file I/O:
|
||||
### 3. Connect a Client
|
||||
|
||||
```typescript
|
||||
import { VpnConfig } from '@push.rocks/smartvpn';
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Validate (throws on invalid)
|
||||
VpnConfig.validateClientConfig(config);
|
||||
VpnConfig.validateServerConfig(config);
|
||||
const client = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await client.start();
|
||||
|
||||
// Load/save JSON configs
|
||||
const config = await VpnConfig.loadFromFile<IVpnClientConfig>('/etc/smartvpn/client.json');
|
||||
await VpnConfig.saveToFile('/etc/smartvpn/client.json', config);
|
||||
const { assignedIp } = await client.connect(bundle.smartvpnConfig);
|
||||
console.log(`Connected! VPN IP: ${assignedIp}`);
|
||||
```
|
||||
|
||||
### `VpnInstaller`
|
||||
## Features ✨
|
||||
|
||||
Generate system service units for the daemon:
|
||||
### 🔐 Enterprise Authentication (Noise IK)
|
||||
|
||||
Every client authenticates with a **Noise IK handshake** (`Noise_IK_25519_ChaChaPoly_BLAKE2s`). The server verifies the client's static public key against its registry — unauthorized clients are rejected before any data flows.
|
||||
|
||||
- Per-client X25519 keypair generated server-side
|
||||
- Client registry with enable/disable, expiry, tags
|
||||
- Key rotation with `rotateClientKey()` — generates new keys, returns fresh config bundle, disconnects old session
|
||||
|
||||
### 🌐 Triple Transport
|
||||
|
||||
| Transport | Protocol | Best For |
|
||||
|-----------|----------|----------|
|
||||
| **WebSocket** | TLS over TCP | Firewall-friendly, Cloudflare compatible |
|
||||
| **QUIC** | UDP (via quinn) | Low latency, datagram support for IP packets |
|
||||
| **WireGuard** | UDP (via boringtun) | Standard WG clients (iOS, Android, wg-quick) |
|
||||
|
||||
The server can run **all three simultaneously** with `transportMode: 'both'` (WS + QUIC) or `'wireguard'`. Clients auto-negotiate with `transport: 'auto'` (tries QUIC first, falls back to WS).
|
||||
|
||||
### 🛡️ ACL Engine (SmartProxy-Aligned)
|
||||
|
||||
Security policies per client, using the same `ipAllowList` / `ipBlockList` naming convention as `@push.rocks/smartproxy`:
|
||||
|
||||
```typescript
|
||||
security: {
|
||||
ipAllowList: ['192.168.1.0/24'], // source IPs allowed to connect
|
||||
ipBlockList: ['192.168.1.100'], // deny overrides allow
|
||||
destinationAllowList: ['10.0.0.0/8'], // VPN destinations permitted
|
||||
destinationBlockList: ['10.0.0.99'], // deny overrides allow
|
||||
maxConnections: 5,
|
||||
rateLimit: { bytesPerSec: 1_000_000, burstBytes: 2_000_000 },
|
||||
}
|
||||
```
|
||||
|
||||
Supports exact IPs, CIDR, wildcards (`192.168.1.*`), and ranges (`1.1.1.1-1.1.1.100`).
|
||||
|
||||
### 📊 Telemetry & QoS
|
||||
|
||||
- **Connection quality**: Smoothed RTT, jitter, min/max RTT, loss ratio, link health (`healthy` / `degraded` / `critical`)
|
||||
- **Adaptive keepalives**: Interval adjusts based on link health (60s → 30s → 10s)
|
||||
- **Per-client rate limiting**: Token bucket with configurable bytes/sec and burst
|
||||
- **Dead-peer detection**: 180s inactivity timeout
|
||||
- **MTU management**: Automatic overhead calculation (IP+TCP+WS+Noise = 79 bytes)
|
||||
|
||||
### 🔄 Hub Client Management
|
||||
|
||||
The server acts as a **hub** — one API to manage all clients:
|
||||
|
||||
```typescript
|
||||
// Create (generates keys, assigns IP, returns config bundle)
|
||||
const bundle = await server.createClient({ clientId: 'bob-phone' });
|
||||
|
||||
// Read
|
||||
const entry = await server.getClient('bob-phone');
|
||||
const all = await server.listRegisteredClients();
|
||||
|
||||
// Update (ACLs, tags, description, rate limits...)
|
||||
await server.updateClient('bob-phone', {
|
||||
security: { destinationAllowList: ['0.0.0.0/0'] },
|
||||
tags: ['mobile', 'field-ops'],
|
||||
});
|
||||
|
||||
// Enable / Disable
|
||||
await server.disableClient('bob-phone'); // disconnects + blocks reconnection
|
||||
await server.enableClient('bob-phone');
|
||||
|
||||
// Key rotation
|
||||
const newBundle = await server.rotateClientKey('bob-phone');
|
||||
|
||||
// Export config (without secrets)
|
||||
const wgConf = await server.exportClientConfig('bob-phone', 'wireguard');
|
||||
|
||||
// Remove
|
||||
await server.removeClient('bob-phone');
|
||||
```
|
||||
|
||||
### 📝 WireGuard Config Generation
|
||||
|
||||
Generate standard `.conf` files for any WireGuard client:
|
||||
|
||||
```typescript
|
||||
import { WgConfigGenerator } from '@push.rocks/smartvpn';
|
||||
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: '<client-wg-private-key>',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1'],
|
||||
peer: {
|
||||
publicKey: '<server-wg-public-key>',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
// → standard WireGuard .conf compatible with wg-quick, iOS, Android
|
||||
```
|
||||
|
||||
### 🖥️ System Service Installation
|
||||
|
||||
```typescript
|
||||
import { VpnInstaller } from '@push.rocks/smartvpn';
|
||||
|
||||
// Auto-detect platform
|
||||
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
|
||||
|
||||
// Generate systemd unit (Linux)
|
||||
const unit = VpnInstaller.generateSystemdUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'server',
|
||||
});
|
||||
// unit.content = full systemd .service file
|
||||
// unit.installPath = '/etc/systemd/system/smartvpn-server.service'
|
||||
|
||||
// Generate launchd plist (macOS)
|
||||
const plist = VpnInstaller.generateLaunchdPlist({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'client',
|
||||
});
|
||||
|
||||
// Auto-detect and generate
|
||||
const serviceUnit = VpnInstaller.generateServiceUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
const unit = VpnInstaller.generateServiceUnit({
|
||||
mode: 'server',
|
||||
configPath: '/etc/smartvpn/server.json',
|
||||
});
|
||||
// unit.platform → 'linux' | 'macos'
|
||||
// unit.content → systemd unit file or launchd plist
|
||||
// unit.installPath → /etc/systemd/system/smartvpn-server.service
|
||||
```
|
||||
|
||||
### Events
|
||||
## API Reference 📖
|
||||
|
||||
Both `VpnClient` and `VpnServer` extend `EventEmitter`:
|
||||
### Classes
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `VpnServer` | Manages the Rust daemon in server mode. Hub methods for client CRUD. |
|
||||
| `VpnClient` | Manages the Rust daemon in client mode. Connect, disconnect, telemetry. |
|
||||
| `VpnBridge<T>` | Low-level typed IPC bridge (stdio or Unix socket). |
|
||||
| `VpnConfig` | Static config validation and file I/O. |
|
||||
| `VpnInstaller` | Generates systemd/launchd service files. |
|
||||
| `WgConfigGenerator` | Generates standard WireGuard `.conf` files. |
|
||||
|
||||
### Key Interfaces
|
||||
|
||||
| Interface | Purpose |
|
||||
|-----------|---------|
|
||||
| `IVpnServerConfig` | Server configuration (listen addr, keys, subnet, transport mode, clients) |
|
||||
| `IVpnClientConfig` | Client configuration (server URL, keys, transport, WG options) |
|
||||
| `IClientEntry` | Server-side client definition (ID, keys, security, priority, tags, expiry) |
|
||||
| `IClientSecurity` | Per-client ACLs and rate limits (SmartProxy-aligned naming) |
|
||||
| `IClientRateLimit` | Rate limiting config (bytesPerSec, burstBytes) |
|
||||
| `IClientConfigBundle` | Full config bundle returned by `createClient()` |
|
||||
| `IVpnClientInfo` | Connected client info (IP, stats, authenticated key) |
|
||||
| `IVpnConnectionQuality` | RTT, jitter, loss ratio, link health |
|
||||
| `IVpnKeypair` | Base64-encoded public/private key pair |
|
||||
|
||||
### Server IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `start` / `stop` | Start/stop the VPN listener |
|
||||
| `createClient` | Generate keys, assign IP, return config bundle |
|
||||
| `removeClient` / `getClient` / `listRegisteredClients` | Client registry CRUD |
|
||||
| `updateClient` / `enableClient` / `disableClient` | Modify client state |
|
||||
| `rotateClientKey` | Fresh keypairs + new config bundle |
|
||||
| `exportClientConfig` | Re-export as SmartVPN config or WireGuard `.conf` |
|
||||
| `listClients` / `disconnectClient` | Manage live connections |
|
||||
| `setClientRateLimit` / `removeClientRateLimit` | Runtime rate limit adjustments |
|
||||
| `getStatus` / `getStatistics` / `getClientTelemetry` | Monitoring |
|
||||
| `generateKeypair` / `generateWgKeypair` / `generateClientKeypair` | Key generation |
|
||||
| `addWgPeer` / `removeWgPeer` / `listWgPeers` | WireGuard peer management |
|
||||
|
||||
### Client IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `connect` / `disconnect` | Manage the tunnel |
|
||||
| `getStatus` / `getStatistics` | Connection state and traffic stats |
|
||||
| `getConnectionQuality` | RTT, jitter, loss, link health |
|
||||
| `getMtuInfo` | MTU and overhead details |
|
||||
|
||||
## Transport Modes 🔀
|
||||
|
||||
### Server Configuration
|
||||
|
||||
```typescript
|
||||
client.on('status', (status) => { /* IVpnStatus */ });
|
||||
client.on('error', (err) => { /* { message, code? } */ });
|
||||
client.on('exit', ({ code, signal }) => { /* daemon exited */ });
|
||||
client.on('reconnected', () => { /* socket reconnected */ });
|
||||
// WebSocket only
|
||||
{ transportMode: 'websocket', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
server.on('client-connected', (info) => { /* IVpnClientInfo */ });
|
||||
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
|
||||
// QUIC only
|
||||
{ transportMode: 'quic', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
// Both (WS + QUIC on same or different ports)
|
||||
{ transportMode: 'both', listenAddr: '0.0.0.0:443', quicListenAddr: '0.0.0.0:4433' }
|
||||
|
||||
// WireGuard
|
||||
{ transportMode: 'wireguard', wgListenPort: 51820, wgPeers: [...] }
|
||||
```
|
||||
|
||||
## 🔐 Security Model
|
||||
### Client Configuration
|
||||
|
||||
The VPN uses a **Noise NK** handshake pattern:
|
||||
```typescript
|
||||
// Auto (tries QUIC first, falls back to WS)
|
||||
{ transport: 'auto', serverUrl: 'wss://vpn.example.com' }
|
||||
|
||||
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
|
||||
2. The client generates an ephemeral keypair, performs `e, es` (Diffie-Hellman with server's static key)
|
||||
3. Server responds with `e, ee` (Diffie-Hellman with both ephemeral keys)
|
||||
4. Result: forward-secret transport keys derived from both DH operations
|
||||
// Explicit QUIC with certificate pinning
|
||||
{ transport: 'quic', serverUrl: '1.2.3.4:4433', serverCertHash: '<sha256-base64>' }
|
||||
|
||||
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
|
||||
- 24-byte random nonces (no counter synchronization needed)
|
||||
- 16-byte authentication tags
|
||||
- Wire format: `[nonce:24B][ciphertext:var][tag:16B]`
|
||||
|
||||
## 📦 Binary Protocol
|
||||
|
||||
Inside the WebSocket tunnel, packets use a simple binary framing:
|
||||
|
||||
```
|
||||
┌──────────┬──────────┬────────────────────┐
|
||||
│ Type (1B)│ Len (4B) │ Payload (variable) │
|
||||
└──────────┴──────────┴────────────────────┘
|
||||
// WireGuard
|
||||
{ transport: 'wireguard', wgPrivateKey: '...', wgEndpoint: 'vpn.example.com:51820', ... }
|
||||
```
|
||||
|
||||
| Type | Value | Description |
|
||||
|------|-------|-------------|
|
||||
| `HandshakeInit` | `0x01` | Client → Server handshake |
|
||||
| `HandshakeResp` | `0x02` | Server → Client handshake |
|
||||
| `IpPacket` | `0x10` | Encrypted IP packet |
|
||||
| `Keepalive` | `0x20` | App-level ping |
|
||||
| `KeepaliveAck` | `0x21` | App-level pong |
|
||||
| `SessionResume` | `0x30` | Resume a dropped session |
|
||||
| `SessionResumeOk` | `0x31` | Resume accepted |
|
||||
| `SessionResumeErr` | `0x32` | Resume rejected |
|
||||
| `Disconnect` | `0x3F` | Graceful disconnect |
|
||||
## Cryptography 🔑
|
||||
|
||||
## 🛠️ Rust Daemon CLI
|
||||
| Layer | Algorithm | Purpose |
|
||||
|-------|-----------|---------|
|
||||
| **Handshake** | Noise IK (X25519 + ChaChaPoly + BLAKE2s) | Mutual authentication + key exchange |
|
||||
| **Transport** | Noise transport state (ChaChaPoly) | All post-handshake data encryption |
|
||||
| **Additional** | XChaCha20-Poly1305 | Extended nonce space for data-at-rest |
|
||||
| **WireGuard** | X25519 + ChaCha20-Poly1305 (via boringtun) | Standard WireGuard crypto |
|
||||
|
||||
The Rust binary supports several modes:
|
||||
## Binary Protocol 📡
|
||||
|
||||
```bash
|
||||
# Development: stdio management (JSON lines on stdin/stdout)
|
||||
smartvpn_daemon --management --mode client
|
||||
smartvpn_daemon --management --mode server
|
||||
All frames use `[type:1B][length:4B][payload:NB]` with a 64KB max payload:
|
||||
|
||||
# Production: Unix socket management
|
||||
smartvpn_daemon --management-socket /var/run/smartvpn.sock --mode server
|
||||
| Type | Hex | Direction | Description |
|
||||
|------|-----|-----------|-------------|
|
||||
| HandshakeInit | `0x01` | Client → Server | Noise IK first message |
|
||||
| HandshakeResp | `0x02` | Server → Client | Noise IK response |
|
||||
| IpPacket | `0x10` | Bidirectional | Encrypted tunnel data |
|
||||
| Keepalive | `0x20` | Client → Server | App-level keepalive (not WS ping) |
|
||||
| KeepaliveAck | `0x21` | Server → Client | Keepalive response with RTT payload |
|
||||
| Disconnect | `0x3F` | Bidirectional | Graceful disconnect |
|
||||
|
||||
# Generate a Noise keypair
|
||||
smartvpn_daemon --generate-keypair
|
||||
```
|
||||
|
||||
## 🔧 Building from Source
|
||||
## Development 🛠️
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Build TypeScript + cross-compile Rust
|
||||
# Build (TypeScript + Rust cross-compile)
|
||||
pnpm build
|
||||
|
||||
# Build Rust only (debug)
|
||||
cd rust && cargo build
|
||||
# Run all tests (79 TS + 121 Rust = 200 tests)
|
||||
pnpm test
|
||||
|
||||
# Run Rust tests
|
||||
# Run Rust tests directly
|
||||
cd rust && cargo test
|
||||
|
||||
# Run TypeScript tests
|
||||
pnpm test
|
||||
# Run a specific TS test
|
||||
tstest test/test.flowcontrol.node.ts --verbose
|
||||
```
|
||||
|
||||
## TypeScript Interfaces
|
||||
### Project Structure
|
||||
|
||||
<details>
|
||||
<summary>Click to expand full type definitions</summary>
|
||||
|
||||
```typescript
|
||||
// Transport options
|
||||
type TVpnTransportOptions =
|
||||
| { transport: 'stdio' }
|
||||
| {
|
||||
transport: 'socket';
|
||||
socketPath: string;
|
||||
autoReconnect?: boolean;
|
||||
reconnectBaseDelayMs?: number;
|
||||
reconnectMaxDelayMs?: number;
|
||||
maxReconnectAttempts?: number;
|
||||
};
|
||||
|
||||
// Client config
|
||||
interface IVpnClientConfig {
|
||||
serverUrl: string; // e.g. 'wss://vpn.example.com/tunnel'
|
||||
serverPublicKey: string; // base64-encoded Noise static key
|
||||
dns?: string[];
|
||||
mtu?: number; // default: 1420
|
||||
keepaliveIntervalSecs?: number; // default: 30
|
||||
}
|
||||
|
||||
// Server config
|
||||
interface IVpnServerConfig {
|
||||
listenAddr: string; // e.g. '0.0.0.0:443'
|
||||
privateKey: string; // base64 Noise static private key
|
||||
publicKey: string; // base64 Noise static public key
|
||||
subnet: string; // e.g. '10.8.0.0/24'
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
}
|
||||
|
||||
// Status
|
||||
type TVpnConnectionState = 'disconnected' | 'connecting' | 'handshaking'
|
||||
| 'connected' | 'reconnecting' | 'error';
|
||||
|
||||
interface IVpnStatus {
|
||||
state: TVpnConnectionState;
|
||||
assignedIp?: string;
|
||||
serverAddr?: string;
|
||||
connectedSince?: string;
|
||||
lastError?: string;
|
||||
}
|
||||
|
||||
// Statistics
|
||||
interface IVpnStatistics {
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
}
|
||||
|
||||
interface IVpnServerStatistics extends IVpnStatistics {
|
||||
activeClients: number;
|
||||
totalConnections: number;
|
||||
}
|
||||
|
||||
interface IVpnClientInfo {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
}
|
||||
|
||||
interface IVpnKeypair {
|
||||
publicKey: string;
|
||||
privateKey: string;
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
smartvpn/
|
||||
├── ts/ # TypeScript control plane
|
||||
│ ├── index.ts # All exports
|
||||
│ ├── smartvpn.interfaces.ts # Interfaces, types, IPC command maps
|
||||
│ ├── smartvpn.classes.vpnserver.ts
|
||||
│ ├── smartvpn.classes.vpnclient.ts
|
||||
│ ├── smartvpn.classes.vpnbridge.ts
|
||||
│ ├── smartvpn.classes.vpnconfig.ts
|
||||
│ ├── smartvpn.classes.vpninstaller.ts
|
||||
│ └── smartvpn.classes.wgconfig.ts
|
||||
├── rust/ # Rust data plane daemon
|
||||
│ └── src/
|
||||
│ ├── main.rs # CLI entry point
|
||||
│ ├── server.rs # VPN server + hub methods
|
||||
│ ├── client.rs # VPN client
|
||||
│ ├── crypto.rs # Noise IK + XChaCha20
|
||||
│ ├── client_registry.rs # Client database
|
||||
│ ├── acl.rs # ACL engine
|
||||
│ ├── management.rs # JSON-lines IPC
|
||||
│ ├── transport.rs # WebSocket transport
|
||||
│ ├── quic_transport.rs # QUIC transport
|
||||
│ ├── wireguard.rs # WireGuard (boringtun)
|
||||
│ ├── codec.rs # Binary frame protocol
|
||||
│ ├── keepalive.rs # Adaptive keepalives
|
||||
│ ├── ratelimit.rs # Token bucket
|
||||
│ └── ... # tunnel, network, telemetry, qos, mtu, reconnect
|
||||
├── test/ # 9 test files (79 tests)
|
||||
├── dist_ts/ # Compiled TypeScript
|
||||
└── dist_rust/ # Cross-compiled binaries (linux amd64 + arm64)
|
||||
```
|
||||
|
||||
## License and Legal Information
|
||||
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license.md) file.
|
||||
|
||||
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.
|
||||
|
||||
|
||||
459
readme.plan.md
Normal file
459
readme.plan.md
Normal file
@@ -0,0 +1,459 @@
|
||||
# Enterprise Auth & Client Management for SmartVPN
|
||||
|
||||
## Context
|
||||
|
||||
SmartVPN's Noise NK mode currently allows **any client that knows the server's public key** to connect — no per-client identity or access control. The goal is to make SmartVPN enterprise-ready with:
|
||||
|
||||
1. **Per-client cryptographic authentication** (Noise IK handshake)
|
||||
2. **Rich client definitions** with ACLs, rate limits, and priority
|
||||
3. **Hub-generated configs** — server generates typed SmartVPN client configs AND WireGuard .conf files from the same client definition
|
||||
4. **Top-notch DX** — one `createClient()` call gives you everything
|
||||
|
||||
**This is a breaking change.** No backward compatibility with the old NK anonymous mode.
|
||||
|
||||
---
|
||||
|
||||
## Design Overview
|
||||
|
||||
### The Hub Model
|
||||
|
||||
The server acts as a **hub** that manages client definitions. Each client definition is the **single source of truth** from which both SmartVPN native configs and WireGuard configs are generated.
|
||||
|
||||
```
|
||||
Hub (Server)
|
||||
└── Client Registry
|
||||
├── "alice-laptop" → SmartVPN config OR WireGuard .conf
|
||||
├── "bob-phone" → SmartVPN config OR WireGuard .conf
|
||||
└── "office-gw" → SmartVPN config OR WireGuard .conf
|
||||
```
|
||||
|
||||
### Authentication: NK → IK (Breaking Change)
|
||||
|
||||
**Old (removed):** `Noise_NK_25519_ChaChaPoly_BLAKE2s` — client is anonymous
|
||||
**New (always):** `Noise_IK_25519_ChaChaPoly_BLAKE2s` — client presents its static key during handshake
|
||||
|
||||
IK is a 2-message handshake (same count as NK), so **the frame protocol stays identical**. Changes:
|
||||
- `create_initiator()` now requires `(client_private_key, server_public_key)` — always
|
||||
- `create_responder()` remains `(server_private_key)` — but now uses IK pattern
|
||||
- After handshake, server extracts client's public key via `get_remote_static()` and verifies against registry
|
||||
- Old NK functions are replaced, not kept alongside
|
||||
|
||||
**Every client must have a keypair. Every server must have a client registry.**
|
||||
|
||||
---
|
||||
|
||||
## Core Interface: `IClientEntry`
|
||||
|
||||
This is the server-side client definition — the central config object.
|
||||
Naming and structure are aligned with SmartProxy's `IRouteConfig` / `IRouteSecurity` patterns.
|
||||
|
||||
```typescript
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
|
||||
// ── Security (aligned with SmartProxy IRouteSecurity pattern) ─────────
|
||||
|
||||
security?: IClientSecurity;
|
||||
|
||||
// ── QoS ────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
|
||||
// ── Metadata (aligned with SmartProxy IRouteConfig pattern) ────────────
|
||||
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Security settings per client — mirrors SmartProxy's IRouteSecurity structure.
|
||||
* Uses the same ipAllowList/ipBlockList naming convention.
|
||||
* Adds VPN-specific destination filtering (destinationAllowList/destinationBlockList).
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5).
|
||||
* Same format as SmartProxy's ipAllowList. */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins).
|
||||
* Same format as SmartProxy's ipBlockList. */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all) */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins) */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
```
|
||||
|
||||
### SmartProxy Alignment Notes
|
||||
|
||||
| Pattern | SmartProxy | SmartVPN |
|
||||
|---------|-----------|---------|
|
||||
| ACL naming | `ipAllowList` / `ipBlockList` | Same — `ipAllowList` / `ipBlockList` |
|
||||
| Security grouping | `security: IRouteSecurity` sub-object | Same — `security: IClientSecurity` sub-object |
|
||||
| Rate limit structure | `rateLimit: IRouteRateLimit` object | Same pattern — `rateLimit: IClientRateLimit` object |
|
||||
| IP format support | Exact, CIDR, wildcard, ranges | Same formats |
|
||||
| Metadata fields | `priority`, `tags`, `enabled`, `description` | Same fields |
|
||||
| ACL evaluation | Block-first, then allow-list | Same — deny overrides allow |
|
||||
|
||||
### ACL Evaluation Order
|
||||
|
||||
```
|
||||
1. Check ipBlockList / destinationBlockList first (explicit deny wins)
|
||||
2. If denied, DROP
|
||||
3. Check ipAllowList / destinationAllowList (explicit allow)
|
||||
4. If ipAllowList is empty → allow any source
|
||||
5. If destinationAllowList is empty → allow all destinations
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Hub Config Generation
|
||||
|
||||
### `createClient()` — The One-Call DX
|
||||
|
||||
When the hub creates a client, it:
|
||||
1. Generates a Noise IK keypair for the client
|
||||
2. Generates a WireGuard keypair for the client
|
||||
3. Allocates a VPN IP address
|
||||
4. Stores the `IClientEntry` in the registry
|
||||
5. Returns a **complete config bundle** with everything the client needs
|
||||
|
||||
```typescript
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
The `secrets` are returned **only at creation time** — the server stores only public keys.
|
||||
|
||||
### `exportClientConfig()` — Re-export (without secrets)
|
||||
|
||||
```typescript
|
||||
exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): IVpnClientConfig | string
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Updated `IVpnServerConfig`
|
||||
|
||||
```typescript
|
||||
export interface IVpnServerConfig {
|
||||
listenAddr: string;
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
privateKey: string; // Server's Noise static private key (base64)
|
||||
publicKey: string; // Server's Noise static public key (base64)
|
||||
subnet: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
defaultBurstBytes?: number;
|
||||
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
|
||||
quicListenAddr?: string;
|
||||
quicIdleTimeoutSecs?: number;
|
||||
wgListenPort?: number;
|
||||
wgPeers?: IWgPeerConfig[]; // Keep for raw WG mode
|
||||
|
||||
/** Pre-registered clients — REQUIRED for SmartVPN native transport */
|
||||
clients: IClientEntry[];
|
||||
}
|
||||
```
|
||||
|
||||
Note: `clients` is now **required** (not optional), and there is no `authMode` field — IK is always used.
|
||||
|
||||
---
|
||||
|
||||
## Updated `IVpnClientConfig`
|
||||
|
||||
```typescript
|
||||
export interface IVpnClientConfig {
|
||||
serverUrl: string;
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — REQUIRED for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
|
||||
serverCertHash?: string;
|
||||
// WireGuard fields unchanged...
|
||||
wgPrivateKey?: string;
|
||||
wgAddress?: string;
|
||||
wgAddressPrefix?: number;
|
||||
wgPresharedKey?: string;
|
||||
wgPersistentKeepalive?: number;
|
||||
wgEndpoint?: string;
|
||||
wgAllowedIps?: string[];
|
||||
}
|
||||
```
|
||||
|
||||
Note: `clientPrivateKey` and `clientPublicKey` are now **required** (not optional) for non-WireGuard transports.
|
||||
|
||||
---
|
||||
|
||||
## New IPC Commands
|
||||
|
||||
Added to `TVpnServerCommands`:
|
||||
|
||||
| Command | Params | Result | Description |
|
||||
|---------|--------|--------|-------------|
|
||||
| `createClient` | `{ client: Partial<IClientEntry> }` | `IClientConfigBundle` | Create client, generate keypairs, assign IP, return full config bundle |
|
||||
| `removeClient` | `{ clientId: string }` | `void` | Remove from registry + disconnect if connected |
|
||||
| `getClient` | `{ clientId: string }` | `IClientEntry` | Get a single client entry |
|
||||
| `listRegisteredClients` | `{}` | `{ clients: IClientEntry[] }` | List all registered clients |
|
||||
| `updateClient` | `{ clientId: string, update: Partial<IClientEntry> }` | `void` | Update ACLs, rate limits, tags, etc. |
|
||||
| `enableClient` | `{ clientId: string }` | `void` | Enable a disabled client |
|
||||
| `disableClient` | `{ clientId: string }` | `void` | Disable (but don't delete) |
|
||||
| `rotateClientKey` | `{ clientId: string }` | `IClientConfigBundle` | New keypairs, return fresh config bundle |
|
||||
| `exportClientConfig` | `{ clientId: string, format: 'smartvpn' \| 'wireguard' }` | `{ config: string }` | Re-export config (without secrets) |
|
||||
| `generateClientKeypair` | `{}` | `IVpnKeypair` | Generate a standalone Noise IK keypair |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Rust — Crypto (Replace NK with IK)
|
||||
|
||||
**File: `rust/src/crypto.rs`**
|
||||
|
||||
- Change `NOISE_PATTERN` from NK to IK: `"Noise_IK_25519_ChaChaPoly_BLAKE2s"`
|
||||
- Replace `create_initiator(server_public_key)` → `create_initiator(client_private_key, server_public_key)`
|
||||
- `create_responder(private_key)` stays the same signature (IK responder only needs its own key)
|
||||
- After handshake, `get_remote_static()` on the responder returns the client's public key
|
||||
- Update `perform_handshake()` to pass client keypair
|
||||
- Update all tests
|
||||
|
||||
### Phase 2: Rust — Client Registry module
|
||||
|
||||
**New file: `rust/src/client_registry.rs`**
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod client_registry;`
|
||||
|
||||
```rust
|
||||
pub struct ClientEntry {
|
||||
pub client_id: String,
|
||||
pub public_key: String,
|
||||
pub wg_public_key: Option<String>,
|
||||
pub security: Option<ClientSecurity>,
|
||||
pub priority: Option<u32>,
|
||||
pub enabled: Option<bool>,
|
||||
pub tags: Option<Vec<String>>,
|
||||
pub description: Option<String>,
|
||||
pub expires_at: Option<String>,
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
/// Mirrors IClientSecurity — aligned with SmartProxy's IRouteSecurity
|
||||
pub struct ClientSecurity {
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
pub max_connections: Option<u32>,
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
pub struct ClientRegistry {
|
||||
entries: HashMap<String, ClientEntry>, // keyed by clientId
|
||||
key_index: HashMap<String, String>, // publicKey → clientId (fast lookup)
|
||||
}
|
||||
```
|
||||
|
||||
Methods: `add`, `remove`, `get_by_id`, `get_by_key`, `update`, `list`, `is_authorized` (enabled + not expired + key exists), `rotate_key`.
|
||||
|
||||
### Phase 3: Rust — ACL enforcement module
|
||||
|
||||
**New file: `rust/src/acl.rs`**
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod acl;`
|
||||
|
||||
```rust
|
||||
/// IP matching supports: exact, CIDR, wildcard, ranges — same as SmartProxy's IpMatcher
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// 1. Check ip_block_list / destination_block_list (deny overrides)
|
||||
// 2. Check ip_allow_list / destination_allow_list (explicit allow)
|
||||
// 3. Empty list = allow all
|
||||
}
|
||||
```
|
||||
|
||||
Called in `server.rs` packet loop after decryption, before forwarding.
|
||||
|
||||
### Phase 4: Rust — Server changes
|
||||
|
||||
**File: `rust/src/server.rs`**
|
||||
|
||||
- Add `clients: Option<Vec<ClientEntry>>` to `ServerConfig`
|
||||
- Add `client_registry: RwLock<ClientRegistry>` to `ServerState` (no `auth_mode` — always IK)
|
||||
- Modify `handle_client_connection()`:
|
||||
- Always use `create_responder()` (now IK pattern)
|
||||
- Call `get_remote_static()` **before** `into_transport_mode()` to get client's public key
|
||||
- Verify against registry — reject unauthorized clients with Disconnect frame
|
||||
- Use registry entry for rate limits (overrides server defaults)
|
||||
- In packet loop: call `acl::check_acl()` on decrypted packets
|
||||
- Add `ClientInfo.authenticated_key: String` and `ClientInfo.registered_client_id: String` (no longer optional)
|
||||
- Add methods: `create_client()`, `remove_client()`, `update_client()`, `list_registered_clients()`, `rotate_client_key()`, `export_client_config()`
|
||||
|
||||
### Phase 5: Rust — Client changes
|
||||
|
||||
**File: `rust/src/client.rs`**
|
||||
|
||||
- Add `client_private_key: String` to `ClientConfig` (required, not optional)
|
||||
- `connect()` always uses `create_initiator(client_private_key, server_public_key)` (IK)
|
||||
|
||||
### Phase 6: Rust — Management IPC handlers
|
||||
|
||||
**File: `rust/src/management.rs`**
|
||||
|
||||
Add handlers for all 10 new IPC commands following existing patterns.
|
||||
|
||||
### Phase 7: TypeScript — Interfaces
|
||||
|
||||
**File: `ts/smartvpn.interfaces.ts`**
|
||||
|
||||
- Add `IClientEntry` interface
|
||||
- Add `IClientConfigBundle` interface
|
||||
- Update `IVpnServerConfig`: add required `clients: IClientEntry[]`
|
||||
- Update `IVpnClientConfig`: add required `clientPrivateKey: string`, `clientPublicKey: string`
|
||||
- Update `IVpnClientInfo`: add `authenticatedKey: string`, `registeredClientId: string`
|
||||
- Add new commands to `TVpnServerCommands`
|
||||
|
||||
### Phase 8: TypeScript — VpnServer class methods
|
||||
|
||||
**File: `ts/smartvpn.classes.vpnserver.ts`**
|
||||
|
||||
Add methods:
|
||||
- `createClient(opts)` → `IClientConfigBundle`
|
||||
- `removeClient(clientId)` → `void`
|
||||
- `getClient(clientId)` → `IClientEntry`
|
||||
- `listRegisteredClients()` → `IClientEntry[]`
|
||||
- `updateClient(clientId, update)` → `void`
|
||||
- `enableClient(clientId)` / `disableClient(clientId)`
|
||||
- `rotateClientKey(clientId)` → `IClientConfigBundle`
|
||||
- `exportClientConfig(clientId, format)` → `string | IVpnClientConfig`
|
||||
|
||||
### Phase 9: TypeScript — Config validation
|
||||
|
||||
**File: `ts/smartvpn.classes.vpnconfig.ts`**
|
||||
|
||||
- Server config: validate `clients` present, each entry has valid `clientId` + `publicKey`
|
||||
- Client config: validate `clientPrivateKey` and `clientPublicKey` present for non-WG transports
|
||||
- Validate CIDRs in ACL fields
|
||||
|
||||
### Phase 10: TypeScript — Hub config generation
|
||||
|
||||
**File: `ts/smartvpn.classes.wgconfig.ts`** (extend existing)
|
||||
|
||||
Add `generateClientConfigFromEntry(entry, serverConfig)` — produces WireGuard .conf from `IClientEntry`.
|
||||
|
||||
### Phase 11: Update existing tests
|
||||
|
||||
All existing tests that use the old NK handshake or old config shapes need updating:
|
||||
- Rust tests in `crypto.rs`, `server.rs`, `client.rs`
|
||||
- TS tests in `test/test.vpnconfig.node.ts`, `test/test.flowcontrol.node.ts`, etc.
|
||||
- Tests now must provide client keypairs and client registry entries
|
||||
|
||||
---
|
||||
|
||||
## DX Highlights
|
||||
|
||||
1. **One call to create a client:**
|
||||
```typescript
|
||||
const bundle = await server.createClient({ clientId: 'alice-laptop', tags: ['engineering'] });
|
||||
// bundle.smartvpnConfig — typed SmartVPN client config
|
||||
// bundle.wireguardConfig — standard WireGuard .conf string
|
||||
// bundle.secrets — private keys, shown only at creation time
|
||||
```
|
||||
|
||||
2. **Typed config objects throughout** — no raw strings or JSON blobs
|
||||
|
||||
3. **Dual transport from same definition** — register once, connect via SmartVPN or WireGuard
|
||||
|
||||
4. **ACLs are deny-overrides-allow** — intuitive enterprise model
|
||||
|
||||
5. **Hot management** — add/remove/update/disable clients at runtime
|
||||
|
||||
6. **Key rotation** — `rotateClientKey()` generates new keys and returns a fresh config bundle
|
||||
|
||||
---
|
||||
|
||||
## Verification Plan
|
||||
|
||||
1. **Rust unit tests:**
|
||||
- `crypto.rs`: IK handshake roundtrip, `get_remote_static()` returns correct key, wrong key fails
|
||||
- `client_registry.rs`: CRUD, `is_authorized` with enabled/disabled/expired
|
||||
- `acl.rs`: allow/deny logic, empty lists, deny-overrides-allow
|
||||
|
||||
2. **Rust integration tests:**
|
||||
- Server accepts authorized client
|
||||
- Server rejects unknown client public key
|
||||
- ACL filtering drops packets to blocked destinations
|
||||
- Runtime `createClient` / `removeClient` works
|
||||
- Disabled client rejected at handshake
|
||||
|
||||
3. **TypeScript tests:**
|
||||
- Config validation with required client fields
|
||||
- `createClient()` returns valid bundle with both formats
|
||||
- `exportClientConfig()` generates valid WireGuard .conf
|
||||
- Full IPC roundtrip: create client → connect → traffic → disconnect
|
||||
|
||||
4. **Build:** `pnpm build` (TS + Rust), `cargo test`, `pnpm test`
|
||||
|
||||
---
|
||||
|
||||
## Key Files to Modify
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `rust/src/crypto.rs` | Replace NK with IK pattern, update initiator signature |
|
||||
| `rust/src/client_registry.rs` | **NEW** — client registry module |
|
||||
| `rust/src/acl.rs` | **NEW** — ACL evaluation module |
|
||||
| `rust/src/server.rs` | Registry integration, IK auth in handshake, ACL in packet loop |
|
||||
| `rust/src/client.rs` | Required `client_private_key`, IK initiator |
|
||||
| `rust/src/management.rs` | 10 new IPC command handlers |
|
||||
| `rust/src/lib.rs` | Register new modules |
|
||||
| `ts/smartvpn.interfaces.ts` | `IClientEntry`, `IClientConfigBundle`, updated configs & commands |
|
||||
| `ts/smartvpn.classes.vpnserver.ts` | New hub methods |
|
||||
| `ts/smartvpn.classes.vpnconfig.ts` | Updated validation rules |
|
||||
| `ts/smartvpn.classes.wgconfig.ts` | Config generation from client entries |
|
||||
2
rust/.cargo/config.toml
Normal file
2
rust/.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
793
rust/Cargo.lock
generated
793
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,18 @@ tun = { version = "0.7", features = ["async"] }
|
||||
bytes = "1"
|
||||
tokio-util = "0.7"
|
||||
futures-util = "0.3"
|
||||
async-trait = "0.1"
|
||||
quinn = "0.11"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rcgen = "0.13"
|
||||
ring = "0.17"
|
||||
rustls-pki-types = "1"
|
||||
rustls-pemfile = "2"
|
||||
webpki-roots = "1"
|
||||
mimalloc = "0.1"
|
||||
boringtun = "0.7"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
ipnet = "2"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
||||
278
rust/src/acl.rs
Normal file
278
rust/src/acl.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use std::net::Ipv4Addr;
|
||||
use ipnet::Ipv4Net;
|
||||
|
||||
use crate::client_registry::ClientSecurity;
|
||||
|
||||
/// Result of an ACL check.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AclResult {
|
||||
Allow,
|
||||
DenySrc,
|
||||
DenyDst,
|
||||
}
|
||||
|
||||
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
|
||||
///
|
||||
/// Evaluation order (deny overrides allow):
|
||||
/// 1. If src_ip is in ip_block_list → DenySrc
|
||||
/// 2. If dst_ip is in destination_block_list → DenyDst
|
||||
/// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc
|
||||
/// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst
|
||||
/// 5. Otherwise → Allow
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// Step 1: Check source block list (deny overrides)
|
||||
if let Some(ref block_list) = security.ip_block_list {
|
||||
if ip_matches_any(src_ip, block_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check destination block list (deny overrides)
|
||||
if let Some(ref block_list) = security.destination_block_list {
|
||||
if ip_matches_any(dst_ip, block_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Check source allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.ip_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Check destination allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.destination_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
AclResult::Allow
|
||||
}
|
||||
|
||||
/// Check if `ip` matches any pattern in the list.
|
||||
/// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*),
|
||||
/// and IP ranges (192.168.1.1-192.168.1.100).
|
||||
fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if ip_matches(ip, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if `ip` matches a single pattern.
|
||||
fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
// CIDR notation (e.g. 192.168.1.0/24)
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = pattern.parse::<Ipv4Net>() {
|
||||
return net.contains(&ip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// IP range (e.g. 192.168.1.1-192.168.1.100)
|
||||
if pattern.contains('-') {
|
||||
let parts: Vec<&str> = pattern.splitn(2, '-').collect();
|
||||
if parts.len() == 2 {
|
||||
if let (Ok(start), Ok(end)) = (parts[0].trim().parse::<Ipv4Addr>(), parts[1].trim().parse::<Ipv4Addr>()) {
|
||||
let ip_u32 = u32::from(ip);
|
||||
return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wildcard pattern (e.g. 192.168.1.*)
|
||||
if pattern.contains('*') {
|
||||
return wildcard_matches(ip, pattern);
|
||||
}
|
||||
|
||||
// Exact IP match
|
||||
if let Ok(exact) = pattern.parse::<Ipv4Addr>() {
|
||||
return ip == exact;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*".
|
||||
fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let ip_octets = ip.octets();
|
||||
let pattern_parts: Vec<&str> = pattern.split('.').collect();
|
||||
if pattern_parts.len() != 4 {
|
||||
return false;
|
||||
}
|
||||
for (i, part) in pattern_parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
continue;
|
||||
}
|
||||
if let Ok(octet) = part.parse::<u8>() {
|
||||
if ip_octets[i] != octet {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client_registry::{ClientRateLimit, ClientSecurity};
|
||||
|
||||
fn security(
|
||||
ip_allow: Option<Vec<&str>>,
|
||||
ip_block: Option<Vec<&str>>,
|
||||
dst_allow: Option<Vec<&str>>,
|
||||
dst_block: Option<Vec<&str>>,
|
||||
) -> ClientSecurity {
|
||||
ClientSecurity {
|
||||
ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
max_connections: None,
|
||||
rate_limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ip(s: &str) -> Ipv4Addr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
// ── No restrictions (empty security) ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_security_allows_all() {
|
||||
let sec = security(None, None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_lists_allow_all() {
|
||||
let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![]));
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
// ── Source IP allow list ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_allow_exact_match() {
|
||||
let sec = security(Some(vec!["10.0.0.1"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_cidr() {
|
||||
let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_wildcard() {
|
||||
let sec = security(Some(vec!["10.0.*.*"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_range() {
|
||||
let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Source IP block list (deny overrides) ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_block_overrides_allow() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
Some(vec!["192.168.1.100"]),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Destination allow list ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_allow_exact() {
|
||||
let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dst_allow_cidr() {
|
||||
let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Destination block list (deny overrides) ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_block_overrides_allow() {
|
||||
let sec = security(
|
||||
None,
|
||||
None,
|
||||
Some(vec!["10.0.0.0/8"]),
|
||||
Some(vec!["10.0.0.99"]),
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Combined source + destination ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn combined_src_and_dst_filtering() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
None,
|
||||
Some(vec!["8.8.8.8"]),
|
||||
None,
|
||||
);
|
||||
// Valid source, valid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow);
|
||||
// Invalid source
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc);
|
||||
// Valid source, invalid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── IP matching edge cases ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wildcard_single_octet() {
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*"));
|
||||
assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_boundaries() {
|
||||
assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pattern_no_match() {
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0"));
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,17 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -18,9 +19,17 @@ use crate::transport;
|
||||
pub struct ClientConfig {
|
||||
pub server_url: String,
|
||||
pub server_public_key: String,
|
||||
/// Client's Noise IK static private key (base64) — required for authentication.
|
||||
pub client_private_key: String,
|
||||
/// Client's Noise IK static public key (base64) — for reference/display.
|
||||
pub client_public_key: String,
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
/// Transport type: "websocket" (default) or "quic".
|
||||
pub transport: Option<String>,
|
||||
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
|
||||
pub server_cert_hash: Option<String>,
|
||||
}
|
||||
|
||||
/// Client statistics.
|
||||
@@ -65,6 +74,8 @@ pub struct VpnClient {
|
||||
assigned_ip: Arc<RwLock<Option<String>>>,
|
||||
shutdown_tx: Option<mpsc::Sender<()>>,
|
||||
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
|
||||
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
}
|
||||
|
||||
impl VpnClient {
|
||||
@@ -75,6 +86,8 @@ impl VpnClient {
|
||||
assigned_ip: Arc::new(RwLock::new(None)),
|
||||
shutdown_tx: None,
|
||||
connected_since: Arc::new(RwLock::new(None)),
|
||||
quality_rx: None,
|
||||
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,23 +106,85 @@ impl VpnClient {
|
||||
let stats = self.stats.clone();
|
||||
let assigned_ip_ref = self.assigned_ip.clone();
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
// Decode keys
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
let client_priv_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.client_private_key,
|
||||
)?;
|
||||
|
||||
// Connect to WebSocket server
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
// Create transport based on configuration
|
||||
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
|
||||
let transport_type = config.transport.as_deref().unwrap_or("auto");
|
||||
match transport_type {
|
||||
"quic" => {
|
||||
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
|
||||
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
info!("Connected via QUIC");
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
"websocket" => {
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
_ => {
|
||||
// "auto" (default): try QUIC first, fall back to WebSocket
|
||||
// Extract host:port from the URL for QUIC attempt
|
||||
let quic_addr = extract_host_port(&config.server_url);
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
if let Some(ref addr) = quic_addr {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(3),
|
||||
try_quic_connect(addr, cert_hash),
|
||||
).await {
|
||||
Ok(Ok((quic_sink, quic_stream))) => {
|
||||
info!("Auto: connected via QUIC to {}", addr);
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC unavailable)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Auto: QUIC timed out, falling back to WebSocket");
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC timed out)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Can't extract host:port for QUIC, use WebSocket directly
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Noise IK handshake (client side = initiator, presents static key)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
let mut initiator = crypto::create_initiator(&server_pub_key)?;
|
||||
let mut initiator = crypto::create_initiator(&client_priv_key, &server_pub_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es
|
||||
// -> e, es, s, ss
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let init_frame = Frame {
|
||||
packet_type: PacketType::HandshakeInit,
|
||||
@@ -117,13 +192,11 @@ impl VpnClient {
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// <- e, ee
|
||||
let resp_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
|
||||
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
|
||||
// <- e, ee, se
|
||||
let resp_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
};
|
||||
|
||||
@@ -139,9 +212,9 @@ impl VpnClient {
|
||||
let mut noise_transport = initiator.into_transport_mode()?;
|
||||
|
||||
// Receive assigned IP info (encrypted)
|
||||
let info_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected IP info message"),
|
||||
let info_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before IP info"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&info_msg[..]);
|
||||
@@ -161,16 +234,32 @@ impl VpnClient {
|
||||
|
||||
info!("Connected to VPN, assigned IP: {}", assigned_ip);
|
||||
|
||||
// Create adaptive keepalive monitor (use custom interval if configured)
|
||||
let ka_config = config.keepalive_interval_secs.map(|secs| {
|
||||
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
|
||||
cfg.degraded_interval = std::time::Duration::from_secs(secs);
|
||||
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
|
||||
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
|
||||
cfg
|
||||
});
|
||||
let (monitor, handle) = keepalive::create_keepalive(ka_config);
|
||||
self.quality_rx = Some(handle.quality_rx);
|
||||
|
||||
// Spawn the keepalive monitor
|
||||
tokio::spawn(monitor.run());
|
||||
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
ws_sink,
|
||||
ws_stream,
|
||||
sink,
|
||||
stream,
|
||||
noise_transport,
|
||||
state,
|
||||
stats,
|
||||
shutdown_rx,
|
||||
config.keepalive_interval_secs.unwrap_or(30),
|
||||
handle.signal_rx,
|
||||
handle.ack_tx,
|
||||
link_health,
|
||||
));
|
||||
|
||||
Ok(assigned_ip_clone)
|
||||
@@ -184,6 +273,7 @@ impl VpnClient {
|
||||
*self.assigned_ip.write().await = None;
|
||||
*self.connected_since.write().await = None;
|
||||
*self.state.write().await = ClientState::Disconnected;
|
||||
self.quality_rx = None;
|
||||
info!("Disconnected from VPN");
|
||||
Ok(())
|
||||
}
|
||||
@@ -208,13 +298,14 @@ impl VpnClient {
|
||||
status
|
||||
}
|
||||
|
||||
/// Get traffic statistics.
|
||||
/// Get traffic statistics (includes connection quality).
|
||||
pub async fn get_statistics(&self) -> serde_json::Value {
|
||||
let stats = self.stats.read().await;
|
||||
let since = self.connected_since.read().await;
|
||||
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
|
||||
let health = self.link_health.read().await;
|
||||
|
||||
serde_json::json!({
|
||||
let mut result = serde_json::json!({
|
||||
"bytesSent": stats.bytes_sent,
|
||||
"bytesReceived": stats.bytes_received,
|
||||
"packetsSent": stats.packets_sent,
|
||||
@@ -222,30 +313,58 @@ impl VpnClient {
|
||||
"keepalivesSent": stats.keepalives_sent,
|
||||
"keepalivesReceived": stats.keepalives_received,
|
||||
"uptimeSeconds": uptime,
|
||||
})
|
||||
});
|
||||
|
||||
// Include connection quality if available
|
||||
if let Some(ref rx) = self.quality_rx {
|
||||
let quality = rx.borrow().clone();
|
||||
result["quality"] = serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", *health),
|
||||
"keepalivesSent": quality.keepalives_sent,
|
||||
"keepalivesAcked": quality.keepalives_acked,
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get connection quality snapshot.
|
||||
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
|
||||
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
|
||||
}
|
||||
|
||||
/// Get current link health.
|
||||
pub async fn get_link_health(&self) -> LinkHealth {
|
||||
*self.link_health.read().await
|
||||
}
|
||||
}
|
||||
|
||||
/// The main client packet forwarding loop (runs in a spawned task).
|
||||
async fn client_loop(
|
||||
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
|
||||
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
mut noise_transport: snow::TransportState,
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
keepalive_secs: u64,
|
||||
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
) {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
|
||||
keepalive_ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
Ok(Some(data)) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
@@ -264,6 +383,8 @@ async fn client_loop(
|
||||
}
|
||||
PacketType::KeepaliveAck => {
|
||||
stats.write().await.keepalives_received += 1;
|
||||
// Signal the keepalive monitor that ACK was received
|
||||
let _ = ack_tx.send(()).await;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Server sent disconnect");
|
||||
@@ -274,35 +395,49 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Connection closed");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_sink.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
Err(e) => {
|
||||
error!("Transport error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = keepalive_ticker.tick() => {
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
signal = signal_rx.recv() => {
|
||||
match signal {
|
||||
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
|
||||
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: timestamp_ms.to_be_bytes().to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
}
|
||||
}
|
||||
Some(KeepaliveSignal::PeerDead) => {
|
||||
warn!("Peer declared dead by keepalive monitor");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
|
||||
debug!("Link health changed to: {}", health);
|
||||
*link_health.write().await = health;
|
||||
}
|
||||
None => {
|
||||
// Keepalive monitor channel closed
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
@@ -313,12 +448,51 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
|
||||
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
let _ = sink.close().await;
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect via QUIC. Returns transport halves on success.
|
||||
async fn try_quic_connect(
|
||||
addr: &str,
|
||||
cert_hash: Option<&str>,
|
||||
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
|
||||
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
|
||||
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
Ok((sink, stream))
|
||||
}
|
||||
|
||||
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
|
||||
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
|
||||
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
|
||||
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
|
||||
fn extract_host_port(url: &str) -> Option<String> {
|
||||
if url.starts_with("ws://") || url.starts_with("wss://") {
|
||||
// Parse as URL
|
||||
let stripped = if url.starts_with("wss://") {
|
||||
&url[6..]
|
||||
} else {
|
||||
&url[5..]
|
||||
};
|
||||
// Remove path
|
||||
let host_port = stripped.split('/').next()?;
|
||||
if host_port.contains(':') {
|
||||
Some(host_port.to_string())
|
||||
} else {
|
||||
// Default port
|
||||
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
|
||||
Some(format!("{}:{}", host_port, default_port))
|
||||
}
|
||||
} else if url.contains(':') {
|
||||
// Already host:port
|
||||
Some(url.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
362
rust/src/client_registry.rs
Normal file
362
rust/src/client_registry.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Per-client rate limiting configuration.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
/// Per-client security settings — aligned with SmartProxy's IRouteSecurity pattern.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientSecurity {
|
||||
/// Source IPs/CIDRs the client may connect FROM (empty/None = any).
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// Source IPs blocked — overrides ip_allow_list (deny wins).
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Destination IPs/CIDRs the client may reach (empty/None = all).
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
/// Destination IPs blocked — overrides destination_allow_list (deny wins).
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
/// Max concurrent connections from this client.
|
||||
pub max_connections: Option<u32>,
|
||||
/// Per-client rate limiting.
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
/// A registered client entry — the server-side source of truth.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientEntry {
|
||||
/// Human-readable client ID (e.g. "alice-laptop").
|
||||
pub client_id: String,
|
||||
/// Client's Noise IK public key (base64).
|
||||
pub public_key: String,
|
||||
/// Client's WireGuard public key (base64) — optional.
|
||||
pub wg_public_key: Option<String>,
|
||||
/// Security settings (ACLs, rate limits).
|
||||
pub security: Option<ClientSecurity>,
|
||||
/// Traffic priority (lower = higher priority, default: 100).
|
||||
pub priority: Option<u32>,
|
||||
/// Whether this client is enabled (default: true).
|
||||
pub enabled: Option<bool>,
|
||||
/// Tags for grouping.
|
||||
pub tags: Option<Vec<String>>,
|
||||
/// Optional description.
|
||||
pub description: Option<String>,
|
||||
/// Optional expiry (ISO 8601 timestamp).
|
||||
pub expires_at: Option<String>,
|
||||
/// Assigned VPN IP address.
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientEntry {
|
||||
/// Whether this client is considered enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Whether this client has expired based on current time.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(ref expires) = self.expires_at {
|
||||
if let Ok(expiry) = chrono::DateTime::parse_from_rfc3339(expires) {
|
||||
return chrono::Utc::now() > expiry;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory client registry with dual-key indexing.
|
||||
pub struct ClientRegistry {
|
||||
/// Primary index: clientId → ClientEntry
|
||||
entries: HashMap<String, ClientEntry>,
|
||||
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
||||
key_index: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
key_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a registry from a list of client entries.
|
||||
pub fn from_entries(entries: Vec<ClientEntry>) -> Result<Self> {
|
||||
let mut registry = Self::new();
|
||||
for entry in entries {
|
||||
registry.add(entry)?;
|
||||
}
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Add a client to the registry.
|
||||
pub fn add(&mut self, entry: ClientEntry) -> Result<()> {
|
||||
if self.entries.contains_key(&entry.client_id) {
|
||||
anyhow::bail!("Client '{}' already exists", entry.client_id);
|
||||
}
|
||||
if self.key_index.contains_key(&entry.public_key) {
|
||||
anyhow::bail!("Public key already registered to another client");
|
||||
}
|
||||
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
||||
self.entries.insert(entry.client_id.clone(), entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a client by ID.
|
||||
pub fn remove(&mut self, client_id: &str) -> Result<ClientEntry> {
|
||||
let entry = self.entries.remove(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
self.key_index.remove(&entry.public_key);
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Get a client by ID.
|
||||
pub fn get_by_id(&self, client_id: &str) -> Option<&ClientEntry> {
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Get a client by public key (used during IK handshake verification).
|
||||
pub fn get_by_key(&self, public_key: &str) -> Option<&ClientEntry> {
|
||||
let client_id = self.key_index.get(public_key)?;
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Check if a public key is authorized (exists, enabled, not expired).
|
||||
pub fn is_authorized(&self, public_key: &str) -> bool {
|
||||
match self.get_by_key(public_key) {
|
||||
Some(entry) => entry.is_enabled() && !entry.is_expired(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client entry. The closure receives a mutable reference to the entry.
|
||||
pub fn update<F>(&mut self, client_id: &str, updater: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut ClientEntry),
|
||||
{
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
let old_key = entry.public_key.clone();
|
||||
updater(entry);
|
||||
// If public key changed, update the index
|
||||
if entry.public_key != old_key {
|
||||
self.key_index.remove(&old_key);
|
||||
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all client entries.
|
||||
pub fn list(&self) -> Vec<&ClientEntry> {
|
||||
self.entries.values().collect()
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns the updated entry.
|
||||
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
// Update key index
|
||||
self.key_index.remove(&entry.public_key);
|
||||
entry.public_key = new_public_key.clone();
|
||||
entry.wg_public_key = new_wg_public_key;
|
||||
self.key_index.insert(new_public_key, client_id.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered clients.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(id: &str, key: &str) -> ClientEntry {
|
||||
ClientEntry {
|
||||
client_id: id.to_string(),
|
||||
public_key: key.to_string(),
|
||||
wg_public_key: None,
|
||||
security: None,
|
||||
priority: None,
|
||||
enabled: None,
|
||||
tags: None,
|
||||
description: None,
|
||||
expires_at: None,
|
||||
assigned_ip: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_lookup() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
assert!(reg.get_by_id("alice").is_some());
|
||||
assert!(reg.get_by_key("key_alice").is_some());
|
||||
assert_eq!(reg.get_by_key("key_alice").unwrap().client_id, "alice");
|
||||
assert!(reg.get_by_id("bob").is_none());
|
||||
assert!(reg.get_by_key("key_bob").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_id() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key1")).unwrap();
|
||||
assert!(reg.add(make_entry("alice", "key2")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "same_key")).unwrap();
|
||||
assert!(reg.add(make_entry("bob", "same_key")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert_eq!(reg.len(), 1);
|
||||
|
||||
let removed = reg.remove("alice").unwrap();
|
||||
assert_eq!(removed.client_id, "alice");
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(reg.get_by_key("key_alice").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.remove("ghost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_enabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert!(reg.is_authorized("key_alice")); // enabled by default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_disabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.enabled = Some(false);
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_expired() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2020-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_future_expiry() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2099-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_unknown_key() {
|
||||
let reg = ClientRegistry::new();
|
||||
assert!(!reg.is_authorized("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
reg.update("alice", |entry| {
|
||||
entry.description = Some("Updated".to_string());
|
||||
entry.enabled = Some(false);
|
||||
}).unwrap();
|
||||
|
||||
let entry = reg.get_by_id("alice").unwrap();
|
||||
assert_eq!(entry.description.as_deref(), Some("Updated"));
|
||||
assert!(!entry.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.update("ghost", |_| {}).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rotate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "old_key")).unwrap();
|
||||
|
||||
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
|
||||
|
||||
assert!(reg.get_by_key("old_key").is_none());
|
||||
assert!(reg.get_by_key("new_key").is_some());
|
||||
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_entries() {
|
||||
let entries = vec![
|
||||
make_entry("alice", "key_a"),
|
||||
make_entry("bob", "key_b"),
|
||||
];
|
||||
let reg = ClientRegistry::from_entries(entries).unwrap();
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert!(reg.get_by_key("key_a").is_some());
|
||||
assert!(reg.get_by_key("key_b").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_clients() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_a")).unwrap();
|
||||
reg.add(make_entry("bob", "key_b")).unwrap();
|
||||
let list = reg.list();
|
||||
assert_eq!(list.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_with_rate_limit() {
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.security = Some(ClientSecurity {
|
||||
ip_allow_list: Some(vec!["192.168.1.0/24".to_string()]),
|
||||
ip_block_list: Some(vec!["192.168.1.100".to_string()]),
|
||||
destination_allow_list: None,
|
||||
destination_block_list: None,
|
||||
max_connections: Some(5),
|
||||
rate_limit: Some(ClientRateLimit {
|
||||
bytes_per_sec: 1_000_000,
|
||||
burst_bytes: 2_000_000,
|
||||
}),
|
||||
});
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(entry).unwrap();
|
||||
let e = reg.get_by_id("alice").unwrap();
|
||||
let sec = e.security.as_ref().unwrap();
|
||||
assert_eq!(sec.rate_limit.as_ref().unwrap().bytes_per_sec, 1_000_000);
|
||||
assert_eq!(sec.max_connections, Some(5));
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,10 @@ use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use snow::Builder;
|
||||
|
||||
/// Noise protocol pattern: NK (client knows server pubkey, no client auth at Noise level)
|
||||
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
|
||||
/// Noise protocol pattern: IK (client presents static key, server authenticates client)
|
||||
/// IK = Initiator's static key is transmitted; responder's Key is pre-known.
|
||||
/// This provides mutual authentication: server verifies client identity via public key.
|
||||
const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
/// Generate a new Noise static keypair.
|
||||
/// Returns (public_key_base64, private_key_base64).
|
||||
@@ -22,18 +24,23 @@ pub fn generate_keypair_raw() -> Result<snow::Keypair> {
|
||||
Ok(builder.generate_keypair()?)
|
||||
}
|
||||
|
||||
/// Create a Noise NK initiator (client side).
|
||||
/// The client knows the server's static public key.
|
||||
pub fn create_initiator(server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
/// Create a Noise IK initiator (client side).
|
||||
/// The client provides its own static keypair AND the server's public key.
|
||||
/// The client's static key is transmitted (encrypted) during the handshake,
|
||||
/// allowing the server to authenticate the client.
|
||||
pub fn create_initiator(client_private_key: &[u8], server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
.local_private_key(client_private_key)
|
||||
.remote_public_key(server_public_key)
|
||||
.build_initiator()?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Create a Noise NK responder (server side).
|
||||
/// Create a Noise IK responder (server side).
|
||||
/// The server uses its static private key.
|
||||
/// After the handshake, call `get_remote_static()` on the HandshakeState
|
||||
/// (before `into_transport_mode()`) to retrieve the client's public key.
|
||||
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
@@ -42,19 +49,20 @@ pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Perform the full Noise NK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport).
|
||||
/// Perform the full Noise IK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport, client_public_key).
|
||||
/// The client_public_key is extracted from the responder before entering transport mode.
|
||||
pub fn perform_handshake(
|
||||
mut initiator: snow::HandshakeState,
|
||||
mut responder: snow::HandshakeState,
|
||||
) -> Result<(snow::TransportState, snow::TransportState)> {
|
||||
) -> Result<(snow::TransportState, snow::TransportState, Vec<u8>)> {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es (initiator sends)
|
||||
// -> e, es, s, ss (initiator sends ephemeral + encrypted static key)
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let msg1 = buf[..len].to_vec();
|
||||
|
||||
// <- e, ee (responder reads and responds)
|
||||
// <- e, ee, se (responder reads and responds)
|
||||
responder.read_message(&msg1, &mut buf)?;
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let msg2 = buf[..len].to_vec();
|
||||
@@ -62,10 +70,16 @@ pub fn perform_handshake(
|
||||
// Initiator reads response
|
||||
initiator.read_message(&msg2, &mut buf)?;
|
||||
|
||||
// Extract client's public key from responder BEFORE entering transport mode
|
||||
let client_public_key = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake did not provide client static key"))?
|
||||
.to_vec();
|
||||
|
||||
let i_transport = initiator.into_transport_mode()?;
|
||||
let r_transport = responder.into_transport_mode()?;
|
||||
|
||||
Ok((i_transport, r_transport))
|
||||
Ok((i_transport, r_transport, client_public_key))
|
||||
}
|
||||
|
||||
/// XChaCha20-Poly1305 encryption for post-handshake data.
|
||||
@@ -135,15 +149,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_handshake() {
|
||||
fn noise_ik_handshake() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
let initiator = create_initiator(&server_kp.public).unwrap();
|
||||
let initiator = create_initiator(&client_kp.private, &server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
let (mut i_transport, mut r_transport) =
|
||||
let (mut i_transport, mut r_transport, remote_key) =
|
||||
perform_handshake(initiator, responder).unwrap();
|
||||
|
||||
// Verify the server received the client's public key
|
||||
assert_eq!(remote_key, client_kp.public);
|
||||
|
||||
// Test encrypted communication
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let plaintext = b"hello from client";
|
||||
@@ -159,6 +177,20 @@ mod tests {
|
||||
assert_eq!(&out[..len], plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_ik_wrong_server_key_fails() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let wrong_server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
// Client uses wrong server public key
|
||||
let initiator = create_initiator(&client_kp.private, &wrong_server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
// Handshake should fail because client targeted wrong server
|
||||
assert!(perform_handshake(initiator, responder).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xchacha_encrypt_decrypt() {
|
||||
let key = [42u8; 32];
|
||||
|
||||
@@ -1,87 +1,464 @@
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio::time::{interval, timeout};
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Default keepalive interval (30 seconds).
|
||||
use crate::telemetry::{ConnectionQuality, RttTracker};
|
||||
|
||||
/// Default keepalive interval (30 seconds — used for Degraded state).
|
||||
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Default keepalive ACK timeout (10 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default keepalive ACK timeout (5 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Link health states for adaptive keepalive.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LinkHealth {
|
||||
/// RTT stable, jitter low, no loss. Interval: 60s.
|
||||
Healthy,
|
||||
/// Elevated jitter or occasional loss. Interval: 30s.
|
||||
Degraded,
|
||||
/// High loss or sustained jitter spike. Interval: 10s.
|
||||
Critical,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LinkHealth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded => write!(f, "degraded"),
|
||||
Self::Critical => write!(f, "critical"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the adaptive keepalive state machine.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveKeepaliveConfig {
|
||||
/// Interval when link health is Healthy.
|
||||
pub healthy_interval: Duration,
|
||||
/// Interval when link health is Degraded.
|
||||
pub degraded_interval: Duration,
|
||||
/// Interval when link health is Critical.
|
||||
pub critical_interval: Duration,
|
||||
/// ACK timeout (how long to wait for ACK before declaring timeout).
|
||||
pub ack_timeout: Duration,
|
||||
/// Jitter threshold (ms) to enter Degraded from Healthy.
|
||||
pub jitter_degraded_ms: f64,
|
||||
/// Jitter threshold (ms) to return to Healthy from Degraded.
|
||||
pub jitter_healthy_ms: f64,
|
||||
/// Loss ratio threshold to enter Degraded.
|
||||
pub loss_degraded: f64,
|
||||
/// Loss ratio threshold to enter Critical.
|
||||
pub loss_critical: f64,
|
||||
/// Loss ratio threshold to return from Critical to Degraded.
|
||||
pub loss_recover: f64,
|
||||
/// Loss ratio threshold to return from Degraded to Healthy.
|
||||
pub loss_healthy: f64,
|
||||
/// Consecutive checks required for upward state transitions (hysteresis).
|
||||
pub upgrade_checks: u32,
|
||||
/// Consecutive timeouts to declare peer dead in Critical state.
|
||||
pub dead_peer_timeouts: u32,
|
||||
}
|
||||
|
||||
impl Default for AdaptiveKeepaliveConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
healthy_interval: Duration::from_secs(60),
|
||||
degraded_interval: Duration::from_secs(30),
|
||||
critical_interval: Duration::from_secs(10),
|
||||
ack_timeout: Duration::from_secs(5),
|
||||
jitter_degraded_ms: 50.0,
|
||||
jitter_healthy_ms: 30.0,
|
||||
loss_degraded: 0.05,
|
||||
loss_critical: 0.20,
|
||||
loss_recover: 0.10,
|
||||
loss_healthy: 0.02,
|
||||
upgrade_checks: 3,
|
||||
dead_peer_timeouts: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signals from the keepalive monitor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeepaliveSignal {
|
||||
/// Time to send a keepalive ping.
|
||||
SendPing,
|
||||
/// Peer is considered dead (no ACK received within timeout).
|
||||
/// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
|
||||
SendPing(u64),
|
||||
/// Peer is considered dead (no ACK received within timeout repeatedly).
|
||||
PeerDead,
|
||||
/// Link health state changed.
|
||||
LinkHealthChanged(LinkHealth),
|
||||
}
|
||||
|
||||
/// A keepalive monitor that emits signals on a channel.
|
||||
/// A keepalive monitor with adaptive interval and RTT tracking.
|
||||
pub struct KeepaliveMonitor {
|
||||
interval: Duration,
|
||||
timeout_duration: Duration,
|
||||
config: AdaptiveKeepaliveConfig,
|
||||
health: LinkHealth,
|
||||
rtt_tracker: RttTracker,
|
||||
signal_tx: mpsc::Sender<KeepaliveSignal>,
|
||||
ack_rx: mpsc::Receiver<()>,
|
||||
quality_tx: watch::Sender<ConnectionQuality>,
|
||||
consecutive_upgrade_checks: u32,
|
||||
}
|
||||
|
||||
/// Handle returned to the caller to send ACKs and receive signals.
|
||||
pub struct KeepaliveHandle {
|
||||
pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
pub ack_tx: mpsc::Sender<()>,
|
||||
pub quality_rx: watch::Receiver<ConnectionQuality>,
|
||||
}
|
||||
|
||||
/// Create a keepalive monitor and its handle.
|
||||
/// Create an adaptive keepalive monitor and its handle.
|
||||
pub fn create_keepalive(
|
||||
keepalive_interval: Option<Duration>,
|
||||
keepalive_timeout: Option<Duration>,
|
||||
config: Option<AdaptiveKeepaliveConfig>,
|
||||
) -> (KeepaliveMonitor, KeepaliveHandle) {
|
||||
let config = config.unwrap_or_default();
|
||||
let (signal_tx, signal_rx) = mpsc::channel(8);
|
||||
let (ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
let monitor = KeepaliveMonitor {
|
||||
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL),
|
||||
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT),
|
||||
config,
|
||||
health: LinkHealth::Degraded, // start in Degraded, earn Healthy
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
};
|
||||
|
||||
let handle = KeepaliveHandle { signal_rx, ack_tx };
|
||||
let handle = KeepaliveHandle {
|
||||
signal_rx,
|
||||
ack_tx,
|
||||
quality_rx,
|
||||
};
|
||||
|
||||
(monitor, handle)
|
||||
}
|
||||
|
||||
impl KeepaliveMonitor {
|
||||
fn current_interval(&self) -> Duration {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => self.config.healthy_interval,
|
||||
LinkHealth::Degraded => self.config.degraded_interval,
|
||||
LinkHealth::Critical => self.config.critical_interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the keepalive loop. Blocks until the peer is dead or channels close.
|
||||
pub async fn run(mut self) {
|
||||
let mut ticker = interval(self.interval);
|
||||
let mut ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
debug!("Sending keepalive ping signal");
|
||||
|
||||
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() {
|
||||
// Channel closed
|
||||
break;
|
||||
// Record ping sent, get timestamp for payload
|
||||
let timestamp_ms = self.rtt_tracker.mark_ping_sent();
|
||||
debug!("Sending keepalive ping (ts={})", timestamp_ms);
|
||||
|
||||
if self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::SendPing(timestamp_ms))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break; // channel closed
|
||||
}
|
||||
|
||||
// Wait for ACK within timeout
|
||||
match timeout(self.timeout_duration, self.ack_rx.recv()).await {
|
||||
match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
|
||||
Ok(Some(())) => {
|
||||
debug!("Keepalive ACK received");
|
||||
if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
|
||||
debug!("Keepalive ACK received, RTT: {:?}", rtt);
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Channel closed
|
||||
break;
|
||||
break; // channel closed
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Keepalive ACK timeout — peer considered dead");
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
self.rtt_tracker.record_timeout();
|
||||
warn!(
|
||||
"Keepalive ACK timeout (consecutive: {})",
|
||||
self.rtt_tracker.consecutive_timeouts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Publish quality snapshot
|
||||
let quality = self.rtt_tracker.snapshot();
|
||||
let _ = self.quality_tx.send(quality.clone());
|
||||
|
||||
// Evaluate state transition
|
||||
let new_health = self.evaluate_health(&quality);
|
||||
|
||||
if new_health != self.health {
|
||||
info!("Link health: {} -> {}", self.health, new_health);
|
||||
self.health = new_health;
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
|
||||
// Reset ticker to new interval
|
||||
ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
let _ = self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::LinkHealthChanged(new_health))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Check for dead peer in Critical state
|
||||
if self.health == LinkHealth::Critical
|
||||
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
|
||||
{
|
||||
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
|
||||
self.rtt_tracker.consecutive_timeouts);
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => {
|
||||
// Downgrade conditions
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
if quality.jitter_ms > self.config.jitter_degraded_ms
|
||||
|| quality.loss_ratio > self.config.loss_degraded
|
||||
|| quality.consecutive_timeouts >= 1
|
||||
{
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
LinkHealth::Healthy
|
||||
}
|
||||
LinkHealth::Degraded => {
|
||||
// Downgrade to Critical
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
// Upgrade to Healthy (with hysteresis)
|
||||
if quality.jitter_ms < self.config.jitter_healthy_ms
|
||||
&& quality.loss_ratio < self.config.loss_healthy
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Healthy;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Degraded
|
||||
}
|
||||
LinkHealth::Critical => {
|
||||
// Upgrade to Degraded (with hysteresis), never directly to Healthy
|
||||
if quality.loss_ratio < self.config.loss_recover
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= 2 {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Critical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let config = AdaptiveKeepaliveConfig::default();
|
||||
assert_eq!(config.healthy_interval, Duration::from_secs(60));
|
||||
assert_eq!(config.degraded_interval, Duration::from_secs(30));
|
||||
assert_eq!(config.critical_interval, Duration::from_secs(10));
|
||||
assert_eq!(config.ack_timeout, Duration::from_secs(5));
|
||||
assert_eq!(config.dead_peer_timeouts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_health_display() {
|
||||
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
|
||||
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
|
||||
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
|
||||
}
|
||||
|
||||
// Helper to create a monitor for unit-testing evaluate_health
|
||||
fn make_test_monitor() -> KeepaliveMonitor {
|
||||
let (signal_tx, _signal_rx) = mpsc::channel(8);
|
||||
let (_ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
KeepaliveMonitor {
|
||||
config: AdaptiveKeepaliveConfig::default(),
|
||||
health: LinkHealth::Degraded,
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_jitter() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
jitter_ms: 60.0, // > 50ms threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.06, // > 5% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25, // > 20% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_consecutive_timeouts() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
consecutive_timeouts: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good_quality = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 20.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Should require 3 consecutive good checks (default upgrade_checks)
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 2);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_resets_on_bad_check() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let bad = ConnectionQuality {
|
||||
jitter_ms: 60.0, // too high
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
m.evaluate_health(&good); // 1 check
|
||||
m.evaluate_health(&good); // 2 checks
|
||||
m.evaluate_health(&bad); // resets
|
||||
assert_eq!(m.consecutive_upgrade_checks, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_to_degraded_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let recovering = ConnectionQuality {
|
||||
loss_ratio: 0.05, // < 10% recover threshold
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_never_directly_to_healthy() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let perfect = ConnectionQuality {
|
||||
jitter_ms: 1.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 10.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even with perfect quality, must go through Degraded first
|
||||
m.evaluate_health(&perfect); // 1
|
||||
let result = m.evaluate_health(&perfect); // 2 → Degraded
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
// Not Healthy yet
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interval_matches_health() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(60));
|
||||
m.health = LinkHealth::Degraded;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(30));
|
||||
m.health = LinkHealth::Critical;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(10));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,9 +5,18 @@ pub mod management;
|
||||
pub mod codec;
|
||||
pub mod crypto;
|
||||
pub mod transport;
|
||||
pub mod transport_trait;
|
||||
pub mod quic_transport;
|
||||
pub mod keepalive;
|
||||
pub mod tunnel;
|
||||
pub mod network;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod reconnect;
|
||||
pub mod telemetry;
|
||||
pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
pub mod wireguard;
|
||||
pub mod client_registry;
|
||||
pub mod acl;
|
||||
|
||||
@@ -7,6 +7,7 @@ use tracing::{info, error, warn};
|
||||
use crate::client::{ClientConfig, VpnClient};
|
||||
use crate::crypto;
|
||||
use crate::server::{ServerConfig, VpnServer};
|
||||
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig};
|
||||
|
||||
// ============================================================================
|
||||
// IPC protocol types
|
||||
@@ -93,6 +94,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
|
||||
let mut vpn_client = VpnClient::new();
|
||||
let mut vpn_server = VpnServer::new();
|
||||
let mut wg_client = WgClient::new();
|
||||
let mut wg_server = WgServer::new();
|
||||
|
||||
send_event_stdout("ready", serde_json::json!({ "mode": mode }));
|
||||
|
||||
@@ -127,8 +130,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
};
|
||||
|
||||
let response = match mode {
|
||||
"client" => handle_client_request(&request, &mut vpn_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server).await,
|
||||
"client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await,
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
send_response_stdout(&response);
|
||||
@@ -150,6 +153,8 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
// Shared state behind Mutex for socket mode (multiple connections)
|
||||
let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new()));
|
||||
let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new()));
|
||||
let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new()));
|
||||
let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new()));
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
@@ -157,9 +162,11 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
let mode = mode.to_string();
|
||||
let client = vpn_client.clone();
|
||||
let server = vpn_server.clone();
|
||||
let wg_c = wg_client.clone();
|
||||
let wg_s = wg_server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_socket_connection(stream, &mode, client, server).await
|
||||
handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await
|
||||
{
|
||||
warn!("Socket connection error: {}", e);
|
||||
}
|
||||
@@ -177,6 +184,8 @@ async fn handle_socket_connection(
|
||||
mode: &str,
|
||||
vpn_client: std::sync::Arc<Mutex<VpnClient>>,
|
||||
vpn_server: std::sync::Arc<Mutex<VpnServer>>,
|
||||
wg_client: std::sync::Arc<Mutex<WgClient>>,
|
||||
wg_server: std::sync::Arc<Mutex<WgServer>>,
|
||||
) -> Result<()> {
|
||||
let (reader, mut writer) = stream.into_split();
|
||||
let buf_reader = BufReader::new(reader);
|
||||
@@ -227,11 +236,13 @@ async fn handle_socket_connection(
|
||||
let response = match mode {
|
||||
"client" => {
|
||||
let mut client = vpn_client.lock().await;
|
||||
handle_client_request(&request, &mut client).await
|
||||
let mut wg_c = wg_client.lock().await;
|
||||
handle_client_request(&request, &mut client, &mut wg_c).await
|
||||
}
|
||||
"server" => {
|
||||
let mut server = vpn_server.lock().await;
|
||||
handle_server_request(&request, &mut server).await
|
||||
let mut wg_s = wg_server.lock().await;
|
||||
handle_server_request(&request, &mut server, &mut wg_s).await
|
||||
}
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
@@ -252,38 +263,112 @@ async fn handle_socket_connection(
|
||||
async fn handle_client_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_client: &mut VpnClient,
|
||||
wg_client: &mut WgClient,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"connect" => {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transport is "wireguard"
|
||||
let transport = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transport"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
if transport == "wireguard" {
|
||||
let config: WgClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("WG connect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnect" => {
|
||||
if wg_client.is_running() {
|
||||
match wg_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG disconnect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disconnect" => match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_status().await)
|
||||
} else {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
}
|
||||
}
|
||||
"getConnectionQuality" => {
|
||||
match vpn_client.get_connection_quality() {
|
||||
Some(quality) => {
|
||||
let health = vpn_client.get_link_health().await;
|
||||
let interval_secs = match health {
|
||||
crate::keepalive::LinkHealth::Healthy => 60,
|
||||
crate::keepalive::LinkHealth::Degraded => 30,
|
||||
crate::keepalive::LinkHealth::Critical => 10,
|
||||
};
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", health),
|
||||
"currentKeepaliveIntervalSecs": interval_secs,
|
||||
}))
|
||||
}
|
||||
None => ManagementResponse::ok(id, serde_json::json!(null)),
|
||||
}
|
||||
}
|
||||
"getMtuInfo" => {
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"tunMtu": 1420,
|
||||
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
|
||||
"linkMtu": 1500,
|
||||
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
|
||||
"oversizedPacketsDropped": 0,
|
||||
"icmpTooBigSent": 0,
|
||||
}))
|
||||
}
|
||||
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
|
||||
}
|
||||
@@ -296,45 +381,92 @@ async fn handle_client_request(
|
||||
async fn handle_server_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_server: &mut VpnServer,
|
||||
wg_server: &mut WgServer,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transportMode is "wireguard"
|
||||
let transport_mode = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transportMode"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
if transport_mode == "wireguard" {
|
||||
let config: WgServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
if wg_server.is_running() {
|
||||
match wg_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_status())
|
||||
} else {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"listClients" => {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnectClient" => {
|
||||
@@ -349,6 +481,50 @@ async fn handle_server_request(
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"setClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
|
||||
Some(r) => r,
|
||||
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
|
||||
};
|
||||
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
|
||||
Some(b) => b,
|
||||
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
|
||||
};
|
||||
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_client_rate_limit(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClientTelemetry" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match clients.into_iter().find(|c| c.client_id == client_id) {
|
||||
Some(info) => {
|
||||
match serde_json::to_value(&info) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
|
||||
}
|
||||
}
|
||||
"generateKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
@@ -359,6 +535,153 @@ async fn handle_server_request(
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
"generateWgKeypair" => {
|
||||
let (public_key, private_key) = wireguard::generate_wg_keypair();
|
||||
ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
)
|
||||
}
|
||||
"addWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let config: WgPeerConfig = match serde_json::from_value(
|
||||
request.params.get("peer").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid peer config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.add_peer(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing publicKey".to_string()),
|
||||
};
|
||||
match wg_server.remove_peer(&public_key).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listWgPeers" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
// ── Client Registry (Hub) Commands ────────────────────────────────
|
||||
"createClient" => {
|
||||
let client_partial = request.params.get("client").cloned().unwrap_or_default();
|
||||
match vpn_server.create_client(client_partial).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Create client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_registered_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.get_registered_client(&client_id).await {
|
||||
Ok(entry) => ManagementResponse::ok(id, entry),
|
||||
Err(e) => ManagementResponse::err(id, format!("Get client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listRegisteredClients" => {
|
||||
let clients = vpn_server.list_registered_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"updateClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let update = request.params.get("update").cloned().unwrap_or_default();
|
||||
match vpn_server.update_registered_client(&client_id, update).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Update client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"enableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.enable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Enable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.disable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"rotateClientKey" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.rotate_client_key(&client_id).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Key rotation failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"exportClientConfig" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let format = request.params.get("format").and_then(|v| v.as_str()).unwrap_or("smartvpn");
|
||||
match vpn_server.export_client_config(&client_id, format).await {
|
||||
Ok(config) => ManagementResponse::ok(id, config),
|
||||
Err(e) => ManagementResponse::err(id, format!("Export failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"generateClientKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
_ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
314
rust/src/mtu.rs
Normal file
314
rust/src/mtu.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
/// Overhead breakdown for VPN tunnel encapsulation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TunnelOverhead {
|
||||
/// Outer IP header: 20 bytes (IPv4, no options).
|
||||
pub ip_header: u16,
|
||||
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
|
||||
pub tcp_header: u16,
|
||||
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
|
||||
pub ws_framing: u16,
|
||||
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
|
||||
pub vpn_header: u16,
|
||||
/// Noise AEAD tag: 16 bytes (Poly1305).
|
||||
pub noise_tag: u16,
|
||||
}
|
||||
|
||||
impl TunnelOverhead {
|
||||
/// Conservative default overhead estimate.
|
||||
pub fn default_overhead() -> Self {
|
||||
Self {
|
||||
ip_header: 20,
|
||||
tcp_header: 32,
|
||||
ws_framing: 6,
|
||||
vpn_header: 5,
|
||||
noise_tag: 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total encapsulation overhead in bytes.
|
||||
pub fn total(&self) -> u16 {
|
||||
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
|
||||
}
|
||||
|
||||
/// Compute effective TUN MTU given the underlying link MTU.
|
||||
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
|
||||
link_mtu.saturating_sub(self.total())
|
||||
}
|
||||
}
|
||||
|
||||
/// MTU configuration for the VPN tunnel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MtuConfig {
|
||||
/// Underlying link MTU (typically 1500 for Ethernet).
|
||||
pub link_mtu: u16,
|
||||
/// Computed effective TUN MTU.
|
||||
pub effective_mtu: u16,
|
||||
/// Whether to generate ICMP too-big for oversized packets.
|
||||
pub send_icmp_too_big: bool,
|
||||
/// Counter: oversized packets encountered.
|
||||
pub oversized_packets: u64,
|
||||
/// Counter: ICMP too-big messages generated.
|
||||
pub icmp_too_big_sent: u64,
|
||||
}
|
||||
|
||||
impl MtuConfig {
|
||||
/// Create a new MTU config from the underlying link MTU.
|
||||
pub fn new(link_mtu: u16) -> Self {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let effective = overhead.effective_tun_mtu(link_mtu);
|
||||
Self {
|
||||
link_mtu,
|
||||
effective_mtu: effective,
|
||||
send_icmp_too_big: true,
|
||||
oversized_packets: 0,
|
||||
icmp_too_big_sent: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a packet exceeds the effective MTU.
|
||||
pub fn is_oversized(&self, packet_len: usize) -> bool {
|
||||
packet_len > self.effective_mtu as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Action to take after checking MTU.
|
||||
pub enum MtuAction {
|
||||
/// Packet is within MTU, forward normally.
|
||||
Forward,
|
||||
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
|
||||
SendIcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check packet against MTU config and return the appropriate action.
|
||||
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
|
||||
if !config.is_oversized(packet.len()) {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
if !config.send_icmp_too_big {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
match generate_icmp_too_big(packet, config.effective_mtu) {
|
||||
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
|
||||
None => MtuAction::Forward,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
|
||||
///
|
||||
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
|
||||
/// Returns the complete IP + ICMP packet to write back into the TUN device.
|
||||
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
|
||||
// Need at least 20 bytes of original IP header
|
||||
if original_packet.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Verify it's IPv4
|
||||
if original_packet[0] >> 4 != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse source/dest from original IP header
|
||||
let src_ip = Ipv4Addr::new(
|
||||
original_packet[12],
|
||||
original_packet[13],
|
||||
original_packet[14],
|
||||
original_packet[15],
|
||||
);
|
||||
let dst_ip = Ipv4Addr::new(
|
||||
original_packet[16],
|
||||
original_packet[17],
|
||||
original_packet[18],
|
||||
original_packet[19],
|
||||
);
|
||||
|
||||
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
|
||||
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
|
||||
let icmp_payload = &original_packet[..icmp_data_len];
|
||||
|
||||
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
|
||||
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
|
||||
icmp.push(3); // Type: Destination Unreachable
|
||||
icmp.push(4); // Code: Fragmentation Needed and DF was Set
|
||||
icmp.push(0); // Checksum placeholder
|
||||
icmp.push(0);
|
||||
icmp.push(0); // Unused
|
||||
icmp.push(0);
|
||||
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
|
||||
icmp.extend_from_slice(icmp_payload);
|
||||
|
||||
// Compute ICMP checksum
|
||||
let cksum = internet_checksum(&icmp);
|
||||
icmp[2] = (cksum >> 8) as u8;
|
||||
icmp[3] = (cksum & 0xff) as u8;
|
||||
|
||||
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
|
||||
let total_len = (20 + icmp.len()) as u16;
|
||||
let mut ip = Vec::with_capacity(total_len as usize);
|
||||
ip.push(0x45); // Version 4, IHL 5
|
||||
ip.push(0x00); // DSCP/ECN
|
||||
ip.extend_from_slice(&total_len.to_be_bytes());
|
||||
ip.extend_from_slice(&[0, 0]); // Identification
|
||||
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
|
||||
ip.push(64); // TTL
|
||||
ip.push(1); // Protocol: ICMP
|
||||
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
|
||||
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
|
||||
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
|
||||
|
||||
// Compute IP header checksum
|
||||
let ip_cksum = internet_checksum(&ip[..20]);
|
||||
ip[10] = (ip_cksum >> 8) as u8;
|
||||
ip[11] = (ip_cksum & 0xff) as u8;
|
||||
|
||||
ip.extend_from_slice(&icmp);
|
||||
Some(ip)
|
||||
}
|
||||
|
||||
/// Standard Internet checksum (RFC 1071).
|
||||
fn internet_checksum(data: &[u8]) -> u16 {
|
||||
let mut sum: u32 = 0;
|
||||
let mut i = 0;
|
||||
while i + 1 < data.len() {
|
||||
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
|
||||
i += 2;
|
||||
}
|
||||
if i < data.len() {
|
||||
sum += (data[i] as u32) << 8;
|
||||
}
|
||||
while sum >> 16 != 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16);
|
||||
}
|
||||
!sum as u16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_overhead_total() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
assert_eq!(oh.total(), 79); // 20+32+6+5+16
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_for_ethernet() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(1500);
|
||||
assert_eq!(mtu, 1421); // 1500 - 79
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_saturates_at_zero() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(50); // Less than overhead
|
||||
assert_eq!(mtu, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mtu_config_default() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert_eq!(config.effective_mtu, 1421);
|
||||
assert_eq!(config.link_mtu, 1500);
|
||||
assert!(config.send_icmp_too_big);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_oversized() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert!(!config.is_oversized(1421));
|
||||
assert!(config.is_oversized(1422));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_generation() {
|
||||
// Craft a minimal IPv4 packet
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45; // version 4, IHL 5
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
|
||||
original[9] = 6; // TCP
|
||||
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
|
||||
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
|
||||
|
||||
// Verify it's a valid IPv4 packet
|
||||
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
|
||||
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
|
||||
|
||||
// Source should be original dst (10.0.0.2)
|
||||
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
|
||||
// Destination should be original src (10.0.0.1)
|
||||
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
|
||||
|
||||
// ICMP type 3, code 4
|
||||
assert_eq!(icmp_pkt[20], 3);
|
||||
assert_eq!(icmp_pkt[21], 4);
|
||||
|
||||
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
|
||||
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
|
||||
assert_eq!(mtu, 1421);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_short_packet() {
|
||||
let short = vec![0u8; 10];
|
||||
assert!(generate_icmp_too_big(&short, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_non_ipv4() {
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6
|
||||
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_checksum_valid() {
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45;
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
|
||||
original[9] = 6;
|
||||
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
|
||||
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
|
||||
|
||||
// Verify IP header checksum
|
||||
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
|
||||
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
|
||||
|
||||
// Verify ICMP checksum
|
||||
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
|
||||
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_forward() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let pkt = vec![0u8; 1421]; // Exactly at MTU
|
||||
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_oversized_generates_icmp() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let mut pkt = vec![0u8; 1500];
|
||||
pkt[0] = 0x45; // Valid IPv4
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
|
||||
match check_mtu(&pkt, &config) {
|
||||
MtuAction::SendIcmpTooBig(icmp) => {
|
||||
assert_eq!(icmp[20], 3); // ICMP type
|
||||
assert_eq!(icmp[21], 4); // ICMP code
|
||||
}
|
||||
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
|
||||
}
|
||||
}
|
||||
}
|
||||
490
rust/src/qos.rs
Normal file
490
rust/src/qos.rs
Normal file
@@ -0,0 +1,490 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Priority levels for IP packets.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(u8)]
|
||||
pub enum Priority {
|
||||
High = 0,
|
||||
Normal = 1,
|
||||
Low = 2,
|
||||
}
|
||||
|
||||
/// QoS statistics per priority level.
|
||||
pub struct QosStats {
|
||||
pub high_enqueued: AtomicU64,
|
||||
pub normal_enqueued: AtomicU64,
|
||||
pub low_enqueued: AtomicU64,
|
||||
pub high_dropped: AtomicU64,
|
||||
pub normal_dropped: AtomicU64,
|
||||
pub low_dropped: AtomicU64,
|
||||
}
|
||||
|
||||
impl QosStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
high_enqueued: AtomicU64::new(0),
|
||||
normal_enqueued: AtomicU64::new(0),
|
||||
low_enqueued: AtomicU64::new(0),
|
||||
high_dropped: AtomicU64::new(0),
|
||||
normal_dropped: AtomicU64::new(0),
|
||||
low_dropped: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QosStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Packet classification
|
||||
// ============================================================================
|
||||
|
||||
/// 5-tuple flow key for tracking bulk flows.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct FlowKey {
|
||||
src_ip: u32,
|
||||
dst_ip: u32,
|
||||
src_port: u16,
|
||||
dst_port: u16,
|
||||
protocol: u8,
|
||||
}
|
||||
|
||||
/// Per-flow state for bulk detection.
|
||||
struct FlowState {
|
||||
bytes_total: u64,
|
||||
window_start: Instant,
|
||||
}
|
||||
|
||||
/// Tracks per-flow byte counts for bulk flow detection.
|
||||
struct FlowTracker {
|
||||
flows: HashMap<FlowKey, FlowState>,
|
||||
window_duration: Duration,
|
||||
max_flows: usize,
|
||||
}
|
||||
|
||||
impl FlowTracker {
|
||||
fn new(window_duration: Duration, max_flows: usize) -> Self {
|
||||
Self {
|
||||
flows: HashMap::new(),
|
||||
window_duration,
|
||||
max_flows,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
|
||||
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
// Evict if at capacity — remove oldest entry
|
||||
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
|
||||
if let Some(oldest_key) = self
|
||||
.flows
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.window_start)
|
||||
.map(|(k, _)| *k)
|
||||
{
|
||||
self.flows.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
let state = self.flows.entry(key).or_insert(FlowState {
|
||||
bytes_total: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
// Reset window if expired
|
||||
if now.duration_since(state.window_start) > self.window_duration {
|
||||
state.bytes_total = 0;
|
||||
state.window_start = now;
|
||||
}
|
||||
|
||||
state.bytes_total += bytes;
|
||||
state.bytes_total > threshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Classifies raw IP packets into priority levels.
|
||||
pub struct PacketClassifier {
|
||||
flow_tracker: FlowTracker,
|
||||
/// Byte threshold for classifying a flow as "bulk" (Low priority).
|
||||
bulk_threshold_bytes: u64,
|
||||
}
|
||||
|
||||
impl PacketClassifier {
|
||||
/// Create a new classifier.
|
||||
///
|
||||
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
|
||||
pub fn new(bulk_threshold_bytes: u64) -> Self {
|
||||
Self {
|
||||
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
|
||||
bulk_threshold_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify a raw IPv4 packet into a priority level.
|
||||
///
|
||||
/// The packet must start with the IPv4 header (as read from a TUN device).
|
||||
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
|
||||
// Need at least 20 bytes for a minimal IPv4 header
|
||||
if ip_packet.len() < 20 {
|
||||
return Priority::Normal;
|
||||
}
|
||||
|
||||
let version = ip_packet[0] >> 4;
|
||||
if version != 4 {
|
||||
return Priority::Normal; // Only classify IPv4 for now
|
||||
}
|
||||
|
||||
let ihl = (ip_packet[0] & 0x0F) as usize;
|
||||
let header_len = ihl * 4;
|
||||
let protocol = ip_packet[9];
|
||||
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
|
||||
|
||||
// ICMP is always high priority
|
||||
if protocol == 1 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Small packets (<128 bytes) are high priority (likely interactive)
|
||||
if total_len < 128 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Extract ports for TCP (6) and UDP (17)
|
||||
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
|
||||
&& ip_packet.len() >= header_len + 4
|
||||
{
|
||||
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
|
||||
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
|
||||
(sp, dp)
|
||||
} else {
|
||||
(0, 0)
|
||||
};
|
||||
|
||||
// DNS (port 53) and SSH (port 22) are high priority
|
||||
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Check for bulk flow
|
||||
if protocol == 6 || protocol == 17 {
|
||||
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
|
||||
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
|
||||
|
||||
let key = FlowKey {
|
||||
src_ip,
|
||||
dst_ip,
|
||||
src_port,
|
||||
dst_port,
|
||||
protocol,
|
||||
};
|
||||
|
||||
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
|
||||
return Priority::Low;
|
||||
}
|
||||
}
|
||||
|
||||
Priority::Normal
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Priority channel set
|
||||
// ============================================================================
|
||||
|
||||
/// Error returned when a packet is dropped.
|
||||
#[derive(Debug)]
|
||||
pub enum PacketDropped {
|
||||
LowPriorityDrop,
|
||||
NormalPriorityDrop,
|
||||
HighPriorityDrop,
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
/// Sending half of the priority channel set.
|
||||
pub struct PrioritySender {
|
||||
high_tx: mpsc::Sender<Vec<u8>>,
|
||||
normal_tx: mpsc::Sender<Vec<u8>>,
|
||||
low_tx: mpsc::Sender<Vec<u8>>,
|
||||
stats: Arc<QosStats>,
|
||||
}
|
||||
|
||||
impl PrioritySender {
|
||||
/// Send a packet with the given priority. Implements smart dropping under backpressure.
|
||||
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
|
||||
let (tx, enqueued_counter) = match priority {
|
||||
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
|
||||
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
|
||||
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
|
||||
};
|
||||
|
||||
match tx.try_send(packet) {
|
||||
Ok(()) => {
|
||||
enqueued_counter.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(packet)) => {
|
||||
self.handle_backpressure(packet, priority).await
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_backpressure(
|
||||
&self,
|
||||
packet: Vec<u8>,
|
||||
priority: Priority,
|
||||
) -> Result<(), PacketDropped> {
|
||||
match priority {
|
||||
Priority::Low => {
|
||||
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::LowPriorityDrop)
|
||||
}
|
||||
Priority::Normal => {
|
||||
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::NormalPriorityDrop)
|
||||
}
|
||||
Priority::High => {
|
||||
// Last resort: briefly wait for space, then drop
|
||||
match tokio::time::timeout(
|
||||
Duration::from_millis(5),
|
||||
self.high_tx.send(packet),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(())) => {
|
||||
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::HighPriorityDrop)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Receiving half of the priority channel set.
|
||||
pub struct PriorityReceiver {
|
||||
high_rx: mpsc::Receiver<Vec<u8>>,
|
||||
normal_rx: mpsc::Receiver<Vec<u8>>,
|
||||
low_rx: mpsc::Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl PriorityReceiver {
|
||||
/// Receive the next packet, draining high-priority first (biased select).
|
||||
pub async fn recv(&mut self) -> Option<Vec<u8>> {
|
||||
tokio::select! {
|
||||
biased;
|
||||
Some(pkt) = self.high_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.normal_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.low_rx.recv() => Some(pkt),
|
||||
else => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a priority channel set split into sender and receiver halves.
|
||||
///
|
||||
/// - `high_cap`: capacity of the high-priority channel
|
||||
/// - `normal_cap`: capacity of the normal-priority channel
|
||||
/// - `low_cap`: capacity of the low-priority channel
|
||||
pub fn create_priority_channels(
|
||||
high_cap: usize,
|
||||
normal_cap: usize,
|
||||
low_cap: usize,
|
||||
) -> (PrioritySender, PriorityReceiver) {
|
||||
let (high_tx, high_rx) = mpsc::channel(high_cap);
|
||||
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
|
||||
let (low_tx, low_rx) = mpsc::channel(low_cap);
|
||||
let stats = Arc::new(QosStats::new());
|
||||
|
||||
let sender = PrioritySender {
|
||||
high_tx,
|
||||
normal_tx,
|
||||
low_tx,
|
||||
stats,
|
||||
};
|
||||
|
||||
let receiver = PriorityReceiver {
|
||||
high_rx,
|
||||
normal_rx,
|
||||
low_rx,
|
||||
};
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
/// Get a reference to the QoS stats from a sender.
|
||||
impl PrioritySender {
|
||||
pub fn stats(&self) -> &Arc<QosStats> {
|
||||
&self.stats
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper: craft a minimal IPv4 packet
|
||||
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
|
||||
let mut pkt = vec![0u8; total_len.max(24) as usize];
|
||||
pkt[0] = 0x45; // version 4, IHL 5
|
||||
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
|
||||
pkt[9] = protocol;
|
||||
// src IP
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
// dst IP
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
// ports (at offset 20 for IHL=5)
|
||||
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
|
||||
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
|
||||
pkt
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_icmp_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_dns_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_ssh_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_small_packet_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_normal_http() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_bulk_flow_as_low() {
|
||||
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
|
||||
|
||||
// Send enough traffic to exceed the threshold
|
||||
for _ in 0..100 {
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
c.classify(&pkt);
|
||||
}
|
||||
|
||||
// Next packet from same flow should be Low
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
assert_eq!(c.classify(&pkt), Priority::Low);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_too_short_packet() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = vec![0u8; 10]; // Too short for IPv4 header
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_non_ipv4() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6 version nibble
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_receiver_drains_high_first() {
|
||||
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
// Enqueue in reverse order
|
||||
sender.send(vec![3], Priority::Low).await.unwrap();
|
||||
sender.send(vec![2], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
|
||||
// Should drain High first
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_low_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 1);
|
||||
|
||||
// Fill the low channel
|
||||
sender.send(vec![0], Priority::Low).await.unwrap();
|
||||
|
||||
// Next low-priority send should be dropped
|
||||
let result = sender.send(vec![1], Priority::Low).await;
|
||||
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_normal_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 1, 8);
|
||||
|
||||
sender.send(vec![0], Priority::Normal).await.unwrap();
|
||||
|
||||
let result = sender.send(vec![1], Priority::Normal).await;
|
||||
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_track_enqueued() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
sender.send(vec![2], Priority::High).await.unwrap();
|
||||
sender.send(vec![3], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![4], Priority::Low).await.unwrap();
|
||||
|
||||
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flow_tracker_evicts_at_capacity() {
|
||||
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
|
||||
|
||||
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
|
||||
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
|
||||
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
|
||||
|
||||
ft.record(k1, 100, 1000);
|
||||
ft.record(k2, 100, 1000);
|
||||
// Should evict k1 (oldest)
|
||||
ft.record(k3, 100, 1000);
|
||||
|
||||
assert_eq!(ft.flows.len(), 2);
|
||||
assert!(!ft.flows.contains_key(&k1));
|
||||
}
|
||||
}
|
||||
546
rust/src/quic_transport.rs
Normal file
546
rust/src/quic_transport.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use quinn::crypto::rustls::QuicClientConfig;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn, debug};
|
||||
|
||||
use crate::transport_trait::{TransportSink, TransportStream};
|
||||
|
||||
// ============================================================================
|
||||
// TLS / Certificate helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a self-signed certificate and private key for QUIC.
|
||||
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
|
||||
let cert_der = CertificateDer::from(cert.cert);
|
||||
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
|
||||
Ok((vec![cert_der], key_der))
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
|
||||
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
|
||||
use ring::digest;
|
||||
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server-side QUIC endpoint
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for the QUIC server endpoint.
|
||||
pub struct QuicServerConfig {
|
||||
pub listen_addr: String,
|
||||
pub cert_chain: Vec<CertificateDer<'static>>,
|
||||
pub private_key: PrivateKeyDer<'static>,
|
||||
pub idle_timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// Create a QUIC server endpoint bound to the given address.
|
||||
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
|
||||
let addr: SocketAddr = config.listen_addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(config.cert_chain, config.private_key)?;
|
||||
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
|
||||
));
|
||||
// Enable datagrams with a generous max size
|
||||
transport.datagram_receive_buffer_size(Some(65535));
|
||||
transport.datagram_send_buffer_size(65535);
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
let endpoint = quinn::Endpoint::server(server_config, addr)?;
|
||||
info!("QUIC server listening on {}", addr);
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client-side QUIC connection
|
||||
// ============================================================================
|
||||
|
||||
/// A certificate verifier that accepts any server certificate.
|
||||
/// Safe when Noise NK provides server authentication at the application layer.
|
||||
#[derive(Debug)]
|
||||
struct AcceptAnyCert;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
|
||||
#[derive(Debug)]
|
||||
struct CertHashVerifier {
|
||||
expected_hash: String,
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
let actual_hash = cert_hash(end_entity);
|
||||
if actual_hash == self.expected_hash {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
} else {
|
||||
Err(rustls::Error::General(format!(
|
||||
"Certificate hash mismatch: expected {}, got {}",
|
||||
self.expected_hash, actual_hash
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
// QUIC always uses TLS 1.3
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a QUIC server.
|
||||
///
|
||||
/// - If `server_cert_hash` is provided, verifies the server certificate matches
|
||||
/// the given SHA-256 hash (cert pinning).
|
||||
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
|
||||
/// safe because the Noise NK handshake (which runs over the QUIC stream)
|
||||
/// authenticates the server via its pre-shared public key — the same trust
|
||||
/// model as WireGuard.
|
||||
pub async fn connect_quic(
|
||||
addr: &str,
|
||||
server_cert_hash: Option<&str>,
|
||||
) -> Result<quinn::Connection> {
|
||||
let remote: SocketAddr = addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let tls_config = if let Some(hash) = server_cert_hash {
|
||||
// Pin to a specific certificate hash
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash.to_string(),
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
} else {
|
||||
// Accept any cert — Noise NK provides server authentication
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
};
|
||||
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
info!("Connecting to QUIC server at {}", addr);
|
||||
let connection = endpoint.connect(remote, "smartvpn")?.await?;
|
||||
info!("QUIC connection established");
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QUIC Transport Sink / Stream implementations
|
||||
// ============================================================================
|
||||
|
||||
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportSink {
|
||||
send_stream: quinn::SendStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportSink {
|
||||
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
send_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for QuicTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// Length-prefix framing: [4-byte big-endian length][payload]
|
||||
let len = data.len() as u32;
|
||||
self.send_stream.write_all(&len.to_be_bytes()).await?;
|
||||
self.send_stream.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
let max_size = self.connection.max_datagram_size();
|
||||
match max_size {
|
||||
Some(max) if data.len() <= max => {
|
||||
self.connection.send_datagram(data.into())?;
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Datagram too large or datagrams disabled — fall back to reliable
|
||||
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.send_stream.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportStream {
|
||||
recv_stream: quinn::RecvStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportStream {
|
||||
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
recv_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for QuicTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// Read length prefix
|
||||
let mut len_buf = [0u8; 4];
|
||||
match self.recv_stream.read_exact(&mut len_buf).await {
|
||||
Ok(()) => {}
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
|
||||
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
|
||||
warn!("QUIC connection lost: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(len_buf) as usize;
|
||||
if len > 65536 {
|
||||
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
|
||||
}
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
match self.recv_stream.read_exact(&mut data).await {
|
||||
Ok(()) => Ok(Some(data)),
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
match self.connection.read_datagram().await {
|
||||
Ok(data) => Ok(Some(data.to_vec())),
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
|
||||
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
self.connection.max_datagram_size().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a QUIC connection and open a bidirectional control stream.
|
||||
/// Returns the transport sink/stream pair ready for the VPN handshake.
|
||||
pub async fn accept_quic_connection(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
// The client opens the bidirectional control stream
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
info!("QUIC bidirectional control stream accepted");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
/// Open a QUIC connection's bidirectional control stream (client side).
|
||||
pub async fn open_quic_streams(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
info!("QUIC bidirectional control stream opened");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cert_generation_and_hash() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
assert_eq!(certs.len(), 1);
|
||||
let hash = cert_hash(&certs[0]);
|
||||
// SHA-256 base64 is 44 characters
|
||||
assert_eq!(hash.len(), 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cert_hash_deterministic() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
let hash1 = cert_hash(&certs[0]);
|
||||
let hash2 = cert_hash(&certs[0]);
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
/// Helper: create QUIC server and client endpoints.
|
||||
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
|
||||
let (certs, key) = generate_self_signed_cert().unwrap();
|
||||
let hash = cert_hash(&certs[0]);
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
|
||||
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key).unwrap();
|
||||
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
|
||||
));
|
||||
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
|
||||
|
||||
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash,
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(client_tls).unwrap(),
|
||||
));
|
||||
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
client_ep.set_default_client_config(client_config);
|
||||
|
||||
let server_addr = server_ep.local_addr().unwrap().to_string();
|
||||
(server_ep, client_ep, server_addr)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_server_client_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi, read, echo, finish
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
|
||||
let data = s_recv.read_to_end(1024).await.unwrap();
|
||||
s_send.write_all(&data).await.unwrap();
|
||||
s_send.finish().unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open_bi, write, finish, read
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
|
||||
c_send.write_all(b"hello quinn").await.unwrap();
|
||||
c_send.finish().unwrap();
|
||||
let data = c_recv.read_to_end(1024).await.unwrap();
|
||||
assert_eq!(&data[..], b"hello quinn");
|
||||
|
||||
let _ = server_task.await;
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test transport trait wrappers over QUIC.
|
||||
/// Key: client must send data first (QUIC streams are opened implicitly by data).
|
||||
/// The server accept_bi runs concurrently with the client's first send_reliable.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_transport_trait_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server task: accept connection, then accept_bi (blocks until client sends data)
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
(s_sink, s_stream, server_ep)
|
||||
});
|
||||
|
||||
// Client: connect, open_bi via wrapper
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
|
||||
// Client sends first — this triggers the QUIC stream to become visible to the server
|
||||
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
|
||||
|
||||
// Now server's accept_bi unblocks
|
||||
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
|
||||
|
||||
// Server reads the message
|
||||
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-client");
|
||||
|
||||
// Server -> Client
|
||||
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
|
||||
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-server");
|
||||
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test QUIC datagram support.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_datagram_exchange() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi (opens control stream), then read datagram
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
// Accept bi stream (control channel)
|
||||
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
|
||||
// Read datagram
|
||||
let dgram = conn.read_datagram().await.unwrap();
|
||||
assert_eq!(&dgram[..], b"dgram-payload");
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
|
||||
|
||||
// Send initial data to open the stream (required for QUIC)
|
||||
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
|
||||
|
||||
// Small yield to let the server process the bi stream
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// Send datagram
|
||||
assert!(conn.max_datagram_size().is_some());
|
||||
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test that supports_datagrams returns true for QUIC transports.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_supports_datagrams() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
assert!(s_stream.supports_datagrams());
|
||||
server_ep
|
||||
});
|
||||
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
assert!(c_stream.supports_datagrams());
|
||||
|
||||
// Send data to trigger server's accept_bi
|
||||
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
}
|
||||
141
rust/src/ratelimit.rs
Normal file
141
rust/src/ratelimit.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use std::time::Instant;
|
||||
|
||||
/// A token bucket rate limiter operating on byte granularity.
|
||||
pub struct TokenBucket {
|
||||
/// Tokens (bytes) added per second.
|
||||
rate: f64,
|
||||
/// Maximum burst capacity in bytes.
|
||||
burst: f64,
|
||||
/// Currently available tokens.
|
||||
tokens: f64,
|
||||
/// Last time tokens were refilled.
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new token bucket.
|
||||
///
|
||||
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
|
||||
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
|
||||
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
|
||||
let burst = burst_bytes as f64;
|
||||
Self {
|
||||
rate: rate_bytes_per_sec as f64,
|
||||
burst,
|
||||
tokens: burst, // start full
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
|
||||
pub fn try_consume(&mut self, bytes: usize) -> bool {
|
||||
self.refill();
|
||||
let needed = bytes as f64;
|
||||
if needed <= self.tokens {
|
||||
self.tokens -= needed;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
|
||||
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
|
||||
self.rate = rate_bytes_per_sec as f64;
|
||||
self.burst = burst_bytes as f64;
|
||||
// Cap current tokens at new burst
|
||||
if self.tokens > self.burst {
|
||||
self.tokens = self.burst;
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
||||
self.last_refill = now;
|
||||
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn allows_traffic_under_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Should allow up to burst size immediately
|
||||
assert!(tb.try_consume(1_500_000));
|
||||
assert!(tb.try_consume(400_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocks_traffic_over_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
// Consume entire burst
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
// Next consume should fail (no time to refill)
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_consume_always_succeeds() {
|
||||
let mut tb = TokenBucket::new(0, 0);
|
||||
assert!(tb.try_consume(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refills_over_time() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
|
||||
// Drain completely
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
|
||||
// Wait 100ms — should refill ~100KB
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_caps_tokens() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Tokens start at burst (2MB)
|
||||
tb.update_limits(500_000, 500_000);
|
||||
// Tokens should be capped to new burst (500KB)
|
||||
assert!(tb.try_consume(500_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_changes_rate() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
assert!(tb.try_consume(1_000_000)); // drain
|
||||
|
||||
// Change to higher rate
|
||||
tb.update_limits(10_000_000, 10_000_000);
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
// At 10MB/s, 50ms should refill ~500KB
|
||||
assert!(tb.try_consume(200_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_rate_blocks_after_burst() {
|
||||
let mut tb = TokenBucket::new(0, 100);
|
||||
assert!(tb.try_consume(100));
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
// Zero rate means no refill
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
// Use a low rate so refill between consecutive calls is negligible
|
||||
let mut tb = TokenBucket::new(100, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
// At 100 bytes/sec, the few μs between calls add ~0 tokens
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,27 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
use crate::acl;
|
||||
use crate::client_registry::{ClientEntry, ClientRegistry};
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
use crate::network::IpPool;
|
||||
use crate::ratelimit::TokenBucket;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
|
||||
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
|
||||
|
||||
/// Server configuration (matches TS IVpnServerConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -29,6 +37,18 @@ pub struct ServerConfig {
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
pub enable_nat: Option<bool>,
|
||||
/// Default rate limit for new clients (bytes/sec). None = unlimited.
|
||||
pub default_rate_limit_bytes_per_sec: Option<u64>,
|
||||
/// Default burst size for new clients (bytes). None = unlimited.
|
||||
pub default_burst_bytes: Option<u64>,
|
||||
/// Transport mode: "websocket" (default), "quic", or "both".
|
||||
pub transport_mode: Option<String>,
|
||||
/// QUIC listen address (host:port). Defaults to listen_addr.
|
||||
pub quic_listen_addr: Option<String>,
|
||||
/// QUIC idle timeout in seconds (default: 30).
|
||||
pub quic_idle_timeout_secs: Option<u64>,
|
||||
/// Pre-registered clients for IK authentication.
|
||||
pub clients: Option<Vec<ClientEntry>>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -40,6 +60,16 @@ pub struct ClientInfo {
|
||||
pub connected_since: String,
|
||||
pub bytes_sent: u64,
|
||||
pub bytes_received: u64,
|
||||
pub packets_dropped: u64,
|
||||
pub bytes_dropped: u64,
|
||||
pub last_keepalive_at: Option<String>,
|
||||
pub keepalives_received: u64,
|
||||
pub rate_limit_bytes_per_sec: Option<u64>,
|
||||
pub burst_bytes: Option<u64>,
|
||||
/// Client's authenticated Noise IK public key (base64).
|
||||
pub authenticated_key: String,
|
||||
/// Registered client ID from the client registry.
|
||||
pub registered_client_id: String,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -63,7 +93,10 @@ pub struct ServerState {
|
||||
pub ip_pool: Mutex<IpPool>,
|
||||
pub clients: RwLock<HashMap<String, ClientInfo>>,
|
||||
pub stats: RwLock<ServerStatistics>,
|
||||
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
|
||||
pub mtu_config: MtuConfig,
|
||||
pub started_at: std::time::Instant,
|
||||
pub client_registry: RwLock<ClientRegistry>,
|
||||
}
|
||||
|
||||
/// The VPN server.
|
||||
@@ -98,26 +131,84 @@ impl VpnServer {
|
||||
}
|
||||
}
|
||||
|
||||
let link_mtu = config.mtu.unwrap_or(1420);
|
||||
// Compute effective MTU from overhead
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
|
||||
|
||||
// Build client registry from config
|
||||
let registry = ClientRegistry::from_entries(
|
||||
config.clients.clone().unwrap_or_default()
|
||||
)?;
|
||||
info!("Client registry loaded with {} entries", registry.len());
|
||||
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
ip_pool: Mutex::new(ip_pool),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(ServerStatistics::default()),
|
||||
rate_limiters: Mutex::new(HashMap::new()),
|
||||
mtu_config,
|
||||
started_at: std::time::Instant::now(),
|
||||
client_registry: RwLock::new(registry),
|
||||
});
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||
self.state = Some(state.clone());
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
|
||||
let listen_addr = config.listen_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
info!("VPN server started");
|
||||
match transport_mode {
|
||||
"quic" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
"both" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
let state2 = state.clone();
|
||||
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
|
||||
// Store second shutdown sender so both listeners stop
|
||||
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
|
||||
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
|
||||
self.shutdown_tx = Some(combined_tx);
|
||||
|
||||
// Forward combined shutdown to both listeners
|
||||
tokio::spawn(async move {
|
||||
combined_rx.recv().await;
|
||||
let _ = shutdown_tx_orig.send(()).await;
|
||||
let _ = shutdown_tx2.send(()).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("WebSocket listener error: {}", e);
|
||||
}
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// "websocket" (default)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!("VPN server started (transport: {})", transport_mode);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -166,14 +257,314 @@ impl VpnServer {
|
||||
if let Some(client) = clients.remove(client_id) {
|
||||
let ip: Ipv4Addr = client.assigned_ip.parse()?;
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
state.rate_limiters.lock().await.remove(client_id);
|
||||
info!("Client {} disconnected", client_id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a rate limit for a specific client.
|
||||
pub async fn set_client_rate_limit(
|
||||
&self,
|
||||
client_id: &str,
|
||||
rate_bytes_per_sec: u64,
|
||||
burst_bytes: u64,
|
||||
) -> Result<()> {
|
||||
if let Some(ref state) = self.state {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
if let Some(limiter) = limiters.get_mut(client_id) {
|
||||
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
|
||||
} else {
|
||||
limiters.insert(
|
||||
client_id.to_string(),
|
||||
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
|
||||
);
|
||||
}
|
||||
// Update client info
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(client_id) {
|
||||
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
|
||||
info.burst_bytes = Some(burst_bytes);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove rate limit for a specific client (unlimited).
|
||||
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.rate_limiters.lock().await.remove(client_id);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(client_id) {
|
||||
info.rate_limit_bytes_per_sec = None;
|
||||
info.burst_bytes = None;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ───────────────────────────────────
|
||||
|
||||
/// Create a new client entry. Generates keypairs and assigns an IP.
|
||||
/// Returns a JSON value with the full config bundle including secrets.
|
||||
pub async fn create_client(&self, partial: serde_json::Value) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let client_id = partial.get("clientId")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("clientId is required"))?
|
||||
.to_string();
|
||||
|
||||
// Generate Noise IK keypair for the client
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
|
||||
// Generate WireGuard keypair for the client
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
// Allocate a VPN IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Build entry from partial + generated values
|
||||
let entry = ClientEntry {
|
||||
client_id: client_id.clone(),
|
||||
public_key: noise_pub.clone(),
|
||||
wg_public_key: Some(wg_pub.clone()),
|
||||
security: serde_json::from_value(
|
||||
partial.get("security").cloned().unwrap_or(serde_json::Value::Null)
|
||||
).ok(),
|
||||
priority: partial.get("priority").and_then(|v| v.as_u64()).map(|v| v as u32),
|
||||
enabled: partial.get("enabled").and_then(|v| v.as_bool()).or(Some(true)),
|
||||
tags: partial.get("tags").and_then(|v| {
|
||||
v.as_array().map(|a| a.iter().filter_map(|s| s.as_str().map(String::from)).collect())
|
||||
}),
|
||||
description: partial.get("description").and_then(|v| v.as_str()).map(String::from),
|
||||
expires_at: partial.get("expiresAt").and_then(|v| v.as_str()).map(String::from),
|
||||
assigned_ip: Some(assigned_ip.to_string()),
|
||||
};
|
||||
|
||||
// Add to registry
|
||||
state.client_registry.write().await.add(entry.clone())?;
|
||||
|
||||
// Build SmartVPN client config
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
// Build WireGuard config string
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv,
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
let entry_json = serde_json::to_value(&entry)?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Remove a registered client from the registry (and disconnect if connected).
|
||||
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let entry = state.client_registry.write().await.remove(client_id)?;
|
||||
// Release the IP if assigned
|
||||
if let Some(ref ip_str) = entry.assigned_ip {
|
||||
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
}
|
||||
}
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a registered client by ID.
|
||||
pub async fn get_registered_client(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
Ok(serde_json::to_value(entry)?)
|
||||
}
|
||||
|
||||
/// List all registered clients.
|
||||
pub async fn list_registered_clients(&self) -> Vec<ClientEntry> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.client_registry.read().await.list().into_iter().cloned().collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a registered client's fields.
|
||||
pub async fn update_registered_client(&self, client_id: &str, update: serde_json::Value) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
if let Some(security) = update.get("security") {
|
||||
entry.security = serde_json::from_value(security.clone()).ok();
|
||||
}
|
||||
if let Some(priority) = update.get("priority").and_then(|v| v.as_u64()) {
|
||||
entry.priority = Some(priority as u32);
|
||||
}
|
||||
if let Some(enabled) = update.get("enabled").and_then(|v| v.as_bool()) {
|
||||
entry.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(tags) = update.get("tags").and_then(|v| v.as_array()) {
|
||||
entry.tags = Some(tags.iter().filter_map(|s| s.as_str().map(String::from)).collect());
|
||||
}
|
||||
if let Some(desc) = update.get("description").and_then(|v| v.as_str()) {
|
||||
entry.description = Some(desc.to_string());
|
||||
}
|
||||
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
|
||||
entry.expires_at = Some(expires.to_string());
|
||||
}
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enable a registered client.
|
||||
pub async fn enable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(true);
|
||||
})
|
||||
}
|
||||
|
||||
/// Disable a registered client (also disconnects if connected).
|
||||
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(false);
|
||||
})?;
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
pub async fn rotate_client_key(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
state.client_registry.write().await.rotate_key(
|
||||
client_id,
|
||||
noise_pub.clone(),
|
||||
Some(wg_pub.clone()),
|
||||
)?;
|
||||
|
||||
// Disconnect existing connection (old key is no longer valid)
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
|
||||
// Get updated entry for the config bundle
|
||||
let entry_json = self.get_registered_client(client_id).await?;
|
||||
let assigned_ip = entry_json.get("assignedIp")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0.0.0.0");
|
||||
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv, assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Export a client config (without secrets) in the specified format.
|
||||
pub async fn export_client_config(&self, client_id: &str, format: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
|
||||
match format {
|
||||
"smartvpn" => {
|
||||
Ok(serde_json::json!({
|
||||
"config": {
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPublicKey": entry.public_key,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
}
|
||||
}))
|
||||
}
|
||||
"wireguard" => {
|
||||
let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0");
|
||||
let config = format!(
|
||||
"[Interface]\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
Ok(serde_json::json!({ "config": config }))
|
||||
}
|
||||
_ => anyhow::bail!("Unknown format: {}", format),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_listener(
|
||||
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
|
||||
/// to the transport-agnostic `handle_client_connection`.
|
||||
async fn run_ws_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
@@ -189,8 +580,20 @@ async fn run_listener(
|
||||
info!("New connection from {}", addr);
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_client_connection(state, stream).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
match transport::accept_connection(stream).await {
|
||||
Ok(ws) => {
|
||||
let (sink, stream) = transport_trait::split_ws(ws);
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("WebSocket upgrade failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -209,29 +612,108 @@ async fn run_listener(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
|
||||
/// `handle_client_connection`.
|
||||
async fn run_quic_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
idle_timeout_secs: u64,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
) -> Result<()> {
|
||||
// Generate or use configured TLS certificate for QUIC
|
||||
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
|
||||
(&state.config.tls_cert, &state.config.tls_key)
|
||||
{
|
||||
// Parse PEM certificates
|
||||
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut cert_pem.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
|
||||
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
|
||||
(certs, key)
|
||||
} else {
|
||||
// Generate self-signed certificate
|
||||
let (certs, key) = quic_transport::generate_self_signed_cert()?;
|
||||
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
|
||||
(certs, key)
|
||||
};
|
||||
|
||||
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
|
||||
listen_addr,
|
||||
cert_chain,
|
||||
private_key,
|
||||
idle_timeout_secs,
|
||||
})?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
incoming = endpoint.accept() => {
|
||||
match incoming {
|
||||
Some(incoming) => {
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
match incoming.await {
|
||||
Ok(conn) => {
|
||||
let remote = conn.remote_address();
|
||||
info!("New QUIC connection from {}", remote);
|
||||
match quic_transport::accept_quic_connection(conn).await {
|
||||
Ok((sink, stream)) => {
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC stream accept failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC handshake failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
None => {
|
||||
info!("QUIC endpoint closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("QUIC shutdown signal received");
|
||||
endpoint.close(0u32.into(), b"shutdown");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transport-agnostic client handler. Performs the Noise IK handshake, authenticates
|
||||
/// the client against the registry, and runs the main packet forwarding loop.
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
stream: tokio::net::TcpStream,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
) -> Result<()> {
|
||||
let ws = transport::accept_connection(stream).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
|
||||
let client_id = uuid_v4();
|
||||
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
let server_private_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&state.config.private_key,
|
||||
)?;
|
||||
|
||||
// Noise IK handshake (server side = responder)
|
||||
let mut responder = crypto::create_responder(&server_private_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// Receive handshake init
|
||||
let init_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected handshake init message"),
|
||||
// Receive handshake init (-> e, es, s, ss)
|
||||
let init_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before handshake"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&init_msg[..]);
|
||||
@@ -243,6 +725,47 @@ async fn handle_client_connection(
|
||||
}
|
||||
|
||||
responder.read_message(&frame.payload, &mut buf)?;
|
||||
|
||||
// Extract client's static public key BEFORE entering transport mode
|
||||
let client_pub_key_bytes = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake: no client static key received"))?
|
||||
.to_vec();
|
||||
let client_pub_key_b64 = base64::Engine::encode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&client_pub_key_bytes,
|
||||
);
|
||||
|
||||
// Verify client against registry
|
||||
let (registered_client_id, client_security) = {
|
||||
let registry = state.client_registry.read().await;
|
||||
if !registry.is_authorized(&client_pub_key_b64) {
|
||||
warn!("Rejecting unauthorized client with key {}", &client_pub_key_b64[..8]);
|
||||
// Send handshake response but then disconnect
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_frame = Frame {
|
||||
packet_type: PacketType::HandshakeResp,
|
||||
payload: buf[..len].to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// Send disconnect frame
|
||||
let disconnect_frame = Frame {
|
||||
packet_type: PacketType::Disconnect,
|
||||
payload: Vec::new(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Client not authorized");
|
||||
}
|
||||
let entry = registry.get_by_key(&client_pub_key_b64).unwrap();
|
||||
(entry.client_id.clone(), entry.security.clone())
|
||||
};
|
||||
|
||||
// Complete handshake (<- e, ee, se)
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_payload = buf[..len].to_vec();
|
||||
|
||||
@@ -252,30 +775,65 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Register client
|
||||
// Use the registered client ID as the connection ID
|
||||
let client_id = registered_client_id.clone();
|
||||
|
||||
// Allocate IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Determine rate limits: per-client security overrides server defaults
|
||||
let (rate_limit, burst) = if let Some(ref sec) = client_security {
|
||||
if let Some(ref rl) = sec.rate_limit {
|
||||
(Some(rl.bytes_per_sec), Some(rl.burst_bytes))
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
}
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
};
|
||||
|
||||
// Register connected client
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: assigned_ip.to_string(),
|
||||
connected_since: timestamp_now(),
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
packets_dropped: 0,
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: rate_limit,
|
||||
burst_bytes: burst,
|
||||
authenticated_key: client_pub_key_b64.clone(),
|
||||
registered_client_id: registered_client_id.clone(),
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
// Set up rate limiter
|
||||
if let (Some(rate), Some(burst)) = (rate_limit, burst) {
|
||||
state
|
||||
.rate_limiters
|
||||
.lock()
|
||||
.await
|
||||
.insert(client_id.clone(), TokenBucket::new(rate, burst));
|
||||
}
|
||||
|
||||
{
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.total_connections += 1;
|
||||
}
|
||||
|
||||
// Send assigned IP info (encrypted)
|
||||
// Send assigned IP info (encrypted), include effective MTU
|
||||
let ip_info = serde_json::json!({
|
||||
"assignedIp": assigned_ip.to_string(),
|
||||
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
|
||||
"mtu": state.config.mtu.unwrap_or(1420),
|
||||
"effectiveMtu": state.mtu_config.effective_mtu,
|
||||
});
|
||||
let ip_info_bytes = serde_json::to_vec(&ip_info)?;
|
||||
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
|
||||
@@ -285,70 +843,130 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
info!("Client {} ({}) connected with IP {}", registered_client_id, &client_pub_key_b64[..8], assigned_ip);
|
||||
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
|
||||
// Main packet loop
|
||||
loop {
|
||||
match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.bytes_received += len as u64;
|
||||
stats.packets_received += 1;
|
||||
tokio::select! {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Ok(Some(data)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
// ACL check on decrypted packet
|
||||
if let Some(ref sec) = client_security {
|
||||
if len >= 20 {
|
||||
// Extract src/dst from IPv4 header
|
||||
let src = Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]);
|
||||
let dst = Ipv4Addr::new(buf[16], buf[17], buf[18], buf[19]);
|
||||
let acl_result = acl::check_acl(sec, src, dst);
|
||||
if acl_result != acl::AclResult::Allow {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting check
|
||||
let allowed = {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
if let Some(limiter) = limiters.get_mut(&client_id) {
|
||||
limiter.try_consume(len)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if !allowed {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.bytes_received += len as u64;
|
||||
stats.packets_received += 1;
|
||||
|
||||
// Update per-client stats
|
||||
drop(stats);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.bytes_received += len as u64;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error from {}: {}", client_id, e);
|
||||
PacketType::Keepalive => {
|
||||
// Echo the keepalive payload back in the ACK
|
||||
let ack_frame = Frame {
|
||||
packet_type: PacketType::KeepaliveAck,
|
||||
payload: frame.payload.clone(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
stats.keepalives_sent += 1;
|
||||
|
||||
// Update per-client keepalive tracking
|
||||
drop(stats);
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.last_keepalive_at = Some(timestamp_now());
|
||||
info.keepalives_received += 1;
|
||||
}
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Client {} sent disconnect", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!("Incomplete frame from {}", client_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Frame decode error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
PacketType::Keepalive => {
|
||||
let ack_frame = Frame {
|
||||
packet_type: PacketType::KeepaliveAck,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
stats.keepalives_sent += 1;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Client {} sent disconnect", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok(None) => {
|
||||
warn!("Incomplete frame from {}", client_id);
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Frame decode error from {}: {}", client_id, e);
|
||||
warn!("Transport error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
|
||||
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -357,25 +975,12 @@ async fn handle_client_connection(
|
||||
// Cleanup
|
||||
state.clients.write().await.remove(&client_id);
|
||||
state.ip_pool.lock().await.release(&assigned_ip);
|
||||
state.rate_limiters.lock().await.remove(&client_id);
|
||||
info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn uuid_v4() -> String {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let bytes: [u8; 16] = rng.gen();
|
||||
format!(
|
||||
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
|
||||
bytes[0], bytes[1], bytes[2], bytes[3],
|
||||
bytes[4], bytes[5],
|
||||
bytes[6], bytes[7],
|
||||
bytes[8], bytes[9],
|
||||
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
|
||||
)
|
||||
}
|
||||
|
||||
fn timestamp_now() -> String {
|
||||
use std::time::SystemTime;
|
||||
let duration = SystemTime::now()
|
||||
|
||||
317
rust/src/telemetry.rs
Normal file
317
rust/src/telemetry.rs
Normal file
@@ -0,0 +1,317 @@
|
||||
use serde::Serialize;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single RTT sample.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RttSample {
|
||||
_rtt: Duration,
|
||||
_timestamp: Instant,
|
||||
was_timeout: bool,
|
||||
}
|
||||
|
||||
/// Snapshot of connection quality metrics.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConnectionQuality {
|
||||
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
|
||||
pub srtt_ms: f64,
|
||||
/// Jitter in milliseconds (mean deviation of RTT).
|
||||
pub jitter_ms: f64,
|
||||
/// Minimum RTT observed in the sample window.
|
||||
pub min_rtt_ms: f64,
|
||||
/// Maximum RTT observed in the sample window.
|
||||
pub max_rtt_ms: f64,
|
||||
/// Packet loss ratio over the sample window (0.0 - 1.0).
|
||||
pub loss_ratio: f64,
|
||||
/// Number of consecutive keepalive timeouts (0 if last succeeded).
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Total keepalives sent.
|
||||
pub keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
pub keepalives_acked: u64,
|
||||
}
|
||||
|
||||
/// Tracks connection quality from keepalive round-trips.
|
||||
pub struct RttTracker {
|
||||
/// Maximum number of samples to keep in the window.
|
||||
max_samples: usize,
|
||||
/// Recent RTT samples (including timeout markers).
|
||||
samples: VecDeque<RttSample>,
|
||||
/// When the last keepalive was sent (for computing RTT on ACK).
|
||||
pending_ping_sent_at: Option<Instant>,
|
||||
/// Number of consecutive keepalive timeouts.
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Smoothed RTT (EMA).
|
||||
srtt: Option<f64>,
|
||||
/// Jitter (mean deviation).
|
||||
jitter: f64,
|
||||
/// Minimum RTT observed.
|
||||
min_rtt: f64,
|
||||
/// Maximum RTT observed.
|
||||
max_rtt: f64,
|
||||
/// Total keepalives sent.
|
||||
keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
keepalives_acked: u64,
|
||||
/// Previous RTT sample for jitter calculation.
|
||||
last_rtt_ms: Option<f64>,
|
||||
}
|
||||
|
||||
impl RttTracker {
|
||||
/// Create a new tracker with the given window size.
|
||||
pub fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
max_samples,
|
||||
samples: VecDeque::with_capacity(max_samples),
|
||||
pending_ping_sent_at: None,
|
||||
consecutive_timeouts: 0,
|
||||
srtt: None,
|
||||
jitter: 0.0,
|
||||
min_rtt: f64::MAX,
|
||||
max_rtt: 0.0,
|
||||
keepalives_sent: 0,
|
||||
keepalives_acked: 0,
|
||||
last_rtt_ms: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record that a keepalive was sent.
|
||||
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
|
||||
pub fn mark_ping_sent(&mut self) -> u64 {
|
||||
self.pending_ping_sent_at = Some(Instant::now());
|
||||
self.keepalives_sent += 1;
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Record that a keepalive ACK was received with the echoed timestamp.
|
||||
/// Returns the computed RTT if a pending ping was recorded.
|
||||
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
|
||||
let sent_at = self.pending_ping_sent_at.take()?;
|
||||
let rtt = sent_at.elapsed();
|
||||
let rtt_ms = rtt.as_secs_f64() * 1000.0;
|
||||
|
||||
self.keepalives_acked += 1;
|
||||
self.consecutive_timeouts = 0;
|
||||
|
||||
// Update SRTT (RFC 6298: alpha = 1/8)
|
||||
match self.srtt {
|
||||
None => {
|
||||
self.srtt = Some(rtt_ms);
|
||||
self.jitter = rtt_ms / 2.0;
|
||||
}
|
||||
Some(prev_srtt) => {
|
||||
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
|
||||
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
|
||||
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
|
||||
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
|
||||
}
|
||||
}
|
||||
|
||||
// Update min/max
|
||||
if rtt_ms < self.min_rtt {
|
||||
self.min_rtt = rtt_ms;
|
||||
}
|
||||
if rtt_ms > self.max_rtt {
|
||||
self.max_rtt = rtt_ms;
|
||||
}
|
||||
|
||||
self.last_rtt_ms = Some(rtt_ms);
|
||||
|
||||
// Push sample into window
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: rtt,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: false,
|
||||
});
|
||||
|
||||
Some(rtt)
|
||||
}
|
||||
|
||||
/// Record that a keepalive timed out (no ACK received).
|
||||
pub fn record_timeout(&mut self) {
|
||||
self.consecutive_timeouts += 1;
|
||||
self.pending_ping_sent_at = None;
|
||||
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: Duration::ZERO,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: true,
|
||||
});
|
||||
}
|
||||
|
||||
/// Get a snapshot of the current connection quality.
|
||||
pub fn snapshot(&self) -> ConnectionQuality {
|
||||
let loss_ratio = if self.samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
|
||||
timeouts as f64 / self.samples.len() as f64
|
||||
};
|
||||
|
||||
ConnectionQuality {
|
||||
srtt_ms: self.srtt.unwrap_or(0.0),
|
||||
jitter_ms: self.jitter,
|
||||
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
|
||||
max_rtt_ms: self.max_rtt,
|
||||
loss_ratio,
|
||||
consecutive_timeouts: self.consecutive_timeouts,
|
||||
keepalives_sent: self.keepalives_sent,
|
||||
keepalives_acked: self.keepalives_acked,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_tracker_has_zero_quality() {
|
||||
let tracker = RttTracker::new(30);
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.srtt_ms, 0.0);
|
||||
assert_eq!(q.jitter_ms, 0.0);
|
||||
assert_eq!(q.loss_ratio, 0.0);
|
||||
assert_eq!(q.consecutive_timeouts, 0);
|
||||
assert_eq!(q.keepalives_sent, 0);
|
||||
assert_eq!(q.keepalives_acked, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_ping_returns_timestamp() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
// Should be a reasonable epoch-ms value (after 2020)
|
||||
assert!(ts > 1_577_836_800_000);
|
||||
assert_eq!(tracker.keepalives_sent, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_computes_rtt() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
let rtt = tracker.record_ack(ts);
|
||||
assert!(rtt.is_some());
|
||||
let rtt = rtt.unwrap();
|
||||
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
|
||||
assert_eq!(tracker.keepalives_acked, 1);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_without_pending_returns_none() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
assert!(tracker.record_ack(12345).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn srtt_converges() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// Simulate several ping/ack cycles with ~10ms RTT
|
||||
for _ in 0..10 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
|
||||
let q = tracker.snapshot();
|
||||
// SRTT should be roughly 10ms (allowing for scheduling variance)
|
||||
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
|
||||
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timeout_increments_counter_and_loss() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 2);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ack_resets_consecutive_timeouts() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_ratio_over_mixed_window() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
|
||||
for _ in 0..3 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!((q.loss_ratio - 0.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_evicts_old_samples() {
|
||||
let mut tracker = RttTracker::new(5);
|
||||
|
||||
// Fill window with 5 timeouts
|
||||
for _ in 0..5 {
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
|
||||
|
||||
// Add 5 successes — should evict all timeouts
|
||||
for _ in 0..5 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_rtt_tracked() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(15));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!(q.min_rtt_ms < q.max_rtt_ms);
|
||||
assert!(q.min_rtt_ms > 0.0);
|
||||
}
|
||||
}
|
||||
116
rust/src/transport_trait.rs
Normal file
116
rust/src/transport_trait.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::transport::WsStream;
|
||||
|
||||
// ============================================================================
|
||||
// Transport trait abstraction
|
||||
// ============================================================================
|
||||
|
||||
/// Outbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportSink: Send + 'static {
|
||||
/// Send a framed binary message on the reliable channel.
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Send a datagram (unreliable, best-effort).
|
||||
/// Falls back to reliable if the transport does not support datagrams.
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Gracefully close the transport.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Inbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportStream: Send + 'static {
|
||||
/// Receive the next reliable binary message. Returns `None` on close.
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Receive the next datagram. Returns `None` if datagrams are unsupported
|
||||
/// or the connection is closed.
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Whether this transport supports unreliable datagrams.
|
||||
fn supports_datagrams(&self) -> bool;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket implementation
|
||||
// ============================================================================
|
||||
|
||||
/// WebSocket transport sink (wraps the write half of a split WsStream).
|
||||
pub struct WsTransportSink {
|
||||
inner: futures_util::stream::SplitSink<WsStream, Message>,
|
||||
}
|
||||
|
||||
impl WsTransportSink {
|
||||
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for WsTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
self.inner.send(Message::Binary(data.into())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// WebSocket has no datagram support — fall back to reliable.
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.inner.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket transport stream (wraps the read half of a split WsStream).
|
||||
pub struct WsTransportStream {
|
||||
inner: futures_util::stream::SplitStream<WsStream>,
|
||||
}
|
||||
|
||||
impl WsTransportStream {
|
||||
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for WsTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
|
||||
Some(Ok(Message::Close(_))) | None => return Ok(None),
|
||||
Some(Ok(Message::Ping(_))) => {
|
||||
// Ping handling is done at the tungstenite layer automatically
|
||||
// when the sink side is alive. Just skip here.
|
||||
continue;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// WebSocket does not support datagrams.
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a WebSocket stream into transport sink and stream halves.
|
||||
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
|
||||
let (sink, stream) = ws.split();
|
||||
(WsTransportSink::new(sink), WsTransportStream::new(stream))
|
||||
}
|
||||
@@ -64,6 +64,22 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Action to take after checking a packet against the MTU.
|
||||
pub enum TunMtuAction {
|
||||
/// Packet is within MTU limits, forward it.
|
||||
Forward,
|
||||
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
|
||||
IcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check a TUN packet against the MTU and return the appropriate action.
|
||||
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
|
||||
match crate::mtu::check_mtu(packet, mtu_config) {
|
||||
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
|
||||
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a route.
|
||||
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
let output = tokio::process::Command::new("ip")
|
||||
|
||||
1329
rust/src/wireguard.rs
Normal file
1329
rust/src/wireguard.rs
Normal file
File diff suppressed because it is too large
Load Diff
320
rust/tests/wg_e2e.rs
Normal file
320
rust/tests/wg_e2e.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! End-to-end WireGuard protocol tests over real UDP sockets.
|
||||
//!
|
||||
//! Entirely userspace — no root, no TUN devices.
|
||||
//! Two boringtun `Tunn` instances exchange real WireGuard packets
|
||||
//! over loopback UDP, validating handshake, encryption, and data flow.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::time;
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
use smartvpn_daemon::wireguard::generate_wg_keypair;
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn parse_key_pair(pub_b64: &str, priv_b64: &str) -> (PublicKey, StaticSecret) {
|
||||
let pub_bytes: [u8; 32] = BASE64.decode(pub_b64).unwrap().try_into().unwrap();
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
(PublicKey::from(pub_bytes), StaticSecret::from(priv_bytes))
|
||||
}
|
||||
|
||||
fn clone_secret(priv_b64: &str) -> StaticSecret {
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
StaticSecret::from(priv_bytes)
|
||||
}
|
||||
|
||||
fn make_ipv4_packet(src: Ipv4Addr, dst: Ipv4Addr, payload: &[u8]) -> Vec<u8> {
|
||||
let total_len = 20 + payload.len();
|
||||
let mut pkt = vec![0u8; total_len];
|
||||
pkt[0] = 0x45;
|
||||
pkt[2] = (total_len >> 8) as u8;
|
||||
pkt[3] = total_len as u8;
|
||||
pkt[9] = 0x11;
|
||||
pkt[12..16].copy_from_slice(&src.octets());
|
||||
pkt[16..20].copy_from_slice(&dst.octets());
|
||||
pkt[20..].copy_from_slice(payload);
|
||||
pkt
|
||||
}
|
||||
|
||||
/// Send any WriteToNetwork result, then drain the tunn for more packets.
|
||||
async fn send_and_drain(
|
||||
tunn: &mut Tunn,
|
||||
pkt: &[u8],
|
||||
socket: &UdpSocket,
|
||||
peer: SocketAddr,
|
||||
) {
|
||||
socket.send_to(pkt, peer).await.unwrap();
|
||||
let mut drain_buf = vec![0u8; 2048];
|
||||
loop {
|
||||
match tunn.decapsulate(None, &[], &mut drain_buf) {
|
||||
TunnResult::WriteToNetwork(p) => { socket.send_to(p, peer).await.unwrap(); }
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to receive a UDP packet and decapsulate it. Returns decrypted IP data if any.
|
||||
async fn try_recv_decap(
|
||||
tunn: &mut Tunn,
|
||||
socket: &UdpSocket,
|
||||
timeout_ms: u64,
|
||||
) -> Option<(Vec<u8>, Ipv4Addr, SocketAddr)> {
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
let (n, src_addr) = match time::timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
socket.recv_from(&mut recv_buf),
|
||||
).await {
|
||||
Ok(Ok(r)) => r,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let result = tunn.decapsulate(Some(src_addr.ip()), &recv_buf[..n], &mut dst_buf);
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(tunn, pkt, socket, src_addr).await;
|
||||
None
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(pkt, addr) => Some((pkt.to_vec(), addr, src_addr)),
|
||||
TunnResult::WriteToTunnelV6(_, _) => None,
|
||||
TunnResult::Done => None,
|
||||
TunnResult::Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the full WireGuard handshake between client and server over real UDP.
|
||||
async fn do_handshake(
|
||||
client_tunn: &mut Tunn,
|
||||
server_tunn: &mut Tunn,
|
||||
client_socket: &UdpSocket,
|
||||
server_socket: &UdpSocket,
|
||||
server_addr: SocketAddr,
|
||||
) {
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
// Step 1: Client initiates handshake
|
||||
match client_tunn.encapsulate(&[], &mut buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => panic!("Expected handshake init"),
|
||||
}
|
||||
|
||||
// Step 2: Server receives init → sends response
|
||||
let (n, client_from) = server_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match server_tunn.decapsulate(Some(client_from.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(server_tunn, pkt, server_socket, client_from).await;
|
||||
}
|
||||
other => panic!("Expected WriteToNetwork from server, got variant {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Step 3: Client receives response
|
||||
let (n, _) = client_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match client_tunn.decapsulate(Some(server_addr.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(client_tunn, pkt, client_socket, server_addr).await;
|
||||
}
|
||||
TunnResult::Done => {}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Step 4: Process any remaining handshake packets
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 200).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 100).await;
|
||||
|
||||
// Step 5: Timer ticks to settle
|
||||
for _ in 0..3 {
|
||||
match server_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
server_socket.send_to(pkt, client_from).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match client_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 50).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 50).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn variant_name(r: &TunnResult) -> &'static str {
|
||||
match r {
|
||||
TunnResult::Done => "Done",
|
||||
TunnResult::Err(_) => "Err",
|
||||
TunnResult::WriteToNetwork(_) => "WriteToNetwork",
|
||||
TunnResult::WriteToTunnelV4(_, _) => "WriteToTunnelV4",
|
||||
TunnResult::WriteToTunnelV6(_, _) => "WriteToTunnelV6",
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an IP packet and send it, then loop-receive on the other side until decrypted.
|
||||
async fn send_and_expect_data(
|
||||
sender_tunn: &mut Tunn,
|
||||
receiver_tunn: &mut Tunn,
|
||||
sender_socket: &UdpSocket,
|
||||
receiver_socket: &UdpSocket,
|
||||
dest_addr: SocketAddr,
|
||||
ip_packet: &[u8],
|
||||
) -> (Vec<u8>, Ipv4Addr) {
|
||||
let mut enc_buf = vec![0u8; 65536];
|
||||
|
||||
match sender_tunn.encapsulate(ip_packet, &mut enc_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
sender_socket.send_to(pkt, dest_addr).await.unwrap();
|
||||
}
|
||||
TunnResult::Err(e) => panic!("Encapsulate failed: {:?}", e),
|
||||
other => panic!("Expected WriteToNetwork, got {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Receive — may need a few rounds for control packets
|
||||
for _ in 0..10 {
|
||||
if let Some((data, addr, _)) = try_recv_decap(receiver_tunn, receiver_socket, 1000).await {
|
||||
return (data, addr);
|
||||
}
|
||||
}
|
||||
panic!("Did not receive decrypted IP packet");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 1: Single client ↔ server bidirectional data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_single_client_bidirectional() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
let client_addr = client_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, None, None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, None, None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
// Client → Server
|
||||
let pkt_c2s = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"Hello from client!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt_c2s,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt_c2s.len()], &pkt_c2s[..]);
|
||||
|
||||
// Server → Client
|
||||
let pkt_s2c = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2), b"Hello from server!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut server_tunn, &mut client_tunn,
|
||||
&server_socket, &client_socket,
|
||||
client_addr, &pkt_s2c,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 1));
|
||||
assert_eq!(&decrypted[..pkt_s2c.len()], &pkt_s2c[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 2: Two clients ↔ one server (peer routing)
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_two_clients_peer_routing() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client1_pub_b64, client1_priv_b64) = generate_wg_keypair();
|
||||
let (client2_pub_b64, client2_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, _) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client1_public, client1_secret) = parse_key_pair(&client1_pub_b64, &client1_priv_b64);
|
||||
let (client2_public, client2_secret) = parse_key_pair(&client2_pub_b64, &client2_priv_b64);
|
||||
|
||||
// Separate server socket per peer to avoid UDP mux complexity in test
|
||||
let server_socket_1 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_socket_2 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client1_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client2_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr_1 = server_socket_1.local_addr().unwrap();
|
||||
let server_addr_2 = server_socket_2.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn_1 = Tunn::new(clone_secret(&server_priv_b64), client1_public, None, None, 0, None);
|
||||
let mut server_tunn_2 = Tunn::new(clone_secret(&server_priv_b64), client2_public, None, None, 1, None);
|
||||
let mut client1_tunn = Tunn::new(client1_secret, server_public.clone(), None, None, 2, None);
|
||||
let mut client2_tunn = Tunn::new(client2_secret, server_public, None, None, 3, None);
|
||||
|
||||
do_handshake(&mut client1_tunn, &mut server_tunn_1, &client1_socket, &server_socket_1, server_addr_1).await;
|
||||
do_handshake(&mut client2_tunn, &mut server_tunn_2, &client2_socket, &server_socket_2, server_addr_2).await;
|
||||
|
||||
// Client 1 → Server
|
||||
let pkt1 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"From client 1");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client1_tunn, &mut server_tunn_1,
|
||||
&client1_socket, &server_socket_1,
|
||||
server_addr_1, &pkt1,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt1.len()], &pkt1[..]);
|
||||
|
||||
// Client 2 → Server
|
||||
let pkt2 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 3), Ipv4Addr::new(10, 0, 0, 1), b"From client 2");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client2_tunn, &mut server_tunn_2,
|
||||
&client2_socket, &server_socket_2,
|
||||
server_addr_2, &pkt2,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 3));
|
||||
assert_eq!(&decrypted[..pkt2.len()], &pkt2[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 3: Preshared key handshake + data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_preshared_key() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let psk: [u8; 32] = rand::random();
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, Some(psk), None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, Some(psk), None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
let pkt = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"PSK-protected data");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt.len()], &pkt[..]);
|
||||
}
|
||||
284
test/test.flowcontrol.node.ts
Normal file
284
test/test.flowcontrol.node.ts
Normal file
@@ -0,0 +1,284 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let client: VpnClient;
|
||||
let clientBundle: IClientConfigBundle;
|
||||
const extraClients: VpnClient[] = [];
|
||||
const extraBundles: IClientConfigBundle[] = [];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server', async () => {
|
||||
serverPort = await findFreePort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
// Phase 1: start the daemon bridge
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
|
||||
// Phase 2: generate a keypair
|
||||
keypair = await server.generateKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
|
||||
// Phase 3: start the VPN listener (empty clients, will use createClient at runtime)
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
// Verify server is now running
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Phase 4: create the first client via the hub
|
||||
clientBundle = await server.createClient({ clientId: 'test-client-0' });
|
||||
expect(clientBundle.secrets.noisePrivateKey).toBeTypeofString();
|
||||
expect(clientBundle.smartvpnConfig.clientPublicKey).toBeTypeofString();
|
||||
});
|
||||
|
||||
tap.test('single client connects and gets IP', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: clientBundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: clientBundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
// Verify client status
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(1);
|
||||
expect(clients[0].assignedIp).toEqual(result.assignedIp);
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange', async () => {
|
||||
// Wait for at least 2 keepalive cycles (interval=3s, so 8s should be enough)
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
const serverStats = await server.getStatistics();
|
||||
expect(serverStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
expect(serverStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify per-client keepalive tracking
|
||||
const clients = await server.listClients();
|
||||
expect(clients[0].keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('connection quality telemetry', async () => {
|
||||
const quality = await client.getConnectionQuality();
|
||||
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
expect(quality.minRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.maxRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.lossRatio).toBeTypeofNumber();
|
||||
expect(['healthy', 'degraded', 'critical']).toContain(quality.linkHealth);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: set and verify', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
// Set a tight rate limit
|
||||
await server.setClientRateLimit(clientId, 100, 100);
|
||||
|
||||
// Verify via telemetry
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(100);
|
||||
expect(telemetry.burstBytes).toEqual(100);
|
||||
expect(telemetry.clientId).toEqual(clientId);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: removal', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
await server.removeClientRateLimit(clientId);
|
||||
|
||||
// Verify telemetry no longer shows rate limit
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
expect(telemetry.burstBytes).toBeNullOrUndefined();
|
||||
|
||||
// Connection still healthy
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('5 concurrent clients', async () => {
|
||||
const assignedIps = new Set<string>();
|
||||
|
||||
// Get the first client's IP
|
||||
const existingClients = await server.listClients();
|
||||
assignedIps.add(existingClients[0].assignedIp);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const bundle = await server.createClient({ clientId: `test-client-${i + 1}` });
|
||||
extraBundles.push(bundle);
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
const result = await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
assignedIps.add(result.assignedIp);
|
||||
extraClients.push(c);
|
||||
}
|
||||
|
||||
// All IPs should be unique (6 total: original + 5 new)
|
||||
expect(assignedIps.size).toEqual(6);
|
||||
|
||||
// Server should see 6 clients
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 6;
|
||||
});
|
||||
const allClients = await server.listClients();
|
||||
expect(allClients.length).toEqual(6);
|
||||
});
|
||||
|
||||
tap.test('client disconnect tracking', async () => {
|
||||
// Disconnect 3 of the 5 extra clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = extraClients[i];
|
||||
await c.disconnect();
|
||||
c.stop();
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 3;
|
||||
}, 15000);
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(3);
|
||||
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(6);
|
||||
});
|
||||
|
||||
tap.test('server-side client disconnection', async () => {
|
||||
const clients = await server.listClients();
|
||||
// Pick one of the remaining extra clients (not the original)
|
||||
const targetClient = clients.find((c) => {
|
||||
// Find a client that belongs to extraClients[3] or extraClients[4]
|
||||
return c.clientId !== clients[0].clientId;
|
||||
});
|
||||
expect(targetClient).toBeTruthy();
|
||||
|
||||
await server.disconnectClient(targetClient!.clientId);
|
||||
|
||||
// Wait for server to update
|
||||
await waitFor(async () => {
|
||||
const remaining = await server.listClients();
|
||||
return remaining.length === 2;
|
||||
});
|
||||
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(2);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop the original client
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
|
||||
// Stop remaining extra clients
|
||||
for (const c of extraClients) {
|
||||
if (c.running) {
|
||||
try {
|
||||
await c.disconnect();
|
||||
} catch {
|
||||
// May already be disconnected
|
||||
}
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Stop the server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
362
test/test.loadtest.node.ts
Normal file
362
test/test.loadtest.node.ts
Normal file
@@ -0,0 +1,362 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as stream from 'stream';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ThrottleProxy (adapted from remoteingress)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class ThrottleTransform extends stream.Transform {
|
||||
private bytesPerSec: number;
|
||||
private bucket: number;
|
||||
private lastRefill: number;
|
||||
private destroyed_: boolean = false;
|
||||
|
||||
constructor(bytesPerSecond: number) {
|
||||
super();
|
||||
this.bytesPerSec = bytesPerSecond;
|
||||
this.bucket = bytesPerSecond;
|
||||
this.lastRefill = Date.now();
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: stream.TransformCallback) {
|
||||
if (this.destroyed_) return;
|
||||
|
||||
const now = Date.now();
|
||||
const elapsed = (now - this.lastRefill) / 1000;
|
||||
this.bucket = Math.min(this.bytesPerSec, this.bucket + elapsed * this.bytesPerSec);
|
||||
this.lastRefill = now;
|
||||
|
||||
if (chunk.length <= this.bucket) {
|
||||
this.bucket -= chunk.length;
|
||||
callback(null, chunk);
|
||||
} else {
|
||||
const deficit = chunk.length - this.bucket;
|
||||
this.bucket = 0;
|
||||
const delayMs = Math.min((deficit / this.bytesPerSec) * 1000, 1000);
|
||||
setTimeout(() => {
|
||||
if (this.destroyed_) { callback(); return; }
|
||||
this.lastRefill = Date.now();
|
||||
this.bucket = 0;
|
||||
callback(null, chunk);
|
||||
}, delayMs);
|
||||
}
|
||||
}
|
||||
|
||||
_destroy(err: Error | null, callback: (error: Error | null) => void) {
|
||||
this.destroyed_ = true;
|
||||
callback(err);
|
||||
}
|
||||
}
|
||||
|
||||
interface ThrottleProxy {
|
||||
server: net.Server;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
async function startThrottleProxy(
|
||||
listenPort: number,
|
||||
targetHost: string,
|
||||
targetPort: number,
|
||||
bytesPerSecond: number,
|
||||
): Promise<ThrottleProxy> {
|
||||
const connections = new Set<net.Socket>();
|
||||
const server = net.createServer((clientSock) => {
|
||||
connections.add(clientSock);
|
||||
const upstream = net.createConnection({ host: targetHost, port: targetPort });
|
||||
connections.add(upstream);
|
||||
|
||||
const throttleUp = new ThrottleTransform(bytesPerSecond);
|
||||
const throttleDown = new ThrottleTransform(bytesPerSecond);
|
||||
|
||||
clientSock.pipe(throttleUp).pipe(upstream);
|
||||
upstream.pipe(throttleDown).pipe(clientSock);
|
||||
|
||||
let cleaned = false;
|
||||
const cleanup = () => {
|
||||
if (cleaned) return;
|
||||
cleaned = true;
|
||||
throttleUp.destroy();
|
||||
throttleDown.destroy();
|
||||
clientSock.destroy();
|
||||
upstream.destroy();
|
||||
connections.delete(clientSock);
|
||||
connections.delete(upstream);
|
||||
};
|
||||
clientSock.on('error', () => cleanup());
|
||||
upstream.on('error', () => cleanup());
|
||||
throttleUp.on('error', () => cleanup());
|
||||
throttleDown.on('error', () => cleanup());
|
||||
clientSock.on('close', () => cleanup());
|
||||
upstream.on('close', () => cleanup());
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => server.listen(listenPort, '127.0.0.1', resolve));
|
||||
return {
|
||||
server,
|
||||
close: async () => {
|
||||
for (const c of connections) c.destroy();
|
||||
connections.clear();
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let proxyPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let throttle: ThrottleProxy;
|
||||
const allClients: VpnClient[] = [];
|
||||
|
||||
let clientCounter = 0;
|
||||
async function createConnectedClient(port: number): Promise<VpnClient> {
|
||||
clientCounter++;
|
||||
const bundle = await server.createClient({ clientId: `load-client-${clientCounter}` });
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${port}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
allClients.push(c);
|
||||
return c;
|
||||
}
|
||||
|
||||
async function stopClient(c: VpnClient): Promise<void> {
|
||||
if (c.running) {
|
||||
try { await c.disconnect(); } catch { /* already disconnected */ }
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start throttled VPN tunnel (1 MB/s)', async () => {
|
||||
serverPort = await findFreePort();
|
||||
proxyPort = await findFreePort();
|
||||
|
||||
// Start VPN server
|
||||
server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Start throttle proxy: 1 MB/s
|
||||
throttle = await startThrottleProxy(proxyPort, '127.0.0.1', serverPort, 1024 * 1024);
|
||||
});
|
||||
|
||||
tap.test('throttled connection: handshake succeeds through throttle', async () => {
|
||||
const client = await createConnectedClient(proxyPort);
|
||||
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
expect(status.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
});
|
||||
|
||||
tap.test('sustained keepalive under throttle', async () => {
|
||||
// Wait for at least 2 keepalive cycles (3s interval)
|
||||
await delay(8000);
|
||||
|
||||
const client = allClients[0];
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Throttle adds latency — RTT should be measurable
|
||||
const quality = await client.getConnectionQuality();
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
});
|
||||
|
||||
tap.test('3 concurrent throttled clients', async () => {
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await createConnectedClient(proxyPort);
|
||||
}
|
||||
|
||||
// All 4 clients should be visible
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 4;
|
||||
});
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(4);
|
||||
|
||||
// Verify all IPs are unique
|
||||
const ips = new Set(clients.map((c) => c.assignedIp));
|
||||
expect(ips.size).toEqual(4);
|
||||
});
|
||||
|
||||
tap.test('rate limiting combined with network throttle', async () => {
|
||||
const clients = await server.listClients();
|
||||
const targetId = clients[0].clientId;
|
||||
|
||||
// Set rate limit on first client
|
||||
await server.setClientRateLimit(targetId, 500, 500);
|
||||
const telemetry = await server.getClientTelemetry(targetId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(500);
|
||||
expect(telemetry.burstBytes).toEqual(500);
|
||||
|
||||
// Verify another client has no rate limit
|
||||
const otherTelemetry = await server.getClientTelemetry(clients[1].clientId);
|
||||
expect(otherTelemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
|
||||
// Clean up the rate limit
|
||||
await server.removeClientRateLimit(targetId);
|
||||
});
|
||||
|
||||
tap.test('burst waves: 3 waves of 3 clients', async () => {
|
||||
const initialCount = (await server.listClients()).length;
|
||||
|
||||
for (let wave = 0; wave < 3; wave++) {
|
||||
const waveClients: VpnClient[] = [];
|
||||
|
||||
// Connect 3 clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = await createConnectedClient(proxyPort);
|
||||
waveClients.push(c);
|
||||
}
|
||||
|
||||
// Verify all connected
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount + 3;
|
||||
});
|
||||
|
||||
// Disconnect all wave clients
|
||||
for (const c of waveClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount;
|
||||
}, 15000);
|
||||
|
||||
await delay(500);
|
||||
}
|
||||
|
||||
// Verify total connections accumulated
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(9 + initialCount);
|
||||
|
||||
// Original clients still connected
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(initialCount);
|
||||
});
|
||||
|
||||
tap.test('aggressive throttle: 10 KB/s', async () => {
|
||||
// Close current throttle proxy and start an aggressive one
|
||||
await throttle.close();
|
||||
const aggressivePort = await findFreePort();
|
||||
throttle = await startThrottleProxy(aggressivePort, '127.0.0.1', serverPort, 10 * 1024);
|
||||
|
||||
// Connect a client through the aggressive throttle
|
||||
const client = await createConnectedClient(aggressivePort);
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Wait for keepalive exchange (might take longer due to throttle)
|
||||
await delay(10000);
|
||||
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('post-load health: direct connection still works', async () => {
|
||||
// Server should still be healthy after all load tests
|
||||
const serverStatus = await server.getStatus();
|
||||
expect(serverStatus.state).toEqual('connected');
|
||||
|
||||
// Connect one more client directly (no throttle)
|
||||
const directClient = await createConnectedClient(serverPort);
|
||||
const status = await directClient.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
await delay(5000);
|
||||
|
||||
const stats = await directClient.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop all clients
|
||||
for (const c of allClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Close throttle proxy
|
||||
if (throttle) {
|
||||
await throttle.close();
|
||||
}
|
||||
|
||||
// Stop server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
258
test/test.quic.node.ts
Normal file
258
test/test.quic.node.ts
Normal file
@@ -0,0 +1,258 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as dgram from 'dgram';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
async function findFreeUdpPort(): Promise<number> {
|
||||
const sock = dgram.createSocket('udp4');
|
||||
await new Promise<void>((resolve) => sock.bind(0, '127.0.0.1', resolve));
|
||||
const port = (sock.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => sock.close(resolve));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let wsPort: number;
|
||||
let quicPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: QUIC-only server + QUIC client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server in QUIC mode', async () => {
|
||||
quicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${quicPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.9.0.0/24',
|
||||
transportMode: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('QUIC client connects and gets IP', async () => {
|
||||
const bundle = await server.createClient({ clientId: 'quic-client-1' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${quicPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.9.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop QUIC server', async () => {
|
||||
await server.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: dual-mode server (both) + auto client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let dualServer: VpnServer;
|
||||
let dualWsPort: number;
|
||||
let dualQuicPort: number;
|
||||
let dualKeypair: IVpnKeypair;
|
||||
|
||||
tap.test('setup: start VPN server in both mode', async () => {
|
||||
dualWsPort = await findFreePort();
|
||||
dualQuicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
dualServer = new VpnServer(options);
|
||||
|
||||
const started = await dualServer['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
dualKeypair = await dualServer.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${dualWsPort}`,
|
||||
privateKey: dualKeypair.privateKey,
|
||||
publicKey: dualKeypair.publicKey,
|
||||
subnet: '10.10.0.0/24',
|
||||
transportMode: 'both',
|
||||
quicListenAddr: `127.0.0.1:${dualQuicPort}`,
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await dualServer['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await dualServer.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('auto client connects to dual-mode server (QUIC preferred)', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-auto-client' });
|
||||
|
||||
// "auto" mode (default): tries QUIC first at same host:port, falls back to WS
|
||||
// Since the WS port and QUIC port differ, auto will try QUIC on WS port (fail),
|
||||
// then fall back to WebSocket
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${dualWsPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
// transport defaults to 'auto'
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await dualServer.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-quic-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange over QUIC', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-keepalive-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
await client.start();
|
||||
|
||||
await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
// Wait for keepalive exchange
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop dual-mode server', async () => {
|
||||
await dualServer.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -2,10 +2,17 @@ import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { VpnConfig } from '../ts/index.js';
|
||||
import type { IVpnClientConfig, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// Valid 32-byte base64 keys for testing
|
||||
const TEST_KEY_A = 'YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE=';
|
||||
const TEST_KEY_B = 'YmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmI=';
|
||||
const TEST_KEY_C = 'Y2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2M=';
|
||||
|
||||
tap.test('VpnConfig: validate valid client config', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
@@ -16,7 +23,9 @@ tap.test('VpnConfig: validate valid client config', async () => {
|
||||
|
||||
tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
const config = {
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -31,7 +40,9 @@ tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'http://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -43,10 +54,28 @@ tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config without clientPrivateKey', async () => {
|
||||
const config = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('clientPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
mtu: 100,
|
||||
};
|
||||
let threw = false;
|
||||
@@ -62,7 +91,9 @@ tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['not-an-ip'],
|
||||
};
|
||||
let threw = false;
|
||||
@@ -78,12 +109,15 @@ tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
tap.test('VpnConfig: validate valid server config', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
enableNat: true,
|
||||
clients: [
|
||||
{ clientId: 'test-client', publicKey: TEST_KEY_C },
|
||||
],
|
||||
};
|
||||
// Should not throw
|
||||
VpnConfig.validateServerConfig(config);
|
||||
@@ -92,8 +126,8 @@ tap.test('VpnConfig: validate valid server config', async () => {
|
||||
tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: 'invalid',
|
||||
};
|
||||
let threw = false;
|
||||
@@ -109,7 +143,7 @@ tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
const config = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
} as IVpnServerConfig;
|
||||
let threw = false;
|
||||
@@ -122,4 +156,24 @@ tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject server config with invalid client publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
clients: [
|
||||
{ clientId: 'bad-client', publicKey: 'short-key' },
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
|
||||
353
test/test.wireguard.node.ts
Normal file
353
test/test.wireguard.node.ts
Normal file
@@ -0,0 +1,353 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import {
|
||||
VpnConfig,
|
||||
VpnServer,
|
||||
WgConfigGenerator,
|
||||
} from '../ts/index.js';
|
||||
import type {
|
||||
IVpnClientConfig,
|
||||
IVpnServerConfig,
|
||||
IVpnServerOptions,
|
||||
IWgPeerConfig,
|
||||
} from '../ts/index.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — client
|
||||
// ============================================================================
|
||||
|
||||
// A valid 32-byte key in base64 (44 chars)
|
||||
const VALID_KEY = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=';
|
||||
const VALID_KEY_2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=';
|
||||
|
||||
tap.test('WG client config: valid wireguard config passes validation', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '', // not needed for WG
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['0.0.0.0/0'],
|
||||
};
|
||||
VpnConfig.validateClientConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgPrivateKey', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgAddress', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgAddress');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgEndpoint', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgEndpoint');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid key length', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: 'tooshort',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('44 characters');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid CIDR in allowedIps', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['not-a-cidr'],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('CIDR');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — server
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WG server config: valid config passes validation', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
VpnConfig.validateServerConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects empty wgPeers', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPeers');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects peer without publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: '',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects invalid wgListenPort', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgListenPort: 0,
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgListenPort');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard keypair generation via daemon
|
||||
// ============================================================================
|
||||
|
||||
let server: VpnServer;
|
||||
|
||||
tap.test('WG: spawn server daemon for keypair generation', async () => {
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns valid keypair', async () => {
|
||||
const keypair = await server.generateWgKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
// WireGuard keys: base64 of 32 bytes = 44 characters
|
||||
expect(keypair.publicKey.length).toEqual(44);
|
||||
expect(keypair.privateKey.length).toEqual(44);
|
||||
// Verify they decode to 32 bytes
|
||||
const pubBuf = Buffer.from(keypair.publicKey, 'base64');
|
||||
const privBuf = Buffer.from(keypair.privateKey, 'base64');
|
||||
expect(pubBuf.length).toEqual(32);
|
||||
expect(privBuf.length).toEqual(32);
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns unique keys each time', async () => {
|
||||
const kp1 = await server.generateWgKeypair();
|
||||
const kp2 = await server.generateWgKeypair();
|
||||
expect(kp1.publicKey).not.toEqual(kp2.publicKey);
|
||||
expect(kp1.privateKey).not.toEqual(kp2.privateKey);
|
||||
});
|
||||
|
||||
tap.test('WG: stop server daemon', async () => {
|
||||
server.stop();
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config file generation
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'clientPrivateKeyBase64====================',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
peer: {
|
||||
publicKey: 'serverPublicKeyBase64====================',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0', '::/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('PrivateKey = clientPrivateKeyBase64====================');
|
||||
expect(conf).toContain('Address = 10.8.0.2/24');
|
||||
expect(conf).toContain('DNS = 1.1.1.1, 8.8.8.8');
|
||||
expect(conf).toContain('MTU = 1420');
|
||||
expect(conf).toContain('[Peer]');
|
||||
expect(conf).toContain('PublicKey = serverPublicKeyBase64====================');
|
||||
expect(conf).toContain('Endpoint = vpn.example.com:51820');
|
||||
expect(conf).toContain('AllowedIPs = 0.0.0.0/0, ::/0');
|
||||
expect(conf).toContain('PersistentKeepalive = 25');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config without optional fields', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'key1',
|
||||
address: '10.0.0.2/32',
|
||||
peer: {
|
||||
publicKey: 'key2',
|
||||
endpoint: 'server:51820',
|
||||
allowedIps: ['10.0.0.0/24'],
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).not.toContain('DNS');
|
||||
expect(conf).not.toContain('MTU');
|
||||
expect(conf).not.toContain('PresharedKey');
|
||||
expect(conf).not.toContain('PersistentKeepalive');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config with NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
dns: ['1.1.1.1'],
|
||||
enableNat: true,
|
||||
natInterface: 'ens3',
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peer1PubKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
presharedKey: 'psk1',
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
{
|
||||
publicKey: 'peer2PubKey',
|
||||
allowedIps: ['10.8.0.3/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('ListenPort = 51820');
|
||||
expect(conf).toContain('PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ens3 -j MASQUERADE');
|
||||
expect(conf).toContain('PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ens3 -j MASQUERADE');
|
||||
// Two [Peer] sections
|
||||
const peerCount = (conf.match(/\[Peer\]/g) || []).length;
|
||||
expect(peerCount).toEqual(2);
|
||||
expect(conf).toContain('PresharedKey = psk1');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config without NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peerKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).not.toContain('PostUp');
|
||||
expect(conf).not.toContain('PostDown');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartvpn',
|
||||
version: '1.0.2',
|
||||
version: '1.8.0',
|
||||
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
|
||||
}
|
||||
|
||||
@@ -4,3 +4,4 @@ export { VpnClient } from './smartvpn.classes.vpnclient.js';
|
||||
export { VpnServer } from './smartvpn.classes.vpnserver.js';
|
||||
export { VpnConfig } from './smartvpn.classes.vpnconfig.js';
|
||||
export { VpnInstaller } from './smartvpn.classes.vpninstaller.js';
|
||||
export { WgConfigGenerator } from './smartvpn.classes.wgconfig.js';
|
||||
|
||||
@@ -5,6 +5,8 @@ import type {
|
||||
IVpnClientConfig,
|
||||
IVpnStatus,
|
||||
IVpnStatistics,
|
||||
IVpnConnectionQuality,
|
||||
IVpnMtuInfo,
|
||||
TVpnClientCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -65,12 +67,26 @@ export class VpnClient extends plugins.events.EventEmitter {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get traffic statistics.
|
||||
* Get traffic statistics (includes connection quality when connected).
|
||||
*/
|
||||
public async getStatistics(): Promise<IVpnStatistics> {
|
||||
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection quality metrics (RTT, jitter, loss, link health).
|
||||
*/
|
||||
public async getConnectionQuality(): Promise<IVpnConnectionQuality> {
|
||||
return this.bridge.sendCommand('getConnectionQuality', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get MTU information (overhead, effective MTU, oversized packet stats).
|
||||
*/
|
||||
public async getMtuInfo(): Promise<IVpnMtuInfo> {
|
||||
return this.bridge.sendCommand('getMtuInfo', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
@@ -12,14 +12,54 @@ export class VpnConfig {
|
||||
* Validate a client config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateClientConfig(config: IVpnClientConfig): void {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws://');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
if (config.transport === 'wireguard') {
|
||||
// WireGuard-specific validation
|
||||
if (!config.wgPrivateKey) {
|
||||
throw new Error('VpnConfig: wgPrivateKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.wgPrivateKey, 'wgPrivateKey');
|
||||
if (!config.wgAddress) {
|
||||
throw new Error('VpnConfig: wgAddress is required for WireGuard transport');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.serverPublicKey, 'serverPublicKey');
|
||||
if (!config.wgEndpoint) {
|
||||
throw new Error('VpnConfig: wgEndpoint is required for WireGuard transport');
|
||||
}
|
||||
if (config.wgPresharedKey) {
|
||||
VpnConfig.validateBase64Key(config.wgPresharedKey, 'wgPresharedKey');
|
||||
}
|
||||
if (config.wgAllowedIps) {
|
||||
for (const cidr of config.wgAllowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
// For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss://
|
||||
if (config.transport !== 'quic') {
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)');
|
||||
}
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
}
|
||||
// Noise IK requires client keypair
|
||||
if (!config.clientPrivateKey) {
|
||||
throw new Error('VpnConfig: clientPrivateKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPrivateKey, 'clientPrivateKey');
|
||||
if (!config.clientPublicKey) {
|
||||
throw new Error('VpnConfig: clientPublicKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPublicKey, 'clientPublicKey');
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -40,20 +80,63 @@ export class VpnConfig {
|
||||
* Validate a server config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateServerConfig(config: IVpnServerConfig): void {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
if (config.transportMode === 'wireguard') {
|
||||
// WireGuard server validation
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.privateKey, 'privateKey');
|
||||
if (!config.wgPeers || config.wgPeers.length === 0) {
|
||||
throw new Error('VpnConfig: at least one wgPeers entry is required for WireGuard mode');
|
||||
}
|
||||
for (const peer of config.wgPeers) {
|
||||
if (!peer.publicKey) {
|
||||
throw new Error('VpnConfig: peer publicKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(peer.publicKey, 'peer.publicKey');
|
||||
if (!peer.allowedIps || peer.allowedIps.length === 0) {
|
||||
throw new Error('VpnConfig: peer allowedIps is required');
|
||||
}
|
||||
for (const cidr of peer.allowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid peer allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
if (peer.presharedKey) {
|
||||
VpnConfig.validateBase64Key(peer.presharedKey, 'peer.presharedKey');
|
||||
}
|
||||
}
|
||||
if (config.wgListenPort !== undefined && (config.wgListenPort < 1 || config.wgListenPort > 65535)) {
|
||||
throw new Error('VpnConfig: wgListenPort must be between 1 and 65535');
|
||||
}
|
||||
} else {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
}
|
||||
// Validate client entries if provided
|
||||
if (config.clients) {
|
||||
for (const client of config.clients) {
|
||||
if (!client.clientId) {
|
||||
throw new Error('VpnConfig: client entry must have a clientId');
|
||||
}
|
||||
if (!client.publicKey) {
|
||||
throw new Error(`VpnConfig: client '${client.clientId}' must have a publicKey`);
|
||||
}
|
||||
VpnConfig.validateBase64Key(client.publicKey, `client '${client.clientId}' publicKey`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -101,4 +184,41 @@ export class VpnConfig {
|
||||
const prefixNum = parseInt(prefix, 10);
|
||||
return !isNaN(prefixNum) && prefixNum >= 0 && prefixNum <= 32;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a CIDR string (IPv4 or IPv6).
|
||||
*/
|
||||
private static isValidCidr(cidr: string): boolean {
|
||||
const parts = cidr.split('/');
|
||||
if (parts.length !== 2) return false;
|
||||
const prefixNum = parseInt(parts[1], 10);
|
||||
if (isNaN(prefixNum) || prefixNum < 0) return false;
|
||||
// IPv4
|
||||
if (VpnConfig.isValidIp(parts[0])) {
|
||||
return prefixNum <= 32;
|
||||
}
|
||||
// IPv6 (basic check)
|
||||
if (parts[0].includes(':')) {
|
||||
return prefixNum <= 128;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a base64-encoded 32-byte key (WireGuard X25519 format).
|
||||
*/
|
||||
private static validateBase64Key(key: string, fieldName: string): void {
|
||||
if (key.length !== 44) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must be 44 characters (base64 of 32 bytes), got ${key.length}`);
|
||||
}
|
||||
try {
|
||||
const buf = Buffer.from(key, 'base64');
|
||||
if (buf.length !== 32) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must decode to 32 bytes, got ${buf.length}`);
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.startsWith('VpnConfig:')) throw e;
|
||||
throw new Error(`VpnConfig: ${fieldName} is not valid base64`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,11 @@ import type {
|
||||
IVpnServerStatistics,
|
||||
IVpnClientInfo,
|
||||
IVpnKeypair,
|
||||
IVpnClientTelemetry,
|
||||
IWgPeerConfig,
|
||||
IWgPeerInfo,
|
||||
IClientEntry,
|
||||
IClientConfigBundle,
|
||||
TVpnServerCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -91,6 +96,139 @@ export class VpnServer extends plugins.events.EventEmitter {
|
||||
return this.bridge.sendCommand('generateKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set rate limit for a specific client.
|
||||
*/
|
||||
public async setClientRateLimit(
|
||||
clientId: string,
|
||||
rateBytesPerSec: number,
|
||||
burstBytes: number,
|
||||
): Promise<void> {
|
||||
await this.bridge.sendCommand('setClientRateLimit', {
|
||||
clientId,
|
||||
rateBytesPerSec,
|
||||
burstBytes,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove rate limit for a specific client (unlimited).
|
||||
*/
|
||||
public async removeClientRateLimit(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClientRateLimit', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get telemetry for a specific client.
|
||||
*/
|
||||
public async getClientTelemetry(clientId: string): Promise<IVpnClientTelemetry> {
|
||||
return this.bridge.sendCommand('getClientTelemetry', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a WireGuard-compatible X25519 keypair.
|
||||
*/
|
||||
public async generateWgKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateWgKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a WireGuard peer (server must be running in wireguard mode).
|
||||
*/
|
||||
public async addWgPeer(peer: IWgPeerConfig): Promise<void> {
|
||||
await this.bridge.sendCommand('addWgPeer', { peer });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a WireGuard peer by public key.
|
||||
*/
|
||||
public async removeWgPeer(publicKey: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeWgPeer', { publicKey });
|
||||
}
|
||||
|
||||
/**
|
||||
* List WireGuard peers with stats.
|
||||
*/
|
||||
public async listWgPeers(): Promise<IWgPeerInfo[]> {
|
||||
const result = await this.bridge.sendCommand('listWgPeers', {} as Record<string, never>);
|
||||
return result.peers;
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ─────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a new client. Generates keypairs, assigns IP, returns full config bundle.
|
||||
* The secrets (private keys) are only returned at creation time.
|
||||
*/
|
||||
public async createClient(opts: Partial<IClientEntry>): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('createClient', { client: opts });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a registered client (also disconnects if connected).
|
||||
*/
|
||||
public async removeClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a registered client by ID.
|
||||
*/
|
||||
public async getClient(clientId: string): Promise<IClientEntry> {
|
||||
return this.bridge.sendCommand('getClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* List all registered clients.
|
||||
*/
|
||||
public async listRegisteredClients(): Promise<IClientEntry[]> {
|
||||
const result = await this.bridge.sendCommand('listRegisteredClients', {} as Record<string, never>);
|
||||
return result.clients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a registered client's fields (ACLs, tags, description, etc.).
|
||||
*/
|
||||
public async updateClient(clientId: string, update: Partial<IClientEntry>): Promise<void> {
|
||||
await this.bridge.sendCommand('updateClient', { clientId, update });
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a previously disabled client.
|
||||
*/
|
||||
public async enableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('enableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable a client (also disconnects if connected).
|
||||
*/
|
||||
public async disableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('disableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
*/
|
||||
public async rotateClientKey(clientId: string): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('rotateClientKey', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Export a client config (without secrets) in the specified format.
|
||||
*/
|
||||
public async exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): Promise<string> {
|
||||
const result = await this.bridge.sendCommand('exportClientConfig', { clientId, format });
|
||||
return result.config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a standalone Noise IK keypair (not tied to a client).
|
||||
*/
|
||||
public async generateClientKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateClientKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
123
ts/smartvpn.classes.wgconfig.ts
Normal file
123
ts/smartvpn.classes.wgconfig.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import type { IWgPeerConfig } from './smartvpn.interfaces.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard .conf file generator
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgClientConfOptions {
|
||||
/** Client private key (base64) */
|
||||
privateKey: string;
|
||||
/** Client TUN address with prefix (e.g. 10.8.0.2/24) */
|
||||
address: string;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Server peer config */
|
||||
peer: {
|
||||
publicKey: string;
|
||||
presharedKey?: string;
|
||||
endpoint: string;
|
||||
allowedIps: string[];
|
||||
persistentKeepalive?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface IWgServerConfOptions {
|
||||
/** Server private key (base64) */
|
||||
privateKey: string;
|
||||
/** Server TUN address with prefix (e.g. 10.8.0.1/24) */
|
||||
address: string;
|
||||
/** UDP listen port */
|
||||
listenPort: number;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Enable NAT — adds PostUp/PostDown iptables rules */
|
||||
enableNat?: boolean;
|
||||
/** Network interface for NAT (e.g. eth0). Auto-detected if omitted. */
|
||||
natInterface?: string;
|
||||
/** Configured peers */
|
||||
peers: IWgPeerConfig[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates standard WireGuard .conf files compatible with wg-quick,
|
||||
* WireGuard iOS/Android apps, and other standard WireGuard clients.
|
||||
*/
|
||||
export class WgConfigGenerator {
|
||||
/**
|
||||
* Generate a client .conf file content.
|
||||
*/
|
||||
public static generateClientConfig(opts: IWgClientConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${opts.peer.publicKey}`);
|
||||
if (opts.peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${opts.peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`Endpoint = ${opts.peer.endpoint}`);
|
||||
lines.push(`AllowedIPs = ${opts.peer.allowedIps.join(', ')}`);
|
||||
if (opts.peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${opts.peer.persistentKeepalive}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a server .conf file content.
|
||||
*/
|
||||
public static generateServerConfig(opts: IWgServerConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
lines.push(`ListenPort = ${opts.listenPort}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
if (opts.enableNat) {
|
||||
const iface = opts.natInterface || 'eth0';
|
||||
lines.push(`PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
lines.push(`PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
}
|
||||
|
||||
for (const peer of opts.peers) {
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${peer.publicKey}`);
|
||||
if (peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`AllowedIPs = ${peer.allowedIps.join(', ')}`);
|
||||
if (peer.endpoint) {
|
||||
lines.push(`Endpoint = ${peer.endpoint}`);
|
||||
}
|
||||
if (peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${peer.persistentKeepalive}`);
|
||||
}
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
}
|
||||
@@ -24,14 +24,36 @@ export type TVpnTransportOptions = IVpnTransportStdio | IVpnTransportSocket;
|
||||
export interface IVpnClientConfig {
|
||||
/** Server WebSocket URL, e.g. wss://vpn.example.com/tunnel */
|
||||
serverUrl: string;
|
||||
/** Server's static public key (base64) for Noise NK handshake */
|
||||
/** Server's static public key (base64) for Noise IK handshake */
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — required for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
/** Optional DNS servers to use while connected */
|
||||
dns?: string[];
|
||||
/** Optional MTU for the TUN device */
|
||||
mtu?: number;
|
||||
/** Keepalive interval in seconds (default: 30) */
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', 'quic', or 'wireguard' */
|
||||
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
|
||||
/** For QUIC: SHA-256 hash of server certificate (base64) for cert pinning */
|
||||
serverCertHash?: string;
|
||||
/** WireGuard: client private key (base64, X25519) */
|
||||
wgPrivateKey?: string;
|
||||
/** WireGuard: client TUN address (e.g. 10.8.0.2) */
|
||||
wgAddress?: string;
|
||||
/** WireGuard: client TUN address prefix length (default: 24) */
|
||||
wgAddressPrefix?: number;
|
||||
/** WireGuard: preshared key (base64, optional) */
|
||||
wgPresharedKey?: string;
|
||||
/** WireGuard: persistent keepalive interval in seconds */
|
||||
wgPersistentKeepalive?: number;
|
||||
/** WireGuard: server endpoint (host:port, e.g. vpn.example.com:51820) */
|
||||
wgEndpoint?: string;
|
||||
/** WireGuard: allowed IPs (CIDR strings, e.g. ['0.0.0.0/0']) */
|
||||
wgAllowedIps?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnClientOptions {
|
||||
@@ -64,6 +86,22 @@ export interface IVpnServerConfig {
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Enable NAT/masquerade for client traffic */
|
||||
enableNat?: boolean;
|
||||
/** Default rate limit for new clients (bytes/sec). Omit for unlimited. */
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
/** Default burst size for new clients (bytes). Omit for unlimited. */
|
||||
defaultBurstBytes?: number;
|
||||
/** Transport mode: 'both' (default, WS+QUIC), 'websocket', 'quic', or 'wireguard' */
|
||||
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
|
||||
/** QUIC listen address (host:port). Defaults to listenAddr. */
|
||||
quicListenAddr?: string;
|
||||
/** QUIC idle timeout in seconds (default: 30) */
|
||||
quicIdleTimeoutSecs?: number;
|
||||
/** WireGuard: UDP listen port (default: 51820) */
|
||||
wgListenPort?: number;
|
||||
/** WireGuard: configured peers */
|
||||
wgPeers?: IWgPeerConfig[];
|
||||
/** Pre-registered clients for Noise IK authentication */
|
||||
clients?: IClientEntry[];
|
||||
}
|
||||
|
||||
export interface IVpnServerOptions {
|
||||
@@ -99,6 +137,7 @@ export interface IVpnStatistics {
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
quality?: IVpnConnectionQuality;
|
||||
}
|
||||
|
||||
export interface IVpnClientInfo {
|
||||
@@ -107,6 +146,16 @@ export interface IVpnClientInfo {
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
/** Client's authenticated Noise IK public key (base64) */
|
||||
authenticatedKey: string;
|
||||
/** Registered client ID from the client registry */
|
||||
registeredClientId: string;
|
||||
}
|
||||
|
||||
export interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -119,6 +168,160 @@ export interface IVpnKeypair {
|
||||
privateKey: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: Connection quality
|
||||
// ============================================================================
|
||||
|
||||
export type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
|
||||
|
||||
export interface IVpnConnectionQuality {
|
||||
srttMs: number;
|
||||
jitterMs: number;
|
||||
minRttMs: number;
|
||||
maxRttMs: number;
|
||||
lossRatio: number;
|
||||
consecutiveTimeouts: number;
|
||||
linkHealth: TVpnLinkHealth;
|
||||
currentKeepaliveIntervalSecs: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: MTU info
|
||||
// ============================================================================
|
||||
|
||||
export interface IVpnMtuInfo {
|
||||
tunMtu: number;
|
||||
effectiveMtu: number;
|
||||
linkMtu: number;
|
||||
overheadBytes: number;
|
||||
oversizedPacketsDropped: number;
|
||||
icmpTooBigSent: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: Client telemetry (server-side per-client)
|
||||
// ============================================================================
|
||||
|
||||
export interface IVpnClientTelemetry {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
bytesReceived: number;
|
||||
bytesSent: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client Registry (Hub) types — aligned with SmartProxy IRouteSecurity pattern
|
||||
// ============================================================================
|
||||
|
||||
/** Per-client rate limiting. */
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Per-client security settings.
|
||||
* Mirrors SmartProxy's IRouteSecurity: ipAllowList/ipBlockList naming + deny-overrides-allow.
|
||||
* Adds VPN-specific destination filtering.
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5). */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins). */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all). */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins). */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client. */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting. */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server-side client definition — the central config object for the Hub.
|
||||
* Naming and structure aligned with SmartProxy's IRouteConfig / IRouteSecurity.
|
||||
*/
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
/** Security settings (ACLs, rate limits) */
|
||||
security?: IClientSecurity;
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
/** Assigned VPN IP address (set by server) */
|
||||
assignedIp?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete client config bundle — returned by createClient() and rotateClientKey().
|
||||
* Contains everything the client needs to connect.
|
||||
*/
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard-specific types
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgPeerConfig {
|
||||
/** Peer's public key (base64, X25519) */
|
||||
publicKey: string;
|
||||
/** Optional preshared key (base64) */
|
||||
presharedKey?: string;
|
||||
/** Allowed IP ranges (CIDR strings) */
|
||||
allowedIps: string[];
|
||||
/** Peer endpoint (host:port) — optional for server peers, required for client */
|
||||
endpoint?: string;
|
||||
/** Persistent keepalive interval in seconds */
|
||||
persistentKeepalive?: number;
|
||||
}
|
||||
|
||||
export interface IWgPeerInfo {
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
persistentKeepalive?: number;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
lastHandshakeTime?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// IPC Command maps (used by smartrust RustBridge<TCommands>)
|
||||
// ============================================================================
|
||||
@@ -128,6 +331,8 @@ export type TVpnClientCommands = {
|
||||
disconnect: { params: Record<string, never>; result: void };
|
||||
getStatus: { params: Record<string, never>; result: IVpnStatus };
|
||||
getStatistics: { params: Record<string, never>; result: IVpnStatistics };
|
||||
getConnectionQuality: { params: Record<string, never>; result: IVpnConnectionQuality };
|
||||
getMtuInfo: { params: Record<string, never>; result: IVpnMtuInfo };
|
||||
};
|
||||
|
||||
export type TVpnServerCommands = {
|
||||
@@ -138,6 +343,24 @@ export type TVpnServerCommands = {
|
||||
listClients: { params: Record<string, never>; result: { clients: IVpnClientInfo[] } };
|
||||
disconnectClient: { params: { clientId: string }; result: void };
|
||||
generateKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
|
||||
removeClientRateLimit: { params: { clientId: string }; result: void };
|
||||
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
|
||||
generateWgKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
addWgPeer: { params: { peer: IWgPeerConfig }; result: void };
|
||||
removeWgPeer: { params: { publicKey: string }; result: void };
|
||||
listWgPeers: { params: Record<string, never>; result: { peers: IWgPeerInfo[] } };
|
||||
// Client Registry (Hub) commands
|
||||
createClient: { params: { client: Partial<IClientEntry> }; result: IClientConfigBundle };
|
||||
removeClient: { params: { clientId: string }; result: void };
|
||||
getClient: { params: { clientId: string }; result: IClientEntry };
|
||||
listRegisteredClients: { params: Record<string, never>; result: { clients: IClientEntry[] } };
|
||||
updateClient: { params: { clientId: string; update: Partial<IClientEntry> }; result: void };
|
||||
enableClient: { params: { clientId: string }; result: void };
|
||||
disableClient: { params: { clientId: string }; result: void };
|
||||
rotateClientKey: { params: { clientId: string }; result: IClientConfigBundle };
|
||||
exportClientConfig: { params: { clientId: string; format: 'smartvpn' | 'wireguard' }; result: { config: string } };
|
||||
generateClientKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"esModuleInterop": true,
|
||||
"verbatimModuleSyntax": true
|
||||
"verbatimModuleSyntax": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"exclude": [
|
||||
"dist_ts/**/*.d.ts"
|
||||
|
||||
Reference in New Issue
Block a user