Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c97beed6e0 | |||
| c3cc237db5 | |||
| 17c27a92d6 | |||
| 9d105e8034 | |||
| e9cf575271 | |||
| 229db4be38 | |||
| e31086d0c2 | |||
| 01a0d8b9f4 | |||
| 187a69028b | |||
| 64dedd389e | |||
| 13d8cbe3fa | |||
| f46ea70286 | |||
| 26ee3634c8 | |||
| 049fa00563 | |||
| e4e59d72f9 | |||
| 51d33127bf | |||
| a4ba6806e5 | |||
| 6330921160 | |||
| e81dd377d8 | |||
| e14c357ba0 | |||
| eb30825f72 | |||
| 835f0f791d | |||
| aec545fe8c | |||
| 4fab721d87 | |||
| 9ee41348e0 | |||
| 97bb148063 | |||
| c8d572b719 | |||
| a46188ce07 | |||
| 9e1ec93814 |
102
changelog.md
102
changelog.md
@@ -1,5 +1,107 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-03-29 - 1.10.1 - fix(test, docs, scripts)
|
||||
correct test command verbosity, shorten load test timings, and document forwarding modes
|
||||
|
||||
- Fixes the test script by removing the duplicated verbose flag in package.json.
|
||||
- Reduces load test delays and burst sizes to keep keepalive and connection tests faster and more stable.
|
||||
- Updates the README to describe forwardingMode options, userspace NAT support, and related configuration examples.
|
||||
|
||||
## 2026-03-29 - 1.10.0 - feat(rust-server, rust-client, ts-interfaces)
|
||||
add configurable packet forwarding with TUN and userspace NAT modes
|
||||
|
||||
- introduce forwardingMode options for client and server configuration interfaces
|
||||
- add server-side forwarding engines for kernel TUN, userspace socket NAT, and testing mode
|
||||
- add a smoltcp-based userspace NAT implementation for packet forwarding without root-only TUN routing
|
||||
- enable client-side TUN forwarding support with route setup, packet I/O, and cleanup
|
||||
- centralize raw packet destination IP extraction in tunnel utilities for shared routing logic
|
||||
- update test command timeout and logging flags
|
||||
|
||||
## 2026-03-29 - 1.9.0 - feat(server)
|
||||
add PROXY protocol v2 support for real client IP handling and connection ACLs
|
||||
|
||||
- add PROXY protocol v2 parsing for WebSocket connections, including IPv4/IPv6 support, LOCAL command handling, and header read timeout protection
|
||||
- apply server-level connection IP block lists before the Noise handshake and enforce per-client source IP allow/block lists using the resolved remote address
|
||||
- expose proxy protocol configuration and remote client address fields in Rust and TypeScript interfaces, and document reverse-proxy usage in the README
|
||||
|
||||
## 2026-03-29 - 1.8.0 - feat(auth,client-registry)
|
||||
add Noise IK client authentication with managed client registry and per-client ACL controls
|
||||
|
||||
- switch the native tunnel handshake from Noise NK to Noise IK and require client keypairs in client configuration
|
||||
- add server-side client registry management APIs for creating, updating, disabling, rotating, listing, and exporting client configs
|
||||
- enforce client authorization from the registry during handshake and expose authenticated client metadata in server client info
|
||||
- introduce per-client security policies with source/destination ACLs and per-client rate limit settings
|
||||
- add Rust ACL matching support for exact IPs, CIDR ranges, wildcards, and IP ranges with test coverage
|
||||
|
||||
## 2026-03-29 - 1.7.0 - feat(rust-tests)
|
||||
add end-to-end WireGuard UDP integration tests and align TypeScript build configuration
|
||||
|
||||
- Add userspace Rust end-to-end tests that validate WireGuard handshake, encryption, peer isolation, and preshared-key data exchange over real UDP sockets.
|
||||
- Update the TypeScript build setup by removing the allowimplicitany build flag and explicitly including Node types in tsconfig.
|
||||
- Refresh development toolchain versions to support the updated test and build workflow.
|
||||
|
||||
## 2026-03-29 - 1.6.0 - feat(readme)
|
||||
document WireGuard transport support, configuration, and usage examples
|
||||
|
||||
- Expand the README from dual-transport to triple-transport support by adding WireGuard alongside WebSocket and QUIC
|
||||
- Add client and server WireGuard examples, including live peer management and .conf generation with WgConfigGenerator
|
||||
- Document new WireGuard-related API methods, config fields, transport modes, and security model details
|
||||
|
||||
## 2026-03-29 - 1.5.0 - feat(wireguard)
|
||||
add WireGuard transport support with management APIs and config generation
|
||||
|
||||
- add Rust WireGuard module integration using boringtun and route management through client/server management handlers
|
||||
- extend TypeScript client and server configuration schemas with WireGuard-specific options and validation
|
||||
- add server-side WireGuard peer management commands including keypair generation, peer add/remove, and peer listing
|
||||
- export a WireGuard config generator for producing client and server .conf files
|
||||
- add WireGuard-focused test coverage for config validation and config generation
|
||||
|
||||
## 2026-03-21 - 1.4.1 - fix(readme)
|
||||
preserve markdown line breaks in feature list
|
||||
|
||||
- Adds trailing spaces to the README feature list so each highlighted capability renders on its own line.
|
||||
|
||||
## 2026-03-19 - 1.4.0 - feat(vpn transport)
|
||||
add QUIC transport support with auto fallback to WebSocket
|
||||
|
||||
- introduces a transport abstraction in the Rust daemon so client and server can operate over WebSocket or QUIC
|
||||
- adds dual-mode server configuration with websocket, quic, and both transport modes plus QUIC idle timeout and listen address options
|
||||
- adds client transport selection with auto mode that attempts QUIC first and falls back to WebSocket
|
||||
- adds QUIC certificate hash pinning support and required Rust dependencies for QUIC and TLS
|
||||
- updates TypeScript interfaces, config validation, tests, and documentation to cover the new transport modes
|
||||
|
||||
## 2026-03-17 - 1.3.0 - feat(tests,client)
|
||||
add flow control and load test coverage and honor configured keepalive intervals
|
||||
|
||||
- Adds end-to-end node tests for client/server flow control, keepalive exchange, connection quality telemetry, rate limiting, concurrent clients, and disconnect tracking.
|
||||
- Adds load testing with throttled proxy scenarios to validate behavior under constrained bandwidth and repeated client churn.
|
||||
- Updates the Rust client to pass configured keepaliveIntervalSecs into the adaptive keepalive monitor instead of always using defaults.
|
||||
|
||||
## 2026-03-15 - 1.2.0 - feat(readme)
|
||||
document QoS, telemetry, MTU, and rate limiting capabilities in the README
|
||||
|
||||
- Expand the architecture and feature overview to cover adaptive keepalive, telemetry, QoS, rate limiting, and MTU handling
|
||||
- Update client and server examples to show new APIs such as getConnectionQuality(), getMtuInfo(), setClientRateLimit(), and getClientTelemetry()
|
||||
- Add TypeScript interface documentation for connection quality, MTU info, enriched client statistics, and per-client telemetry
|
||||
|
||||
## 2026-03-15 - 1.1.0 - feat(rust-core)
|
||||
add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
|
||||
|
||||
- adds adaptive keepalive monitoring with RTT, jitter, loss, and link health reporting to client statistics and management endpoints
|
||||
- introduces MTU overhead calculation and oversized-packet handling support, plus client MTU info APIs
|
||||
- adds token-bucket rate limiting with configurable default limits and server management commands to set, remove, and inspect per-client telemetry
|
||||
- extends TypeScript client and server interfaces with connection quality, MTU, and client telemetry methods
|
||||
|
||||
## 2026-02-27 - 1.0.3 - fix(build)
|
||||
add aarch64 linker configuration for cross-compilation
|
||||
|
||||
- Added rust/.cargo/config.toml to configure linker for target aarch64-unknown-linux-gnu
|
||||
- Sets linker to 'aarch64-linux-gnu-gcc' to enable cross-compilation to ARM64
|
||||
|
||||
## 2026-02-27 - 1.0.2 - fix()
|
||||
no changes detected - no code or content modifications
|
||||
|
||||
|
||||
## 2026-02-27 - 1.0.1 - fix(release)
|
||||
bump patch version (no code changes)
|
||||
|
||||
|
||||
21
package.json
21
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartvpn",
|
||||
"version": "1.0.1",
|
||||
"version": "1.10.1",
|
||||
"private": false,
|
||||
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
|
||||
"type": "module",
|
||||
@@ -10,8 +10,9 @@
|
||||
"main": "dist_ts/index.js",
|
||||
"typings": "dist_ts/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
|
||||
"test": "tstest test/ --verbose",
|
||||
"build": "(tsbuild tsfolders) && (tsrust)",
|
||||
"test:before": "(tsrust)",
|
||||
"test": "tstest test/ --verbose --logfile --timeout 60",
|
||||
"buildDocs": "tsdoc"
|
||||
},
|
||||
"repository": {
|
||||
@@ -28,15 +29,15 @@
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@push.rocks/smartrust": "^1.3.0",
|
||||
"@push.rocks/smartpath": "^5.0.18"
|
||||
"@push.rocks/smartpath": "^6.0.0",
|
||||
"@push.rocks/smartrust": "^1.3.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@git.zone/tsbuild": "^2.2.12",
|
||||
"@git.zone/tsrun": "^1.3.3",
|
||||
"@git.zone/tstest": "^1.0.96",
|
||||
"@git.zone/tsrust": "^1.0.29",
|
||||
"@types/node": "^22.0.0"
|
||||
"@git.zone/tsbuild": "^4.4.0",
|
||||
"@git.zone/tsrun": "^2.0.2",
|
||||
"@git.zone/tsrust": "^1.3.2",
|
||||
"@git.zone/tstest": "^3.6.3",
|
||||
"@types/node": "^25.5.0"
|
||||
},
|
||||
"files": [
|
||||
"ts/**/*",
|
||||
|
||||
3545
pnpm-lock.yaml
generated
3545
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
644
readme.md
644
readme.md
@@ -1,392 +1,420 @@
|
||||
# @push.rocks/smartvpn
|
||||
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, typed APIs while all networking heavy lifting — encryption, tunneling, packet forwarding — runs at native speed in Rust.
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Enterprise-ready client authentication, triple transport support (WebSocket + QUIC + WireGuard), and a typed hub API for managing clients from code.
|
||||
|
||||
🔐 **Noise IK** mutual authentication — per-client X25519 keypairs, server-side registry
|
||||
🚀 **Triple transport**: WebSocket (Cloudflare-friendly), raw **QUIC** (datagrams), and **WireGuard** (standard protocol)
|
||||
🛡️ **ACL engine** — deny-overrides-allow IP filtering, aligned with SmartProxy conventions
|
||||
🔀 **PROXY protocol v2** — real client IPs behind reverse proxies (HAProxy, SmartProxy, Cloudflare Spectrum)
|
||||
📊 **Adaptive QoS**: per-client rate limiting, priority queues, connection quality tracking
|
||||
🔄 **Hub API**: one `createClient()` call generates keys, assigns IP, returns both SmartVPN + WireGuard configs
|
||||
📡 **Real-time telemetry**: RTT, jitter, loss ratio, link health — all via typed APIs
|
||||
🌐 **Flexible forwarding**: TUN device (kernel), userspace NAT (no root), or testing mode
|
||||
|
||||
## Issue Reporting and Security
|
||||
|
||||
For reporting bugs, issues, or security vulnerabilities, please visit [community.foss.global/](https://community.foss.global/). This is the central community hub for all issue reporting. Developers who sign and comply with our contribution agreement and go through identification can also get a [code.foss.global/](https://code.foss.global/) account to submit Pull Requests directly.
|
||||
|
||||
## Install
|
||||
## Install 📦
|
||||
|
||||
```bash
|
||||
npm install @push.rocks/smartvpn
|
||||
# or
|
||||
pnpm install @push.rocks/smartvpn
|
||||
# or
|
||||
npm install @push.rocks/smartvpn
|
||||
```
|
||||
|
||||
## 🏗️ Architecture
|
||||
The package ships with pre-compiled Rust binaries for **linux/amd64** and **linux/arm64**. No Rust toolchain required at runtime.
|
||||
|
||||
## Architecture 🏗️
|
||||
|
||||
```
|
||||
TypeScript (control plane) Rust (data plane)
|
||||
┌──────────────────────────┐ ┌───────────────────────────────┐
|
||||
│ VpnClient / VpnServer │ │ smartvpn_daemon │
|
||||
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
|
||||
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
|
||||
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha) │
|
||||
└──────────────────────────┘ │ ├─ codec (binary framing) │
|
||||
│ ├─ keepalive (app-level) │
|
||||
│ ├─ tunnel (TUN device) │
|
||||
│ ├─ network (NAT/IP pool) │
|
||||
│ └─ reconnect (backoff) │
|
||||
└───────────────────────────────┘
|
||||
┌──────────────────────────────┐ JSON-lines IPC ┌───────────────────────────────┐
|
||||
│ TypeScript Control Plane │ ◄─────────────────────► │ Rust Data Plane Daemon │
|
||||
│ │ stdio or Unix sock │ │
|
||||
│ VpnServer / VpnClient │ │ Noise IK handshake │
|
||||
│ Typed IPC commands │ │ XChaCha20-Poly1305 │
|
||||
│ Config validation │ │ WS + QUIC + WireGuard │
|
||||
│ Hub: client management │ │ TUN device, IP pool, NAT │
|
||||
│ WireGuard .conf generation │ │ Rate limiting, ACLs, QoS │
|
||||
└──────────────────────────────┘ └───────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key design decisions:**
|
||||
**Split-plane design** — TypeScript handles orchestration, config, and DX; Rust handles every hot-path byte with zero-copy async I/O (tokio, mimalloc).
|
||||
|
||||
| Decision | Choice | Why |
|
||||
|----------|--------|-----|
|
||||
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
|
||||
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
|
||||
| Keepalive | App-level (not WS pings) | Cloudflare drops WS ping frames; app-level pings survive |
|
||||
| IPC | JSON lines over stdio/Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
|
||||
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
|
||||
## Quick Start 🚀
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### VPN Client
|
||||
|
||||
```typescript
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Development: spawn the Rust daemon as a child process
|
||||
const client = new VpnClient({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon bridge
|
||||
await client.start();
|
||||
|
||||
// Connect to a VPN server
|
||||
const { assignedIp } = await client.connect({
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
});
|
||||
|
||||
console.log(`Connected! Assigned IP: ${assignedIp}`);
|
||||
|
||||
// Check status
|
||||
const status = await client.getStatus();
|
||||
console.log(status); // { state: 'connected', assignedIp: '10.8.0.2', ... }
|
||||
|
||||
// Get traffic stats
|
||||
const stats = await client.getStatistics();
|
||||
console.log(stats); // { bytesSent, bytesReceived, packetsSent, ... }
|
||||
|
||||
// Disconnect
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
```
|
||||
|
||||
### VPN Server
|
||||
### 1. Start a VPN Server (Hub)
|
||||
|
||||
```typescript
|
||||
import { VpnServer } from '@push.rocks/smartvpn';
|
||||
|
||||
const server = new VpnServer({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon and the VPN server
|
||||
const server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
await server.start({
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'BASE64_PRIVATE_KEY',
|
||||
publicKey: 'BASE64_PUBLIC_KEY',
|
||||
privateKey: '<server-noise-private-key-base64>',
|
||||
publicKey: '<server-noise-public-key-base64>',
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
transportMode: 'both', // WebSocket + QUIC simultaneously
|
||||
forwardingMode: 'tun', // 'tun' (kernel), 'socket' (userspace NAT), or 'testing'
|
||||
enableNat: true,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
});
|
||||
|
||||
// Generate a Noise keypair
|
||||
const keypair = await server.generateKeypair();
|
||||
console.log(keypair); // { publicKey: '...', privateKey: '...' }
|
||||
|
||||
// List connected clients
|
||||
const clients = await server.listClients();
|
||||
// [{ clientId, assignedIp, connectedSince, bytesSent, bytesReceived }]
|
||||
|
||||
// Disconnect a specific client
|
||||
await server.disconnectClient('some-client-id');
|
||||
|
||||
// Get server stats
|
||||
const stats = await server.getStatistics();
|
||||
// { bytesSent, bytesReceived, activeClients, totalConnections, ... }
|
||||
|
||||
// Stop
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
```
|
||||
|
||||
### Production: Socket Transport
|
||||
|
||||
In production, the daemon runs as a system service and you connect over a Unix socket:
|
||||
### 2. Create a Client (One Call = Everything)
|
||||
|
||||
```typescript
|
||||
const client = new VpnClient({
|
||||
transport: {
|
||||
transport: 'socket',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
autoReconnect: true,
|
||||
reconnectBaseDelayMs: 100,
|
||||
reconnectMaxDelayMs: 30000,
|
||||
maxReconnectAttempts: 10,
|
||||
const bundle = await server.createClient({
|
||||
clientId: 'alice-laptop',
|
||||
tags: ['engineering'],
|
||||
security: {
|
||||
destinationAllowList: ['10.0.0.0/8'], // can only reach internal network
|
||||
destinationBlockList: ['10.0.0.99'], // except this host
|
||||
rateLimit: { bytesPerSec: 10_000_000, burstBytes: 20_000_000 },
|
||||
},
|
||||
});
|
||||
|
||||
await client.start(); // connects to existing daemon (does not spawn)
|
||||
// bundle.smartvpnConfig → typed IVpnClientConfig, ready to use
|
||||
// bundle.wireguardConfig → standard WireGuard .conf string
|
||||
// bundle.secrets → { noisePrivateKey, wgPrivateKey } — shown ONCE
|
||||
```
|
||||
|
||||
When using socket transport, `client.stop()` closes the socket but **does not kill the daemon** — exactly what you want in production.
|
||||
|
||||
## 📋 API Reference
|
||||
|
||||
### `VpnClient`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start()` | `Promise<boolean>` | Start the daemon bridge (spawn or connect) |
|
||||
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
|
||||
| `disconnect()` | `Promise<void>` | Disconnect from VPN |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
|
||||
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic statistics |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
| `running` | `boolean` | Whether bridge is active |
|
||||
|
||||
### `VpnServer`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start(config?)` | `Promise<void>` | Start daemon + VPN server |
|
||||
| `stopServer()` | `Promise<void>` | Stop the VPN server |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
|
||||
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
|
||||
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients |
|
||||
| `disconnectClient(id)` | `Promise<void>` | Kick a client |
|
||||
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
|
||||
### `VpnConfig`
|
||||
|
||||
Static utility class for config validation and file I/O:
|
||||
### 3. Connect a Client
|
||||
|
||||
```typescript
|
||||
import { VpnConfig } from '@push.rocks/smartvpn';
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Validate (throws on invalid)
|
||||
VpnConfig.validateClientConfig(config);
|
||||
VpnConfig.validateServerConfig(config);
|
||||
const client = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await client.start();
|
||||
|
||||
// Load/save JSON configs
|
||||
const config = await VpnConfig.loadFromFile<IVpnClientConfig>('/etc/smartvpn/client.json');
|
||||
await VpnConfig.saveToFile('/etc/smartvpn/client.json', config);
|
||||
const { assignedIp } = await client.connect(bundle.smartvpnConfig);
|
||||
console.log(`Connected! VPN IP: ${assignedIp}`);
|
||||
```
|
||||
|
||||
### `VpnInstaller`
|
||||
## Features ✨
|
||||
|
||||
Generate system service units for the daemon:
|
||||
### 🔐 Enterprise Authentication (Noise IK)
|
||||
|
||||
Every client authenticates with a **Noise IK handshake** (`Noise_IK_25519_ChaChaPoly_BLAKE2s`). The server verifies the client's static public key against its registry — unauthorized clients are rejected before any data flows.
|
||||
|
||||
- Per-client X25519 keypair generated server-side
|
||||
- Client registry with enable/disable, expiry, tags
|
||||
- Key rotation with `rotateClientKey()` — generates new keys, returns fresh config bundle, disconnects old session
|
||||
|
||||
### 🌐 Triple Transport
|
||||
|
||||
| Transport | Protocol | Best For |
|
||||
|-----------|----------|----------|
|
||||
| **WebSocket** | TLS over TCP | Firewall-friendly, Cloudflare compatible |
|
||||
| **QUIC** | UDP (via quinn) | Low latency, datagram support for IP packets |
|
||||
| **WireGuard** | UDP (via boringtun) | Standard WG clients (iOS, Android, wg-quick) |
|
||||
|
||||
The server can run **all three simultaneously** with `transportMode: 'both'` (WS + QUIC) or `'wireguard'`. Clients auto-negotiate with `transport: 'auto'` (tries QUIC first, falls back to WS).
|
||||
|
||||
### 🛡️ ACL Engine (SmartProxy-Aligned)
|
||||
|
||||
Security policies per client, using the same `ipAllowList` / `ipBlockList` naming convention as `@push.rocks/smartproxy`:
|
||||
|
||||
```typescript
|
||||
security: {
|
||||
ipAllowList: ['192.168.1.0/24'], // source IPs allowed to connect
|
||||
ipBlockList: ['192.168.1.100'], // deny overrides allow
|
||||
destinationAllowList: ['10.0.0.0/8'], // VPN destinations permitted
|
||||
destinationBlockList: ['10.0.0.99'], // deny overrides allow
|
||||
maxConnections: 5,
|
||||
rateLimit: { bytesPerSec: 1_000_000, burstBytes: 2_000_000 },
|
||||
}
|
||||
```
|
||||
|
||||
Supports exact IPs, CIDR, wildcards (`192.168.1.*`), and ranges (`1.1.1.1-1.1.1.100`).
|
||||
|
||||
### 🔀 PROXY Protocol v2
|
||||
|
||||
When the VPN server sits behind a reverse proxy, enable PROXY protocol v2 to receive the **real client IP** instead of the proxy's address. This makes `ipAllowList` / `ipBlockList` ACLs work correctly through load balancers.
|
||||
|
||||
```typescript
|
||||
await server.start({
|
||||
// ... other config ...
|
||||
proxyProtocol: true, // parse PP v2 headers on WS connections
|
||||
connectionIpBlockList: ['198.51.100.0/24'], // server-wide block list (pre-handshake)
|
||||
});
|
||||
```
|
||||
|
||||
**Two-phase ACL with real IPs:**
|
||||
|
||||
| Phase | When | What Happens |
|
||||
|-------|------|-------------|
|
||||
| **Pre-handshake** | After TCP accept | Server-level `connectionIpBlockList` rejects known-bad IPs — zero crypto cost |
|
||||
| **Post-handshake** | After Noise IK identifies client | Per-client `ipAllowList` / `ipBlockList` checked against real source IP |
|
||||
|
||||
- Parses the PP v2 binary header from raw TCP before WebSocket upgrade
|
||||
- 5-second timeout protects against stalling attacks
|
||||
- LOCAL command (proxy health checks) handled gracefully
|
||||
- IPv4 and IPv6 addresses supported
|
||||
- `remoteAddr` field on `IVpnClientInfo` exposes the real client IP for monitoring
|
||||
- **Security**: must be `false` (default) when accepting direct connections — only enable behind a trusted proxy
|
||||
|
||||
### 📦 Packet Forwarding Modes
|
||||
|
||||
SmartVPN supports three forwarding modes, configurable per-server and per-client:
|
||||
|
||||
| Mode | Flag | Description | Root Required |
|
||||
|------|------|-------------|---------------|
|
||||
| **TUN** | `'tun'` | Kernel TUN device — real packet forwarding with system routing | ✅ Yes |
|
||||
| **Userspace NAT** | `'socket'` | Userspace TCP/UDP proxy via `connect(2)` — no TUN, no root needed | ❌ No |
|
||||
| **Testing** | `'testing'` | Monitoring only — packets are counted but not forwarded | ❌ No |
|
||||
|
||||
```typescript
|
||||
// Server with userspace NAT (no root required)
|
||||
await server.start({
|
||||
// ...
|
||||
forwardingMode: 'socket',
|
||||
enableNat: true,
|
||||
});
|
||||
|
||||
// Client with TUN device
|
||||
const { assignedIp } = await client.connect({
|
||||
// ...
|
||||
forwardingMode: 'tun',
|
||||
});
|
||||
```
|
||||
|
||||
The userspace NAT mode extracts destination IP/port from IP packets, opens a real socket to the destination, and relays data — supporting both TCP streams and UDP datagrams without requiring `CAP_NET_ADMIN` or root privileges.
|
||||
|
||||
### 📊 Telemetry & QoS
|
||||
|
||||
- **Connection quality**: Smoothed RTT, jitter, min/max RTT, loss ratio, link health (`healthy` / `degraded` / `critical`)
|
||||
- **Adaptive keepalives**: Interval adjusts based on link health (60s → 30s → 10s)
|
||||
- **Per-client rate limiting**: Token bucket with configurable bytes/sec and burst
|
||||
- **Dead-peer detection**: 180s inactivity timeout
|
||||
- **MTU management**: Automatic overhead calculation (IP+TCP+WS+Noise = 79 bytes)
|
||||
|
||||
### 🔄 Hub Client Management
|
||||
|
||||
The server acts as a **hub** — one API to manage all clients:
|
||||
|
||||
```typescript
|
||||
// Create (generates keys, assigns IP, returns config bundle)
|
||||
const bundle = await server.createClient({ clientId: 'bob-phone' });
|
||||
|
||||
// Read
|
||||
const entry = await server.getClient('bob-phone');
|
||||
const all = await server.listRegisteredClients();
|
||||
|
||||
// Update (ACLs, tags, description, rate limits...)
|
||||
await server.updateClient('bob-phone', {
|
||||
security: { destinationAllowList: ['0.0.0.0/0'] },
|
||||
tags: ['mobile', 'field-ops'],
|
||||
});
|
||||
|
||||
// Enable / Disable
|
||||
await server.disableClient('bob-phone'); // disconnects + blocks reconnection
|
||||
await server.enableClient('bob-phone');
|
||||
|
||||
// Key rotation
|
||||
const newBundle = await server.rotateClientKey('bob-phone');
|
||||
|
||||
// Export config (without secrets)
|
||||
const wgConf = await server.exportClientConfig('bob-phone', 'wireguard');
|
||||
|
||||
// Remove
|
||||
await server.removeClient('bob-phone');
|
||||
```
|
||||
|
||||
### 📝 WireGuard Config Generation
|
||||
|
||||
Generate standard `.conf` files for any WireGuard client:
|
||||
|
||||
```typescript
|
||||
import { WgConfigGenerator } from '@push.rocks/smartvpn';
|
||||
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: '<client-wg-private-key>',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1'],
|
||||
peer: {
|
||||
publicKey: '<server-wg-public-key>',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
// → standard WireGuard .conf compatible with wg-quick, iOS, Android
|
||||
```
|
||||
|
||||
### 🖥️ System Service Installation
|
||||
|
||||
```typescript
|
||||
import { VpnInstaller } from '@push.rocks/smartvpn';
|
||||
|
||||
// Auto-detect platform
|
||||
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
|
||||
|
||||
// Generate systemd unit (Linux)
|
||||
const unit = VpnInstaller.generateSystemdUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'server',
|
||||
});
|
||||
// unit.content = full systemd .service file
|
||||
// unit.installPath = '/etc/systemd/system/smartvpn-server.service'
|
||||
|
||||
// Generate launchd plist (macOS)
|
||||
const plist = VpnInstaller.generateLaunchdPlist({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'client',
|
||||
});
|
||||
|
||||
// Auto-detect and generate
|
||||
const serviceUnit = VpnInstaller.generateServiceUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
const unit = VpnInstaller.generateServiceUnit({
|
||||
mode: 'server',
|
||||
configPath: '/etc/smartvpn/server.json',
|
||||
});
|
||||
// unit.platform → 'linux' | 'macos'
|
||||
// unit.content → systemd unit file or launchd plist
|
||||
// unit.installPath → /etc/systemd/system/smartvpn-server.service
|
||||
```
|
||||
|
||||
### Events
|
||||
## API Reference 📖
|
||||
|
||||
Both `VpnClient` and `VpnServer` extend `EventEmitter`:
|
||||
### Classes
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `VpnServer` | Manages the Rust daemon in server mode. Hub methods for client CRUD. |
|
||||
| `VpnClient` | Manages the Rust daemon in client mode. Connect, disconnect, telemetry. |
|
||||
| `VpnBridge<T>` | Low-level typed IPC bridge (stdio or Unix socket). |
|
||||
| `VpnConfig` | Static config validation and file I/O. |
|
||||
| `VpnInstaller` | Generates systemd/launchd service files. |
|
||||
| `WgConfigGenerator` | Generates standard WireGuard `.conf` files. |
|
||||
|
||||
### Key Interfaces
|
||||
|
||||
| Interface | Purpose |
|
||||
|-----------|---------|
|
||||
| `IVpnServerConfig` | Server configuration (listen addr, keys, subnet, transport mode, forwarding mode, clients, proxy protocol) |
|
||||
| `IVpnClientConfig` | Client configuration (server URL, keys, transport, forwarding mode, WG options) |
|
||||
| `IClientEntry` | Server-side client definition (ID, keys, security, priority, tags, expiry) |
|
||||
| `IClientSecurity` | Per-client ACLs and rate limits (SmartProxy-aligned naming) |
|
||||
| `IClientRateLimit` | Rate limiting config (bytesPerSec, burstBytes) |
|
||||
| `IClientConfigBundle` | Full config bundle returned by `createClient()` |
|
||||
| `IVpnClientInfo` | Connected client info (IP, stats, authenticated key, remote addr) |
|
||||
| `IVpnConnectionQuality` | RTT, jitter, loss ratio, link health |
|
||||
| `IVpnKeypair` | Base64-encoded public/private key pair |
|
||||
|
||||
### Server IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `start` / `stop` | Start/stop the VPN listener |
|
||||
| `createClient` | Generate keys, assign IP, return config bundle |
|
||||
| `removeClient` / `getClient` / `listRegisteredClients` | Client registry CRUD |
|
||||
| `updateClient` / `enableClient` / `disableClient` | Modify client state |
|
||||
| `rotateClientKey` | Fresh keypairs + new config bundle |
|
||||
| `exportClientConfig` | Re-export as SmartVPN config or WireGuard `.conf` |
|
||||
| `listClients` / `disconnectClient` | Manage live connections |
|
||||
| `setClientRateLimit` / `removeClientRateLimit` | Runtime rate limit adjustments |
|
||||
| `getStatus` / `getStatistics` / `getClientTelemetry` | Monitoring |
|
||||
| `generateKeypair` / `generateWgKeypair` / `generateClientKeypair` | Key generation |
|
||||
| `addWgPeer` / `removeWgPeer` / `listWgPeers` | WireGuard peer management |
|
||||
|
||||
### Client IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `connect` / `disconnect` | Manage the tunnel |
|
||||
| `getStatus` / `getStatistics` | Connection state and traffic stats |
|
||||
| `getConnectionQuality` | RTT, jitter, loss, link health |
|
||||
| `getMtuInfo` | MTU and overhead details |
|
||||
|
||||
## Transport Modes 🔀
|
||||
|
||||
### Server Configuration
|
||||
|
||||
```typescript
|
||||
client.on('status', (status) => { /* IVpnStatus */ });
|
||||
client.on('error', (err) => { /* { message, code? } */ });
|
||||
client.on('exit', ({ code, signal }) => { /* daemon exited */ });
|
||||
client.on('reconnected', () => { /* socket reconnected */ });
|
||||
// WebSocket only
|
||||
{ transportMode: 'websocket', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
server.on('client-connected', (info) => { /* IVpnClientInfo */ });
|
||||
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
|
||||
// QUIC only
|
||||
{ transportMode: 'quic', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
// Both (WS + QUIC on same or different ports)
|
||||
{ transportMode: 'both', listenAddr: '0.0.0.0:443', quicListenAddr: '0.0.0.0:4433' }
|
||||
|
||||
// WireGuard
|
||||
{ transportMode: 'wireguard', wgListenPort: 51820, wgPeers: [...] }
|
||||
```
|
||||
|
||||
## 🔐 Security Model
|
||||
### Client Configuration
|
||||
|
||||
The VPN uses a **Noise NK** handshake pattern:
|
||||
```typescript
|
||||
// Auto (tries QUIC first, falls back to WS)
|
||||
{ transport: 'auto', serverUrl: 'wss://vpn.example.com' }
|
||||
|
||||
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
|
||||
2. The client generates an ephemeral keypair, performs `e, es` (Diffie-Hellman with server's static key)
|
||||
3. Server responds with `e, ee` (Diffie-Hellman with both ephemeral keys)
|
||||
4. Result: forward-secret transport keys derived from both DH operations
|
||||
// Explicit QUIC with certificate pinning
|
||||
{ transport: 'quic', serverUrl: '1.2.3.4:4433', serverCertHash: '<sha256-base64>' }
|
||||
|
||||
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
|
||||
- 24-byte random nonces (no counter synchronization needed)
|
||||
- 16-byte authentication tags
|
||||
- Wire format: `[nonce:24B][ciphertext:var][tag:16B]`
|
||||
|
||||
## 📦 Binary Protocol
|
||||
|
||||
Inside the WebSocket tunnel, packets use a simple binary framing:
|
||||
|
||||
```
|
||||
┌──────────┬──────────┬────────────────────┐
|
||||
│ Type (1B)│ Len (4B) │ Payload (variable) │
|
||||
└──────────┴──────────┴────────────────────┘
|
||||
// WireGuard
|
||||
{ transport: 'wireguard', wgPrivateKey: '...', wgEndpoint: 'vpn.example.com:51820', ... }
|
||||
```
|
||||
|
||||
| Type | Value | Description |
|
||||
|------|-------|-------------|
|
||||
| `HandshakeInit` | `0x01` | Client → Server handshake |
|
||||
| `HandshakeResp` | `0x02` | Server → Client handshake |
|
||||
| `IpPacket` | `0x10` | Encrypted IP packet |
|
||||
| `Keepalive` | `0x20` | App-level ping |
|
||||
| `KeepaliveAck` | `0x21` | App-level pong |
|
||||
| `SessionResume` | `0x30` | Resume a dropped session |
|
||||
| `SessionResumeOk` | `0x31` | Resume accepted |
|
||||
| `SessionResumeErr` | `0x32` | Resume rejected |
|
||||
| `Disconnect` | `0x3F` | Graceful disconnect |
|
||||
## Cryptography 🔑
|
||||
|
||||
## 🛠️ Rust Daemon CLI
|
||||
| Layer | Algorithm | Purpose |
|
||||
|-------|-----------|---------|
|
||||
| **Handshake** | Noise IK (X25519 + ChaChaPoly + BLAKE2s) | Mutual authentication + key exchange |
|
||||
| **Transport** | Noise transport state (ChaChaPoly) | All post-handshake data encryption |
|
||||
| **Additional** | XChaCha20-Poly1305 | Extended nonce space for data-at-rest |
|
||||
| **WireGuard** | X25519 + ChaCha20-Poly1305 (via boringtun) | Standard WireGuard crypto |
|
||||
|
||||
The Rust binary supports several modes:
|
||||
## Binary Protocol 📡
|
||||
|
||||
```bash
|
||||
# Development: stdio management (JSON lines on stdin/stdout)
|
||||
smartvpn_daemon --management --mode client
|
||||
smartvpn_daemon --management --mode server
|
||||
All frames use `[type:1B][length:4B][payload:NB]` with a 64KB max payload:
|
||||
|
||||
# Production: Unix socket management
|
||||
smartvpn_daemon --management-socket /var/run/smartvpn.sock --mode server
|
||||
| Type | Hex | Direction | Description |
|
||||
|------|-----|-----------|-------------|
|
||||
| HandshakeInit | `0x01` | Client → Server | Noise IK first message |
|
||||
| HandshakeResp | `0x02` | Server → Client | Noise IK response |
|
||||
| IpPacket | `0x10` | Bidirectional | Encrypted tunnel data |
|
||||
| Keepalive | `0x20` | Client → Server | App-level keepalive (not WS ping) |
|
||||
| KeepaliveAck | `0x21` | Server → Client | Keepalive response with RTT payload |
|
||||
| Disconnect | `0x3F` | Bidirectional | Graceful disconnect |
|
||||
|
||||
# Generate a Noise keypair
|
||||
smartvpn_daemon --generate-keypair
|
||||
```
|
||||
|
||||
## 🔧 Building from Source
|
||||
## Development 🛠️
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Build TypeScript + cross-compile Rust
|
||||
# Build (TypeScript + Rust cross-compile)
|
||||
pnpm build
|
||||
|
||||
# Build Rust only (debug)
|
||||
cd rust && cargo build
|
||||
# Run all tests (79 TS + 132 Rust = 211 tests)
|
||||
pnpm test
|
||||
|
||||
# Run Rust tests
|
||||
# Run Rust tests directly
|
||||
cd rust && cargo test
|
||||
|
||||
# Run TypeScript tests
|
||||
pnpm test
|
||||
# Run a specific TS test
|
||||
tstest test/test.flowcontrol.node.ts --verbose
|
||||
```
|
||||
|
||||
## TypeScript Interfaces
|
||||
### Project Structure
|
||||
|
||||
<details>
|
||||
<summary>Click to expand full type definitions</summary>
|
||||
|
||||
```typescript
|
||||
// Transport options
|
||||
type TVpnTransportOptions =
|
||||
| { transport: 'stdio' }
|
||||
| {
|
||||
transport: 'socket';
|
||||
socketPath: string;
|
||||
autoReconnect?: boolean;
|
||||
reconnectBaseDelayMs?: number;
|
||||
reconnectMaxDelayMs?: number;
|
||||
maxReconnectAttempts?: number;
|
||||
};
|
||||
|
||||
// Client config
|
||||
interface IVpnClientConfig {
|
||||
serverUrl: string; // e.g. 'wss://vpn.example.com/tunnel'
|
||||
serverPublicKey: string; // base64-encoded Noise static key
|
||||
dns?: string[];
|
||||
mtu?: number; // default: 1420
|
||||
keepaliveIntervalSecs?: number; // default: 30
|
||||
}
|
||||
|
||||
// Server config
|
||||
interface IVpnServerConfig {
|
||||
listenAddr: string; // e.g. '0.0.0.0:443'
|
||||
privateKey: string; // base64 Noise static private key
|
||||
publicKey: string; // base64 Noise static public key
|
||||
subnet: string; // e.g. '10.8.0.0/24'
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
}
|
||||
|
||||
// Status
|
||||
type TVpnConnectionState = 'disconnected' | 'connecting' | 'handshaking'
|
||||
| 'connected' | 'reconnecting' | 'error';
|
||||
|
||||
interface IVpnStatus {
|
||||
state: TVpnConnectionState;
|
||||
assignedIp?: string;
|
||||
serverAddr?: string;
|
||||
connectedSince?: string;
|
||||
lastError?: string;
|
||||
}
|
||||
|
||||
// Statistics
|
||||
interface IVpnStatistics {
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
}
|
||||
|
||||
interface IVpnServerStatistics extends IVpnStatistics {
|
||||
activeClients: number;
|
||||
totalConnections: number;
|
||||
}
|
||||
|
||||
interface IVpnClientInfo {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
}
|
||||
|
||||
interface IVpnKeypair {
|
||||
publicKey: string;
|
||||
privateKey: string;
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
smartvpn/
|
||||
├── ts/ # TypeScript control plane
|
||||
│ ├── index.ts # All exports
|
||||
│ ├── smartvpn.interfaces.ts # Interfaces, types, IPC command maps
|
||||
│ ├── smartvpn.classes.vpnserver.ts
|
||||
│ ├── smartvpn.classes.vpnclient.ts
|
||||
│ ├── smartvpn.classes.vpnbridge.ts
|
||||
│ ├── smartvpn.classes.vpnconfig.ts
|
||||
│ ├── smartvpn.classes.vpninstaller.ts
|
||||
│ └── smartvpn.classes.wgconfig.ts
|
||||
├── rust/ # Rust data plane daemon
|
||||
│ └── src/
|
||||
│ ├── main.rs # CLI entry point
|
||||
│ ├── server.rs # VPN server + hub methods
|
||||
│ ├── client.rs # VPN client
|
||||
│ ├── crypto.rs # Noise IK + XChaCha20
|
||||
│ ├── client_registry.rs # Client database
|
||||
│ ├── acl.rs # ACL engine
|
||||
│ ├── proxy_protocol.rs # PROXY protocol v2 parser
|
||||
│ ├── management.rs # JSON-lines IPC
|
||||
│ ├── transport.rs # WebSocket transport
|
||||
│ ├── quic_transport.rs # QUIC transport
|
||||
│ ├── wireguard.rs # WireGuard (boringtun)
|
||||
│ ├── codec.rs # Binary frame protocol
|
||||
│ ├── keepalive.rs # Adaptive keepalives
|
||||
│ ├── ratelimit.rs # Token bucket
|
||||
│ ├── userspace_nat.rs # Userspace TCP/UDP NAT proxy
|
||||
│ └── ... # tunnel, network, telemetry, qos, mtu, reconnect
|
||||
├── test/ # 9 test files (79 tests)
|
||||
├── dist_ts/ # Compiled TypeScript
|
||||
└── dist_rust/ # Cross-compiled binaries (linux amd64 + arm64)
|
||||
```
|
||||
|
||||
## License and Legal Information
|
||||
|
||||
|
||||
253
readme.plan.md
Normal file
253
readme.plan.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# PROXY Protocol v2 Support for SmartVPN WebSocket Transport
|
||||
|
||||
## Context
|
||||
|
||||
SmartVPN's WebSocket transport is designed to sit behind reverse proxies (Cloudflare, HAProxy, SmartProxy). The recently added ACL engine has `ipAllowList`/`ipBlockList` per client, but without PROXY protocol support the server only sees the proxy's IP — not the real client's. This makes source-IP ACLs useless behind a proxy.
|
||||
|
||||
PROXY protocol v2 solves this by letting the proxy prepend a binary header with the real client IP/port before the WebSocket upgrade.
|
||||
|
||||
---
|
||||
|
||||
## Design
|
||||
|
||||
### Two-Phase ACL with Real Client IP
|
||||
|
||||
```
|
||||
TCP accept → Read PP v2 header → Extract real IP
|
||||
│
|
||||
├─ Phase 1 (pre-handshake): Check server-level connectionIpBlockList → reject early
|
||||
│
|
||||
├─ WebSocket upgrade → Noise IK handshake → Client identity known
|
||||
│
|
||||
└─ Phase 2 (post-handshake): Check per-client ipAllowList/ipBlockList → reject if denied
|
||||
```
|
||||
|
||||
- **Phase 1**: Server-wide block list (`connectionIpBlockList` on `IVpnServerConfig`). Rejects before any crypto work. Protects server resources.
|
||||
- **Phase 2**: Per-client ACL from `IClientSecurity.ipAllowList`/`ipBlockList`. Applied after the Noise IK handshake identifies the client.
|
||||
|
||||
### No New Dependencies
|
||||
|
||||
PROXY protocol v2 is a fixed-format binary header (16-byte signature + variable address block). Manual parsing (~80 lines) follows the same pattern as `codec.rs`. No crate needed.
|
||||
|
||||
### Scope: WebSocket Only
|
||||
|
||||
- **WebSocket**: Needs PP v2 (sits behind reverse proxies)
|
||||
- **QUIC**: Direct UDP, just use `conn.remote_address()`
|
||||
- **WireGuard**: Direct UDP, uses boringtun peer tracking
|
||||
|
||||
---
|
||||
|
||||
## Implementation
|
||||
|
||||
### Phase 1: New Rust module `proxy_protocol.rs`
|
||||
|
||||
**New file: `rust/src/proxy_protocol.rs`**
|
||||
|
||||
PP v2 binary format:
|
||||
```
|
||||
Bytes 0-11: Signature \x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A
|
||||
Byte 12: Version (high nibble = 0x2) | Command (low nibble: 0x0=LOCAL, 0x1=PROXY)
|
||||
Byte 13: Address family | Protocol (0x11 = IPv4/TCP, 0x21 = IPv6/TCP)
|
||||
Bytes 14-15: Address data length (big-endian u16)
|
||||
Bytes 16+: IPv4: 4 src_ip + 4 dst_ip + 2 src_port + 2 dst_port (12 bytes)
|
||||
IPv6: 16 src_ip + 16 dst_ip + 2 src_port + 2 dst_port (36 bytes)
|
||||
```
|
||||
|
||||
```rust
|
||||
pub struct ProxyHeader {
|
||||
pub src_addr: SocketAddr,
|
||||
pub dst_addr: SocketAddr,
|
||||
pub is_local: bool, // LOCAL command = health check probe
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
/// Reads exactly the header bytes — the stream is clean for WS upgrade after.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader>
|
||||
```
|
||||
|
||||
- 5-second timeout on header read (constant `PROXY_HEADER_TIMEOUT`)
|
||||
- Validates 12-byte signature, version nibble, command type
|
||||
- Parses IPv4 and IPv6 address blocks
|
||||
- LOCAL command returns `is_local: true` (caller closes connection gracefully)
|
||||
- Unit tests: valid IPv4/IPv6 headers, LOCAL command, invalid signature, truncated data
|
||||
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod proxy_protocol;`
|
||||
|
||||
### Phase 2: Server config + client info fields
|
||||
|
||||
**File: `rust/src/server.rs` — `ServerConfig`**
|
||||
|
||||
Add:
|
||||
```rust
|
||||
/// Enable PROXY protocol v2 parsing on WebSocket connections.
|
||||
/// SECURITY: Must be false when accepting direct client connections.
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept time, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
```
|
||||
|
||||
**File: `rust/src/server.rs` — `ClientInfo`**
|
||||
|
||||
Add:
|
||||
```rust
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
```
|
||||
|
||||
### Phase 3: ACL helper
|
||||
|
||||
**File: `rust/src/acl.rs`**
|
||||
|
||||
Add a public function for the server-level pre-handshake check:
|
||||
```rust
|
||||
/// Check whether a connection source IP is in a block list.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
```
|
||||
|
||||
(Keeps `ip_matches_any` private; exposes only the specific check needed.)
|
||||
|
||||
### Phase 4: WebSocket listener integration
|
||||
|
||||
**File: `rust/src/server.rs` — `run_ws_listener()`**
|
||||
|
||||
Between `listener.accept()` and `transport::accept_connection()`:
|
||||
|
||||
```rust
|
||||
// Determine real client address
|
||||
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
|
||||
match proxy_protocol::read_proxy_header(&mut tcp_stream).await {
|
||||
Ok(header) if header.is_local => {
|
||||
// Health check probe — close gracefully
|
||||
return;
|
||||
}
|
||||
Ok(header) => {
|
||||
info!("PP v2: real client {} -> {}", header.src_addr, header.dst_addr);
|
||||
Some(header.src_addr)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
|
||||
return; // Drop connection
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(tcp_addr) // Direct connection — use TCP SocketAddr
|
||||
};
|
||||
|
||||
// Pre-handshake server-level block list check
|
||||
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, block_list) {
|
||||
warn!("Connection blocked by server IP block list: {}", addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Then proceed with WS upgrade + handle_client_connection as before
|
||||
```
|
||||
|
||||
Key correctness note: `read_proxy_header` reads *exactly* the PP header bytes via `read_exact`. The `TcpStream` is then in a clean state for the WS HTTP upgrade. No buffered wrapper needed.
|
||||
|
||||
### Phase 5: Update `handle_client_connection` signature
|
||||
|
||||
**File: `rust/src/server.rs`**
|
||||
|
||||
Change signature:
|
||||
```rust
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>, // NEW
|
||||
) -> Result<()>
|
||||
```
|
||||
|
||||
After Noise IK handshake + registry lookup (where `client_security` is available), add connection-level per-client ACL:
|
||||
|
||||
```rust
|
||||
if let (Some(ref sec), Some(addr)) = (&client_security, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, sec.ip_block_list.as_deref().unwrap_or(&[])) {
|
||||
anyhow::bail!("Client {} connection denied: source IP {} blocked", registered_client_id, addr);
|
||||
}
|
||||
if let Some(ref allow) = sec.ip_allow_list {
|
||||
if !allow.is_empty() && !acl::is_ip_allowed(v4, allow) {
|
||||
anyhow::bail!("Client {} connection denied: source IP {} not in allow list", registered_client_id, addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Populate `remote_addr` when building `ClientInfo`:
|
||||
```rust
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
```
|
||||
|
||||
### Phase 6: QUIC listener — pass remote addr through
|
||||
|
||||
**File: `rust/src/server.rs` — `run_quic_listener()`**
|
||||
|
||||
QUIC doesn't use PROXY protocol. Just pass `conn.remote_address()` through:
|
||||
```rust
|
||||
let remote = conn.remote_address();
|
||||
// ...
|
||||
handle_client_connection(state, Box::new(sink), Box::new(stream), Some(remote)).await
|
||||
```
|
||||
|
||||
### Phase 7: TypeScript interface updates
|
||||
|
||||
**File: `ts/smartvpn.interfaces.ts`**
|
||||
|
||||
Add to `IVpnServerConfig`:
|
||||
```typescript
|
||||
/** Enable PROXY protocol v2 on incoming WebSocket connections.
|
||||
* Required when behind a reverse proxy that sends PP v2 headers. */
|
||||
proxyProtocol?: boolean;
|
||||
/** Server-level IP block list — applied at TCP accept time, before Noise handshake. */
|
||||
connectionIpBlockList?: string[];
|
||||
```
|
||||
|
||||
Add to `IVpnClientInfo`:
|
||||
```typescript
|
||||
/** Real client IP:port (from PROXY protocol or direct TCP). */
|
||||
remoteAddr?: string;
|
||||
```
|
||||
|
||||
### Phase 8: Tests
|
||||
|
||||
**Rust unit tests in `proxy_protocol.rs`:**
|
||||
- `parse_valid_ipv4_header` — construct a valid PP v2 header with known IPs, verify parsed correctly
|
||||
- `parse_valid_ipv6_header` — same for IPv6
|
||||
- `parse_local_command` — health check probe returns `is_local: true`
|
||||
- `reject_invalid_signature` — random bytes rejected
|
||||
- `reject_truncated_header` — short reads fail gracefully
|
||||
- `reject_v1_header` — PROXY v1 text format rejected (we only support v2)
|
||||
|
||||
**Rust unit tests in `acl.rs`:**
|
||||
- `is_connection_blocked` with various IP patterns
|
||||
|
||||
**TypeScript tests:**
|
||||
- Config validation accepts `proxyProtocol: true` + `connectionIpBlockList`
|
||||
|
||||
---
|
||||
|
||||
## Key Files to Modify
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `rust/src/proxy_protocol.rs` | **NEW** — PP v2 parser + tests |
|
||||
| `rust/src/lib.rs` | Add `pub mod proxy_protocol;` |
|
||||
| `rust/src/server.rs` | `ServerConfig` + `ClientInfo` fields, `run_ws_listener` PP integration, `handle_client_connection` signature + connection ACL, `run_quic_listener` pass-through |
|
||||
| `rust/src/acl.rs` | Add `is_connection_blocked` public function |
|
||||
| `ts/smartvpn.interfaces.ts` | `proxyProtocol`, `connectionIpBlockList`, `remoteAddr` |
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo test` — all existing 121 tests + new PP parser tests pass
|
||||
2. `pnpm test` — all 79 TS tests pass (no PP in test setup, just config validation)
|
||||
3. Manual: `socat` or test harness to send a PP v2 header before WS upgrade, verify server logs real IP
|
||||
2
rust/.cargo/config.toml
Normal file
2
rust/.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
908
rust/Cargo.lock
generated
908
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,19 @@ tun = { version = "0.7", features = ["async"] }
|
||||
bytes = "1"
|
||||
tokio-util = "0.7"
|
||||
futures-util = "0.3"
|
||||
async-trait = "0.1"
|
||||
quinn = "0.11"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rcgen = "0.13"
|
||||
ring = "0.17"
|
||||
rustls-pki-types = "1"
|
||||
rustls-pemfile = "2"
|
||||
webpki-roots = "1"
|
||||
mimalloc = "0.1"
|
||||
boringtun = "0.7"
|
||||
smoltcp = { version = "0.13", default-features = false, features = ["medium-ip", "proto-ipv4", "socket-tcp", "socket-udp", "alloc"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
ipnet = "2"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
||||
302
rust/src/acl.rs
Normal file
302
rust/src/acl.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use std::net::Ipv4Addr;
|
||||
use ipnet::Ipv4Net;
|
||||
|
||||
use crate::client_registry::ClientSecurity;
|
||||
|
||||
/// Result of an ACL check.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AclResult {
|
||||
Allow,
|
||||
DenySrc,
|
||||
DenyDst,
|
||||
}
|
||||
|
||||
/// Check whether a connection source IP is in a server-level block list.
|
||||
/// Used for pre-handshake rejection of known-bad IPs.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
|
||||
/// Check whether a source IP is allowed by allow/block lists.
|
||||
/// Returns true if the IP is permitted (not blocked and passes allow check).
|
||||
pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool {
|
||||
// Deny overrides allow
|
||||
if let Some(bl) = block_list {
|
||||
if ip_matches_any(ip, bl) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If allow list exists and is non-empty, IP must match
|
||||
if let Some(al) = allow_list {
|
||||
if !al.is_empty() && !ip_matches_any(ip, al) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
|
||||
///
|
||||
/// Evaluation order (deny overrides allow):
|
||||
/// 1. If src_ip is in ip_block_list → DenySrc
|
||||
/// 2. If dst_ip is in destination_block_list → DenyDst
|
||||
/// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc
|
||||
/// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst
|
||||
/// 5. Otherwise → Allow
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// Step 1: Check source block list (deny overrides)
|
||||
if let Some(ref block_list) = security.ip_block_list {
|
||||
if ip_matches_any(src_ip, block_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check destination block list (deny overrides)
|
||||
if let Some(ref block_list) = security.destination_block_list {
|
||||
if ip_matches_any(dst_ip, block_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Check source allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.ip_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Check destination allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.destination_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
AclResult::Allow
|
||||
}
|
||||
|
||||
/// Check if `ip` matches any pattern in the list.
|
||||
/// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*),
|
||||
/// and IP ranges (192.168.1.1-192.168.1.100).
|
||||
fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if ip_matches(ip, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if `ip` matches a single pattern.
|
||||
fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
// CIDR notation (e.g. 192.168.1.0/24)
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = pattern.parse::<Ipv4Net>() {
|
||||
return net.contains(&ip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// IP range (e.g. 192.168.1.1-192.168.1.100)
|
||||
if pattern.contains('-') {
|
||||
let parts: Vec<&str> = pattern.splitn(2, '-').collect();
|
||||
if parts.len() == 2 {
|
||||
if let (Ok(start), Ok(end)) = (parts[0].trim().parse::<Ipv4Addr>(), parts[1].trim().parse::<Ipv4Addr>()) {
|
||||
let ip_u32 = u32::from(ip);
|
||||
return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wildcard pattern (e.g. 192.168.1.*)
|
||||
if pattern.contains('*') {
|
||||
return wildcard_matches(ip, pattern);
|
||||
}
|
||||
|
||||
// Exact IP match
|
||||
if let Ok(exact) = pattern.parse::<Ipv4Addr>() {
|
||||
return ip == exact;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*".
|
||||
fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let ip_octets = ip.octets();
|
||||
let pattern_parts: Vec<&str> = pattern.split('.').collect();
|
||||
if pattern_parts.len() != 4 {
|
||||
return false;
|
||||
}
|
||||
for (i, part) in pattern_parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
continue;
|
||||
}
|
||||
if let Ok(octet) = part.parse::<u8>() {
|
||||
if ip_octets[i] != octet {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client_registry::{ClientRateLimit, ClientSecurity};
|
||||
|
||||
fn security(
|
||||
ip_allow: Option<Vec<&str>>,
|
||||
ip_block: Option<Vec<&str>>,
|
||||
dst_allow: Option<Vec<&str>>,
|
||||
dst_block: Option<Vec<&str>>,
|
||||
) -> ClientSecurity {
|
||||
ClientSecurity {
|
||||
ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
max_connections: None,
|
||||
rate_limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ip(s: &str) -> Ipv4Addr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
// ── No restrictions (empty security) ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_security_allows_all() {
|
||||
let sec = security(None, None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_lists_allow_all() {
|
||||
let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![]));
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
// ── Source IP allow list ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_allow_exact_match() {
|
||||
let sec = security(Some(vec!["10.0.0.1"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_cidr() {
|
||||
let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_wildcard() {
|
||||
let sec = security(Some(vec!["10.0.*.*"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_range() {
|
||||
let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Source IP block list (deny overrides) ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_block_overrides_allow() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
Some(vec!["192.168.1.100"]),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Destination allow list ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_allow_exact() {
|
||||
let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dst_allow_cidr() {
|
||||
let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Destination block list (deny overrides) ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_block_overrides_allow() {
|
||||
let sec = security(
|
||||
None,
|
||||
None,
|
||||
Some(vec!["10.0.0.0/8"]),
|
||||
Some(vec!["10.0.0.99"]),
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Combined source + destination ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn combined_src_and_dst_filtering() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
None,
|
||||
Some(vec!["8.8.8.8"]),
|
||||
None,
|
||||
);
|
||||
// Valid source, valid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow);
|
||||
// Invalid source
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc);
|
||||
// Valid source, invalid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── IP matching edge cases ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wildcard_single_octet() {
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*"));
|
||||
assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_boundaries() {
|
||||
assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pattern_no_match() {
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0"));
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,19 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
use crate::tunnel::{self, TunConfig};
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -18,9 +21,20 @@ use crate::transport;
|
||||
pub struct ClientConfig {
|
||||
pub server_url: String,
|
||||
pub server_public_key: String,
|
||||
/// Client's Noise IK static private key (base64) — required for authentication.
|
||||
pub client_private_key: String,
|
||||
/// Client's Noise IK static public key (base64) — for reference/display.
|
||||
pub client_public_key: String,
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
/// Transport type: "websocket" (default) or "quic".
|
||||
pub transport: Option<String>,
|
||||
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
|
||||
pub server_cert_hash: Option<String>,
|
||||
/// Forwarding mode: "tun" (TUN device, requires root) or "testing" (no TUN).
|
||||
/// Default: "testing".
|
||||
pub forwarding_mode: Option<String>,
|
||||
}
|
||||
|
||||
/// Client statistics.
|
||||
@@ -65,6 +79,8 @@ pub struct VpnClient {
|
||||
assigned_ip: Arc<RwLock<Option<String>>>,
|
||||
shutdown_tx: Option<mpsc::Sender<()>>,
|
||||
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
|
||||
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
}
|
||||
|
||||
impl VpnClient {
|
||||
@@ -75,6 +91,8 @@ impl VpnClient {
|
||||
assigned_ip: Arc::new(RwLock::new(None)),
|
||||
shutdown_tx: None,
|
||||
connected_since: Arc::new(RwLock::new(None)),
|
||||
quality_rx: None,
|
||||
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,23 +111,85 @@ impl VpnClient {
|
||||
let stats = self.stats.clone();
|
||||
let assigned_ip_ref = self.assigned_ip.clone();
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
// Decode keys
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
let client_priv_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.client_private_key,
|
||||
)?;
|
||||
|
||||
// Connect to WebSocket server
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
// Create transport based on configuration
|
||||
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
|
||||
let transport_type = config.transport.as_deref().unwrap_or("auto");
|
||||
match transport_type {
|
||||
"quic" => {
|
||||
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
|
||||
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
info!("Connected via QUIC");
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
"websocket" => {
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
_ => {
|
||||
// "auto" (default): try QUIC first, fall back to WebSocket
|
||||
// Extract host:port from the URL for QUIC attempt
|
||||
let quic_addr = extract_host_port(&config.server_url);
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
if let Some(ref addr) = quic_addr {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(3),
|
||||
try_quic_connect(addr, cert_hash),
|
||||
).await {
|
||||
Ok(Ok((quic_sink, quic_stream))) => {
|
||||
info!("Auto: connected via QUIC to {}", addr);
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC unavailable)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Auto: QUIC timed out, falling back to WebSocket");
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC timed out)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Can't extract host:port for QUIC, use WebSocket directly
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Noise IK handshake (client side = initiator, presents static key)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
let mut initiator = crypto::create_initiator(&server_pub_key)?;
|
||||
let mut initiator = crypto::create_initiator(&client_priv_key, &server_pub_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es
|
||||
// -> e, es, s, ss
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let init_frame = Frame {
|
||||
packet_type: PacketType::HandshakeInit,
|
||||
@@ -117,13 +197,11 @@ impl VpnClient {
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// <- e, ee
|
||||
let resp_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
|
||||
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
|
||||
// <- e, ee, se
|
||||
let resp_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
};
|
||||
|
||||
@@ -139,9 +217,9 @@ impl VpnClient {
|
||||
let mut noise_transport = initiator.into_transport_mode()?;
|
||||
|
||||
// Receive assigned IP info (encrypted)
|
||||
let info_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected IP info message"),
|
||||
let info_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before IP info"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&info_msg[..]);
|
||||
@@ -161,16 +239,60 @@ impl VpnClient {
|
||||
|
||||
info!("Connected to VPN, assigned IP: {}", assigned_ip);
|
||||
|
||||
// Optionally create TUN device for IP packet forwarding (requires root)
|
||||
let tun_enabled = config.forwarding_mode.as_deref() == Some("tun");
|
||||
let (tun_reader, tun_writer, tun_subnet) = if tun_enabled {
|
||||
let client_tun_ip: Ipv4Addr = assigned_ip.parse()?;
|
||||
let mtu = ip_info["mtu"].as_u64().unwrap_or(1420) as u16;
|
||||
let tun_config = TunConfig {
|
||||
name: "svpn-client0".to_string(),
|
||||
address: client_tun_ip,
|
||||
netmask: Ipv4Addr::new(255, 255, 255, 0),
|
||||
mtu,
|
||||
};
|
||||
let tun_device = tunnel::create_tun(&tun_config)?;
|
||||
|
||||
// Add route for VPN subnet through the TUN device
|
||||
let gateway_str = ip_info["gateway"].as_str().unwrap_or("10.8.0.1");
|
||||
let gateway: Ipv4Addr = gateway_str.parse().unwrap_or(Ipv4Addr::new(10, 8, 0, 1));
|
||||
let subnet = format!("{}/24", Ipv4Addr::from(u32::from(gateway) & 0xFFFFFF00));
|
||||
tunnel::add_route(&subnet, &tun_config.name).await?;
|
||||
|
||||
let (reader, writer) = tokio::io::split(tun_device);
|
||||
(Some(reader), Some(writer), Some(subnet))
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
// Create adaptive keepalive monitor (use custom interval if configured)
|
||||
let ka_config = config.keepalive_interval_secs.map(|secs| {
|
||||
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
|
||||
cfg.degraded_interval = std::time::Duration::from_secs(secs);
|
||||
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
|
||||
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
|
||||
cfg
|
||||
});
|
||||
let (monitor, handle) = keepalive::create_keepalive(ka_config);
|
||||
self.quality_rx = Some(handle.quality_rx);
|
||||
|
||||
// Spawn the keepalive monitor
|
||||
tokio::spawn(monitor.run());
|
||||
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
ws_sink,
|
||||
ws_stream,
|
||||
sink,
|
||||
stream,
|
||||
noise_transport,
|
||||
state,
|
||||
stats,
|
||||
shutdown_rx,
|
||||
config.keepalive_interval_secs.unwrap_or(30),
|
||||
handle.signal_rx,
|
||||
handle.ack_tx,
|
||||
link_health,
|
||||
tun_reader,
|
||||
tun_writer,
|
||||
tun_subnet,
|
||||
));
|
||||
|
||||
Ok(assigned_ip_clone)
|
||||
@@ -184,6 +306,7 @@ impl VpnClient {
|
||||
*self.assigned_ip.write().await = None;
|
||||
*self.connected_since.write().await = None;
|
||||
*self.state.write().await = ClientState::Disconnected;
|
||||
self.quality_rx = None;
|
||||
info!("Disconnected from VPN");
|
||||
Ok(())
|
||||
}
|
||||
@@ -208,13 +331,14 @@ impl VpnClient {
|
||||
status
|
||||
}
|
||||
|
||||
/// Get traffic statistics.
|
||||
/// Get traffic statistics (includes connection quality).
|
||||
pub async fn get_statistics(&self) -> serde_json::Value {
|
||||
let stats = self.stats.read().await;
|
||||
let since = self.connected_since.read().await;
|
||||
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
|
||||
let health = self.link_health.read().await;
|
||||
|
||||
serde_json::json!({
|
||||
let mut result = serde_json::json!({
|
||||
"bytesSent": stats.bytes_sent,
|
||||
"bytesReceived": stats.bytes_received,
|
||||
"packetsSent": stats.packets_sent,
|
||||
@@ -222,30 +346,64 @@ impl VpnClient {
|
||||
"keepalivesSent": stats.keepalives_sent,
|
||||
"keepalivesReceived": stats.keepalives_received,
|
||||
"uptimeSeconds": uptime,
|
||||
})
|
||||
});
|
||||
|
||||
// Include connection quality if available
|
||||
if let Some(ref rx) = self.quality_rx {
|
||||
let quality = rx.borrow().clone();
|
||||
result["quality"] = serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", *health),
|
||||
"keepalivesSent": quality.keepalives_sent,
|
||||
"keepalivesAcked": quality.keepalives_acked,
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get connection quality snapshot.
|
||||
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
|
||||
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
|
||||
}
|
||||
|
||||
/// Get current link health.
|
||||
pub async fn get_link_health(&self) -> LinkHealth {
|
||||
*self.link_health.read().await
|
||||
}
|
||||
}
|
||||
|
||||
/// The main client packet forwarding loop (runs in a spawned task).
|
||||
async fn client_loop(
|
||||
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
|
||||
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
mut noise_transport: snow::TransportState,
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
keepalive_secs: u64,
|
||||
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
link_health: Arc<RwLock<LinkHealth>>,
|
||||
mut tun_reader: Option<tokio::io::ReadHalf<tun::AsyncDevice>>,
|
||||
mut tun_writer: Option<tokio::io::WriteHalf<tun::AsyncDevice>>,
|
||||
tun_subnet: Option<String>,
|
||||
) {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
|
||||
keepalive_ticker.tick().await; // skip first immediate tick
|
||||
let mut tun_buf = vec![0u8; 65536];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
Ok(Some(data)) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
@@ -254,6 +412,14 @@ async fn client_loop(
|
||||
let mut s = stats.write().await;
|
||||
s.bytes_received += len as u64;
|
||||
s.packets_received += 1;
|
||||
drop(s);
|
||||
|
||||
// Write decrypted packet to TUN device (if enabled)
|
||||
if let Some(ref mut writer) = tun_writer {
|
||||
if let Err(e) = writer.write_all(&buf[..len]).await {
|
||||
warn!("TUN write error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error: {}", e);
|
||||
@@ -264,6 +430,8 @@ async fn client_loop(
|
||||
}
|
||||
PacketType::KeepaliveAck => {
|
||||
stats.write().await.keepalives_received += 1;
|
||||
// Signal the keepalive monitor that ACK was received
|
||||
let _ = ack_tx.send(()).await;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Server sent disconnect");
|
||||
@@ -274,35 +442,93 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Connection closed");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_sink.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
Err(e) => {
|
||||
error!("Transport error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = keepalive_ticker.tick() => {
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
// Read outbound packets from TUN and send to server (only when TUN enabled)
|
||||
result = async {
|
||||
match tun_reader {
|
||||
Some(ref mut reader) => reader.read(&mut tun_buf).await,
|
||||
None => std::future::pending::<std::io::Result<usize>>().await,
|
||||
}
|
||||
} => {
|
||||
match result {
|
||||
Ok(0) => {
|
||||
info!("TUN device closed");
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
match noise_transport.write_message(&tun_buf[..n], &mut buf) {
|
||||
Ok(len) => {
|
||||
let frame = Frame {
|
||||
packet_type: PacketType::IpPacket,
|
||||
payload: buf[..len].to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(
|
||||
&mut FrameCodec, frame, &mut frame_bytes
|
||||
).is_ok() {
|
||||
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
|
||||
warn!("Failed to send TUN packet to server");
|
||||
break;
|
||||
}
|
||||
let mut s = stats.write().await;
|
||||
s.bytes_sent += n as u64;
|
||||
s.packets_sent += 1;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Noise encrypt error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("TUN read error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
signal = signal_rx.recv() => {
|
||||
match signal {
|
||||
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
|
||||
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: timestamp_ms.to_be_bytes().to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
}
|
||||
}
|
||||
Some(KeepaliveSignal::PeerDead) => {
|
||||
warn!("Peer declared dead by keepalive monitor");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
|
||||
debug!("Link health changed to: {}", health);
|
||||
*link_health.write().await = health;
|
||||
}
|
||||
None => {
|
||||
// Keepalive monitor channel closed
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
@@ -313,12 +539,58 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
|
||||
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
let _ = sink.close().await;
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup: remove TUN route if enabled
|
||||
if let Some(ref subnet) = tun_subnet {
|
||||
if let Err(e) = tunnel::remove_route(subnet, "svpn-client0").await {
|
||||
warn!("Failed to remove client TUN route: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect via QUIC. Returns transport halves on success.
|
||||
async fn try_quic_connect(
|
||||
addr: &str,
|
||||
cert_hash: Option<&str>,
|
||||
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
|
||||
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
|
||||
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
Ok((sink, stream))
|
||||
}
|
||||
|
||||
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
|
||||
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
|
||||
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
|
||||
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
|
||||
fn extract_host_port(url: &str) -> Option<String> {
|
||||
if url.starts_with("ws://") || url.starts_with("wss://") {
|
||||
// Parse as URL
|
||||
let stripped = if url.starts_with("wss://") {
|
||||
&url[6..]
|
||||
} else {
|
||||
&url[5..]
|
||||
};
|
||||
// Remove path
|
||||
let host_port = stripped.split('/').next()?;
|
||||
if host_port.contains(':') {
|
||||
Some(host_port.to_string())
|
||||
} else {
|
||||
// Default port
|
||||
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
|
||||
Some(format!("{}:{}", host_port, default_port))
|
||||
}
|
||||
} else if url.contains(':') {
|
||||
// Already host:port
|
||||
Some(url.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
362
rust/src/client_registry.rs
Normal file
362
rust/src/client_registry.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Per-client rate limiting configuration.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
/// Per-client security settings — aligned with SmartProxy's IRouteSecurity pattern.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientSecurity {
|
||||
/// Source IPs/CIDRs the client may connect FROM (empty/None = any).
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// Source IPs blocked — overrides ip_allow_list (deny wins).
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Destination IPs/CIDRs the client may reach (empty/None = all).
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
/// Destination IPs blocked — overrides destination_allow_list (deny wins).
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
/// Max concurrent connections from this client.
|
||||
pub max_connections: Option<u32>,
|
||||
/// Per-client rate limiting.
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
/// A registered client entry — the server-side source of truth.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientEntry {
|
||||
/// Human-readable client ID (e.g. "alice-laptop").
|
||||
pub client_id: String,
|
||||
/// Client's Noise IK public key (base64).
|
||||
pub public_key: String,
|
||||
/// Client's WireGuard public key (base64) — optional.
|
||||
pub wg_public_key: Option<String>,
|
||||
/// Security settings (ACLs, rate limits).
|
||||
pub security: Option<ClientSecurity>,
|
||||
/// Traffic priority (lower = higher priority, default: 100).
|
||||
pub priority: Option<u32>,
|
||||
/// Whether this client is enabled (default: true).
|
||||
pub enabled: Option<bool>,
|
||||
/// Tags for grouping.
|
||||
pub tags: Option<Vec<String>>,
|
||||
/// Optional description.
|
||||
pub description: Option<String>,
|
||||
/// Optional expiry (ISO 8601 timestamp).
|
||||
pub expires_at: Option<String>,
|
||||
/// Assigned VPN IP address.
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientEntry {
|
||||
/// Whether this client is considered enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Whether this client has expired based on current time.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(ref expires) = self.expires_at {
|
||||
if let Ok(expiry) = chrono::DateTime::parse_from_rfc3339(expires) {
|
||||
return chrono::Utc::now() > expiry;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory client registry with dual-key indexing.
|
||||
pub struct ClientRegistry {
|
||||
/// Primary index: clientId → ClientEntry
|
||||
entries: HashMap<String, ClientEntry>,
|
||||
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
||||
key_index: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
key_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a registry from a list of client entries.
|
||||
pub fn from_entries(entries: Vec<ClientEntry>) -> Result<Self> {
|
||||
let mut registry = Self::new();
|
||||
for entry in entries {
|
||||
registry.add(entry)?;
|
||||
}
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Add a client to the registry.
|
||||
pub fn add(&mut self, entry: ClientEntry) -> Result<()> {
|
||||
if self.entries.contains_key(&entry.client_id) {
|
||||
anyhow::bail!("Client '{}' already exists", entry.client_id);
|
||||
}
|
||||
if self.key_index.contains_key(&entry.public_key) {
|
||||
anyhow::bail!("Public key already registered to another client");
|
||||
}
|
||||
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
||||
self.entries.insert(entry.client_id.clone(), entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a client by ID.
|
||||
pub fn remove(&mut self, client_id: &str) -> Result<ClientEntry> {
|
||||
let entry = self.entries.remove(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
self.key_index.remove(&entry.public_key);
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Get a client by ID.
|
||||
pub fn get_by_id(&self, client_id: &str) -> Option<&ClientEntry> {
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Get a client by public key (used during IK handshake verification).
|
||||
pub fn get_by_key(&self, public_key: &str) -> Option<&ClientEntry> {
|
||||
let client_id = self.key_index.get(public_key)?;
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Check if a public key is authorized (exists, enabled, not expired).
|
||||
pub fn is_authorized(&self, public_key: &str) -> bool {
|
||||
match self.get_by_key(public_key) {
|
||||
Some(entry) => entry.is_enabled() && !entry.is_expired(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client entry. The closure receives a mutable reference to the entry.
|
||||
pub fn update<F>(&mut self, client_id: &str, updater: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut ClientEntry),
|
||||
{
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
let old_key = entry.public_key.clone();
|
||||
updater(entry);
|
||||
// If public key changed, update the index
|
||||
if entry.public_key != old_key {
|
||||
self.key_index.remove(&old_key);
|
||||
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all client entries.
|
||||
pub fn list(&self) -> Vec<&ClientEntry> {
|
||||
self.entries.values().collect()
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns the updated entry.
|
||||
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
// Update key index
|
||||
self.key_index.remove(&entry.public_key);
|
||||
entry.public_key = new_public_key.clone();
|
||||
entry.wg_public_key = new_wg_public_key;
|
||||
self.key_index.insert(new_public_key, client_id.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered clients.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(id: &str, key: &str) -> ClientEntry {
|
||||
ClientEntry {
|
||||
client_id: id.to_string(),
|
||||
public_key: key.to_string(),
|
||||
wg_public_key: None,
|
||||
security: None,
|
||||
priority: None,
|
||||
enabled: None,
|
||||
tags: None,
|
||||
description: None,
|
||||
expires_at: None,
|
||||
assigned_ip: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_lookup() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
assert!(reg.get_by_id("alice").is_some());
|
||||
assert!(reg.get_by_key("key_alice").is_some());
|
||||
assert_eq!(reg.get_by_key("key_alice").unwrap().client_id, "alice");
|
||||
assert!(reg.get_by_id("bob").is_none());
|
||||
assert!(reg.get_by_key("key_bob").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_id() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key1")).unwrap();
|
||||
assert!(reg.add(make_entry("alice", "key2")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "same_key")).unwrap();
|
||||
assert!(reg.add(make_entry("bob", "same_key")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert_eq!(reg.len(), 1);
|
||||
|
||||
let removed = reg.remove("alice").unwrap();
|
||||
assert_eq!(removed.client_id, "alice");
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(reg.get_by_key("key_alice").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.remove("ghost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_enabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert!(reg.is_authorized("key_alice")); // enabled by default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_disabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.enabled = Some(false);
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_expired() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2020-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_future_expiry() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2099-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_unknown_key() {
|
||||
let reg = ClientRegistry::new();
|
||||
assert!(!reg.is_authorized("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
reg.update("alice", |entry| {
|
||||
entry.description = Some("Updated".to_string());
|
||||
entry.enabled = Some(false);
|
||||
}).unwrap();
|
||||
|
||||
let entry = reg.get_by_id("alice").unwrap();
|
||||
assert_eq!(entry.description.as_deref(), Some("Updated"));
|
||||
assert!(!entry.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.update("ghost", |_| {}).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rotate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "old_key")).unwrap();
|
||||
|
||||
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
|
||||
|
||||
assert!(reg.get_by_key("old_key").is_none());
|
||||
assert!(reg.get_by_key("new_key").is_some());
|
||||
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_entries() {
|
||||
let entries = vec![
|
||||
make_entry("alice", "key_a"),
|
||||
make_entry("bob", "key_b"),
|
||||
];
|
||||
let reg = ClientRegistry::from_entries(entries).unwrap();
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert!(reg.get_by_key("key_a").is_some());
|
||||
assert!(reg.get_by_key("key_b").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_clients() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_a")).unwrap();
|
||||
reg.add(make_entry("bob", "key_b")).unwrap();
|
||||
let list = reg.list();
|
||||
assert_eq!(list.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_with_rate_limit() {
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.security = Some(ClientSecurity {
|
||||
ip_allow_list: Some(vec!["192.168.1.0/24".to_string()]),
|
||||
ip_block_list: Some(vec!["192.168.1.100".to_string()]),
|
||||
destination_allow_list: None,
|
||||
destination_block_list: None,
|
||||
max_connections: Some(5),
|
||||
rate_limit: Some(ClientRateLimit {
|
||||
bytes_per_sec: 1_000_000,
|
||||
burst_bytes: 2_000_000,
|
||||
}),
|
||||
});
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(entry).unwrap();
|
||||
let e = reg.get_by_id("alice").unwrap();
|
||||
let sec = e.security.as_ref().unwrap();
|
||||
assert_eq!(sec.rate_limit.as_ref().unwrap().bytes_per_sec, 1_000_000);
|
||||
assert_eq!(sec.max_connections, Some(5));
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,10 @@ use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use snow::Builder;
|
||||
|
||||
/// Noise protocol pattern: NK (client knows server pubkey, no client auth at Noise level)
|
||||
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
|
||||
/// Noise protocol pattern: IK (client presents static key, server authenticates client)
|
||||
/// IK = Initiator's static key is transmitted; responder's Key is pre-known.
|
||||
/// This provides mutual authentication: server verifies client identity via public key.
|
||||
const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
/// Generate a new Noise static keypair.
|
||||
/// Returns (public_key_base64, private_key_base64).
|
||||
@@ -22,18 +24,23 @@ pub fn generate_keypair_raw() -> Result<snow::Keypair> {
|
||||
Ok(builder.generate_keypair()?)
|
||||
}
|
||||
|
||||
/// Create a Noise NK initiator (client side).
|
||||
/// The client knows the server's static public key.
|
||||
pub fn create_initiator(server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
/// Create a Noise IK initiator (client side).
|
||||
/// The client provides its own static keypair AND the server's public key.
|
||||
/// The client's static key is transmitted (encrypted) during the handshake,
|
||||
/// allowing the server to authenticate the client.
|
||||
pub fn create_initiator(client_private_key: &[u8], server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
.local_private_key(client_private_key)
|
||||
.remote_public_key(server_public_key)
|
||||
.build_initiator()?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Create a Noise NK responder (server side).
|
||||
/// Create a Noise IK responder (server side).
|
||||
/// The server uses its static private key.
|
||||
/// After the handshake, call `get_remote_static()` on the HandshakeState
|
||||
/// (before `into_transport_mode()`) to retrieve the client's public key.
|
||||
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
@@ -42,19 +49,20 @@ pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Perform the full Noise NK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport).
|
||||
/// Perform the full Noise IK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport, client_public_key).
|
||||
/// The client_public_key is extracted from the responder before entering transport mode.
|
||||
pub fn perform_handshake(
|
||||
mut initiator: snow::HandshakeState,
|
||||
mut responder: snow::HandshakeState,
|
||||
) -> Result<(snow::TransportState, snow::TransportState)> {
|
||||
) -> Result<(snow::TransportState, snow::TransportState, Vec<u8>)> {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es (initiator sends)
|
||||
// -> e, es, s, ss (initiator sends ephemeral + encrypted static key)
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let msg1 = buf[..len].to_vec();
|
||||
|
||||
// <- e, ee (responder reads and responds)
|
||||
// <- e, ee, se (responder reads and responds)
|
||||
responder.read_message(&msg1, &mut buf)?;
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let msg2 = buf[..len].to_vec();
|
||||
@@ -62,10 +70,16 @@ pub fn perform_handshake(
|
||||
// Initiator reads response
|
||||
initiator.read_message(&msg2, &mut buf)?;
|
||||
|
||||
// Extract client's public key from responder BEFORE entering transport mode
|
||||
let client_public_key = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake did not provide client static key"))?
|
||||
.to_vec();
|
||||
|
||||
let i_transport = initiator.into_transport_mode()?;
|
||||
let r_transport = responder.into_transport_mode()?;
|
||||
|
||||
Ok((i_transport, r_transport))
|
||||
Ok((i_transport, r_transport, client_public_key))
|
||||
}
|
||||
|
||||
/// XChaCha20-Poly1305 encryption for post-handshake data.
|
||||
@@ -135,15 +149,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_handshake() {
|
||||
fn noise_ik_handshake() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
let initiator = create_initiator(&server_kp.public).unwrap();
|
||||
let initiator = create_initiator(&client_kp.private, &server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
let (mut i_transport, mut r_transport) =
|
||||
let (mut i_transport, mut r_transport, remote_key) =
|
||||
perform_handshake(initiator, responder).unwrap();
|
||||
|
||||
// Verify the server received the client's public key
|
||||
assert_eq!(remote_key, client_kp.public);
|
||||
|
||||
// Test encrypted communication
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let plaintext = b"hello from client";
|
||||
@@ -159,6 +177,20 @@ mod tests {
|
||||
assert_eq!(&out[..len], plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_ik_wrong_server_key_fails() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let wrong_server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
// Client uses wrong server public key
|
||||
let initiator = create_initiator(&client_kp.private, &wrong_server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
// Handshake should fail because client targeted wrong server
|
||||
assert!(perform_handshake(initiator, responder).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xchacha_encrypt_decrypt() {
|
||||
let key = [42u8; 32];
|
||||
|
||||
@@ -1,87 +1,464 @@
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio::time::{interval, timeout};
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Default keepalive interval (30 seconds).
|
||||
use crate::telemetry::{ConnectionQuality, RttTracker};
|
||||
|
||||
/// Default keepalive interval (30 seconds — used for Degraded state).
|
||||
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Default keepalive ACK timeout (10 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default keepalive ACK timeout (5 seconds).
|
||||
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Link health states for adaptive keepalive.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LinkHealth {
|
||||
/// RTT stable, jitter low, no loss. Interval: 60s.
|
||||
Healthy,
|
||||
/// Elevated jitter or occasional loss. Interval: 30s.
|
||||
Degraded,
|
||||
/// High loss or sustained jitter spike. Interval: 10s.
|
||||
Critical,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LinkHealth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded => write!(f, "degraded"),
|
||||
Self::Critical => write!(f, "critical"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the adaptive keepalive state machine.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveKeepaliveConfig {
|
||||
/// Interval when link health is Healthy.
|
||||
pub healthy_interval: Duration,
|
||||
/// Interval when link health is Degraded.
|
||||
pub degraded_interval: Duration,
|
||||
/// Interval when link health is Critical.
|
||||
pub critical_interval: Duration,
|
||||
/// ACK timeout (how long to wait for ACK before declaring timeout).
|
||||
pub ack_timeout: Duration,
|
||||
/// Jitter threshold (ms) to enter Degraded from Healthy.
|
||||
pub jitter_degraded_ms: f64,
|
||||
/// Jitter threshold (ms) to return to Healthy from Degraded.
|
||||
pub jitter_healthy_ms: f64,
|
||||
/// Loss ratio threshold to enter Degraded.
|
||||
pub loss_degraded: f64,
|
||||
/// Loss ratio threshold to enter Critical.
|
||||
pub loss_critical: f64,
|
||||
/// Loss ratio threshold to return from Critical to Degraded.
|
||||
pub loss_recover: f64,
|
||||
/// Loss ratio threshold to return from Degraded to Healthy.
|
||||
pub loss_healthy: f64,
|
||||
/// Consecutive checks required for upward state transitions (hysteresis).
|
||||
pub upgrade_checks: u32,
|
||||
/// Consecutive timeouts to declare peer dead in Critical state.
|
||||
pub dead_peer_timeouts: u32,
|
||||
}
|
||||
|
||||
impl Default for AdaptiveKeepaliveConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
healthy_interval: Duration::from_secs(60),
|
||||
degraded_interval: Duration::from_secs(30),
|
||||
critical_interval: Duration::from_secs(10),
|
||||
ack_timeout: Duration::from_secs(5),
|
||||
jitter_degraded_ms: 50.0,
|
||||
jitter_healthy_ms: 30.0,
|
||||
loss_degraded: 0.05,
|
||||
loss_critical: 0.20,
|
||||
loss_recover: 0.10,
|
||||
loss_healthy: 0.02,
|
||||
upgrade_checks: 3,
|
||||
dead_peer_timeouts: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signals from the keepalive monitor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeepaliveSignal {
|
||||
/// Time to send a keepalive ping.
|
||||
SendPing,
|
||||
/// Peer is considered dead (no ACK received within timeout).
|
||||
/// Time to send a keepalive ping. Contains the timestamp (ms since epoch) to embed in payload.
|
||||
SendPing(u64),
|
||||
/// Peer is considered dead (no ACK received within timeout repeatedly).
|
||||
PeerDead,
|
||||
/// Link health state changed.
|
||||
LinkHealthChanged(LinkHealth),
|
||||
}
|
||||
|
||||
/// A keepalive monitor that emits signals on a channel.
|
||||
/// A keepalive monitor with adaptive interval and RTT tracking.
|
||||
pub struct KeepaliveMonitor {
|
||||
interval: Duration,
|
||||
timeout_duration: Duration,
|
||||
config: AdaptiveKeepaliveConfig,
|
||||
health: LinkHealth,
|
||||
rtt_tracker: RttTracker,
|
||||
signal_tx: mpsc::Sender<KeepaliveSignal>,
|
||||
ack_rx: mpsc::Receiver<()>,
|
||||
quality_tx: watch::Sender<ConnectionQuality>,
|
||||
consecutive_upgrade_checks: u32,
|
||||
}
|
||||
|
||||
/// Handle returned to the caller to send ACKs and receive signals.
|
||||
pub struct KeepaliveHandle {
|
||||
pub signal_rx: mpsc::Receiver<KeepaliveSignal>,
|
||||
pub ack_tx: mpsc::Sender<()>,
|
||||
pub quality_rx: watch::Receiver<ConnectionQuality>,
|
||||
}
|
||||
|
||||
/// Create a keepalive monitor and its handle.
|
||||
/// Create an adaptive keepalive monitor and its handle.
|
||||
pub fn create_keepalive(
|
||||
keepalive_interval: Option<Duration>,
|
||||
keepalive_timeout: Option<Duration>,
|
||||
config: Option<AdaptiveKeepaliveConfig>,
|
||||
) -> (KeepaliveMonitor, KeepaliveHandle) {
|
||||
let config = config.unwrap_or_default();
|
||||
let (signal_tx, signal_rx) = mpsc::channel(8);
|
||||
let (ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
let monitor = KeepaliveMonitor {
|
||||
interval: keepalive_interval.unwrap_or(DEFAULT_KEEPALIVE_INTERVAL),
|
||||
timeout_duration: keepalive_timeout.unwrap_or(DEFAULT_KEEPALIVE_TIMEOUT),
|
||||
config,
|
||||
health: LinkHealth::Degraded, // start in Degraded, earn Healthy
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
};
|
||||
|
||||
let handle = KeepaliveHandle { signal_rx, ack_tx };
|
||||
let handle = KeepaliveHandle {
|
||||
signal_rx,
|
||||
ack_tx,
|
||||
quality_rx,
|
||||
};
|
||||
|
||||
(monitor, handle)
|
||||
}
|
||||
|
||||
impl KeepaliveMonitor {
|
||||
fn current_interval(&self) -> Duration {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => self.config.healthy_interval,
|
||||
LinkHealth::Degraded => self.config.degraded_interval,
|
||||
LinkHealth::Critical => self.config.critical_interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the keepalive loop. Blocks until the peer is dead or channels close.
|
||||
pub async fn run(mut self) {
|
||||
let mut ticker = interval(self.interval);
|
||||
let mut ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
debug!("Sending keepalive ping signal");
|
||||
|
||||
if self.signal_tx.send(KeepaliveSignal::SendPing).await.is_err() {
|
||||
// Channel closed
|
||||
break;
|
||||
// Record ping sent, get timestamp for payload
|
||||
let timestamp_ms = self.rtt_tracker.mark_ping_sent();
|
||||
debug!("Sending keepalive ping (ts={})", timestamp_ms);
|
||||
|
||||
if self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::SendPing(timestamp_ms))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break; // channel closed
|
||||
}
|
||||
|
||||
// Wait for ACK within timeout
|
||||
match timeout(self.timeout_duration, self.ack_rx.recv()).await {
|
||||
match timeout(self.config.ack_timeout, self.ack_rx.recv()).await {
|
||||
Ok(Some(())) => {
|
||||
debug!("Keepalive ACK received");
|
||||
if let Some(rtt) = self.rtt_tracker.record_ack(timestamp_ms) {
|
||||
debug!("Keepalive ACK received, RTT: {:?}", rtt);
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Channel closed
|
||||
break;
|
||||
break; // channel closed
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Keepalive ACK timeout — peer considered dead");
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
self.rtt_tracker.record_timeout();
|
||||
warn!(
|
||||
"Keepalive ACK timeout (consecutive: {})",
|
||||
self.rtt_tracker.consecutive_timeouts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Publish quality snapshot
|
||||
let quality = self.rtt_tracker.snapshot();
|
||||
let _ = self.quality_tx.send(quality.clone());
|
||||
|
||||
// Evaluate state transition
|
||||
let new_health = self.evaluate_health(&quality);
|
||||
|
||||
if new_health != self.health {
|
||||
info!("Link health: {} -> {}", self.health, new_health);
|
||||
self.health = new_health;
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
|
||||
// Reset ticker to new interval
|
||||
ticker = interval(self.current_interval());
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
let _ = self
|
||||
.signal_tx
|
||||
.send(KeepaliveSignal::LinkHealthChanged(new_health))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Check for dead peer in Critical state
|
||||
if self.health == LinkHealth::Critical
|
||||
&& self.rtt_tracker.consecutive_timeouts >= self.config.dead_peer_timeouts
|
||||
{
|
||||
warn!("Peer considered dead after {} consecutive timeouts in Critical state",
|
||||
self.rtt_tracker.consecutive_timeouts);
|
||||
let _ = self.signal_tx.send(KeepaliveSignal::PeerDead).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_health(&mut self, quality: &ConnectionQuality) -> LinkHealth {
|
||||
match self.health {
|
||||
LinkHealth::Healthy => {
|
||||
// Downgrade conditions
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
if quality.jitter_ms > self.config.jitter_degraded_ms
|
||||
|| quality.loss_ratio > self.config.loss_degraded
|
||||
|| quality.consecutive_timeouts >= 1
|
||||
{
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
LinkHealth::Healthy
|
||||
}
|
||||
LinkHealth::Degraded => {
|
||||
// Downgrade to Critical
|
||||
if quality.consecutive_timeouts >= 2 || quality.loss_ratio > self.config.loss_critical {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Critical;
|
||||
}
|
||||
// Upgrade to Healthy (with hysteresis)
|
||||
if quality.jitter_ms < self.config.jitter_healthy_ms
|
||||
&& quality.loss_ratio < self.config.loss_healthy
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= self.config.upgrade_checks {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Healthy;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Degraded
|
||||
}
|
||||
LinkHealth::Critical => {
|
||||
// Upgrade to Degraded (with hysteresis), never directly to Healthy
|
||||
if quality.loss_ratio < self.config.loss_recover
|
||||
&& quality.consecutive_timeouts == 0
|
||||
{
|
||||
self.consecutive_upgrade_checks += 1;
|
||||
if self.consecutive_upgrade_checks >= 2 {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
return LinkHealth::Degraded;
|
||||
}
|
||||
} else {
|
||||
self.consecutive_upgrade_checks = 0;
|
||||
}
|
||||
LinkHealth::Critical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let config = AdaptiveKeepaliveConfig::default();
|
||||
assert_eq!(config.healthy_interval, Duration::from_secs(60));
|
||||
assert_eq!(config.degraded_interval, Duration::from_secs(30));
|
||||
assert_eq!(config.critical_interval, Duration::from_secs(10));
|
||||
assert_eq!(config.ack_timeout, Duration::from_secs(5));
|
||||
assert_eq!(config.dead_peer_timeouts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_health_display() {
|
||||
assert_eq!(format!("{}", LinkHealth::Healthy), "healthy");
|
||||
assert_eq!(format!("{}", LinkHealth::Degraded), "degraded");
|
||||
assert_eq!(format!("{}", LinkHealth::Critical), "critical");
|
||||
}
|
||||
|
||||
// Helper to create a monitor for unit-testing evaluate_health
|
||||
fn make_test_monitor() -> KeepaliveMonitor {
|
||||
let (signal_tx, _signal_rx) = mpsc::channel(8);
|
||||
let (_ack_tx, ack_rx) = mpsc::channel(8);
|
||||
let (quality_tx, _quality_rx) = watch::channel(ConnectionQuality::default());
|
||||
|
||||
KeepaliveMonitor {
|
||||
config: AdaptiveKeepaliveConfig::default(),
|
||||
health: LinkHealth::Degraded,
|
||||
rtt_tracker: RttTracker::new(30),
|
||||
signal_tx,
|
||||
ack_rx,
|
||||
quality_tx,
|
||||
consecutive_upgrade_checks: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_jitter() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
jitter_ms: 60.0, // > 50ms threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_degraded_on_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.06, // > 5% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25, // > 20% threshold
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn healthy_to_critical_on_consecutive_timeouts() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
let q = ConnectionQuality {
|
||||
consecutive_timeouts: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let result = m.evaluate_health(&q);
|
||||
assert_eq!(result, LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good_quality = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 20.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Should require 3 consecutive good checks (default upgrade_checks)
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Degraded);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 2);
|
||||
assert_eq!(m.evaluate_health(&good_quality), LinkHealth::Healthy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_healthy_resets_on_bad_check() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let good = ConnectionQuality {
|
||||
jitter_ms: 10.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let bad = ConnectionQuality {
|
||||
jitter_ms: 60.0, // too high
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
m.evaluate_health(&good); // 1 check
|
||||
m.evaluate_health(&good); // 2 checks
|
||||
m.evaluate_health(&bad); // resets
|
||||
assert_eq!(m.consecutive_upgrade_checks, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_to_degraded_requires_hysteresis() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let recovering = ConnectionQuality {
|
||||
loss_ratio: 0.05, // < 10% recover threshold
|
||||
consecutive_timeouts: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Critical);
|
||||
assert_eq!(m.consecutive_upgrade_checks, 1);
|
||||
assert_eq!(m.evaluate_health(&recovering), LinkHealth::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_never_directly_to_healthy() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Critical;
|
||||
let perfect = ConnectionQuality {
|
||||
jitter_ms: 1.0,
|
||||
loss_ratio: 0.0,
|
||||
consecutive_timeouts: 0,
|
||||
srtt_ms: 10.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even with perfect quality, must go through Degraded first
|
||||
m.evaluate_health(&perfect); // 1
|
||||
let result = m.evaluate_health(&perfect); // 2 → Degraded
|
||||
assert_eq!(result, LinkHealth::Degraded);
|
||||
// Not Healthy yet
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_to_critical_on_high_loss() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Degraded;
|
||||
let q = ConnectionQuality {
|
||||
loss_ratio: 0.25,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(m.evaluate_health(&q), LinkHealth::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interval_matches_health() {
|
||||
let mut m = make_test_monitor();
|
||||
m.health = LinkHealth::Healthy;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(60));
|
||||
m.health = LinkHealth::Degraded;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(30));
|
||||
m.health = LinkHealth::Critical;
|
||||
assert_eq!(m.current_interval(), Duration::from_secs(10));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,9 +5,20 @@ pub mod management;
|
||||
pub mod codec;
|
||||
pub mod crypto;
|
||||
pub mod transport;
|
||||
pub mod transport_trait;
|
||||
pub mod quic_transport;
|
||||
pub mod keepalive;
|
||||
pub mod tunnel;
|
||||
pub mod network;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod reconnect;
|
||||
pub mod telemetry;
|
||||
pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
pub mod wireguard;
|
||||
pub mod client_registry;
|
||||
pub mod acl;
|
||||
pub mod proxy_protocol;
|
||||
pub mod userspace_nat;
|
||||
|
||||
@@ -7,6 +7,7 @@ use tracing::{info, error, warn};
|
||||
use crate::client::{ClientConfig, VpnClient};
|
||||
use crate::crypto;
|
||||
use crate::server::{ServerConfig, VpnServer};
|
||||
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig};
|
||||
|
||||
// ============================================================================
|
||||
// IPC protocol types
|
||||
@@ -93,6 +94,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
|
||||
let mut vpn_client = VpnClient::new();
|
||||
let mut vpn_server = VpnServer::new();
|
||||
let mut wg_client = WgClient::new();
|
||||
let mut wg_server = WgServer::new();
|
||||
|
||||
send_event_stdout("ready", serde_json::json!({ "mode": mode }));
|
||||
|
||||
@@ -127,8 +130,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
};
|
||||
|
||||
let response = match mode {
|
||||
"client" => handle_client_request(&request, &mut vpn_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server).await,
|
||||
"client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await,
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
send_response_stdout(&response);
|
||||
@@ -150,6 +153,8 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
// Shared state behind Mutex for socket mode (multiple connections)
|
||||
let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new()));
|
||||
let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new()));
|
||||
let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new()));
|
||||
let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new()));
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
@@ -157,9 +162,11 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
let mode = mode.to_string();
|
||||
let client = vpn_client.clone();
|
||||
let server = vpn_server.clone();
|
||||
let wg_c = wg_client.clone();
|
||||
let wg_s = wg_server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_socket_connection(stream, &mode, client, server).await
|
||||
handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await
|
||||
{
|
||||
warn!("Socket connection error: {}", e);
|
||||
}
|
||||
@@ -177,6 +184,8 @@ async fn handle_socket_connection(
|
||||
mode: &str,
|
||||
vpn_client: std::sync::Arc<Mutex<VpnClient>>,
|
||||
vpn_server: std::sync::Arc<Mutex<VpnServer>>,
|
||||
wg_client: std::sync::Arc<Mutex<WgClient>>,
|
||||
wg_server: std::sync::Arc<Mutex<WgServer>>,
|
||||
) -> Result<()> {
|
||||
let (reader, mut writer) = stream.into_split();
|
||||
let buf_reader = BufReader::new(reader);
|
||||
@@ -227,11 +236,13 @@ async fn handle_socket_connection(
|
||||
let response = match mode {
|
||||
"client" => {
|
||||
let mut client = vpn_client.lock().await;
|
||||
handle_client_request(&request, &mut client).await
|
||||
let mut wg_c = wg_client.lock().await;
|
||||
handle_client_request(&request, &mut client, &mut wg_c).await
|
||||
}
|
||||
"server" => {
|
||||
let mut server = vpn_server.lock().await;
|
||||
handle_server_request(&request, &mut server).await
|
||||
let mut wg_s = wg_server.lock().await;
|
||||
handle_server_request(&request, &mut server, &mut wg_s).await
|
||||
}
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
@@ -252,38 +263,112 @@ async fn handle_socket_connection(
|
||||
async fn handle_client_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_client: &mut VpnClient,
|
||||
wg_client: &mut WgClient,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"connect" => {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transport is "wireguard"
|
||||
let transport = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transport"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
if transport == "wireguard" {
|
||||
let config: WgClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("WG connect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnect" => {
|
||||
if wg_client.is_running() {
|
||||
match wg_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG disconnect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disconnect" => match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_status().await)
|
||||
} else {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
}
|
||||
}
|
||||
"getConnectionQuality" => {
|
||||
match vpn_client.get_connection_quality() {
|
||||
Some(quality) => {
|
||||
let health = vpn_client.get_link_health().await;
|
||||
let interval_secs = match health {
|
||||
crate::keepalive::LinkHealth::Healthy => 60,
|
||||
crate::keepalive::LinkHealth::Degraded => 30,
|
||||
crate::keepalive::LinkHealth::Critical => 10,
|
||||
};
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"srttMs": quality.srtt_ms,
|
||||
"jitterMs": quality.jitter_ms,
|
||||
"minRttMs": quality.min_rtt_ms,
|
||||
"maxRttMs": quality.max_rtt_ms,
|
||||
"lossRatio": quality.loss_ratio,
|
||||
"consecutiveTimeouts": quality.consecutive_timeouts,
|
||||
"linkHealth": format!("{}", health),
|
||||
"currentKeepaliveIntervalSecs": interval_secs,
|
||||
}))
|
||||
}
|
||||
None => ManagementResponse::ok(id, serde_json::json!(null)),
|
||||
}
|
||||
}
|
||||
"getMtuInfo" => {
|
||||
ManagementResponse::ok(id, serde_json::json!({
|
||||
"tunMtu": 1420,
|
||||
"effectiveMtu": crate::mtu::TunnelOverhead::default_overhead().effective_tun_mtu(1500),
|
||||
"linkMtu": 1500,
|
||||
"overheadBytes": crate::mtu::TunnelOverhead::default_overhead().total(),
|
||||
"oversizedPacketsDropped": 0,
|
||||
"icmpTooBigSent": 0,
|
||||
}))
|
||||
}
|
||||
_ => ManagementResponse::err(id, format!("Unknown client method: {}", request.method)),
|
||||
}
|
||||
@@ -296,45 +381,92 @@ async fn handle_client_request(
|
||||
async fn handle_server_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_server: &mut VpnServer,
|
||||
wg_server: &mut WgServer,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transportMode is "wireguard"
|
||||
let transport_mode = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transportMode"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
if transport_mode == "wireguard" {
|
||||
let config: WgServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
if wg_server.is_running() {
|
||||
match wg_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_status())
|
||||
} else {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"listClients" => {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnectClient" => {
|
||||
@@ -349,6 +481,50 @@ async fn handle_server_request(
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"setClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let rate = match request.params.get("rateBytesPerSec").and_then(|v| v.as_u64()) {
|
||||
Some(r) => r,
|
||||
None => return ManagementResponse::err(id, "Missing rateBytesPerSec".to_string()),
|
||||
};
|
||||
let burst = match request.params.get("burstBytes").and_then(|v| v.as_u64()) {
|
||||
Some(b) => b,
|
||||
None => return ManagementResponse::err(id, "Missing burstBytes".to_string()),
|
||||
};
|
||||
match vpn_server.set_client_rate_limit(&client_id, rate, burst).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClientRateLimit" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_client_rate_limit(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClientTelemetry" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match clients.into_iter().find(|c| c.client_id == client_id) {
|
||||
Some(info) => {
|
||||
match serde_json::to_value(&info) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id, format!("Client {} not found", client_id)),
|
||||
}
|
||||
}
|
||||
"generateKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
@@ -359,6 +535,153 @@ async fn handle_server_request(
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
"generateWgKeypair" => {
|
||||
let (public_key, private_key) = wireguard::generate_wg_keypair();
|
||||
ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
)
|
||||
}
|
||||
"addWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let config: WgPeerConfig = match serde_json::from_value(
|
||||
request.params.get("peer").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid peer config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.add_peer(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing publicKey".to_string()),
|
||||
};
|
||||
match wg_server.remove_peer(&public_key).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listWgPeers" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
// ── Client Registry (Hub) Commands ────────────────────────────────
|
||||
"createClient" => {
|
||||
let client_partial = request.params.get("client").cloned().unwrap_or_default();
|
||||
match vpn_server.create_client(client_partial).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Create client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_registered_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.get_registered_client(&client_id).await {
|
||||
Ok(entry) => ManagementResponse::ok(id, entry),
|
||||
Err(e) => ManagementResponse::err(id, format!("Get client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listRegisteredClients" => {
|
||||
let clients = vpn_server.list_registered_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"updateClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let update = request.params.get("update").cloned().unwrap_or_default();
|
||||
match vpn_server.update_registered_client(&client_id, update).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Update client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"enableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.enable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Enable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.disable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"rotateClientKey" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.rotate_client_key(&client_id).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Key rotation failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"exportClientConfig" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let format = request.params.get("format").and_then(|v| v.as_str()).unwrap_or("smartvpn");
|
||||
match vpn_server.export_client_config(&client_id, format).await {
|
||||
Ok(config) => ManagementResponse::ok(id, config),
|
||||
Err(e) => ManagementResponse::err(id, format!("Export failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"generateClientKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
_ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
314
rust/src/mtu.rs
Normal file
314
rust/src/mtu.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
/// Overhead breakdown for VPN tunnel encapsulation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TunnelOverhead {
|
||||
/// Outer IP header: 20 bytes (IPv4, no options).
|
||||
pub ip_header: u16,
|
||||
/// TCP header: typically 32 bytes (20 base + 12 for timestamps).
|
||||
pub tcp_header: u16,
|
||||
/// WebSocket framing: ~6 bytes (2 base + 4 mask from client).
|
||||
pub ws_framing: u16,
|
||||
/// VPN binary frame header: 5 bytes [type:1B][length:4B].
|
||||
pub vpn_header: u16,
|
||||
/// Noise AEAD tag: 16 bytes (Poly1305).
|
||||
pub noise_tag: u16,
|
||||
}
|
||||
|
||||
impl TunnelOverhead {
|
||||
/// Conservative default overhead estimate.
|
||||
pub fn default_overhead() -> Self {
|
||||
Self {
|
||||
ip_header: 20,
|
||||
tcp_header: 32,
|
||||
ws_framing: 6,
|
||||
vpn_header: 5,
|
||||
noise_tag: 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total encapsulation overhead in bytes.
|
||||
pub fn total(&self) -> u16 {
|
||||
self.ip_header + self.tcp_header + self.ws_framing + self.vpn_header + self.noise_tag
|
||||
}
|
||||
|
||||
/// Compute effective TUN MTU given the underlying link MTU.
|
||||
pub fn effective_tun_mtu(&self, link_mtu: u16) -> u16 {
|
||||
link_mtu.saturating_sub(self.total())
|
||||
}
|
||||
}
|
||||
|
||||
/// MTU configuration for the VPN tunnel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MtuConfig {
|
||||
/// Underlying link MTU (typically 1500 for Ethernet).
|
||||
pub link_mtu: u16,
|
||||
/// Computed effective TUN MTU.
|
||||
pub effective_mtu: u16,
|
||||
/// Whether to generate ICMP too-big for oversized packets.
|
||||
pub send_icmp_too_big: bool,
|
||||
/// Counter: oversized packets encountered.
|
||||
pub oversized_packets: u64,
|
||||
/// Counter: ICMP too-big messages generated.
|
||||
pub icmp_too_big_sent: u64,
|
||||
}
|
||||
|
||||
impl MtuConfig {
|
||||
/// Create a new MTU config from the underlying link MTU.
|
||||
pub fn new(link_mtu: u16) -> Self {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let effective = overhead.effective_tun_mtu(link_mtu);
|
||||
Self {
|
||||
link_mtu,
|
||||
effective_mtu: effective,
|
||||
send_icmp_too_big: true,
|
||||
oversized_packets: 0,
|
||||
icmp_too_big_sent: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a packet exceeds the effective MTU.
|
||||
pub fn is_oversized(&self, packet_len: usize) -> bool {
|
||||
packet_len > self.effective_mtu as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Action to take after checking MTU.
|
||||
pub enum MtuAction {
|
||||
/// Packet is within MTU, forward normally.
|
||||
Forward,
|
||||
/// Packet is oversized; contains the ICMP too-big message to write back into TUN.
|
||||
SendIcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check packet against MTU config and return the appropriate action.
|
||||
pub fn check_mtu(packet: &[u8], config: &MtuConfig) -> MtuAction {
|
||||
if !config.is_oversized(packet.len()) {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
if !config.send_icmp_too_big {
|
||||
return MtuAction::Forward;
|
||||
}
|
||||
|
||||
match generate_icmp_too_big(packet, config.effective_mtu) {
|
||||
Some(icmp) => MtuAction::SendIcmpTooBig(icmp),
|
||||
None => MtuAction::Forward,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an ICMPv4 Destination Unreachable / Fragmentation Needed message.
|
||||
///
|
||||
/// Per RFC 792: Type 3, Code 4, with next-hop MTU in bytes 6-7 (RFC 1191).
|
||||
/// Returns the complete IP + ICMP packet to write back into the TUN device.
|
||||
pub fn generate_icmp_too_big(original_packet: &[u8], next_hop_mtu: u16) -> Option<Vec<u8>> {
|
||||
// Need at least 20 bytes of original IP header
|
||||
if original_packet.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Verify it's IPv4
|
||||
if original_packet[0] >> 4 != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse source/dest from original IP header
|
||||
let src_ip = Ipv4Addr::new(
|
||||
original_packet[12],
|
||||
original_packet[13],
|
||||
original_packet[14],
|
||||
original_packet[15],
|
||||
);
|
||||
let dst_ip = Ipv4Addr::new(
|
||||
original_packet[16],
|
||||
original_packet[17],
|
||||
original_packet[18],
|
||||
original_packet[19],
|
||||
);
|
||||
|
||||
// ICMP payload: IP header + first 8 bytes of original datagram (per RFC 792)
|
||||
let icmp_data_len = original_packet.len().min(28); // 20 IP header + 8 bytes
|
||||
let icmp_payload = &original_packet[..icmp_data_len];
|
||||
|
||||
// Build ICMP message: type(1) + code(1) + checksum(2) + unused(2) + next_hop_mtu(2) + data
|
||||
let mut icmp = Vec::with_capacity(8 + icmp_data_len);
|
||||
icmp.push(3); // Type: Destination Unreachable
|
||||
icmp.push(4); // Code: Fragmentation Needed and DF was Set
|
||||
icmp.push(0); // Checksum placeholder
|
||||
icmp.push(0);
|
||||
icmp.push(0); // Unused
|
||||
icmp.push(0);
|
||||
icmp.extend_from_slice(&next_hop_mtu.to_be_bytes());
|
||||
icmp.extend_from_slice(icmp_payload);
|
||||
|
||||
// Compute ICMP checksum
|
||||
let cksum = internet_checksum(&icmp);
|
||||
icmp[2] = (cksum >> 8) as u8;
|
||||
icmp[3] = (cksum & 0xff) as u8;
|
||||
|
||||
// Build IP header (ICMP response: FROM tunnel gateway TO original source)
|
||||
let total_len = (20 + icmp.len()) as u16;
|
||||
let mut ip = Vec::with_capacity(total_len as usize);
|
||||
ip.push(0x45); // Version 4, IHL 5
|
||||
ip.push(0x00); // DSCP/ECN
|
||||
ip.extend_from_slice(&total_len.to_be_bytes());
|
||||
ip.extend_from_slice(&[0, 0]); // Identification
|
||||
ip.extend_from_slice(&[0x40, 0x00]); // Flags: Don't Fragment, Fragment Offset: 0
|
||||
ip.push(64); // TTL
|
||||
ip.push(1); // Protocol: ICMP
|
||||
ip.extend_from_slice(&[0, 0]); // Header checksum placeholder
|
||||
ip.extend_from_slice(&dst_ip.octets()); // Source: tunnel endpoint (was dst)
|
||||
ip.extend_from_slice(&src_ip.octets()); // Destination: original source
|
||||
|
||||
// Compute IP header checksum
|
||||
let ip_cksum = internet_checksum(&ip[..20]);
|
||||
ip[10] = (ip_cksum >> 8) as u8;
|
||||
ip[11] = (ip_cksum & 0xff) as u8;
|
||||
|
||||
ip.extend_from_slice(&icmp);
|
||||
Some(ip)
|
||||
}
|
||||
|
||||
/// Standard Internet checksum (RFC 1071).
|
||||
fn internet_checksum(data: &[u8]) -> u16 {
|
||||
let mut sum: u32 = 0;
|
||||
let mut i = 0;
|
||||
while i + 1 < data.len() {
|
||||
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
|
||||
i += 2;
|
||||
}
|
||||
if i < data.len() {
|
||||
sum += (data[i] as u32) << 8;
|
||||
}
|
||||
while sum >> 16 != 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16);
|
||||
}
|
||||
!sum as u16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_overhead_total() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
assert_eq!(oh.total(), 79); // 20+32+6+5+16
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_for_ethernet() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(1500);
|
||||
assert_eq!(mtu, 1421); // 1500 - 79
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_mtu_saturates_at_zero() {
|
||||
let oh = TunnelOverhead::default_overhead();
|
||||
let mtu = oh.effective_tun_mtu(50); // Less than overhead
|
||||
assert_eq!(mtu, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mtu_config_default() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert_eq!(config.effective_mtu, 1421);
|
||||
assert_eq!(config.link_mtu, 1500);
|
||||
assert!(config.send_icmp_too_big);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_oversized() {
|
||||
let config = MtuConfig::new(1500);
|
||||
assert!(!config.is_oversized(1421));
|
||||
assert!(config.is_oversized(1422));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_generation() {
|
||||
// Craft a minimal IPv4 packet
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45; // version 4, IHL 5
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes()); // total length
|
||||
original[9] = 6; // TCP
|
||||
original[12..16].copy_from_slice(&[10, 0, 0, 1]); // src IP
|
||||
original[16..20].copy_from_slice(&[10, 0, 0, 2]); // dst IP
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1421).unwrap();
|
||||
|
||||
// Verify it's a valid IPv4 packet
|
||||
assert_eq!(icmp_pkt[0] >> 4, 4); // IPv4
|
||||
assert_eq!(icmp_pkt[9], 1); // ICMP protocol
|
||||
|
||||
// Source should be original dst (10.0.0.2)
|
||||
assert_eq!(&icmp_pkt[12..16], &[10, 0, 0, 2]);
|
||||
// Destination should be original src (10.0.0.1)
|
||||
assert_eq!(&icmp_pkt[16..20], &[10, 0, 0, 1]);
|
||||
|
||||
// ICMP type 3, code 4
|
||||
assert_eq!(icmp_pkt[20], 3);
|
||||
assert_eq!(icmp_pkt[21], 4);
|
||||
|
||||
// Next-hop MTU at ICMP bytes 6-7 (offset 26-27 in IP packet)
|
||||
let mtu = u16::from_be_bytes([icmp_pkt[26], icmp_pkt[27]]);
|
||||
assert_eq!(mtu, 1421);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_short_packet() {
|
||||
let short = vec![0u8; 10];
|
||||
assert!(generate_icmp_too_big(&short, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_too_big_rejects_non_ipv4() {
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6
|
||||
assert!(generate_icmp_too_big(&pkt, 1421).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icmp_checksum_valid() {
|
||||
let mut original = vec![0u8; 28];
|
||||
original[0] = 0x45;
|
||||
original[2..4].copy_from_slice(&1500u16.to_be_bytes());
|
||||
original[9] = 6;
|
||||
original[12..16].copy_from_slice(&[192, 168, 1, 100]);
|
||||
original[16..20].copy_from_slice(&[10, 8, 0, 1]);
|
||||
|
||||
let icmp_pkt = generate_icmp_too_big(&original, 1420).unwrap();
|
||||
|
||||
// Verify IP header checksum
|
||||
let ip_cksum = internet_checksum(&icmp_pkt[..20]);
|
||||
assert_eq!(ip_cksum, 0, "IP header checksum should verify to 0");
|
||||
|
||||
// Verify ICMP checksum
|
||||
let icmp_cksum = internet_checksum(&icmp_pkt[20..]);
|
||||
assert_eq!(icmp_cksum, 0, "ICMP checksum should verify to 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_forward() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let pkt = vec![0u8; 1421]; // Exactly at MTU
|
||||
assert!(matches!(check_mtu(&pkt, &config), MtuAction::Forward));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_mtu_oversized_generates_icmp() {
|
||||
let config = MtuConfig::new(1500);
|
||||
let mut pkt = vec![0u8; 1500];
|
||||
pkt[0] = 0x45; // Valid IPv4
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
|
||||
match check_mtu(&pkt, &config) {
|
||||
MtuAction::SendIcmpTooBig(icmp) => {
|
||||
assert_eq!(icmp[20], 3); // ICMP type
|
||||
assert_eq!(icmp[21], 4); // ICMP code
|
||||
}
|
||||
MtuAction::Forward => panic!("Expected SendIcmpTooBig"),
|
||||
}
|
||||
}
|
||||
}
|
||||
261
rust/src/proxy_protocol.rs
Normal file
261
rust/src/proxy_protocol.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
//! PROXY protocol v2 parser for extracting real client addresses
|
||||
//! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.).
|
||||
//!
|
||||
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
|
||||
|
||||
use anyhow::Result;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Timeout for reading the PROXY protocol header from a new connection.
|
||||
const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// The 12-byte PP v2 signature.
|
||||
const PP_V2_SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
/// Parsed PROXY protocol v2 header.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyHeader {
|
||||
/// Real client source address.
|
||||
pub src_addr: SocketAddr,
|
||||
/// Proxy-to-server destination address.
|
||||
pub dst_addr: SocketAddr,
|
||||
/// True if this is a LOCAL command (health check probe from proxy).
|
||||
pub is_local: bool,
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
///
|
||||
/// Reads exactly the header bytes — the stream is in a clean state for
|
||||
/// WebSocket upgrade afterward. Returns an error on timeout, invalid
|
||||
/// signature, or malformed header.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream))
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))?
|
||||
}
|
||||
|
||||
async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
// Read the 16-byte fixed prefix
|
||||
let mut prefix = [0u8; 16];
|
||||
stream.read_exact(&mut prefix).await?;
|
||||
|
||||
// Validate the 12-byte signature
|
||||
if prefix[..12] != PP_V2_SIGNATURE {
|
||||
anyhow::bail!("Invalid PROXY protocol v2 signature");
|
||||
}
|
||||
|
||||
// Byte 12: version (high nibble) | command (low nibble)
|
||||
let version = (prefix[12] & 0xF0) >> 4;
|
||||
let command = prefix[12] & 0x0F;
|
||||
|
||||
if version != 2 {
|
||||
anyhow::bail!("Unsupported PROXY protocol version: {}", version);
|
||||
}
|
||||
|
||||
// Byte 13: address family (high nibble) | protocol (low nibble)
|
||||
let addr_family = (prefix[13] & 0xF0) >> 4;
|
||||
let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP)
|
||||
|
||||
// Bytes 14-15: address data length (big-endian)
|
||||
let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize;
|
||||
|
||||
// Read the address data
|
||||
let mut addr_data = vec![0u8; addr_len];
|
||||
if addr_len > 0 {
|
||||
stream.read_exact(&mut addr_data).await?;
|
||||
}
|
||||
|
||||
// LOCAL command (0x00) = health check, no real address
|
||||
if command == 0x00 {
|
||||
return Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
is_local: true,
|
||||
});
|
||||
}
|
||||
|
||||
// PROXY command (0x01) — parse address block
|
||||
if command != 0x01 {
|
||||
anyhow::bail!("Unknown PROXY protocol command: {}", command);
|
||||
}
|
||||
|
||||
match addr_family {
|
||||
// AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes
|
||||
1 => {
|
||||
if addr_data.len() < 12 {
|
||||
anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
|
||||
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes
|
||||
2 => {
|
||||
if addr_data.len() < 36 {
|
||||
anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
|
||||
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)),
|
||||
dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_UNSPEC or unknown
|
||||
_ => {
|
||||
anyhow::bail!("Unsupported address family: {}", addr_family);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
|
||||
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
|
||||
match (src, dst) {
|
||||
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x11); // AF_INET | STREAM
|
||||
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x21); // AF_INET6 | STREAM
|
||||
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
_ => panic!("Mismatched address families"),
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 LOCAL header (health check probe).
|
||||
pub fn build_pp_v2_local() -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
buf.push(0x20); // version 2 | LOCAL command
|
||||
buf.push(0x00); // AF_UNSPEC
|
||||
buf.extend_from_slice(&0u16.to_be_bytes()); // no address data
|
||||
buf
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Helper: create a TCP pair and write data to the client side, then parse from server side.
|
||||
async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result<ProxyHeader> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let data = header_bytes.to_vec();
|
||||
let client_task = tokio::spawn(async move {
|
||||
let mut client = TcpStream::connect(addr).await.unwrap();
|
||||
client.write_all(&data).await.unwrap();
|
||||
client // keep alive
|
||||
});
|
||||
|
||||
let (mut server_stream, _) = listener.accept().await.unwrap();
|
||||
let result = read_proxy_header(&mut server_stream).await;
|
||||
let _client = client_task.await.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv4_header() {
|
||||
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
|
||||
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv6_header() {
|
||||
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[2001:db8::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_local_command() {
|
||||
let header = build_pp_v2_local();
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(parsed.is_local);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_invalid_signature() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[0] = 0xFF; // corrupt signature
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("signature"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_wrong_version() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[12] = 0x10; // version 1 instead of 2
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("version"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_truncated_header() {
|
||||
// Only 10 bytes — not even the full signature
|
||||
let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv4_header_is_exactly_28_bytes() {
|
||||
let src = "1.2.3.4:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "5.6.7.8:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28
|
||||
assert_eq!(header.len(), 28);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6_header_is_exactly_52_bytes() {
|
||||
let src = "[::1]:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52
|
||||
assert_eq!(header.len(), 52);
|
||||
}
|
||||
}
|
||||
490
rust/src/qos.rs
Normal file
490
rust/src/qos.rs
Normal file
@@ -0,0 +1,490 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Priority levels for IP packets.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(u8)]
|
||||
pub enum Priority {
|
||||
High = 0,
|
||||
Normal = 1,
|
||||
Low = 2,
|
||||
}
|
||||
|
||||
/// QoS statistics per priority level.
|
||||
pub struct QosStats {
|
||||
pub high_enqueued: AtomicU64,
|
||||
pub normal_enqueued: AtomicU64,
|
||||
pub low_enqueued: AtomicU64,
|
||||
pub high_dropped: AtomicU64,
|
||||
pub normal_dropped: AtomicU64,
|
||||
pub low_dropped: AtomicU64,
|
||||
}
|
||||
|
||||
impl QosStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
high_enqueued: AtomicU64::new(0),
|
||||
normal_enqueued: AtomicU64::new(0),
|
||||
low_enqueued: AtomicU64::new(0),
|
||||
high_dropped: AtomicU64::new(0),
|
||||
normal_dropped: AtomicU64::new(0),
|
||||
low_dropped: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QosStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Packet classification
|
||||
// ============================================================================
|
||||
|
||||
/// 5-tuple flow key for tracking bulk flows.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct FlowKey {
|
||||
src_ip: u32,
|
||||
dst_ip: u32,
|
||||
src_port: u16,
|
||||
dst_port: u16,
|
||||
protocol: u8,
|
||||
}
|
||||
|
||||
/// Per-flow state for bulk detection.
|
||||
struct FlowState {
|
||||
bytes_total: u64,
|
||||
window_start: Instant,
|
||||
}
|
||||
|
||||
/// Tracks per-flow byte counts for bulk flow detection.
|
||||
struct FlowTracker {
|
||||
flows: HashMap<FlowKey, FlowState>,
|
||||
window_duration: Duration,
|
||||
max_flows: usize,
|
||||
}
|
||||
|
||||
impl FlowTracker {
|
||||
fn new(window_duration: Duration, max_flows: usize) -> Self {
|
||||
Self {
|
||||
flows: HashMap::new(),
|
||||
window_duration,
|
||||
max_flows,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes for a flow. Returns true if the flow exceeds the threshold.
|
||||
fn record(&mut self, key: FlowKey, bytes: u64, threshold: u64) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
// Evict if at capacity — remove oldest entry
|
||||
if self.flows.len() >= self.max_flows && !self.flows.contains_key(&key) {
|
||||
if let Some(oldest_key) = self
|
||||
.flows
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.window_start)
|
||||
.map(|(k, _)| *k)
|
||||
{
|
||||
self.flows.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
let state = self.flows.entry(key).or_insert(FlowState {
|
||||
bytes_total: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
// Reset window if expired
|
||||
if now.duration_since(state.window_start) > self.window_duration {
|
||||
state.bytes_total = 0;
|
||||
state.window_start = now;
|
||||
}
|
||||
|
||||
state.bytes_total += bytes;
|
||||
state.bytes_total > threshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Classifies raw IP packets into priority levels.
|
||||
pub struct PacketClassifier {
|
||||
flow_tracker: FlowTracker,
|
||||
/// Byte threshold for classifying a flow as "bulk" (Low priority).
|
||||
bulk_threshold_bytes: u64,
|
||||
}
|
||||
|
||||
impl PacketClassifier {
|
||||
/// Create a new classifier.
|
||||
///
|
||||
/// - `bulk_threshold_bytes`: bytes per flow within window to trigger Low priority (default: 1MB)
|
||||
pub fn new(bulk_threshold_bytes: u64) -> Self {
|
||||
Self {
|
||||
flow_tracker: FlowTracker::new(Duration::from_secs(60), 10_000),
|
||||
bulk_threshold_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify a raw IPv4 packet into a priority level.
|
||||
///
|
||||
/// The packet must start with the IPv4 header (as read from a TUN device).
|
||||
pub fn classify(&mut self, ip_packet: &[u8]) -> Priority {
|
||||
// Need at least 20 bytes for a minimal IPv4 header
|
||||
if ip_packet.len() < 20 {
|
||||
return Priority::Normal;
|
||||
}
|
||||
|
||||
let version = ip_packet[0] >> 4;
|
||||
if version != 4 {
|
||||
return Priority::Normal; // Only classify IPv4 for now
|
||||
}
|
||||
|
||||
let ihl = (ip_packet[0] & 0x0F) as usize;
|
||||
let header_len = ihl * 4;
|
||||
let protocol = ip_packet[9];
|
||||
let total_len = u16::from_be_bytes([ip_packet[2], ip_packet[3]]) as usize;
|
||||
|
||||
// ICMP is always high priority
|
||||
if protocol == 1 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Small packets (<128 bytes) are high priority (likely interactive)
|
||||
if total_len < 128 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Extract ports for TCP (6) and UDP (17)
|
||||
let (src_port, dst_port) = if (protocol == 6 || protocol == 17)
|
||||
&& ip_packet.len() >= header_len + 4
|
||||
{
|
||||
let sp = u16::from_be_bytes([ip_packet[header_len], ip_packet[header_len + 1]]);
|
||||
let dp = u16::from_be_bytes([ip_packet[header_len + 2], ip_packet[header_len + 3]]);
|
||||
(sp, dp)
|
||||
} else {
|
||||
(0, 0)
|
||||
};
|
||||
|
||||
// DNS (port 53) and SSH (port 22) are high priority
|
||||
if dst_port == 53 || src_port == 53 || dst_port == 22 || src_port == 22 {
|
||||
return Priority::High;
|
||||
}
|
||||
|
||||
// Check for bulk flow
|
||||
if protocol == 6 || protocol == 17 {
|
||||
let src_ip = u32::from_be_bytes([ip_packet[12], ip_packet[13], ip_packet[14], ip_packet[15]]);
|
||||
let dst_ip = u32::from_be_bytes([ip_packet[16], ip_packet[17], ip_packet[18], ip_packet[19]]);
|
||||
|
||||
let key = FlowKey {
|
||||
src_ip,
|
||||
dst_ip,
|
||||
src_port,
|
||||
dst_port,
|
||||
protocol,
|
||||
};
|
||||
|
||||
if self.flow_tracker.record(key, total_len as u64, self.bulk_threshold_bytes) {
|
||||
return Priority::Low;
|
||||
}
|
||||
}
|
||||
|
||||
Priority::Normal
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Priority channel set
|
||||
// ============================================================================
|
||||
|
||||
/// Error returned when a packet is dropped.
|
||||
#[derive(Debug)]
|
||||
pub enum PacketDropped {
|
||||
LowPriorityDrop,
|
||||
NormalPriorityDrop,
|
||||
HighPriorityDrop,
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
/// Sending half of the priority channel set.
|
||||
pub struct PrioritySender {
|
||||
high_tx: mpsc::Sender<Vec<u8>>,
|
||||
normal_tx: mpsc::Sender<Vec<u8>>,
|
||||
low_tx: mpsc::Sender<Vec<u8>>,
|
||||
stats: Arc<QosStats>,
|
||||
}
|
||||
|
||||
impl PrioritySender {
|
||||
/// Send a packet with the given priority. Implements smart dropping under backpressure.
|
||||
pub async fn send(&self, packet: Vec<u8>, priority: Priority) -> Result<(), PacketDropped> {
|
||||
let (tx, enqueued_counter) = match priority {
|
||||
Priority::High => (&self.high_tx, &self.stats.high_enqueued),
|
||||
Priority::Normal => (&self.normal_tx, &self.stats.normal_enqueued),
|
||||
Priority::Low => (&self.low_tx, &self.stats.low_enqueued),
|
||||
};
|
||||
|
||||
match tx.try_send(packet) {
|
||||
Ok(()) => {
|
||||
enqueued_counter.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(packet)) => {
|
||||
self.handle_backpressure(packet, priority).await
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => Err(PacketDropped::ChannelClosed),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_backpressure(
|
||||
&self,
|
||||
packet: Vec<u8>,
|
||||
priority: Priority,
|
||||
) -> Result<(), PacketDropped> {
|
||||
match priority {
|
||||
Priority::Low => {
|
||||
self.stats.low_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::LowPriorityDrop)
|
||||
}
|
||||
Priority::Normal => {
|
||||
self.stats.normal_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::NormalPriorityDrop)
|
||||
}
|
||||
Priority::High => {
|
||||
// Last resort: briefly wait for space, then drop
|
||||
match tokio::time::timeout(
|
||||
Duration::from_millis(5),
|
||||
self.high_tx.send(packet),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(())) => {
|
||||
self.stats.high_enqueued.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
self.stats.high_dropped.fetch_add(1, Ordering::Relaxed);
|
||||
Err(PacketDropped::HighPriorityDrop)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Receiving half of the priority channel set.
|
||||
pub struct PriorityReceiver {
|
||||
high_rx: mpsc::Receiver<Vec<u8>>,
|
||||
normal_rx: mpsc::Receiver<Vec<u8>>,
|
||||
low_rx: mpsc::Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl PriorityReceiver {
|
||||
/// Receive the next packet, draining high-priority first (biased select).
|
||||
pub async fn recv(&mut self) -> Option<Vec<u8>> {
|
||||
tokio::select! {
|
||||
biased;
|
||||
Some(pkt) = self.high_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.normal_rx.recv() => Some(pkt),
|
||||
Some(pkt) = self.low_rx.recv() => Some(pkt),
|
||||
else => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a priority channel set split into sender and receiver halves.
|
||||
///
|
||||
/// - `high_cap`: capacity of the high-priority channel
|
||||
/// - `normal_cap`: capacity of the normal-priority channel
|
||||
/// - `low_cap`: capacity of the low-priority channel
|
||||
pub fn create_priority_channels(
|
||||
high_cap: usize,
|
||||
normal_cap: usize,
|
||||
low_cap: usize,
|
||||
) -> (PrioritySender, PriorityReceiver) {
|
||||
let (high_tx, high_rx) = mpsc::channel(high_cap);
|
||||
let (normal_tx, normal_rx) = mpsc::channel(normal_cap);
|
||||
let (low_tx, low_rx) = mpsc::channel(low_cap);
|
||||
let stats = Arc::new(QosStats::new());
|
||||
|
||||
let sender = PrioritySender {
|
||||
high_tx,
|
||||
normal_tx,
|
||||
low_tx,
|
||||
stats,
|
||||
};
|
||||
|
||||
let receiver = PriorityReceiver {
|
||||
high_rx,
|
||||
normal_rx,
|
||||
low_rx,
|
||||
};
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
/// Get a reference to the QoS stats from a sender.
|
||||
impl PrioritySender {
|
||||
pub fn stats(&self) -> &Arc<QosStats> {
|
||||
&self.stats
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper: craft a minimal IPv4 packet
|
||||
fn make_ipv4_packet(protocol: u8, src_port: u16, dst_port: u16, total_len: u16) -> Vec<u8> {
|
||||
let mut pkt = vec![0u8; total_len.max(24) as usize];
|
||||
pkt[0] = 0x45; // version 4, IHL 5
|
||||
pkt[2..4].copy_from_slice(&total_len.to_be_bytes());
|
||||
pkt[9] = protocol;
|
||||
// src IP
|
||||
pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
|
||||
// dst IP
|
||||
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
|
||||
// ports (at offset 20 for IHL=5)
|
||||
pkt[20..22].copy_from_slice(&src_port.to_be_bytes());
|
||||
pkt[22..24].copy_from_slice(&dst_port.to_be_bytes());
|
||||
pkt
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_icmp_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(1, 0, 0, 64); // ICMP
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_dns_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(17, 12345, 53, 200); // UDP to port 53
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_ssh_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 54321, 22, 200); // TCP to port 22
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_small_packet_as_high() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 8080, 64); // Small TCP packet
|
||||
assert_eq!(c.classify(&pkt), Priority::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_normal_http() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500); // TCP to port 80, >128B
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_bulk_flow_as_low() {
|
||||
let mut c = PacketClassifier::new(10_000); // Low threshold for testing
|
||||
|
||||
// Send enough traffic to exceed the threshold
|
||||
for _ in 0..100 {
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
c.classify(&pkt);
|
||||
}
|
||||
|
||||
// Next packet from same flow should be Low
|
||||
let pkt = make_ipv4_packet(6, 12345, 80, 500);
|
||||
assert_eq!(c.classify(&pkt), Priority::Low);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_too_short_packet() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let pkt = vec![0u8; 10]; // Too short for IPv4 header
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_non_ipv4() {
|
||||
let mut c = PacketClassifier::new(1_000_000);
|
||||
let mut pkt = vec![0u8; 40];
|
||||
pkt[0] = 0x60; // IPv6 version nibble
|
||||
assert_eq!(c.classify(&pkt), Priority::Normal);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_receiver_drains_high_first() {
|
||||
let (sender, mut receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
// Enqueue in reverse order
|
||||
sender.send(vec![3], Priority::Low).await.unwrap();
|
||||
sender.send(vec![2], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
|
||||
// Should drain High first
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![1]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![2]);
|
||||
assert_eq!(receiver.recv().await.unwrap(), vec![3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_low_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 1);
|
||||
|
||||
// Fill the low channel
|
||||
sender.send(vec![0], Priority::Low).await.unwrap();
|
||||
|
||||
// Next low-priority send should be dropped
|
||||
let result = sender.send(vec![1], Priority::Low).await;
|
||||
assert!(matches!(result, Err(PacketDropped::LowPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().low_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn smart_dropping_normal_priority() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 1, 8);
|
||||
|
||||
sender.send(vec![0], Priority::Normal).await.unwrap();
|
||||
|
||||
let result = sender.send(vec![1], Priority::Normal).await;
|
||||
assert!(matches!(result, Err(PacketDropped::NormalPriorityDrop)));
|
||||
|
||||
assert_eq!(sender.stats().normal_dropped.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_track_enqueued() {
|
||||
let (sender, _receiver) = create_priority_channels(8, 8, 8);
|
||||
|
||||
sender.send(vec![1], Priority::High).await.unwrap();
|
||||
sender.send(vec![2], Priority::High).await.unwrap();
|
||||
sender.send(vec![3], Priority::Normal).await.unwrap();
|
||||
sender.send(vec![4], Priority::Low).await.unwrap();
|
||||
|
||||
assert_eq!(sender.stats().high_enqueued.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(sender.stats().normal_enqueued.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(sender.stats().low_enqueued.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flow_tracker_evicts_at_capacity() {
|
||||
let mut ft = FlowTracker::new(Duration::from_secs(60), 2);
|
||||
|
||||
let k1 = FlowKey { src_ip: 1, dst_ip: 2, src_port: 100, dst_port: 200, protocol: 6 };
|
||||
let k2 = FlowKey { src_ip: 3, dst_ip: 4, src_port: 300, dst_port: 400, protocol: 6 };
|
||||
let k3 = FlowKey { src_ip: 5, dst_ip: 6, src_port: 500, dst_port: 600, protocol: 6 };
|
||||
|
||||
ft.record(k1, 100, 1000);
|
||||
ft.record(k2, 100, 1000);
|
||||
// Should evict k1 (oldest)
|
||||
ft.record(k3, 100, 1000);
|
||||
|
||||
assert_eq!(ft.flows.len(), 2);
|
||||
assert!(!ft.flows.contains_key(&k1));
|
||||
}
|
||||
}
|
||||
546
rust/src/quic_transport.rs
Normal file
546
rust/src/quic_transport.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use quinn::crypto::rustls::QuicClientConfig;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn, debug};
|
||||
|
||||
use crate::transport_trait::{TransportSink, TransportStream};
|
||||
|
||||
// ============================================================================
|
||||
// TLS / Certificate helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a self-signed certificate and private key for QUIC.
|
||||
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
|
||||
let cert_der = CertificateDer::from(cert.cert);
|
||||
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
|
||||
Ok((vec![cert_der], key_der))
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
|
||||
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
|
||||
use ring::digest;
|
||||
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server-side QUIC endpoint
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for the QUIC server endpoint.
|
||||
pub struct QuicServerConfig {
|
||||
pub listen_addr: String,
|
||||
pub cert_chain: Vec<CertificateDer<'static>>,
|
||||
pub private_key: PrivateKeyDer<'static>,
|
||||
pub idle_timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// Create a QUIC server endpoint bound to the given address.
|
||||
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
|
||||
let addr: SocketAddr = config.listen_addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(config.cert_chain, config.private_key)?;
|
||||
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
|
||||
));
|
||||
// Enable datagrams with a generous max size
|
||||
transport.datagram_receive_buffer_size(Some(65535));
|
||||
transport.datagram_send_buffer_size(65535);
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
let endpoint = quinn::Endpoint::server(server_config, addr)?;
|
||||
info!("QUIC server listening on {}", addr);
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client-side QUIC connection
|
||||
// ============================================================================
|
||||
|
||||
/// A certificate verifier that accepts any server certificate.
|
||||
/// Safe when Noise NK provides server authentication at the application layer.
|
||||
#[derive(Debug)]
|
||||
struct AcceptAnyCert;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
|
||||
#[derive(Debug)]
|
||||
struct CertHashVerifier {
|
||||
expected_hash: String,
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
let actual_hash = cert_hash(end_entity);
|
||||
if actual_hash == self.expected_hash {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
} else {
|
||||
Err(rustls::Error::General(format!(
|
||||
"Certificate hash mismatch: expected {}, got {}",
|
||||
self.expected_hash, actual_hash
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
// QUIC always uses TLS 1.3
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a QUIC server.
|
||||
///
|
||||
/// - If `server_cert_hash` is provided, verifies the server certificate matches
|
||||
/// the given SHA-256 hash (cert pinning).
|
||||
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
|
||||
/// safe because the Noise NK handshake (which runs over the QUIC stream)
|
||||
/// authenticates the server via its pre-shared public key — the same trust
|
||||
/// model as WireGuard.
|
||||
pub async fn connect_quic(
|
||||
addr: &str,
|
||||
server_cert_hash: Option<&str>,
|
||||
) -> Result<quinn::Connection> {
|
||||
let remote: SocketAddr = addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let tls_config = if let Some(hash) = server_cert_hash {
|
||||
// Pin to a specific certificate hash
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash.to_string(),
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
} else {
|
||||
// Accept any cert — Noise NK provides server authentication
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
};
|
||||
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
info!("Connecting to QUIC server at {}", addr);
|
||||
let connection = endpoint.connect(remote, "smartvpn")?.await?;
|
||||
info!("QUIC connection established");
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QUIC Transport Sink / Stream implementations
|
||||
// ============================================================================
|
||||
|
||||
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportSink {
|
||||
send_stream: quinn::SendStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportSink {
|
||||
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
send_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for QuicTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// Length-prefix framing: [4-byte big-endian length][payload]
|
||||
let len = data.len() as u32;
|
||||
self.send_stream.write_all(&len.to_be_bytes()).await?;
|
||||
self.send_stream.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
let max_size = self.connection.max_datagram_size();
|
||||
match max_size {
|
||||
Some(max) if data.len() <= max => {
|
||||
self.connection.send_datagram(data.into())?;
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Datagram too large or datagrams disabled — fall back to reliable
|
||||
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.send_stream.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportStream {
|
||||
recv_stream: quinn::RecvStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportStream {
|
||||
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
recv_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for QuicTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// Read length prefix
|
||||
let mut len_buf = [0u8; 4];
|
||||
match self.recv_stream.read_exact(&mut len_buf).await {
|
||||
Ok(()) => {}
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
|
||||
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
|
||||
warn!("QUIC connection lost: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(len_buf) as usize;
|
||||
if len > 65536 {
|
||||
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
|
||||
}
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
match self.recv_stream.read_exact(&mut data).await {
|
||||
Ok(()) => Ok(Some(data)),
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
match self.connection.read_datagram().await {
|
||||
Ok(data) => Ok(Some(data.to_vec())),
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
|
||||
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
self.connection.max_datagram_size().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a QUIC connection and open a bidirectional control stream.
|
||||
/// Returns the transport sink/stream pair ready for the VPN handshake.
|
||||
pub async fn accept_quic_connection(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
// The client opens the bidirectional control stream
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
info!("QUIC bidirectional control stream accepted");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
/// Open a QUIC connection's bidirectional control stream (client side).
|
||||
pub async fn open_quic_streams(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
info!("QUIC bidirectional control stream opened");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cert_generation_and_hash() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
assert_eq!(certs.len(), 1);
|
||||
let hash = cert_hash(&certs[0]);
|
||||
// SHA-256 base64 is 44 characters
|
||||
assert_eq!(hash.len(), 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cert_hash_deterministic() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
let hash1 = cert_hash(&certs[0]);
|
||||
let hash2 = cert_hash(&certs[0]);
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
/// Helper: create QUIC server and client endpoints.
|
||||
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
|
||||
let (certs, key) = generate_self_signed_cert().unwrap();
|
||||
let hash = cert_hash(&certs[0]);
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
|
||||
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key).unwrap();
|
||||
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
|
||||
));
|
||||
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
|
||||
|
||||
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash,
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(client_tls).unwrap(),
|
||||
));
|
||||
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
client_ep.set_default_client_config(client_config);
|
||||
|
||||
let server_addr = server_ep.local_addr().unwrap().to_string();
|
||||
(server_ep, client_ep, server_addr)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_server_client_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi, read, echo, finish
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
|
||||
let data = s_recv.read_to_end(1024).await.unwrap();
|
||||
s_send.write_all(&data).await.unwrap();
|
||||
s_send.finish().unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open_bi, write, finish, read
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
|
||||
c_send.write_all(b"hello quinn").await.unwrap();
|
||||
c_send.finish().unwrap();
|
||||
let data = c_recv.read_to_end(1024).await.unwrap();
|
||||
assert_eq!(&data[..], b"hello quinn");
|
||||
|
||||
let _ = server_task.await;
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test transport trait wrappers over QUIC.
|
||||
/// Key: client must send data first (QUIC streams are opened implicitly by data).
|
||||
/// The server accept_bi runs concurrently with the client's first send_reliable.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_transport_trait_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server task: accept connection, then accept_bi (blocks until client sends data)
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
(s_sink, s_stream, server_ep)
|
||||
});
|
||||
|
||||
// Client: connect, open_bi via wrapper
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
|
||||
// Client sends first — this triggers the QUIC stream to become visible to the server
|
||||
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
|
||||
|
||||
// Now server's accept_bi unblocks
|
||||
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
|
||||
|
||||
// Server reads the message
|
||||
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-client");
|
||||
|
||||
// Server -> Client
|
||||
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
|
||||
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-server");
|
||||
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test QUIC datagram support.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_datagram_exchange() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi (opens control stream), then read datagram
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
// Accept bi stream (control channel)
|
||||
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
|
||||
// Read datagram
|
||||
let dgram = conn.read_datagram().await.unwrap();
|
||||
assert_eq!(&dgram[..], b"dgram-payload");
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
|
||||
|
||||
// Send initial data to open the stream (required for QUIC)
|
||||
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
|
||||
|
||||
// Small yield to let the server process the bi stream
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// Send datagram
|
||||
assert!(conn.max_datagram_size().is_some());
|
||||
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test that supports_datagrams returns true for QUIC transports.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_supports_datagrams() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
assert!(s_stream.supports_datagrams());
|
||||
server_ep
|
||||
});
|
||||
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
assert!(c_stream.supports_datagrams());
|
||||
|
||||
// Send data to trigger server's accept_bi
|
||||
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
}
|
||||
141
rust/src/ratelimit.rs
Normal file
141
rust/src/ratelimit.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use std::time::Instant;
|
||||
|
||||
/// A token bucket rate limiter operating on byte granularity.
|
||||
pub struct TokenBucket {
|
||||
/// Tokens (bytes) added per second.
|
||||
rate: f64,
|
||||
/// Maximum burst capacity in bytes.
|
||||
burst: f64,
|
||||
/// Currently available tokens.
|
||||
tokens: f64,
|
||||
/// Last time tokens were refilled.
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new token bucket.
|
||||
///
|
||||
/// - `rate_bytes_per_sec`: sustained rate in bytes/second
|
||||
/// - `burst_bytes`: maximum burst size in bytes (also the initial token count)
|
||||
pub fn new(rate_bytes_per_sec: u64, burst_bytes: u64) -> Self {
|
||||
let burst = burst_bytes as f64;
|
||||
Self {
|
||||
rate: rate_bytes_per_sec as f64,
|
||||
burst,
|
||||
tokens: burst, // start full
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to consume `bytes` tokens. Returns `true` if allowed, `false` if rate exceeded.
|
||||
pub fn try_consume(&mut self, bytes: usize) -> bool {
|
||||
self.refill();
|
||||
let needed = bytes as f64;
|
||||
if needed <= self.tokens {
|
||||
self.tokens -= needed;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Update rate and burst limits dynamically (for live IPC reconfiguration).
|
||||
pub fn update_limits(&mut self, rate_bytes_per_sec: u64, burst_bytes: u64) {
|
||||
self.rate = rate_bytes_per_sec as f64;
|
||||
self.burst = burst_bytes as f64;
|
||||
// Cap current tokens at new burst
|
||||
if self.tokens > self.burst {
|
||||
self.tokens = self.burst;
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
||||
self.last_refill = now;
|
||||
self.tokens = (self.tokens + elapsed * self.rate).min(self.burst);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn allows_traffic_under_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Should allow up to burst size immediately
|
||||
assert!(tb.try_consume(1_500_000));
|
||||
assert!(tb.try_consume(400_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocks_traffic_over_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
// Consume entire burst
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
// Next consume should fail (no time to refill)
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_consume_always_succeeds() {
|
||||
let mut tb = TokenBucket::new(0, 0);
|
||||
assert!(tb.try_consume(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refills_over_time() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000); // 1MB/s, 1MB burst
|
||||
// Drain completely
|
||||
assert!(tb.try_consume(1_000_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
|
||||
// Wait 100ms — should refill ~100KB
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
assert!(tb.try_consume(50_000)); // 50KB should be available after ~100ms at 1MB/s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_caps_tokens() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 2_000_000);
|
||||
// Tokens start at burst (2MB)
|
||||
tb.update_limits(500_000, 500_000);
|
||||
// Tokens should be capped to new burst (500KB)
|
||||
assert!(tb.try_consume(500_000));
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_limits_changes_rate() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000_000);
|
||||
assert!(tb.try_consume(1_000_000)); // drain
|
||||
|
||||
// Change to higher rate
|
||||
tb.update_limits(10_000_000, 10_000_000);
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
// At 10MB/s, 50ms should refill ~500KB
|
||||
assert!(tb.try_consume(200_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_rate_blocks_after_burst() {
|
||||
let mut tb = TokenBucket::new(0, 100);
|
||||
assert!(tb.try_consume(100));
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
// Zero rate means no refill
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
// Use a low rate so refill between consecutive calls is negligible
|
||||
let mut tb = TokenBucket::new(100, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
// At 100 bytes/sec, the few μs between calls add ~0 tokens
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
1082
rust/src/server.rs
1082
rust/src/server.rs
File diff suppressed because it is too large
Load Diff
317
rust/src/telemetry.rs
Normal file
317
rust/src/telemetry.rs
Normal file
@@ -0,0 +1,317 @@
|
||||
use serde::Serialize;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single RTT sample.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RttSample {
|
||||
_rtt: Duration,
|
||||
_timestamp: Instant,
|
||||
was_timeout: bool,
|
||||
}
|
||||
|
||||
/// Snapshot of connection quality metrics.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConnectionQuality {
|
||||
/// Smoothed RTT in milliseconds (EMA, RFC 6298 style).
|
||||
pub srtt_ms: f64,
|
||||
/// Jitter in milliseconds (mean deviation of RTT).
|
||||
pub jitter_ms: f64,
|
||||
/// Minimum RTT observed in the sample window.
|
||||
pub min_rtt_ms: f64,
|
||||
/// Maximum RTT observed in the sample window.
|
||||
pub max_rtt_ms: f64,
|
||||
/// Packet loss ratio over the sample window (0.0 - 1.0).
|
||||
pub loss_ratio: f64,
|
||||
/// Number of consecutive keepalive timeouts (0 if last succeeded).
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Total keepalives sent.
|
||||
pub keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
pub keepalives_acked: u64,
|
||||
}
|
||||
|
||||
/// Tracks connection quality from keepalive round-trips.
|
||||
pub struct RttTracker {
|
||||
/// Maximum number of samples to keep in the window.
|
||||
max_samples: usize,
|
||||
/// Recent RTT samples (including timeout markers).
|
||||
samples: VecDeque<RttSample>,
|
||||
/// When the last keepalive was sent (for computing RTT on ACK).
|
||||
pending_ping_sent_at: Option<Instant>,
|
||||
/// Number of consecutive keepalive timeouts.
|
||||
pub consecutive_timeouts: u32,
|
||||
/// Smoothed RTT (EMA).
|
||||
srtt: Option<f64>,
|
||||
/// Jitter (mean deviation).
|
||||
jitter: f64,
|
||||
/// Minimum RTT observed.
|
||||
min_rtt: f64,
|
||||
/// Maximum RTT observed.
|
||||
max_rtt: f64,
|
||||
/// Total keepalives sent.
|
||||
keepalives_sent: u64,
|
||||
/// Total keepalive ACKs received.
|
||||
keepalives_acked: u64,
|
||||
/// Previous RTT sample for jitter calculation.
|
||||
last_rtt_ms: Option<f64>,
|
||||
}
|
||||
|
||||
impl RttTracker {
|
||||
/// Create a new tracker with the given window size.
|
||||
pub fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
max_samples,
|
||||
samples: VecDeque::with_capacity(max_samples),
|
||||
pending_ping_sent_at: None,
|
||||
consecutive_timeouts: 0,
|
||||
srtt: None,
|
||||
jitter: 0.0,
|
||||
min_rtt: f64::MAX,
|
||||
max_rtt: 0.0,
|
||||
keepalives_sent: 0,
|
||||
keepalives_acked: 0,
|
||||
last_rtt_ms: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record that a keepalive was sent.
|
||||
/// Returns a millisecond timestamp (since UNIX epoch) to embed in the keepalive payload.
|
||||
pub fn mark_ping_sent(&mut self) -> u64 {
|
||||
self.pending_ping_sent_at = Some(Instant::now());
|
||||
self.keepalives_sent += 1;
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Record that a keepalive ACK was received with the echoed timestamp.
|
||||
/// Returns the computed RTT if a pending ping was recorded.
|
||||
pub fn record_ack(&mut self, _echoed_timestamp_ms: u64) -> Option<Duration> {
|
||||
let sent_at = self.pending_ping_sent_at.take()?;
|
||||
let rtt = sent_at.elapsed();
|
||||
let rtt_ms = rtt.as_secs_f64() * 1000.0;
|
||||
|
||||
self.keepalives_acked += 1;
|
||||
self.consecutive_timeouts = 0;
|
||||
|
||||
// Update SRTT (RFC 6298: alpha = 1/8)
|
||||
match self.srtt {
|
||||
None => {
|
||||
self.srtt = Some(rtt_ms);
|
||||
self.jitter = rtt_ms / 2.0;
|
||||
}
|
||||
Some(prev_srtt) => {
|
||||
// RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| (beta = 1/4)
|
||||
self.jitter = 0.75 * self.jitter + 0.25 * (prev_srtt - rtt_ms).abs();
|
||||
// SRTT = (1 - alpha) * SRTT + alpha * R (alpha = 1/8)
|
||||
self.srtt = Some(0.875 * prev_srtt + 0.125 * rtt_ms);
|
||||
}
|
||||
}
|
||||
|
||||
// Update min/max
|
||||
if rtt_ms < self.min_rtt {
|
||||
self.min_rtt = rtt_ms;
|
||||
}
|
||||
if rtt_ms > self.max_rtt {
|
||||
self.max_rtt = rtt_ms;
|
||||
}
|
||||
|
||||
self.last_rtt_ms = Some(rtt_ms);
|
||||
|
||||
// Push sample into window
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: rtt,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: false,
|
||||
});
|
||||
|
||||
Some(rtt)
|
||||
}
|
||||
|
||||
/// Record that a keepalive timed out (no ACK received).
|
||||
pub fn record_timeout(&mut self) {
|
||||
self.consecutive_timeouts += 1;
|
||||
self.pending_ping_sent_at = None;
|
||||
|
||||
if self.samples.len() >= self.max_samples {
|
||||
self.samples.pop_front();
|
||||
}
|
||||
self.samples.push_back(RttSample {
|
||||
_rtt: Duration::ZERO,
|
||||
_timestamp: Instant::now(),
|
||||
was_timeout: true,
|
||||
});
|
||||
}
|
||||
|
||||
/// Get a snapshot of the current connection quality.
|
||||
pub fn snapshot(&self) -> ConnectionQuality {
|
||||
let loss_ratio = if self.samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
let timeouts = self.samples.iter().filter(|s| s.was_timeout).count();
|
||||
timeouts as f64 / self.samples.len() as f64
|
||||
};
|
||||
|
||||
ConnectionQuality {
|
||||
srtt_ms: self.srtt.unwrap_or(0.0),
|
||||
jitter_ms: self.jitter,
|
||||
min_rtt_ms: if self.min_rtt == f64::MAX { 0.0 } else { self.min_rtt },
|
||||
max_rtt_ms: self.max_rtt,
|
||||
loss_ratio,
|
||||
consecutive_timeouts: self.consecutive_timeouts,
|
||||
keepalives_sent: self.keepalives_sent,
|
||||
keepalives_acked: self.keepalives_acked,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_tracker_has_zero_quality() {
|
||||
let tracker = RttTracker::new(30);
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.srtt_ms, 0.0);
|
||||
assert_eq!(q.jitter_ms, 0.0);
|
||||
assert_eq!(q.loss_ratio, 0.0);
|
||||
assert_eq!(q.consecutive_timeouts, 0);
|
||||
assert_eq!(q.keepalives_sent, 0);
|
||||
assert_eq!(q.keepalives_acked, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_ping_returns_timestamp() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
// Should be a reasonable epoch-ms value (after 2020)
|
||||
assert!(ts > 1_577_836_800_000);
|
||||
assert_eq!(tracker.keepalives_sent, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_computes_rtt() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
let rtt = tracker.record_ack(ts);
|
||||
assert!(rtt.is_some());
|
||||
let rtt = rtt.unwrap();
|
||||
assert!(rtt.as_millis() >= 4); // at least ~5ms minus scheduling jitter
|
||||
assert_eq!(tracker.keepalives_acked, 1);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_ack_without_pending_returns_none() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
assert!(tracker.record_ack(12345).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn srtt_converges() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// Simulate several ping/ack cycles with ~10ms RTT
|
||||
for _ in 0..10 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
|
||||
let q = tracker.snapshot();
|
||||
// SRTT should be roughly 10ms (allowing for scheduling variance)
|
||||
assert!(q.srtt_ms > 5.0, "SRTT too low: {}", q.srtt_ms);
|
||||
assert!(q.srtt_ms < 50.0, "SRTT too high: {}", q.srtt_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timeout_increments_counter_and_loss() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 2);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert_eq!(q.loss_ratio, 1.0); // 2 timeouts out of 2 samples
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ack_resets_consecutive_timeouts() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
assert_eq!(tracker.consecutive_timeouts, 1);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
assert_eq!(tracker.consecutive_timeouts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_ratio_over_mixed_window() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
// 3 successful, 1 timeout, 1 successful = 1/5 = 0.2 loss
|
||||
for _ in 0..3 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!((q.loss_ratio - 0.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_evicts_old_samples() {
|
||||
let mut tracker = RttTracker::new(5);
|
||||
|
||||
// Fill window with 5 timeouts
|
||||
for _ in 0..5 {
|
||||
tracker.mark_ping_sent();
|
||||
tracker.record_timeout();
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 1.0);
|
||||
|
||||
// Add 5 successes — should evict all timeouts
|
||||
for _ in 0..5 {
|
||||
let ts = tracker.mark_ping_sent();
|
||||
tracker.record_ack(ts);
|
||||
}
|
||||
assert_eq!(tracker.snapshot().loss_ratio, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_rtt_tracked() {
|
||||
let mut tracker = RttTracker::new(30);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let ts = tracker.mark_ping_sent();
|
||||
std::thread::sleep(Duration::from_millis(15));
|
||||
tracker.record_ack(ts);
|
||||
|
||||
let q = tracker.snapshot();
|
||||
assert!(q.min_rtt_ms < q.max_rtt_ms);
|
||||
assert!(q.min_rtt_ms > 0.0);
|
||||
}
|
||||
}
|
||||
116
rust/src/transport_trait.rs
Normal file
116
rust/src/transport_trait.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::transport::WsStream;
|
||||
|
||||
// ============================================================================
|
||||
// Transport trait abstraction
|
||||
// ============================================================================
|
||||
|
||||
/// Outbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportSink: Send + 'static {
|
||||
/// Send a framed binary message on the reliable channel.
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Send a datagram (unreliable, best-effort).
|
||||
/// Falls back to reliable if the transport does not support datagrams.
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Gracefully close the transport.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Inbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportStream: Send + 'static {
|
||||
/// Receive the next reliable binary message. Returns `None` on close.
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Receive the next datagram. Returns `None` if datagrams are unsupported
|
||||
/// or the connection is closed.
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Whether this transport supports unreliable datagrams.
|
||||
fn supports_datagrams(&self) -> bool;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket implementation
|
||||
// ============================================================================
|
||||
|
||||
/// WebSocket transport sink (wraps the write half of a split WsStream).
|
||||
pub struct WsTransportSink {
|
||||
inner: futures_util::stream::SplitSink<WsStream, Message>,
|
||||
}
|
||||
|
||||
impl WsTransportSink {
|
||||
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for WsTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
self.inner.send(Message::Binary(data.into())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// WebSocket has no datagram support — fall back to reliable.
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.inner.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket transport stream (wraps the read half of a split WsStream).
|
||||
pub struct WsTransportStream {
|
||||
inner: futures_util::stream::SplitStream<WsStream>,
|
||||
}
|
||||
|
||||
impl WsTransportStream {
|
||||
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for WsTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
|
||||
Some(Ok(Message::Close(_))) | None => return Ok(None),
|
||||
Some(Ok(Message::Ping(_))) => {
|
||||
// Ping handling is done at the tungstenite layer automatically
|
||||
// when the sink side is alive. Just skip here.
|
||||
continue;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// WebSocket does not support datagrams.
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a WebSocket stream into transport sink and stream halves.
|
||||
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
|
||||
let (sink, stream) = ws.split();
|
||||
(WsTransportSink::new(sink), WsTransportStream::new(stream))
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use tracing::info;
|
||||
|
||||
/// Configuration for creating a TUN device.
|
||||
@@ -64,6 +64,42 @@ pub async fn add_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Action to take after checking a packet against the MTU.
|
||||
pub enum TunMtuAction {
|
||||
/// Packet is within MTU limits, forward it.
|
||||
Forward,
|
||||
/// Packet is oversized; the Vec contains the ICMP too-big message to write back into TUN.
|
||||
IcmpTooBig(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Check a TUN packet against the MTU and return the appropriate action.
|
||||
pub fn check_tun_mtu(packet: &[u8], mtu_config: &crate::mtu::MtuConfig) -> TunMtuAction {
|
||||
match crate::mtu::check_mtu(packet, mtu_config) {
|
||||
crate::mtu::MtuAction::Forward => TunMtuAction::Forward,
|
||||
crate::mtu::MtuAction::SendIcmpTooBig(icmp) => TunMtuAction::IcmpTooBig(icmp),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract destination IP from a raw IP packet header.
|
||||
pub fn extract_dst_ip(packet: &[u8]) -> Option<IpAddr> {
|
||||
if packet.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let version = packet[0] >> 4;
|
||||
match version {
|
||||
4 if packet.len() >= 20 => {
|
||||
let dst = Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]);
|
||||
Some(IpAddr::V4(dst))
|
||||
}
|
||||
6 if packet.len() >= 40 => {
|
||||
let mut octets = [0u8; 16];
|
||||
octets.copy_from_slice(&packet[24..40]);
|
||||
Some(IpAddr::V6(Ipv6Addr::from(octets)))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a route.
|
||||
pub async fn remove_route(subnet: &str, device_name: &str) -> Result<()> {
|
||||
let output = tokio::process::Command::new("ip")
|
||||
|
||||
640
rust/src/userspace_nat.rs
Normal file
640
rust/src/userspace_nat.rs
Normal file
@@ -0,0 +1,640 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet};
|
||||
use smoltcp::phy::{self, Device, DeviceCapabilities, Medium};
|
||||
use smoltcp::socket::{tcp, udp};
|
||||
use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, IpEndpoint};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpStream, UdpSocket};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::server::ServerState;
|
||||
use crate::tunnel;
|
||||
|
||||
// ============================================================================
|
||||
// Virtual IP device for smoltcp
|
||||
// ============================================================================
|
||||
|
||||
pub struct VirtualIpDevice {
|
||||
rx_queue: VecDeque<Vec<u8>>,
|
||||
tx_queue: VecDeque<Vec<u8>>,
|
||||
mtu: usize,
|
||||
}
|
||||
|
||||
impl VirtualIpDevice {
|
||||
pub fn new(mtu: usize) -> Self {
|
||||
Self {
|
||||
rx_queue: VecDeque::new(),
|
||||
tx_queue: VecDeque::new(),
|
||||
mtu,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inject_packet(&mut self, packet: Vec<u8>) {
|
||||
self.rx_queue.push_back(packet);
|
||||
}
|
||||
|
||||
pub fn drain_tx(&mut self) -> impl Iterator<Item = Vec<u8>> + '_ {
|
||||
self.tx_queue.drain(..)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VirtualRxToken {
|
||||
buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl phy::RxToken for VirtualRxToken {
|
||||
fn consume<R, F>(self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&[u8]) -> R,
|
||||
{
|
||||
f(&self.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VirtualTxToken<'a> {
|
||||
queue: &'a mut VecDeque<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl<'a> phy::TxToken for VirtualTxToken<'a> {
|
||||
fn consume<R, F>(self, len: usize, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> R,
|
||||
{
|
||||
let mut buffer = vec![0u8; len];
|
||||
let result = f(&mut buffer);
|
||||
self.queue.push_back(buffer);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Device for VirtualIpDevice {
|
||||
type RxToken<'a> = VirtualRxToken;
|
||||
type TxToken<'a> = VirtualTxToken<'a>;
|
||||
|
||||
fn receive(
|
||||
&mut self,
|
||||
_timestamp: smoltcp::time::Instant,
|
||||
) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
|
||||
self.rx_queue.pop_front().map(|buffer| {
|
||||
let rx = VirtualRxToken { buffer };
|
||||
let tx = VirtualTxToken {
|
||||
queue: &mut self.tx_queue,
|
||||
};
|
||||
(rx, tx)
|
||||
})
|
||||
}
|
||||
|
||||
fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
|
||||
Some(VirtualTxToken {
|
||||
queue: &mut self.tx_queue,
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> DeviceCapabilities {
|
||||
let mut caps = DeviceCapabilities::default();
|
||||
caps.medium = Medium::Ip;
|
||||
caps.max_transmission_unit = self.mtu;
|
||||
caps.max_burst_size = Some(1);
|
||||
caps
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session tracking
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
struct SessionKey {
|
||||
src_ip: Ipv4Addr,
|
||||
src_port: u16,
|
||||
dst_ip: Ipv4Addr,
|
||||
dst_port: u16,
|
||||
protocol: u8,
|
||||
}
|
||||
|
||||
struct TcpSession {
|
||||
smoltcp_handle: SocketHandle,
|
||||
bridge_data_tx: mpsc::Sender<Vec<u8>>,
|
||||
#[allow(dead_code)]
|
||||
client_ip: Ipv4Addr,
|
||||
}
|
||||
|
||||
struct UdpSession {
|
||||
smoltcp_handle: SocketHandle,
|
||||
bridge_data_tx: mpsc::Sender<Vec<u8>>,
|
||||
#[allow(dead_code)]
|
||||
client_ip: Ipv4Addr,
|
||||
last_activity: tokio::time::Instant,
|
||||
}
|
||||
|
||||
enum BridgeMessage {
|
||||
TcpData { key: SessionKey, data: Vec<u8> },
|
||||
TcpClosed { key: SessionKey },
|
||||
UdpData { key: SessionKey, data: Vec<u8> },
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// IP packet parsing helpers
|
||||
// ============================================================================
|
||||
|
||||
fn parse_ipv4_header(packet: &[u8]) -> Option<(u8, Ipv4Addr, Ipv4Addr, u8)> {
|
||||
if packet.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
let version = packet[0] >> 4;
|
||||
if version != 4 {
|
||||
return None;
|
||||
}
|
||||
let ihl = (packet[0] & 0x0F) as usize * 4;
|
||||
let protocol = packet[9];
|
||||
let src = Ipv4Addr::new(packet[12], packet[13], packet[14], packet[15]);
|
||||
let dst = Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]);
|
||||
Some((ihl as u8, src, dst, protocol))
|
||||
}
|
||||
|
||||
fn parse_tcp_ports(packet: &[u8], ihl: usize) -> Option<(u16, u16, u8)> {
|
||||
if packet.len() < ihl + 14 {
|
||||
return None;
|
||||
}
|
||||
let src_port = u16::from_be_bytes([packet[ihl], packet[ihl + 1]]);
|
||||
let dst_port = u16::from_be_bytes([packet[ihl + 2], packet[ihl + 3]]);
|
||||
let flags = packet[ihl + 13];
|
||||
Some((src_port, dst_port, flags))
|
||||
}
|
||||
|
||||
fn parse_udp_ports(packet: &[u8], ihl: usize) -> Option<(u16, u16)> {
|
||||
if packet.len() < ihl + 4 {
|
||||
return None;
|
||||
}
|
||||
let src_port = u16::from_be_bytes([packet[ihl], packet[ihl + 1]]);
|
||||
let dst_port = u16::from_be_bytes([packet[ihl + 2], packet[ihl + 3]]);
|
||||
Some((src_port, dst_port))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// NAT Engine
|
||||
// ============================================================================
|
||||
|
||||
pub struct NatEngine {
|
||||
device: VirtualIpDevice,
|
||||
iface: Interface,
|
||||
sockets: SocketSet<'static>,
|
||||
tcp_sessions: HashMap<SessionKey, TcpSession>,
|
||||
udp_sessions: HashMap<SessionKey, UdpSession>,
|
||||
state: Arc<ServerState>,
|
||||
bridge_rx: mpsc::Receiver<BridgeMessage>,
|
||||
bridge_tx: mpsc::Sender<BridgeMessage>,
|
||||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl NatEngine {
|
||||
pub fn new(gateway_ip: Ipv4Addr, mtu: usize, state: Arc<ServerState>) -> Self {
|
||||
let mut device = VirtualIpDevice::new(mtu);
|
||||
let config = Config::new(HardwareAddress::Ip);
|
||||
let now = smoltcp::time::Instant::from_millis(0);
|
||||
let mut iface = Interface::new(config, &mut device, now);
|
||||
|
||||
// Accept packets to ANY destination IP (essential for NAT)
|
||||
iface.set_any_ip(true);
|
||||
|
||||
// Assign the gateway IP as the interface address
|
||||
iface.update_ip_addrs(|addrs| {
|
||||
addrs
|
||||
.push(IpCidr::new(IpAddress::Ipv4(gateway_ip.into()), 24))
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Add a default route so smoltcp knows where to send packets
|
||||
iface.routes_mut().add_default_ipv4_route(gateway_ip.into()).unwrap();
|
||||
|
||||
let sockets = SocketSet::new(Vec::with_capacity(256));
|
||||
let (bridge_tx, bridge_rx) = mpsc::channel(4096);
|
||||
|
||||
Self {
|
||||
device,
|
||||
iface,
|
||||
sockets,
|
||||
tcp_sessions: HashMap::new(),
|
||||
udp_sessions: HashMap::new(),
|
||||
state,
|
||||
bridge_rx,
|
||||
bridge_tx,
|
||||
start_time: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn smoltcp_now(&self) -> smoltcp::time::Instant {
|
||||
smoltcp::time::Instant::from_millis(self.start_time.elapsed().as_millis() as i64)
|
||||
}
|
||||
|
||||
/// Inject a raw IP packet from a VPN client and handle new session creation.
|
||||
fn inject_packet(&mut self, packet: Vec<u8>) {
|
||||
let Some((ihl, src_ip, dst_ip, protocol)) = parse_ipv4_header(&packet) else {
|
||||
return;
|
||||
};
|
||||
let ihl = ihl as usize;
|
||||
|
||||
match protocol {
|
||||
6 => {
|
||||
// TCP
|
||||
let Some((src_port, dst_port, flags)) = parse_tcp_ports(&packet, ihl) else {
|
||||
return;
|
||||
};
|
||||
let key = SessionKey {
|
||||
src_ip,
|
||||
src_port,
|
||||
dst_ip,
|
||||
dst_port,
|
||||
protocol: 6,
|
||||
};
|
||||
|
||||
// SYN without ACK = new connection
|
||||
let is_syn = (flags & 0x02) != 0 && (flags & 0x10) == 0;
|
||||
if is_syn && !self.tcp_sessions.contains_key(&key) {
|
||||
self.create_tcp_session(&key);
|
||||
}
|
||||
}
|
||||
17 => {
|
||||
// UDP
|
||||
let Some((src_port, dst_port)) = parse_udp_ports(&packet, ihl) else {
|
||||
return;
|
||||
};
|
||||
let key = SessionKey {
|
||||
src_ip,
|
||||
src_port,
|
||||
dst_ip,
|
||||
dst_port,
|
||||
protocol: 17,
|
||||
};
|
||||
|
||||
if !self.udp_sessions.contains_key(&key) {
|
||||
self.create_udp_session(&key);
|
||||
}
|
||||
|
||||
// Update last_activity for existing sessions
|
||||
if let Some(session) = self.udp_sessions.get_mut(&key) {
|
||||
session.last_activity = tokio::time::Instant::now();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// ICMP and other protocols — not forwarded in socket mode
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
self.device.inject_packet(packet);
|
||||
}
|
||||
|
||||
fn create_tcp_session(&mut self, key: &SessionKey) {
|
||||
// Create smoltcp TCP socket
|
||||
let tcp_rx_buf = tcp::SocketBuffer::new(vec![0u8; 65535]);
|
||||
let tcp_tx_buf = tcp::SocketBuffer::new(vec![0u8; 65535]);
|
||||
let mut socket = tcp::Socket::new(tcp_rx_buf, tcp_tx_buf);
|
||||
|
||||
// Listen on the destination address so smoltcp accepts the SYN
|
||||
let endpoint = IpEndpoint::new(
|
||||
IpAddress::Ipv4(key.dst_ip.into()),
|
||||
key.dst_port,
|
||||
);
|
||||
if socket.listen(endpoint).is_err() {
|
||||
warn!("NAT: failed to listen on {:?}", endpoint);
|
||||
return;
|
||||
}
|
||||
|
||||
let handle = self.sockets.add(socket);
|
||||
|
||||
// Channel for sending data from NAT engine to bridge task
|
||||
let (data_tx, data_rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
|
||||
let session = TcpSession {
|
||||
smoltcp_handle: handle,
|
||||
bridge_data_tx: data_tx,
|
||||
client_ip: key.src_ip,
|
||||
};
|
||||
self.tcp_sessions.insert(key.clone(), session);
|
||||
|
||||
// Spawn bridge task that connects to the real destination
|
||||
let bridge_tx = self.bridge_tx.clone();
|
||||
let key_clone = key.clone();
|
||||
tokio::spawn(async move {
|
||||
tcp_bridge_task(key_clone, data_rx, bridge_tx).await;
|
||||
});
|
||||
|
||||
debug!(
|
||||
"NAT: new TCP session {}:{} -> {}:{}",
|
||||
key.src_ip, key.src_port, key.dst_ip, key.dst_port
|
||||
);
|
||||
}
|
||||
|
||||
fn create_udp_session(&mut self, key: &SessionKey) {
|
||||
// Create smoltcp UDP socket
|
||||
let udp_rx_buf = udp::PacketBuffer::new(
|
||||
vec![udp::PacketMetadata::EMPTY; 32],
|
||||
vec![0u8; 65535],
|
||||
);
|
||||
let udp_tx_buf = udp::PacketBuffer::new(
|
||||
vec![udp::PacketMetadata::EMPTY; 32],
|
||||
vec![0u8; 65535],
|
||||
);
|
||||
let mut socket = udp::Socket::new(udp_rx_buf, udp_tx_buf);
|
||||
|
||||
let endpoint = IpEndpoint::new(
|
||||
IpAddress::Ipv4(key.dst_ip.into()),
|
||||
key.dst_port,
|
||||
);
|
||||
if socket.bind(endpoint).is_err() {
|
||||
warn!("NAT: failed to bind UDP on {:?}", endpoint);
|
||||
return;
|
||||
}
|
||||
|
||||
let handle = self.sockets.add(socket);
|
||||
|
||||
let (data_tx, data_rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
|
||||
let session = UdpSession {
|
||||
smoltcp_handle: handle,
|
||||
bridge_data_tx: data_tx,
|
||||
client_ip: key.src_ip,
|
||||
last_activity: tokio::time::Instant::now(),
|
||||
};
|
||||
self.udp_sessions.insert(key.clone(), session);
|
||||
|
||||
let bridge_tx = self.bridge_tx.clone();
|
||||
let key_clone = key.clone();
|
||||
tokio::spawn(async move {
|
||||
udp_bridge_task(key_clone, data_rx, bridge_tx).await;
|
||||
});
|
||||
|
||||
debug!(
|
||||
"NAT: new UDP session {}:{} -> {}:{}",
|
||||
key.src_ip, key.src_port, key.dst_ip, key.dst_port
|
||||
);
|
||||
}
|
||||
|
||||
/// Poll smoltcp, bridge data between smoltcp sockets and bridge tasks,
|
||||
/// and dispatch outgoing packets to VPN clients.
|
||||
async fn process(&mut self) {
|
||||
let now = self.smoltcp_now();
|
||||
self.iface
|
||||
.poll(now, &mut self.device, &mut self.sockets);
|
||||
|
||||
// Bridge: read data from smoltcp TCP sockets → send to bridge tasks
|
||||
let mut closed_tcp: Vec<SessionKey> = Vec::new();
|
||||
for (key, session) in &self.tcp_sessions {
|
||||
let socket = self.sockets.get_mut::<tcp::Socket>(session.smoltcp_handle);
|
||||
if socket.can_recv() {
|
||||
let _ = socket.recv(|data| {
|
||||
let _ = session.bridge_data_tx.try_send(data.to_vec());
|
||||
(data.len(), ())
|
||||
});
|
||||
}
|
||||
// Detect closed connections
|
||||
if !socket.is_open() && !socket.is_listening() {
|
||||
closed_tcp.push(key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up closed TCP sessions
|
||||
for key in closed_tcp {
|
||||
if let Some(session) = self.tcp_sessions.remove(&key) {
|
||||
self.sockets.remove(session.smoltcp_handle);
|
||||
debug!("NAT: TCP session closed {}:{} -> {}:{}", key.src_ip, key.src_port, key.dst_ip, key.dst_port);
|
||||
}
|
||||
}
|
||||
|
||||
// Bridge: read data from smoltcp UDP sockets → send to bridge tasks
|
||||
for (_key, session) in &self.udp_sessions {
|
||||
let socket = self.sockets.get_mut::<udp::Socket>(session.smoltcp_handle);
|
||||
while let Ok((data, _meta)) = socket.recv() {
|
||||
let _ = session.bridge_data_tx.try_send(data.to_vec());
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch outgoing packets from smoltcp to VPN clients
|
||||
let routes = self.state.tun_routes.read().await;
|
||||
for packet in self.device.drain_tx() {
|
||||
if let Some(std::net::IpAddr::V4(dst_ip)) = tunnel::extract_dst_ip(&packet) {
|
||||
if let Some(sender) = routes.get(&dst_ip) {
|
||||
let _ = sender.try_send(packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_bridge_message(&mut self, msg: BridgeMessage) {
|
||||
match msg {
|
||||
BridgeMessage::TcpData { key, data } => {
|
||||
if let Some(session) = self.tcp_sessions.get(&key) {
|
||||
let socket =
|
||||
self.sockets.get_mut::<tcp::Socket>(session.smoltcp_handle);
|
||||
if socket.can_send() {
|
||||
let _ = socket.send_slice(&data);
|
||||
}
|
||||
}
|
||||
}
|
||||
BridgeMessage::TcpClosed { key } => {
|
||||
if let Some(session) = self.tcp_sessions.remove(&key) {
|
||||
let socket =
|
||||
self.sockets.get_mut::<tcp::Socket>(session.smoltcp_handle);
|
||||
socket.close();
|
||||
// Don't remove from SocketSet yet — let smoltcp send FIN
|
||||
// It will be cleaned up in process() when is_open() returns false
|
||||
self.tcp_sessions.insert(key, session);
|
||||
}
|
||||
}
|
||||
BridgeMessage::UdpData { key, data } => {
|
||||
if let Some(session) = self.udp_sessions.get_mut(&key) {
|
||||
session.last_activity = tokio::time::Instant::now();
|
||||
let socket =
|
||||
self.sockets.get_mut::<udp::Socket>(session.smoltcp_handle);
|
||||
let dst_endpoint = IpEndpoint::new(
|
||||
IpAddress::Ipv4(key.src_ip.into()),
|
||||
key.src_port,
|
||||
);
|
||||
// Send response: from the "server" (dst) back to the "client" (src)
|
||||
let _ = socket.send_slice(&data, dst_endpoint);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup_idle_udp_sessions(&mut self) {
|
||||
let timeout = Duration::from_secs(60);
|
||||
let now = tokio::time::Instant::now();
|
||||
let expired: Vec<SessionKey> = self
|
||||
.udp_sessions
|
||||
.iter()
|
||||
.filter(|(_, s)| now.duration_since(s.last_activity) > timeout)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
|
||||
for key in expired {
|
||||
if let Some(session) = self.udp_sessions.remove(&key) {
|
||||
self.sockets.remove(session.smoltcp_handle);
|
||||
debug!(
|
||||
"NAT: UDP session timed out {}:{} -> {}:{}",
|
||||
key.src_ip, key.src_port, key.dst_ip, key.dst_port
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main async event loop for the NAT engine.
|
||||
pub async fn run(
|
||||
mut self,
|
||||
mut packet_rx: mpsc::Receiver<Vec<u8>>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
) -> Result<()> {
|
||||
info!("Userspace NAT engine started");
|
||||
let mut timer = tokio::time::interval(Duration::from_millis(50));
|
||||
let mut cleanup_timer = tokio::time::interval(Duration::from_secs(10));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(packet) = packet_rx.recv() => {
|
||||
self.inject_packet(packet);
|
||||
self.process().await;
|
||||
}
|
||||
Some(msg) = self.bridge_rx.recv() => {
|
||||
self.handle_bridge_message(msg);
|
||||
self.process().await;
|
||||
}
|
||||
_ = timer.tick() => {
|
||||
// Periodic poll for smoltcp maintenance (TCP retransmit, etc.)
|
||||
self.process().await;
|
||||
}
|
||||
_ = cleanup_timer.tick() => {
|
||||
self.cleanup_idle_udp_sessions();
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("Userspace NAT engine shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Bridge tasks
|
||||
// ============================================================================
|
||||
|
||||
async fn tcp_bridge_task(
|
||||
key: SessionKey,
|
||||
mut data_rx: mpsc::Receiver<Vec<u8>>,
|
||||
bridge_tx: mpsc::Sender<BridgeMessage>,
|
||||
) {
|
||||
let addr = SocketAddr::new(key.dst_ip.into(), key.dst_port);
|
||||
|
||||
// Connect to real destination with timeout
|
||||
let stream = match tokio::time::timeout(Duration::from_secs(30), TcpStream::connect(addr)).await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
debug!("NAT TCP connect to {} failed: {}", addr, e);
|
||||
let _ = bridge_tx.send(BridgeMessage::TcpClosed { key }).await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("NAT TCP connect to {} timed out", addr);
|
||||
let _ = bridge_tx.send(BridgeMessage::TcpClosed { key }).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
|
||||
// Read from real socket → send to NAT engine
|
||||
let bridge_tx2 = bridge_tx.clone();
|
||||
let key2 = key.clone();
|
||||
let read_task = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
match reader.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
if bridge_tx2
|
||||
.send(BridgeMessage::TcpData {
|
||||
key: key2.clone(),
|
||||
data: buf[..n].to_vec(),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
let _ = bridge_tx2
|
||||
.send(BridgeMessage::TcpClosed { key: key2 })
|
||||
.await;
|
||||
});
|
||||
|
||||
// Receive from NAT engine → write to real socket
|
||||
while let Some(data) = data_rx.recv().await {
|
||||
if writer.write_all(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
read_task.abort();
|
||||
}
|
||||
|
||||
async fn udp_bridge_task(
|
||||
key: SessionKey,
|
||||
mut data_rx: mpsc::Receiver<Vec<u8>>,
|
||||
bridge_tx: mpsc::Sender<BridgeMessage>,
|
||||
) {
|
||||
let socket = match UdpSocket::bind("0.0.0.0:0").await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!("NAT UDP bind failed: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let dest = SocketAddr::new(key.dst_ip.into(), key.dst_port);
|
||||
|
||||
let socket = Arc::new(socket);
|
||||
let socket2 = socket.clone();
|
||||
let bridge_tx2 = bridge_tx.clone();
|
||||
let key2 = key.clone();
|
||||
|
||||
// Read responses from real socket
|
||||
let read_task = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
match socket2.recv_from(&mut buf).await {
|
||||
Ok((n, _src)) => {
|
||||
if bridge_tx2
|
||||
.send(BridgeMessage::UdpData {
|
||||
key: key2.clone(),
|
||||
data: buf[..n].to_vec(),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Forward data from NAT engine to real socket
|
||||
while let Some(data) = data_rx.recv().await {
|
||||
let _ = socket.send_to(&data, dest).await;
|
||||
}
|
||||
|
||||
read_task.abort();
|
||||
}
|
||||
1311
rust/src/wireguard.rs
Normal file
1311
rust/src/wireguard.rs
Normal file
File diff suppressed because it is too large
Load Diff
320
rust/tests/wg_e2e.rs
Normal file
320
rust/tests/wg_e2e.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! End-to-end WireGuard protocol tests over real UDP sockets.
|
||||
//!
|
||||
//! Entirely userspace — no root, no TUN devices.
|
||||
//! Two boringtun `Tunn` instances exchange real WireGuard packets
|
||||
//! over loopback UDP, validating handshake, encryption, and data flow.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::time;
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
use smartvpn_daemon::wireguard::generate_wg_keypair;
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn parse_key_pair(pub_b64: &str, priv_b64: &str) -> (PublicKey, StaticSecret) {
|
||||
let pub_bytes: [u8; 32] = BASE64.decode(pub_b64).unwrap().try_into().unwrap();
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
(PublicKey::from(pub_bytes), StaticSecret::from(priv_bytes))
|
||||
}
|
||||
|
||||
fn clone_secret(priv_b64: &str) -> StaticSecret {
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
StaticSecret::from(priv_bytes)
|
||||
}
|
||||
|
||||
fn make_ipv4_packet(src: Ipv4Addr, dst: Ipv4Addr, payload: &[u8]) -> Vec<u8> {
|
||||
let total_len = 20 + payload.len();
|
||||
let mut pkt = vec![0u8; total_len];
|
||||
pkt[0] = 0x45;
|
||||
pkt[2] = (total_len >> 8) as u8;
|
||||
pkt[3] = total_len as u8;
|
||||
pkt[9] = 0x11;
|
||||
pkt[12..16].copy_from_slice(&src.octets());
|
||||
pkt[16..20].copy_from_slice(&dst.octets());
|
||||
pkt[20..].copy_from_slice(payload);
|
||||
pkt
|
||||
}
|
||||
|
||||
/// Send any WriteToNetwork result, then drain the tunn for more packets.
|
||||
async fn send_and_drain(
|
||||
tunn: &mut Tunn,
|
||||
pkt: &[u8],
|
||||
socket: &UdpSocket,
|
||||
peer: SocketAddr,
|
||||
) {
|
||||
socket.send_to(pkt, peer).await.unwrap();
|
||||
let mut drain_buf = vec![0u8; 2048];
|
||||
loop {
|
||||
match tunn.decapsulate(None, &[], &mut drain_buf) {
|
||||
TunnResult::WriteToNetwork(p) => { socket.send_to(p, peer).await.unwrap(); }
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to receive a UDP packet and decapsulate it. Returns decrypted IP data if any.
|
||||
async fn try_recv_decap(
|
||||
tunn: &mut Tunn,
|
||||
socket: &UdpSocket,
|
||||
timeout_ms: u64,
|
||||
) -> Option<(Vec<u8>, Ipv4Addr, SocketAddr)> {
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
let (n, src_addr) = match time::timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
socket.recv_from(&mut recv_buf),
|
||||
).await {
|
||||
Ok(Ok(r)) => r,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let result = tunn.decapsulate(Some(src_addr.ip()), &recv_buf[..n], &mut dst_buf);
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(tunn, pkt, socket, src_addr).await;
|
||||
None
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(pkt, addr) => Some((pkt.to_vec(), addr, src_addr)),
|
||||
TunnResult::WriteToTunnelV6(_, _) => None,
|
||||
TunnResult::Done => None,
|
||||
TunnResult::Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the full WireGuard handshake between client and server over real UDP.
|
||||
async fn do_handshake(
|
||||
client_tunn: &mut Tunn,
|
||||
server_tunn: &mut Tunn,
|
||||
client_socket: &UdpSocket,
|
||||
server_socket: &UdpSocket,
|
||||
server_addr: SocketAddr,
|
||||
) {
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
// Step 1: Client initiates handshake
|
||||
match client_tunn.encapsulate(&[], &mut buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => panic!("Expected handshake init"),
|
||||
}
|
||||
|
||||
// Step 2: Server receives init → sends response
|
||||
let (n, client_from) = server_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match server_tunn.decapsulate(Some(client_from.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(server_tunn, pkt, server_socket, client_from).await;
|
||||
}
|
||||
other => panic!("Expected WriteToNetwork from server, got variant {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Step 3: Client receives response
|
||||
let (n, _) = client_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match client_tunn.decapsulate(Some(server_addr.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(client_tunn, pkt, client_socket, server_addr).await;
|
||||
}
|
||||
TunnResult::Done => {}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Step 4: Process any remaining handshake packets
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 200).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 100).await;
|
||||
|
||||
// Step 5: Timer ticks to settle
|
||||
for _ in 0..3 {
|
||||
match server_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
server_socket.send_to(pkt, client_from).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match client_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 50).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 50).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn variant_name(r: &TunnResult) -> &'static str {
|
||||
match r {
|
||||
TunnResult::Done => "Done",
|
||||
TunnResult::Err(_) => "Err",
|
||||
TunnResult::WriteToNetwork(_) => "WriteToNetwork",
|
||||
TunnResult::WriteToTunnelV4(_, _) => "WriteToTunnelV4",
|
||||
TunnResult::WriteToTunnelV6(_, _) => "WriteToTunnelV6",
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an IP packet and send it, then loop-receive on the other side until decrypted.
|
||||
async fn send_and_expect_data(
|
||||
sender_tunn: &mut Tunn,
|
||||
receiver_tunn: &mut Tunn,
|
||||
sender_socket: &UdpSocket,
|
||||
receiver_socket: &UdpSocket,
|
||||
dest_addr: SocketAddr,
|
||||
ip_packet: &[u8],
|
||||
) -> (Vec<u8>, Ipv4Addr) {
|
||||
let mut enc_buf = vec![0u8; 65536];
|
||||
|
||||
match sender_tunn.encapsulate(ip_packet, &mut enc_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
sender_socket.send_to(pkt, dest_addr).await.unwrap();
|
||||
}
|
||||
TunnResult::Err(e) => panic!("Encapsulate failed: {:?}", e),
|
||||
other => panic!("Expected WriteToNetwork, got {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Receive — may need a few rounds for control packets
|
||||
for _ in 0..10 {
|
||||
if let Some((data, addr, _)) = try_recv_decap(receiver_tunn, receiver_socket, 1000).await {
|
||||
return (data, addr);
|
||||
}
|
||||
}
|
||||
panic!("Did not receive decrypted IP packet");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 1: Single client ↔ server bidirectional data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_single_client_bidirectional() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
let client_addr = client_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, None, None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, None, None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
// Client → Server
|
||||
let pkt_c2s = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"Hello from client!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt_c2s,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt_c2s.len()], &pkt_c2s[..]);
|
||||
|
||||
// Server → Client
|
||||
let pkt_s2c = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2), b"Hello from server!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut server_tunn, &mut client_tunn,
|
||||
&server_socket, &client_socket,
|
||||
client_addr, &pkt_s2c,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 1));
|
||||
assert_eq!(&decrypted[..pkt_s2c.len()], &pkt_s2c[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 2: Two clients ↔ one server (peer routing)
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_two_clients_peer_routing() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client1_pub_b64, client1_priv_b64) = generate_wg_keypair();
|
||||
let (client2_pub_b64, client2_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, _) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client1_public, client1_secret) = parse_key_pair(&client1_pub_b64, &client1_priv_b64);
|
||||
let (client2_public, client2_secret) = parse_key_pair(&client2_pub_b64, &client2_priv_b64);
|
||||
|
||||
// Separate server socket per peer to avoid UDP mux complexity in test
|
||||
let server_socket_1 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_socket_2 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client1_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client2_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr_1 = server_socket_1.local_addr().unwrap();
|
||||
let server_addr_2 = server_socket_2.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn_1 = Tunn::new(clone_secret(&server_priv_b64), client1_public, None, None, 0, None);
|
||||
let mut server_tunn_2 = Tunn::new(clone_secret(&server_priv_b64), client2_public, None, None, 1, None);
|
||||
let mut client1_tunn = Tunn::new(client1_secret, server_public.clone(), None, None, 2, None);
|
||||
let mut client2_tunn = Tunn::new(client2_secret, server_public, None, None, 3, None);
|
||||
|
||||
do_handshake(&mut client1_tunn, &mut server_tunn_1, &client1_socket, &server_socket_1, server_addr_1).await;
|
||||
do_handshake(&mut client2_tunn, &mut server_tunn_2, &client2_socket, &server_socket_2, server_addr_2).await;
|
||||
|
||||
// Client 1 → Server
|
||||
let pkt1 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"From client 1");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client1_tunn, &mut server_tunn_1,
|
||||
&client1_socket, &server_socket_1,
|
||||
server_addr_1, &pkt1,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt1.len()], &pkt1[..]);
|
||||
|
||||
// Client 2 → Server
|
||||
let pkt2 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 3), Ipv4Addr::new(10, 0, 0, 1), b"From client 2");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client2_tunn, &mut server_tunn_2,
|
||||
&client2_socket, &server_socket_2,
|
||||
server_addr_2, &pkt2,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 3));
|
||||
assert_eq!(&decrypted[..pkt2.len()], &pkt2[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 3: Preshared key handshake + data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_preshared_key() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let psk: [u8; 32] = rand::random();
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, Some(psk), None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, Some(psk), None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
let pkt = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"PSK-protected data");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt.len()], &pkt[..]);
|
||||
}
|
||||
284
test/test.flowcontrol.node.ts
Normal file
284
test/test.flowcontrol.node.ts
Normal file
@@ -0,0 +1,284 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let client: VpnClient;
|
||||
let clientBundle: IClientConfigBundle;
|
||||
const extraClients: VpnClient[] = [];
|
||||
const extraBundles: IClientConfigBundle[] = [];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server', async () => {
|
||||
serverPort = await findFreePort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
// Phase 1: start the daemon bridge
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
|
||||
// Phase 2: generate a keypair
|
||||
keypair = await server.generateKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
|
||||
// Phase 3: start the VPN listener (empty clients, will use createClient at runtime)
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
// Verify server is now running
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Phase 4: create the first client via the hub
|
||||
clientBundle = await server.createClient({ clientId: 'test-client-0' });
|
||||
expect(clientBundle.secrets.noisePrivateKey).toBeTypeofString();
|
||||
expect(clientBundle.smartvpnConfig.clientPublicKey).toBeTypeofString();
|
||||
});
|
||||
|
||||
tap.test('single client connects and gets IP', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: clientBundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: clientBundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
// Verify client status
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(1);
|
||||
expect(clients[0].assignedIp).toEqual(result.assignedIp);
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange', async () => {
|
||||
// Wait for at least 2 keepalive cycles (interval=3s, so 8s should be enough)
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
const serverStats = await server.getStatistics();
|
||||
expect(serverStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
expect(serverStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify per-client keepalive tracking
|
||||
const clients = await server.listClients();
|
||||
expect(clients[0].keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('connection quality telemetry', async () => {
|
||||
const quality = await client.getConnectionQuality();
|
||||
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
expect(quality.minRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.maxRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.lossRatio).toBeTypeofNumber();
|
||||
expect(['healthy', 'degraded', 'critical']).toContain(quality.linkHealth);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: set and verify', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
// Set a tight rate limit
|
||||
await server.setClientRateLimit(clientId, 100, 100);
|
||||
|
||||
// Verify via telemetry
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(100);
|
||||
expect(telemetry.burstBytes).toEqual(100);
|
||||
expect(telemetry.clientId).toEqual(clientId);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: removal', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
await server.removeClientRateLimit(clientId);
|
||||
|
||||
// Verify telemetry no longer shows rate limit
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
expect(telemetry.burstBytes).toBeNullOrUndefined();
|
||||
|
||||
// Connection still healthy
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('5 concurrent clients', async () => {
|
||||
const assignedIps = new Set<string>();
|
||||
|
||||
// Get the first client's IP
|
||||
const existingClients = await server.listClients();
|
||||
assignedIps.add(existingClients[0].assignedIp);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const bundle = await server.createClient({ clientId: `test-client-${i + 1}` });
|
||||
extraBundles.push(bundle);
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
const result = await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
assignedIps.add(result.assignedIp);
|
||||
extraClients.push(c);
|
||||
}
|
||||
|
||||
// All IPs should be unique (6 total: original + 5 new)
|
||||
expect(assignedIps.size).toEqual(6);
|
||||
|
||||
// Server should see 6 clients
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 6;
|
||||
});
|
||||
const allClients = await server.listClients();
|
||||
expect(allClients.length).toEqual(6);
|
||||
});
|
||||
|
||||
tap.test('client disconnect tracking', async () => {
|
||||
// Disconnect 3 of the 5 extra clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = extraClients[i];
|
||||
await c.disconnect();
|
||||
c.stop();
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 3;
|
||||
}, 15000);
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(3);
|
||||
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(6);
|
||||
});
|
||||
|
||||
tap.test('server-side client disconnection', async () => {
|
||||
const clients = await server.listClients();
|
||||
// Pick one of the remaining extra clients (not the original)
|
||||
const targetClient = clients.find((c) => {
|
||||
// Find a client that belongs to extraClients[3] or extraClients[4]
|
||||
return c.clientId !== clients[0].clientId;
|
||||
});
|
||||
expect(targetClient).toBeTruthy();
|
||||
|
||||
await server.disconnectClient(targetClient!.clientId);
|
||||
|
||||
// Wait for server to update
|
||||
await waitFor(async () => {
|
||||
const remaining = await server.listClients();
|
||||
return remaining.length === 2;
|
||||
});
|
||||
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(2);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop the original client
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
|
||||
// Stop remaining extra clients
|
||||
for (const c of extraClients) {
|
||||
if (c.running) {
|
||||
try {
|
||||
await c.disconnect();
|
||||
} catch {
|
||||
// May already be disconnected
|
||||
}
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Stop the server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
362
test/test.loadtest.node.ts
Normal file
362
test/test.loadtest.node.ts
Normal file
@@ -0,0 +1,362 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as stream from 'stream';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ThrottleProxy (adapted from remoteingress)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class ThrottleTransform extends stream.Transform {
|
||||
private bytesPerSec: number;
|
||||
private bucket: number;
|
||||
private lastRefill: number;
|
||||
private destroyed_: boolean = false;
|
||||
|
||||
constructor(bytesPerSecond: number) {
|
||||
super();
|
||||
this.bytesPerSec = bytesPerSecond;
|
||||
this.bucket = bytesPerSecond;
|
||||
this.lastRefill = Date.now();
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: stream.TransformCallback) {
|
||||
if (this.destroyed_) return;
|
||||
|
||||
const now = Date.now();
|
||||
const elapsed = (now - this.lastRefill) / 1000;
|
||||
this.bucket = Math.min(this.bytesPerSec, this.bucket + elapsed * this.bytesPerSec);
|
||||
this.lastRefill = now;
|
||||
|
||||
if (chunk.length <= this.bucket) {
|
||||
this.bucket -= chunk.length;
|
||||
callback(null, chunk);
|
||||
} else {
|
||||
const deficit = chunk.length - this.bucket;
|
||||
this.bucket = 0;
|
||||
const delayMs = Math.min((deficit / this.bytesPerSec) * 1000, 1000);
|
||||
setTimeout(() => {
|
||||
if (this.destroyed_) { callback(); return; }
|
||||
this.lastRefill = Date.now();
|
||||
this.bucket = 0;
|
||||
callback(null, chunk);
|
||||
}, delayMs);
|
||||
}
|
||||
}
|
||||
|
||||
_destroy(err: Error | null, callback: (error: Error | null) => void) {
|
||||
this.destroyed_ = true;
|
||||
callback(err);
|
||||
}
|
||||
}
|
||||
|
||||
interface ThrottleProxy {
|
||||
server: net.Server;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
async function startThrottleProxy(
|
||||
listenPort: number,
|
||||
targetHost: string,
|
||||
targetPort: number,
|
||||
bytesPerSecond: number,
|
||||
): Promise<ThrottleProxy> {
|
||||
const connections = new Set<net.Socket>();
|
||||
const server = net.createServer((clientSock) => {
|
||||
connections.add(clientSock);
|
||||
const upstream = net.createConnection({ host: targetHost, port: targetPort });
|
||||
connections.add(upstream);
|
||||
|
||||
const throttleUp = new ThrottleTransform(bytesPerSecond);
|
||||
const throttleDown = new ThrottleTransform(bytesPerSecond);
|
||||
|
||||
clientSock.pipe(throttleUp).pipe(upstream);
|
||||
upstream.pipe(throttleDown).pipe(clientSock);
|
||||
|
||||
let cleaned = false;
|
||||
const cleanup = () => {
|
||||
if (cleaned) return;
|
||||
cleaned = true;
|
||||
throttleUp.destroy();
|
||||
throttleDown.destroy();
|
||||
clientSock.destroy();
|
||||
upstream.destroy();
|
||||
connections.delete(clientSock);
|
||||
connections.delete(upstream);
|
||||
};
|
||||
clientSock.on('error', () => cleanup());
|
||||
upstream.on('error', () => cleanup());
|
||||
throttleUp.on('error', () => cleanup());
|
||||
throttleDown.on('error', () => cleanup());
|
||||
clientSock.on('close', () => cleanup());
|
||||
upstream.on('close', () => cleanup());
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => server.listen(listenPort, '127.0.0.1', resolve));
|
||||
return {
|
||||
server,
|
||||
close: async () => {
|
||||
for (const c of connections) c.destroy();
|
||||
connections.clear();
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let proxyPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let throttle: ThrottleProxy;
|
||||
const allClients: VpnClient[] = [];
|
||||
|
||||
let clientCounter = 0;
|
||||
async function createConnectedClient(port: number): Promise<VpnClient> {
|
||||
clientCounter++;
|
||||
const bundle = await server.createClient({ clientId: `load-client-${clientCounter}` });
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${port}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
allClients.push(c);
|
||||
return c;
|
||||
}
|
||||
|
||||
async function stopClient(c: VpnClient): Promise<void> {
|
||||
if (c.running) {
|
||||
try { await c.disconnect(); } catch { /* already disconnected */ }
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start throttled VPN tunnel (1 MB/s)', async () => {
|
||||
serverPort = await findFreePort();
|
||||
proxyPort = await findFreePort();
|
||||
|
||||
// Start VPN server
|
||||
server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Start throttle proxy: 1 MB/s
|
||||
throttle = await startThrottleProxy(proxyPort, '127.0.0.1', serverPort, 1024 * 1024);
|
||||
});
|
||||
|
||||
tap.test('throttled connection: handshake succeeds through throttle', async () => {
|
||||
const client = await createConnectedClient(proxyPort);
|
||||
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
expect(status.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
});
|
||||
|
||||
tap.test('sustained keepalive under throttle', async () => {
|
||||
// Wait for at least 1 keepalive cycle (3s interval)
|
||||
await delay(4000);
|
||||
|
||||
const client = allClients[0];
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Throttle adds latency — RTT should be measurable
|
||||
const quality = await client.getConnectionQuality();
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
});
|
||||
|
||||
tap.test('3 concurrent throttled clients', async () => {
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await createConnectedClient(proxyPort);
|
||||
}
|
||||
|
||||
// All 4 clients should be visible
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 4;
|
||||
});
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(4);
|
||||
|
||||
// Verify all IPs are unique
|
||||
const ips = new Set(clients.map((c) => c.assignedIp));
|
||||
expect(ips.size).toEqual(4);
|
||||
});
|
||||
|
||||
tap.test('rate limiting combined with network throttle', async () => {
|
||||
const clients = await server.listClients();
|
||||
const targetId = clients[0].clientId;
|
||||
|
||||
// Set rate limit on first client
|
||||
await server.setClientRateLimit(targetId, 500, 500);
|
||||
const telemetry = await server.getClientTelemetry(targetId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(500);
|
||||
expect(telemetry.burstBytes).toEqual(500);
|
||||
|
||||
// Verify another client has no rate limit
|
||||
const otherTelemetry = await server.getClientTelemetry(clients[1].clientId);
|
||||
expect(otherTelemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
|
||||
// Clean up the rate limit
|
||||
await server.removeClientRateLimit(targetId);
|
||||
});
|
||||
|
||||
tap.test('burst waves: 2 waves of 2 clients', async () => {
|
||||
const initialCount = (await server.listClients()).length;
|
||||
|
||||
for (let wave = 0; wave < 2; wave++) {
|
||||
const waveClients: VpnClient[] = [];
|
||||
|
||||
// Connect 2 clients
|
||||
for (let i = 0; i < 2; i++) {
|
||||
const c = await createConnectedClient(proxyPort);
|
||||
waveClients.push(c);
|
||||
}
|
||||
|
||||
// Verify all connected
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount + 2;
|
||||
});
|
||||
|
||||
// Disconnect all wave clients
|
||||
for (const c of waveClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount;
|
||||
}, 15000);
|
||||
|
||||
await delay(500);
|
||||
}
|
||||
|
||||
// Verify total connections accumulated
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(4 + initialCount);
|
||||
|
||||
// Original clients still connected
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(initialCount);
|
||||
});
|
||||
|
||||
tap.test('aggressive throttle: 10 KB/s', async () => {
|
||||
// Close current throttle proxy and start an aggressive one
|
||||
await throttle.close();
|
||||
const aggressivePort = await findFreePort();
|
||||
throttle = await startThrottleProxy(aggressivePort, '127.0.0.1', serverPort, 10 * 1024);
|
||||
|
||||
// Connect a client through the aggressive throttle
|
||||
const client = await createConnectedClient(aggressivePort);
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Wait for keepalive exchange (might take longer due to throttle)
|
||||
await delay(4000);
|
||||
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('post-load health: direct connection still works', async () => {
|
||||
// Server should still be healthy after all load tests
|
||||
const serverStatus = await server.getStatus();
|
||||
expect(serverStatus.state).toEqual('connected');
|
||||
|
||||
// Connect one more client directly (no throttle)
|
||||
const directClient = await createConnectedClient(serverPort);
|
||||
const status = await directClient.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
await delay(3500);
|
||||
|
||||
const stats = await directClient.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop all clients
|
||||
for (const c of allClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Close throttle proxy
|
||||
if (throttle) {
|
||||
await throttle.close();
|
||||
}
|
||||
|
||||
// Stop server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
258
test/test.quic.node.ts
Normal file
258
test/test.quic.node.ts
Normal file
@@ -0,0 +1,258 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as dgram from 'dgram';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
async function findFreeUdpPort(): Promise<number> {
|
||||
const sock = dgram.createSocket('udp4');
|
||||
await new Promise<void>((resolve) => sock.bind(0, '127.0.0.1', resolve));
|
||||
const port = (sock.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => sock.close(resolve));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let wsPort: number;
|
||||
let quicPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: QUIC-only server + QUIC client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server in QUIC mode', async () => {
|
||||
quicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${quicPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.9.0.0/24',
|
||||
transportMode: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('QUIC client connects and gets IP', async () => {
|
||||
const bundle = await server.createClient({ clientId: 'quic-client-1' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${quicPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.9.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop QUIC server', async () => {
|
||||
await server.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: dual-mode server (both) + auto client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let dualServer: VpnServer;
|
||||
let dualWsPort: number;
|
||||
let dualQuicPort: number;
|
||||
let dualKeypair: IVpnKeypair;
|
||||
|
||||
tap.test('setup: start VPN server in both mode', async () => {
|
||||
dualWsPort = await findFreePort();
|
||||
dualQuicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
dualServer = new VpnServer(options);
|
||||
|
||||
const started = await dualServer['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
dualKeypair = await dualServer.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${dualWsPort}`,
|
||||
privateKey: dualKeypair.privateKey,
|
||||
publicKey: dualKeypair.publicKey,
|
||||
subnet: '10.10.0.0/24',
|
||||
transportMode: 'both',
|
||||
quicListenAddr: `127.0.0.1:${dualQuicPort}`,
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await dualServer['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await dualServer.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('auto client connects to dual-mode server (QUIC preferred)', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-auto-client' });
|
||||
|
||||
// "auto" mode (default): tries QUIC first at same host:port, falls back to WS
|
||||
// Since the WS port and QUIC port differ, auto will try QUIC on WS port (fail),
|
||||
// then fall back to WebSocket
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${dualWsPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
// transport defaults to 'auto'
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await dualServer.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-quic-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange over QUIC', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-keepalive-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
await client.start();
|
||||
|
||||
await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
// Wait for keepalive exchange
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop dual-mode server', async () => {
|
||||
await dualServer.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -2,10 +2,17 @@ import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { VpnConfig } from '../ts/index.js';
|
||||
import type { IVpnClientConfig, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// Valid 32-byte base64 keys for testing
|
||||
const TEST_KEY_A = 'YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE=';
|
||||
const TEST_KEY_B = 'YmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmI=';
|
||||
const TEST_KEY_C = 'Y2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2M=';
|
||||
|
||||
tap.test('VpnConfig: validate valid client config', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
@@ -16,7 +23,9 @@ tap.test('VpnConfig: validate valid client config', async () => {
|
||||
|
||||
tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
const config = {
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -31,7 +40,9 @@ tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'http://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -43,10 +54,28 @@ tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config without clientPrivateKey', async () => {
|
||||
const config = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('clientPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
mtu: 100,
|
||||
};
|
||||
let threw = false;
|
||||
@@ -62,7 +91,9 @@ tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['not-an-ip'],
|
||||
};
|
||||
let threw = false;
|
||||
@@ -78,12 +109,15 @@ tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
tap.test('VpnConfig: validate valid server config', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
enableNat: true,
|
||||
clients: [
|
||||
{ clientId: 'test-client', publicKey: TEST_KEY_C },
|
||||
],
|
||||
};
|
||||
// Should not throw
|
||||
VpnConfig.validateServerConfig(config);
|
||||
@@ -92,8 +126,8 @@ tap.test('VpnConfig: validate valid server config', async () => {
|
||||
tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: 'invalid',
|
||||
};
|
||||
let threw = false;
|
||||
@@ -109,7 +143,7 @@ tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
const config = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
} as IVpnServerConfig;
|
||||
let threw = false;
|
||||
@@ -122,4 +156,24 @@ tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject server config with invalid client publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
clients: [
|
||||
{ clientId: 'bad-client', publicKey: 'short-key' },
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
|
||||
353
test/test.wireguard.node.ts
Normal file
353
test/test.wireguard.node.ts
Normal file
@@ -0,0 +1,353 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import {
|
||||
VpnConfig,
|
||||
VpnServer,
|
||||
WgConfigGenerator,
|
||||
} from '../ts/index.js';
|
||||
import type {
|
||||
IVpnClientConfig,
|
||||
IVpnServerConfig,
|
||||
IVpnServerOptions,
|
||||
IWgPeerConfig,
|
||||
} from '../ts/index.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — client
|
||||
// ============================================================================
|
||||
|
||||
// A valid 32-byte key in base64 (44 chars)
|
||||
const VALID_KEY = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=';
|
||||
const VALID_KEY_2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=';
|
||||
|
||||
tap.test('WG client config: valid wireguard config passes validation', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '', // not needed for WG
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['0.0.0.0/0'],
|
||||
};
|
||||
VpnConfig.validateClientConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgPrivateKey', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgAddress', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgAddress');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgEndpoint', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgEndpoint');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid key length', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: 'tooshort',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('44 characters');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid CIDR in allowedIps', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['not-a-cidr'],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('CIDR');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — server
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WG server config: valid config passes validation', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
VpnConfig.validateServerConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects empty wgPeers', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPeers');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects peer without publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: '',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects invalid wgListenPort', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgListenPort: 0,
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgListenPort');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard keypair generation via daemon
|
||||
// ============================================================================
|
||||
|
||||
let server: VpnServer;
|
||||
|
||||
tap.test('WG: spawn server daemon for keypair generation', async () => {
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns valid keypair', async () => {
|
||||
const keypair = await server.generateWgKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
// WireGuard keys: base64 of 32 bytes = 44 characters
|
||||
expect(keypair.publicKey.length).toEqual(44);
|
||||
expect(keypair.privateKey.length).toEqual(44);
|
||||
// Verify they decode to 32 bytes
|
||||
const pubBuf = Buffer.from(keypair.publicKey, 'base64');
|
||||
const privBuf = Buffer.from(keypair.privateKey, 'base64');
|
||||
expect(pubBuf.length).toEqual(32);
|
||||
expect(privBuf.length).toEqual(32);
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns unique keys each time', async () => {
|
||||
const kp1 = await server.generateWgKeypair();
|
||||
const kp2 = await server.generateWgKeypair();
|
||||
expect(kp1.publicKey).not.toEqual(kp2.publicKey);
|
||||
expect(kp1.privateKey).not.toEqual(kp2.privateKey);
|
||||
});
|
||||
|
||||
tap.test('WG: stop server daemon', async () => {
|
||||
server.stop();
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config file generation
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'clientPrivateKeyBase64====================',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
peer: {
|
||||
publicKey: 'serverPublicKeyBase64====================',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0', '::/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('PrivateKey = clientPrivateKeyBase64====================');
|
||||
expect(conf).toContain('Address = 10.8.0.2/24');
|
||||
expect(conf).toContain('DNS = 1.1.1.1, 8.8.8.8');
|
||||
expect(conf).toContain('MTU = 1420');
|
||||
expect(conf).toContain('[Peer]');
|
||||
expect(conf).toContain('PublicKey = serverPublicKeyBase64====================');
|
||||
expect(conf).toContain('Endpoint = vpn.example.com:51820');
|
||||
expect(conf).toContain('AllowedIPs = 0.0.0.0/0, ::/0');
|
||||
expect(conf).toContain('PersistentKeepalive = 25');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config without optional fields', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'key1',
|
||||
address: '10.0.0.2/32',
|
||||
peer: {
|
||||
publicKey: 'key2',
|
||||
endpoint: 'server:51820',
|
||||
allowedIps: ['10.0.0.0/24'],
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).not.toContain('DNS');
|
||||
expect(conf).not.toContain('MTU');
|
||||
expect(conf).not.toContain('PresharedKey');
|
||||
expect(conf).not.toContain('PersistentKeepalive');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config with NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
dns: ['1.1.1.1'],
|
||||
enableNat: true,
|
||||
natInterface: 'ens3',
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peer1PubKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
presharedKey: 'psk1',
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
{
|
||||
publicKey: 'peer2PubKey',
|
||||
allowedIps: ['10.8.0.3/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('ListenPort = 51820');
|
||||
expect(conf).toContain('PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ens3 -j MASQUERADE');
|
||||
expect(conf).toContain('PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ens3 -j MASQUERADE');
|
||||
// Two [Peer] sections
|
||||
const peerCount = (conf.match(/\[Peer\]/g) || []).length;
|
||||
expect(peerCount).toEqual(2);
|
||||
expect(conf).toContain('PresharedKey = psk1');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config without NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peerKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).not.toContain('PostUp');
|
||||
expect(conf).not.toContain('PostDown');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartvpn',
|
||||
version: '1.0.1',
|
||||
version: '1.10.1',
|
||||
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
|
||||
}
|
||||
|
||||
@@ -4,3 +4,4 @@ export { VpnClient } from './smartvpn.classes.vpnclient.js';
|
||||
export { VpnServer } from './smartvpn.classes.vpnserver.js';
|
||||
export { VpnConfig } from './smartvpn.classes.vpnconfig.js';
|
||||
export { VpnInstaller } from './smartvpn.classes.vpninstaller.js';
|
||||
export { WgConfigGenerator } from './smartvpn.classes.wgconfig.js';
|
||||
|
||||
@@ -5,6 +5,8 @@ import type {
|
||||
IVpnClientConfig,
|
||||
IVpnStatus,
|
||||
IVpnStatistics,
|
||||
IVpnConnectionQuality,
|
||||
IVpnMtuInfo,
|
||||
TVpnClientCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -65,12 +67,26 @@ export class VpnClient extends plugins.events.EventEmitter {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get traffic statistics.
|
||||
* Get traffic statistics (includes connection quality when connected).
|
||||
*/
|
||||
public async getStatistics(): Promise<IVpnStatistics> {
|
||||
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection quality metrics (RTT, jitter, loss, link health).
|
||||
*/
|
||||
public async getConnectionQuality(): Promise<IVpnConnectionQuality> {
|
||||
return this.bridge.sendCommand('getConnectionQuality', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get MTU information (overhead, effective MTU, oversized packet stats).
|
||||
*/
|
||||
public async getMtuInfo(): Promise<IVpnMtuInfo> {
|
||||
return this.bridge.sendCommand('getMtuInfo', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
@@ -12,14 +12,54 @@ export class VpnConfig {
|
||||
* Validate a client config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateClientConfig(config: IVpnClientConfig): void {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws://');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
if (config.transport === 'wireguard') {
|
||||
// WireGuard-specific validation
|
||||
if (!config.wgPrivateKey) {
|
||||
throw new Error('VpnConfig: wgPrivateKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.wgPrivateKey, 'wgPrivateKey');
|
||||
if (!config.wgAddress) {
|
||||
throw new Error('VpnConfig: wgAddress is required for WireGuard transport');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.serverPublicKey, 'serverPublicKey');
|
||||
if (!config.wgEndpoint) {
|
||||
throw new Error('VpnConfig: wgEndpoint is required for WireGuard transport');
|
||||
}
|
||||
if (config.wgPresharedKey) {
|
||||
VpnConfig.validateBase64Key(config.wgPresharedKey, 'wgPresharedKey');
|
||||
}
|
||||
if (config.wgAllowedIps) {
|
||||
for (const cidr of config.wgAllowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
// For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss://
|
||||
if (config.transport !== 'quic') {
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)');
|
||||
}
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
}
|
||||
// Noise IK requires client keypair
|
||||
if (!config.clientPrivateKey) {
|
||||
throw new Error('VpnConfig: clientPrivateKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPrivateKey, 'clientPrivateKey');
|
||||
if (!config.clientPublicKey) {
|
||||
throw new Error('VpnConfig: clientPublicKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPublicKey, 'clientPublicKey');
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -40,20 +80,63 @@ export class VpnConfig {
|
||||
* Validate a server config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateServerConfig(config: IVpnServerConfig): void {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
if (config.transportMode === 'wireguard') {
|
||||
// WireGuard server validation
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.privateKey, 'privateKey');
|
||||
if (!config.wgPeers || config.wgPeers.length === 0) {
|
||||
throw new Error('VpnConfig: at least one wgPeers entry is required for WireGuard mode');
|
||||
}
|
||||
for (const peer of config.wgPeers) {
|
||||
if (!peer.publicKey) {
|
||||
throw new Error('VpnConfig: peer publicKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(peer.publicKey, 'peer.publicKey');
|
||||
if (!peer.allowedIps || peer.allowedIps.length === 0) {
|
||||
throw new Error('VpnConfig: peer allowedIps is required');
|
||||
}
|
||||
for (const cidr of peer.allowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid peer allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
if (peer.presharedKey) {
|
||||
VpnConfig.validateBase64Key(peer.presharedKey, 'peer.presharedKey');
|
||||
}
|
||||
}
|
||||
if (config.wgListenPort !== undefined && (config.wgListenPort < 1 || config.wgListenPort > 65535)) {
|
||||
throw new Error('VpnConfig: wgListenPort must be between 1 and 65535');
|
||||
}
|
||||
} else {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
}
|
||||
// Validate client entries if provided
|
||||
if (config.clients) {
|
||||
for (const client of config.clients) {
|
||||
if (!client.clientId) {
|
||||
throw new Error('VpnConfig: client entry must have a clientId');
|
||||
}
|
||||
if (!client.publicKey) {
|
||||
throw new Error(`VpnConfig: client '${client.clientId}' must have a publicKey`);
|
||||
}
|
||||
VpnConfig.validateBase64Key(client.publicKey, `client '${client.clientId}' publicKey`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -101,4 +184,41 @@ export class VpnConfig {
|
||||
const prefixNum = parseInt(prefix, 10);
|
||||
return !isNaN(prefixNum) && prefixNum >= 0 && prefixNum <= 32;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a CIDR string (IPv4 or IPv6).
|
||||
*/
|
||||
private static isValidCidr(cidr: string): boolean {
|
||||
const parts = cidr.split('/');
|
||||
if (parts.length !== 2) return false;
|
||||
const prefixNum = parseInt(parts[1], 10);
|
||||
if (isNaN(prefixNum) || prefixNum < 0) return false;
|
||||
// IPv4
|
||||
if (VpnConfig.isValidIp(parts[0])) {
|
||||
return prefixNum <= 32;
|
||||
}
|
||||
// IPv6 (basic check)
|
||||
if (parts[0].includes(':')) {
|
||||
return prefixNum <= 128;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a base64-encoded 32-byte key (WireGuard X25519 format).
|
||||
*/
|
||||
private static validateBase64Key(key: string, fieldName: string): void {
|
||||
if (key.length !== 44) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must be 44 characters (base64 of 32 bytes), got ${key.length}`);
|
||||
}
|
||||
try {
|
||||
const buf = Buffer.from(key, 'base64');
|
||||
if (buf.length !== 32) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must decode to 32 bytes, got ${buf.length}`);
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.startsWith('VpnConfig:')) throw e;
|
||||
throw new Error(`VpnConfig: ${fieldName} is not valid base64`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,11 @@ import type {
|
||||
IVpnServerStatistics,
|
||||
IVpnClientInfo,
|
||||
IVpnKeypair,
|
||||
IVpnClientTelemetry,
|
||||
IWgPeerConfig,
|
||||
IWgPeerInfo,
|
||||
IClientEntry,
|
||||
IClientConfigBundle,
|
||||
TVpnServerCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -91,6 +96,139 @@ export class VpnServer extends plugins.events.EventEmitter {
|
||||
return this.bridge.sendCommand('generateKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set rate limit for a specific client.
|
||||
*/
|
||||
public async setClientRateLimit(
|
||||
clientId: string,
|
||||
rateBytesPerSec: number,
|
||||
burstBytes: number,
|
||||
): Promise<void> {
|
||||
await this.bridge.sendCommand('setClientRateLimit', {
|
||||
clientId,
|
||||
rateBytesPerSec,
|
||||
burstBytes,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove rate limit for a specific client (unlimited).
|
||||
*/
|
||||
public async removeClientRateLimit(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClientRateLimit', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get telemetry for a specific client.
|
||||
*/
|
||||
public async getClientTelemetry(clientId: string): Promise<IVpnClientTelemetry> {
|
||||
return this.bridge.sendCommand('getClientTelemetry', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a WireGuard-compatible X25519 keypair.
|
||||
*/
|
||||
public async generateWgKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateWgKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a WireGuard peer (server must be running in wireguard mode).
|
||||
*/
|
||||
public async addWgPeer(peer: IWgPeerConfig): Promise<void> {
|
||||
await this.bridge.sendCommand('addWgPeer', { peer });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a WireGuard peer by public key.
|
||||
*/
|
||||
public async removeWgPeer(publicKey: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeWgPeer', { publicKey });
|
||||
}
|
||||
|
||||
/**
|
||||
* List WireGuard peers with stats.
|
||||
*/
|
||||
public async listWgPeers(): Promise<IWgPeerInfo[]> {
|
||||
const result = await this.bridge.sendCommand('listWgPeers', {} as Record<string, never>);
|
||||
return result.peers;
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ─────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a new client. Generates keypairs, assigns IP, returns full config bundle.
|
||||
* The secrets (private keys) are only returned at creation time.
|
||||
*/
|
||||
public async createClient(opts: Partial<IClientEntry>): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('createClient', { client: opts });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a registered client (also disconnects if connected).
|
||||
*/
|
||||
public async removeClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a registered client by ID.
|
||||
*/
|
||||
public async getClient(clientId: string): Promise<IClientEntry> {
|
||||
return this.bridge.sendCommand('getClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* List all registered clients.
|
||||
*/
|
||||
public async listRegisteredClients(): Promise<IClientEntry[]> {
|
||||
const result = await this.bridge.sendCommand('listRegisteredClients', {} as Record<string, never>);
|
||||
return result.clients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a registered client's fields (ACLs, tags, description, etc.).
|
||||
*/
|
||||
public async updateClient(clientId: string, update: Partial<IClientEntry>): Promise<void> {
|
||||
await this.bridge.sendCommand('updateClient', { clientId, update });
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a previously disabled client.
|
||||
*/
|
||||
public async enableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('enableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable a client (also disconnects if connected).
|
||||
*/
|
||||
public async disableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('disableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
*/
|
||||
public async rotateClientKey(clientId: string): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('rotateClientKey', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Export a client config (without secrets) in the specified format.
|
||||
*/
|
||||
public async exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): Promise<string> {
|
||||
const result = await this.bridge.sendCommand('exportClientConfig', { clientId, format });
|
||||
return result.config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a standalone Noise IK keypair (not tied to a client).
|
||||
*/
|
||||
public async generateClientKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateClientKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
123
ts/smartvpn.classes.wgconfig.ts
Normal file
123
ts/smartvpn.classes.wgconfig.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import type { IWgPeerConfig } from './smartvpn.interfaces.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard .conf file generator
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgClientConfOptions {
|
||||
/** Client private key (base64) */
|
||||
privateKey: string;
|
||||
/** Client TUN address with prefix (e.g. 10.8.0.2/24) */
|
||||
address: string;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Server peer config */
|
||||
peer: {
|
||||
publicKey: string;
|
||||
presharedKey?: string;
|
||||
endpoint: string;
|
||||
allowedIps: string[];
|
||||
persistentKeepalive?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface IWgServerConfOptions {
|
||||
/** Server private key (base64) */
|
||||
privateKey: string;
|
||||
/** Server TUN address with prefix (e.g. 10.8.0.1/24) */
|
||||
address: string;
|
||||
/** UDP listen port */
|
||||
listenPort: number;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Enable NAT — adds PostUp/PostDown iptables rules */
|
||||
enableNat?: boolean;
|
||||
/** Network interface for NAT (e.g. eth0). Auto-detected if omitted. */
|
||||
natInterface?: string;
|
||||
/** Configured peers */
|
||||
peers: IWgPeerConfig[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates standard WireGuard .conf files compatible with wg-quick,
|
||||
* WireGuard iOS/Android apps, and other standard WireGuard clients.
|
||||
*/
|
||||
export class WgConfigGenerator {
|
||||
/**
|
||||
* Generate a client .conf file content.
|
||||
*/
|
||||
public static generateClientConfig(opts: IWgClientConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${opts.peer.publicKey}`);
|
||||
if (opts.peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${opts.peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`Endpoint = ${opts.peer.endpoint}`);
|
||||
lines.push(`AllowedIPs = ${opts.peer.allowedIps.join(', ')}`);
|
||||
if (opts.peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${opts.peer.persistentKeepalive}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a server .conf file content.
|
||||
*/
|
||||
public static generateServerConfig(opts: IWgServerConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
lines.push(`ListenPort = ${opts.listenPort}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
if (opts.enableNat) {
|
||||
const iface = opts.natInterface || 'eth0';
|
||||
lines.push(`PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
lines.push(`PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
}
|
||||
|
||||
for (const peer of opts.peers) {
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${peer.publicKey}`);
|
||||
if (peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`AllowedIPs = ${peer.allowedIps.join(', ')}`);
|
||||
if (peer.endpoint) {
|
||||
lines.push(`Endpoint = ${peer.endpoint}`);
|
||||
}
|
||||
if (peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${peer.persistentKeepalive}`);
|
||||
}
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
}
|
||||
@@ -24,14 +24,39 @@ export type TVpnTransportOptions = IVpnTransportStdio | IVpnTransportSocket;
|
||||
export interface IVpnClientConfig {
|
||||
/** Server WebSocket URL, e.g. wss://vpn.example.com/tunnel */
|
||||
serverUrl: string;
|
||||
/** Server's static public key (base64) for Noise NK handshake */
|
||||
/** Server's static public key (base64) for Noise IK handshake */
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — required for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
/** Optional DNS servers to use while connected */
|
||||
dns?: string[];
|
||||
/** Optional MTU for the TUN device */
|
||||
mtu?: number;
|
||||
/** Keepalive interval in seconds (default: 30) */
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', 'quic', or 'wireguard' */
|
||||
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
|
||||
/** For QUIC: SHA-256 hash of server certificate (base64) for cert pinning */
|
||||
serverCertHash?: string;
|
||||
/** Forwarding mode: 'tun' (TUN device, requires root) or 'testing' (no TUN).
|
||||
* Default: 'testing'. */
|
||||
forwardingMode?: 'tun' | 'testing';
|
||||
/** WireGuard: client private key (base64, X25519) */
|
||||
wgPrivateKey?: string;
|
||||
/** WireGuard: client TUN address (e.g. 10.8.0.2) */
|
||||
wgAddress?: string;
|
||||
/** WireGuard: client TUN address prefix length (default: 24) */
|
||||
wgAddressPrefix?: number;
|
||||
/** WireGuard: preshared key (base64, optional) */
|
||||
wgPresharedKey?: string;
|
||||
/** WireGuard: persistent keepalive interval in seconds */
|
||||
wgPersistentKeepalive?: number;
|
||||
/** WireGuard: server endpoint (host:port, e.g. vpn.example.com:51820) */
|
||||
wgEndpoint?: string;
|
||||
/** WireGuard: allowed IPs (CIDR strings, e.g. ['0.0.0.0/0']) */
|
||||
wgAllowedIps?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnClientOptions {
|
||||
@@ -64,6 +89,32 @@ export interface IVpnServerConfig {
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Enable NAT/masquerade for client traffic */
|
||||
enableNat?: boolean;
|
||||
/** Forwarding mode: 'tun' (kernel TUN, requires root), 'socket' (userspace NAT),
|
||||
* or 'testing' (monitoring only). Default: 'testing'. */
|
||||
forwardingMode?: 'tun' | 'socket' | 'testing';
|
||||
/** Default rate limit for new clients (bytes/sec). Omit for unlimited. */
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
/** Default burst size for new clients (bytes). Omit for unlimited. */
|
||||
defaultBurstBytes?: number;
|
||||
/** Transport mode: 'both' (default, WS+QUIC), 'websocket', 'quic', or 'wireguard' */
|
||||
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
|
||||
/** QUIC listen address (host:port). Defaults to listenAddr. */
|
||||
quicListenAddr?: string;
|
||||
/** QUIC idle timeout in seconds (default: 30) */
|
||||
quicIdleTimeoutSecs?: number;
|
||||
/** WireGuard: UDP listen port (default: 51820) */
|
||||
wgListenPort?: number;
|
||||
/** WireGuard: configured peers */
|
||||
wgPeers?: IWgPeerConfig[];
|
||||
/** Pre-registered clients for Noise IK authentication */
|
||||
clients?: IClientEntry[];
|
||||
/** Enable PROXY protocol v2 on incoming WebSocket connections.
|
||||
* Required when behind a reverse proxy that sends PP v2 headers (HAProxy, SmartProxy).
|
||||
* SECURITY: Must be false when accepting direct client connections. */
|
||||
proxyProtocol?: boolean;
|
||||
/** Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
* Supports exact IPs, CIDR, wildcards, ranges. */
|
||||
connectionIpBlockList?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnServerOptions {
|
||||
@@ -99,6 +150,7 @@ export interface IVpnStatistics {
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
quality?: IVpnConnectionQuality;
|
||||
}
|
||||
|
||||
export interface IVpnClientInfo {
|
||||
@@ -107,6 +159,18 @@ export interface IVpnClientInfo {
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
/** Client's authenticated Noise IK public key (base64) */
|
||||
authenticatedKey: string;
|
||||
/** Registered client ID from the client registry */
|
||||
registeredClientId: string;
|
||||
/** Real client IP:port (from PROXY protocol or direct TCP connection) */
|
||||
remoteAddr?: string;
|
||||
}
|
||||
|
||||
export interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -119,6 +183,160 @@ export interface IVpnKeypair {
|
||||
privateKey: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: Connection quality
|
||||
// ============================================================================
|
||||
|
||||
export type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
|
||||
|
||||
export interface IVpnConnectionQuality {
|
||||
srttMs: number;
|
||||
jitterMs: number;
|
||||
minRttMs: number;
|
||||
maxRttMs: number;
|
||||
lossRatio: number;
|
||||
consecutiveTimeouts: number;
|
||||
linkHealth: TVpnLinkHealth;
|
||||
currentKeepaliveIntervalSecs: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: MTU info
|
||||
// ============================================================================
|
||||
|
||||
export interface IVpnMtuInfo {
|
||||
tunMtu: number;
|
||||
effectiveMtu: number;
|
||||
linkMtu: number;
|
||||
overheadBytes: number;
|
||||
oversizedPacketsDropped: number;
|
||||
icmpTooBigSent: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QoS: Client telemetry (server-side per-client)
|
||||
// ============================================================================
|
||||
|
||||
export interface IVpnClientTelemetry {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
bytesReceived: number;
|
||||
bytesSent: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client Registry (Hub) types — aligned with SmartProxy IRouteSecurity pattern
|
||||
// ============================================================================
|
||||
|
||||
/** Per-client rate limiting. */
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Per-client security settings.
|
||||
* Mirrors SmartProxy's IRouteSecurity: ipAllowList/ipBlockList naming + deny-overrides-allow.
|
||||
* Adds VPN-specific destination filtering.
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5). */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins). */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all). */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins). */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client. */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting. */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server-side client definition — the central config object for the Hub.
|
||||
* Naming and structure aligned with SmartProxy's IRouteConfig / IRouteSecurity.
|
||||
*/
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
/** Security settings (ACLs, rate limits) */
|
||||
security?: IClientSecurity;
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
/** Assigned VPN IP address (set by server) */
|
||||
assignedIp?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete client config bundle — returned by createClient() and rotateClientKey().
|
||||
* Contains everything the client needs to connect.
|
||||
*/
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard-specific types
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgPeerConfig {
|
||||
/** Peer's public key (base64, X25519) */
|
||||
publicKey: string;
|
||||
/** Optional preshared key (base64) */
|
||||
presharedKey?: string;
|
||||
/** Allowed IP ranges (CIDR strings) */
|
||||
allowedIps: string[];
|
||||
/** Peer endpoint (host:port) — optional for server peers, required for client */
|
||||
endpoint?: string;
|
||||
/** Persistent keepalive interval in seconds */
|
||||
persistentKeepalive?: number;
|
||||
}
|
||||
|
||||
export interface IWgPeerInfo {
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
persistentKeepalive?: number;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
lastHandshakeTime?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// IPC Command maps (used by smartrust RustBridge<TCommands>)
|
||||
// ============================================================================
|
||||
@@ -128,6 +346,8 @@ export type TVpnClientCommands = {
|
||||
disconnect: { params: Record<string, never>; result: void };
|
||||
getStatus: { params: Record<string, never>; result: IVpnStatus };
|
||||
getStatistics: { params: Record<string, never>; result: IVpnStatistics };
|
||||
getConnectionQuality: { params: Record<string, never>; result: IVpnConnectionQuality };
|
||||
getMtuInfo: { params: Record<string, never>; result: IVpnMtuInfo };
|
||||
};
|
||||
|
||||
export type TVpnServerCommands = {
|
||||
@@ -138,6 +358,24 @@ export type TVpnServerCommands = {
|
||||
listClients: { params: Record<string, never>; result: { clients: IVpnClientInfo[] } };
|
||||
disconnectClient: { params: { clientId: string }; result: void };
|
||||
generateKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
|
||||
removeClientRateLimit: { params: { clientId: string }; result: void };
|
||||
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
|
||||
generateWgKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
addWgPeer: { params: { peer: IWgPeerConfig }; result: void };
|
||||
removeWgPeer: { params: { publicKey: string }; result: void };
|
||||
listWgPeers: { params: Record<string, never>; result: { peers: IWgPeerInfo[] } };
|
||||
// Client Registry (Hub) commands
|
||||
createClient: { params: { client: Partial<IClientEntry> }; result: IClientConfigBundle };
|
||||
removeClient: { params: { clientId: string }; result: void };
|
||||
getClient: { params: { clientId: string }; result: IClientEntry };
|
||||
listRegisteredClients: { params: Record<string, never>; result: { clients: IClientEntry[] } };
|
||||
updateClient: { params: { clientId: string; update: Partial<IClientEntry> }; result: void };
|
||||
enableClient: { params: { clientId: string }; result: void };
|
||||
disableClient: { params: { clientId: string }; result: void };
|
||||
rotateClientKey: { params: { clientId: string }; result: IClientConfigBundle };
|
||||
exportClientConfig: { params: { clientId: string; format: 'smartvpn' | 'wireguard' }; result: { config: string } };
|
||||
generateClientKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"esModuleInterop": true,
|
||||
"verbatimModuleSyntax": true
|
||||
"verbatimModuleSyntax": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"exclude": [
|
||||
"dist_ts/**/*.d.ts"
|
||||
|
||||
Reference in New Issue
Block a user