Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e31086d0c2 | |||
| 01a0d8b9f4 | |||
| 187a69028b | |||
| 64dedd389e | |||
| 13d8cbe3fa | |||
| f46ea70286 | |||
| 26ee3634c8 | |||
| 049fa00563 | |||
| e4e59d72f9 | |||
| 51d33127bf | |||
| a4ba6806e5 | |||
| 6330921160 | |||
| e81dd377d8 | |||
| e14c357ba0 | |||
| eb30825f72 | |||
| 835f0f791d | |||
| aec545fe8c |
60
changelog.md
60
changelog.md
@@ -1,5 +1,65 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-03-29 - 1.8.0 - feat(auth,client-registry)
|
||||
add Noise IK client authentication with managed client registry and per-client ACL controls
|
||||
|
||||
- switch the native tunnel handshake from Noise NK to Noise IK and require client keypairs in client configuration
|
||||
- add server-side client registry management APIs for creating, updating, disabling, rotating, listing, and exporting client configs
|
||||
- enforce client authorization from the registry during handshake and expose authenticated client metadata in server client info
|
||||
- introduce per-client security policies with source/destination ACLs and per-client rate limit settings
|
||||
- add Rust ACL matching support for exact IPs, CIDR ranges, wildcards, and IP ranges with test coverage
|
||||
|
||||
## 2026-03-29 - 1.7.0 - feat(rust-tests)
|
||||
add end-to-end WireGuard UDP integration tests and align TypeScript build configuration
|
||||
|
||||
- Add userspace Rust end-to-end tests that validate WireGuard handshake, encryption, peer isolation, and preshared-key data exchange over real UDP sockets.
|
||||
- Update the TypeScript build setup by removing the allowimplicitany build flag and explicitly including Node types in tsconfig.
|
||||
- Refresh development toolchain versions to support the updated test and build workflow.
|
||||
|
||||
## 2026-03-29 - 1.6.0 - feat(readme)
|
||||
document WireGuard transport support, configuration, and usage examples
|
||||
|
||||
- Expand the README from dual-transport to triple-transport support by adding WireGuard alongside WebSocket and QUIC
|
||||
- Add client and server WireGuard examples, including live peer management and .conf generation with WgConfigGenerator
|
||||
- Document new WireGuard-related API methods, config fields, transport modes, and security model details
|
||||
|
||||
## 2026-03-29 - 1.5.0 - feat(wireguard)
|
||||
add WireGuard transport support with management APIs and config generation
|
||||
|
||||
- add Rust WireGuard module integration using boringtun and route management through client/server management handlers
|
||||
- extend TypeScript client and server configuration schemas with WireGuard-specific options and validation
|
||||
- add server-side WireGuard peer management commands including keypair generation, peer add/remove, and peer listing
|
||||
- export a WireGuard config generator for producing client and server .conf files
|
||||
- add WireGuard-focused test coverage for config validation and config generation
|
||||
|
||||
## 2026-03-21 - 1.4.1 - fix(readme)
|
||||
preserve markdown line breaks in feature list
|
||||
|
||||
- Adds trailing spaces to the README feature list so each highlighted capability renders on its own line.
|
||||
|
||||
## 2026-03-19 - 1.4.0 - feat(vpn transport)
|
||||
add QUIC transport support with auto fallback to WebSocket
|
||||
|
||||
- introduces a transport abstraction in the Rust daemon so client and server can operate over WebSocket or QUIC
|
||||
- adds dual-mode server configuration with websocket, quic, and both transport modes plus QUIC idle timeout and listen address options
|
||||
- adds client transport selection with auto mode that attempts QUIC first and falls back to WebSocket
|
||||
- adds QUIC certificate hash pinning support and required Rust dependencies for QUIC and TLS
|
||||
- updates TypeScript interfaces, config validation, tests, and documentation to cover the new transport modes
|
||||
|
||||
## 2026-03-17 - 1.3.0 - feat(tests,client)
|
||||
add flow control and load test coverage and honor configured keepalive intervals
|
||||
|
||||
- Adds end-to-end node tests for client/server flow control, keepalive exchange, connection quality telemetry, rate limiting, concurrent clients, and disconnect tracking.
|
||||
- Adds load testing with throttled proxy scenarios to validate behavior under constrained bandwidth and repeated client churn.
|
||||
- Updates the Rust client to pass configured keepaliveIntervalSecs into the adaptive keepalive monitor instead of always using defaults.
|
||||
|
||||
## 2026-03-15 - 1.2.0 - feat(readme)
|
||||
document QoS, telemetry, MTU, and rate limiting capabilities in the README
|
||||
|
||||
- Expand the architecture and feature overview to cover adaptive keepalive, telemetry, QoS, rate limiting, and MTU handling
|
||||
- Update client and server examples to show new APIs such as getConnectionQuality(), getMtuInfo(), setClientRateLimit(), and getClientTelemetry()
|
||||
- Add TypeScript interface documentation for connection quality, MTU info, enriched client statistics, and per-client telemetry
|
||||
|
||||
## 2026-03-15 - 1.1.0 - feat(rust-core)
|
||||
add adaptive keepalive telemetry, MTU handling, and per-client rate limiting APIs
|
||||
|
||||
|
||||
19
package.json
19
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartvpn",
|
||||
"version": "1.1.0",
|
||||
"version": "1.8.0",
|
||||
"private": false,
|
||||
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
|
||||
"type": "module",
|
||||
@@ -10,7 +10,8 @@
|
||||
"main": "dist_ts/index.js",
|
||||
"typings": "dist_ts/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
|
||||
"build": "(tsbuild tsfolders) && (tsrust)",
|
||||
"test:before": "(tsrust)",
|
||||
"test": "tstest test/ --verbose",
|
||||
"buildDocs": "tsdoc"
|
||||
},
|
||||
@@ -28,15 +29,15 @@
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@push.rocks/smartrust": "^1.3.0",
|
||||
"@push.rocks/smartpath": "^5.0.18"
|
||||
"@push.rocks/smartpath": "^6.0.0",
|
||||
"@push.rocks/smartrust": "^1.3.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@git.zone/tsbuild": "^2.2.12",
|
||||
"@git.zone/tsrun": "^1.3.3",
|
||||
"@git.zone/tstest": "^1.0.96",
|
||||
"@git.zone/tsrust": "^1.0.29",
|
||||
"@types/node": "^22.0.0"
|
||||
"@git.zone/tsbuild": "^4.4.0",
|
||||
"@git.zone/tsrun": "^2.0.2",
|
||||
"@git.zone/tsrust": "^1.3.2",
|
||||
"@git.zone/tstest": "^3.6.3",
|
||||
"@types/node": "^25.5.0"
|
||||
},
|
||||
"files": [
|
||||
"ts/**/*",
|
||||
|
||||
3545
pnpm-lock.yaml
generated
3545
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
588
readme.md
588
readme.md
@@ -1,396 +1,366 @@
|
||||
# @push.rocks/smartvpn
|
||||
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, typed APIs while all networking heavy lifting — encryption, tunneling, packet forwarding — runs at native speed in Rust.
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Enterprise-ready client authentication, triple transport support (WebSocket + QUIC + WireGuard), and a typed hub API for managing clients from code.
|
||||
|
||||
🔐 **Noise IK** mutual authentication — per-client X25519 keypairs, server-side registry
|
||||
🚀 **Triple transport**: WebSocket (Cloudflare-friendly), raw **QUIC** (datagrams), and **WireGuard** (standard protocol)
|
||||
🛡️ **ACL engine** — deny-overrides-allow IP filtering, aligned with SmartProxy conventions
|
||||
📊 **Adaptive QoS**: per-client rate limiting, priority queues, connection quality tracking
|
||||
🔄 **Hub API**: one `createClient()` call generates keys, assigns IP, returns both SmartVPN + WireGuard configs
|
||||
📡 **Real-time telemetry**: RTT, jitter, loss ratio, link health — all via typed APIs
|
||||
|
||||
## Issue Reporting and Security
|
||||
|
||||
For reporting bugs, issues, or security vulnerabilities, please visit [community.foss.global/](https://community.foss.global/). This is the central community hub for all issue reporting. Developers who sign and comply with our contribution agreement and go through identification can also get a [code.foss.global/](https://code.foss.global/) account to submit Pull Requests directly.
|
||||
|
||||
## Install
|
||||
## Install 📦
|
||||
|
||||
```bash
|
||||
npm install @push.rocks/smartvpn
|
||||
# or
|
||||
pnpm install @push.rocks/smartvpn
|
||||
# or
|
||||
npm install @push.rocks/smartvpn
|
||||
```
|
||||
|
||||
## 🏗️ Architecture
|
||||
The package ships with pre-compiled Rust binaries for **linux/amd64** and **linux/arm64**. No Rust toolchain required at runtime.
|
||||
|
||||
## Architecture 🏗️
|
||||
|
||||
```
|
||||
TypeScript (control plane) Rust (data plane)
|
||||
┌──────────────────────────┐ ┌───────────────────────────────┐
|
||||
│ VpnClient / VpnServer │ │ smartvpn_daemon │
|
||||
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
|
||||
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
|
||||
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha) │
|
||||
└──────────────────────────┘ │ ├─ codec (binary framing) │
|
||||
│ ├─ keepalive (app-level) │
|
||||
│ ├─ tunnel (TUN device) │
|
||||
│ ├─ network (NAT/IP pool) │
|
||||
│ └─ reconnect (backoff) │
|
||||
└───────────────────────────────┘
|
||||
┌──────────────────────────────┐ JSON-lines IPC ┌───────────────────────────────┐
|
||||
│ TypeScript Control Plane │ ◄─────────────────────► │ Rust Data Plane Daemon │
|
||||
│ │ stdio or Unix sock │ │
|
||||
│ VpnServer / VpnClient │ │ Noise IK handshake │
|
||||
│ Typed IPC commands │ │ XChaCha20-Poly1305 │
|
||||
│ Config validation │ │ WS + QUIC + WireGuard │
|
||||
│ Hub: client management │ │ TUN device, IP pool, NAT │
|
||||
│ WireGuard .conf generation │ │ Rate limiting, ACLs, QoS │
|
||||
└──────────────────────────────┘ └───────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key design decisions:**
|
||||
**Split-plane design** — TypeScript handles orchestration, config, and DX; Rust handles every hot-path byte with zero-copy async I/O (tokio, mimalloc).
|
||||
|
||||
| Decision | Choice | Why |
|
||||
|----------|--------|-----|
|
||||
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
|
||||
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
|
||||
| Keepalive | App-level (not WS pings) | Cloudflare drops WS ping frames; app-level pings survive |
|
||||
| IPC | JSON lines over stdio/Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
|
||||
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
|
||||
## Quick Start 🚀
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### VPN Client
|
||||
|
||||
```typescript
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Development: spawn the Rust daemon as a child process
|
||||
const client = new VpnClient({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon bridge
|
||||
await client.start();
|
||||
|
||||
// Connect to a VPN server
|
||||
const { assignedIp } = await client.connect({
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
});
|
||||
|
||||
console.log(`Connected! Assigned IP: ${assignedIp}`);
|
||||
|
||||
// Check status
|
||||
const status = await client.getStatus();
|
||||
console.log(status); // { state: 'connected', assignedIp: '10.8.0.2', ... }
|
||||
|
||||
// Get traffic stats
|
||||
const stats = await client.getStatistics();
|
||||
console.log(stats); // { bytesSent, bytesReceived, packetsSent, ... }
|
||||
|
||||
// Disconnect
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
```
|
||||
|
||||
### VPN Server
|
||||
### 1. Start a VPN Server (Hub)
|
||||
|
||||
```typescript
|
||||
import { VpnServer } from '@push.rocks/smartvpn';
|
||||
|
||||
const server = new VpnServer({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Start the daemon and the VPN server
|
||||
const server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
await server.start({
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'BASE64_PRIVATE_KEY',
|
||||
publicKey: 'BASE64_PUBLIC_KEY',
|
||||
privateKey: '<server-noise-private-key-base64>',
|
||||
publicKey: '<server-noise-public-key-base64>',
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
transportMode: 'both', // WebSocket + QUIC simultaneously
|
||||
enableNat: true,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
});
|
||||
|
||||
// Generate a Noise keypair
|
||||
const keypair = await server.generateKeypair();
|
||||
console.log(keypair); // { publicKey: '...', privateKey: '...' }
|
||||
|
||||
// List connected clients
|
||||
const clients = await server.listClients();
|
||||
// [{ clientId, assignedIp, connectedSince, bytesSent, bytesReceived }]
|
||||
|
||||
// Disconnect a specific client
|
||||
await server.disconnectClient('some-client-id');
|
||||
|
||||
// Get server stats
|
||||
const stats = await server.getStatistics();
|
||||
// { bytesSent, bytesReceived, activeClients, totalConnections, ... }
|
||||
|
||||
// Stop
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
```
|
||||
|
||||
### Production: Socket Transport
|
||||
|
||||
In production, the daemon runs as a system service and you connect over a Unix socket:
|
||||
### 2. Create a Client (One Call = Everything)
|
||||
|
||||
```typescript
|
||||
const client = new VpnClient({
|
||||
transport: {
|
||||
transport: 'socket',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
autoReconnect: true,
|
||||
reconnectBaseDelayMs: 100,
|
||||
reconnectMaxDelayMs: 30000,
|
||||
maxReconnectAttempts: 10,
|
||||
const bundle = await server.createClient({
|
||||
clientId: 'alice-laptop',
|
||||
tags: ['engineering'],
|
||||
security: {
|
||||
destinationAllowList: ['10.0.0.0/8'], // can only reach internal network
|
||||
destinationBlockList: ['10.0.0.99'], // except this host
|
||||
rateLimit: { bytesPerSec: 10_000_000, burstBytes: 20_000_000 },
|
||||
},
|
||||
});
|
||||
|
||||
await client.start(); // connects to existing daemon (does not spawn)
|
||||
// bundle.smartvpnConfig → typed IVpnClientConfig, ready to use
|
||||
// bundle.wireguardConfig → standard WireGuard .conf string
|
||||
// bundle.secrets → { noisePrivateKey, wgPrivateKey } — shown ONCE
|
||||
```
|
||||
|
||||
When using socket transport, `client.stop()` closes the socket but **does not kill the daemon** — exactly what you want in production.
|
||||
|
||||
## 📋 API Reference
|
||||
|
||||
### `VpnClient`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start()` | `Promise<boolean>` | Start the daemon bridge (spawn or connect) |
|
||||
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
|
||||
| `disconnect()` | `Promise<void>` | Disconnect from VPN |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
|
||||
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic statistics |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
| `running` | `boolean` | Whether bridge is active |
|
||||
|
||||
### `VpnServer`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start(config?)` | `Promise<void>` | Start daemon + VPN server |
|
||||
| `stopServer()` | `Promise<void>` | Stop the VPN server |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
|
||||
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
|
||||
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients |
|
||||
| `disconnectClient(id)` | `Promise<void>` | Kick a client |
|
||||
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
|
||||
### `VpnConfig`
|
||||
|
||||
Static utility class for config validation and file I/O:
|
||||
### 3. Connect a Client
|
||||
|
||||
```typescript
|
||||
import { VpnConfig } from '@push.rocks/smartvpn';
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Validate (throws on invalid)
|
||||
VpnConfig.validateClientConfig(config);
|
||||
VpnConfig.validateServerConfig(config);
|
||||
const client = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await client.start();
|
||||
|
||||
// Load/save JSON configs
|
||||
const config = await VpnConfig.loadFromFile<IVpnClientConfig>('/etc/smartvpn/client.json');
|
||||
await VpnConfig.saveToFile('/etc/smartvpn/client.json', config);
|
||||
const { assignedIp } = await client.connect(bundle.smartvpnConfig);
|
||||
console.log(`Connected! VPN IP: ${assignedIp}`);
|
||||
```
|
||||
|
||||
### `VpnInstaller`
|
||||
## Features ✨
|
||||
|
||||
Generate system service units for the daemon:
|
||||
### 🔐 Enterprise Authentication (Noise IK)
|
||||
|
||||
Every client authenticates with a **Noise IK handshake** (`Noise_IK_25519_ChaChaPoly_BLAKE2s`). The server verifies the client's static public key against its registry — unauthorized clients are rejected before any data flows.
|
||||
|
||||
- Per-client X25519 keypair generated server-side
|
||||
- Client registry with enable/disable, expiry, tags
|
||||
- Key rotation with `rotateClientKey()` — generates new keys, returns fresh config bundle, disconnects old session
|
||||
|
||||
### 🌐 Triple Transport
|
||||
|
||||
| Transport | Protocol | Best For |
|
||||
|-----------|----------|----------|
|
||||
| **WebSocket** | TLS over TCP | Firewall-friendly, Cloudflare compatible |
|
||||
| **QUIC** | UDP (via quinn) | Low latency, datagram support for IP packets |
|
||||
| **WireGuard** | UDP (via boringtun) | Standard WG clients (iOS, Android, wg-quick) |
|
||||
|
||||
The server can run **all three simultaneously** with `transportMode: 'both'` (WS + QUIC) or `'wireguard'`. Clients auto-negotiate with `transport: 'auto'` (tries QUIC first, falls back to WS).
|
||||
|
||||
### 🛡️ ACL Engine (SmartProxy-Aligned)
|
||||
|
||||
Security policies per client, using the same `ipAllowList` / `ipBlockList` naming convention as `@push.rocks/smartproxy`:
|
||||
|
||||
```typescript
|
||||
security: {
|
||||
ipAllowList: ['192.168.1.0/24'], // source IPs allowed to connect
|
||||
ipBlockList: ['192.168.1.100'], // deny overrides allow
|
||||
destinationAllowList: ['10.0.0.0/8'], // VPN destinations permitted
|
||||
destinationBlockList: ['10.0.0.99'], // deny overrides allow
|
||||
maxConnections: 5,
|
||||
rateLimit: { bytesPerSec: 1_000_000, burstBytes: 2_000_000 },
|
||||
}
|
||||
```
|
||||
|
||||
Supports exact IPs, CIDR, wildcards (`192.168.1.*`), and ranges (`1.1.1.1-1.1.1.100`).
|
||||
|
||||
### 📊 Telemetry & QoS
|
||||
|
||||
- **Connection quality**: Smoothed RTT, jitter, min/max RTT, loss ratio, link health (`healthy` / `degraded` / `critical`)
|
||||
- **Adaptive keepalives**: Interval adjusts based on link health (60s → 30s → 10s)
|
||||
- **Per-client rate limiting**: Token bucket with configurable bytes/sec and burst
|
||||
- **Dead-peer detection**: 180s inactivity timeout
|
||||
- **MTU management**: Automatic overhead calculation (IP+TCP+WS+Noise = 79 bytes)
|
||||
|
||||
### 🔄 Hub Client Management
|
||||
|
||||
The server acts as a **hub** — one API to manage all clients:
|
||||
|
||||
```typescript
|
||||
// Create (generates keys, assigns IP, returns config bundle)
|
||||
const bundle = await server.createClient({ clientId: 'bob-phone' });
|
||||
|
||||
// Read
|
||||
const entry = await server.getClient('bob-phone');
|
||||
const all = await server.listRegisteredClients();
|
||||
|
||||
// Update (ACLs, tags, description, rate limits...)
|
||||
await server.updateClient('bob-phone', {
|
||||
security: { destinationAllowList: ['0.0.0.0/0'] },
|
||||
tags: ['mobile', 'field-ops'],
|
||||
});
|
||||
|
||||
// Enable / Disable
|
||||
await server.disableClient('bob-phone'); // disconnects + blocks reconnection
|
||||
await server.enableClient('bob-phone');
|
||||
|
||||
// Key rotation
|
||||
const newBundle = await server.rotateClientKey('bob-phone');
|
||||
|
||||
// Export config (without secrets)
|
||||
const wgConf = await server.exportClientConfig('bob-phone', 'wireguard');
|
||||
|
||||
// Remove
|
||||
await server.removeClient('bob-phone');
|
||||
```
|
||||
|
||||
### 📝 WireGuard Config Generation
|
||||
|
||||
Generate standard `.conf` files for any WireGuard client:
|
||||
|
||||
```typescript
|
||||
import { WgConfigGenerator } from '@push.rocks/smartvpn';
|
||||
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: '<client-wg-private-key>',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1'],
|
||||
peer: {
|
||||
publicKey: '<server-wg-public-key>',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
// → standard WireGuard .conf compatible with wg-quick, iOS, Android
|
||||
```
|
||||
|
||||
### 🖥️ System Service Installation
|
||||
|
||||
```typescript
|
||||
import { VpnInstaller } from '@push.rocks/smartvpn';
|
||||
|
||||
// Auto-detect platform
|
||||
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
|
||||
|
||||
// Generate systemd unit (Linux)
|
||||
const unit = VpnInstaller.generateSystemdUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'server',
|
||||
});
|
||||
// unit.content = full systemd .service file
|
||||
// unit.installPath = '/etc/systemd/system/smartvpn-server.service'
|
||||
|
||||
// Generate launchd plist (macOS)
|
||||
const plist = VpnInstaller.generateLaunchdPlist({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'client',
|
||||
});
|
||||
|
||||
// Auto-detect and generate
|
||||
const serviceUnit = VpnInstaller.generateServiceUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
const unit = VpnInstaller.generateServiceUnit({
|
||||
mode: 'server',
|
||||
configPath: '/etc/smartvpn/server.json',
|
||||
});
|
||||
// unit.platform → 'linux' | 'macos'
|
||||
// unit.content → systemd unit file or launchd plist
|
||||
// unit.installPath → /etc/systemd/system/smartvpn-server.service
|
||||
```
|
||||
|
||||
### Events
|
||||
## API Reference 📖
|
||||
|
||||
Both `VpnClient` and `VpnServer` extend `EventEmitter`:
|
||||
### Classes
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `VpnServer` | Manages the Rust daemon in server mode. Hub methods for client CRUD. |
|
||||
| `VpnClient` | Manages the Rust daemon in client mode. Connect, disconnect, telemetry. |
|
||||
| `VpnBridge<T>` | Low-level typed IPC bridge (stdio or Unix socket). |
|
||||
| `VpnConfig` | Static config validation and file I/O. |
|
||||
| `VpnInstaller` | Generates systemd/launchd service files. |
|
||||
| `WgConfigGenerator` | Generates standard WireGuard `.conf` files. |
|
||||
|
||||
### Key Interfaces
|
||||
|
||||
| Interface | Purpose |
|
||||
|-----------|---------|
|
||||
| `IVpnServerConfig` | Server configuration (listen addr, keys, subnet, transport mode, clients) |
|
||||
| `IVpnClientConfig` | Client configuration (server URL, keys, transport, WG options) |
|
||||
| `IClientEntry` | Server-side client definition (ID, keys, security, priority, tags, expiry) |
|
||||
| `IClientSecurity` | Per-client ACLs and rate limits (SmartProxy-aligned naming) |
|
||||
| `IClientRateLimit` | Rate limiting config (bytesPerSec, burstBytes) |
|
||||
| `IClientConfigBundle` | Full config bundle returned by `createClient()` |
|
||||
| `IVpnClientInfo` | Connected client info (IP, stats, authenticated key) |
|
||||
| `IVpnConnectionQuality` | RTT, jitter, loss ratio, link health |
|
||||
| `IVpnKeypair` | Base64-encoded public/private key pair |
|
||||
|
||||
### Server IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `start` / `stop` | Start/stop the VPN listener |
|
||||
| `createClient` | Generate keys, assign IP, return config bundle |
|
||||
| `removeClient` / `getClient` / `listRegisteredClients` | Client registry CRUD |
|
||||
| `updateClient` / `enableClient` / `disableClient` | Modify client state |
|
||||
| `rotateClientKey` | Fresh keypairs + new config bundle |
|
||||
| `exportClientConfig` | Re-export as SmartVPN config or WireGuard `.conf` |
|
||||
| `listClients` / `disconnectClient` | Manage live connections |
|
||||
| `setClientRateLimit` / `removeClientRateLimit` | Runtime rate limit adjustments |
|
||||
| `getStatus` / `getStatistics` / `getClientTelemetry` | Monitoring |
|
||||
| `generateKeypair` / `generateWgKeypair` / `generateClientKeypair` | Key generation |
|
||||
| `addWgPeer` / `removeWgPeer` / `listWgPeers` | WireGuard peer management |
|
||||
|
||||
### Client IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `connect` / `disconnect` | Manage the tunnel |
|
||||
| `getStatus` / `getStatistics` | Connection state and traffic stats |
|
||||
| `getConnectionQuality` | RTT, jitter, loss, link health |
|
||||
| `getMtuInfo` | MTU and overhead details |
|
||||
|
||||
## Transport Modes 🔀
|
||||
|
||||
### Server Configuration
|
||||
|
||||
```typescript
|
||||
client.on('status', (status) => { /* IVpnStatus */ });
|
||||
client.on('error', (err) => { /* { message, code? } */ });
|
||||
client.on('exit', ({ code, signal }) => { /* daemon exited */ });
|
||||
client.on('reconnected', () => { /* socket reconnected */ });
|
||||
// WebSocket only
|
||||
{ transportMode: 'websocket', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
server.on('client-connected', (info) => { /* IVpnClientInfo */ });
|
||||
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
|
||||
// QUIC only
|
||||
{ transportMode: 'quic', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
// Both (WS + QUIC on same or different ports)
|
||||
{ transportMode: 'both', listenAddr: '0.0.0.0:443', quicListenAddr: '0.0.0.0:4433' }
|
||||
|
||||
// WireGuard
|
||||
{ transportMode: 'wireguard', wgListenPort: 51820, wgPeers: [...] }
|
||||
```
|
||||
|
||||
## 🔐 Security Model
|
||||
### Client Configuration
|
||||
|
||||
The VPN uses a **Noise NK** handshake pattern:
|
||||
```typescript
|
||||
// Auto (tries QUIC first, falls back to WS)
|
||||
{ transport: 'auto', serverUrl: 'wss://vpn.example.com' }
|
||||
|
||||
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
|
||||
2. The client generates an ephemeral keypair, performs `e, es` (Diffie-Hellman with server's static key)
|
||||
3. Server responds with `e, ee` (Diffie-Hellman with both ephemeral keys)
|
||||
4. Result: forward-secret transport keys derived from both DH operations
|
||||
// Explicit QUIC with certificate pinning
|
||||
{ transport: 'quic', serverUrl: '1.2.3.4:4433', serverCertHash: '<sha256-base64>' }
|
||||
|
||||
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
|
||||
- 24-byte random nonces (no counter synchronization needed)
|
||||
- 16-byte authentication tags
|
||||
- Wire format: `[nonce:24B][ciphertext:var][tag:16B]`
|
||||
|
||||
## 📦 Binary Protocol
|
||||
|
||||
Inside the WebSocket tunnel, packets use a simple binary framing:
|
||||
|
||||
```
|
||||
┌──────────┬──────────┬────────────────────┐
|
||||
│ Type (1B)│ Len (4B) │ Payload (variable) │
|
||||
└──────────┴──────────┴────────────────────┘
|
||||
// WireGuard
|
||||
{ transport: 'wireguard', wgPrivateKey: '...', wgEndpoint: 'vpn.example.com:51820', ... }
|
||||
```
|
||||
|
||||
| Type | Value | Description |
|
||||
|------|-------|-------------|
|
||||
| `HandshakeInit` | `0x01` | Client → Server handshake |
|
||||
| `HandshakeResp` | `0x02` | Server → Client handshake |
|
||||
| `IpPacket` | `0x10` | Encrypted IP packet |
|
||||
| `Keepalive` | `0x20` | App-level ping |
|
||||
| `KeepaliveAck` | `0x21` | App-level pong |
|
||||
| `SessionResume` | `0x30` | Resume a dropped session |
|
||||
| `SessionResumeOk` | `0x31` | Resume accepted |
|
||||
| `SessionResumeErr` | `0x32` | Resume rejected |
|
||||
| `Disconnect` | `0x3F` | Graceful disconnect |
|
||||
## Cryptography 🔑
|
||||
|
||||
## 🛠️ Rust Daemon CLI
|
||||
| Layer | Algorithm | Purpose |
|
||||
|-------|-----------|---------|
|
||||
| **Handshake** | Noise IK (X25519 + ChaChaPoly + BLAKE2s) | Mutual authentication + key exchange |
|
||||
| **Transport** | Noise transport state (ChaChaPoly) | All post-handshake data encryption |
|
||||
| **Additional** | XChaCha20-Poly1305 | Extended nonce space for data-at-rest |
|
||||
| **WireGuard** | X25519 + ChaCha20-Poly1305 (via boringtun) | Standard WireGuard crypto |
|
||||
|
||||
The Rust binary supports several modes:
|
||||
## Binary Protocol 📡
|
||||
|
||||
```bash
|
||||
# Development: stdio management (JSON lines on stdin/stdout)
|
||||
smartvpn_daemon --management --mode client
|
||||
smartvpn_daemon --management --mode server
|
||||
All frames use `[type:1B][length:4B][payload:NB]` with a 64KB max payload:
|
||||
|
||||
# Production: Unix socket management
|
||||
smartvpn_daemon --management-socket /var/run/smartvpn.sock --mode server
|
||||
| Type | Hex | Direction | Description |
|
||||
|------|-----|-----------|-------------|
|
||||
| HandshakeInit | `0x01` | Client → Server | Noise IK first message |
|
||||
| HandshakeResp | `0x02` | Server → Client | Noise IK response |
|
||||
| IpPacket | `0x10` | Bidirectional | Encrypted tunnel data |
|
||||
| Keepalive | `0x20` | Client → Server | App-level keepalive (not WS ping) |
|
||||
| KeepaliveAck | `0x21` | Server → Client | Keepalive response with RTT payload |
|
||||
| Disconnect | `0x3F` | Bidirectional | Graceful disconnect |
|
||||
|
||||
# Generate a Noise keypair
|
||||
smartvpn_daemon --generate-keypair
|
||||
```
|
||||
|
||||
## 🔧 Building from Source
|
||||
## Development 🛠️
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Build TypeScript + cross-compile Rust
|
||||
# Build (TypeScript + Rust cross-compile)
|
||||
pnpm build
|
||||
|
||||
# Build Rust only (debug)
|
||||
cd rust && cargo build
|
||||
# Run all tests (79 TS + 121 Rust = 200 tests)
|
||||
pnpm test
|
||||
|
||||
# Run Rust tests
|
||||
# Run Rust tests directly
|
||||
cd rust && cargo test
|
||||
|
||||
# Run TypeScript tests
|
||||
pnpm test
|
||||
# Run a specific TS test
|
||||
tstest test/test.flowcontrol.node.ts --verbose
|
||||
```
|
||||
|
||||
## TypeScript Interfaces
|
||||
### Project Structure
|
||||
|
||||
<details>
|
||||
<summary>Click to expand full type definitions</summary>
|
||||
|
||||
```typescript
|
||||
// Transport options
|
||||
type TVpnTransportOptions =
|
||||
| { transport: 'stdio' }
|
||||
| {
|
||||
transport: 'socket';
|
||||
socketPath: string;
|
||||
autoReconnect?: boolean;
|
||||
reconnectBaseDelayMs?: number;
|
||||
reconnectMaxDelayMs?: number;
|
||||
maxReconnectAttempts?: number;
|
||||
};
|
||||
|
||||
// Client config
|
||||
interface IVpnClientConfig {
|
||||
serverUrl: string; // e.g. 'wss://vpn.example.com/tunnel'
|
||||
serverPublicKey: string; // base64-encoded Noise static key
|
||||
dns?: string[];
|
||||
mtu?: number; // default: 1420
|
||||
keepaliveIntervalSecs?: number; // default: 30
|
||||
}
|
||||
|
||||
// Server config
|
||||
interface IVpnServerConfig {
|
||||
listenAddr: string; // e.g. '0.0.0.0:443'
|
||||
privateKey: string; // base64 Noise static private key
|
||||
publicKey: string; // base64 Noise static public key
|
||||
subnet: string; // e.g. '10.8.0.0/24'
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
}
|
||||
|
||||
// Status
|
||||
type TVpnConnectionState = 'disconnected' | 'connecting' | 'handshaking'
|
||||
| 'connected' | 'reconnecting' | 'error';
|
||||
|
||||
interface IVpnStatus {
|
||||
state: TVpnConnectionState;
|
||||
assignedIp?: string;
|
||||
serverAddr?: string;
|
||||
connectedSince?: string;
|
||||
lastError?: string;
|
||||
}
|
||||
|
||||
// Statistics
|
||||
interface IVpnStatistics {
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
}
|
||||
|
||||
interface IVpnServerStatistics extends IVpnStatistics {
|
||||
activeClients: number;
|
||||
totalConnections: number;
|
||||
}
|
||||
|
||||
interface IVpnClientInfo {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
}
|
||||
|
||||
interface IVpnKeypair {
|
||||
publicKey: string;
|
||||
privateKey: string;
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
smartvpn/
|
||||
├── ts/ # TypeScript control plane
|
||||
│ ├── index.ts # All exports
|
||||
│ ├── smartvpn.interfaces.ts # Interfaces, types, IPC command maps
|
||||
│ ├── smartvpn.classes.vpnserver.ts
|
||||
│ ├── smartvpn.classes.vpnclient.ts
|
||||
│ ├── smartvpn.classes.vpnbridge.ts
|
||||
│ ├── smartvpn.classes.vpnconfig.ts
|
||||
│ ├── smartvpn.classes.vpninstaller.ts
|
||||
│ └── smartvpn.classes.wgconfig.ts
|
||||
├── rust/ # Rust data plane daemon
|
||||
│ └── src/
|
||||
│ ├── main.rs # CLI entry point
|
||||
│ ├── server.rs # VPN server + hub methods
|
||||
│ ├── client.rs # VPN client
|
||||
│ ├── crypto.rs # Noise IK + XChaCha20
|
||||
│ ├── client_registry.rs # Client database
|
||||
│ ├── acl.rs # ACL engine
|
||||
│ ├── management.rs # JSON-lines IPC
|
||||
│ ├── transport.rs # WebSocket transport
|
||||
│ ├── quic_transport.rs # QUIC transport
|
||||
│ ├── wireguard.rs # WireGuard (boringtun)
|
||||
│ ├── codec.rs # Binary frame protocol
|
||||
│ ├── keepalive.rs # Adaptive keepalives
|
||||
│ ├── ratelimit.rs # Token bucket
|
||||
│ └── ... # tunnel, network, telemetry, qos, mtu, reconnect
|
||||
├── test/ # 9 test files (79 tests)
|
||||
├── dist_ts/ # Compiled TypeScript
|
||||
└── dist_rust/ # Cross-compiled binaries (linux amd64 + arm64)
|
||||
```
|
||||
|
||||
## License and Legal Information
|
||||
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license.md) file.
|
||||
|
||||
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.
|
||||
|
||||
|
||||
459
readme.plan.md
Normal file
459
readme.plan.md
Normal file
@@ -0,0 +1,459 @@
|
||||
# Enterprise Auth & Client Management for SmartVPN
|
||||
|
||||
## Context
|
||||
|
||||
SmartVPN's Noise NK mode currently allows **any client that knows the server's public key** to connect — no per-client identity or access control. The goal is to make SmartVPN enterprise-ready with:
|
||||
|
||||
1. **Per-client cryptographic authentication** (Noise IK handshake)
|
||||
2. **Rich client definitions** with ACLs, rate limits, and priority
|
||||
3. **Hub-generated configs** — server generates typed SmartVPN client configs AND WireGuard .conf files from the same client definition
|
||||
4. **Top-notch DX** — one `createClient()` call gives you everything
|
||||
|
||||
**This is a breaking change.** No backward compatibility with the old NK anonymous mode.
|
||||
|
||||
---
|
||||
|
||||
## Design Overview
|
||||
|
||||
### The Hub Model
|
||||
|
||||
The server acts as a **hub** that manages client definitions. Each client definition is the **single source of truth** from which both SmartVPN native configs and WireGuard configs are generated.
|
||||
|
||||
```
|
||||
Hub (Server)
|
||||
└── Client Registry
|
||||
├── "alice-laptop" → SmartVPN config OR WireGuard .conf
|
||||
├── "bob-phone" → SmartVPN config OR WireGuard .conf
|
||||
└── "office-gw" → SmartVPN config OR WireGuard .conf
|
||||
```
|
||||
|
||||
### Authentication: NK → IK (Breaking Change)
|
||||
|
||||
**Old (removed):** `Noise_NK_25519_ChaChaPoly_BLAKE2s` — client is anonymous
|
||||
**New (always):** `Noise_IK_25519_ChaChaPoly_BLAKE2s` — client presents its static key during handshake
|
||||
|
||||
IK is a 2-message handshake (same count as NK), so **the frame protocol stays identical**. Changes:
|
||||
- `create_initiator()` now requires `(client_private_key, server_public_key)` — always
|
||||
- `create_responder()` remains `(server_private_key)` — but now uses IK pattern
|
||||
- After handshake, server extracts client's public key via `get_remote_static()` and verifies against registry
|
||||
- Old NK functions are replaced, not kept alongside
|
||||
|
||||
**Every client must have a keypair. Every server must have a client registry.**
|
||||
|
||||
---
|
||||
|
||||
## Core Interface: `IClientEntry`
|
||||
|
||||
This is the server-side client definition — the central config object.
|
||||
Naming and structure are aligned with SmartProxy's `IRouteConfig` / `IRouteSecurity` patterns.
|
||||
|
||||
```typescript
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
|
||||
// ── Security (aligned with SmartProxy IRouteSecurity pattern) ─────────
|
||||
|
||||
security?: IClientSecurity;
|
||||
|
||||
// ── QoS ────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
|
||||
// ── Metadata (aligned with SmartProxy IRouteConfig pattern) ────────────
|
||||
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Security settings per client — mirrors SmartProxy's IRouteSecurity structure.
|
||||
* Uses the same ipAllowList/ipBlockList naming convention.
|
||||
* Adds VPN-specific destination filtering (destinationAllowList/destinationBlockList).
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5).
|
||||
* Same format as SmartProxy's ipAllowList. */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins).
|
||||
* Same format as SmartProxy's ipBlockList. */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all) */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins) */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
```
|
||||
|
||||
### SmartProxy Alignment Notes
|
||||
|
||||
| Pattern | SmartProxy | SmartVPN |
|
||||
|---------|-----------|---------|
|
||||
| ACL naming | `ipAllowList` / `ipBlockList` | Same — `ipAllowList` / `ipBlockList` |
|
||||
| Security grouping | `security: IRouteSecurity` sub-object | Same — `security: IClientSecurity` sub-object |
|
||||
| Rate limit structure | `rateLimit: IRouteRateLimit` object | Same pattern — `rateLimit: IClientRateLimit` object |
|
||||
| IP format support | Exact, CIDR, wildcard, ranges | Same formats |
|
||||
| Metadata fields | `priority`, `tags`, `enabled`, `description` | Same fields |
|
||||
| ACL evaluation | Block-first, then allow-list | Same — deny overrides allow |
|
||||
|
||||
### ACL Evaluation Order
|
||||
|
||||
```
|
||||
1. Check ipBlockList / destinationBlockList first (explicit deny wins)
|
||||
2. If denied, DROP
|
||||
3. Check ipAllowList / destinationAllowList (explicit allow)
|
||||
4. If ipAllowList is empty → allow any source
|
||||
5. If destinationAllowList is empty → allow all destinations
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Hub Config Generation
|
||||
|
||||
### `createClient()` — The One-Call DX
|
||||
|
||||
When the hub creates a client, it:
|
||||
1. Generates a Noise IK keypair for the client
|
||||
2. Generates a WireGuard keypair for the client
|
||||
3. Allocates a VPN IP address
|
||||
4. Stores the `IClientEntry` in the registry
|
||||
5. Returns a **complete config bundle** with everything the client needs
|
||||
|
||||
```typescript
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
The `secrets` are returned **only at creation time** — the server stores only public keys.
|
||||
|
||||
### `exportClientConfig()` — Re-export (without secrets)
|
||||
|
||||
```typescript
|
||||
exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): IVpnClientConfig | string
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Updated `IVpnServerConfig`
|
||||
|
||||
```typescript
|
||||
export interface IVpnServerConfig {
|
||||
listenAddr: string;
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
privateKey: string; // Server's Noise static private key (base64)
|
||||
publicKey: string; // Server's Noise static public key (base64)
|
||||
subnet: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
defaultBurstBytes?: number;
|
||||
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
|
||||
quicListenAddr?: string;
|
||||
quicIdleTimeoutSecs?: number;
|
||||
wgListenPort?: number;
|
||||
wgPeers?: IWgPeerConfig[]; // Keep for raw WG mode
|
||||
|
||||
/** Pre-registered clients — REQUIRED for SmartVPN native transport */
|
||||
clients: IClientEntry[];
|
||||
}
|
||||
```
|
||||
|
||||
Note: `clients` is now **required** (not optional), and there is no `authMode` field — IK is always used.
|
||||
|
||||
---
|
||||
|
||||
## Updated `IVpnClientConfig`
|
||||
|
||||
```typescript
|
||||
export interface IVpnClientConfig {
|
||||
serverUrl: string;
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — REQUIRED for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
|
||||
serverCertHash?: string;
|
||||
// WireGuard fields unchanged...
|
||||
wgPrivateKey?: string;
|
||||
wgAddress?: string;
|
||||
wgAddressPrefix?: number;
|
||||
wgPresharedKey?: string;
|
||||
wgPersistentKeepalive?: number;
|
||||
wgEndpoint?: string;
|
||||
wgAllowedIps?: string[];
|
||||
}
|
||||
```
|
||||
|
||||
Note: `clientPrivateKey` and `clientPublicKey` are now **required** (not optional) for non-WireGuard transports.
|
||||
|
||||
---
|
||||
|
||||
## New IPC Commands
|
||||
|
||||
Added to `TVpnServerCommands`:
|
||||
|
||||
| Command | Params | Result | Description |
|
||||
|---------|--------|--------|-------------|
|
||||
| `createClient` | `{ client: Partial<IClientEntry> }` | `IClientConfigBundle` | Create client, generate keypairs, assign IP, return full config bundle |
|
||||
| `removeClient` | `{ clientId: string }` | `void` | Remove from registry + disconnect if connected |
|
||||
| `getClient` | `{ clientId: string }` | `IClientEntry` | Get a single client entry |
|
||||
| `listRegisteredClients` | `{}` | `{ clients: IClientEntry[] }` | List all registered clients |
|
||||
| `updateClient` | `{ clientId: string, update: Partial<IClientEntry> }` | `void` | Update ACLs, rate limits, tags, etc. |
|
||||
| `enableClient` | `{ clientId: string }` | `void` | Enable a disabled client |
|
||||
| `disableClient` | `{ clientId: string }` | `void` | Disable (but don't delete) |
|
||||
| `rotateClientKey` | `{ clientId: string }` | `IClientConfigBundle` | New keypairs, return fresh config bundle |
|
||||
| `exportClientConfig` | `{ clientId: string, format: 'smartvpn' \| 'wireguard' }` | `{ config: string }` | Re-export config (without secrets) |
|
||||
| `generateClientKeypair` | `{}` | `IVpnKeypair` | Generate a standalone Noise IK keypair |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Rust — Crypto (Replace NK with IK)
|
||||
|
||||
**File: `rust/src/crypto.rs`**
|
||||
|
||||
- Change `NOISE_PATTERN` from NK to IK: `"Noise_IK_25519_ChaChaPoly_BLAKE2s"`
|
||||
- Replace `create_initiator(server_public_key)` → `create_initiator(client_private_key, server_public_key)`
|
||||
- `create_responder(private_key)` stays the same signature (IK responder only needs its own key)
|
||||
- After handshake, `get_remote_static()` on the responder returns the client's public key
|
||||
- Update `perform_handshake()` to pass client keypair
|
||||
- Update all tests
|
||||
|
||||
### Phase 2: Rust — Client Registry module
|
||||
|
||||
**New file: `rust/src/client_registry.rs`**
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod client_registry;`
|
||||
|
||||
```rust
|
||||
pub struct ClientEntry {
|
||||
pub client_id: String,
|
||||
pub public_key: String,
|
||||
pub wg_public_key: Option<String>,
|
||||
pub security: Option<ClientSecurity>,
|
||||
pub priority: Option<u32>,
|
||||
pub enabled: Option<bool>,
|
||||
pub tags: Option<Vec<String>>,
|
||||
pub description: Option<String>,
|
||||
pub expires_at: Option<String>,
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
/// Mirrors IClientSecurity — aligned with SmartProxy's IRouteSecurity
|
||||
pub struct ClientSecurity {
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
pub max_connections: Option<u32>,
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
pub struct ClientRegistry {
|
||||
entries: HashMap<String, ClientEntry>, // keyed by clientId
|
||||
key_index: HashMap<String, String>, // publicKey → clientId (fast lookup)
|
||||
}
|
||||
```
|
||||
|
||||
Methods: `add`, `remove`, `get_by_id`, `get_by_key`, `update`, `list`, `is_authorized` (enabled + not expired + key exists), `rotate_key`.
|
||||
|
||||
### Phase 3: Rust — ACL enforcement module
|
||||
|
||||
**New file: `rust/src/acl.rs`**
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod acl;`
|
||||
|
||||
```rust
|
||||
/// IP matching supports: exact, CIDR, wildcard, ranges — same as SmartProxy's IpMatcher
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// 1. Check ip_block_list / destination_block_list (deny overrides)
|
||||
// 2. Check ip_allow_list / destination_allow_list (explicit allow)
|
||||
// 3. Empty list = allow all
|
||||
}
|
||||
```
|
||||
|
||||
Called in `server.rs` packet loop after decryption, before forwarding.
|
||||
|
||||
### Phase 4: Rust — Server changes
|
||||
|
||||
**File: `rust/src/server.rs`**
|
||||
|
||||
- Add `clients: Option<Vec<ClientEntry>>` to `ServerConfig`
|
||||
- Add `client_registry: RwLock<ClientRegistry>` to `ServerState` (no `auth_mode` — always IK)
|
||||
- Modify `handle_client_connection()`:
|
||||
- Always use `create_responder()` (now IK pattern)
|
||||
- Call `get_remote_static()` **before** `into_transport_mode()` to get client's public key
|
||||
- Verify against registry — reject unauthorized clients with Disconnect frame
|
||||
- Use registry entry for rate limits (overrides server defaults)
|
||||
- In packet loop: call `acl::check_acl()` on decrypted packets
|
||||
- Add `ClientInfo.authenticated_key: String` and `ClientInfo.registered_client_id: String` (no longer optional)
|
||||
- Add methods: `create_client()`, `remove_client()`, `update_client()`, `list_registered_clients()`, `rotate_client_key()`, `export_client_config()`
|
||||
|
||||
### Phase 5: Rust — Client changes
|
||||
|
||||
**File: `rust/src/client.rs`**
|
||||
|
||||
- Add `client_private_key: String` to `ClientConfig` (required, not optional)
|
||||
- `connect()` always uses `create_initiator(client_private_key, server_public_key)` (IK)
|
||||
|
||||
### Phase 6: Rust — Management IPC handlers
|
||||
|
||||
**File: `rust/src/management.rs`**
|
||||
|
||||
Add handlers for all 10 new IPC commands following existing patterns.
|
||||
|
||||
### Phase 7: TypeScript — Interfaces
|
||||
|
||||
**File: `ts/smartvpn.interfaces.ts`**
|
||||
|
||||
- Add `IClientEntry` interface
|
||||
- Add `IClientConfigBundle` interface
|
||||
- Update `IVpnServerConfig`: add required `clients: IClientEntry[]`
|
||||
- Update `IVpnClientConfig`: add required `clientPrivateKey: string`, `clientPublicKey: string`
|
||||
- Update `IVpnClientInfo`: add `authenticatedKey: string`, `registeredClientId: string`
|
||||
- Add new commands to `TVpnServerCommands`
|
||||
|
||||
### Phase 8: TypeScript — VpnServer class methods
|
||||
|
||||
**File: `ts/smartvpn.classes.vpnserver.ts`**
|
||||
|
||||
Add methods:
|
||||
- `createClient(opts)` → `IClientConfigBundle`
|
||||
- `removeClient(clientId)` → `void`
|
||||
- `getClient(clientId)` → `IClientEntry`
|
||||
- `listRegisteredClients()` → `IClientEntry[]`
|
||||
- `updateClient(clientId, update)` → `void`
|
||||
- `enableClient(clientId)` / `disableClient(clientId)`
|
||||
- `rotateClientKey(clientId)` → `IClientConfigBundle`
|
||||
- `exportClientConfig(clientId, format)` → `string | IVpnClientConfig`
|
||||
|
||||
### Phase 9: TypeScript — Config validation
|
||||
|
||||
**File: `ts/smartvpn.classes.vpnconfig.ts`**
|
||||
|
||||
- Server config: validate `clients` present, each entry has valid `clientId` + `publicKey`
|
||||
- Client config: validate `clientPrivateKey` and `clientPublicKey` present for non-WG transports
|
||||
- Validate CIDRs in ACL fields
|
||||
|
||||
### Phase 10: TypeScript — Hub config generation
|
||||
|
||||
**File: `ts/smartvpn.classes.wgconfig.ts`** (extend existing)
|
||||
|
||||
Add `generateClientConfigFromEntry(entry, serverConfig)` — produces WireGuard .conf from `IClientEntry`.
|
||||
|
||||
### Phase 11: Update existing tests
|
||||
|
||||
All existing tests that use the old NK handshake or old config shapes need updating:
|
||||
- Rust tests in `crypto.rs`, `server.rs`, `client.rs`
|
||||
- TS tests in `test/test.vpnconfig.node.ts`, `test/test.flowcontrol.node.ts`, etc.
|
||||
- Tests now must provide client keypairs and client registry entries
|
||||
|
||||
---
|
||||
|
||||
## DX Highlights
|
||||
|
||||
1. **One call to create a client:**
|
||||
```typescript
|
||||
const bundle = await server.createClient({ clientId: 'alice-laptop', tags: ['engineering'] });
|
||||
// bundle.smartvpnConfig — typed SmartVPN client config
|
||||
// bundle.wireguardConfig — standard WireGuard .conf string
|
||||
// bundle.secrets — private keys, shown only at creation time
|
||||
```
|
||||
|
||||
2. **Typed config objects throughout** — no raw strings or JSON blobs
|
||||
|
||||
3. **Dual transport from same definition** — register once, connect via SmartVPN or WireGuard
|
||||
|
||||
4. **ACLs are deny-overrides-allow** — intuitive enterprise model
|
||||
|
||||
5. **Hot management** — add/remove/update/disable clients at runtime
|
||||
|
||||
6. **Key rotation** — `rotateClientKey()` generates new keys and returns a fresh config bundle
|
||||
|
||||
---
|
||||
|
||||
## Verification Plan
|
||||
|
||||
1. **Rust unit tests:**
|
||||
- `crypto.rs`: IK handshake roundtrip, `get_remote_static()` returns correct key, wrong key fails
|
||||
- `client_registry.rs`: CRUD, `is_authorized` with enabled/disabled/expired
|
||||
- `acl.rs`: allow/deny logic, empty lists, deny-overrides-allow
|
||||
|
||||
2. **Rust integration tests:**
|
||||
- Server accepts authorized client
|
||||
- Server rejects unknown client public key
|
||||
- ACL filtering drops packets to blocked destinations
|
||||
- Runtime `createClient` / `removeClient` works
|
||||
- Disabled client rejected at handshake
|
||||
|
||||
3. **TypeScript tests:**
|
||||
- Config validation with required client fields
|
||||
- `createClient()` returns valid bundle with both formats
|
||||
- `exportClientConfig()` generates valid WireGuard .conf
|
||||
- Full IPC roundtrip: create client → connect → traffic → disconnect
|
||||
|
||||
4. **Build:** `pnpm build` (TS + Rust), `cargo test`, `pnpm test`
|
||||
|
||||
---
|
||||
|
||||
## Key Files to Modify
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `rust/src/crypto.rs` | Replace NK with IK pattern, update initiator signature |
|
||||
| `rust/src/client_registry.rs` | **NEW** — client registry module |
|
||||
| `rust/src/acl.rs` | **NEW** — ACL evaluation module |
|
||||
| `rust/src/server.rs` | Registry integration, IK auth in handshake, ACL in packet loop |
|
||||
| `rust/src/client.rs` | Required `client_private_key`, IK initiator |
|
||||
| `rust/src/management.rs` | 10 new IPC command handlers |
|
||||
| `rust/src/lib.rs` | Register new modules |
|
||||
| `ts/smartvpn.interfaces.ts` | `IClientEntry`, `IClientConfigBundle`, updated configs & commands |
|
||||
| `ts/smartvpn.classes.vpnserver.ts` | New hub methods |
|
||||
| `ts/smartvpn.classes.vpnconfig.ts` | Updated validation rules |
|
||||
| `ts/smartvpn.classes.wgconfig.ts` | Config generation from client entries |
|
||||
793
rust/Cargo.lock
generated
793
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,18 @@ tun = { version = "0.7", features = ["async"] }
|
||||
bytes = "1"
|
||||
tokio-util = "0.7"
|
||||
futures-util = "0.3"
|
||||
async-trait = "0.1"
|
||||
quinn = "0.11"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rcgen = "0.13"
|
||||
ring = "0.17"
|
||||
rustls-pki-types = "1"
|
||||
rustls-pemfile = "2"
|
||||
webpki-roots = "1"
|
||||
mimalloc = "0.1"
|
||||
boringtun = "0.7"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
ipnet = "2"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
||||
278
rust/src/acl.rs
Normal file
278
rust/src/acl.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use std::net::Ipv4Addr;
|
||||
use ipnet::Ipv4Net;
|
||||
|
||||
use crate::client_registry::ClientSecurity;
|
||||
|
||||
/// Result of an ACL check.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AclResult {
|
||||
Allow,
|
||||
DenySrc,
|
||||
DenyDst,
|
||||
}
|
||||
|
||||
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
|
||||
///
|
||||
/// Evaluation order (deny overrides allow):
|
||||
/// 1. If src_ip is in ip_block_list → DenySrc
|
||||
/// 2. If dst_ip is in destination_block_list → DenyDst
|
||||
/// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc
|
||||
/// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst
|
||||
/// 5. Otherwise → Allow
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// Step 1: Check source block list (deny overrides)
|
||||
if let Some(ref block_list) = security.ip_block_list {
|
||||
if ip_matches_any(src_ip, block_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check destination block list (deny overrides)
|
||||
if let Some(ref block_list) = security.destination_block_list {
|
||||
if ip_matches_any(dst_ip, block_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Check source allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.ip_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Check destination allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.destination_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
AclResult::Allow
|
||||
}
|
||||
|
||||
/// Check if `ip` matches any pattern in the list.
|
||||
/// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*),
|
||||
/// and IP ranges (192.168.1.1-192.168.1.100).
|
||||
fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if ip_matches(ip, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if `ip` matches a single pattern.
|
||||
fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
// CIDR notation (e.g. 192.168.1.0/24)
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = pattern.parse::<Ipv4Net>() {
|
||||
return net.contains(&ip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// IP range (e.g. 192.168.1.1-192.168.1.100)
|
||||
if pattern.contains('-') {
|
||||
let parts: Vec<&str> = pattern.splitn(2, '-').collect();
|
||||
if parts.len() == 2 {
|
||||
if let (Ok(start), Ok(end)) = (parts[0].trim().parse::<Ipv4Addr>(), parts[1].trim().parse::<Ipv4Addr>()) {
|
||||
let ip_u32 = u32::from(ip);
|
||||
return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wildcard pattern (e.g. 192.168.1.*)
|
||||
if pattern.contains('*') {
|
||||
return wildcard_matches(ip, pattern);
|
||||
}
|
||||
|
||||
// Exact IP match
|
||||
if let Ok(exact) = pattern.parse::<Ipv4Addr>() {
|
||||
return ip == exact;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*".
|
||||
fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let ip_octets = ip.octets();
|
||||
let pattern_parts: Vec<&str> = pattern.split('.').collect();
|
||||
if pattern_parts.len() != 4 {
|
||||
return false;
|
||||
}
|
||||
for (i, part) in pattern_parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
continue;
|
||||
}
|
||||
if let Ok(octet) = part.parse::<u8>() {
|
||||
if ip_octets[i] != octet {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client_registry::{ClientRateLimit, ClientSecurity};
|
||||
|
||||
fn security(
|
||||
ip_allow: Option<Vec<&str>>,
|
||||
ip_block: Option<Vec<&str>>,
|
||||
dst_allow: Option<Vec<&str>>,
|
||||
dst_block: Option<Vec<&str>>,
|
||||
) -> ClientSecurity {
|
||||
ClientSecurity {
|
||||
ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
max_connections: None,
|
||||
rate_limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ip(s: &str) -> Ipv4Addr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
// ── No restrictions (empty security) ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_security_allows_all() {
|
||||
let sec = security(None, None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_lists_allow_all() {
|
||||
let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![]));
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
// ── Source IP allow list ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_allow_exact_match() {
|
||||
let sec = security(Some(vec!["10.0.0.1"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_cidr() {
|
||||
let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_wildcard() {
|
||||
let sec = security(Some(vec!["10.0.*.*"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_range() {
|
||||
let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Source IP block list (deny overrides) ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_block_overrides_allow() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
Some(vec!["192.168.1.100"]),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Destination allow list ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_allow_exact() {
|
||||
let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dst_allow_cidr() {
|
||||
let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Destination block list (deny overrides) ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_block_overrides_allow() {
|
||||
let sec = security(
|
||||
None,
|
||||
None,
|
||||
Some(vec!["10.0.0.0/8"]),
|
||||
Some(vec!["10.0.0.99"]),
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Combined source + destination ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn combined_src_and_dst_filtering() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
None,
|
||||
Some(vec!["8.8.8.8"]),
|
||||
None,
|
||||
);
|
||||
// Valid source, valid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow);
|
||||
// Invalid source
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc);
|
||||
// Valid source, invalid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── IP matching edge cases ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wildcard_single_octet() {
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*"));
|
||||
assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_boundaries() {
|
||||
assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pattern_no_match() {
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0"));
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
@@ -12,6 +10,8 @@ use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -19,9 +19,17 @@ use crate::transport;
|
||||
pub struct ClientConfig {
|
||||
pub server_url: String,
|
||||
pub server_public_key: String,
|
||||
/// Client's Noise IK static private key (base64) — required for authentication.
|
||||
pub client_private_key: String,
|
||||
/// Client's Noise IK static public key (base64) — for reference/display.
|
||||
pub client_public_key: String,
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
/// Transport type: "websocket" (default) or "quic".
|
||||
pub transport: Option<String>,
|
||||
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
|
||||
pub server_cert_hash: Option<String>,
|
||||
}
|
||||
|
||||
/// Client statistics.
|
||||
@@ -100,22 +108,83 @@ impl VpnClient {
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
// Decode keys
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
let client_priv_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.client_private_key,
|
||||
)?;
|
||||
|
||||
// Connect to WebSocket server
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
// Create transport based on configuration
|
||||
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
|
||||
let transport_type = config.transport.as_deref().unwrap_or("auto");
|
||||
match transport_type {
|
||||
"quic" => {
|
||||
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
|
||||
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
info!("Connected via QUIC");
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
"websocket" => {
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
_ => {
|
||||
// "auto" (default): try QUIC first, fall back to WebSocket
|
||||
// Extract host:port from the URL for QUIC attempt
|
||||
let quic_addr = extract_host_port(&config.server_url);
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
if let Some(ref addr) = quic_addr {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(3),
|
||||
try_quic_connect(addr, cert_hash),
|
||||
).await {
|
||||
Ok(Ok((quic_sink, quic_stream))) => {
|
||||
info!("Auto: connected via QUIC to {}", addr);
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC unavailable)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Auto: QUIC timed out, falling back to WebSocket");
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC timed out)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Can't extract host:port for QUIC, use WebSocket directly
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Noise IK handshake (client side = initiator, presents static key)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
let mut initiator = crypto::create_initiator(&server_pub_key)?;
|
||||
let mut initiator = crypto::create_initiator(&client_priv_key, &server_pub_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es
|
||||
// -> e, es, s, ss
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let init_frame = Frame {
|
||||
packet_type: PacketType::HandshakeInit,
|
||||
@@ -123,13 +192,11 @@ impl VpnClient {
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// <- e, ee
|
||||
let resp_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
|
||||
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
|
||||
// <- e, ee, se
|
||||
let resp_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
};
|
||||
|
||||
@@ -145,9 +212,9 @@ impl VpnClient {
|
||||
let mut noise_transport = initiator.into_transport_mode()?;
|
||||
|
||||
// Receive assigned IP info (encrypted)
|
||||
let info_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected IP info message"),
|
||||
let info_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before IP info"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&info_msg[..]);
|
||||
@@ -167,8 +234,15 @@ impl VpnClient {
|
||||
|
||||
info!("Connected to VPN, assigned IP: {}", assigned_ip);
|
||||
|
||||
// Create adaptive keepalive monitor
|
||||
let (monitor, handle) = keepalive::create_keepalive(None);
|
||||
// Create adaptive keepalive monitor (use custom interval if configured)
|
||||
let ka_config = config.keepalive_interval_secs.map(|secs| {
|
||||
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
|
||||
cfg.degraded_interval = std::time::Duration::from_secs(secs);
|
||||
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
|
||||
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
|
||||
cfg
|
||||
});
|
||||
let (monitor, handle) = keepalive::create_keepalive(ka_config);
|
||||
self.quality_rx = Some(handle.quality_rx);
|
||||
|
||||
// Spawn the keepalive monitor
|
||||
@@ -177,8 +251,8 @@ impl VpnClient {
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
ws_sink,
|
||||
ws_stream,
|
||||
sink,
|
||||
stream,
|
||||
noise_transport,
|
||||
state,
|
||||
stats,
|
||||
@@ -273,8 +347,8 @@ impl VpnClient {
|
||||
|
||||
/// The main client packet forwarding loop (runs in a spawned task).
|
||||
async fn client_loop(
|
||||
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
|
||||
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
mut noise_transport: snow::TransportState,
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
@@ -287,10 +361,10 @@ async fn client_loop(
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
Ok(Some(data)) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
@@ -321,17 +395,13 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Connection closed");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_sink.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
Err(e) => {
|
||||
error!("Transport error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
@@ -347,7 +417,7 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
@@ -378,12 +448,51 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
|
||||
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
let _ = sink.close().await;
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect via QUIC. Returns transport halves on success.
|
||||
async fn try_quic_connect(
|
||||
addr: &str,
|
||||
cert_hash: Option<&str>,
|
||||
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
|
||||
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
|
||||
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
Ok((sink, stream))
|
||||
}
|
||||
|
||||
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
|
||||
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
|
||||
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
|
||||
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
|
||||
fn extract_host_port(url: &str) -> Option<String> {
|
||||
if url.starts_with("ws://") || url.starts_with("wss://") {
|
||||
// Parse as URL
|
||||
let stripped = if url.starts_with("wss://") {
|
||||
&url[6..]
|
||||
} else {
|
||||
&url[5..]
|
||||
};
|
||||
// Remove path
|
||||
let host_port = stripped.split('/').next()?;
|
||||
if host_port.contains(':') {
|
||||
Some(host_port.to_string())
|
||||
} else {
|
||||
// Default port
|
||||
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
|
||||
Some(format!("{}:{}", host_port, default_port))
|
||||
}
|
||||
} else if url.contains(':') {
|
||||
// Already host:port
|
||||
Some(url.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
362
rust/src/client_registry.rs
Normal file
362
rust/src/client_registry.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Per-client rate limiting configuration.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
/// Per-client security settings — aligned with SmartProxy's IRouteSecurity pattern.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientSecurity {
|
||||
/// Source IPs/CIDRs the client may connect FROM (empty/None = any).
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// Source IPs blocked — overrides ip_allow_list (deny wins).
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Destination IPs/CIDRs the client may reach (empty/None = all).
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
/// Destination IPs blocked — overrides destination_allow_list (deny wins).
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
/// Max concurrent connections from this client.
|
||||
pub max_connections: Option<u32>,
|
||||
/// Per-client rate limiting.
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
/// A registered client entry — the server-side source of truth.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientEntry {
|
||||
/// Human-readable client ID (e.g. "alice-laptop").
|
||||
pub client_id: String,
|
||||
/// Client's Noise IK public key (base64).
|
||||
pub public_key: String,
|
||||
/// Client's WireGuard public key (base64) — optional.
|
||||
pub wg_public_key: Option<String>,
|
||||
/// Security settings (ACLs, rate limits).
|
||||
pub security: Option<ClientSecurity>,
|
||||
/// Traffic priority (lower = higher priority, default: 100).
|
||||
pub priority: Option<u32>,
|
||||
/// Whether this client is enabled (default: true).
|
||||
pub enabled: Option<bool>,
|
||||
/// Tags for grouping.
|
||||
pub tags: Option<Vec<String>>,
|
||||
/// Optional description.
|
||||
pub description: Option<String>,
|
||||
/// Optional expiry (ISO 8601 timestamp).
|
||||
pub expires_at: Option<String>,
|
||||
/// Assigned VPN IP address.
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientEntry {
|
||||
/// Whether this client is considered enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Whether this client has expired based on current time.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(ref expires) = self.expires_at {
|
||||
if let Ok(expiry) = chrono::DateTime::parse_from_rfc3339(expires) {
|
||||
return chrono::Utc::now() > expiry;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory client registry with dual-key indexing.
|
||||
pub struct ClientRegistry {
|
||||
/// Primary index: clientId → ClientEntry
|
||||
entries: HashMap<String, ClientEntry>,
|
||||
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
||||
key_index: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
key_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a registry from a list of client entries.
|
||||
pub fn from_entries(entries: Vec<ClientEntry>) -> Result<Self> {
|
||||
let mut registry = Self::new();
|
||||
for entry in entries {
|
||||
registry.add(entry)?;
|
||||
}
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Add a client to the registry.
|
||||
pub fn add(&mut self, entry: ClientEntry) -> Result<()> {
|
||||
if self.entries.contains_key(&entry.client_id) {
|
||||
anyhow::bail!("Client '{}' already exists", entry.client_id);
|
||||
}
|
||||
if self.key_index.contains_key(&entry.public_key) {
|
||||
anyhow::bail!("Public key already registered to another client");
|
||||
}
|
||||
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
||||
self.entries.insert(entry.client_id.clone(), entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a client by ID.
|
||||
pub fn remove(&mut self, client_id: &str) -> Result<ClientEntry> {
|
||||
let entry = self.entries.remove(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
self.key_index.remove(&entry.public_key);
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Get a client by ID.
|
||||
pub fn get_by_id(&self, client_id: &str) -> Option<&ClientEntry> {
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Get a client by public key (used during IK handshake verification).
|
||||
pub fn get_by_key(&self, public_key: &str) -> Option<&ClientEntry> {
|
||||
let client_id = self.key_index.get(public_key)?;
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Check if a public key is authorized (exists, enabled, not expired).
|
||||
pub fn is_authorized(&self, public_key: &str) -> bool {
|
||||
match self.get_by_key(public_key) {
|
||||
Some(entry) => entry.is_enabled() && !entry.is_expired(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client entry. The closure receives a mutable reference to the entry.
|
||||
pub fn update<F>(&mut self, client_id: &str, updater: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut ClientEntry),
|
||||
{
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
let old_key = entry.public_key.clone();
|
||||
updater(entry);
|
||||
// If public key changed, update the index
|
||||
if entry.public_key != old_key {
|
||||
self.key_index.remove(&old_key);
|
||||
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all client entries.
|
||||
pub fn list(&self) -> Vec<&ClientEntry> {
|
||||
self.entries.values().collect()
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns the updated entry.
|
||||
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
// Update key index
|
||||
self.key_index.remove(&entry.public_key);
|
||||
entry.public_key = new_public_key.clone();
|
||||
entry.wg_public_key = new_wg_public_key;
|
||||
self.key_index.insert(new_public_key, client_id.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered clients.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(id: &str, key: &str) -> ClientEntry {
|
||||
ClientEntry {
|
||||
client_id: id.to_string(),
|
||||
public_key: key.to_string(),
|
||||
wg_public_key: None,
|
||||
security: None,
|
||||
priority: None,
|
||||
enabled: None,
|
||||
tags: None,
|
||||
description: None,
|
||||
expires_at: None,
|
||||
assigned_ip: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_lookup() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
assert!(reg.get_by_id("alice").is_some());
|
||||
assert!(reg.get_by_key("key_alice").is_some());
|
||||
assert_eq!(reg.get_by_key("key_alice").unwrap().client_id, "alice");
|
||||
assert!(reg.get_by_id("bob").is_none());
|
||||
assert!(reg.get_by_key("key_bob").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_id() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key1")).unwrap();
|
||||
assert!(reg.add(make_entry("alice", "key2")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "same_key")).unwrap();
|
||||
assert!(reg.add(make_entry("bob", "same_key")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert_eq!(reg.len(), 1);
|
||||
|
||||
let removed = reg.remove("alice").unwrap();
|
||||
assert_eq!(removed.client_id, "alice");
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(reg.get_by_key("key_alice").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.remove("ghost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_enabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert!(reg.is_authorized("key_alice")); // enabled by default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_disabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.enabled = Some(false);
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_expired() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2020-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_future_expiry() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2099-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_unknown_key() {
|
||||
let reg = ClientRegistry::new();
|
||||
assert!(!reg.is_authorized("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
reg.update("alice", |entry| {
|
||||
entry.description = Some("Updated".to_string());
|
||||
entry.enabled = Some(false);
|
||||
}).unwrap();
|
||||
|
||||
let entry = reg.get_by_id("alice").unwrap();
|
||||
assert_eq!(entry.description.as_deref(), Some("Updated"));
|
||||
assert!(!entry.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.update("ghost", |_| {}).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rotate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "old_key")).unwrap();
|
||||
|
||||
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
|
||||
|
||||
assert!(reg.get_by_key("old_key").is_none());
|
||||
assert!(reg.get_by_key("new_key").is_some());
|
||||
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_entries() {
|
||||
let entries = vec![
|
||||
make_entry("alice", "key_a"),
|
||||
make_entry("bob", "key_b"),
|
||||
];
|
||||
let reg = ClientRegistry::from_entries(entries).unwrap();
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert!(reg.get_by_key("key_a").is_some());
|
||||
assert!(reg.get_by_key("key_b").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_clients() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_a")).unwrap();
|
||||
reg.add(make_entry("bob", "key_b")).unwrap();
|
||||
let list = reg.list();
|
||||
assert_eq!(list.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_with_rate_limit() {
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.security = Some(ClientSecurity {
|
||||
ip_allow_list: Some(vec!["192.168.1.0/24".to_string()]),
|
||||
ip_block_list: Some(vec!["192.168.1.100".to_string()]),
|
||||
destination_allow_list: None,
|
||||
destination_block_list: None,
|
||||
max_connections: Some(5),
|
||||
rate_limit: Some(ClientRateLimit {
|
||||
bytes_per_sec: 1_000_000,
|
||||
burst_bytes: 2_000_000,
|
||||
}),
|
||||
});
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(entry).unwrap();
|
||||
let e = reg.get_by_id("alice").unwrap();
|
||||
let sec = e.security.as_ref().unwrap();
|
||||
assert_eq!(sec.rate_limit.as_ref().unwrap().bytes_per_sec, 1_000_000);
|
||||
assert_eq!(sec.max_connections, Some(5));
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,10 @@ use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use snow::Builder;
|
||||
|
||||
/// Noise protocol pattern: NK (client knows server pubkey, no client auth at Noise level)
|
||||
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
|
||||
/// Noise protocol pattern: IK (client presents static key, server authenticates client)
|
||||
/// IK = Initiator's static key is transmitted; responder's Key is pre-known.
|
||||
/// This provides mutual authentication: server verifies client identity via public key.
|
||||
const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
/// Generate a new Noise static keypair.
|
||||
/// Returns (public_key_base64, private_key_base64).
|
||||
@@ -22,18 +24,23 @@ pub fn generate_keypair_raw() -> Result<snow::Keypair> {
|
||||
Ok(builder.generate_keypair()?)
|
||||
}
|
||||
|
||||
/// Create a Noise NK initiator (client side).
|
||||
/// The client knows the server's static public key.
|
||||
pub fn create_initiator(server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
/// Create a Noise IK initiator (client side).
|
||||
/// The client provides its own static keypair AND the server's public key.
|
||||
/// The client's static key is transmitted (encrypted) during the handshake,
|
||||
/// allowing the server to authenticate the client.
|
||||
pub fn create_initiator(client_private_key: &[u8], server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
.local_private_key(client_private_key)
|
||||
.remote_public_key(server_public_key)
|
||||
.build_initiator()?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Create a Noise NK responder (server side).
|
||||
/// Create a Noise IK responder (server side).
|
||||
/// The server uses its static private key.
|
||||
/// After the handshake, call `get_remote_static()` on the HandshakeState
|
||||
/// (before `into_transport_mode()`) to retrieve the client's public key.
|
||||
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
@@ -42,19 +49,20 @@ pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Perform the full Noise NK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport).
|
||||
/// Perform the full Noise IK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport, client_public_key).
|
||||
/// The client_public_key is extracted from the responder before entering transport mode.
|
||||
pub fn perform_handshake(
|
||||
mut initiator: snow::HandshakeState,
|
||||
mut responder: snow::HandshakeState,
|
||||
) -> Result<(snow::TransportState, snow::TransportState)> {
|
||||
) -> Result<(snow::TransportState, snow::TransportState, Vec<u8>)> {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es (initiator sends)
|
||||
// -> e, es, s, ss (initiator sends ephemeral + encrypted static key)
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let msg1 = buf[..len].to_vec();
|
||||
|
||||
// <- e, ee (responder reads and responds)
|
||||
// <- e, ee, se (responder reads and responds)
|
||||
responder.read_message(&msg1, &mut buf)?;
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let msg2 = buf[..len].to_vec();
|
||||
@@ -62,10 +70,16 @@ pub fn perform_handshake(
|
||||
// Initiator reads response
|
||||
initiator.read_message(&msg2, &mut buf)?;
|
||||
|
||||
// Extract client's public key from responder BEFORE entering transport mode
|
||||
let client_public_key = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake did not provide client static key"))?
|
||||
.to_vec();
|
||||
|
||||
let i_transport = initiator.into_transport_mode()?;
|
||||
let r_transport = responder.into_transport_mode()?;
|
||||
|
||||
Ok((i_transport, r_transport))
|
||||
Ok((i_transport, r_transport, client_public_key))
|
||||
}
|
||||
|
||||
/// XChaCha20-Poly1305 encryption for post-handshake data.
|
||||
@@ -135,15 +149,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_handshake() {
|
||||
fn noise_ik_handshake() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
let initiator = create_initiator(&server_kp.public).unwrap();
|
||||
let initiator = create_initiator(&client_kp.private, &server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
let (mut i_transport, mut r_transport) =
|
||||
let (mut i_transport, mut r_transport, remote_key) =
|
||||
perform_handshake(initiator, responder).unwrap();
|
||||
|
||||
// Verify the server received the client's public key
|
||||
assert_eq!(remote_key, client_kp.public);
|
||||
|
||||
// Test encrypted communication
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let plaintext = b"hello from client";
|
||||
@@ -159,6 +177,20 @@ mod tests {
|
||||
assert_eq!(&out[..len], plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_ik_wrong_server_key_fails() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let wrong_server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
// Client uses wrong server public key
|
||||
let initiator = create_initiator(&client_kp.private, &wrong_server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
// Handshake should fail because client targeted wrong server
|
||||
assert!(perform_handshake(initiator, responder).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xchacha_encrypt_decrypt() {
|
||||
let key = [42u8; 32];
|
||||
|
||||
@@ -5,6 +5,8 @@ pub mod management;
|
||||
pub mod codec;
|
||||
pub mod crypto;
|
||||
pub mod transport;
|
||||
pub mod transport_trait;
|
||||
pub mod quic_transport;
|
||||
pub mod keepalive;
|
||||
pub mod tunnel;
|
||||
pub mod network;
|
||||
@@ -15,3 +17,6 @@ pub mod telemetry;
|
||||
pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
pub mod wireguard;
|
||||
pub mod client_registry;
|
||||
pub mod acl;
|
||||
|
||||
@@ -7,6 +7,7 @@ use tracing::{info, error, warn};
|
||||
use crate::client::{ClientConfig, VpnClient};
|
||||
use crate::crypto;
|
||||
use crate::server::{ServerConfig, VpnServer};
|
||||
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig};
|
||||
|
||||
// ============================================================================
|
||||
// IPC protocol types
|
||||
@@ -93,6 +94,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
|
||||
let mut vpn_client = VpnClient::new();
|
||||
let mut vpn_server = VpnServer::new();
|
||||
let mut wg_client = WgClient::new();
|
||||
let mut wg_server = WgServer::new();
|
||||
|
||||
send_event_stdout("ready", serde_json::json!({ "mode": mode }));
|
||||
|
||||
@@ -127,8 +130,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
};
|
||||
|
||||
let response = match mode {
|
||||
"client" => handle_client_request(&request, &mut vpn_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server).await,
|
||||
"client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await,
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
send_response_stdout(&response);
|
||||
@@ -150,6 +153,8 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
// Shared state behind Mutex for socket mode (multiple connections)
|
||||
let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new()));
|
||||
let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new()));
|
||||
let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new()));
|
||||
let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new()));
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
@@ -157,9 +162,11 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
let mode = mode.to_string();
|
||||
let client = vpn_client.clone();
|
||||
let server = vpn_server.clone();
|
||||
let wg_c = wg_client.clone();
|
||||
let wg_s = wg_server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_socket_connection(stream, &mode, client, server).await
|
||||
handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await
|
||||
{
|
||||
warn!("Socket connection error: {}", e);
|
||||
}
|
||||
@@ -177,6 +184,8 @@ async fn handle_socket_connection(
|
||||
mode: &str,
|
||||
vpn_client: std::sync::Arc<Mutex<VpnClient>>,
|
||||
vpn_server: std::sync::Arc<Mutex<VpnServer>>,
|
||||
wg_client: std::sync::Arc<Mutex<WgClient>>,
|
||||
wg_server: std::sync::Arc<Mutex<WgServer>>,
|
||||
) -> Result<()> {
|
||||
let (reader, mut writer) = stream.into_split();
|
||||
let buf_reader = BufReader::new(reader);
|
||||
@@ -227,11 +236,13 @@ async fn handle_socket_connection(
|
||||
let response = match mode {
|
||||
"client" => {
|
||||
let mut client = vpn_client.lock().await;
|
||||
handle_client_request(&request, &mut client).await
|
||||
let mut wg_c = wg_client.lock().await;
|
||||
handle_client_request(&request, &mut client, &mut wg_c).await
|
||||
}
|
||||
"server" => {
|
||||
let mut server = vpn_server.lock().await;
|
||||
handle_server_request(&request, &mut server).await
|
||||
let mut wg_s = wg_server.lock().await;
|
||||
handle_server_request(&request, &mut server, &mut wg_s).await
|
||||
}
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
@@ -252,38 +263,79 @@ async fn handle_socket_connection(
|
||||
async fn handle_client_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_client: &mut VpnClient,
|
||||
wg_client: &mut WgClient,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"connect" => {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transport is "wireguard"
|
||||
let transport = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transport"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
if transport == "wireguard" {
|
||||
let config: WgClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("WG connect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnect" => {
|
||||
if wg_client.is_running() {
|
||||
match wg_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG disconnect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disconnect" => match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_status().await)
|
||||
} else {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
}
|
||||
}
|
||||
"getConnectionQuality" => {
|
||||
match vpn_client.get_connection_quality() {
|
||||
@@ -329,45 +381,92 @@ async fn handle_client_request(
|
||||
async fn handle_server_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_server: &mut VpnServer,
|
||||
wg_server: &mut WgServer,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transportMode is "wireguard"
|
||||
let transport_mode = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transportMode"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
if transport_mode == "wireguard" {
|
||||
let config: WgServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
if wg_server.is_running() {
|
||||
match wg_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_status())
|
||||
} else {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"listClients" => {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnectClient" => {
|
||||
@@ -436,6 +535,153 @@ async fn handle_server_request(
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
"generateWgKeypair" => {
|
||||
let (public_key, private_key) = wireguard::generate_wg_keypair();
|
||||
ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
)
|
||||
}
|
||||
"addWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let config: WgPeerConfig = match serde_json::from_value(
|
||||
request.params.get("peer").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid peer config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.add_peer(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing publicKey".to_string()),
|
||||
};
|
||||
match wg_server.remove_peer(&public_key).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listWgPeers" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
// ── Client Registry (Hub) Commands ────────────────────────────────
|
||||
"createClient" => {
|
||||
let client_partial = request.params.get("client").cloned().unwrap_or_default();
|
||||
match vpn_server.create_client(client_partial).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Create client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_registered_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.get_registered_client(&client_id).await {
|
||||
Ok(entry) => ManagementResponse::ok(id, entry),
|
||||
Err(e) => ManagementResponse::err(id, format!("Get client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listRegisteredClients" => {
|
||||
let clients = vpn_server.list_registered_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"updateClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let update = request.params.get("update").cloned().unwrap_or_default();
|
||||
match vpn_server.update_registered_client(&client_id, update).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Update client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"enableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.enable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Enable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.disable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"rotateClientKey" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.rotate_client_key(&client_id).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Key rotation failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"exportClientConfig" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let format = request.params.get("format").and_then(|v| v.as_str()).unwrap_or("smartvpn");
|
||||
match vpn_server.export_client_config(&client_id, format).await {
|
||||
Ok(config) => ManagementResponse::ok(id, config),
|
||||
Err(e) => ManagementResponse::err(id, format!("Export failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"generateClientKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
_ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
546
rust/src/quic_transport.rs
Normal file
546
rust/src/quic_transport.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use quinn::crypto::rustls::QuicClientConfig;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn, debug};
|
||||
|
||||
use crate::transport_trait::{TransportSink, TransportStream};
|
||||
|
||||
// ============================================================================
|
||||
// TLS / Certificate helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a self-signed certificate and private key for QUIC.
|
||||
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
|
||||
let cert_der = CertificateDer::from(cert.cert);
|
||||
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
|
||||
Ok((vec![cert_der], key_der))
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
|
||||
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
|
||||
use ring::digest;
|
||||
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server-side QUIC endpoint
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for the QUIC server endpoint.
|
||||
pub struct QuicServerConfig {
|
||||
pub listen_addr: String,
|
||||
pub cert_chain: Vec<CertificateDer<'static>>,
|
||||
pub private_key: PrivateKeyDer<'static>,
|
||||
pub idle_timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// Create a QUIC server endpoint bound to the given address.
|
||||
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
|
||||
let addr: SocketAddr = config.listen_addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(config.cert_chain, config.private_key)?;
|
||||
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
|
||||
));
|
||||
// Enable datagrams with a generous max size
|
||||
transport.datagram_receive_buffer_size(Some(65535));
|
||||
transport.datagram_send_buffer_size(65535);
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
let endpoint = quinn::Endpoint::server(server_config, addr)?;
|
||||
info!("QUIC server listening on {}", addr);
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client-side QUIC connection
|
||||
// ============================================================================
|
||||
|
||||
/// A certificate verifier that accepts any server certificate.
|
||||
/// Safe when Noise NK provides server authentication at the application layer.
|
||||
#[derive(Debug)]
|
||||
struct AcceptAnyCert;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
|
||||
#[derive(Debug)]
|
||||
struct CertHashVerifier {
|
||||
expected_hash: String,
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
let actual_hash = cert_hash(end_entity);
|
||||
if actual_hash == self.expected_hash {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
} else {
|
||||
Err(rustls::Error::General(format!(
|
||||
"Certificate hash mismatch: expected {}, got {}",
|
||||
self.expected_hash, actual_hash
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
// QUIC always uses TLS 1.3
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a QUIC server.
|
||||
///
|
||||
/// - If `server_cert_hash` is provided, verifies the server certificate matches
|
||||
/// the given SHA-256 hash (cert pinning).
|
||||
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
|
||||
/// safe because the Noise NK handshake (which runs over the QUIC stream)
|
||||
/// authenticates the server via its pre-shared public key — the same trust
|
||||
/// model as WireGuard.
|
||||
pub async fn connect_quic(
|
||||
addr: &str,
|
||||
server_cert_hash: Option<&str>,
|
||||
) -> Result<quinn::Connection> {
|
||||
let remote: SocketAddr = addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let tls_config = if let Some(hash) = server_cert_hash {
|
||||
// Pin to a specific certificate hash
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash.to_string(),
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
} else {
|
||||
// Accept any cert — Noise NK provides server authentication
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
};
|
||||
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
info!("Connecting to QUIC server at {}", addr);
|
||||
let connection = endpoint.connect(remote, "smartvpn")?.await?;
|
||||
info!("QUIC connection established");
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QUIC Transport Sink / Stream implementations
|
||||
// ============================================================================
|
||||
|
||||
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportSink {
|
||||
send_stream: quinn::SendStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportSink {
|
||||
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
send_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for QuicTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// Length-prefix framing: [4-byte big-endian length][payload]
|
||||
let len = data.len() as u32;
|
||||
self.send_stream.write_all(&len.to_be_bytes()).await?;
|
||||
self.send_stream.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
let max_size = self.connection.max_datagram_size();
|
||||
match max_size {
|
||||
Some(max) if data.len() <= max => {
|
||||
self.connection.send_datagram(data.into())?;
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Datagram too large or datagrams disabled — fall back to reliable
|
||||
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.send_stream.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportStream {
|
||||
recv_stream: quinn::RecvStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportStream {
|
||||
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
recv_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for QuicTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// Read length prefix
|
||||
let mut len_buf = [0u8; 4];
|
||||
match self.recv_stream.read_exact(&mut len_buf).await {
|
||||
Ok(()) => {}
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
|
||||
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
|
||||
warn!("QUIC connection lost: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(len_buf) as usize;
|
||||
if len > 65536 {
|
||||
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
|
||||
}
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
match self.recv_stream.read_exact(&mut data).await {
|
||||
Ok(()) => Ok(Some(data)),
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
match self.connection.read_datagram().await {
|
||||
Ok(data) => Ok(Some(data.to_vec())),
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
|
||||
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
self.connection.max_datagram_size().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a QUIC connection and open a bidirectional control stream.
|
||||
/// Returns the transport sink/stream pair ready for the VPN handshake.
|
||||
pub async fn accept_quic_connection(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
// The client opens the bidirectional control stream
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
info!("QUIC bidirectional control stream accepted");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
/// Open a QUIC connection's bidirectional control stream (client side).
|
||||
pub async fn open_quic_streams(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
info!("QUIC bidirectional control stream opened");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cert_generation_and_hash() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
assert_eq!(certs.len(), 1);
|
||||
let hash = cert_hash(&certs[0]);
|
||||
// SHA-256 base64 is 44 characters
|
||||
assert_eq!(hash.len(), 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cert_hash_deterministic() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
let hash1 = cert_hash(&certs[0]);
|
||||
let hash2 = cert_hash(&certs[0]);
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
/// Helper: create QUIC server and client endpoints.
|
||||
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
|
||||
let (certs, key) = generate_self_signed_cert().unwrap();
|
||||
let hash = cert_hash(&certs[0]);
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
|
||||
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key).unwrap();
|
||||
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
|
||||
));
|
||||
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
|
||||
|
||||
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash,
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(client_tls).unwrap(),
|
||||
));
|
||||
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
client_ep.set_default_client_config(client_config);
|
||||
|
||||
let server_addr = server_ep.local_addr().unwrap().to_string();
|
||||
(server_ep, client_ep, server_addr)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_server_client_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi, read, echo, finish
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
|
||||
let data = s_recv.read_to_end(1024).await.unwrap();
|
||||
s_send.write_all(&data).await.unwrap();
|
||||
s_send.finish().unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open_bi, write, finish, read
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
|
||||
c_send.write_all(b"hello quinn").await.unwrap();
|
||||
c_send.finish().unwrap();
|
||||
let data = c_recv.read_to_end(1024).await.unwrap();
|
||||
assert_eq!(&data[..], b"hello quinn");
|
||||
|
||||
let _ = server_task.await;
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test transport trait wrappers over QUIC.
|
||||
/// Key: client must send data first (QUIC streams are opened implicitly by data).
|
||||
/// The server accept_bi runs concurrently with the client's first send_reliable.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_transport_trait_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server task: accept connection, then accept_bi (blocks until client sends data)
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
(s_sink, s_stream, server_ep)
|
||||
});
|
||||
|
||||
// Client: connect, open_bi via wrapper
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
|
||||
// Client sends first — this triggers the QUIC stream to become visible to the server
|
||||
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
|
||||
|
||||
// Now server's accept_bi unblocks
|
||||
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
|
||||
|
||||
// Server reads the message
|
||||
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-client");
|
||||
|
||||
// Server -> Client
|
||||
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
|
||||
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-server");
|
||||
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test QUIC datagram support.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_datagram_exchange() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi (opens control stream), then read datagram
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
// Accept bi stream (control channel)
|
||||
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
|
||||
// Read datagram
|
||||
let dgram = conn.read_datagram().await.unwrap();
|
||||
assert_eq!(&dgram[..], b"dgram-payload");
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
|
||||
|
||||
// Send initial data to open the stream (required for QUIC)
|
||||
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
|
||||
|
||||
// Small yield to let the server process the bi stream
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// Send datagram
|
||||
assert!(conn.max_datagram_size().is_some());
|
||||
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test that supports_datagrams returns true for QUIC transports.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_supports_datagrams() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
assert!(s_stream.supports_datagrams());
|
||||
server_ep
|
||||
});
|
||||
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
assert!(c_stream.supports_datagrams());
|
||||
|
||||
// Send data to trigger server's accept_bi
|
||||
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
}
|
||||
@@ -130,10 +130,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000);
|
||||
// Use a low rate so refill between consecutive calls is negligible
|
||||
let mut tb = TokenBucket::new(100, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
// At 100 bytes/sec, the few μs between calls add ~0 tokens
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
@@ -8,15 +7,18 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
use crate::acl;
|
||||
use crate::client_registry::{ClientEntry, ClientRegistry};
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
use crate::network::IpPool;
|
||||
use crate::ratelimit::TokenBucket;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
|
||||
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
|
||||
@@ -39,6 +41,14 @@ pub struct ServerConfig {
|
||||
pub default_rate_limit_bytes_per_sec: Option<u64>,
|
||||
/// Default burst size for new clients (bytes). None = unlimited.
|
||||
pub default_burst_bytes: Option<u64>,
|
||||
/// Transport mode: "websocket" (default), "quic", or "both".
|
||||
pub transport_mode: Option<String>,
|
||||
/// QUIC listen address (host:port). Defaults to listen_addr.
|
||||
pub quic_listen_addr: Option<String>,
|
||||
/// QUIC idle timeout in seconds (default: 30).
|
||||
pub quic_idle_timeout_secs: Option<u64>,
|
||||
/// Pre-registered clients for IK authentication.
|
||||
pub clients: Option<Vec<ClientEntry>>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -56,6 +66,10 @@ pub struct ClientInfo {
|
||||
pub keepalives_received: u64,
|
||||
pub rate_limit_bytes_per_sec: Option<u64>,
|
||||
pub burst_bytes: Option<u64>,
|
||||
/// Client's authenticated Noise IK public key (base64).
|
||||
pub authenticated_key: String,
|
||||
/// Registered client ID from the client registry.
|
||||
pub registered_client_id: String,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -82,6 +96,7 @@ pub struct ServerState {
|
||||
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
|
||||
pub mtu_config: MtuConfig,
|
||||
pub started_at: std::time::Instant,
|
||||
pub client_registry: RwLock<ClientRegistry>,
|
||||
}
|
||||
|
||||
/// The VPN server.
|
||||
@@ -121,6 +136,12 @@ impl VpnServer {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
|
||||
|
||||
// Build client registry from config
|
||||
let registry = ClientRegistry::from_entries(
|
||||
config.clients.clone().unwrap_or_default()
|
||||
)?;
|
||||
info!("Client registry loaded with {} entries", registry.len());
|
||||
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
ip_pool: Mutex::new(ip_pool),
|
||||
@@ -129,20 +150,65 @@ impl VpnServer {
|
||||
rate_limiters: Mutex::new(HashMap::new()),
|
||||
mtu_config,
|
||||
started_at: std::time::Instant::now(),
|
||||
client_registry: RwLock::new(registry),
|
||||
});
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||
self.state = Some(state.clone());
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
|
||||
let listen_addr = config.listen_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
info!("VPN server started");
|
||||
match transport_mode {
|
||||
"quic" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
"both" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
let state2 = state.clone();
|
||||
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
|
||||
// Store second shutdown sender so both listeners stop
|
||||
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
|
||||
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
|
||||
self.shutdown_tx = Some(combined_tx);
|
||||
|
||||
// Forward combined shutdown to both listeners
|
||||
tokio::spawn(async move {
|
||||
combined_rx.recv().await;
|
||||
let _ = shutdown_tx_orig.send(()).await;
|
||||
let _ = shutdown_tx2.send(()).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("WebSocket listener error: {}", e);
|
||||
}
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// "websocket" (default)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!("VPN server started (transport: {})", transport_mode);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -237,9 +303,268 @@ impl VpnServer {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ───────────────────────────────────
|
||||
|
||||
/// Create a new client entry. Generates keypairs and assigns an IP.
|
||||
/// Returns a JSON value with the full config bundle including secrets.
|
||||
pub async fn create_client(&self, partial: serde_json::Value) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let client_id = partial.get("clientId")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("clientId is required"))?
|
||||
.to_string();
|
||||
|
||||
// Generate Noise IK keypair for the client
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
|
||||
// Generate WireGuard keypair for the client
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
// Allocate a VPN IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Build entry from partial + generated values
|
||||
let entry = ClientEntry {
|
||||
client_id: client_id.clone(),
|
||||
public_key: noise_pub.clone(),
|
||||
wg_public_key: Some(wg_pub.clone()),
|
||||
security: serde_json::from_value(
|
||||
partial.get("security").cloned().unwrap_or(serde_json::Value::Null)
|
||||
).ok(),
|
||||
priority: partial.get("priority").and_then(|v| v.as_u64()).map(|v| v as u32),
|
||||
enabled: partial.get("enabled").and_then(|v| v.as_bool()).or(Some(true)),
|
||||
tags: partial.get("tags").and_then(|v| {
|
||||
v.as_array().map(|a| a.iter().filter_map(|s| s.as_str().map(String::from)).collect())
|
||||
}),
|
||||
description: partial.get("description").and_then(|v| v.as_str()).map(String::from),
|
||||
expires_at: partial.get("expiresAt").and_then(|v| v.as_str()).map(String::from),
|
||||
assigned_ip: Some(assigned_ip.to_string()),
|
||||
};
|
||||
|
||||
// Add to registry
|
||||
state.client_registry.write().await.add(entry.clone())?;
|
||||
|
||||
// Build SmartVPN client config
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
// Build WireGuard config string
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv,
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
let entry_json = serde_json::to_value(&entry)?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Remove a registered client from the registry (and disconnect if connected).
|
||||
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let entry = state.client_registry.write().await.remove(client_id)?;
|
||||
// Release the IP if assigned
|
||||
if let Some(ref ip_str) = entry.assigned_ip {
|
||||
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
}
|
||||
}
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a registered client by ID.
|
||||
pub async fn get_registered_client(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
Ok(serde_json::to_value(entry)?)
|
||||
}
|
||||
|
||||
/// List all registered clients.
|
||||
pub async fn list_registered_clients(&self) -> Vec<ClientEntry> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.client_registry.read().await.list().into_iter().cloned().collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a registered client's fields.
|
||||
pub async fn update_registered_client(&self, client_id: &str, update: serde_json::Value) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
if let Some(security) = update.get("security") {
|
||||
entry.security = serde_json::from_value(security.clone()).ok();
|
||||
}
|
||||
if let Some(priority) = update.get("priority").and_then(|v| v.as_u64()) {
|
||||
entry.priority = Some(priority as u32);
|
||||
}
|
||||
if let Some(enabled) = update.get("enabled").and_then(|v| v.as_bool()) {
|
||||
entry.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(tags) = update.get("tags").and_then(|v| v.as_array()) {
|
||||
entry.tags = Some(tags.iter().filter_map(|s| s.as_str().map(String::from)).collect());
|
||||
}
|
||||
if let Some(desc) = update.get("description").and_then(|v| v.as_str()) {
|
||||
entry.description = Some(desc.to_string());
|
||||
}
|
||||
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
|
||||
entry.expires_at = Some(expires.to_string());
|
||||
}
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enable a registered client.
|
||||
pub async fn enable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(true);
|
||||
})
|
||||
}
|
||||
|
||||
/// Disable a registered client (also disconnects if connected).
|
||||
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(false);
|
||||
})?;
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
pub async fn rotate_client_key(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
state.client_registry.write().await.rotate_key(
|
||||
client_id,
|
||||
noise_pub.clone(),
|
||||
Some(wg_pub.clone()),
|
||||
)?;
|
||||
|
||||
// Disconnect existing connection (old key is no longer valid)
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
|
||||
// Get updated entry for the config bundle
|
||||
let entry_json = self.get_registered_client(client_id).await?;
|
||||
let assigned_ip = entry_json.get("assignedIp")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0.0.0.0");
|
||||
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv, assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Export a client config (without secrets) in the specified format.
|
||||
pub async fn export_client_config(&self, client_id: &str, format: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
|
||||
match format {
|
||||
"smartvpn" => {
|
||||
Ok(serde_json::json!({
|
||||
"config": {
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPublicKey": entry.public_key,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
}
|
||||
}))
|
||||
}
|
||||
"wireguard" => {
|
||||
let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0");
|
||||
let config = format!(
|
||||
"[Interface]\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
Ok(serde_json::json!({ "config": config }))
|
||||
}
|
||||
_ => anyhow::bail!("Unknown format: {}", format),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_listener(
|
||||
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
|
||||
/// to the transport-agnostic `handle_client_connection`.
|
||||
async fn run_ws_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
@@ -255,8 +580,20 @@ async fn run_listener(
|
||||
info!("New connection from {}", addr);
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_client_connection(state, stream).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
match transport::accept_connection(stream).await {
|
||||
Ok(ws) => {
|
||||
let (sink, stream) = transport_trait::split_ws(ws);
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("WebSocket upgrade failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -275,29 +612,108 @@ async fn run_listener(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
|
||||
/// `handle_client_connection`.
|
||||
async fn run_quic_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
idle_timeout_secs: u64,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
) -> Result<()> {
|
||||
// Generate or use configured TLS certificate for QUIC
|
||||
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
|
||||
(&state.config.tls_cert, &state.config.tls_key)
|
||||
{
|
||||
// Parse PEM certificates
|
||||
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut cert_pem.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
|
||||
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
|
||||
(certs, key)
|
||||
} else {
|
||||
// Generate self-signed certificate
|
||||
let (certs, key) = quic_transport::generate_self_signed_cert()?;
|
||||
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
|
||||
(certs, key)
|
||||
};
|
||||
|
||||
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
|
||||
listen_addr,
|
||||
cert_chain,
|
||||
private_key,
|
||||
idle_timeout_secs,
|
||||
})?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
incoming = endpoint.accept() => {
|
||||
match incoming {
|
||||
Some(incoming) => {
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
match incoming.await {
|
||||
Ok(conn) => {
|
||||
let remote = conn.remote_address();
|
||||
info!("New QUIC connection from {}", remote);
|
||||
match quic_transport::accept_quic_connection(conn).await {
|
||||
Ok((sink, stream)) => {
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC stream accept failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC handshake failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
None => {
|
||||
info!("QUIC endpoint closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("QUIC shutdown signal received");
|
||||
endpoint.close(0u32.into(), b"shutdown");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transport-agnostic client handler. Performs the Noise IK handshake, authenticates
|
||||
/// the client against the registry, and runs the main packet forwarding loop.
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
stream: tokio::net::TcpStream,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
) -> Result<()> {
|
||||
let ws = transport::accept_connection(stream).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
|
||||
let client_id = uuid_v4();
|
||||
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
let server_private_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&state.config.private_key,
|
||||
)?;
|
||||
|
||||
// Noise IK handshake (server side = responder)
|
||||
let mut responder = crypto::create_responder(&server_private_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// Receive handshake init
|
||||
let init_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected handshake init message"),
|
||||
// Receive handshake init (-> e, es, s, ss)
|
||||
let init_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before handshake"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&init_msg[..]);
|
||||
@@ -309,6 +725,47 @@ async fn handle_client_connection(
|
||||
}
|
||||
|
||||
responder.read_message(&frame.payload, &mut buf)?;
|
||||
|
||||
// Extract client's static public key BEFORE entering transport mode
|
||||
let client_pub_key_bytes = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake: no client static key received"))?
|
||||
.to_vec();
|
||||
let client_pub_key_b64 = base64::Engine::encode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&client_pub_key_bytes,
|
||||
);
|
||||
|
||||
// Verify client against registry
|
||||
let (registered_client_id, client_security) = {
|
||||
let registry = state.client_registry.read().await;
|
||||
if !registry.is_authorized(&client_pub_key_b64) {
|
||||
warn!("Rejecting unauthorized client with key {}", &client_pub_key_b64[..8]);
|
||||
// Send handshake response but then disconnect
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_frame = Frame {
|
||||
packet_type: PacketType::HandshakeResp,
|
||||
payload: buf[..len].to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// Send disconnect frame
|
||||
let disconnect_frame = Frame {
|
||||
packet_type: PacketType::Disconnect,
|
||||
payload: Vec::new(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Client not authorized");
|
||||
}
|
||||
let entry = registry.get_by_key(&client_pub_key_b64).unwrap();
|
||||
(entry.client_id.clone(), entry.security.clone())
|
||||
};
|
||||
|
||||
// Complete handshake (<- e, ee, se)
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_payload = buf[..len].to_vec();
|
||||
|
||||
@@ -318,13 +775,28 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Register client
|
||||
let default_rate = state.config.default_rate_limit_bytes_per_sec;
|
||||
let default_burst = state.config.default_burst_bytes;
|
||||
// Use the registered client ID as the connection ID
|
||||
let client_id = registered_client_id.clone();
|
||||
|
||||
// Allocate IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Determine rate limits: per-client security overrides server defaults
|
||||
let (rate_limit, burst) = if let Some(ref sec) = client_security {
|
||||
if let Some(ref rl) = sec.rate_limit {
|
||||
(Some(rl.bytes_per_sec), Some(rl.burst_bytes))
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
}
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
};
|
||||
|
||||
// Register connected client
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: assigned_ip.to_string(),
|
||||
@@ -335,13 +807,15 @@ async fn handle_client_connection(
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: default_rate,
|
||||
burst_bytes: default_burst,
|
||||
rate_limit_bytes_per_sec: rate_limit,
|
||||
burst_bytes: burst,
|
||||
authenticated_key: client_pub_key_b64.clone(),
|
||||
registered_client_id: registered_client_id.clone(),
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
// Set up rate limiter if defaults are configured
|
||||
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
|
||||
// Set up rate limiter
|
||||
if let (Some(rate), Some(burst)) = (rate_limit, burst) {
|
||||
state
|
||||
.rate_limiters
|
||||
.lock()
|
||||
@@ -369,25 +843,43 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
info!("Client {} ({}) connected with IP {}", registered_client_id, &client_pub_key_b64[..8], assigned_ip);
|
||||
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
Ok(Some(data)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
// ACL check on decrypted packet
|
||||
if let Some(ref sec) = client_security {
|
||||
if len >= 20 {
|
||||
// Extract src/dst from IPv4 header
|
||||
let src = Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]);
|
||||
let dst = Ipv4Addr::new(buf[16], buf[17], buf[18], buf[19]);
|
||||
let acl_result = acl::check_acl(sec, src, dst);
|
||||
if acl_result != acl::AclResult::Allow {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting check
|
||||
let allowed = {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
@@ -432,7 +924,7 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
@@ -463,20 +955,12 @@ async fn handle_client_connection(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
continue;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
Err(e) => {
|
||||
warn!("Transport error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -497,20 +981,6 @@ async fn handle_client_connection(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn uuid_v4() -> String {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let bytes: [u8; 16] = rng.gen();
|
||||
format!(
|
||||
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
|
||||
bytes[0], bytes[1], bytes[2], bytes[3],
|
||||
bytes[4], bytes[5],
|
||||
bytes[6], bytes[7],
|
||||
bytes[8], bytes[9],
|
||||
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
|
||||
)
|
||||
}
|
||||
|
||||
fn timestamp_now() -> String {
|
||||
use std::time::SystemTime;
|
||||
let duration = SystemTime::now()
|
||||
|
||||
116
rust/src/transport_trait.rs
Normal file
116
rust/src/transport_trait.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::transport::WsStream;
|
||||
|
||||
// ============================================================================
|
||||
// Transport trait abstraction
|
||||
// ============================================================================
|
||||
|
||||
/// Outbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportSink: Send + 'static {
|
||||
/// Send a framed binary message on the reliable channel.
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Send a datagram (unreliable, best-effort).
|
||||
/// Falls back to reliable if the transport does not support datagrams.
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Gracefully close the transport.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Inbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportStream: Send + 'static {
|
||||
/// Receive the next reliable binary message. Returns `None` on close.
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Receive the next datagram. Returns `None` if datagrams are unsupported
|
||||
/// or the connection is closed.
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Whether this transport supports unreliable datagrams.
|
||||
fn supports_datagrams(&self) -> bool;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket implementation
|
||||
// ============================================================================
|
||||
|
||||
/// WebSocket transport sink (wraps the write half of a split WsStream).
|
||||
pub struct WsTransportSink {
|
||||
inner: futures_util::stream::SplitSink<WsStream, Message>,
|
||||
}
|
||||
|
||||
impl WsTransportSink {
|
||||
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for WsTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
self.inner.send(Message::Binary(data.into())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// WebSocket has no datagram support — fall back to reliable.
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.inner.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket transport stream (wraps the read half of a split WsStream).
|
||||
pub struct WsTransportStream {
|
||||
inner: futures_util::stream::SplitStream<WsStream>,
|
||||
}
|
||||
|
||||
impl WsTransportStream {
|
||||
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for WsTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
|
||||
Some(Ok(Message::Close(_))) | None => return Ok(None),
|
||||
Some(Ok(Message::Ping(_))) => {
|
||||
// Ping handling is done at the tungstenite layer automatically
|
||||
// when the sink side is alive. Just skip here.
|
||||
continue;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// WebSocket does not support datagrams.
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a WebSocket stream into transport sink and stream halves.
|
||||
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
|
||||
let (sink, stream) = ws.split();
|
||||
(WsTransportSink::new(sink), WsTransportStream::new(stream))
|
||||
}
|
||||
1329
rust/src/wireguard.rs
Normal file
1329
rust/src/wireguard.rs
Normal file
File diff suppressed because it is too large
Load Diff
320
rust/tests/wg_e2e.rs
Normal file
320
rust/tests/wg_e2e.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! End-to-end WireGuard protocol tests over real UDP sockets.
|
||||
//!
|
||||
//! Entirely userspace — no root, no TUN devices.
|
||||
//! Two boringtun `Tunn` instances exchange real WireGuard packets
|
||||
//! over loopback UDP, validating handshake, encryption, and data flow.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::time;
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
use smartvpn_daemon::wireguard::generate_wg_keypair;
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn parse_key_pair(pub_b64: &str, priv_b64: &str) -> (PublicKey, StaticSecret) {
|
||||
let pub_bytes: [u8; 32] = BASE64.decode(pub_b64).unwrap().try_into().unwrap();
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
(PublicKey::from(pub_bytes), StaticSecret::from(priv_bytes))
|
||||
}
|
||||
|
||||
fn clone_secret(priv_b64: &str) -> StaticSecret {
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
StaticSecret::from(priv_bytes)
|
||||
}
|
||||
|
||||
fn make_ipv4_packet(src: Ipv4Addr, dst: Ipv4Addr, payload: &[u8]) -> Vec<u8> {
|
||||
let total_len = 20 + payload.len();
|
||||
let mut pkt = vec![0u8; total_len];
|
||||
pkt[0] = 0x45;
|
||||
pkt[2] = (total_len >> 8) as u8;
|
||||
pkt[3] = total_len as u8;
|
||||
pkt[9] = 0x11;
|
||||
pkt[12..16].copy_from_slice(&src.octets());
|
||||
pkt[16..20].copy_from_slice(&dst.octets());
|
||||
pkt[20..].copy_from_slice(payload);
|
||||
pkt
|
||||
}
|
||||
|
||||
/// Send any WriteToNetwork result, then drain the tunn for more packets.
|
||||
async fn send_and_drain(
|
||||
tunn: &mut Tunn,
|
||||
pkt: &[u8],
|
||||
socket: &UdpSocket,
|
||||
peer: SocketAddr,
|
||||
) {
|
||||
socket.send_to(pkt, peer).await.unwrap();
|
||||
let mut drain_buf = vec![0u8; 2048];
|
||||
loop {
|
||||
match tunn.decapsulate(None, &[], &mut drain_buf) {
|
||||
TunnResult::WriteToNetwork(p) => { socket.send_to(p, peer).await.unwrap(); }
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to receive a UDP packet and decapsulate it. Returns decrypted IP data if any.
|
||||
async fn try_recv_decap(
|
||||
tunn: &mut Tunn,
|
||||
socket: &UdpSocket,
|
||||
timeout_ms: u64,
|
||||
) -> Option<(Vec<u8>, Ipv4Addr, SocketAddr)> {
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
let (n, src_addr) = match time::timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
socket.recv_from(&mut recv_buf),
|
||||
).await {
|
||||
Ok(Ok(r)) => r,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let result = tunn.decapsulate(Some(src_addr.ip()), &recv_buf[..n], &mut dst_buf);
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(tunn, pkt, socket, src_addr).await;
|
||||
None
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(pkt, addr) => Some((pkt.to_vec(), addr, src_addr)),
|
||||
TunnResult::WriteToTunnelV6(_, _) => None,
|
||||
TunnResult::Done => None,
|
||||
TunnResult::Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the full WireGuard handshake between client and server over real UDP.
|
||||
async fn do_handshake(
|
||||
client_tunn: &mut Tunn,
|
||||
server_tunn: &mut Tunn,
|
||||
client_socket: &UdpSocket,
|
||||
server_socket: &UdpSocket,
|
||||
server_addr: SocketAddr,
|
||||
) {
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
// Step 1: Client initiates handshake
|
||||
match client_tunn.encapsulate(&[], &mut buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => panic!("Expected handshake init"),
|
||||
}
|
||||
|
||||
// Step 2: Server receives init → sends response
|
||||
let (n, client_from) = server_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match server_tunn.decapsulate(Some(client_from.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(server_tunn, pkt, server_socket, client_from).await;
|
||||
}
|
||||
other => panic!("Expected WriteToNetwork from server, got variant {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Step 3: Client receives response
|
||||
let (n, _) = client_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match client_tunn.decapsulate(Some(server_addr.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(client_tunn, pkt, client_socket, server_addr).await;
|
||||
}
|
||||
TunnResult::Done => {}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Step 4: Process any remaining handshake packets
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 200).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 100).await;
|
||||
|
||||
// Step 5: Timer ticks to settle
|
||||
for _ in 0..3 {
|
||||
match server_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
server_socket.send_to(pkt, client_from).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match client_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 50).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 50).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn variant_name(r: &TunnResult) -> &'static str {
|
||||
match r {
|
||||
TunnResult::Done => "Done",
|
||||
TunnResult::Err(_) => "Err",
|
||||
TunnResult::WriteToNetwork(_) => "WriteToNetwork",
|
||||
TunnResult::WriteToTunnelV4(_, _) => "WriteToTunnelV4",
|
||||
TunnResult::WriteToTunnelV6(_, _) => "WriteToTunnelV6",
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an IP packet and send it, then loop-receive on the other side until decrypted.
|
||||
async fn send_and_expect_data(
|
||||
sender_tunn: &mut Tunn,
|
||||
receiver_tunn: &mut Tunn,
|
||||
sender_socket: &UdpSocket,
|
||||
receiver_socket: &UdpSocket,
|
||||
dest_addr: SocketAddr,
|
||||
ip_packet: &[u8],
|
||||
) -> (Vec<u8>, Ipv4Addr) {
|
||||
let mut enc_buf = vec![0u8; 65536];
|
||||
|
||||
match sender_tunn.encapsulate(ip_packet, &mut enc_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
sender_socket.send_to(pkt, dest_addr).await.unwrap();
|
||||
}
|
||||
TunnResult::Err(e) => panic!("Encapsulate failed: {:?}", e),
|
||||
other => panic!("Expected WriteToNetwork, got {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Receive — may need a few rounds for control packets
|
||||
for _ in 0..10 {
|
||||
if let Some((data, addr, _)) = try_recv_decap(receiver_tunn, receiver_socket, 1000).await {
|
||||
return (data, addr);
|
||||
}
|
||||
}
|
||||
panic!("Did not receive decrypted IP packet");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 1: Single client ↔ server bidirectional data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_single_client_bidirectional() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
let client_addr = client_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, None, None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, None, None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
// Client → Server
|
||||
let pkt_c2s = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"Hello from client!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt_c2s,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt_c2s.len()], &pkt_c2s[..]);
|
||||
|
||||
// Server → Client
|
||||
let pkt_s2c = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2), b"Hello from server!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut server_tunn, &mut client_tunn,
|
||||
&server_socket, &client_socket,
|
||||
client_addr, &pkt_s2c,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 1));
|
||||
assert_eq!(&decrypted[..pkt_s2c.len()], &pkt_s2c[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 2: Two clients ↔ one server (peer routing)
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_two_clients_peer_routing() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client1_pub_b64, client1_priv_b64) = generate_wg_keypair();
|
||||
let (client2_pub_b64, client2_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, _) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client1_public, client1_secret) = parse_key_pair(&client1_pub_b64, &client1_priv_b64);
|
||||
let (client2_public, client2_secret) = parse_key_pair(&client2_pub_b64, &client2_priv_b64);
|
||||
|
||||
// Separate server socket per peer to avoid UDP mux complexity in test
|
||||
let server_socket_1 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_socket_2 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client1_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client2_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr_1 = server_socket_1.local_addr().unwrap();
|
||||
let server_addr_2 = server_socket_2.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn_1 = Tunn::new(clone_secret(&server_priv_b64), client1_public, None, None, 0, None);
|
||||
let mut server_tunn_2 = Tunn::new(clone_secret(&server_priv_b64), client2_public, None, None, 1, None);
|
||||
let mut client1_tunn = Tunn::new(client1_secret, server_public.clone(), None, None, 2, None);
|
||||
let mut client2_tunn = Tunn::new(client2_secret, server_public, None, None, 3, None);
|
||||
|
||||
do_handshake(&mut client1_tunn, &mut server_tunn_1, &client1_socket, &server_socket_1, server_addr_1).await;
|
||||
do_handshake(&mut client2_tunn, &mut server_tunn_2, &client2_socket, &server_socket_2, server_addr_2).await;
|
||||
|
||||
// Client 1 → Server
|
||||
let pkt1 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"From client 1");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client1_tunn, &mut server_tunn_1,
|
||||
&client1_socket, &server_socket_1,
|
||||
server_addr_1, &pkt1,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt1.len()], &pkt1[..]);
|
||||
|
||||
// Client 2 → Server
|
||||
let pkt2 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 3), Ipv4Addr::new(10, 0, 0, 1), b"From client 2");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client2_tunn, &mut server_tunn_2,
|
||||
&client2_socket, &server_socket_2,
|
||||
server_addr_2, &pkt2,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 3));
|
||||
assert_eq!(&decrypted[..pkt2.len()], &pkt2[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 3: Preshared key handshake + data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_preshared_key() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let psk: [u8; 32] = rand::random();
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, Some(psk), None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, Some(psk), None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
let pkt = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"PSK-protected data");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt.len()], &pkt[..]);
|
||||
}
|
||||
284
test/test.flowcontrol.node.ts
Normal file
284
test/test.flowcontrol.node.ts
Normal file
@@ -0,0 +1,284 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let client: VpnClient;
|
||||
let clientBundle: IClientConfigBundle;
|
||||
const extraClients: VpnClient[] = [];
|
||||
const extraBundles: IClientConfigBundle[] = [];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server', async () => {
|
||||
serverPort = await findFreePort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
// Phase 1: start the daemon bridge
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
|
||||
// Phase 2: generate a keypair
|
||||
keypair = await server.generateKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
|
||||
// Phase 3: start the VPN listener (empty clients, will use createClient at runtime)
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
// Verify server is now running
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Phase 4: create the first client via the hub
|
||||
clientBundle = await server.createClient({ clientId: 'test-client-0' });
|
||||
expect(clientBundle.secrets.noisePrivateKey).toBeTypeofString();
|
||||
expect(clientBundle.smartvpnConfig.clientPublicKey).toBeTypeofString();
|
||||
});
|
||||
|
||||
tap.test('single client connects and gets IP', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: clientBundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: clientBundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
// Verify client status
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(1);
|
||||
expect(clients[0].assignedIp).toEqual(result.assignedIp);
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange', async () => {
|
||||
// Wait for at least 2 keepalive cycles (interval=3s, so 8s should be enough)
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
const serverStats = await server.getStatistics();
|
||||
expect(serverStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
expect(serverStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify per-client keepalive tracking
|
||||
const clients = await server.listClients();
|
||||
expect(clients[0].keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('connection quality telemetry', async () => {
|
||||
const quality = await client.getConnectionQuality();
|
||||
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
expect(quality.minRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.maxRttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.lossRatio).toBeTypeofNumber();
|
||||
expect(['healthy', 'degraded', 'critical']).toContain(quality.linkHealth);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: set and verify', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
// Set a tight rate limit
|
||||
await server.setClientRateLimit(clientId, 100, 100);
|
||||
|
||||
// Verify via telemetry
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(100);
|
||||
expect(telemetry.burstBytes).toEqual(100);
|
||||
expect(telemetry.clientId).toEqual(clientId);
|
||||
});
|
||||
|
||||
tap.test('rate limiting: removal', async () => {
|
||||
const clients = await server.listClients();
|
||||
const clientId = clients[0].clientId;
|
||||
|
||||
await server.removeClientRateLimit(clientId);
|
||||
|
||||
// Verify telemetry no longer shows rate limit
|
||||
const telemetry = await server.getClientTelemetry(clientId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
expect(telemetry.burstBytes).toBeNullOrUndefined();
|
||||
|
||||
// Connection still healthy
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('5 concurrent clients', async () => {
|
||||
const assignedIps = new Set<string>();
|
||||
|
||||
// Get the first client's IP
|
||||
const existingClients = await server.listClients();
|
||||
assignedIps.add(existingClients[0].assignedIp);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const bundle = await server.createClient({ clientId: `test-client-${i + 1}` });
|
||||
extraBundles.push(bundle);
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
const result = await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
assignedIps.add(result.assignedIp);
|
||||
extraClients.push(c);
|
||||
}
|
||||
|
||||
// All IPs should be unique (6 total: original + 5 new)
|
||||
expect(assignedIps.size).toEqual(6);
|
||||
|
||||
// Server should see 6 clients
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 6;
|
||||
});
|
||||
const allClients = await server.listClients();
|
||||
expect(allClients.length).toEqual(6);
|
||||
});
|
||||
|
||||
tap.test('client disconnect tracking', async () => {
|
||||
// Disconnect 3 of the 5 extra clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = extraClients[i];
|
||||
await c.disconnect();
|
||||
c.stop();
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 3;
|
||||
}, 15000);
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(3);
|
||||
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(6);
|
||||
});
|
||||
|
||||
tap.test('server-side client disconnection', async () => {
|
||||
const clients = await server.listClients();
|
||||
// Pick one of the remaining extra clients (not the original)
|
||||
const targetClient = clients.find((c) => {
|
||||
// Find a client that belongs to extraClients[3] or extraClients[4]
|
||||
return c.clientId !== clients[0].clientId;
|
||||
});
|
||||
expect(targetClient).toBeTruthy();
|
||||
|
||||
await server.disconnectClient(targetClient!.clientId);
|
||||
|
||||
// Wait for server to update
|
||||
await waitFor(async () => {
|
||||
const remaining = await server.listClients();
|
||||
return remaining.length === 2;
|
||||
});
|
||||
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(2);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop the original client
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
|
||||
// Stop remaining extra clients
|
||||
for (const c of extraClients) {
|
||||
if (c.running) {
|
||||
try {
|
||||
await c.disconnect();
|
||||
} catch {
|
||||
// May already be disconnected
|
||||
}
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Stop the server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
362
test/test.loadtest.node.ts
Normal file
362
test/test.loadtest.node.ts
Normal file
@@ -0,0 +1,362 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as stream from 'stream';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ThrottleProxy (adapted from remoteingress)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class ThrottleTransform extends stream.Transform {
|
||||
private bytesPerSec: number;
|
||||
private bucket: number;
|
||||
private lastRefill: number;
|
||||
private destroyed_: boolean = false;
|
||||
|
||||
constructor(bytesPerSecond: number) {
|
||||
super();
|
||||
this.bytesPerSec = bytesPerSecond;
|
||||
this.bucket = bytesPerSecond;
|
||||
this.lastRefill = Date.now();
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: stream.TransformCallback) {
|
||||
if (this.destroyed_) return;
|
||||
|
||||
const now = Date.now();
|
||||
const elapsed = (now - this.lastRefill) / 1000;
|
||||
this.bucket = Math.min(this.bytesPerSec, this.bucket + elapsed * this.bytesPerSec);
|
||||
this.lastRefill = now;
|
||||
|
||||
if (chunk.length <= this.bucket) {
|
||||
this.bucket -= chunk.length;
|
||||
callback(null, chunk);
|
||||
} else {
|
||||
const deficit = chunk.length - this.bucket;
|
||||
this.bucket = 0;
|
||||
const delayMs = Math.min((deficit / this.bytesPerSec) * 1000, 1000);
|
||||
setTimeout(() => {
|
||||
if (this.destroyed_) { callback(); return; }
|
||||
this.lastRefill = Date.now();
|
||||
this.bucket = 0;
|
||||
callback(null, chunk);
|
||||
}, delayMs);
|
||||
}
|
||||
}
|
||||
|
||||
_destroy(err: Error | null, callback: (error: Error | null) => void) {
|
||||
this.destroyed_ = true;
|
||||
callback(err);
|
||||
}
|
||||
}
|
||||
|
||||
interface ThrottleProxy {
|
||||
server: net.Server;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
async function startThrottleProxy(
|
||||
listenPort: number,
|
||||
targetHost: string,
|
||||
targetPort: number,
|
||||
bytesPerSecond: number,
|
||||
): Promise<ThrottleProxy> {
|
||||
const connections = new Set<net.Socket>();
|
||||
const server = net.createServer((clientSock) => {
|
||||
connections.add(clientSock);
|
||||
const upstream = net.createConnection({ host: targetHost, port: targetPort });
|
||||
connections.add(upstream);
|
||||
|
||||
const throttleUp = new ThrottleTransform(bytesPerSecond);
|
||||
const throttleDown = new ThrottleTransform(bytesPerSecond);
|
||||
|
||||
clientSock.pipe(throttleUp).pipe(upstream);
|
||||
upstream.pipe(throttleDown).pipe(clientSock);
|
||||
|
||||
let cleaned = false;
|
||||
const cleanup = () => {
|
||||
if (cleaned) return;
|
||||
cleaned = true;
|
||||
throttleUp.destroy();
|
||||
throttleDown.destroy();
|
||||
clientSock.destroy();
|
||||
upstream.destroy();
|
||||
connections.delete(clientSock);
|
||||
connections.delete(upstream);
|
||||
};
|
||||
clientSock.on('error', () => cleanup());
|
||||
upstream.on('error', () => cleanup());
|
||||
throttleUp.on('error', () => cleanup());
|
||||
throttleDown.on('error', () => cleanup());
|
||||
clientSock.on('close', () => cleanup());
|
||||
upstream.on('close', () => cleanup());
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => server.listen(listenPort, '127.0.0.1', resolve));
|
||||
return {
|
||||
server,
|
||||
close: async () => {
|
||||
for (const c of connections) c.destroy();
|
||||
connections.clear();
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let proxyPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let throttle: ThrottleProxy;
|
||||
const allClients: VpnClient[] = [];
|
||||
|
||||
let clientCounter = 0;
|
||||
async function createConnectedClient(port: number): Promise<VpnClient> {
|
||||
clientCounter++;
|
||||
const bundle = await server.createClient({ clientId: `load-client-${clientCounter}` });
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${port}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
allClients.push(c);
|
||||
return c;
|
||||
}
|
||||
|
||||
async function stopClient(c: VpnClient): Promise<void> {
|
||||
if (c.running) {
|
||||
try { await c.disconnect(); } catch { /* already disconnected */ }
|
||||
c.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start throttled VPN tunnel (1 MB/s)', async () => {
|
||||
serverPort = await findFreePort();
|
||||
proxyPort = await findFreePort();
|
||||
|
||||
// Start VPN server
|
||||
server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.8.0.0/24',
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Start throttle proxy: 1 MB/s
|
||||
throttle = await startThrottleProxy(proxyPort, '127.0.0.1', serverPort, 1024 * 1024);
|
||||
});
|
||||
|
||||
tap.test('throttled connection: handshake succeeds through throttle', async () => {
|
||||
const client = await createConnectedClient(proxyPort);
|
||||
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
expect(status.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 1;
|
||||
});
|
||||
});
|
||||
|
||||
tap.test('sustained keepalive under throttle', async () => {
|
||||
// Wait for at least 2 keepalive cycles (3s interval)
|
||||
await delay(8000);
|
||||
|
||||
const client = allClients[0];
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Throttle adds latency — RTT should be measurable
|
||||
const quality = await client.getConnectionQuality();
|
||||
expect(quality.srttMs).toBeGreaterThanOrEqual(0);
|
||||
expect(quality.jitterMs).toBeTypeofNumber();
|
||||
});
|
||||
|
||||
tap.test('3 concurrent throttled clients', async () => {
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await createConnectedClient(proxyPort);
|
||||
}
|
||||
|
||||
// All 4 clients should be visible
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length === 4;
|
||||
});
|
||||
|
||||
const clients = await server.listClients();
|
||||
expect(clients.length).toEqual(4);
|
||||
|
||||
// Verify all IPs are unique
|
||||
const ips = new Set(clients.map((c) => c.assignedIp));
|
||||
expect(ips.size).toEqual(4);
|
||||
});
|
||||
|
||||
tap.test('rate limiting combined with network throttle', async () => {
|
||||
const clients = await server.listClients();
|
||||
const targetId = clients[0].clientId;
|
||||
|
||||
// Set rate limit on first client
|
||||
await server.setClientRateLimit(targetId, 500, 500);
|
||||
const telemetry = await server.getClientTelemetry(targetId);
|
||||
expect(telemetry.rateLimitBytesPerSec).toEqual(500);
|
||||
expect(telemetry.burstBytes).toEqual(500);
|
||||
|
||||
// Verify another client has no rate limit
|
||||
const otherTelemetry = await server.getClientTelemetry(clients[1].clientId);
|
||||
expect(otherTelemetry.rateLimitBytesPerSec).toBeNullOrUndefined();
|
||||
|
||||
// Clean up the rate limit
|
||||
await server.removeClientRateLimit(targetId);
|
||||
});
|
||||
|
||||
tap.test('burst waves: 3 waves of 3 clients', async () => {
|
||||
const initialCount = (await server.listClients()).length;
|
||||
|
||||
for (let wave = 0; wave < 3; wave++) {
|
||||
const waveClients: VpnClient[] = [];
|
||||
|
||||
// Connect 3 clients
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const c = await createConnectedClient(proxyPort);
|
||||
waveClients.push(c);
|
||||
}
|
||||
|
||||
// Verify all connected
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount + 3;
|
||||
});
|
||||
|
||||
// Disconnect all wave clients
|
||||
for (const c of waveClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
// Wait for server to detect disconnections
|
||||
await waitFor(async () => {
|
||||
const all = await server.listClients();
|
||||
return all.length === initialCount;
|
||||
}, 15000);
|
||||
|
||||
await delay(500);
|
||||
}
|
||||
|
||||
// Verify total connections accumulated
|
||||
const stats = await server.getStatistics();
|
||||
expect(stats.totalConnections).toBeGreaterThanOrEqual(9 + initialCount);
|
||||
|
||||
// Original clients still connected
|
||||
const remaining = await server.listClients();
|
||||
expect(remaining.length).toEqual(initialCount);
|
||||
});
|
||||
|
||||
tap.test('aggressive throttle: 10 KB/s', async () => {
|
||||
// Close current throttle proxy and start an aggressive one
|
||||
await throttle.close();
|
||||
const aggressivePort = await findFreePort();
|
||||
throttle = await startThrottleProxy(aggressivePort, '127.0.0.1', serverPort, 10 * 1024);
|
||||
|
||||
// Connect a client through the aggressive throttle
|
||||
const client = await createConnectedClient(aggressivePort);
|
||||
const status = await client.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Wait for keepalive exchange (might take longer due to throttle)
|
||||
await delay(10000);
|
||||
|
||||
const stats = await client.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(stats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('post-load health: direct connection still works', async () => {
|
||||
// Server should still be healthy after all load tests
|
||||
const serverStatus = await server.getStatus();
|
||||
expect(serverStatus.state).toEqual('connected');
|
||||
|
||||
// Connect one more client directly (no throttle)
|
||||
const directClient = await createConnectedClient(serverPort);
|
||||
const status = await directClient.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
await delay(5000);
|
||||
|
||||
const stats = await directClient.getStatistics();
|
||||
expect(stats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('teardown: stop all', async () => {
|
||||
// Stop all clients
|
||||
for (const c of allClients) {
|
||||
await stopClient(c);
|
||||
}
|
||||
|
||||
await delay(500);
|
||||
|
||||
// Close throttle proxy
|
||||
if (throttle) {
|
||||
await throttle.close();
|
||||
}
|
||||
|
||||
// Stop server
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
await delay(500);
|
||||
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
258
test/test.quic.node.ts
Normal file
258
test/test.quic.node.ts
Normal file
@@ -0,0 +1,258 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as dgram from 'dgram';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
async function findFreeUdpPort(): Promise<number> {
|
||||
const sock = dgram.createSocket('udp4');
|
||||
await new Promise<void>((resolve) => sock.bind(0, '127.0.0.1', resolve));
|
||||
const port = (sock.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => sock.close(resolve));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let wsPort: number;
|
||||
let quicPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: QUIC-only server + QUIC client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server in QUIC mode', async () => {
|
||||
quicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${quicPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.9.0.0/24',
|
||||
transportMode: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('QUIC client connects and gets IP', async () => {
|
||||
const bundle = await server.createClient({ clientId: 'quic-client-1' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${quicPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.9.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop QUIC server', async () => {
|
||||
await server.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: dual-mode server (both) + auto client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let dualServer: VpnServer;
|
||||
let dualWsPort: number;
|
||||
let dualQuicPort: number;
|
||||
let dualKeypair: IVpnKeypair;
|
||||
|
||||
tap.test('setup: start VPN server in both mode', async () => {
|
||||
dualWsPort = await findFreePort();
|
||||
dualQuicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
dualServer = new VpnServer(options);
|
||||
|
||||
const started = await dualServer['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
dualKeypair = await dualServer.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${dualWsPort}`,
|
||||
privateKey: dualKeypair.privateKey,
|
||||
publicKey: dualKeypair.publicKey,
|
||||
subnet: '10.10.0.0/24',
|
||||
transportMode: 'both',
|
||||
quicListenAddr: `127.0.0.1:${dualQuicPort}`,
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await dualServer['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await dualServer.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('auto client connects to dual-mode server (QUIC preferred)', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-auto-client' });
|
||||
|
||||
// "auto" mode (default): tries QUIC first at same host:port, falls back to WS
|
||||
// Since the WS port and QUIC port differ, auto will try QUIC on WS port (fail),
|
||||
// then fall back to WebSocket
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${dualWsPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
// transport defaults to 'auto'
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await dualServer.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-quic-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange over QUIC', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-keepalive-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
await client.start();
|
||||
|
||||
await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
// Wait for keepalive exchange
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop dual-mode server', async () => {
|
||||
await dualServer.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -2,10 +2,17 @@ import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { VpnConfig } from '../ts/index.js';
|
||||
import type { IVpnClientConfig, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// Valid 32-byte base64 keys for testing
|
||||
const TEST_KEY_A = 'YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE=';
|
||||
const TEST_KEY_B = 'YmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmI=';
|
||||
const TEST_KEY_C = 'Y2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2M=';
|
||||
|
||||
tap.test('VpnConfig: validate valid client config', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
@@ -16,7 +23,9 @@ tap.test('VpnConfig: validate valid client config', async () => {
|
||||
|
||||
tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
const config = {
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -31,7 +40,9 @@ tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'http://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -43,10 +54,28 @@ tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config without clientPrivateKey', async () => {
|
||||
const config = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('clientPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
mtu: 100,
|
||||
};
|
||||
let threw = false;
|
||||
@@ -62,7 +91,9 @@ tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['not-an-ip'],
|
||||
};
|
||||
let threw = false;
|
||||
@@ -78,12 +109,15 @@ tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
tap.test('VpnConfig: validate valid server config', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
enableNat: true,
|
||||
clients: [
|
||||
{ clientId: 'test-client', publicKey: TEST_KEY_C },
|
||||
],
|
||||
};
|
||||
// Should not throw
|
||||
VpnConfig.validateServerConfig(config);
|
||||
@@ -92,8 +126,8 @@ tap.test('VpnConfig: validate valid server config', async () => {
|
||||
tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: 'invalid',
|
||||
};
|
||||
let threw = false;
|
||||
@@ -109,7 +143,7 @@ tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
const config = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
} as IVpnServerConfig;
|
||||
let threw = false;
|
||||
@@ -122,4 +156,24 @@ tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject server config with invalid client publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
clients: [
|
||||
{ clientId: 'bad-client', publicKey: 'short-key' },
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
|
||||
353
test/test.wireguard.node.ts
Normal file
353
test/test.wireguard.node.ts
Normal file
@@ -0,0 +1,353 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import {
|
||||
VpnConfig,
|
||||
VpnServer,
|
||||
WgConfigGenerator,
|
||||
} from '../ts/index.js';
|
||||
import type {
|
||||
IVpnClientConfig,
|
||||
IVpnServerConfig,
|
||||
IVpnServerOptions,
|
||||
IWgPeerConfig,
|
||||
} from '../ts/index.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — client
|
||||
// ============================================================================
|
||||
|
||||
// A valid 32-byte key in base64 (44 chars)
|
||||
const VALID_KEY = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=';
|
||||
const VALID_KEY_2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=';
|
||||
|
||||
tap.test('WG client config: valid wireguard config passes validation', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '', // not needed for WG
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['0.0.0.0/0'],
|
||||
};
|
||||
VpnConfig.validateClientConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgPrivateKey', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgAddress', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgAddress');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgEndpoint', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgEndpoint');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid key length', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: 'tooshort',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('44 characters');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid CIDR in allowedIps', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['not-a-cidr'],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('CIDR');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — server
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WG server config: valid config passes validation', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
VpnConfig.validateServerConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects empty wgPeers', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPeers');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects peer without publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: '',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects invalid wgListenPort', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgListenPort: 0,
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgListenPort');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard keypair generation via daemon
|
||||
// ============================================================================
|
||||
|
||||
let server: VpnServer;
|
||||
|
||||
tap.test('WG: spawn server daemon for keypair generation', async () => {
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns valid keypair', async () => {
|
||||
const keypair = await server.generateWgKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
// WireGuard keys: base64 of 32 bytes = 44 characters
|
||||
expect(keypair.publicKey.length).toEqual(44);
|
||||
expect(keypair.privateKey.length).toEqual(44);
|
||||
// Verify they decode to 32 bytes
|
||||
const pubBuf = Buffer.from(keypair.publicKey, 'base64');
|
||||
const privBuf = Buffer.from(keypair.privateKey, 'base64');
|
||||
expect(pubBuf.length).toEqual(32);
|
||||
expect(privBuf.length).toEqual(32);
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns unique keys each time', async () => {
|
||||
const kp1 = await server.generateWgKeypair();
|
||||
const kp2 = await server.generateWgKeypair();
|
||||
expect(kp1.publicKey).not.toEqual(kp2.publicKey);
|
||||
expect(kp1.privateKey).not.toEqual(kp2.privateKey);
|
||||
});
|
||||
|
||||
tap.test('WG: stop server daemon', async () => {
|
||||
server.stop();
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config file generation
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'clientPrivateKeyBase64====================',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
peer: {
|
||||
publicKey: 'serverPublicKeyBase64====================',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0', '::/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('PrivateKey = clientPrivateKeyBase64====================');
|
||||
expect(conf).toContain('Address = 10.8.0.2/24');
|
||||
expect(conf).toContain('DNS = 1.1.1.1, 8.8.8.8');
|
||||
expect(conf).toContain('MTU = 1420');
|
||||
expect(conf).toContain('[Peer]');
|
||||
expect(conf).toContain('PublicKey = serverPublicKeyBase64====================');
|
||||
expect(conf).toContain('Endpoint = vpn.example.com:51820');
|
||||
expect(conf).toContain('AllowedIPs = 0.0.0.0/0, ::/0');
|
||||
expect(conf).toContain('PersistentKeepalive = 25');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config without optional fields', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'key1',
|
||||
address: '10.0.0.2/32',
|
||||
peer: {
|
||||
publicKey: 'key2',
|
||||
endpoint: 'server:51820',
|
||||
allowedIps: ['10.0.0.0/24'],
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).not.toContain('DNS');
|
||||
expect(conf).not.toContain('MTU');
|
||||
expect(conf).not.toContain('PresharedKey');
|
||||
expect(conf).not.toContain('PersistentKeepalive');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config with NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
dns: ['1.1.1.1'],
|
||||
enableNat: true,
|
||||
natInterface: 'ens3',
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peer1PubKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
presharedKey: 'psk1',
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
{
|
||||
publicKey: 'peer2PubKey',
|
||||
allowedIps: ['10.8.0.3/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('ListenPort = 51820');
|
||||
expect(conf).toContain('PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ens3 -j MASQUERADE');
|
||||
expect(conf).toContain('PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ens3 -j MASQUERADE');
|
||||
// Two [Peer] sections
|
||||
const peerCount = (conf.match(/\[Peer\]/g) || []).length;
|
||||
expect(peerCount).toEqual(2);
|
||||
expect(conf).toContain('PresharedKey = psk1');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config without NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peerKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).not.toContain('PostUp');
|
||||
expect(conf).not.toContain('PostDown');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartvpn',
|
||||
version: '1.1.0',
|
||||
version: '1.8.0',
|
||||
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
|
||||
}
|
||||
|
||||
@@ -4,3 +4,4 @@ export { VpnClient } from './smartvpn.classes.vpnclient.js';
|
||||
export { VpnServer } from './smartvpn.classes.vpnserver.js';
|
||||
export { VpnConfig } from './smartvpn.classes.vpnconfig.js';
|
||||
export { VpnInstaller } from './smartvpn.classes.vpninstaller.js';
|
||||
export { WgConfigGenerator } from './smartvpn.classes.wgconfig.js';
|
||||
|
||||
@@ -12,14 +12,54 @@ export class VpnConfig {
|
||||
* Validate a client config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateClientConfig(config: IVpnClientConfig): void {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws://');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
if (config.transport === 'wireguard') {
|
||||
// WireGuard-specific validation
|
||||
if (!config.wgPrivateKey) {
|
||||
throw new Error('VpnConfig: wgPrivateKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.wgPrivateKey, 'wgPrivateKey');
|
||||
if (!config.wgAddress) {
|
||||
throw new Error('VpnConfig: wgAddress is required for WireGuard transport');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.serverPublicKey, 'serverPublicKey');
|
||||
if (!config.wgEndpoint) {
|
||||
throw new Error('VpnConfig: wgEndpoint is required for WireGuard transport');
|
||||
}
|
||||
if (config.wgPresharedKey) {
|
||||
VpnConfig.validateBase64Key(config.wgPresharedKey, 'wgPresharedKey');
|
||||
}
|
||||
if (config.wgAllowedIps) {
|
||||
for (const cidr of config.wgAllowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
// For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss://
|
||||
if (config.transport !== 'quic') {
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)');
|
||||
}
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
}
|
||||
// Noise IK requires client keypair
|
||||
if (!config.clientPrivateKey) {
|
||||
throw new Error('VpnConfig: clientPrivateKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPrivateKey, 'clientPrivateKey');
|
||||
if (!config.clientPublicKey) {
|
||||
throw new Error('VpnConfig: clientPublicKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPublicKey, 'clientPublicKey');
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -40,20 +80,63 @@ export class VpnConfig {
|
||||
* Validate a server config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateServerConfig(config: IVpnServerConfig): void {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
if (config.transportMode === 'wireguard') {
|
||||
// WireGuard server validation
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.privateKey, 'privateKey');
|
||||
if (!config.wgPeers || config.wgPeers.length === 0) {
|
||||
throw new Error('VpnConfig: at least one wgPeers entry is required for WireGuard mode');
|
||||
}
|
||||
for (const peer of config.wgPeers) {
|
||||
if (!peer.publicKey) {
|
||||
throw new Error('VpnConfig: peer publicKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(peer.publicKey, 'peer.publicKey');
|
||||
if (!peer.allowedIps || peer.allowedIps.length === 0) {
|
||||
throw new Error('VpnConfig: peer allowedIps is required');
|
||||
}
|
||||
for (const cidr of peer.allowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid peer allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
if (peer.presharedKey) {
|
||||
VpnConfig.validateBase64Key(peer.presharedKey, 'peer.presharedKey');
|
||||
}
|
||||
}
|
||||
if (config.wgListenPort !== undefined && (config.wgListenPort < 1 || config.wgListenPort > 65535)) {
|
||||
throw new Error('VpnConfig: wgListenPort must be between 1 and 65535');
|
||||
}
|
||||
} else {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
}
|
||||
// Validate client entries if provided
|
||||
if (config.clients) {
|
||||
for (const client of config.clients) {
|
||||
if (!client.clientId) {
|
||||
throw new Error('VpnConfig: client entry must have a clientId');
|
||||
}
|
||||
if (!client.publicKey) {
|
||||
throw new Error(`VpnConfig: client '${client.clientId}' must have a publicKey`);
|
||||
}
|
||||
VpnConfig.validateBase64Key(client.publicKey, `client '${client.clientId}' publicKey`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -101,4 +184,41 @@ export class VpnConfig {
|
||||
const prefixNum = parseInt(prefix, 10);
|
||||
return !isNaN(prefixNum) && prefixNum >= 0 && prefixNum <= 32;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a CIDR string (IPv4 or IPv6).
|
||||
*/
|
||||
private static isValidCidr(cidr: string): boolean {
|
||||
const parts = cidr.split('/');
|
||||
if (parts.length !== 2) return false;
|
||||
const prefixNum = parseInt(parts[1], 10);
|
||||
if (isNaN(prefixNum) || prefixNum < 0) return false;
|
||||
// IPv4
|
||||
if (VpnConfig.isValidIp(parts[0])) {
|
||||
return prefixNum <= 32;
|
||||
}
|
||||
// IPv6 (basic check)
|
||||
if (parts[0].includes(':')) {
|
||||
return prefixNum <= 128;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a base64-encoded 32-byte key (WireGuard X25519 format).
|
||||
*/
|
||||
private static validateBase64Key(key: string, fieldName: string): void {
|
||||
if (key.length !== 44) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must be 44 characters (base64 of 32 bytes), got ${key.length}`);
|
||||
}
|
||||
try {
|
||||
const buf = Buffer.from(key, 'base64');
|
||||
if (buf.length !== 32) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must decode to 32 bytes, got ${buf.length}`);
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.startsWith('VpnConfig:')) throw e;
|
||||
throw new Error(`VpnConfig: ${fieldName} is not valid base64`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import type {
|
||||
IVpnClientInfo,
|
||||
IVpnKeypair,
|
||||
IVpnClientTelemetry,
|
||||
IWgPeerConfig,
|
||||
IWgPeerInfo,
|
||||
IClientEntry,
|
||||
IClientConfigBundle,
|
||||
TVpnServerCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -121,6 +125,110 @@ export class VpnServer extends plugins.events.EventEmitter {
|
||||
return this.bridge.sendCommand('getClientTelemetry', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a WireGuard-compatible X25519 keypair.
|
||||
*/
|
||||
public async generateWgKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateWgKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a WireGuard peer (server must be running in wireguard mode).
|
||||
*/
|
||||
public async addWgPeer(peer: IWgPeerConfig): Promise<void> {
|
||||
await this.bridge.sendCommand('addWgPeer', { peer });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a WireGuard peer by public key.
|
||||
*/
|
||||
public async removeWgPeer(publicKey: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeWgPeer', { publicKey });
|
||||
}
|
||||
|
||||
/**
|
||||
* List WireGuard peers with stats.
|
||||
*/
|
||||
public async listWgPeers(): Promise<IWgPeerInfo[]> {
|
||||
const result = await this.bridge.sendCommand('listWgPeers', {} as Record<string, never>);
|
||||
return result.peers;
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ─────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a new client. Generates keypairs, assigns IP, returns full config bundle.
|
||||
* The secrets (private keys) are only returned at creation time.
|
||||
*/
|
||||
public async createClient(opts: Partial<IClientEntry>): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('createClient', { client: opts });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a registered client (also disconnects if connected).
|
||||
*/
|
||||
public async removeClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a registered client by ID.
|
||||
*/
|
||||
public async getClient(clientId: string): Promise<IClientEntry> {
|
||||
return this.bridge.sendCommand('getClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* List all registered clients.
|
||||
*/
|
||||
public async listRegisteredClients(): Promise<IClientEntry[]> {
|
||||
const result = await this.bridge.sendCommand('listRegisteredClients', {} as Record<string, never>);
|
||||
return result.clients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a registered client's fields (ACLs, tags, description, etc.).
|
||||
*/
|
||||
public async updateClient(clientId: string, update: Partial<IClientEntry>): Promise<void> {
|
||||
await this.bridge.sendCommand('updateClient', { clientId, update });
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a previously disabled client.
|
||||
*/
|
||||
public async enableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('enableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable a client (also disconnects if connected).
|
||||
*/
|
||||
public async disableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('disableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
*/
|
||||
public async rotateClientKey(clientId: string): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('rotateClientKey', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Export a client config (without secrets) in the specified format.
|
||||
*/
|
||||
public async exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): Promise<string> {
|
||||
const result = await this.bridge.sendCommand('exportClientConfig', { clientId, format });
|
||||
return result.config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a standalone Noise IK keypair (not tied to a client).
|
||||
*/
|
||||
public async generateClientKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateClientKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
123
ts/smartvpn.classes.wgconfig.ts
Normal file
123
ts/smartvpn.classes.wgconfig.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import type { IWgPeerConfig } from './smartvpn.interfaces.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard .conf file generator
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgClientConfOptions {
|
||||
/** Client private key (base64) */
|
||||
privateKey: string;
|
||||
/** Client TUN address with prefix (e.g. 10.8.0.2/24) */
|
||||
address: string;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Server peer config */
|
||||
peer: {
|
||||
publicKey: string;
|
||||
presharedKey?: string;
|
||||
endpoint: string;
|
||||
allowedIps: string[];
|
||||
persistentKeepalive?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface IWgServerConfOptions {
|
||||
/** Server private key (base64) */
|
||||
privateKey: string;
|
||||
/** Server TUN address with prefix (e.g. 10.8.0.1/24) */
|
||||
address: string;
|
||||
/** UDP listen port */
|
||||
listenPort: number;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Enable NAT — adds PostUp/PostDown iptables rules */
|
||||
enableNat?: boolean;
|
||||
/** Network interface for NAT (e.g. eth0). Auto-detected if omitted. */
|
||||
natInterface?: string;
|
||||
/** Configured peers */
|
||||
peers: IWgPeerConfig[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates standard WireGuard .conf files compatible with wg-quick,
|
||||
* WireGuard iOS/Android apps, and other standard WireGuard clients.
|
||||
*/
|
||||
export class WgConfigGenerator {
|
||||
/**
|
||||
* Generate a client .conf file content.
|
||||
*/
|
||||
public static generateClientConfig(opts: IWgClientConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${opts.peer.publicKey}`);
|
||||
if (opts.peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${opts.peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`Endpoint = ${opts.peer.endpoint}`);
|
||||
lines.push(`AllowedIPs = ${opts.peer.allowedIps.join(', ')}`);
|
||||
if (opts.peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${opts.peer.persistentKeepalive}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a server .conf file content.
|
||||
*/
|
||||
public static generateServerConfig(opts: IWgServerConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
lines.push(`ListenPort = ${opts.listenPort}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
if (opts.enableNat) {
|
||||
const iface = opts.natInterface || 'eth0';
|
||||
lines.push(`PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
lines.push(`PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
}
|
||||
|
||||
for (const peer of opts.peers) {
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${peer.publicKey}`);
|
||||
if (peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`AllowedIPs = ${peer.allowedIps.join(', ')}`);
|
||||
if (peer.endpoint) {
|
||||
lines.push(`Endpoint = ${peer.endpoint}`);
|
||||
}
|
||||
if (peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${peer.persistentKeepalive}`);
|
||||
}
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
}
|
||||
@@ -24,14 +24,36 @@ export type TVpnTransportOptions = IVpnTransportStdio | IVpnTransportSocket;
|
||||
export interface IVpnClientConfig {
|
||||
/** Server WebSocket URL, e.g. wss://vpn.example.com/tunnel */
|
||||
serverUrl: string;
|
||||
/** Server's static public key (base64) for Noise NK handshake */
|
||||
/** Server's static public key (base64) for Noise IK handshake */
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — required for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
/** Optional DNS servers to use while connected */
|
||||
dns?: string[];
|
||||
/** Optional MTU for the TUN device */
|
||||
mtu?: number;
|
||||
/** Keepalive interval in seconds (default: 30) */
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', 'quic', or 'wireguard' */
|
||||
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
|
||||
/** For QUIC: SHA-256 hash of server certificate (base64) for cert pinning */
|
||||
serverCertHash?: string;
|
||||
/** WireGuard: client private key (base64, X25519) */
|
||||
wgPrivateKey?: string;
|
||||
/** WireGuard: client TUN address (e.g. 10.8.0.2) */
|
||||
wgAddress?: string;
|
||||
/** WireGuard: client TUN address prefix length (default: 24) */
|
||||
wgAddressPrefix?: number;
|
||||
/** WireGuard: preshared key (base64, optional) */
|
||||
wgPresharedKey?: string;
|
||||
/** WireGuard: persistent keepalive interval in seconds */
|
||||
wgPersistentKeepalive?: number;
|
||||
/** WireGuard: server endpoint (host:port, e.g. vpn.example.com:51820) */
|
||||
wgEndpoint?: string;
|
||||
/** WireGuard: allowed IPs (CIDR strings, e.g. ['0.0.0.0/0']) */
|
||||
wgAllowedIps?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnClientOptions {
|
||||
@@ -68,6 +90,18 @@ export interface IVpnServerConfig {
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
/** Default burst size for new clients (bytes). Omit for unlimited. */
|
||||
defaultBurstBytes?: number;
|
||||
/** Transport mode: 'both' (default, WS+QUIC), 'websocket', 'quic', or 'wireguard' */
|
||||
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
|
||||
/** QUIC listen address (host:port). Defaults to listenAddr. */
|
||||
quicListenAddr?: string;
|
||||
/** QUIC idle timeout in seconds (default: 30) */
|
||||
quicIdleTimeoutSecs?: number;
|
||||
/** WireGuard: UDP listen port (default: 51820) */
|
||||
wgListenPort?: number;
|
||||
/** WireGuard: configured peers */
|
||||
wgPeers?: IWgPeerConfig[];
|
||||
/** Pre-registered clients for Noise IK authentication */
|
||||
clients?: IClientEntry[];
|
||||
}
|
||||
|
||||
export interface IVpnServerOptions {
|
||||
@@ -118,6 +152,10 @@ export interface IVpnClientInfo {
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
/** Client's authenticated Noise IK public key (base64) */
|
||||
authenticatedKey: string;
|
||||
/** Registered client ID from the client registry */
|
||||
registeredClientId: string;
|
||||
}
|
||||
|
||||
export interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -177,6 +215,113 @@ export interface IVpnClientTelemetry {
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client Registry (Hub) types — aligned with SmartProxy IRouteSecurity pattern
|
||||
// ============================================================================
|
||||
|
||||
/** Per-client rate limiting. */
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Per-client security settings.
|
||||
* Mirrors SmartProxy's IRouteSecurity: ipAllowList/ipBlockList naming + deny-overrides-allow.
|
||||
* Adds VPN-specific destination filtering.
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5). */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins). */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all). */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins). */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client. */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting. */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server-side client definition — the central config object for the Hub.
|
||||
* Naming and structure aligned with SmartProxy's IRouteConfig / IRouteSecurity.
|
||||
*/
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
/** Security settings (ACLs, rate limits) */
|
||||
security?: IClientSecurity;
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
/** Assigned VPN IP address (set by server) */
|
||||
assignedIp?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete client config bundle — returned by createClient() and rotateClientKey().
|
||||
* Contains everything the client needs to connect.
|
||||
*/
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard-specific types
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgPeerConfig {
|
||||
/** Peer's public key (base64, X25519) */
|
||||
publicKey: string;
|
||||
/** Optional preshared key (base64) */
|
||||
presharedKey?: string;
|
||||
/** Allowed IP ranges (CIDR strings) */
|
||||
allowedIps: string[];
|
||||
/** Peer endpoint (host:port) — optional for server peers, required for client */
|
||||
endpoint?: string;
|
||||
/** Persistent keepalive interval in seconds */
|
||||
persistentKeepalive?: number;
|
||||
}
|
||||
|
||||
export interface IWgPeerInfo {
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
persistentKeepalive?: number;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
lastHandshakeTime?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// IPC Command maps (used by smartrust RustBridge<TCommands>)
|
||||
// ============================================================================
|
||||
@@ -201,6 +346,21 @@ export type TVpnServerCommands = {
|
||||
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
|
||||
removeClientRateLimit: { params: { clientId: string }; result: void };
|
||||
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
|
||||
generateWgKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
addWgPeer: { params: { peer: IWgPeerConfig }; result: void };
|
||||
removeWgPeer: { params: { publicKey: string }; result: void };
|
||||
listWgPeers: { params: Record<string, never>; result: { peers: IWgPeerInfo[] } };
|
||||
// Client Registry (Hub) commands
|
||||
createClient: { params: { client: Partial<IClientEntry> }; result: IClientConfigBundle };
|
||||
removeClient: { params: { clientId: string }; result: void };
|
||||
getClient: { params: { clientId: string }; result: IClientEntry };
|
||||
listRegisteredClients: { params: Record<string, never>; result: { clients: IClientEntry[] } };
|
||||
updateClient: { params: { clientId: string; update: Partial<IClientEntry> }; result: void };
|
||||
enableClient: { params: { clientId: string }; result: void };
|
||||
disableClient: { params: { clientId: string }; result: void };
|
||||
rotateClientKey: { params: { clientId: string }; result: IClientConfigBundle };
|
||||
exportClientConfig: { params: { clientId: string; format: 'smartvpn' | 'wireguard' }; result: { config: string } };
|
||||
generateClientKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"esModuleInterop": true,
|
||||
"verbatimModuleSyntax": true
|
||||
"verbatimModuleSyntax": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"exclude": [
|
||||
"dist_ts/**/*.d.ts"
|
||||
|
||||
Reference in New Issue
Block a user