Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e9cf575271 | |||
| 229db4be38 | |||
| e31086d0c2 | |||
| 01a0d8b9f4 | |||
| 187a69028b | |||
| 64dedd389e | |||
| 13d8cbe3fa | |||
| f46ea70286 | |||
| 26ee3634c8 | |||
| 049fa00563 | |||
| e4e59d72f9 | |||
| 51d33127bf | |||
| a4ba6806e5 | |||
| 6330921160 | |||
| e81dd377d8 |
53
changelog.md
53
changelog.md
@@ -1,5 +1,58 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-03-29 - 1.9.0 - feat(server)
|
||||
add PROXY protocol v2 support for real client IP handling and connection ACLs
|
||||
|
||||
- add PROXY protocol v2 parsing for WebSocket connections, including IPv4/IPv6 support, LOCAL command handling, and header read timeout protection
|
||||
- apply server-level connection IP block lists before the Noise handshake and enforce per-client source IP allow/block lists using the resolved remote address
|
||||
- expose proxy protocol configuration and remote client address fields in Rust and TypeScript interfaces, and document reverse-proxy usage in the README
|
||||
|
||||
## 2026-03-29 - 1.8.0 - feat(auth,client-registry)
|
||||
add Noise IK client authentication with managed client registry and per-client ACL controls
|
||||
|
||||
- switch the native tunnel handshake from Noise NK to Noise IK and require client keypairs in client configuration
|
||||
- add server-side client registry management APIs for creating, updating, disabling, rotating, listing, and exporting client configs
|
||||
- enforce client authorization from the registry during handshake and expose authenticated client metadata in server client info
|
||||
- introduce per-client security policies with source/destination ACLs and per-client rate limit settings
|
||||
- add Rust ACL matching support for exact IPs, CIDR ranges, wildcards, and IP ranges with test coverage
|
||||
|
||||
## 2026-03-29 - 1.7.0 - feat(rust-tests)
|
||||
add end-to-end WireGuard UDP integration tests and align TypeScript build configuration
|
||||
|
||||
- Add userspace Rust end-to-end tests that validate WireGuard handshake, encryption, peer isolation, and preshared-key data exchange over real UDP sockets.
|
||||
- Update the TypeScript build setup by removing the allowimplicitany build flag and explicitly including Node types in tsconfig.
|
||||
- Refresh development toolchain versions to support the updated test and build workflow.
|
||||
|
||||
## 2026-03-29 - 1.6.0 - feat(readme)
|
||||
document WireGuard transport support, configuration, and usage examples
|
||||
|
||||
- Expand the README from dual-transport to triple-transport support by adding WireGuard alongside WebSocket and QUIC
|
||||
- Add client and server WireGuard examples, including live peer management and .conf generation with WgConfigGenerator
|
||||
- Document new WireGuard-related API methods, config fields, transport modes, and security model details
|
||||
|
||||
## 2026-03-29 - 1.5.0 - feat(wireguard)
|
||||
add WireGuard transport support with management APIs and config generation
|
||||
|
||||
- add Rust WireGuard module integration using boringtun and route management through client/server management handlers
|
||||
- extend TypeScript client and server configuration schemas with WireGuard-specific options and validation
|
||||
- add server-side WireGuard peer management commands including keypair generation, peer add/remove, and peer listing
|
||||
- export a WireGuard config generator for producing client and server .conf files
|
||||
- add WireGuard-focused test coverage for config validation and config generation
|
||||
|
||||
## 2026-03-21 - 1.4.1 - fix(readme)
|
||||
preserve markdown line breaks in feature list
|
||||
|
||||
- Adds trailing spaces to the README feature list so each highlighted capability renders on its own line.
|
||||
|
||||
## 2026-03-19 - 1.4.0 - feat(vpn transport)
|
||||
add QUIC transport support with auto fallback to WebSocket
|
||||
|
||||
- introduces a transport abstraction in the Rust daemon so client and server can operate over WebSocket or QUIC
|
||||
- adds dual-mode server configuration with websocket, quic, and both transport modes plus QUIC idle timeout and listen address options
|
||||
- adds client transport selection with auto mode that attempts QUIC first and falls back to WebSocket
|
||||
- adds QUIC certificate hash pinning support and required Rust dependencies for QUIC and TLS
|
||||
- updates TypeScript interfaces, config validation, tests, and documentation to cover the new transport modes
|
||||
|
||||
## 2026-03-17 - 1.3.0 - feat(tests,client)
|
||||
add flow control and load test coverage and honor configured keepalive intervals
|
||||
|
||||
|
||||
19
package.json
19
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartvpn",
|
||||
"version": "1.3.0",
|
||||
"version": "1.9.0",
|
||||
"private": false,
|
||||
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
|
||||
"type": "module",
|
||||
@@ -10,7 +10,8 @@
|
||||
"main": "dist_ts/index.js",
|
||||
"typings": "dist_ts/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
|
||||
"build": "(tsbuild tsfolders) && (tsrust)",
|
||||
"test:before": "(tsrust)",
|
||||
"test": "tstest test/ --verbose",
|
||||
"buildDocs": "tsdoc"
|
||||
},
|
||||
@@ -28,15 +29,15 @@
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@push.rocks/smartrust": "^1.3.0",
|
||||
"@push.rocks/smartpath": "^5.0.18"
|
||||
"@push.rocks/smartpath": "^6.0.0",
|
||||
"@push.rocks/smartrust": "^1.3.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@git.zone/tsbuild": "^2.2.12",
|
||||
"@git.zone/tsrun": "^1.3.3",
|
||||
"@git.zone/tstest": "^1.0.96",
|
||||
"@git.zone/tsrust": "^1.0.29",
|
||||
"@types/node": "^22.0.0"
|
||||
"@git.zone/tsbuild": "^4.4.0",
|
||||
"@git.zone/tsrun": "^2.0.2",
|
||||
"@git.zone/tsrust": "^1.3.2",
|
||||
"@git.zone/tstest": "^3.6.3",
|
||||
"@types/node": "^25.5.0"
|
||||
},
|
||||
"files": [
|
||||
"ts/**/*",
|
||||
|
||||
3545
pnpm-lock.yaml
generated
3545
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
753
readme.md
753
readme.md
@@ -1,531 +1,394 @@
|
||||
# @push.rocks/smartvpn
|
||||
|
||||
A high-performance VPN with a **TypeScript control plane** and a **Rust data plane daemon**. Manage VPN connections with clean, fully-typed APIs while all networking heavy lifting — encryption, tunneling, QoS, rate limiting — runs at native speed in Rust.
|
||||
A high-performance VPN solution with a **TypeScript control plane** and a **Rust data plane daemon**. Enterprise-ready client authentication, triple transport support (WebSocket + QUIC + WireGuard), and a typed hub API for managing clients from code.
|
||||
|
||||
🔐 **Noise IK** mutual authentication — per-client X25519 keypairs, server-side registry
|
||||
🚀 **Triple transport**: WebSocket (Cloudflare-friendly), raw **QUIC** (datagrams), and **WireGuard** (standard protocol)
|
||||
🛡️ **ACL engine** — deny-overrides-allow IP filtering, aligned with SmartProxy conventions
|
||||
🔀 **PROXY protocol v2** — real client IPs behind reverse proxies (HAProxy, SmartProxy, Cloudflare Spectrum)
|
||||
📊 **Adaptive QoS**: per-client rate limiting, priority queues, connection quality tracking
|
||||
🔄 **Hub API**: one `createClient()` call generates keys, assigns IP, returns both SmartVPN + WireGuard configs
|
||||
📡 **Real-time telemetry**: RTT, jitter, loss ratio, link health — all via typed APIs
|
||||
|
||||
## Issue Reporting and Security
|
||||
|
||||
For reporting bugs, issues, or security vulnerabilities, please visit [community.foss.global/](https://community.foss.global/). This is the central community hub for all issue reporting. Developers who sign and comply with our contribution agreement and go through identification can also get a [code.foss.global/](https://code.foss.global/) account to submit Pull Requests directly.
|
||||
|
||||
## Install
|
||||
## Install 📦
|
||||
|
||||
```bash
|
||||
pnpm install @push.rocks/smartvpn
|
||||
# or
|
||||
npm install @push.rocks/smartvpn
|
||||
```
|
||||
|
||||
## 🏗️ Architecture
|
||||
The package ships with pre-compiled Rust binaries for **linux/amd64** and **linux/arm64**. No Rust toolchain required at runtime.
|
||||
|
||||
## Architecture 🏗️
|
||||
|
||||
```
|
||||
TypeScript (control plane) Rust (data plane)
|
||||
┌──────────────────────────┐ ┌────────────────────────────────────┐
|
||||
│ VpnClient / VpnServer │ │ smartvpn_daemon │
|
||||
│ └─ VpnBridge │──stdio/──▶ │ ├─ management (JSON IPC) │
|
||||
│ └─ RustBridge │ socket │ ├─ transport (WebSocket/TLS) │
|
||||
│ (smartrust) │ │ ├─ crypto (Noise NK + XCha20) │
|
||||
└──────────────────────────┘ │ ├─ codec (binary framing) │
|
||||
│ ├─ keepalive (adaptive state FSM) │
|
||||
│ ├─ telemetry (RTT/jitter/loss) │
|
||||
│ ├─ qos (classify + priority Q) │
|
||||
│ ├─ ratelimit (token bucket) │
|
||||
│ ├─ mtu (overhead calc + ICMP) │
|
||||
│ ├─ tunnel (TUN device) │
|
||||
│ ├─ network (NAT/IP pool) │
|
||||
│ └─ reconnect (exp. backoff) │
|
||||
└────────────────────────────────────┘
|
||||
┌──────────────────────────────┐ JSON-lines IPC ┌───────────────────────────────┐
|
||||
│ TypeScript Control Plane │ ◄─────────────────────► │ Rust Data Plane Daemon │
|
||||
│ │ stdio or Unix sock │ │
|
||||
│ VpnServer / VpnClient │ │ Noise IK handshake │
|
||||
│ Typed IPC commands │ │ XChaCha20-Poly1305 │
|
||||
│ Config validation │ │ WS + QUIC + WireGuard │
|
||||
│ Hub: client management │ │ TUN device, IP pool, NAT │
|
||||
│ WireGuard .conf generation │ │ Rate limiting, ACLs, QoS │
|
||||
└──────────────────────────────┘ └───────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key design decisions:**
|
||||
**Split-plane design** — TypeScript handles orchestration, config, and DX; Rust handles every hot-path byte with zero-copy async I/O (tokio, mimalloc).
|
||||
|
||||
| Decision | Choice | Why |
|
||||
|----------|--------|-----|
|
||||
| Transport | WebSocket over HTTPS | Works through Cloudflare and other terminating proxies |
|
||||
| Encryption | Noise NK + XChaCha20-Poly1305 | Strong forward secrecy, large nonce space (no counter needed) |
|
||||
| Keepalive | Adaptive app-level pings | Cloudflare drops WS pings; interval adapts to link health (10–60s) |
|
||||
| QoS | Packet classification + priority queues | DNS/SSH/ICMP always drain first; bulk flows get deprioritized |
|
||||
| Rate limiting | Per-client token bucket | Byte-granular, dynamically reconfigurable via IPC |
|
||||
| IPC | JSON lines over stdio / Unix socket | `stdio` for dev, `socket` for production (daemon stays alive) |
|
||||
| Binary protocol | `[type:1B][length:4B][payload:NB]` | Minimal overhead, easy to parse at wire speed |
|
||||
## Quick Start 🚀
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### VPN Client
|
||||
|
||||
```typescript
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
const client = new VpnClient({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
await client.start();
|
||||
|
||||
const { assignedIp } = await client.connect({
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'BASE64_SERVER_PUBLIC_KEY',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
});
|
||||
|
||||
console.log(`Connected! Assigned IP: ${assignedIp}`);
|
||||
|
||||
// Connection quality (adaptive keepalive + telemetry)
|
||||
const quality = await client.getConnectionQuality();
|
||||
console.log(quality);
|
||||
// {
|
||||
// srttMs: 42.5, jitterMs: 3.2, minRttMs: 38.0, maxRttMs: 67.0,
|
||||
// lossRatio: 0.0, consecutiveTimeouts: 0,
|
||||
// linkHealth: 'healthy', currentKeepaliveIntervalSecs: 60
|
||||
// }
|
||||
|
||||
// MTU info
|
||||
const mtu = await client.getMtuInfo();
|
||||
console.log(mtu);
|
||||
// { tunMtu: 1420, effectiveMtu: 1421, linkMtu: 1500, overheadBytes: 79, ... }
|
||||
|
||||
// Traffic stats (includes quality snapshot)
|
||||
const stats = await client.getStatistics();
|
||||
|
||||
await client.disconnect();
|
||||
client.stop();
|
||||
```
|
||||
|
||||
### VPN Server
|
||||
### 1. Start a VPN Server (Hub)
|
||||
|
||||
```typescript
|
||||
import { VpnServer } from '@push.rocks/smartvpn';
|
||||
|
||||
const server = new VpnServer({
|
||||
transport: { transport: 'stdio' },
|
||||
});
|
||||
|
||||
// Generate a Noise keypair first
|
||||
await server.start();
|
||||
// If you don't have keys yet:
|
||||
const keypair = await server.generateKeypair();
|
||||
|
||||
// Start the VPN listener (or pass config to start() directly)
|
||||
const server = new VpnServer({ transport: { transport: 'stdio' } });
|
||||
await server.start({
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
privateKey: '<server-noise-private-key-base64>',
|
||||
publicKey: '<server-noise-public-key-base64>',
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
transportMode: 'both', // WebSocket + QUIC simultaneously
|
||||
enableNat: true,
|
||||
// Optional: default rate limit for all new clients
|
||||
defaultRateLimitBytesPerSec: 10_000_000, // 10 MB/s
|
||||
defaultBurstBytes: 20_000_000, // 20 MB burst
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
});
|
||||
|
||||
// List connected clients
|
||||
const clients = await server.listClients();
|
||||
|
||||
// Per-client rate limiting (live, no reconnect needed)
|
||||
await server.setClientRateLimit('client-id', 5_000_000, 10_000_000);
|
||||
await server.removeClientRateLimit('client-id'); // unlimited
|
||||
|
||||
// Per-client telemetry
|
||||
const telemetry = await server.getClientTelemetry('client-id');
|
||||
console.log(telemetry);
|
||||
// {
|
||||
// clientId, assignedIp, lastKeepaliveAt, keepalivesReceived,
|
||||
// packetsDropped, bytesDropped, bytesReceived, bytesSent,
|
||||
// rateLimitBytesPerSec, burstBytes
|
||||
// }
|
||||
|
||||
// Kick a client
|
||||
await server.disconnectClient('client-id');
|
||||
|
||||
await server.stopServer();
|
||||
server.stop();
|
||||
```
|
||||
|
||||
### Production: Socket Transport
|
||||
|
||||
In production, the daemon runs as a system service and you connect over a Unix socket:
|
||||
### 2. Create a Client (One Call = Everything)
|
||||
|
||||
```typescript
|
||||
const client = new VpnClient({
|
||||
transport: {
|
||||
transport: 'socket',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
autoReconnect: true,
|
||||
reconnectBaseDelayMs: 100,
|
||||
reconnectMaxDelayMs: 30000,
|
||||
maxReconnectAttempts: 10,
|
||||
const bundle = await server.createClient({
|
||||
clientId: 'alice-laptop',
|
||||
tags: ['engineering'],
|
||||
security: {
|
||||
destinationAllowList: ['10.0.0.0/8'], // can only reach internal network
|
||||
destinationBlockList: ['10.0.0.99'], // except this host
|
||||
rateLimit: { bytesPerSec: 10_000_000, burstBytes: 20_000_000 },
|
||||
},
|
||||
});
|
||||
|
||||
await client.start(); // connects to existing daemon (does not spawn)
|
||||
// bundle.smartvpnConfig → typed IVpnClientConfig, ready to use
|
||||
// bundle.wireguardConfig → standard WireGuard .conf string
|
||||
// bundle.secrets → { noisePrivateKey, wgPrivateKey } — shown ONCE
|
||||
```
|
||||
|
||||
When using socket transport, `client.stop()` closes the socket but **does not kill the daemon** — exactly what you want in production.
|
||||
|
||||
## 📋 API Reference
|
||||
|
||||
### `VpnClient`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start()` | `Promise<boolean>` | Start the daemon bridge (spawn or connect) |
|
||||
| `connect(config?)` | `Promise<{ assignedIp }>` | Connect to VPN server |
|
||||
| `disconnect()` | `Promise<void>` | Disconnect from VPN |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Current connection state |
|
||||
| `getStatistics()` | `Promise<IVpnStatistics>` | Traffic stats + connection quality |
|
||||
| `getConnectionQuality()` | `Promise<IVpnConnectionQuality>` | RTT, jitter, loss, link health |
|
||||
| `getMtuInfo()` | `Promise<IVpnMtuInfo>` | MTU info and overhead breakdown |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
| `running` | `boolean` | Whether bridge is active |
|
||||
|
||||
### `VpnServer`
|
||||
|
||||
| Method | Returns | Description |
|
||||
|--------|---------|-------------|
|
||||
| `start(config?)` | `Promise<void>` | Start daemon + VPN server |
|
||||
| `stopServer()` | `Promise<void>` | Stop the VPN server |
|
||||
| `getStatus()` | `Promise<IVpnStatus>` | Server connection state |
|
||||
| `getStatistics()` | `Promise<IVpnServerStatistics>` | Server stats (includes client counts) |
|
||||
| `listClients()` | `Promise<IVpnClientInfo[]>` | Connected clients with QoS stats |
|
||||
| `disconnectClient(id)` | `Promise<void>` | Kick a client |
|
||||
| `generateKeypair()` | `Promise<IVpnKeypair>` | Generate Noise NK keypair |
|
||||
| `setClientRateLimit(id, rate, burst)` | `Promise<void>` | Set per-client rate limit (bytes/sec) |
|
||||
| `removeClientRateLimit(id)` | `Promise<void>` | Remove rate limit (unlimited) |
|
||||
| `getClientTelemetry(id)` | `Promise<IVpnClientTelemetry>` | Per-client telemetry + drop stats |
|
||||
| `stop()` | `void` | Kill/close the daemon bridge |
|
||||
|
||||
### `VpnConfig`
|
||||
|
||||
Static utility class for config validation and file I/O:
|
||||
### 3. Connect a Client
|
||||
|
||||
```typescript
|
||||
import { VpnConfig } from '@push.rocks/smartvpn';
|
||||
import { VpnClient } from '@push.rocks/smartvpn';
|
||||
|
||||
// Validate (throws on invalid)
|
||||
VpnConfig.validateClientConfig(config);
|
||||
VpnConfig.validateServerConfig(config);
|
||||
const client = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await client.start();
|
||||
|
||||
// Load/save JSON configs
|
||||
const config = await VpnConfig.loadFromFile<IVpnClientConfig>('/etc/smartvpn/client.json');
|
||||
await VpnConfig.saveToFile('/etc/smartvpn/client.json', config);
|
||||
const { assignedIp } = await client.connect(bundle.smartvpnConfig);
|
||||
console.log(`Connected! VPN IP: ${assignedIp}`);
|
||||
```
|
||||
|
||||
### `VpnInstaller`
|
||||
## Features ✨
|
||||
|
||||
Generate system service units for the daemon:
|
||||
### 🔐 Enterprise Authentication (Noise IK)
|
||||
|
||||
Every client authenticates with a **Noise IK handshake** (`Noise_IK_25519_ChaChaPoly_BLAKE2s`). The server verifies the client's static public key against its registry — unauthorized clients are rejected before any data flows.
|
||||
|
||||
- Per-client X25519 keypair generated server-side
|
||||
- Client registry with enable/disable, expiry, tags
|
||||
- Key rotation with `rotateClientKey()` — generates new keys, returns fresh config bundle, disconnects old session
|
||||
|
||||
### 🌐 Triple Transport
|
||||
|
||||
| Transport | Protocol | Best For |
|
||||
|-----------|----------|----------|
|
||||
| **WebSocket** | TLS over TCP | Firewall-friendly, Cloudflare compatible |
|
||||
| **QUIC** | UDP (via quinn) | Low latency, datagram support for IP packets |
|
||||
| **WireGuard** | UDP (via boringtun) | Standard WG clients (iOS, Android, wg-quick) |
|
||||
|
||||
The server can run **all three simultaneously** with `transportMode: 'both'` (WS + QUIC) or `'wireguard'`. Clients auto-negotiate with `transport: 'auto'` (tries QUIC first, falls back to WS).
|
||||
|
||||
### 🛡️ ACL Engine (SmartProxy-Aligned)
|
||||
|
||||
Security policies per client, using the same `ipAllowList` / `ipBlockList` naming convention as `@push.rocks/smartproxy`:
|
||||
|
||||
```typescript
|
||||
security: {
|
||||
ipAllowList: ['192.168.1.0/24'], // source IPs allowed to connect
|
||||
ipBlockList: ['192.168.1.100'], // deny overrides allow
|
||||
destinationAllowList: ['10.0.0.0/8'], // VPN destinations permitted
|
||||
destinationBlockList: ['10.0.0.99'], // deny overrides allow
|
||||
maxConnections: 5,
|
||||
rateLimit: { bytesPerSec: 1_000_000, burstBytes: 2_000_000 },
|
||||
}
|
||||
```
|
||||
|
||||
Supports exact IPs, CIDR, wildcards (`192.168.1.*`), and ranges (`1.1.1.1-1.1.1.100`).
|
||||
|
||||
### 🔀 PROXY Protocol v2
|
||||
|
||||
When the VPN server sits behind a reverse proxy, enable PROXY protocol v2 to receive the **real client IP** instead of the proxy's address. This makes `ipAllowList` / `ipBlockList` ACLs work correctly through load balancers.
|
||||
|
||||
```typescript
|
||||
await server.start({
|
||||
// ... other config ...
|
||||
proxyProtocol: true, // parse PP v2 headers on WS connections
|
||||
connectionIpBlockList: ['198.51.100.0/24'], // server-wide block list (pre-handshake)
|
||||
});
|
||||
```
|
||||
|
||||
**Two-phase ACL with real IPs:**
|
||||
|
||||
| Phase | When | What Happens |
|
||||
|-------|------|-------------|
|
||||
| **Pre-handshake** | After TCP accept | Server-level `connectionIpBlockList` rejects known-bad IPs — zero crypto cost |
|
||||
| **Post-handshake** | After Noise IK identifies client | Per-client `ipAllowList` / `ipBlockList` checked against real source IP |
|
||||
|
||||
- Parses the PP v2 binary header from raw TCP before WebSocket upgrade
|
||||
- 5-second timeout protects against stalling attacks
|
||||
- LOCAL command (proxy health checks) handled gracefully
|
||||
- IPv4 and IPv6 addresses supported
|
||||
- `remoteAddr` field on `IVpnClientInfo` exposes the real client IP for monitoring
|
||||
- **Security**: must be `false` (default) when accepting direct connections — only enable behind a trusted proxy
|
||||
|
||||
### 📊 Telemetry & QoS
|
||||
|
||||
- **Connection quality**: Smoothed RTT, jitter, min/max RTT, loss ratio, link health (`healthy` / `degraded` / `critical`)
|
||||
- **Adaptive keepalives**: Interval adjusts based on link health (60s → 30s → 10s)
|
||||
- **Per-client rate limiting**: Token bucket with configurable bytes/sec and burst
|
||||
- **Dead-peer detection**: 180s inactivity timeout
|
||||
- **MTU management**: Automatic overhead calculation (IP+TCP+WS+Noise = 79 bytes)
|
||||
|
||||
### 🔄 Hub Client Management
|
||||
|
||||
The server acts as a **hub** — one API to manage all clients:
|
||||
|
||||
```typescript
|
||||
// Create (generates keys, assigns IP, returns config bundle)
|
||||
const bundle = await server.createClient({ clientId: 'bob-phone' });
|
||||
|
||||
// Read
|
||||
const entry = await server.getClient('bob-phone');
|
||||
const all = await server.listRegisteredClients();
|
||||
|
||||
// Update (ACLs, tags, description, rate limits...)
|
||||
await server.updateClient('bob-phone', {
|
||||
security: { destinationAllowList: ['0.0.0.0/0'] },
|
||||
tags: ['mobile', 'field-ops'],
|
||||
});
|
||||
|
||||
// Enable / Disable
|
||||
await server.disableClient('bob-phone'); // disconnects + blocks reconnection
|
||||
await server.enableClient('bob-phone');
|
||||
|
||||
// Key rotation
|
||||
const newBundle = await server.rotateClientKey('bob-phone');
|
||||
|
||||
// Export config (without secrets)
|
||||
const wgConf = await server.exportClientConfig('bob-phone', 'wireguard');
|
||||
|
||||
// Remove
|
||||
await server.removeClient('bob-phone');
|
||||
```
|
||||
|
||||
### 📝 WireGuard Config Generation
|
||||
|
||||
Generate standard `.conf` files for any WireGuard client:
|
||||
|
||||
```typescript
|
||||
import { WgConfigGenerator } from '@push.rocks/smartvpn';
|
||||
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: '<client-wg-private-key>',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1'],
|
||||
peer: {
|
||||
publicKey: '<server-wg-public-key>',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
// → standard WireGuard .conf compatible with wg-quick, iOS, Android
|
||||
```
|
||||
|
||||
### 🖥️ System Service Installation
|
||||
|
||||
```typescript
|
||||
import { VpnInstaller } from '@push.rocks/smartvpn';
|
||||
|
||||
const platform = VpnInstaller.detectPlatform(); // 'linux' | 'macos' | 'windows' | 'unknown'
|
||||
|
||||
// Linux (systemd)
|
||||
const unit = VpnInstaller.generateSystemdUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'server',
|
||||
});
|
||||
|
||||
// macOS (launchd)
|
||||
const plist = VpnInstaller.generateLaunchdPlist({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
mode: 'client',
|
||||
});
|
||||
|
||||
// Auto-detect platform
|
||||
const serviceUnit = VpnInstaller.generateServiceUnit({
|
||||
binaryPath: '/usr/local/bin/smartvpn_daemon',
|
||||
socketPath: '/var/run/smartvpn.sock',
|
||||
const unit = VpnInstaller.generateServiceUnit({
|
||||
mode: 'server',
|
||||
configPath: '/etc/smartvpn/server.json',
|
||||
});
|
||||
// unit.platform → 'linux' | 'macos'
|
||||
// unit.content → systemd unit file or launchd plist
|
||||
// unit.installPath → /etc/systemd/system/smartvpn-server.service
|
||||
```
|
||||
|
||||
### Events
|
||||
## API Reference 📖
|
||||
|
||||
Both `VpnClient` and `VpnServer` extend `EventEmitter`:
|
||||
### Classes
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `VpnServer` | Manages the Rust daemon in server mode. Hub methods for client CRUD. |
|
||||
| `VpnClient` | Manages the Rust daemon in client mode. Connect, disconnect, telemetry. |
|
||||
| `VpnBridge<T>` | Low-level typed IPC bridge (stdio or Unix socket). |
|
||||
| `VpnConfig` | Static config validation and file I/O. |
|
||||
| `VpnInstaller` | Generates systemd/launchd service files. |
|
||||
| `WgConfigGenerator` | Generates standard WireGuard `.conf` files. |
|
||||
|
||||
### Key Interfaces
|
||||
|
||||
| Interface | Purpose |
|
||||
|-----------|---------|
|
||||
| `IVpnServerConfig` | Server configuration (listen addr, keys, subnet, transport mode, clients, proxy protocol) |
|
||||
| `IVpnClientConfig` | Client configuration (server URL, keys, transport, WG options) |
|
||||
| `IClientEntry` | Server-side client definition (ID, keys, security, priority, tags, expiry) |
|
||||
| `IClientSecurity` | Per-client ACLs and rate limits (SmartProxy-aligned naming) |
|
||||
| `IClientRateLimit` | Rate limiting config (bytesPerSec, burstBytes) |
|
||||
| `IClientConfigBundle` | Full config bundle returned by `createClient()` |
|
||||
| `IVpnClientInfo` | Connected client info (IP, stats, authenticated key, remote addr) |
|
||||
| `IVpnConnectionQuality` | RTT, jitter, loss ratio, link health |
|
||||
| `IVpnKeypair` | Base64-encoded public/private key pair |
|
||||
|
||||
### Server IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `start` / `stop` | Start/stop the VPN listener |
|
||||
| `createClient` | Generate keys, assign IP, return config bundle |
|
||||
| `removeClient` / `getClient` / `listRegisteredClients` | Client registry CRUD |
|
||||
| `updateClient` / `enableClient` / `disableClient` | Modify client state |
|
||||
| `rotateClientKey` | Fresh keypairs + new config bundle |
|
||||
| `exportClientConfig` | Re-export as SmartVPN config or WireGuard `.conf` |
|
||||
| `listClients` / `disconnectClient` | Manage live connections |
|
||||
| `setClientRateLimit` / `removeClientRateLimit` | Runtime rate limit adjustments |
|
||||
| `getStatus` / `getStatistics` / `getClientTelemetry` | Monitoring |
|
||||
| `generateKeypair` / `generateWgKeypair` / `generateClientKeypair` | Key generation |
|
||||
| `addWgPeer` / `removeWgPeer` / `listWgPeers` | WireGuard peer management |
|
||||
|
||||
### Client IPC Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `connect` / `disconnect` | Manage the tunnel |
|
||||
| `getStatus` / `getStatistics` | Connection state and traffic stats |
|
||||
| `getConnectionQuality` | RTT, jitter, loss, link health |
|
||||
| `getMtuInfo` | MTU and overhead details |
|
||||
|
||||
## Transport Modes 🔀
|
||||
|
||||
### Server Configuration
|
||||
|
||||
```typescript
|
||||
client.on('exit', ({ code, signal }) => { /* daemon exited */ });
|
||||
client.on('reconnected', () => { /* socket reconnected */ });
|
||||
// WebSocket only
|
||||
{ transportMode: 'websocket', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
server.on('client-connected', (info) => { /* IVpnClientInfo */ });
|
||||
server.on('client-disconnected', ({ clientId, reason }) => { /* ... */ });
|
||||
// QUIC only
|
||||
{ transportMode: 'quic', listenAddr: '0.0.0.0:443' }
|
||||
|
||||
// Both (WS + QUIC on same or different ports)
|
||||
{ transportMode: 'both', listenAddr: '0.0.0.0:443', quicListenAddr: '0.0.0.0:4433' }
|
||||
|
||||
// WireGuard
|
||||
{ transportMode: 'wireguard', wgListenPort: 51820, wgPeers: [...] }
|
||||
```
|
||||
|
||||
## 📊 QoS System
|
||||
|
||||
The Rust daemon includes a full QoS stack that operates on decrypted IP packets:
|
||||
|
||||
### Adaptive Keepalive
|
||||
|
||||
The keepalive system automatically adjusts its interval based on connection quality:
|
||||
|
||||
| Link Health | Keepalive Interval | Triggered When |
|
||||
|-------------|-------------------|----------------|
|
||||
| 🟢 Healthy | 60s | Jitter < 30ms, loss < 2%, no timeouts |
|
||||
| 🟡 Degraded | 30s | Jitter > 50ms, loss > 5%, or 1+ timeout |
|
||||
| 🔴 Critical | 10s | Loss > 20% or 2+ consecutive timeouts |
|
||||
|
||||
State transitions include hysteresis (3 consecutive good checks to upgrade, 2 to recover) to prevent flapping. Dead peer detection fires after 3 consecutive timeouts in Critical state.
|
||||
|
||||
### Packet Classification
|
||||
|
||||
IP packets are classified into three priority levels by inspecting headers (no deep packet inspection):
|
||||
|
||||
| Priority | Traffic |
|
||||
|----------|---------|
|
||||
| **High** | ICMP, DNS (port 53), SSH (port 22), small packets (< 128 bytes) |
|
||||
| **Normal** | Everything else |
|
||||
| **Low** | Bulk flows exceeding 1 MB within a 60s window |
|
||||
|
||||
Priority channels drain with biased `tokio::select!` — high-priority packets always go first.
|
||||
|
||||
### Smart Packet Dropping
|
||||
|
||||
Under backpressure, packets are dropped intelligently:
|
||||
|
||||
1. **Low** queue full → drop silently
|
||||
2. **Normal** queue full → drop
|
||||
3. **High** queue full → wait 5ms, then drop as last resort
|
||||
|
||||
Drop statistics are tracked per priority level and exposed via telemetry.
|
||||
|
||||
### Per-Client Rate Limiting
|
||||
|
||||
Token bucket algorithm with byte granularity:
|
||||
### Client Configuration
|
||||
|
||||
```typescript
|
||||
// Set: 10 MB/s sustained, 20 MB burst
|
||||
await server.setClientRateLimit('client-id', 10_000_000, 20_000_000);
|
||||
// Auto (tries QUIC first, falls back to WS)
|
||||
{ transport: 'auto', serverUrl: 'wss://vpn.example.com' }
|
||||
|
||||
// Check drops via telemetry
|
||||
const t = await server.getClientTelemetry('client-id');
|
||||
console.log(`Dropped: ${t.packetsDropped} packets, ${t.bytesDropped} bytes`);
|
||||
// Explicit QUIC with certificate pinning
|
||||
{ transport: 'quic', serverUrl: '1.2.3.4:4433', serverCertHash: '<sha256-base64>' }
|
||||
|
||||
// Remove limit
|
||||
await server.removeClientRateLimit('client-id');
|
||||
// WireGuard
|
||||
{ transport: 'wireguard', wgPrivateKey: '...', wgEndpoint: 'vpn.example.com:51820', ... }
|
||||
```
|
||||
|
||||
Rate limits can be changed live without disconnecting the client.
|
||||
## Cryptography 🔑
|
||||
|
||||
### Path MTU
|
||||
| Layer | Algorithm | Purpose |
|
||||
|-------|-----------|---------|
|
||||
| **Handshake** | Noise IK (X25519 + ChaChaPoly + BLAKE2s) | Mutual authentication + key exchange |
|
||||
| **Transport** | Noise transport state (ChaChaPoly) | All post-handshake data encryption |
|
||||
| **Additional** | XChaCha20-Poly1305 | Extended nonce space for data-at-rest |
|
||||
| **WireGuard** | X25519 + ChaCha20-Poly1305 (via boringtun) | Standard WireGuard crypto |
|
||||
|
||||
Tunnel overhead is calculated precisely:
|
||||
## Binary Protocol 📡
|
||||
|
||||
| Layer | Bytes |
|
||||
|-------|-------|
|
||||
| IP header | 20 |
|
||||
| TCP header (with timestamps) | 32 |
|
||||
| WebSocket framing | 6 |
|
||||
| VPN frame header | 5 |
|
||||
| Noise AEAD tag | 16 |
|
||||
| **Total overhead** | **79** |
|
||||
All frames use `[type:1B][length:4B][payload:NB]` with a 64KB max payload:
|
||||
|
||||
For a standard 1500-byte Ethernet link, effective TUN MTU = **1421 bytes**. The default TUN MTU of 1420 is conservative and correct. Oversized packets get an ICMP "Fragmentation Needed" (Type 3, Code 4) written back into the TUN, so the source TCP adjusts its MSS automatically.
|
||||
| Type | Hex | Direction | Description |
|
||||
|------|-----|-----------|-------------|
|
||||
| HandshakeInit | `0x01` | Client → Server | Noise IK first message |
|
||||
| HandshakeResp | `0x02` | Server → Client | Noise IK response |
|
||||
| IpPacket | `0x10` | Bidirectional | Encrypted tunnel data |
|
||||
| Keepalive | `0x20` | Client → Server | App-level keepalive (not WS ping) |
|
||||
| KeepaliveAck | `0x21` | Server → Client | Keepalive response with RTT payload |
|
||||
| Disconnect | `0x3F` | Bidirectional | Graceful disconnect |
|
||||
|
||||
## 🔐 Security Model
|
||||
|
||||
The VPN uses a **Noise NK** handshake pattern:
|
||||
|
||||
1. **NK** = client does **N**ot authenticate, but **K**nows the server's static public key
|
||||
2. The client generates an ephemeral keypair, performs `e, es` (DH with server's static key)
|
||||
3. Server responds with `e, ee` (DH with both ephemeral keys)
|
||||
4. Result: forward-secret transport keys derived from both DH operations
|
||||
|
||||
Post-handshake, all IP packets are encrypted with **XChaCha20-Poly1305**:
|
||||
- 24-byte random nonces (no counter synchronization needed)
|
||||
- 16-byte authentication tags
|
||||
- Wire format: `[nonce:24B][ciphertext:var][tag:16B]`
|
||||
|
||||
## 📦 Binary Protocol
|
||||
|
||||
Inside the WebSocket tunnel, packets use a simple binary framing:
|
||||
|
||||
```
|
||||
┌──────────┬──────────┬────────────────────┐
|
||||
│ Type (1B)│ Len (4B) │ Payload (variable) │
|
||||
└──────────┴──────────┴────────────────────┘
|
||||
```
|
||||
|
||||
| Type | Value | Description |
|
||||
|------|-------|-------------|
|
||||
| `HandshakeInit` | `0x01` | Client → Server handshake |
|
||||
| `HandshakeResp` | `0x02` | Server → Client handshake |
|
||||
| `IpPacket` | `0x10` | Encrypted IP packet |
|
||||
| `Keepalive` | `0x20` | App-level ping (8-byte timestamp payload) |
|
||||
| `KeepaliveAck` | `0x21` | App-level pong (echoes timestamp for RTT) |
|
||||
| `SessionResume` | `0x30` | Resume a dropped session |
|
||||
| `SessionResumeOk` | `0x31` | Resume accepted |
|
||||
| `SessionResumeErr` | `0x32` | Resume rejected |
|
||||
| `Disconnect` | `0x3F` | Graceful disconnect |
|
||||
|
||||
## 🛠️ Rust Daemon CLI
|
||||
|
||||
```bash
|
||||
# Development: stdio management (JSON lines on stdin/stdout)
|
||||
smartvpn_daemon --management --mode client
|
||||
smartvpn_daemon --management --mode server
|
||||
|
||||
# Production: Unix socket management
|
||||
smartvpn_daemon --management-socket /var/run/smartvpn.sock --mode server
|
||||
|
||||
# Generate a Noise keypair
|
||||
smartvpn_daemon --generate-keypair
|
||||
```
|
||||
|
||||
## 🔧 Building from Source
|
||||
## Development 🛠️
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Build TypeScript + cross-compile Rust (amd64 + arm64)
|
||||
# Build (TypeScript + Rust cross-compile)
|
||||
pnpm build
|
||||
|
||||
# Build Rust only (debug)
|
||||
cd rust && cargo build
|
||||
|
||||
# Run all tests (71 Rust + 32 TypeScript)
|
||||
cd rust && cargo test
|
||||
# Run all tests (79 TS + 129 Rust = 208 tests)
|
||||
pnpm test
|
||||
|
||||
# Run Rust tests directly
|
||||
cd rust && cargo test
|
||||
|
||||
# Run a specific TS test
|
||||
tstest test/test.flowcontrol.node.ts --verbose
|
||||
```
|
||||
|
||||
## TypeScript Interfaces
|
||||
### Project Structure
|
||||
|
||||
<details>
|
||||
<summary>Click to expand full type definitions</summary>
|
||||
|
||||
```typescript
|
||||
// Transport options
|
||||
type TVpnTransportOptions =
|
||||
| { transport: 'stdio' }
|
||||
| {
|
||||
transport: 'socket';
|
||||
socketPath: string;
|
||||
autoReconnect?: boolean;
|
||||
reconnectBaseDelayMs?: number;
|
||||
reconnectMaxDelayMs?: number;
|
||||
maxReconnectAttempts?: number;
|
||||
};
|
||||
|
||||
// Client config
|
||||
interface IVpnClientConfig {
|
||||
serverUrl: string;
|
||||
serverPublicKey: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
}
|
||||
|
||||
// Server config
|
||||
interface IVpnServerConfig {
|
||||
listenAddr: string;
|
||||
privateKey: string;
|
||||
publicKey: string;
|
||||
subnet: string;
|
||||
tlsCert?: string;
|
||||
tlsKey?: string;
|
||||
dns?: string[];
|
||||
mtu?: number;
|
||||
keepaliveIntervalSecs?: number;
|
||||
enableNat?: boolean;
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
defaultBurstBytes?: number;
|
||||
}
|
||||
|
||||
// Status
|
||||
type TVpnConnectionState = 'disconnected' | 'connecting' | 'handshaking'
|
||||
| 'connected' | 'reconnecting' | 'error';
|
||||
|
||||
interface IVpnStatus {
|
||||
state: TVpnConnectionState;
|
||||
assignedIp?: string;
|
||||
serverAddr?: string;
|
||||
connectedSince?: string;
|
||||
lastError?: string;
|
||||
}
|
||||
|
||||
// Statistics
|
||||
interface IVpnStatistics {
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
keepalivesSent: number;
|
||||
keepalivesReceived: number;
|
||||
uptimeSeconds: number;
|
||||
quality?: IVpnConnectionQuality;
|
||||
}
|
||||
|
||||
interface IVpnServerStatistics extends IVpnStatistics {
|
||||
activeClients: number;
|
||||
totalConnections: number;
|
||||
}
|
||||
|
||||
// Connection quality (QoS)
|
||||
type TVpnLinkHealth = 'healthy' | 'degraded' | 'critical';
|
||||
|
||||
interface IVpnConnectionQuality {
|
||||
srttMs: number;
|
||||
jitterMs: number;
|
||||
minRttMs: number;
|
||||
maxRttMs: number;
|
||||
lossRatio: number;
|
||||
consecutiveTimeouts: number;
|
||||
linkHealth: TVpnLinkHealth;
|
||||
currentKeepaliveIntervalSecs: number;
|
||||
}
|
||||
|
||||
// MTU info
|
||||
interface IVpnMtuInfo {
|
||||
tunMtu: number;
|
||||
effectiveMtu: number;
|
||||
linkMtu: number;
|
||||
overheadBytes: number;
|
||||
oversizedPacketsDropped: number;
|
||||
icmpTooBigSent: number;
|
||||
}
|
||||
|
||||
// Client info (with QoS fields)
|
||||
interface IVpnClientInfo {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
connectedSince: string;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// Per-client telemetry
|
||||
interface IVpnClientTelemetry {
|
||||
clientId: string;
|
||||
assignedIp: string;
|
||||
lastKeepaliveAt?: string;
|
||||
keepalivesReceived: number;
|
||||
packetsDropped: number;
|
||||
bytesDropped: number;
|
||||
bytesReceived: number;
|
||||
bytesSent: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
interface IVpnKeypair {
|
||||
publicKey: string;
|
||||
privateKey: string;
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
smartvpn/
|
||||
├── ts/ # TypeScript control plane
|
||||
│ ├── index.ts # All exports
|
||||
│ ├── smartvpn.interfaces.ts # Interfaces, types, IPC command maps
|
||||
│ ├── smartvpn.classes.vpnserver.ts
|
||||
│ ├── smartvpn.classes.vpnclient.ts
|
||||
│ ├── smartvpn.classes.vpnbridge.ts
|
||||
│ ├── smartvpn.classes.vpnconfig.ts
|
||||
│ ├── smartvpn.classes.vpninstaller.ts
|
||||
│ └── smartvpn.classes.wgconfig.ts
|
||||
├── rust/ # Rust data plane daemon
|
||||
│ └── src/
|
||||
│ ├── main.rs # CLI entry point
|
||||
│ ├── server.rs # VPN server + hub methods
|
||||
│ ├── client.rs # VPN client
|
||||
│ ├── crypto.rs # Noise IK + XChaCha20
|
||||
│ ├── client_registry.rs # Client database
|
||||
│ ├── acl.rs # ACL engine
|
||||
│ ├── proxy_protocol.rs # PROXY protocol v2 parser
|
||||
│ ├── management.rs # JSON-lines IPC
|
||||
│ ├── transport.rs # WebSocket transport
|
||||
│ ├── quic_transport.rs # QUIC transport
|
||||
│ ├── wireguard.rs # WireGuard (boringtun)
|
||||
│ ├── codec.rs # Binary frame protocol
|
||||
│ ├── keepalive.rs # Adaptive keepalives
|
||||
│ ├── ratelimit.rs # Token bucket
|
||||
│ └── ... # tunnel, network, telemetry, qos, mtu, reconnect
|
||||
├── test/ # 9 test files (79 tests)
|
||||
├── dist_ts/ # Compiled TypeScript
|
||||
└── dist_rust/ # Cross-compiled binaries (linux amd64 + arm64)
|
||||
```
|
||||
|
||||
## License and Legal Information
|
||||
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license.md) file.
|
||||
|
||||
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.
|
||||
|
||||
|
||||
253
readme.plan.md
Normal file
253
readme.plan.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# PROXY Protocol v2 Support for SmartVPN WebSocket Transport
|
||||
|
||||
## Context
|
||||
|
||||
SmartVPN's WebSocket transport is designed to sit behind reverse proxies (Cloudflare, HAProxy, SmartProxy). The recently added ACL engine has `ipAllowList`/`ipBlockList` per client, but without PROXY protocol support the server only sees the proxy's IP — not the real client's. This makes source-IP ACLs useless behind a proxy.
|
||||
|
||||
PROXY protocol v2 solves this by letting the proxy prepend a binary header with the real client IP/port before the WebSocket upgrade.
|
||||
|
||||
---
|
||||
|
||||
## Design
|
||||
|
||||
### Two-Phase ACL with Real Client IP
|
||||
|
||||
```
|
||||
TCP accept → Read PP v2 header → Extract real IP
|
||||
│
|
||||
├─ Phase 1 (pre-handshake): Check server-level connectionIpBlockList → reject early
|
||||
│
|
||||
├─ WebSocket upgrade → Noise IK handshake → Client identity known
|
||||
│
|
||||
└─ Phase 2 (post-handshake): Check per-client ipAllowList/ipBlockList → reject if denied
|
||||
```
|
||||
|
||||
- **Phase 1**: Server-wide block list (`connectionIpBlockList` on `IVpnServerConfig`). Rejects before any crypto work. Protects server resources.
|
||||
- **Phase 2**: Per-client ACL from `IClientSecurity.ipAllowList`/`ipBlockList`. Applied after the Noise IK handshake identifies the client.
|
||||
|
||||
### No New Dependencies
|
||||
|
||||
PROXY protocol v2 is a fixed-format binary header (16-byte signature + variable address block). Manual parsing (~80 lines) follows the same pattern as `codec.rs`. No crate needed.
|
||||
|
||||
### Scope: WebSocket Only
|
||||
|
||||
- **WebSocket**: Needs PP v2 (sits behind reverse proxies)
|
||||
- **QUIC**: Direct UDP, just use `conn.remote_address()`
|
||||
- **WireGuard**: Direct UDP, uses boringtun peer tracking
|
||||
|
||||
---
|
||||
|
||||
## Implementation
|
||||
|
||||
### Phase 1: New Rust module `proxy_protocol.rs`
|
||||
|
||||
**New file: `rust/src/proxy_protocol.rs`**
|
||||
|
||||
PP v2 binary format:
|
||||
```
|
||||
Bytes 0-11: Signature \x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A
|
||||
Byte 12: Version (high nibble = 0x2) | Command (low nibble: 0x0=LOCAL, 0x1=PROXY)
|
||||
Byte 13: Address family | Protocol (0x11 = IPv4/TCP, 0x21 = IPv6/TCP)
|
||||
Bytes 14-15: Address data length (big-endian u16)
|
||||
Bytes 16+: IPv4: 4 src_ip + 4 dst_ip + 2 src_port + 2 dst_port (12 bytes)
|
||||
IPv6: 16 src_ip + 16 dst_ip + 2 src_port + 2 dst_port (36 bytes)
|
||||
```
|
||||
|
||||
```rust
|
||||
pub struct ProxyHeader {
|
||||
pub src_addr: SocketAddr,
|
||||
pub dst_addr: SocketAddr,
|
||||
pub is_local: bool, // LOCAL command = health check probe
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
/// Reads exactly the header bytes — the stream is clean for WS upgrade after.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader>
|
||||
```
|
||||
|
||||
- 5-second timeout on header read (constant `PROXY_HEADER_TIMEOUT`)
|
||||
- Validates 12-byte signature, version nibble, command type
|
||||
- Parses IPv4 and IPv6 address blocks
|
||||
- LOCAL command returns `is_local: true` (caller closes connection gracefully)
|
||||
- Unit tests: valid IPv4/IPv6 headers, LOCAL command, invalid signature, truncated data
|
||||
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod proxy_protocol;`
|
||||
|
||||
### Phase 2: Server config + client info fields
|
||||
|
||||
**File: `rust/src/server.rs` — `ServerConfig`**
|
||||
|
||||
Add:
|
||||
```rust
|
||||
/// Enable PROXY protocol v2 parsing on WebSocket connections.
|
||||
/// SECURITY: Must be false when accepting direct client connections.
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept time, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
```
|
||||
|
||||
**File: `rust/src/server.rs` — `ClientInfo`**
|
||||
|
||||
Add:
|
||||
```rust
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
```
|
||||
|
||||
### Phase 3: ACL helper
|
||||
|
||||
**File: `rust/src/acl.rs`**
|
||||
|
||||
Add a public function for the server-level pre-handshake check:
|
||||
```rust
|
||||
/// Check whether a connection source IP is in a block list.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
```
|
||||
|
||||
(Keeps `ip_matches_any` private; exposes only the specific check needed.)
|
||||
|
||||
### Phase 4: WebSocket listener integration
|
||||
|
||||
**File: `rust/src/server.rs` — `run_ws_listener()`**
|
||||
|
||||
Between `listener.accept()` and `transport::accept_connection()`:
|
||||
|
||||
```rust
|
||||
// Determine real client address
|
||||
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
|
||||
match proxy_protocol::read_proxy_header(&mut tcp_stream).await {
|
||||
Ok(header) if header.is_local => {
|
||||
// Health check probe — close gracefully
|
||||
return;
|
||||
}
|
||||
Ok(header) => {
|
||||
info!("PP v2: real client {} -> {}", header.src_addr, header.dst_addr);
|
||||
Some(header.src_addr)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
|
||||
return; // Drop connection
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(tcp_addr) // Direct connection — use TCP SocketAddr
|
||||
};
|
||||
|
||||
// Pre-handshake server-level block list check
|
||||
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, block_list) {
|
||||
warn!("Connection blocked by server IP block list: {}", addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Then proceed with WS upgrade + handle_client_connection as before
|
||||
```
|
||||
|
||||
Key correctness note: `read_proxy_header` reads *exactly* the PP header bytes via `read_exact`. The `TcpStream` is then in a clean state for the WS HTTP upgrade. No buffered wrapper needed.
|
||||
|
||||
### Phase 5: Update `handle_client_connection` signature
|
||||
|
||||
**File: `rust/src/server.rs`**
|
||||
|
||||
Change signature:
|
||||
```rust
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>, // NEW
|
||||
) -> Result<()>
|
||||
```
|
||||
|
||||
After Noise IK handshake + registry lookup (where `client_security` is available), add connection-level per-client ACL:
|
||||
|
||||
```rust
|
||||
if let (Some(ref sec), Some(addr)) = (&client_security, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, sec.ip_block_list.as_deref().unwrap_or(&[])) {
|
||||
anyhow::bail!("Client {} connection denied: source IP {} blocked", registered_client_id, addr);
|
||||
}
|
||||
if let Some(ref allow) = sec.ip_allow_list {
|
||||
if !allow.is_empty() && !acl::is_ip_allowed(v4, allow) {
|
||||
anyhow::bail!("Client {} connection denied: source IP {} not in allow list", registered_client_id, addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Populate `remote_addr` when building `ClientInfo`:
|
||||
```rust
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
```
|
||||
|
||||
### Phase 6: QUIC listener — pass remote addr through
|
||||
|
||||
**File: `rust/src/server.rs` — `run_quic_listener()`**
|
||||
|
||||
QUIC doesn't use PROXY protocol. Just pass `conn.remote_address()` through:
|
||||
```rust
|
||||
let remote = conn.remote_address();
|
||||
// ...
|
||||
handle_client_connection(state, Box::new(sink), Box::new(stream), Some(remote)).await
|
||||
```
|
||||
|
||||
### Phase 7: TypeScript interface updates
|
||||
|
||||
**File: `ts/smartvpn.interfaces.ts`**
|
||||
|
||||
Add to `IVpnServerConfig`:
|
||||
```typescript
|
||||
/** Enable PROXY protocol v2 on incoming WebSocket connections.
|
||||
* Required when behind a reverse proxy that sends PP v2 headers. */
|
||||
proxyProtocol?: boolean;
|
||||
/** Server-level IP block list — applied at TCP accept time, before Noise handshake. */
|
||||
connectionIpBlockList?: string[];
|
||||
```
|
||||
|
||||
Add to `IVpnClientInfo`:
|
||||
```typescript
|
||||
/** Real client IP:port (from PROXY protocol or direct TCP). */
|
||||
remoteAddr?: string;
|
||||
```
|
||||
|
||||
### Phase 8: Tests
|
||||
|
||||
**Rust unit tests in `proxy_protocol.rs`:**
|
||||
- `parse_valid_ipv4_header` — construct a valid PP v2 header with known IPs, verify parsed correctly
|
||||
- `parse_valid_ipv6_header` — same for IPv6
|
||||
- `parse_local_command` — health check probe returns `is_local: true`
|
||||
- `reject_invalid_signature` — random bytes rejected
|
||||
- `reject_truncated_header` — short reads fail gracefully
|
||||
- `reject_v1_header` — PROXY v1 text format rejected (we only support v2)
|
||||
|
||||
**Rust unit tests in `acl.rs`:**
|
||||
- `is_connection_blocked` with various IP patterns
|
||||
|
||||
**TypeScript tests:**
|
||||
- Config validation accepts `proxyProtocol: true` + `connectionIpBlockList`
|
||||
|
||||
---
|
||||
|
||||
## Key Files to Modify
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `rust/src/proxy_protocol.rs` | **NEW** — PP v2 parser + tests |
|
||||
| `rust/src/lib.rs` | Add `pub mod proxy_protocol;` |
|
||||
| `rust/src/server.rs` | `ServerConfig` + `ClientInfo` fields, `run_ws_listener` PP integration, `handle_client_connection` signature + connection ACL, `run_quic_listener` pass-through |
|
||||
| `rust/src/acl.rs` | Add `is_connection_blocked` public function |
|
||||
| `ts/smartvpn.interfaces.ts` | `proxyProtocol`, `connectionIpBlockList`, `remoteAddr` |
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo test` — all existing 121 tests + new PP parser tests pass
|
||||
2. `pnpm test` — all 79 TS tests pass (no PP in test setup, just config validation)
|
||||
3. Manual: `socat` or test harness to send a PP v2 header before WS upgrade, verify server logs real IP
|
||||
793
rust/Cargo.lock
generated
793
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,18 @@ tun = { version = "0.7", features = ["async"] }
|
||||
bytes = "1"
|
||||
tokio-util = "0.7"
|
||||
futures-util = "0.3"
|
||||
async-trait = "0.1"
|
||||
quinn = "0.11"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rcgen = "0.13"
|
||||
ring = "0.17"
|
||||
rustls-pki-types = "1"
|
||||
rustls-pemfile = "2"
|
||||
webpki-roots = "1"
|
||||
mimalloc = "0.1"
|
||||
boringtun = "0.7"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
ipnet = "2"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
||||
302
rust/src/acl.rs
Normal file
302
rust/src/acl.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use std::net::Ipv4Addr;
|
||||
use ipnet::Ipv4Net;
|
||||
|
||||
use crate::client_registry::ClientSecurity;
|
||||
|
||||
/// Result of an ACL check.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AclResult {
|
||||
Allow,
|
||||
DenySrc,
|
||||
DenyDst,
|
||||
}
|
||||
|
||||
/// Check whether a connection source IP is in a server-level block list.
|
||||
/// Used for pre-handshake rejection of known-bad IPs.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
|
||||
/// Check whether a source IP is allowed by allow/block lists.
|
||||
/// Returns true if the IP is permitted (not blocked and passes allow check).
|
||||
pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool {
|
||||
// Deny overrides allow
|
||||
if let Some(bl) = block_list {
|
||||
if ip_matches_any(ip, bl) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If allow list exists and is non-empty, IP must match
|
||||
if let Some(al) = allow_list {
|
||||
if !al.is_empty() && !ip_matches_any(ip, al) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
|
||||
///
|
||||
/// Evaluation order (deny overrides allow):
|
||||
/// 1. If src_ip is in ip_block_list → DenySrc
|
||||
/// 2. If dst_ip is in destination_block_list → DenyDst
|
||||
/// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc
|
||||
/// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst
|
||||
/// 5. Otherwise → Allow
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// Step 1: Check source block list (deny overrides)
|
||||
if let Some(ref block_list) = security.ip_block_list {
|
||||
if ip_matches_any(src_ip, block_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check destination block list (deny overrides)
|
||||
if let Some(ref block_list) = security.destination_block_list {
|
||||
if ip_matches_any(dst_ip, block_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Check source allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.ip_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Check destination allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.destination_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
AclResult::Allow
|
||||
}
|
||||
|
||||
/// Check if `ip` matches any pattern in the list.
|
||||
/// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*),
|
||||
/// and IP ranges (192.168.1.1-192.168.1.100).
|
||||
fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if ip_matches(ip, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if `ip` matches a single pattern.
|
||||
fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
// CIDR notation (e.g. 192.168.1.0/24)
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = pattern.parse::<Ipv4Net>() {
|
||||
return net.contains(&ip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// IP range (e.g. 192.168.1.1-192.168.1.100)
|
||||
if pattern.contains('-') {
|
||||
let parts: Vec<&str> = pattern.splitn(2, '-').collect();
|
||||
if parts.len() == 2 {
|
||||
if let (Ok(start), Ok(end)) = (parts[0].trim().parse::<Ipv4Addr>(), parts[1].trim().parse::<Ipv4Addr>()) {
|
||||
let ip_u32 = u32::from(ip);
|
||||
return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wildcard pattern (e.g. 192.168.1.*)
|
||||
if pattern.contains('*') {
|
||||
return wildcard_matches(ip, pattern);
|
||||
}
|
||||
|
||||
// Exact IP match
|
||||
if let Ok(exact) = pattern.parse::<Ipv4Addr>() {
|
||||
return ip == exact;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*".
|
||||
fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let ip_octets = ip.octets();
|
||||
let pattern_parts: Vec<&str> = pattern.split('.').collect();
|
||||
if pattern_parts.len() != 4 {
|
||||
return false;
|
||||
}
|
||||
for (i, part) in pattern_parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
continue;
|
||||
}
|
||||
if let Ok(octet) = part.parse::<u8>() {
|
||||
if ip_octets[i] != octet {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client_registry::{ClientRateLimit, ClientSecurity};
|
||||
|
||||
fn security(
|
||||
ip_allow: Option<Vec<&str>>,
|
||||
ip_block: Option<Vec<&str>>,
|
||||
dst_allow: Option<Vec<&str>>,
|
||||
dst_block: Option<Vec<&str>>,
|
||||
) -> ClientSecurity {
|
||||
ClientSecurity {
|
||||
ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
max_connections: None,
|
||||
rate_limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ip(s: &str) -> Ipv4Addr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
// ── No restrictions (empty security) ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_security_allows_all() {
|
||||
let sec = security(None, None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_lists_allow_all() {
|
||||
let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![]));
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
// ── Source IP allow list ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_allow_exact_match() {
|
||||
let sec = security(Some(vec!["10.0.0.1"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_cidr() {
|
||||
let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_wildcard() {
|
||||
let sec = security(Some(vec!["10.0.*.*"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_range() {
|
||||
let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Source IP block list (deny overrides) ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_block_overrides_allow() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
Some(vec!["192.168.1.100"]),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Destination allow list ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_allow_exact() {
|
||||
let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dst_allow_cidr() {
|
||||
let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Destination block list (deny overrides) ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_block_overrides_allow() {
|
||||
let sec = security(
|
||||
None,
|
||||
None,
|
||||
Some(vec!["10.0.0.0/8"]),
|
||||
Some(vec!["10.0.0.99"]),
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Combined source + destination ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn combined_src_and_dst_filtering() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
None,
|
||||
Some(vec!["8.8.8.8"]),
|
||||
None,
|
||||
);
|
||||
// Valid source, valid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow);
|
||||
// Invalid source
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc);
|
||||
// Valid source, invalid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── IP matching edge cases ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wildcard_single_octet() {
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*"));
|
||||
assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_boundaries() {
|
||||
assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pattern_no_match() {
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0"));
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
@@ -12,6 +10,8 @@ use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -19,9 +19,17 @@ use crate::transport;
|
||||
pub struct ClientConfig {
|
||||
pub server_url: String,
|
||||
pub server_public_key: String,
|
||||
/// Client's Noise IK static private key (base64) — required for authentication.
|
||||
pub client_private_key: String,
|
||||
/// Client's Noise IK static public key (base64) — for reference/display.
|
||||
pub client_public_key: String,
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
/// Transport type: "websocket" (default) or "quic".
|
||||
pub transport: Option<String>,
|
||||
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
|
||||
pub server_cert_hash: Option<String>,
|
||||
}
|
||||
|
||||
/// Client statistics.
|
||||
@@ -100,22 +108,83 @@ impl VpnClient {
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
// Decode keys
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
let client_priv_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.client_private_key,
|
||||
)?;
|
||||
|
||||
// Connect to WebSocket server
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
// Create transport based on configuration
|
||||
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
|
||||
let transport_type = config.transport.as_deref().unwrap_or("auto");
|
||||
match transport_type {
|
||||
"quic" => {
|
||||
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
|
||||
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
info!("Connected via QUIC");
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
"websocket" => {
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
_ => {
|
||||
// "auto" (default): try QUIC first, fall back to WebSocket
|
||||
// Extract host:port from the URL for QUIC attempt
|
||||
let quic_addr = extract_host_port(&config.server_url);
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
if let Some(ref addr) = quic_addr {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(3),
|
||||
try_quic_connect(addr, cert_hash),
|
||||
).await {
|
||||
Ok(Ok((quic_sink, quic_stream))) => {
|
||||
info!("Auto: connected via QUIC to {}", addr);
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC unavailable)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Auto: QUIC timed out, falling back to WebSocket");
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC timed out)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Can't extract host:port for QUIC, use WebSocket directly
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Noise IK handshake (client side = initiator, presents static key)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
let mut initiator = crypto::create_initiator(&server_pub_key)?;
|
||||
let mut initiator = crypto::create_initiator(&client_priv_key, &server_pub_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es
|
||||
// -> e, es, s, ss
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let init_frame = Frame {
|
||||
packet_type: PacketType::HandshakeInit,
|
||||
@@ -123,13 +192,11 @@ impl VpnClient {
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// <- e, ee
|
||||
let resp_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
|
||||
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
|
||||
// <- e, ee, se
|
||||
let resp_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
};
|
||||
|
||||
@@ -145,9 +212,9 @@ impl VpnClient {
|
||||
let mut noise_transport = initiator.into_transport_mode()?;
|
||||
|
||||
// Receive assigned IP info (encrypted)
|
||||
let info_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected IP info message"),
|
||||
let info_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before IP info"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&info_msg[..]);
|
||||
@@ -184,8 +251,8 @@ impl VpnClient {
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
ws_sink,
|
||||
ws_stream,
|
||||
sink,
|
||||
stream,
|
||||
noise_transport,
|
||||
state,
|
||||
stats,
|
||||
@@ -280,8 +347,8 @@ impl VpnClient {
|
||||
|
||||
/// The main client packet forwarding loop (runs in a spawned task).
|
||||
async fn client_loop(
|
||||
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
|
||||
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
mut noise_transport: snow::TransportState,
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
@@ -294,10 +361,10 @@ async fn client_loop(
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
Ok(Some(data)) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
@@ -328,17 +395,13 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Connection closed");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_sink.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
Err(e) => {
|
||||
error!("Transport error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
@@ -354,7 +417,7 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
@@ -385,12 +448,51 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
|
||||
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
let _ = sink.close().await;
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect via QUIC. Returns transport halves on success.
|
||||
async fn try_quic_connect(
|
||||
addr: &str,
|
||||
cert_hash: Option<&str>,
|
||||
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
|
||||
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
|
||||
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
Ok((sink, stream))
|
||||
}
|
||||
|
||||
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
|
||||
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
|
||||
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
|
||||
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
|
||||
fn extract_host_port(url: &str) -> Option<String> {
|
||||
if url.starts_with("ws://") || url.starts_with("wss://") {
|
||||
// Parse as URL
|
||||
let stripped = if url.starts_with("wss://") {
|
||||
&url[6..]
|
||||
} else {
|
||||
&url[5..]
|
||||
};
|
||||
// Remove path
|
||||
let host_port = stripped.split('/').next()?;
|
||||
if host_port.contains(':') {
|
||||
Some(host_port.to_string())
|
||||
} else {
|
||||
// Default port
|
||||
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
|
||||
Some(format!("{}:{}", host_port, default_port))
|
||||
}
|
||||
} else if url.contains(':') {
|
||||
// Already host:port
|
||||
Some(url.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
362
rust/src/client_registry.rs
Normal file
362
rust/src/client_registry.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Per-client rate limiting configuration.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
/// Per-client security settings — aligned with SmartProxy's IRouteSecurity pattern.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientSecurity {
|
||||
/// Source IPs/CIDRs the client may connect FROM (empty/None = any).
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// Source IPs blocked — overrides ip_allow_list (deny wins).
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Destination IPs/CIDRs the client may reach (empty/None = all).
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
/// Destination IPs blocked — overrides destination_allow_list (deny wins).
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
/// Max concurrent connections from this client.
|
||||
pub max_connections: Option<u32>,
|
||||
/// Per-client rate limiting.
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
/// A registered client entry — the server-side source of truth.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientEntry {
|
||||
/// Human-readable client ID (e.g. "alice-laptop").
|
||||
pub client_id: String,
|
||||
/// Client's Noise IK public key (base64).
|
||||
pub public_key: String,
|
||||
/// Client's WireGuard public key (base64) — optional.
|
||||
pub wg_public_key: Option<String>,
|
||||
/// Security settings (ACLs, rate limits).
|
||||
pub security: Option<ClientSecurity>,
|
||||
/// Traffic priority (lower = higher priority, default: 100).
|
||||
pub priority: Option<u32>,
|
||||
/// Whether this client is enabled (default: true).
|
||||
pub enabled: Option<bool>,
|
||||
/// Tags for grouping.
|
||||
pub tags: Option<Vec<String>>,
|
||||
/// Optional description.
|
||||
pub description: Option<String>,
|
||||
/// Optional expiry (ISO 8601 timestamp).
|
||||
pub expires_at: Option<String>,
|
||||
/// Assigned VPN IP address.
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientEntry {
|
||||
/// Whether this client is considered enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Whether this client has expired based on current time.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(ref expires) = self.expires_at {
|
||||
if let Ok(expiry) = chrono::DateTime::parse_from_rfc3339(expires) {
|
||||
return chrono::Utc::now() > expiry;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory client registry with dual-key indexing.
|
||||
pub struct ClientRegistry {
|
||||
/// Primary index: clientId → ClientEntry
|
||||
entries: HashMap<String, ClientEntry>,
|
||||
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
||||
key_index: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
key_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a registry from a list of client entries.
|
||||
pub fn from_entries(entries: Vec<ClientEntry>) -> Result<Self> {
|
||||
let mut registry = Self::new();
|
||||
for entry in entries {
|
||||
registry.add(entry)?;
|
||||
}
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Add a client to the registry.
|
||||
pub fn add(&mut self, entry: ClientEntry) -> Result<()> {
|
||||
if self.entries.contains_key(&entry.client_id) {
|
||||
anyhow::bail!("Client '{}' already exists", entry.client_id);
|
||||
}
|
||||
if self.key_index.contains_key(&entry.public_key) {
|
||||
anyhow::bail!("Public key already registered to another client");
|
||||
}
|
||||
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
||||
self.entries.insert(entry.client_id.clone(), entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a client by ID.
|
||||
pub fn remove(&mut self, client_id: &str) -> Result<ClientEntry> {
|
||||
let entry = self.entries.remove(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
self.key_index.remove(&entry.public_key);
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Get a client by ID.
|
||||
pub fn get_by_id(&self, client_id: &str) -> Option<&ClientEntry> {
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Get a client by public key (used during IK handshake verification).
|
||||
pub fn get_by_key(&self, public_key: &str) -> Option<&ClientEntry> {
|
||||
let client_id = self.key_index.get(public_key)?;
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Check if a public key is authorized (exists, enabled, not expired).
|
||||
pub fn is_authorized(&self, public_key: &str) -> bool {
|
||||
match self.get_by_key(public_key) {
|
||||
Some(entry) => entry.is_enabled() && !entry.is_expired(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client entry. The closure receives a mutable reference to the entry.
|
||||
pub fn update<F>(&mut self, client_id: &str, updater: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut ClientEntry),
|
||||
{
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
let old_key = entry.public_key.clone();
|
||||
updater(entry);
|
||||
// If public key changed, update the index
|
||||
if entry.public_key != old_key {
|
||||
self.key_index.remove(&old_key);
|
||||
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all client entries.
|
||||
pub fn list(&self) -> Vec<&ClientEntry> {
|
||||
self.entries.values().collect()
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns the updated entry.
|
||||
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
// Update key index
|
||||
self.key_index.remove(&entry.public_key);
|
||||
entry.public_key = new_public_key.clone();
|
||||
entry.wg_public_key = new_wg_public_key;
|
||||
self.key_index.insert(new_public_key, client_id.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered clients.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(id: &str, key: &str) -> ClientEntry {
|
||||
ClientEntry {
|
||||
client_id: id.to_string(),
|
||||
public_key: key.to_string(),
|
||||
wg_public_key: None,
|
||||
security: None,
|
||||
priority: None,
|
||||
enabled: None,
|
||||
tags: None,
|
||||
description: None,
|
||||
expires_at: None,
|
||||
assigned_ip: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_lookup() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
assert!(reg.get_by_id("alice").is_some());
|
||||
assert!(reg.get_by_key("key_alice").is_some());
|
||||
assert_eq!(reg.get_by_key("key_alice").unwrap().client_id, "alice");
|
||||
assert!(reg.get_by_id("bob").is_none());
|
||||
assert!(reg.get_by_key("key_bob").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_id() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key1")).unwrap();
|
||||
assert!(reg.add(make_entry("alice", "key2")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "same_key")).unwrap();
|
||||
assert!(reg.add(make_entry("bob", "same_key")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert_eq!(reg.len(), 1);
|
||||
|
||||
let removed = reg.remove("alice").unwrap();
|
||||
assert_eq!(removed.client_id, "alice");
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(reg.get_by_key("key_alice").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.remove("ghost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_enabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert!(reg.is_authorized("key_alice")); // enabled by default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_disabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.enabled = Some(false);
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_expired() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2020-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_future_expiry() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2099-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_unknown_key() {
|
||||
let reg = ClientRegistry::new();
|
||||
assert!(!reg.is_authorized("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
reg.update("alice", |entry| {
|
||||
entry.description = Some("Updated".to_string());
|
||||
entry.enabled = Some(false);
|
||||
}).unwrap();
|
||||
|
||||
let entry = reg.get_by_id("alice").unwrap();
|
||||
assert_eq!(entry.description.as_deref(), Some("Updated"));
|
||||
assert!(!entry.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.update("ghost", |_| {}).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rotate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "old_key")).unwrap();
|
||||
|
||||
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
|
||||
|
||||
assert!(reg.get_by_key("old_key").is_none());
|
||||
assert!(reg.get_by_key("new_key").is_some());
|
||||
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_entries() {
|
||||
let entries = vec![
|
||||
make_entry("alice", "key_a"),
|
||||
make_entry("bob", "key_b"),
|
||||
];
|
||||
let reg = ClientRegistry::from_entries(entries).unwrap();
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert!(reg.get_by_key("key_a").is_some());
|
||||
assert!(reg.get_by_key("key_b").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_clients() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_a")).unwrap();
|
||||
reg.add(make_entry("bob", "key_b")).unwrap();
|
||||
let list = reg.list();
|
||||
assert_eq!(list.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_with_rate_limit() {
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.security = Some(ClientSecurity {
|
||||
ip_allow_list: Some(vec!["192.168.1.0/24".to_string()]),
|
||||
ip_block_list: Some(vec!["192.168.1.100".to_string()]),
|
||||
destination_allow_list: None,
|
||||
destination_block_list: None,
|
||||
max_connections: Some(5),
|
||||
rate_limit: Some(ClientRateLimit {
|
||||
bytes_per_sec: 1_000_000,
|
||||
burst_bytes: 2_000_000,
|
||||
}),
|
||||
});
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(entry).unwrap();
|
||||
let e = reg.get_by_id("alice").unwrap();
|
||||
let sec = e.security.as_ref().unwrap();
|
||||
assert_eq!(sec.rate_limit.as_ref().unwrap().bytes_per_sec, 1_000_000);
|
||||
assert_eq!(sec.max_connections, Some(5));
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,10 @@ use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use snow::Builder;
|
||||
|
||||
/// Noise protocol pattern: NK (client knows server pubkey, no client auth at Noise level)
|
||||
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
|
||||
/// Noise protocol pattern: IK (client presents static key, server authenticates client)
|
||||
/// IK = Initiator's static key is transmitted; responder's Key is pre-known.
|
||||
/// This provides mutual authentication: server verifies client identity via public key.
|
||||
const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
/// Generate a new Noise static keypair.
|
||||
/// Returns (public_key_base64, private_key_base64).
|
||||
@@ -22,18 +24,23 @@ pub fn generate_keypair_raw() -> Result<snow::Keypair> {
|
||||
Ok(builder.generate_keypair()?)
|
||||
}
|
||||
|
||||
/// Create a Noise NK initiator (client side).
|
||||
/// The client knows the server's static public key.
|
||||
pub fn create_initiator(server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
/// Create a Noise IK initiator (client side).
|
||||
/// The client provides its own static keypair AND the server's public key.
|
||||
/// The client's static key is transmitted (encrypted) during the handshake,
|
||||
/// allowing the server to authenticate the client.
|
||||
pub fn create_initiator(client_private_key: &[u8], server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
.local_private_key(client_private_key)
|
||||
.remote_public_key(server_public_key)
|
||||
.build_initiator()?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Create a Noise NK responder (server side).
|
||||
/// Create a Noise IK responder (server side).
|
||||
/// The server uses its static private key.
|
||||
/// After the handshake, call `get_remote_static()` on the HandshakeState
|
||||
/// (before `into_transport_mode()`) to retrieve the client's public key.
|
||||
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
@@ -42,19 +49,20 @@ pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Perform the full Noise NK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport).
|
||||
/// Perform the full Noise IK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport, client_public_key).
|
||||
/// The client_public_key is extracted from the responder before entering transport mode.
|
||||
pub fn perform_handshake(
|
||||
mut initiator: snow::HandshakeState,
|
||||
mut responder: snow::HandshakeState,
|
||||
) -> Result<(snow::TransportState, snow::TransportState)> {
|
||||
) -> Result<(snow::TransportState, snow::TransportState, Vec<u8>)> {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es (initiator sends)
|
||||
// -> e, es, s, ss (initiator sends ephemeral + encrypted static key)
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let msg1 = buf[..len].to_vec();
|
||||
|
||||
// <- e, ee (responder reads and responds)
|
||||
// <- e, ee, se (responder reads and responds)
|
||||
responder.read_message(&msg1, &mut buf)?;
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let msg2 = buf[..len].to_vec();
|
||||
@@ -62,10 +70,16 @@ pub fn perform_handshake(
|
||||
// Initiator reads response
|
||||
initiator.read_message(&msg2, &mut buf)?;
|
||||
|
||||
// Extract client's public key from responder BEFORE entering transport mode
|
||||
let client_public_key = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake did not provide client static key"))?
|
||||
.to_vec();
|
||||
|
||||
let i_transport = initiator.into_transport_mode()?;
|
||||
let r_transport = responder.into_transport_mode()?;
|
||||
|
||||
Ok((i_transport, r_transport))
|
||||
Ok((i_transport, r_transport, client_public_key))
|
||||
}
|
||||
|
||||
/// XChaCha20-Poly1305 encryption for post-handshake data.
|
||||
@@ -135,15 +149,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_handshake() {
|
||||
fn noise_ik_handshake() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
let initiator = create_initiator(&server_kp.public).unwrap();
|
||||
let initiator = create_initiator(&client_kp.private, &server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
let (mut i_transport, mut r_transport) =
|
||||
let (mut i_transport, mut r_transport, remote_key) =
|
||||
perform_handshake(initiator, responder).unwrap();
|
||||
|
||||
// Verify the server received the client's public key
|
||||
assert_eq!(remote_key, client_kp.public);
|
||||
|
||||
// Test encrypted communication
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let plaintext = b"hello from client";
|
||||
@@ -159,6 +177,20 @@ mod tests {
|
||||
assert_eq!(&out[..len], plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_ik_wrong_server_key_fails() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let wrong_server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
// Client uses wrong server public key
|
||||
let initiator = create_initiator(&client_kp.private, &wrong_server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
// Handshake should fail because client targeted wrong server
|
||||
assert!(perform_handshake(initiator, responder).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xchacha_encrypt_decrypt() {
|
||||
let key = [42u8; 32];
|
||||
|
||||
@@ -5,6 +5,8 @@ pub mod management;
|
||||
pub mod codec;
|
||||
pub mod crypto;
|
||||
pub mod transport;
|
||||
pub mod transport_trait;
|
||||
pub mod quic_transport;
|
||||
pub mod keepalive;
|
||||
pub mod tunnel;
|
||||
pub mod network;
|
||||
@@ -15,3 +17,7 @@ pub mod telemetry;
|
||||
pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
pub mod wireguard;
|
||||
pub mod client_registry;
|
||||
pub mod acl;
|
||||
pub mod proxy_protocol;
|
||||
|
||||
@@ -7,6 +7,7 @@ use tracing::{info, error, warn};
|
||||
use crate::client::{ClientConfig, VpnClient};
|
||||
use crate::crypto;
|
||||
use crate::server::{ServerConfig, VpnServer};
|
||||
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig};
|
||||
|
||||
// ============================================================================
|
||||
// IPC protocol types
|
||||
@@ -93,6 +94,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
|
||||
let mut vpn_client = VpnClient::new();
|
||||
let mut vpn_server = VpnServer::new();
|
||||
let mut wg_client = WgClient::new();
|
||||
let mut wg_server = WgServer::new();
|
||||
|
||||
send_event_stdout("ready", serde_json::json!({ "mode": mode }));
|
||||
|
||||
@@ -127,8 +130,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
};
|
||||
|
||||
let response = match mode {
|
||||
"client" => handle_client_request(&request, &mut vpn_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server).await,
|
||||
"client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await,
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
send_response_stdout(&response);
|
||||
@@ -150,6 +153,8 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
// Shared state behind Mutex for socket mode (multiple connections)
|
||||
let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new()));
|
||||
let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new()));
|
||||
let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new()));
|
||||
let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new()));
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
@@ -157,9 +162,11 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
let mode = mode.to_string();
|
||||
let client = vpn_client.clone();
|
||||
let server = vpn_server.clone();
|
||||
let wg_c = wg_client.clone();
|
||||
let wg_s = wg_server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_socket_connection(stream, &mode, client, server).await
|
||||
handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await
|
||||
{
|
||||
warn!("Socket connection error: {}", e);
|
||||
}
|
||||
@@ -177,6 +184,8 @@ async fn handle_socket_connection(
|
||||
mode: &str,
|
||||
vpn_client: std::sync::Arc<Mutex<VpnClient>>,
|
||||
vpn_server: std::sync::Arc<Mutex<VpnServer>>,
|
||||
wg_client: std::sync::Arc<Mutex<WgClient>>,
|
||||
wg_server: std::sync::Arc<Mutex<WgServer>>,
|
||||
) -> Result<()> {
|
||||
let (reader, mut writer) = stream.into_split();
|
||||
let buf_reader = BufReader::new(reader);
|
||||
@@ -227,11 +236,13 @@ async fn handle_socket_connection(
|
||||
let response = match mode {
|
||||
"client" => {
|
||||
let mut client = vpn_client.lock().await;
|
||||
handle_client_request(&request, &mut client).await
|
||||
let mut wg_c = wg_client.lock().await;
|
||||
handle_client_request(&request, &mut client, &mut wg_c).await
|
||||
}
|
||||
"server" => {
|
||||
let mut server = vpn_server.lock().await;
|
||||
handle_server_request(&request, &mut server).await
|
||||
let mut wg_s = wg_server.lock().await;
|
||||
handle_server_request(&request, &mut server, &mut wg_s).await
|
||||
}
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
@@ -252,38 +263,79 @@ async fn handle_socket_connection(
|
||||
async fn handle_client_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_client: &mut VpnClient,
|
||||
wg_client: &mut WgClient,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"connect" => {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transport is "wireguard"
|
||||
let transport = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transport"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
if transport == "wireguard" {
|
||||
let config: WgClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("WG connect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ClientConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_client.connect(config).await {
|
||||
Ok(assigned_ip) => {
|
||||
ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip }))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnect" => {
|
||||
if wg_client.is_running() {
|
||||
match wg_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG disconnect failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disconnect" => match vpn_client.disconnect().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_status().await)
|
||||
} else {
|
||||
let status = vpn_client.get_status().await;
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
if wg_client.is_running() {
|
||||
ManagementResponse::ok(id, wg_client.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_client.get_statistics().await;
|
||||
ManagementResponse::ok(id, stats)
|
||||
}
|
||||
}
|
||||
"getConnectionQuality" => {
|
||||
match vpn_client.get_connection_quality() {
|
||||
@@ -329,45 +381,92 @@ async fn handle_client_request(
|
||||
async fn handle_server_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_server: &mut VpnServer,
|
||||
wg_server: &mut WgServer,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
// Check if transportMode is "wireguard"
|
||||
let transport_mode = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transportMode"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
if transport_mode == "wireguard" {
|
||||
let config: WgServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
if wg_server.is_running() {
|
||||
match wg_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
},
|
||||
"getStatus" => {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_status())
|
||||
} else {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
}
|
||||
"getStatistics" => {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"listClients" => {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
if wg_server.is_running() {
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
"disconnectClient" => {
|
||||
@@ -436,6 +535,153 @@ async fn handle_server_request(
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
"generateWgKeypair" => {
|
||||
let (public_key, private_key) = wireguard::generate_wg_keypair();
|
||||
ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
)
|
||||
}
|
||||
"addWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let config: WgPeerConfig = match serde_json::from_value(
|
||||
request.params.get("peer").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid peer config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.add_peer(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing publicKey".to_string()),
|
||||
};
|
||||
match wg_server.remove_peer(&public_key).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listWgPeers" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
// ── Client Registry (Hub) Commands ────────────────────────────────
|
||||
"createClient" => {
|
||||
let client_partial = request.params.get("client").cloned().unwrap_or_default();
|
||||
match vpn_server.create_client(client_partial).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Create client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_registered_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.get_registered_client(&client_id).await {
|
||||
Ok(entry) => ManagementResponse::ok(id, entry),
|
||||
Err(e) => ManagementResponse::err(id, format!("Get client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listRegisteredClients" => {
|
||||
let clients = vpn_server.list_registered_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"updateClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let update = request.params.get("update").cloned().unwrap_or_default();
|
||||
match vpn_server.update_registered_client(&client_id, update).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Update client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"enableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.enable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Enable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.disable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"rotateClientKey" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.rotate_client_key(&client_id).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Key rotation failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"exportClientConfig" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let format = request.params.get("format").and_then(|v| v.as_str()).unwrap_or("smartvpn");
|
||||
match vpn_server.export_client_config(&client_id, format).await {
|
||||
Ok(config) => ManagementResponse::ok(id, config),
|
||||
Err(e) => ManagementResponse::err(id, format!("Export failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"generateClientKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
_ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
261
rust/src/proxy_protocol.rs
Normal file
261
rust/src/proxy_protocol.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
//! PROXY protocol v2 parser for extracting real client addresses
|
||||
//! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.).
|
||||
//!
|
||||
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
|
||||
|
||||
use anyhow::Result;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Timeout for reading the PROXY protocol header from a new connection.
|
||||
const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// The 12-byte PP v2 signature.
|
||||
const PP_V2_SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
/// Parsed PROXY protocol v2 header.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyHeader {
|
||||
/// Real client source address.
|
||||
pub src_addr: SocketAddr,
|
||||
/// Proxy-to-server destination address.
|
||||
pub dst_addr: SocketAddr,
|
||||
/// True if this is a LOCAL command (health check probe from proxy).
|
||||
pub is_local: bool,
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
///
|
||||
/// Reads exactly the header bytes — the stream is in a clean state for
|
||||
/// WebSocket upgrade afterward. Returns an error on timeout, invalid
|
||||
/// signature, or malformed header.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream))
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))?
|
||||
}
|
||||
|
||||
async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
// Read the 16-byte fixed prefix
|
||||
let mut prefix = [0u8; 16];
|
||||
stream.read_exact(&mut prefix).await?;
|
||||
|
||||
// Validate the 12-byte signature
|
||||
if prefix[..12] != PP_V2_SIGNATURE {
|
||||
anyhow::bail!("Invalid PROXY protocol v2 signature");
|
||||
}
|
||||
|
||||
// Byte 12: version (high nibble) | command (low nibble)
|
||||
let version = (prefix[12] & 0xF0) >> 4;
|
||||
let command = prefix[12] & 0x0F;
|
||||
|
||||
if version != 2 {
|
||||
anyhow::bail!("Unsupported PROXY protocol version: {}", version);
|
||||
}
|
||||
|
||||
// Byte 13: address family (high nibble) | protocol (low nibble)
|
||||
let addr_family = (prefix[13] & 0xF0) >> 4;
|
||||
let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP)
|
||||
|
||||
// Bytes 14-15: address data length (big-endian)
|
||||
let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize;
|
||||
|
||||
// Read the address data
|
||||
let mut addr_data = vec![0u8; addr_len];
|
||||
if addr_len > 0 {
|
||||
stream.read_exact(&mut addr_data).await?;
|
||||
}
|
||||
|
||||
// LOCAL command (0x00) = health check, no real address
|
||||
if command == 0x00 {
|
||||
return Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
is_local: true,
|
||||
});
|
||||
}
|
||||
|
||||
// PROXY command (0x01) — parse address block
|
||||
if command != 0x01 {
|
||||
anyhow::bail!("Unknown PROXY protocol command: {}", command);
|
||||
}
|
||||
|
||||
match addr_family {
|
||||
// AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes
|
||||
1 => {
|
||||
if addr_data.len() < 12 {
|
||||
anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
|
||||
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes
|
||||
2 => {
|
||||
if addr_data.len() < 36 {
|
||||
anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
|
||||
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)),
|
||||
dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_UNSPEC or unknown
|
||||
_ => {
|
||||
anyhow::bail!("Unsupported address family: {}", addr_family);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
|
||||
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
|
||||
match (src, dst) {
|
||||
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x11); // AF_INET | STREAM
|
||||
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x21); // AF_INET6 | STREAM
|
||||
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
_ => panic!("Mismatched address families"),
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 LOCAL header (health check probe).
|
||||
pub fn build_pp_v2_local() -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
buf.push(0x20); // version 2 | LOCAL command
|
||||
buf.push(0x00); // AF_UNSPEC
|
||||
buf.extend_from_slice(&0u16.to_be_bytes()); // no address data
|
||||
buf
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Helper: create a TCP pair and write data to the client side, then parse from server side.
|
||||
async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result<ProxyHeader> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let data = header_bytes.to_vec();
|
||||
let client_task = tokio::spawn(async move {
|
||||
let mut client = TcpStream::connect(addr).await.unwrap();
|
||||
client.write_all(&data).await.unwrap();
|
||||
client // keep alive
|
||||
});
|
||||
|
||||
let (mut server_stream, _) = listener.accept().await.unwrap();
|
||||
let result = read_proxy_header(&mut server_stream).await;
|
||||
let _client = client_task.await.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv4_header() {
|
||||
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
|
||||
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv6_header() {
|
||||
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[2001:db8::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_local_command() {
|
||||
let header = build_pp_v2_local();
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(parsed.is_local);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_invalid_signature() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[0] = 0xFF; // corrupt signature
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("signature"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_wrong_version() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[12] = 0x10; // version 1 instead of 2
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("version"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_truncated_header() {
|
||||
// Only 10 bytes — not even the full signature
|
||||
let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv4_header_is_exactly_28_bytes() {
|
||||
let src = "1.2.3.4:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "5.6.7.8:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28
|
||||
assert_eq!(header.len(), 28);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6_header_is_exactly_52_bytes() {
|
||||
let src = "[::1]:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52
|
||||
assert_eq!(header.len(), 52);
|
||||
}
|
||||
}
|
||||
546
rust/src/quic_transport.rs
Normal file
546
rust/src/quic_transport.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use quinn::crypto::rustls::QuicClientConfig;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn, debug};
|
||||
|
||||
use crate::transport_trait::{TransportSink, TransportStream};
|
||||
|
||||
// ============================================================================
|
||||
// TLS / Certificate helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a self-signed certificate and private key for QUIC.
|
||||
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
|
||||
let cert_der = CertificateDer::from(cert.cert);
|
||||
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
|
||||
Ok((vec![cert_der], key_der))
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
|
||||
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
|
||||
use ring::digest;
|
||||
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server-side QUIC endpoint
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for the QUIC server endpoint.
|
||||
pub struct QuicServerConfig {
|
||||
pub listen_addr: String,
|
||||
pub cert_chain: Vec<CertificateDer<'static>>,
|
||||
pub private_key: PrivateKeyDer<'static>,
|
||||
pub idle_timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// Create a QUIC server endpoint bound to the given address.
|
||||
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
|
||||
let addr: SocketAddr = config.listen_addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(config.cert_chain, config.private_key)?;
|
||||
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
|
||||
));
|
||||
// Enable datagrams with a generous max size
|
||||
transport.datagram_receive_buffer_size(Some(65535));
|
||||
transport.datagram_send_buffer_size(65535);
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
let endpoint = quinn::Endpoint::server(server_config, addr)?;
|
||||
info!("QUIC server listening on {}", addr);
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client-side QUIC connection
|
||||
// ============================================================================
|
||||
|
||||
/// A certificate verifier that accepts any server certificate.
|
||||
/// Safe when Noise NK provides server authentication at the application layer.
|
||||
#[derive(Debug)]
|
||||
struct AcceptAnyCert;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
|
||||
#[derive(Debug)]
|
||||
struct CertHashVerifier {
|
||||
expected_hash: String,
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
let actual_hash = cert_hash(end_entity);
|
||||
if actual_hash == self.expected_hash {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
} else {
|
||||
Err(rustls::Error::General(format!(
|
||||
"Certificate hash mismatch: expected {}, got {}",
|
||||
self.expected_hash, actual_hash
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
// QUIC always uses TLS 1.3
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a QUIC server.
|
||||
///
|
||||
/// - If `server_cert_hash` is provided, verifies the server certificate matches
|
||||
/// the given SHA-256 hash (cert pinning).
|
||||
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
|
||||
/// safe because the Noise NK handshake (which runs over the QUIC stream)
|
||||
/// authenticates the server via its pre-shared public key — the same trust
|
||||
/// model as WireGuard.
|
||||
pub async fn connect_quic(
|
||||
addr: &str,
|
||||
server_cert_hash: Option<&str>,
|
||||
) -> Result<quinn::Connection> {
|
||||
let remote: SocketAddr = addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let tls_config = if let Some(hash) = server_cert_hash {
|
||||
// Pin to a specific certificate hash
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash.to_string(),
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
} else {
|
||||
// Accept any cert — Noise NK provides server authentication
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
};
|
||||
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
info!("Connecting to QUIC server at {}", addr);
|
||||
let connection = endpoint.connect(remote, "smartvpn")?.await?;
|
||||
info!("QUIC connection established");
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QUIC Transport Sink / Stream implementations
|
||||
// ============================================================================
|
||||
|
||||
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportSink {
|
||||
send_stream: quinn::SendStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportSink {
|
||||
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
send_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for QuicTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// Length-prefix framing: [4-byte big-endian length][payload]
|
||||
let len = data.len() as u32;
|
||||
self.send_stream.write_all(&len.to_be_bytes()).await?;
|
||||
self.send_stream.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
let max_size = self.connection.max_datagram_size();
|
||||
match max_size {
|
||||
Some(max) if data.len() <= max => {
|
||||
self.connection.send_datagram(data.into())?;
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Datagram too large or datagrams disabled — fall back to reliable
|
||||
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.send_stream.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportStream {
|
||||
recv_stream: quinn::RecvStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportStream {
|
||||
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
recv_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for QuicTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// Read length prefix
|
||||
let mut len_buf = [0u8; 4];
|
||||
match self.recv_stream.read_exact(&mut len_buf).await {
|
||||
Ok(()) => {}
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
|
||||
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
|
||||
warn!("QUIC connection lost: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(len_buf) as usize;
|
||||
if len > 65536 {
|
||||
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
|
||||
}
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
match self.recv_stream.read_exact(&mut data).await {
|
||||
Ok(()) => Ok(Some(data)),
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
match self.connection.read_datagram().await {
|
||||
Ok(data) => Ok(Some(data.to_vec())),
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
|
||||
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
self.connection.max_datagram_size().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a QUIC connection and open a bidirectional control stream.
|
||||
/// Returns the transport sink/stream pair ready for the VPN handshake.
|
||||
pub async fn accept_quic_connection(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
// The client opens the bidirectional control stream
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
info!("QUIC bidirectional control stream accepted");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
/// Open a QUIC connection's bidirectional control stream (client side).
|
||||
pub async fn open_quic_streams(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
info!("QUIC bidirectional control stream opened");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cert_generation_and_hash() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
assert_eq!(certs.len(), 1);
|
||||
let hash = cert_hash(&certs[0]);
|
||||
// SHA-256 base64 is 44 characters
|
||||
assert_eq!(hash.len(), 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cert_hash_deterministic() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
let hash1 = cert_hash(&certs[0]);
|
||||
let hash2 = cert_hash(&certs[0]);
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
/// Helper: create QUIC server and client endpoints.
|
||||
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
|
||||
let (certs, key) = generate_self_signed_cert().unwrap();
|
||||
let hash = cert_hash(&certs[0]);
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
|
||||
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key).unwrap();
|
||||
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
|
||||
));
|
||||
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
|
||||
|
||||
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash,
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(client_tls).unwrap(),
|
||||
));
|
||||
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
client_ep.set_default_client_config(client_config);
|
||||
|
||||
let server_addr = server_ep.local_addr().unwrap().to_string();
|
||||
(server_ep, client_ep, server_addr)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_server_client_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi, read, echo, finish
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
|
||||
let data = s_recv.read_to_end(1024).await.unwrap();
|
||||
s_send.write_all(&data).await.unwrap();
|
||||
s_send.finish().unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open_bi, write, finish, read
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
|
||||
c_send.write_all(b"hello quinn").await.unwrap();
|
||||
c_send.finish().unwrap();
|
||||
let data = c_recv.read_to_end(1024).await.unwrap();
|
||||
assert_eq!(&data[..], b"hello quinn");
|
||||
|
||||
let _ = server_task.await;
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test transport trait wrappers over QUIC.
|
||||
/// Key: client must send data first (QUIC streams are opened implicitly by data).
|
||||
/// The server accept_bi runs concurrently with the client's first send_reliable.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_transport_trait_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server task: accept connection, then accept_bi (blocks until client sends data)
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
(s_sink, s_stream, server_ep)
|
||||
});
|
||||
|
||||
// Client: connect, open_bi via wrapper
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
|
||||
// Client sends first — this triggers the QUIC stream to become visible to the server
|
||||
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
|
||||
|
||||
// Now server's accept_bi unblocks
|
||||
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
|
||||
|
||||
// Server reads the message
|
||||
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-client");
|
||||
|
||||
// Server -> Client
|
||||
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
|
||||
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-server");
|
||||
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test QUIC datagram support.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_datagram_exchange() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi (opens control stream), then read datagram
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
// Accept bi stream (control channel)
|
||||
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
|
||||
// Read datagram
|
||||
let dgram = conn.read_datagram().await.unwrap();
|
||||
assert_eq!(&dgram[..], b"dgram-payload");
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
|
||||
|
||||
// Send initial data to open the stream (required for QUIC)
|
||||
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
|
||||
|
||||
// Small yield to let the server process the bi stream
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// Send datagram
|
||||
assert!(conn.max_datagram_size().is_some());
|
||||
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test that supports_datagrams returns true for QUIC transports.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_supports_datagrams() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
assert!(s_stream.supports_datagrams());
|
||||
server_ep
|
||||
});
|
||||
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
assert!(c_stream.supports_datagrams());
|
||||
|
||||
// Send data to trigger server's accept_bi
|
||||
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
}
|
||||
@@ -130,10 +130,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000);
|
||||
// Use a low rate so refill between consecutive calls is negligible
|
||||
let mut tb = TokenBucket::new(100, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
// At 100 bytes/sec, the few μs between calls add ~0 tokens
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
@@ -8,15 +7,18 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
use crate::acl;
|
||||
use crate::client_registry::{ClientEntry, ClientRegistry};
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
use crate::network::IpPool;
|
||||
use crate::ratelimit::TokenBucket;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
|
||||
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
|
||||
@@ -39,6 +41,19 @@ pub struct ServerConfig {
|
||||
pub default_rate_limit_bytes_per_sec: Option<u64>,
|
||||
/// Default burst size for new clients (bytes). None = unlimited.
|
||||
pub default_burst_bytes: Option<u64>,
|
||||
/// Transport mode: "websocket" (default), "quic", or "both".
|
||||
pub transport_mode: Option<String>,
|
||||
/// QUIC listen address (host:port). Defaults to listen_addr.
|
||||
pub quic_listen_addr: Option<String>,
|
||||
/// QUIC idle timeout in seconds (default: 30).
|
||||
pub quic_idle_timeout_secs: Option<u64>,
|
||||
/// Pre-registered clients for IK authentication.
|
||||
pub clients: Option<Vec<ClientEntry>>,
|
||||
/// Enable PROXY protocol v2 parsing on incoming WebSocket connections.
|
||||
/// SECURITY: Must be false when accepting direct client connections.
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -56,6 +71,12 @@ pub struct ClientInfo {
|
||||
pub keepalives_received: u64,
|
||||
pub rate_limit_bytes_per_sec: Option<u64>,
|
||||
pub burst_bytes: Option<u64>,
|
||||
/// Client's authenticated Noise IK public key (base64).
|
||||
pub authenticated_key: String,
|
||||
/// Registered client ID from the client registry.
|
||||
pub registered_client_id: String,
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -82,6 +103,7 @@ pub struct ServerState {
|
||||
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
|
||||
pub mtu_config: MtuConfig,
|
||||
pub started_at: std::time::Instant,
|
||||
pub client_registry: RwLock<ClientRegistry>,
|
||||
}
|
||||
|
||||
/// The VPN server.
|
||||
@@ -121,6 +143,12 @@ impl VpnServer {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
|
||||
|
||||
// Build client registry from config
|
||||
let registry = ClientRegistry::from_entries(
|
||||
config.clients.clone().unwrap_or_default()
|
||||
)?;
|
||||
info!("Client registry loaded with {} entries", registry.len());
|
||||
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
ip_pool: Mutex::new(ip_pool),
|
||||
@@ -129,20 +157,65 @@ impl VpnServer {
|
||||
rate_limiters: Mutex::new(HashMap::new()),
|
||||
mtu_config,
|
||||
started_at: std::time::Instant::now(),
|
||||
client_registry: RwLock::new(registry),
|
||||
});
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||
self.state = Some(state.clone());
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
|
||||
let listen_addr = config.listen_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
info!("VPN server started");
|
||||
match transport_mode {
|
||||
"quic" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
"both" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
let state2 = state.clone();
|
||||
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
|
||||
// Store second shutdown sender so both listeners stop
|
||||
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
|
||||
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
|
||||
self.shutdown_tx = Some(combined_tx);
|
||||
|
||||
// Forward combined shutdown to both listeners
|
||||
tokio::spawn(async move {
|
||||
combined_rx.recv().await;
|
||||
let _ = shutdown_tx_orig.send(()).await;
|
||||
let _ = shutdown_tx2.send(()).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("WebSocket listener error: {}", e);
|
||||
}
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// "websocket" (default)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!("VPN server started (transport: {})", transport_mode);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -237,9 +310,268 @@ impl VpnServer {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ───────────────────────────────────
|
||||
|
||||
/// Create a new client entry. Generates keypairs and assigns an IP.
|
||||
/// Returns a JSON value with the full config bundle including secrets.
|
||||
pub async fn create_client(&self, partial: serde_json::Value) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let client_id = partial.get("clientId")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("clientId is required"))?
|
||||
.to_string();
|
||||
|
||||
// Generate Noise IK keypair for the client
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
|
||||
// Generate WireGuard keypair for the client
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
// Allocate a VPN IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Build entry from partial + generated values
|
||||
let entry = ClientEntry {
|
||||
client_id: client_id.clone(),
|
||||
public_key: noise_pub.clone(),
|
||||
wg_public_key: Some(wg_pub.clone()),
|
||||
security: serde_json::from_value(
|
||||
partial.get("security").cloned().unwrap_or(serde_json::Value::Null)
|
||||
).ok(),
|
||||
priority: partial.get("priority").and_then(|v| v.as_u64()).map(|v| v as u32),
|
||||
enabled: partial.get("enabled").and_then(|v| v.as_bool()).or(Some(true)),
|
||||
tags: partial.get("tags").and_then(|v| {
|
||||
v.as_array().map(|a| a.iter().filter_map(|s| s.as_str().map(String::from)).collect())
|
||||
}),
|
||||
description: partial.get("description").and_then(|v| v.as_str()).map(String::from),
|
||||
expires_at: partial.get("expiresAt").and_then(|v| v.as_str()).map(String::from),
|
||||
assigned_ip: Some(assigned_ip.to_string()),
|
||||
};
|
||||
|
||||
// Add to registry
|
||||
state.client_registry.write().await.add(entry.clone())?;
|
||||
|
||||
// Build SmartVPN client config
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
// Build WireGuard config string
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv,
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
let entry_json = serde_json::to_value(&entry)?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Remove a registered client from the registry (and disconnect if connected).
|
||||
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let entry = state.client_registry.write().await.remove(client_id)?;
|
||||
// Release the IP if assigned
|
||||
if let Some(ref ip_str) = entry.assigned_ip {
|
||||
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
}
|
||||
}
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a registered client by ID.
|
||||
pub async fn get_registered_client(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
Ok(serde_json::to_value(entry)?)
|
||||
}
|
||||
|
||||
/// List all registered clients.
|
||||
pub async fn list_registered_clients(&self) -> Vec<ClientEntry> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.client_registry.read().await.list().into_iter().cloned().collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a registered client's fields.
|
||||
pub async fn update_registered_client(&self, client_id: &str, update: serde_json::Value) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
if let Some(security) = update.get("security") {
|
||||
entry.security = serde_json::from_value(security.clone()).ok();
|
||||
}
|
||||
if let Some(priority) = update.get("priority").and_then(|v| v.as_u64()) {
|
||||
entry.priority = Some(priority as u32);
|
||||
}
|
||||
if let Some(enabled) = update.get("enabled").and_then(|v| v.as_bool()) {
|
||||
entry.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(tags) = update.get("tags").and_then(|v| v.as_array()) {
|
||||
entry.tags = Some(tags.iter().filter_map(|s| s.as_str().map(String::from)).collect());
|
||||
}
|
||||
if let Some(desc) = update.get("description").and_then(|v| v.as_str()) {
|
||||
entry.description = Some(desc.to_string());
|
||||
}
|
||||
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
|
||||
entry.expires_at = Some(expires.to_string());
|
||||
}
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enable a registered client.
|
||||
pub async fn enable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(true);
|
||||
})
|
||||
}
|
||||
|
||||
/// Disable a registered client (also disconnects if connected).
|
||||
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(false);
|
||||
})?;
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
pub async fn rotate_client_key(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
state.client_registry.write().await.rotate_key(
|
||||
client_id,
|
||||
noise_pub.clone(),
|
||||
Some(wg_pub.clone()),
|
||||
)?;
|
||||
|
||||
// Disconnect existing connection (old key is no longer valid)
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
|
||||
// Get updated entry for the config bundle
|
||||
let entry_json = self.get_registered_client(client_id).await?;
|
||||
let assigned_ip = entry_json.get("assignedIp")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0.0.0.0");
|
||||
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv, assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Export a client config (without secrets) in the specified format.
|
||||
pub async fn export_client_config(&self, client_id: &str, format: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
|
||||
match format {
|
||||
"smartvpn" => {
|
||||
Ok(serde_json::json!({
|
||||
"config": {
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPublicKey": entry.public_key,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
}
|
||||
}))
|
||||
}
|
||||
"wireguard" => {
|
||||
let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0");
|
||||
let config = format!(
|
||||
"[Interface]\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
Ok(serde_json::json!({ "config": config }))
|
||||
}
|
||||
_ => anyhow::bail!("Unknown format: {}", format),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_listener(
|
||||
/// WebSocket listener — accepts TCP connections, optionally parses PROXY protocol v2,
|
||||
/// upgrades to WS, then hands off to `handle_client_connection`.
|
||||
async fn run_ws_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
@@ -251,12 +583,58 @@ async fn run_listener(
|
||||
tokio::select! {
|
||||
accept = listener.accept() => {
|
||||
match accept {
|
||||
Ok((stream, addr)) => {
|
||||
info!("New connection from {}", addr);
|
||||
Ok((mut tcp_stream, tcp_addr)) => {
|
||||
info!("New connection from {}", tcp_addr);
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_client_connection(state, stream).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
// Phase 0: Parse PROXY protocol v2 header if enabled
|
||||
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
|
||||
match crate::proxy_protocol::read_proxy_header(&mut tcp_stream).await {
|
||||
Ok(header) if header.is_local => {
|
||||
info!("PP v2 LOCAL probe from {}", tcp_addr);
|
||||
return; // Health check — close gracefully
|
||||
}
|
||||
Ok(header) => {
|
||||
info!("PP v2: real client {} (via {})", header.src_addr, tcp_addr);
|
||||
Some(header.src_addr)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
|
||||
return; // Drop connection
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(tcp_addr) // Direct connection — use TCP SocketAddr
|
||||
};
|
||||
|
||||
// Phase 1: Server-level connection IP block list (pre-handshake)
|
||||
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
|
||||
if !block_list.is_empty() {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, block_list) {
|
||||
warn!("Connection blocked by server IP block list: {}", addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: WebSocket upgrade + VPN handshake
|
||||
match transport::accept_connection(tcp_stream).await {
|
||||
Ok(ws) => {
|
||||
let (sink, stream) = transport_trait::split_ws(ws);
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
remote_addr,
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("WebSocket upgrade failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -275,29 +653,110 @@ async fn run_listener(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
|
||||
/// `handle_client_connection`.
|
||||
async fn run_quic_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
idle_timeout_secs: u64,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
) -> Result<()> {
|
||||
// Generate or use configured TLS certificate for QUIC
|
||||
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
|
||||
(&state.config.tls_cert, &state.config.tls_key)
|
||||
{
|
||||
// Parse PEM certificates
|
||||
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut cert_pem.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
|
||||
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
|
||||
(certs, key)
|
||||
} else {
|
||||
// Generate self-signed certificate
|
||||
let (certs, key) = quic_transport::generate_self_signed_cert()?;
|
||||
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
|
||||
(certs, key)
|
||||
};
|
||||
|
||||
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
|
||||
listen_addr,
|
||||
cert_chain,
|
||||
private_key,
|
||||
idle_timeout_secs,
|
||||
})?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
incoming = endpoint.accept() => {
|
||||
match incoming {
|
||||
Some(incoming) => {
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
match incoming.await {
|
||||
Ok(conn) => {
|
||||
let remote = conn.remote_address();
|
||||
info!("New QUIC connection from {}", remote);
|
||||
match quic_transport::accept_quic_connection(conn).await {
|
||||
Ok((sink, stream)) => {
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
Some(remote),
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC stream accept failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC handshake failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
None => {
|
||||
info!("QUIC endpoint closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("QUIC shutdown signal received");
|
||||
endpoint.close(0u32.into(), b"shutdown");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transport-agnostic client handler. Performs the Noise IK handshake, authenticates
|
||||
/// the client against the registry, and runs the main packet forwarding loop.
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
stream: tokio::net::TcpStream,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>,
|
||||
) -> Result<()> {
|
||||
let ws = transport::accept_connection(stream).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
|
||||
let client_id = uuid_v4();
|
||||
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
let server_private_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&state.config.private_key,
|
||||
)?;
|
||||
|
||||
// Noise IK handshake (server side = responder)
|
||||
let mut responder = crypto::create_responder(&server_private_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// Receive handshake init
|
||||
let init_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected handshake init message"),
|
||||
// Receive handshake init (-> e, es, s, ss)
|
||||
let init_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before handshake"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&init_msg[..]);
|
||||
@@ -309,6 +768,47 @@ async fn handle_client_connection(
|
||||
}
|
||||
|
||||
responder.read_message(&frame.payload, &mut buf)?;
|
||||
|
||||
// Extract client's static public key BEFORE entering transport mode
|
||||
let client_pub_key_bytes = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake: no client static key received"))?
|
||||
.to_vec();
|
||||
let client_pub_key_b64 = base64::Engine::encode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&client_pub_key_bytes,
|
||||
);
|
||||
|
||||
// Verify client against registry
|
||||
let (registered_client_id, client_security) = {
|
||||
let registry = state.client_registry.read().await;
|
||||
if !registry.is_authorized(&client_pub_key_b64) {
|
||||
warn!("Rejecting unauthorized client with key {}", &client_pub_key_b64[..8]);
|
||||
// Send handshake response but then disconnect
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_frame = Frame {
|
||||
packet_type: PacketType::HandshakeResp,
|
||||
payload: buf[..len].to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// Send disconnect frame
|
||||
let disconnect_frame = Frame {
|
||||
packet_type: PacketType::Disconnect,
|
||||
payload: Vec::new(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Client not authorized");
|
||||
}
|
||||
let entry = registry.get_by_key(&client_pub_key_b64).unwrap();
|
||||
(entry.client_id.clone(), entry.security.clone())
|
||||
};
|
||||
|
||||
// Complete handshake (<- e, ee, se)
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_payload = buf[..len].to_vec();
|
||||
|
||||
@@ -318,13 +818,46 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Register client
|
||||
let default_rate = state.config.default_rate_limit_bytes_per_sec;
|
||||
let default_burst = state.config.default_burst_bytes;
|
||||
// Connection-level ACL: check real client IP against per-client ipAllowList/ipBlockList
|
||||
if let (Some(ref sec), Some(ref addr)) = (&client_security, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if !acl::is_source_allowed(
|
||||
v4,
|
||||
sec.ip_allow_list.as_deref(),
|
||||
sec.ip_block_list.as_deref(),
|
||||
) {
|
||||
warn!("Connection-level ACL denied client {} from IP {}", registered_client_id, addr);
|
||||
let disconnect_frame = Frame { packet_type: PacketType::Disconnect, payload: Vec::new() };
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Connection denied: source IP {} not allowed for client {}", addr, registered_client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the registered client ID as the connection ID
|
||||
let client_id = registered_client_id.clone();
|
||||
|
||||
// Allocate IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Determine rate limits: per-client security overrides server defaults
|
||||
let (rate_limit, burst) = if let Some(ref sec) = client_security {
|
||||
if let Some(ref rl) = sec.rate_limit {
|
||||
(Some(rl.bytes_per_sec), Some(rl.burst_bytes))
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
}
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
};
|
||||
|
||||
// Register connected client
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: assigned_ip.to_string(),
|
||||
@@ -335,13 +868,16 @@ async fn handle_client_connection(
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: default_rate,
|
||||
burst_bytes: default_burst,
|
||||
rate_limit_bytes_per_sec: rate_limit,
|
||||
burst_bytes: burst,
|
||||
authenticated_key: client_pub_key_b64.clone(),
|
||||
registered_client_id: registered_client_id.clone(),
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
// Set up rate limiter if defaults are configured
|
||||
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
|
||||
// Set up rate limiter
|
||||
if let (Some(rate), Some(burst)) = (rate_limit, burst) {
|
||||
state
|
||||
.rate_limiters
|
||||
.lock()
|
||||
@@ -369,25 +905,45 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
info!("Client {} ({}) connected with IP {} from {}",
|
||||
registered_client_id, &client_pub_key_b64[..8], assigned_ip,
|
||||
remote_addr.map(|a| a.to_string()).unwrap_or_else(|| "unknown".to_string()));
|
||||
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
Ok(Some(data)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
// ACL check on decrypted packet
|
||||
if let Some(ref sec) = client_security {
|
||||
if len >= 20 {
|
||||
// Extract src/dst from IPv4 header
|
||||
let src = Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]);
|
||||
let dst = Ipv4Addr::new(buf[16], buf[17], buf[18], buf[19]);
|
||||
let acl_result = acl::check_acl(sec, src, dst);
|
||||
if acl_result != acl::AclResult::Allow {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting check
|
||||
let allowed = {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
@@ -432,7 +988,7 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
@@ -463,20 +1019,12 @@ async fn handle_client_connection(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
continue;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
Err(e) => {
|
||||
warn!("Transport error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -497,20 +1045,6 @@ async fn handle_client_connection(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn uuid_v4() -> String {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let bytes: [u8; 16] = rng.gen();
|
||||
format!(
|
||||
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
|
||||
bytes[0], bytes[1], bytes[2], bytes[3],
|
||||
bytes[4], bytes[5],
|
||||
bytes[6], bytes[7],
|
||||
bytes[8], bytes[9],
|
||||
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
|
||||
)
|
||||
}
|
||||
|
||||
fn timestamp_now() -> String {
|
||||
use std::time::SystemTime;
|
||||
let duration = SystemTime::now()
|
||||
|
||||
116
rust/src/transport_trait.rs
Normal file
116
rust/src/transport_trait.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::transport::WsStream;
|
||||
|
||||
// ============================================================================
|
||||
// Transport trait abstraction
|
||||
// ============================================================================
|
||||
|
||||
/// Outbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportSink: Send + 'static {
|
||||
/// Send a framed binary message on the reliable channel.
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Send a datagram (unreliable, best-effort).
|
||||
/// Falls back to reliable if the transport does not support datagrams.
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Gracefully close the transport.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Inbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportStream: Send + 'static {
|
||||
/// Receive the next reliable binary message. Returns `None` on close.
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Receive the next datagram. Returns `None` if datagrams are unsupported
|
||||
/// or the connection is closed.
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Whether this transport supports unreliable datagrams.
|
||||
fn supports_datagrams(&self) -> bool;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket implementation
|
||||
// ============================================================================
|
||||
|
||||
/// WebSocket transport sink (wraps the write half of a split WsStream).
|
||||
pub struct WsTransportSink {
|
||||
inner: futures_util::stream::SplitSink<WsStream, Message>,
|
||||
}
|
||||
|
||||
impl WsTransportSink {
|
||||
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for WsTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
self.inner.send(Message::Binary(data.into())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// WebSocket has no datagram support — fall back to reliable.
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.inner.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket transport stream (wraps the read half of a split WsStream).
|
||||
pub struct WsTransportStream {
|
||||
inner: futures_util::stream::SplitStream<WsStream>,
|
||||
}
|
||||
|
||||
impl WsTransportStream {
|
||||
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for WsTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
|
||||
Some(Ok(Message::Close(_))) | None => return Ok(None),
|
||||
Some(Ok(Message::Ping(_))) => {
|
||||
// Ping handling is done at the tungstenite layer automatically
|
||||
// when the sink side is alive. Just skip here.
|
||||
continue;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// WebSocket does not support datagrams.
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a WebSocket stream into transport sink and stream halves.
|
||||
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
|
||||
let (sink, stream) = ws.split();
|
||||
(WsTransportSink::new(sink), WsTransportStream::new(stream))
|
||||
}
|
||||
1329
rust/src/wireguard.rs
Normal file
1329
rust/src/wireguard.rs
Normal file
File diff suppressed because it is too large
Load Diff
320
rust/tests/wg_e2e.rs
Normal file
320
rust/tests/wg_e2e.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! End-to-end WireGuard protocol tests over real UDP sockets.
|
||||
//!
|
||||
//! Entirely userspace — no root, no TUN devices.
|
||||
//! Two boringtun `Tunn` instances exchange real WireGuard packets
|
||||
//! over loopback UDP, validating handshake, encryption, and data flow.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::time;
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
use smartvpn_daemon::wireguard::generate_wg_keypair;
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn parse_key_pair(pub_b64: &str, priv_b64: &str) -> (PublicKey, StaticSecret) {
|
||||
let pub_bytes: [u8; 32] = BASE64.decode(pub_b64).unwrap().try_into().unwrap();
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
(PublicKey::from(pub_bytes), StaticSecret::from(priv_bytes))
|
||||
}
|
||||
|
||||
fn clone_secret(priv_b64: &str) -> StaticSecret {
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
StaticSecret::from(priv_bytes)
|
||||
}
|
||||
|
||||
fn make_ipv4_packet(src: Ipv4Addr, dst: Ipv4Addr, payload: &[u8]) -> Vec<u8> {
|
||||
let total_len = 20 + payload.len();
|
||||
let mut pkt = vec![0u8; total_len];
|
||||
pkt[0] = 0x45;
|
||||
pkt[2] = (total_len >> 8) as u8;
|
||||
pkt[3] = total_len as u8;
|
||||
pkt[9] = 0x11;
|
||||
pkt[12..16].copy_from_slice(&src.octets());
|
||||
pkt[16..20].copy_from_slice(&dst.octets());
|
||||
pkt[20..].copy_from_slice(payload);
|
||||
pkt
|
||||
}
|
||||
|
||||
/// Send any WriteToNetwork result, then drain the tunn for more packets.
|
||||
async fn send_and_drain(
|
||||
tunn: &mut Tunn,
|
||||
pkt: &[u8],
|
||||
socket: &UdpSocket,
|
||||
peer: SocketAddr,
|
||||
) {
|
||||
socket.send_to(pkt, peer).await.unwrap();
|
||||
let mut drain_buf = vec![0u8; 2048];
|
||||
loop {
|
||||
match tunn.decapsulate(None, &[], &mut drain_buf) {
|
||||
TunnResult::WriteToNetwork(p) => { socket.send_to(p, peer).await.unwrap(); }
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to receive a UDP packet and decapsulate it. Returns decrypted IP data if any.
|
||||
async fn try_recv_decap(
|
||||
tunn: &mut Tunn,
|
||||
socket: &UdpSocket,
|
||||
timeout_ms: u64,
|
||||
) -> Option<(Vec<u8>, Ipv4Addr, SocketAddr)> {
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
let (n, src_addr) = match time::timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
socket.recv_from(&mut recv_buf),
|
||||
).await {
|
||||
Ok(Ok(r)) => r,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let result = tunn.decapsulate(Some(src_addr.ip()), &recv_buf[..n], &mut dst_buf);
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(tunn, pkt, socket, src_addr).await;
|
||||
None
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(pkt, addr) => Some((pkt.to_vec(), addr, src_addr)),
|
||||
TunnResult::WriteToTunnelV6(_, _) => None,
|
||||
TunnResult::Done => None,
|
||||
TunnResult::Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the full WireGuard handshake between client and server over real UDP.
|
||||
async fn do_handshake(
|
||||
client_tunn: &mut Tunn,
|
||||
server_tunn: &mut Tunn,
|
||||
client_socket: &UdpSocket,
|
||||
server_socket: &UdpSocket,
|
||||
server_addr: SocketAddr,
|
||||
) {
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
// Step 1: Client initiates handshake
|
||||
match client_tunn.encapsulate(&[], &mut buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => panic!("Expected handshake init"),
|
||||
}
|
||||
|
||||
// Step 2: Server receives init → sends response
|
||||
let (n, client_from) = server_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match server_tunn.decapsulate(Some(client_from.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(server_tunn, pkt, server_socket, client_from).await;
|
||||
}
|
||||
other => panic!("Expected WriteToNetwork from server, got variant {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Step 3: Client receives response
|
||||
let (n, _) = client_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match client_tunn.decapsulate(Some(server_addr.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(client_tunn, pkt, client_socket, server_addr).await;
|
||||
}
|
||||
TunnResult::Done => {}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Step 4: Process any remaining handshake packets
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 200).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 100).await;
|
||||
|
||||
// Step 5: Timer ticks to settle
|
||||
for _ in 0..3 {
|
||||
match server_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
server_socket.send_to(pkt, client_from).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match client_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 50).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 50).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn variant_name(r: &TunnResult) -> &'static str {
|
||||
match r {
|
||||
TunnResult::Done => "Done",
|
||||
TunnResult::Err(_) => "Err",
|
||||
TunnResult::WriteToNetwork(_) => "WriteToNetwork",
|
||||
TunnResult::WriteToTunnelV4(_, _) => "WriteToTunnelV4",
|
||||
TunnResult::WriteToTunnelV6(_, _) => "WriteToTunnelV6",
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an IP packet and send it, then loop-receive on the other side until decrypted.
|
||||
async fn send_and_expect_data(
|
||||
sender_tunn: &mut Tunn,
|
||||
receiver_tunn: &mut Tunn,
|
||||
sender_socket: &UdpSocket,
|
||||
receiver_socket: &UdpSocket,
|
||||
dest_addr: SocketAddr,
|
||||
ip_packet: &[u8],
|
||||
) -> (Vec<u8>, Ipv4Addr) {
|
||||
let mut enc_buf = vec![0u8; 65536];
|
||||
|
||||
match sender_tunn.encapsulate(ip_packet, &mut enc_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
sender_socket.send_to(pkt, dest_addr).await.unwrap();
|
||||
}
|
||||
TunnResult::Err(e) => panic!("Encapsulate failed: {:?}", e),
|
||||
other => panic!("Expected WriteToNetwork, got {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Receive — may need a few rounds for control packets
|
||||
for _ in 0..10 {
|
||||
if let Some((data, addr, _)) = try_recv_decap(receiver_tunn, receiver_socket, 1000).await {
|
||||
return (data, addr);
|
||||
}
|
||||
}
|
||||
panic!("Did not receive decrypted IP packet");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 1: Single client ↔ server bidirectional data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_single_client_bidirectional() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
let client_addr = client_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, None, None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, None, None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
// Client → Server
|
||||
let pkt_c2s = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"Hello from client!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt_c2s,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt_c2s.len()], &pkt_c2s[..]);
|
||||
|
||||
// Server → Client
|
||||
let pkt_s2c = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2), b"Hello from server!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut server_tunn, &mut client_tunn,
|
||||
&server_socket, &client_socket,
|
||||
client_addr, &pkt_s2c,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 1));
|
||||
assert_eq!(&decrypted[..pkt_s2c.len()], &pkt_s2c[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 2: Two clients ↔ one server (peer routing)
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_two_clients_peer_routing() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client1_pub_b64, client1_priv_b64) = generate_wg_keypair();
|
||||
let (client2_pub_b64, client2_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, _) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client1_public, client1_secret) = parse_key_pair(&client1_pub_b64, &client1_priv_b64);
|
||||
let (client2_public, client2_secret) = parse_key_pair(&client2_pub_b64, &client2_priv_b64);
|
||||
|
||||
// Separate server socket per peer to avoid UDP mux complexity in test
|
||||
let server_socket_1 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_socket_2 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client1_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client2_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr_1 = server_socket_1.local_addr().unwrap();
|
||||
let server_addr_2 = server_socket_2.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn_1 = Tunn::new(clone_secret(&server_priv_b64), client1_public, None, None, 0, None);
|
||||
let mut server_tunn_2 = Tunn::new(clone_secret(&server_priv_b64), client2_public, None, None, 1, None);
|
||||
let mut client1_tunn = Tunn::new(client1_secret, server_public.clone(), None, None, 2, None);
|
||||
let mut client2_tunn = Tunn::new(client2_secret, server_public, None, None, 3, None);
|
||||
|
||||
do_handshake(&mut client1_tunn, &mut server_tunn_1, &client1_socket, &server_socket_1, server_addr_1).await;
|
||||
do_handshake(&mut client2_tunn, &mut server_tunn_2, &client2_socket, &server_socket_2, server_addr_2).await;
|
||||
|
||||
// Client 1 → Server
|
||||
let pkt1 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"From client 1");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client1_tunn, &mut server_tunn_1,
|
||||
&client1_socket, &server_socket_1,
|
||||
server_addr_1, &pkt1,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt1.len()], &pkt1[..]);
|
||||
|
||||
// Client 2 → Server
|
||||
let pkt2 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 3), Ipv4Addr::new(10, 0, 0, 1), b"From client 2");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client2_tunn, &mut server_tunn_2,
|
||||
&client2_socket, &server_socket_2,
|
||||
server_addr_2, &pkt2,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 3));
|
||||
assert_eq!(&decrypted[..pkt2.len()], &pkt2[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 3: Preshared key handshake + data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_preshared_key() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let psk: [u8; 32] = rand::random();
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, Some(psk), None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, Some(psk), None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
let pkt = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"PSK-protected data");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt.len()], &pkt[..]);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
@@ -40,7 +40,9 @@ let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let client: VpnClient;
|
||||
let clientBundle: IClientConfigBundle;
|
||||
const extraClients: VpnClient[] = [];
|
||||
const extraBundles: IClientConfigBundle[] = [];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
@@ -64,7 +66,7 @@ tap.test('setup: start VPN server', async () => {
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
|
||||
// Phase 3: start the VPN listener
|
||||
// Phase 3: start the VPN listener (empty clients, will use createClient at runtime)
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
@@ -76,6 +78,11 @@ tap.test('setup: start VPN server', async () => {
|
||||
// Verify server is now running
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Phase 4: create the first client via the hub
|
||||
clientBundle = await server.createClient({ clientId: 'test-client-0' });
|
||||
expect(clientBundle.secrets.noisePrivateKey).toBeTypeofString();
|
||||
expect(clientBundle.smartvpnConfig.clientPublicKey).toBeTypeofString();
|
||||
});
|
||||
|
||||
tap.test('single client connects and gets IP', async () => {
|
||||
@@ -89,6 +96,8 @@ tap.test('single client connects and gets IP', async () => {
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: clientBundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: clientBundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
@@ -175,11 +184,15 @@ tap.test('5 concurrent clients', async () => {
|
||||
assignedIps.add(existingClients[0].assignedIp);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const bundle = await server.createClient({ clientId: `test-client-${i + 1}` });
|
||||
extraBundles.push(bundle);
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
const result = await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
@@ -144,12 +144,17 @@ let keypair: IVpnKeypair;
|
||||
let throttle: ThrottleProxy;
|
||||
const allClients: VpnClient[] = [];
|
||||
|
||||
let clientCounter = 0;
|
||||
async function createConnectedClient(port: number): Promise<VpnClient> {
|
||||
clientCounter++;
|
||||
const bundle = await server.createClient({ clientId: `load-client-${clientCounter}` });
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${port}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
allClients.push(c);
|
||||
|
||||
258
test/test.quic.node.ts
Normal file
258
test/test.quic.node.ts
Normal file
@@ -0,0 +1,258 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as dgram from 'dgram';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async function findFreePort(): Promise<number> {
|
||||
const server = net.createServer();
|
||||
await new Promise<void>((resolve) => server.listen(0, '127.0.0.1', resolve));
|
||||
const port = (server.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => server.close(() => resolve()));
|
||||
return port;
|
||||
}
|
||||
|
||||
async function findFreeUdpPort(): Promise<number> {
|
||||
const sock = dgram.createSocket('udp4');
|
||||
await new Promise<void>((resolve) => sock.bind(0, '127.0.0.1', resolve));
|
||||
const port = (sock.address() as net.AddressInfo).port;
|
||||
await new Promise<void>((resolve) => sock.close(resolve));
|
||||
return port;
|
||||
}
|
||||
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
fn: () => Promise<boolean>,
|
||||
timeoutMs: number = 10000,
|
||||
pollMs: number = 500,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await fn()) return;
|
||||
await delay(pollMs);
|
||||
}
|
||||
throw new Error(`waitFor timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let server: VpnServer;
|
||||
let wsPort: number;
|
||||
let quicPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: QUIC-only server + QUIC client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tap.test('setup: start VPN server in QUIC mode', async () => {
|
||||
quicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
keypair = await server.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${quicPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
publicKey: keypair.publicKey,
|
||||
subnet: '10.9.0.0/24',
|
||||
transportMode: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await server['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('QUIC client connects and gets IP', async () => {
|
||||
const bundle = await server.createClient({ clientId: 'quic-client-1' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${quicPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.9.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
// Verify server sees the client
|
||||
await waitFor(async () => {
|
||||
const clients = await server.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop QUIC server', async () => {
|
||||
await server.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: dual-mode server (both) + auto client
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
let dualServer: VpnServer;
|
||||
let dualWsPort: number;
|
||||
let dualQuicPort: number;
|
||||
let dualKeypair: IVpnKeypair;
|
||||
|
||||
tap.test('setup: start VPN server in both mode', async () => {
|
||||
dualWsPort = await findFreePort();
|
||||
dualQuicPort = await findFreeUdpPort();
|
||||
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
dualServer = new VpnServer(options);
|
||||
|
||||
const started = await dualServer['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
dualKeypair = await dualServer.generateKeypair();
|
||||
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${dualWsPort}`,
|
||||
privateKey: dualKeypair.privateKey,
|
||||
publicKey: dualKeypair.publicKey,
|
||||
subnet: '10.10.0.0/24',
|
||||
transportMode: 'both',
|
||||
quicListenAddr: `127.0.0.1:${dualQuicPort}`,
|
||||
keepaliveIntervalSecs: 3,
|
||||
};
|
||||
await dualServer['bridge'].sendCommand('start', { config: serverConfig });
|
||||
|
||||
const status = await dualServer.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
});
|
||||
|
||||
tap.test('auto client connects to dual-mode server (QUIC preferred)', async () => {
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-auto-client' });
|
||||
|
||||
// "auto" mode (default): tries QUIC first at same host:port, falls back to WS
|
||||
// Since the WS port and QUIC port differ, auto will try QUIC on WS port (fail),
|
||||
// then fall back to WebSocket
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${dualWsPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
// transport defaults to 'auto'
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await waitFor(async () => {
|
||||
const clients = await dualServer.listClients();
|
||||
return clients.length >= 1;
|
||||
});
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-quic-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
expect(result.assignedIp).toBeTypeofString();
|
||||
expect(result.assignedIp).toStartWith('10.10.0.');
|
||||
|
||||
const clientStatus = await client.getStatus();
|
||||
expect(clientStatus.state).toEqual('connected');
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange over QUIC', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-keepalive-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
const client = new VpnClient(options);
|
||||
await client.start();
|
||||
|
||||
await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
// Wait for keepalive exchange
|
||||
await delay(8000);
|
||||
|
||||
const clientStats = await client.getStatistics();
|
||||
expect(clientStats.keepalivesSent).toBeGreaterThanOrEqual(1);
|
||||
expect(clientStats.keepalivesReceived).toBeGreaterThanOrEqual(1);
|
||||
|
||||
await client.stop();
|
||||
});
|
||||
|
||||
tap.test('teardown: stop dual-mode server', async () => {
|
||||
await dualServer.stop();
|
||||
await delay(500);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -2,10 +2,17 @@ import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { VpnConfig } from '../ts/index.js';
|
||||
import type { IVpnClientConfig, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// Valid 32-byte base64 keys for testing
|
||||
const TEST_KEY_A = 'YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE=';
|
||||
const TEST_KEY_B = 'YmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmI=';
|
||||
const TEST_KEY_C = 'Y2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2M=';
|
||||
|
||||
tap.test('VpnConfig: validate valid client config', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
@@ -16,7 +23,9 @@ tap.test('VpnConfig: validate valid client config', async () => {
|
||||
|
||||
tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
const config = {
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -31,7 +40,9 @@ tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'http://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -43,10 +54,28 @@ tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config without clientPrivateKey', async () => {
|
||||
const config = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('clientPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
mtu: 100,
|
||||
};
|
||||
let threw = false;
|
||||
@@ -62,7 +91,9 @@ tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['not-an-ip'],
|
||||
};
|
||||
let threw = false;
|
||||
@@ -78,12 +109,15 @@ tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
tap.test('VpnConfig: validate valid server config', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
enableNat: true,
|
||||
clients: [
|
||||
{ clientId: 'test-client', publicKey: TEST_KEY_C },
|
||||
],
|
||||
};
|
||||
// Should not throw
|
||||
VpnConfig.validateServerConfig(config);
|
||||
@@ -92,8 +126,8 @@ tap.test('VpnConfig: validate valid server config', async () => {
|
||||
tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: 'invalid',
|
||||
};
|
||||
let threw = false;
|
||||
@@ -109,7 +143,7 @@ tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
const config = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
} as IVpnServerConfig;
|
||||
let threw = false;
|
||||
@@ -122,4 +156,24 @@ tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject server config with invalid client publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
clients: [
|
||||
{ clientId: 'bad-client', publicKey: 'short-key' },
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
|
||||
353
test/test.wireguard.node.ts
Normal file
353
test/test.wireguard.node.ts
Normal file
@@ -0,0 +1,353 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import {
|
||||
VpnConfig,
|
||||
VpnServer,
|
||||
WgConfigGenerator,
|
||||
} from '../ts/index.js';
|
||||
import type {
|
||||
IVpnClientConfig,
|
||||
IVpnServerConfig,
|
||||
IVpnServerOptions,
|
||||
IWgPeerConfig,
|
||||
} from '../ts/index.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — client
|
||||
// ============================================================================
|
||||
|
||||
// A valid 32-byte key in base64 (44 chars)
|
||||
const VALID_KEY = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=';
|
||||
const VALID_KEY_2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=';
|
||||
|
||||
tap.test('WG client config: valid wireguard config passes validation', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '', // not needed for WG
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['0.0.0.0/0'],
|
||||
};
|
||||
VpnConfig.validateClientConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgPrivateKey', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgAddress', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgAddress');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects missing wgEndpoint', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgEndpoint');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid key length', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: 'tooshort',
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('44 characters');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG client config: rejects invalid CIDR in allowedIps', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: '',
|
||||
serverPublicKey: VALID_KEY,
|
||||
transport: 'wireguard',
|
||||
wgPrivateKey: VALID_KEY_2,
|
||||
wgAddress: '10.8.0.2',
|
||||
wgEndpoint: 'vpn.example.com:51820',
|
||||
wgAllowedIps: ['not-a-cidr'],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('CIDR');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config validation — server
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WG server config: valid config passes validation', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
VpnConfig.validateServerConfig(config);
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects empty wgPeers', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgPeers');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects peer without publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: '',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG server config: rejects invalid wgListenPort', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '',
|
||||
privateKey: VALID_KEY,
|
||||
publicKey: VALID_KEY_2,
|
||||
subnet: '10.8.0.0/24',
|
||||
transportMode: 'wireguard',
|
||||
wgListenPort: 0,
|
||||
wgPeers: [
|
||||
{
|
||||
publicKey: VALID_KEY_2,
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('wgListenPort');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard keypair generation via daemon
|
||||
// ============================================================================
|
||||
|
||||
let server: VpnServer;
|
||||
|
||||
tap.test('WG: spawn server daemon for keypair generation', async () => {
|
||||
const options: IVpnServerOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
server = new VpnServer(options);
|
||||
const started = await server['bridge'].start();
|
||||
expect(started).toBeTrue();
|
||||
expect(server.running).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns valid keypair', async () => {
|
||||
const keypair = await server.generateWgKeypair();
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
// WireGuard keys: base64 of 32 bytes = 44 characters
|
||||
expect(keypair.publicKey.length).toEqual(44);
|
||||
expect(keypair.privateKey.length).toEqual(44);
|
||||
// Verify they decode to 32 bytes
|
||||
const pubBuf = Buffer.from(keypair.publicKey, 'base64');
|
||||
const privBuf = Buffer.from(keypair.privateKey, 'base64');
|
||||
expect(pubBuf.length).toEqual(32);
|
||||
expect(privBuf.length).toEqual(32);
|
||||
});
|
||||
|
||||
tap.test('WG: generateWgKeypair returns unique keys each time', async () => {
|
||||
const kp1 = await server.generateWgKeypair();
|
||||
const kp2 = await server.generateWgKeypair();
|
||||
expect(kp1.publicKey).not.toEqual(kp2.publicKey);
|
||||
expect(kp1.privateKey).not.toEqual(kp2.privateKey);
|
||||
});
|
||||
|
||||
tap.test('WG: stop server daemon', async () => {
|
||||
server.stop();
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
expect(server.running).toBeFalse();
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard config file generation
|
||||
// ============================================================================
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'clientPrivateKeyBase64====================',
|
||||
address: '10.8.0.2/24',
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
peer: {
|
||||
publicKey: 'serverPublicKeyBase64====================',
|
||||
endpoint: 'vpn.example.com:51820',
|
||||
allowedIps: ['0.0.0.0/0', '::/0'],
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('PrivateKey = clientPrivateKeyBase64====================');
|
||||
expect(conf).toContain('Address = 10.8.0.2/24');
|
||||
expect(conf).toContain('DNS = 1.1.1.1, 8.8.8.8');
|
||||
expect(conf).toContain('MTU = 1420');
|
||||
expect(conf).toContain('[Peer]');
|
||||
expect(conf).toContain('PublicKey = serverPublicKeyBase64====================');
|
||||
expect(conf).toContain('Endpoint = vpn.example.com:51820');
|
||||
expect(conf).toContain('AllowedIPs = 0.0.0.0/0, ::/0');
|
||||
expect(conf).toContain('PersistentKeepalive = 25');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate client config without optional fields', async () => {
|
||||
const conf = WgConfigGenerator.generateClientConfig({
|
||||
privateKey: 'key1',
|
||||
address: '10.0.0.2/32',
|
||||
peer: {
|
||||
publicKey: 'key2',
|
||||
endpoint: 'server:51820',
|
||||
allowedIps: ['10.0.0.0/24'],
|
||||
},
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).not.toContain('DNS');
|
||||
expect(conf).not.toContain('MTU');
|
||||
expect(conf).not.toContain('PresharedKey');
|
||||
expect(conf).not.toContain('PersistentKeepalive');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config with NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
dns: ['1.1.1.1'],
|
||||
enableNat: true,
|
||||
natInterface: 'ens3',
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peer1PubKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
presharedKey: 'psk1',
|
||||
persistentKeepalive: 25,
|
||||
},
|
||||
{
|
||||
publicKey: 'peer2PubKey',
|
||||
allowedIps: ['10.8.0.3/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).toContain('[Interface]');
|
||||
expect(conf).toContain('ListenPort = 51820');
|
||||
expect(conf).toContain('PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ens3 -j MASQUERADE');
|
||||
expect(conf).toContain('PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ens3 -j MASQUERADE');
|
||||
// Two [Peer] sections
|
||||
const peerCount = (conf.match(/\[Peer\]/g) || []).length;
|
||||
expect(peerCount).toEqual(2);
|
||||
expect(conf).toContain('PresharedKey = psk1');
|
||||
});
|
||||
|
||||
tap.test('WgConfigGenerator: generate server config without NAT', async () => {
|
||||
const conf = WgConfigGenerator.generateServerConfig({
|
||||
privateKey: 'serverPrivKey',
|
||||
address: '10.8.0.1/24',
|
||||
listenPort: 51820,
|
||||
peers: [
|
||||
{
|
||||
publicKey: 'peerKey',
|
||||
allowedIps: ['10.8.0.2/32'],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(conf).not.toContain('PostUp');
|
||||
expect(conf).not.toContain('PostDown');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartvpn',
|
||||
version: '1.3.0',
|
||||
version: '1.9.0',
|
||||
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
|
||||
}
|
||||
|
||||
@@ -4,3 +4,4 @@ export { VpnClient } from './smartvpn.classes.vpnclient.js';
|
||||
export { VpnServer } from './smartvpn.classes.vpnserver.js';
|
||||
export { VpnConfig } from './smartvpn.classes.vpnconfig.js';
|
||||
export { VpnInstaller } from './smartvpn.classes.vpninstaller.js';
|
||||
export { WgConfigGenerator } from './smartvpn.classes.wgconfig.js';
|
||||
|
||||
@@ -12,14 +12,54 @@ export class VpnConfig {
|
||||
* Validate a client config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateClientConfig(config: IVpnClientConfig): void {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws://');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
if (config.transport === 'wireguard') {
|
||||
// WireGuard-specific validation
|
||||
if (!config.wgPrivateKey) {
|
||||
throw new Error('VpnConfig: wgPrivateKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.wgPrivateKey, 'wgPrivateKey');
|
||||
if (!config.wgAddress) {
|
||||
throw new Error('VpnConfig: wgAddress is required for WireGuard transport');
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required for WireGuard transport');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.serverPublicKey, 'serverPublicKey');
|
||||
if (!config.wgEndpoint) {
|
||||
throw new Error('VpnConfig: wgEndpoint is required for WireGuard transport');
|
||||
}
|
||||
if (config.wgPresharedKey) {
|
||||
VpnConfig.validateBase64Key(config.wgPresharedKey, 'wgPresharedKey');
|
||||
}
|
||||
if (config.wgAllowedIps) {
|
||||
for (const cidr of config.wgAllowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!config.serverUrl) {
|
||||
throw new Error('VpnConfig: serverUrl is required');
|
||||
}
|
||||
// For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss://
|
||||
if (config.transport !== 'quic') {
|
||||
if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) {
|
||||
throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)');
|
||||
}
|
||||
}
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
}
|
||||
// Noise IK requires client keypair
|
||||
if (!config.clientPrivateKey) {
|
||||
throw new Error('VpnConfig: clientPrivateKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPrivateKey, 'clientPrivateKey');
|
||||
if (!config.clientPublicKey) {
|
||||
throw new Error('VpnConfig: clientPublicKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPublicKey, 'clientPublicKey');
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -40,20 +80,63 @@ export class VpnConfig {
|
||||
* Validate a server config object. Throws on invalid config.
|
||||
*/
|
||||
public static validateServerConfig(config: IVpnServerConfig): void {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
if (config.transportMode === 'wireguard') {
|
||||
// WireGuard server validation
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.privateKey, 'privateKey');
|
||||
if (!config.wgPeers || config.wgPeers.length === 0) {
|
||||
throw new Error('VpnConfig: at least one wgPeers entry is required for WireGuard mode');
|
||||
}
|
||||
for (const peer of config.wgPeers) {
|
||||
if (!peer.publicKey) {
|
||||
throw new Error('VpnConfig: peer publicKey is required');
|
||||
}
|
||||
VpnConfig.validateBase64Key(peer.publicKey, 'peer.publicKey');
|
||||
if (!peer.allowedIps || peer.allowedIps.length === 0) {
|
||||
throw new Error('VpnConfig: peer allowedIps is required');
|
||||
}
|
||||
for (const cidr of peer.allowedIps) {
|
||||
if (!VpnConfig.isValidCidr(cidr)) {
|
||||
throw new Error(`VpnConfig: invalid peer allowedIp CIDR: ${cidr}`);
|
||||
}
|
||||
}
|
||||
if (peer.presharedKey) {
|
||||
VpnConfig.validateBase64Key(peer.presharedKey, 'peer.presharedKey');
|
||||
}
|
||||
}
|
||||
if (config.wgListenPort !== undefined && (config.wgListenPort < 1 || config.wgListenPort > 65535)) {
|
||||
throw new Error('VpnConfig: wgListenPort must be between 1 and 65535');
|
||||
}
|
||||
} else {
|
||||
if (!config.listenAddr) {
|
||||
throw new Error('VpnConfig: listenAddr is required');
|
||||
}
|
||||
if (!config.privateKey) {
|
||||
throw new Error('VpnConfig: privateKey is required');
|
||||
}
|
||||
if (!config.publicKey) {
|
||||
throw new Error('VpnConfig: publicKey is required');
|
||||
}
|
||||
if (!config.subnet) {
|
||||
throw new Error('VpnConfig: subnet is required');
|
||||
}
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
}
|
||||
// Validate client entries if provided
|
||||
if (config.clients) {
|
||||
for (const client of config.clients) {
|
||||
if (!client.clientId) {
|
||||
throw new Error('VpnConfig: client entry must have a clientId');
|
||||
}
|
||||
if (!client.publicKey) {
|
||||
throw new Error(`VpnConfig: client '${client.clientId}' must have a publicKey`);
|
||||
}
|
||||
VpnConfig.validateBase64Key(client.publicKey, `client '${client.clientId}' publicKey`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -101,4 +184,41 @@ export class VpnConfig {
|
||||
const prefixNum = parseInt(prefix, 10);
|
||||
return !isNaN(prefixNum) && prefixNum >= 0 && prefixNum <= 32;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a CIDR string (IPv4 or IPv6).
|
||||
*/
|
||||
private static isValidCidr(cidr: string): boolean {
|
||||
const parts = cidr.split('/');
|
||||
if (parts.length !== 2) return false;
|
||||
const prefixNum = parseInt(parts[1], 10);
|
||||
if (isNaN(prefixNum) || prefixNum < 0) return false;
|
||||
// IPv4
|
||||
if (VpnConfig.isValidIp(parts[0])) {
|
||||
return prefixNum <= 32;
|
||||
}
|
||||
// IPv6 (basic check)
|
||||
if (parts[0].includes(':')) {
|
||||
return prefixNum <= 128;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a base64-encoded 32-byte key (WireGuard X25519 format).
|
||||
*/
|
||||
private static validateBase64Key(key: string, fieldName: string): void {
|
||||
if (key.length !== 44) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must be 44 characters (base64 of 32 bytes), got ${key.length}`);
|
||||
}
|
||||
try {
|
||||
const buf = Buffer.from(key, 'base64');
|
||||
if (buf.length !== 32) {
|
||||
throw new Error(`VpnConfig: ${fieldName} must decode to 32 bytes, got ${buf.length}`);
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.startsWith('VpnConfig:')) throw e;
|
||||
throw new Error(`VpnConfig: ${fieldName} is not valid base64`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import type {
|
||||
IVpnClientInfo,
|
||||
IVpnKeypair,
|
||||
IVpnClientTelemetry,
|
||||
IWgPeerConfig,
|
||||
IWgPeerInfo,
|
||||
IClientEntry,
|
||||
IClientConfigBundle,
|
||||
TVpnServerCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -121,6 +125,110 @@ export class VpnServer extends plugins.events.EventEmitter {
|
||||
return this.bridge.sendCommand('getClientTelemetry', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a WireGuard-compatible X25519 keypair.
|
||||
*/
|
||||
public async generateWgKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateWgKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a WireGuard peer (server must be running in wireguard mode).
|
||||
*/
|
||||
public async addWgPeer(peer: IWgPeerConfig): Promise<void> {
|
||||
await this.bridge.sendCommand('addWgPeer', { peer });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a WireGuard peer by public key.
|
||||
*/
|
||||
public async removeWgPeer(publicKey: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeWgPeer', { publicKey });
|
||||
}
|
||||
|
||||
/**
|
||||
* List WireGuard peers with stats.
|
||||
*/
|
||||
public async listWgPeers(): Promise<IWgPeerInfo[]> {
|
||||
const result = await this.bridge.sendCommand('listWgPeers', {} as Record<string, never>);
|
||||
return result.peers;
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ─────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a new client. Generates keypairs, assigns IP, returns full config bundle.
|
||||
* The secrets (private keys) are only returned at creation time.
|
||||
*/
|
||||
public async createClient(opts: Partial<IClientEntry>): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('createClient', { client: opts });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a registered client (also disconnects if connected).
|
||||
*/
|
||||
public async removeClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a registered client by ID.
|
||||
*/
|
||||
public async getClient(clientId: string): Promise<IClientEntry> {
|
||||
return this.bridge.sendCommand('getClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* List all registered clients.
|
||||
*/
|
||||
public async listRegisteredClients(): Promise<IClientEntry[]> {
|
||||
const result = await this.bridge.sendCommand('listRegisteredClients', {} as Record<string, never>);
|
||||
return result.clients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a registered client's fields (ACLs, tags, description, etc.).
|
||||
*/
|
||||
public async updateClient(clientId: string, update: Partial<IClientEntry>): Promise<void> {
|
||||
await this.bridge.sendCommand('updateClient', { clientId, update });
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a previously disabled client.
|
||||
*/
|
||||
public async enableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('enableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable a client (also disconnects if connected).
|
||||
*/
|
||||
public async disableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('disableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
*/
|
||||
public async rotateClientKey(clientId: string): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('rotateClientKey', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Export a client config (without secrets) in the specified format.
|
||||
*/
|
||||
public async exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): Promise<string> {
|
||||
const result = await this.bridge.sendCommand('exportClientConfig', { clientId, format });
|
||||
return result.config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a standalone Noise IK keypair (not tied to a client).
|
||||
*/
|
||||
public async generateClientKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateClientKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
123
ts/smartvpn.classes.wgconfig.ts
Normal file
123
ts/smartvpn.classes.wgconfig.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import type { IWgPeerConfig } from './smartvpn.interfaces.js';
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard .conf file generator
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgClientConfOptions {
|
||||
/** Client private key (base64) */
|
||||
privateKey: string;
|
||||
/** Client TUN address with prefix (e.g. 10.8.0.2/24) */
|
||||
address: string;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Server peer config */
|
||||
peer: {
|
||||
publicKey: string;
|
||||
presharedKey?: string;
|
||||
endpoint: string;
|
||||
allowedIps: string[];
|
||||
persistentKeepalive?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface IWgServerConfOptions {
|
||||
/** Server private key (base64) */
|
||||
privateKey: string;
|
||||
/** Server TUN address with prefix (e.g. 10.8.0.1/24) */
|
||||
address: string;
|
||||
/** UDP listen port */
|
||||
listenPort: number;
|
||||
/** DNS servers */
|
||||
dns?: string[];
|
||||
/** TUN MTU */
|
||||
mtu?: number;
|
||||
/** Enable NAT — adds PostUp/PostDown iptables rules */
|
||||
enableNat?: boolean;
|
||||
/** Network interface for NAT (e.g. eth0). Auto-detected if omitted. */
|
||||
natInterface?: string;
|
||||
/** Configured peers */
|
||||
peers: IWgPeerConfig[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates standard WireGuard .conf files compatible with wg-quick,
|
||||
* WireGuard iOS/Android apps, and other standard WireGuard clients.
|
||||
*/
|
||||
export class WgConfigGenerator {
|
||||
/**
|
||||
* Generate a client .conf file content.
|
||||
*/
|
||||
public static generateClientConfig(opts: IWgClientConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${opts.peer.publicKey}`);
|
||||
if (opts.peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${opts.peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`Endpoint = ${opts.peer.endpoint}`);
|
||||
lines.push(`AllowedIPs = ${opts.peer.allowedIps.join(', ')}`);
|
||||
if (opts.peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${opts.peer.persistentKeepalive}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a server .conf file content.
|
||||
*/
|
||||
public static generateServerConfig(opts: IWgServerConfOptions): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push('[Interface]');
|
||||
lines.push(`PrivateKey = ${opts.privateKey}`);
|
||||
lines.push(`Address = ${opts.address}`);
|
||||
lines.push(`ListenPort = ${opts.listenPort}`);
|
||||
if (opts.dns && opts.dns.length > 0) {
|
||||
lines.push(`DNS = ${opts.dns.join(', ')}`);
|
||||
}
|
||||
if (opts.mtu) {
|
||||
lines.push(`MTU = ${opts.mtu}`);
|
||||
}
|
||||
if (opts.enableNat) {
|
||||
const iface = opts.natInterface || 'eth0';
|
||||
lines.push(`PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
lines.push(`PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ${iface} -j MASQUERADE`);
|
||||
}
|
||||
|
||||
for (const peer of opts.peers) {
|
||||
lines.push('');
|
||||
lines.push('[Peer]');
|
||||
lines.push(`PublicKey = ${peer.publicKey}`);
|
||||
if (peer.presharedKey) {
|
||||
lines.push(`PresharedKey = ${peer.presharedKey}`);
|
||||
}
|
||||
lines.push(`AllowedIPs = ${peer.allowedIps.join(', ')}`);
|
||||
if (peer.endpoint) {
|
||||
lines.push(`Endpoint = ${peer.endpoint}`);
|
||||
}
|
||||
if (peer.persistentKeepalive) {
|
||||
lines.push(`PersistentKeepalive = ${peer.persistentKeepalive}`);
|
||||
}
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
return lines.join('\n');
|
||||
}
|
||||
}
|
||||
@@ -24,14 +24,36 @@ export type TVpnTransportOptions = IVpnTransportStdio | IVpnTransportSocket;
|
||||
export interface IVpnClientConfig {
|
||||
/** Server WebSocket URL, e.g. wss://vpn.example.com/tunnel */
|
||||
serverUrl: string;
|
||||
/** Server's static public key (base64) for Noise NK handshake */
|
||||
/** Server's static public key (base64) for Noise IK handshake */
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — required for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
/** Optional DNS servers to use while connected */
|
||||
dns?: string[];
|
||||
/** Optional MTU for the TUN device */
|
||||
mtu?: number;
|
||||
/** Keepalive interval in seconds (default: 30) */
|
||||
keepaliveIntervalSecs?: number;
|
||||
/** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', 'quic', or 'wireguard' */
|
||||
transport?: 'auto' | 'websocket' | 'quic' | 'wireguard';
|
||||
/** For QUIC: SHA-256 hash of server certificate (base64) for cert pinning */
|
||||
serverCertHash?: string;
|
||||
/** WireGuard: client private key (base64, X25519) */
|
||||
wgPrivateKey?: string;
|
||||
/** WireGuard: client TUN address (e.g. 10.8.0.2) */
|
||||
wgAddress?: string;
|
||||
/** WireGuard: client TUN address prefix length (default: 24) */
|
||||
wgAddressPrefix?: number;
|
||||
/** WireGuard: preshared key (base64, optional) */
|
||||
wgPresharedKey?: string;
|
||||
/** WireGuard: persistent keepalive interval in seconds */
|
||||
wgPersistentKeepalive?: number;
|
||||
/** WireGuard: server endpoint (host:port, e.g. vpn.example.com:51820) */
|
||||
wgEndpoint?: string;
|
||||
/** WireGuard: allowed IPs (CIDR strings, e.g. ['0.0.0.0/0']) */
|
||||
wgAllowedIps?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnClientOptions {
|
||||
@@ -68,6 +90,25 @@ export interface IVpnServerConfig {
|
||||
defaultRateLimitBytesPerSec?: number;
|
||||
/** Default burst size for new clients (bytes). Omit for unlimited. */
|
||||
defaultBurstBytes?: number;
|
||||
/** Transport mode: 'both' (default, WS+QUIC), 'websocket', 'quic', or 'wireguard' */
|
||||
transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard';
|
||||
/** QUIC listen address (host:port). Defaults to listenAddr. */
|
||||
quicListenAddr?: string;
|
||||
/** QUIC idle timeout in seconds (default: 30) */
|
||||
quicIdleTimeoutSecs?: number;
|
||||
/** WireGuard: UDP listen port (default: 51820) */
|
||||
wgListenPort?: number;
|
||||
/** WireGuard: configured peers */
|
||||
wgPeers?: IWgPeerConfig[];
|
||||
/** Pre-registered clients for Noise IK authentication */
|
||||
clients?: IClientEntry[];
|
||||
/** Enable PROXY protocol v2 on incoming WebSocket connections.
|
||||
* Required when behind a reverse proxy that sends PP v2 headers (HAProxy, SmartProxy).
|
||||
* SECURITY: Must be false when accepting direct client connections. */
|
||||
proxyProtocol?: boolean;
|
||||
/** Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
* Supports exact IPs, CIDR, wildcards, ranges. */
|
||||
connectionIpBlockList?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnServerOptions {
|
||||
@@ -118,6 +159,12 @@ export interface IVpnClientInfo {
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
/** Client's authenticated Noise IK public key (base64) */
|
||||
authenticatedKey: string;
|
||||
/** Registered client ID from the client registry */
|
||||
registeredClientId: string;
|
||||
/** Real client IP:port (from PROXY protocol or direct TCP connection) */
|
||||
remoteAddr?: string;
|
||||
}
|
||||
|
||||
export interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -177,6 +224,113 @@ export interface IVpnClientTelemetry {
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client Registry (Hub) types — aligned with SmartProxy IRouteSecurity pattern
|
||||
// ============================================================================
|
||||
|
||||
/** Per-client rate limiting. */
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Per-client security settings.
|
||||
* Mirrors SmartProxy's IRouteSecurity: ipAllowList/ipBlockList naming + deny-overrides-allow.
|
||||
* Adds VPN-specific destination filtering.
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5). */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins). */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all). */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins). */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client. */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting. */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server-side client definition — the central config object for the Hub.
|
||||
* Naming and structure aligned with SmartProxy's IRouteConfig / IRouteSecurity.
|
||||
*/
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
/** Security settings (ACLs, rate limits) */
|
||||
security?: IClientSecurity;
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
/** Assigned VPN IP address (set by server) */
|
||||
assignedIp?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete client config bundle — returned by createClient() and rotateClientKey().
|
||||
* Contains everything the client needs to connect.
|
||||
*/
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard-specific types
|
||||
// ============================================================================
|
||||
|
||||
export interface IWgPeerConfig {
|
||||
/** Peer's public key (base64, X25519) */
|
||||
publicKey: string;
|
||||
/** Optional preshared key (base64) */
|
||||
presharedKey?: string;
|
||||
/** Allowed IP ranges (CIDR strings) */
|
||||
allowedIps: string[];
|
||||
/** Peer endpoint (host:port) — optional for server peers, required for client */
|
||||
endpoint?: string;
|
||||
/** Persistent keepalive interval in seconds */
|
||||
persistentKeepalive?: number;
|
||||
}
|
||||
|
||||
export interface IWgPeerInfo {
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
persistentKeepalive?: number;
|
||||
bytesSent: number;
|
||||
bytesReceived: number;
|
||||
packetsSent: number;
|
||||
packetsReceived: number;
|
||||
lastHandshakeTime?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// IPC Command maps (used by smartrust RustBridge<TCommands>)
|
||||
// ============================================================================
|
||||
@@ -201,6 +355,21 @@ export type TVpnServerCommands = {
|
||||
setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void };
|
||||
removeClientRateLimit: { params: { clientId: string }; result: void };
|
||||
getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry };
|
||||
generateWgKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
addWgPeer: { params: { peer: IWgPeerConfig }; result: void };
|
||||
removeWgPeer: { params: { publicKey: string }; result: void };
|
||||
listWgPeers: { params: Record<string, never>; result: { peers: IWgPeerInfo[] } };
|
||||
// Client Registry (Hub) commands
|
||||
createClient: { params: { client: Partial<IClientEntry> }; result: IClientConfigBundle };
|
||||
removeClient: { params: { clientId: string }; result: void };
|
||||
getClient: { params: { clientId: string }; result: IClientEntry };
|
||||
listRegisteredClients: { params: Record<string, never>; result: { clients: IClientEntry[] } };
|
||||
updateClient: { params: { clientId: string; update: Partial<IClientEntry> }; result: void };
|
||||
enableClient: { params: { clientId: string }; result: void };
|
||||
disableClient: { params: { clientId: string }; result: void };
|
||||
rotateClientKey: { params: { clientId: string }; result: IClientConfigBundle };
|
||||
exportClientConfig: { params: { clientId: string; format: 'smartvpn' | 'wireguard' }; result: { config: string } };
|
||||
generateClientKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"esModuleInterop": true,
|
||||
"verbatimModuleSyntax": true
|
||||
"verbatimModuleSyntax": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"exclude": [
|
||||
"dist_ts/**/*.d.ts"
|
||||
|
||||
Reference in New Issue
Block a user