Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e9cf575271 | |||
| 229db4be38 | |||
| e31086d0c2 | |||
| 01a0d8b9f4 | |||
| 187a69028b | |||
| 64dedd389e | |||
| 13d8cbe3fa |
23
changelog.md
23
changelog.md
@@ -1,5 +1,28 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-03-29 - 1.9.0 - feat(server)
|
||||
add PROXY protocol v2 support for real client IP handling and connection ACLs
|
||||
|
||||
- add PROXY protocol v2 parsing for WebSocket connections, including IPv4/IPv6 support, LOCAL command handling, and header read timeout protection
|
||||
- apply server-level connection IP block lists before the Noise handshake and enforce per-client source IP allow/block lists using the resolved remote address
|
||||
- expose proxy protocol configuration and remote client address fields in Rust and TypeScript interfaces, and document reverse-proxy usage in the README
|
||||
|
||||
## 2026-03-29 - 1.8.0 - feat(auth,client-registry)
|
||||
add Noise IK client authentication with managed client registry and per-client ACL controls
|
||||
|
||||
- switch the native tunnel handshake from Noise NK to Noise IK and require client keypairs in client configuration
|
||||
- add server-side client registry management APIs for creating, updating, disabling, rotating, listing, and exporting client configs
|
||||
- enforce client authorization from the registry during handshake and expose authenticated client metadata in server client info
|
||||
- introduce per-client security policies with source/destination ACLs and per-client rate limit settings
|
||||
- add Rust ACL matching support for exact IPs, CIDR ranges, wildcards, and IP ranges with test coverage
|
||||
|
||||
## 2026-03-29 - 1.7.0 - feat(rust-tests)
|
||||
add end-to-end WireGuard UDP integration tests and align TypeScript build configuration
|
||||
|
||||
- Add userspace Rust end-to-end tests that validate WireGuard handshake, encryption, peer isolation, and preshared-key data exchange over real UDP sockets.
|
||||
- Update the TypeScript build setup by removing the allowimplicitany build flag and explicitly including Node types in tsconfig.
|
||||
- Refresh development toolchain versions to support the updated test and build workflow.
|
||||
|
||||
## 2026-03-29 - 1.6.0 - feat(readme)
|
||||
document WireGuard transport support, configuration, and usage examples
|
||||
|
||||
|
||||
12
package.json
12
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartvpn",
|
||||
"version": "1.6.0",
|
||||
"version": "1.9.0",
|
||||
"private": false,
|
||||
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
|
||||
"type": "module",
|
||||
@@ -10,7 +10,7 @@
|
||||
"main": "dist_ts/index.js",
|
||||
"typings": "dist_ts/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
|
||||
"build": "(tsbuild tsfolders) && (tsrust)",
|
||||
"test:before": "(tsrust)",
|
||||
"test": "tstest test/ --verbose",
|
||||
"buildDocs": "tsdoc"
|
||||
@@ -33,10 +33,10 @@
|
||||
"@push.rocks/smartrust": "^1.3.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@git.zone/tsbuild": "^4.3.0",
|
||||
"@git.zone/tsrun": "^2.0.1",
|
||||
"@git.zone/tsrust": "^1.3.0",
|
||||
"@git.zone/tstest": "^3.5.0",
|
||||
"@git.zone/tsbuild": "^4.4.0",
|
||||
"@git.zone/tsrun": "^2.0.2",
|
||||
"@git.zone/tsrust": "^1.3.2",
|
||||
"@git.zone/tstest": "^3.6.3",
|
||||
"@types/node": "^25.5.0"
|
||||
},
|
||||
"files": [
|
||||
|
||||
865
pnpm-lock.yaml
generated
865
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
253
readme.plan.md
Normal file
253
readme.plan.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# PROXY Protocol v2 Support for SmartVPN WebSocket Transport
|
||||
|
||||
## Context
|
||||
|
||||
SmartVPN's WebSocket transport is designed to sit behind reverse proxies (Cloudflare, HAProxy, SmartProxy). The recently added ACL engine has `ipAllowList`/`ipBlockList` per client, but without PROXY protocol support the server only sees the proxy's IP — not the real client's. This makes source-IP ACLs useless behind a proxy.
|
||||
|
||||
PROXY protocol v2 solves this by letting the proxy prepend a binary header with the real client IP/port before the WebSocket upgrade.
|
||||
|
||||
---
|
||||
|
||||
## Design
|
||||
|
||||
### Two-Phase ACL with Real Client IP
|
||||
|
||||
```
|
||||
TCP accept → Read PP v2 header → Extract real IP
|
||||
│
|
||||
├─ Phase 1 (pre-handshake): Check server-level connectionIpBlockList → reject early
|
||||
│
|
||||
├─ WebSocket upgrade → Noise IK handshake → Client identity known
|
||||
│
|
||||
└─ Phase 2 (post-handshake): Check per-client ipAllowList/ipBlockList → reject if denied
|
||||
```
|
||||
|
||||
- **Phase 1**: Server-wide block list (`connectionIpBlockList` on `IVpnServerConfig`). Rejects before any crypto work. Protects server resources.
|
||||
- **Phase 2**: Per-client ACL from `IClientSecurity.ipAllowList`/`ipBlockList`. Applied after the Noise IK handshake identifies the client.
|
||||
|
||||
### No New Dependencies
|
||||
|
||||
PROXY protocol v2 is a fixed-format binary header (16-byte signature + variable address block). Manual parsing (~80 lines) follows the same pattern as `codec.rs`. No crate needed.
|
||||
|
||||
### Scope: WebSocket Only
|
||||
|
||||
- **WebSocket**: Needs PP v2 (sits behind reverse proxies)
|
||||
- **QUIC**: Direct UDP, just use `conn.remote_address()`
|
||||
- **WireGuard**: Direct UDP, uses boringtun peer tracking
|
||||
|
||||
---
|
||||
|
||||
## Implementation
|
||||
|
||||
### Phase 1: New Rust module `proxy_protocol.rs`
|
||||
|
||||
**New file: `rust/src/proxy_protocol.rs`**
|
||||
|
||||
PP v2 binary format:
|
||||
```
|
||||
Bytes 0-11: Signature \x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A
|
||||
Byte 12: Version (high nibble = 0x2) | Command (low nibble: 0x0=LOCAL, 0x1=PROXY)
|
||||
Byte 13: Address family | Protocol (0x11 = IPv4/TCP, 0x21 = IPv6/TCP)
|
||||
Bytes 14-15: Address data length (big-endian u16)
|
||||
Bytes 16+: IPv4: 4 src_ip + 4 dst_ip + 2 src_port + 2 dst_port (12 bytes)
|
||||
IPv6: 16 src_ip + 16 dst_ip + 2 src_port + 2 dst_port (36 bytes)
|
||||
```
|
||||
|
||||
```rust
|
||||
pub struct ProxyHeader {
|
||||
pub src_addr: SocketAddr,
|
||||
pub dst_addr: SocketAddr,
|
||||
pub is_local: bool, // LOCAL command = health check probe
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
/// Reads exactly the header bytes — the stream is clean for WS upgrade after.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader>
|
||||
```
|
||||
|
||||
- 5-second timeout on header read (constant `PROXY_HEADER_TIMEOUT`)
|
||||
- Validates 12-byte signature, version nibble, command type
|
||||
- Parses IPv4 and IPv6 address blocks
|
||||
- LOCAL command returns `is_local: true` (caller closes connection gracefully)
|
||||
- Unit tests: valid IPv4/IPv6 headers, LOCAL command, invalid signature, truncated data
|
||||
|
||||
**Modify: `rust/src/lib.rs`** — add `pub mod proxy_protocol;`
|
||||
|
||||
### Phase 2: Server config + client info fields
|
||||
|
||||
**File: `rust/src/server.rs` — `ServerConfig`**
|
||||
|
||||
Add:
|
||||
```rust
|
||||
/// Enable PROXY protocol v2 parsing on WebSocket connections.
|
||||
/// SECURITY: Must be false when accepting direct client connections.
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept time, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
```
|
||||
|
||||
**File: `rust/src/server.rs` — `ClientInfo`**
|
||||
|
||||
Add:
|
||||
```rust
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
```
|
||||
|
||||
### Phase 3: ACL helper
|
||||
|
||||
**File: `rust/src/acl.rs`**
|
||||
|
||||
Add a public function for the server-level pre-handshake check:
|
||||
```rust
|
||||
/// Check whether a connection source IP is in a block list.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
```
|
||||
|
||||
(Keeps `ip_matches_any` private; exposes only the specific check needed.)
|
||||
|
||||
### Phase 4: WebSocket listener integration
|
||||
|
||||
**File: `rust/src/server.rs` — `run_ws_listener()`**
|
||||
|
||||
Between `listener.accept()` and `transport::accept_connection()`:
|
||||
|
||||
```rust
|
||||
// Determine real client address
|
||||
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
|
||||
match proxy_protocol::read_proxy_header(&mut tcp_stream).await {
|
||||
Ok(header) if header.is_local => {
|
||||
// Health check probe — close gracefully
|
||||
return;
|
||||
}
|
||||
Ok(header) => {
|
||||
info!("PP v2: real client {} -> {}", header.src_addr, header.dst_addr);
|
||||
Some(header.src_addr)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
|
||||
return; // Drop connection
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(tcp_addr) // Direct connection — use TCP SocketAddr
|
||||
};
|
||||
|
||||
// Pre-handshake server-level block list check
|
||||
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, block_list) {
|
||||
warn!("Connection blocked by server IP block list: {}", addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Then proceed with WS upgrade + handle_client_connection as before
|
||||
```
|
||||
|
||||
Key correctness note: `read_proxy_header` reads *exactly* the PP header bytes via `read_exact`. The `TcpStream` is then in a clean state for the WS HTTP upgrade. No buffered wrapper needed.
|
||||
|
||||
### Phase 5: Update `handle_client_connection` signature
|
||||
|
||||
**File: `rust/src/server.rs`**
|
||||
|
||||
Change signature:
|
||||
```rust
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>, // NEW
|
||||
) -> Result<()>
|
||||
```
|
||||
|
||||
After Noise IK handshake + registry lookup (where `client_security` is available), add connection-level per-client ACL:
|
||||
|
||||
```rust
|
||||
if let (Some(ref sec), Some(addr)) = (&client_security, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, sec.ip_block_list.as_deref().unwrap_or(&[])) {
|
||||
anyhow::bail!("Client {} connection denied: source IP {} blocked", registered_client_id, addr);
|
||||
}
|
||||
if let Some(ref allow) = sec.ip_allow_list {
|
||||
if !allow.is_empty() && !acl::is_ip_allowed(v4, allow) {
|
||||
anyhow::bail!("Client {} connection denied: source IP {} not in allow list", registered_client_id, addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Populate `remote_addr` when building `ClientInfo`:
|
||||
```rust
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
```
|
||||
|
||||
### Phase 6: QUIC listener — pass remote addr through
|
||||
|
||||
**File: `rust/src/server.rs` — `run_quic_listener()`**
|
||||
|
||||
QUIC doesn't use PROXY protocol. Just pass `conn.remote_address()` through:
|
||||
```rust
|
||||
let remote = conn.remote_address();
|
||||
// ...
|
||||
handle_client_connection(state, Box::new(sink), Box::new(stream), Some(remote)).await
|
||||
```
|
||||
|
||||
### Phase 7: TypeScript interface updates
|
||||
|
||||
**File: `ts/smartvpn.interfaces.ts`**
|
||||
|
||||
Add to `IVpnServerConfig`:
|
||||
```typescript
|
||||
/** Enable PROXY protocol v2 on incoming WebSocket connections.
|
||||
* Required when behind a reverse proxy that sends PP v2 headers. */
|
||||
proxyProtocol?: boolean;
|
||||
/** Server-level IP block list — applied at TCP accept time, before Noise handshake. */
|
||||
connectionIpBlockList?: string[];
|
||||
```
|
||||
|
||||
Add to `IVpnClientInfo`:
|
||||
```typescript
|
||||
/** Real client IP:port (from PROXY protocol or direct TCP). */
|
||||
remoteAddr?: string;
|
||||
```
|
||||
|
||||
### Phase 8: Tests
|
||||
|
||||
**Rust unit tests in `proxy_protocol.rs`:**
|
||||
- `parse_valid_ipv4_header` — construct a valid PP v2 header with known IPs, verify parsed correctly
|
||||
- `parse_valid_ipv6_header` — same for IPv6
|
||||
- `parse_local_command` — health check probe returns `is_local: true`
|
||||
- `reject_invalid_signature` — random bytes rejected
|
||||
- `reject_truncated_header` — short reads fail gracefully
|
||||
- `reject_v1_header` — PROXY v1 text format rejected (we only support v2)
|
||||
|
||||
**Rust unit tests in `acl.rs`:**
|
||||
- `is_connection_blocked` with various IP patterns
|
||||
|
||||
**TypeScript tests:**
|
||||
- Config validation accepts `proxyProtocol: true` + `connectionIpBlockList`
|
||||
|
||||
---
|
||||
|
||||
## Key Files to Modify
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `rust/src/proxy_protocol.rs` | **NEW** — PP v2 parser + tests |
|
||||
| `rust/src/lib.rs` | Add `pub mod proxy_protocol;` |
|
||||
| `rust/src/server.rs` | `ServerConfig` + `ClientInfo` fields, `run_ws_listener` PP integration, `handle_client_connection` signature + connection ACL, `run_quic_listener` pass-through |
|
||||
| `rust/src/acl.rs` | Add `is_connection_blocked` public function |
|
||||
| `ts/smartvpn.interfaces.ts` | `proxyProtocol`, `connectionIpBlockList`, `remoteAddr` |
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo test` — all existing 121 tests + new PP parser tests pass
|
||||
2. `pnpm test` — all 79 TS tests pass (no PP in test setup, just config validation)
|
||||
3. Manual: `socat` or test harness to send a PP v2 header before WS upgrade, verify server logs real IP
|
||||
111
rust/Cargo.lock
generated
111
rust/Cargo.lock
generated
@@ -46,6 +46,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.21"
|
||||
@@ -306,6 +315,20 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
|
||||
dependencies = [
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
@@ -728,6 +751,30 @@ version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.65"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.4"
|
||||
@@ -942,6 +989,15 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
@@ -1528,8 +1584,10 @@ dependencies = [
|
||||
"boringtun",
|
||||
"bytes",
|
||||
"chacha20poly1305",
|
||||
"chrono",
|
||||
"clap",
|
||||
"futures-util",
|
||||
"ipnet",
|
||||
"mimalloc",
|
||||
"quinn",
|
||||
"rand 0.8.5",
|
||||
@@ -2020,12 +2078,65 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb"
|
||||
dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-interface"
|
||||
version = "0.59.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-strings"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
|
||||
@@ -35,6 +35,8 @@ rustls-pemfile = "2"
|
||||
webpki-roots = "1"
|
||||
mimalloc = "0.1"
|
||||
boringtun = "0.7"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
ipnet = "2"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
||||
302
rust/src/acl.rs
Normal file
302
rust/src/acl.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use std::net::Ipv4Addr;
|
||||
use ipnet::Ipv4Net;
|
||||
|
||||
use crate::client_registry::ClientSecurity;
|
||||
|
||||
/// Result of an ACL check.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AclResult {
|
||||
Allow,
|
||||
DenySrc,
|
||||
DenyDst,
|
||||
}
|
||||
|
||||
/// Check whether a connection source IP is in a server-level block list.
|
||||
/// Used for pre-handshake rejection of known-bad IPs.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
|
||||
/// Check whether a source IP is allowed by allow/block lists.
|
||||
/// Returns true if the IP is permitted (not blocked and passes allow check).
|
||||
pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool {
|
||||
// Deny overrides allow
|
||||
if let Some(bl) = block_list {
|
||||
if ip_matches_any(ip, bl) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If allow list exists and is non-empty, IP must match
|
||||
if let Some(al) = allow_list {
|
||||
if !al.is_empty() && !ip_matches_any(ip, al) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
|
||||
///
|
||||
/// Evaluation order (deny overrides allow):
|
||||
/// 1. If src_ip is in ip_block_list → DenySrc
|
||||
/// 2. If dst_ip is in destination_block_list → DenyDst
|
||||
/// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc
|
||||
/// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst
|
||||
/// 5. Otherwise → Allow
|
||||
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
|
||||
// Step 1: Check source block list (deny overrides)
|
||||
if let Some(ref block_list) = security.ip_block_list {
|
||||
if ip_matches_any(src_ip, block_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check destination block list (deny overrides)
|
||||
if let Some(ref block_list) = security.destination_block_list {
|
||||
if ip_matches_any(dst_ip, block_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Check source allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.ip_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) {
|
||||
return AclResult::DenySrc;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Check destination allow list (if non-empty, must match)
|
||||
if let Some(ref allow_list) = security.destination_allow_list {
|
||||
if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) {
|
||||
return AclResult::DenyDst;
|
||||
}
|
||||
}
|
||||
|
||||
AclResult::Allow
|
||||
}
|
||||
|
||||
/// Check if `ip` matches any pattern in the list.
|
||||
/// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*),
|
||||
/// and IP ranges (192.168.1.1-192.168.1.100).
|
||||
fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if ip_matches(ip, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if `ip` matches a single pattern.
|
||||
fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
// CIDR notation (e.g. 192.168.1.0/24)
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = pattern.parse::<Ipv4Net>() {
|
||||
return net.contains(&ip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// IP range (e.g. 192.168.1.1-192.168.1.100)
|
||||
if pattern.contains('-') {
|
||||
let parts: Vec<&str> = pattern.splitn(2, '-').collect();
|
||||
if parts.len() == 2 {
|
||||
if let (Ok(start), Ok(end)) = (parts[0].trim().parse::<Ipv4Addr>(), parts[1].trim().parse::<Ipv4Addr>()) {
|
||||
let ip_u32 = u32::from(ip);
|
||||
return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wildcard pattern (e.g. 192.168.1.*)
|
||||
if pattern.contains('*') {
|
||||
return wildcard_matches(ip, pattern);
|
||||
}
|
||||
|
||||
// Exact IP match
|
||||
if let Ok(exact) = pattern.parse::<Ipv4Addr>() {
|
||||
return ip == exact;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*".
|
||||
fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool {
|
||||
let ip_octets = ip.octets();
|
||||
let pattern_parts: Vec<&str> = pattern.split('.').collect();
|
||||
if pattern_parts.len() != 4 {
|
||||
return false;
|
||||
}
|
||||
for (i, part) in pattern_parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
continue;
|
||||
}
|
||||
if let Ok(octet) = part.parse::<u8>() {
|
||||
if ip_octets[i] != octet {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client_registry::{ClientRateLimit, ClientSecurity};
|
||||
|
||||
fn security(
|
||||
ip_allow: Option<Vec<&str>>,
|
||||
ip_block: Option<Vec<&str>>,
|
||||
dst_allow: Option<Vec<&str>>,
|
||||
dst_block: Option<Vec<&str>>,
|
||||
) -> ClientSecurity {
|
||||
ClientSecurity {
|
||||
ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()),
|
||||
destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()),
|
||||
max_connections: None,
|
||||
rate_limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ip(s: &str) -> Ipv4Addr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
// ── No restrictions (empty security) ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_security_allows_all() {
|
||||
let sec = security(None, None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_lists_allow_all() {
|
||||
let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![]));
|
||||
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
|
||||
}
|
||||
|
||||
// ── Source IP allow list ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_allow_exact_match() {
|
||||
let sec = security(Some(vec!["10.0.0.1"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_cidr() {
|
||||
let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_wildcard() {
|
||||
let sec = security(Some(vec!["10.0.*.*"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn src_allow_range() {
|
||||
let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Source IP block list (deny overrides) ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn src_block_overrides_allow() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
Some(vec!["192.168.1.100"]),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc);
|
||||
}
|
||||
|
||||
// ── Destination allow list ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_allow_exact() {
|
||||
let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dst_allow_cidr() {
|
||||
let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Destination block list (deny overrides) ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dst_block_overrides_allow() {
|
||||
let sec = security(
|
||||
None,
|
||||
None,
|
||||
Some(vec!["10.0.0.0/8"]),
|
||||
Some(vec!["10.0.0.99"]),
|
||||
);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow);
|
||||
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── Combined source + destination ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn combined_src_and_dst_filtering() {
|
||||
let sec = security(
|
||||
Some(vec!["192.168.1.0/24"]),
|
||||
None,
|
||||
Some(vec!["8.8.8.8"]),
|
||||
None,
|
||||
);
|
||||
// Valid source, valid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow);
|
||||
// Invalid source
|
||||
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc);
|
||||
// Valid source, invalid dest
|
||||
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst);
|
||||
}
|
||||
|
||||
// ── IP matching edge cases ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wildcard_single_octet() {
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*"));
|
||||
assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_boundaries() {
|
||||
assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5"));
|
||||
assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pattern_no_match() {
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99"));
|
||||
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0"));
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,10 @@ use crate::quic_transport;
|
||||
pub struct ClientConfig {
|
||||
pub server_url: String,
|
||||
pub server_public_key: String,
|
||||
/// Client's Noise IK static private key (base64) — required for authentication.
|
||||
pub client_private_key: String,
|
||||
/// Client's Noise IK static public key (base64) — for reference/display.
|
||||
pub client_public_key: String,
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
@@ -104,11 +108,15 @@ impl VpnClient {
|
||||
let connected_since = self.connected_since.clone();
|
||||
let link_health = self.link_health.clone();
|
||||
|
||||
// Decode server public key
|
||||
// Decode keys
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
let client_priv_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.client_private_key,
|
||||
)?;
|
||||
|
||||
// Create transport based on configuration
|
||||
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
|
||||
@@ -171,12 +179,12 @@ impl VpnClient {
|
||||
}
|
||||
};
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
// Noise IK handshake (client side = initiator, presents static key)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
let mut initiator = crypto::create_initiator(&server_pub_key)?;
|
||||
let mut initiator = crypto::create_initiator(&client_priv_key, &server_pub_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es
|
||||
// -> e, es, s, ss
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let init_frame = Frame {
|
||||
packet_type: PacketType::HandshakeInit,
|
||||
@@ -186,7 +194,7 @@ impl VpnClient {
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// <- e, ee
|
||||
// <- e, ee, se
|
||||
let resp_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
|
||||
362
rust/src/client_registry.rs
Normal file
362
rust/src/client_registry.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Per-client rate limiting configuration.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientRateLimit {
|
||||
pub bytes_per_sec: u64,
|
||||
pub burst_bytes: u64,
|
||||
}
|
||||
|
||||
/// Per-client security settings — aligned with SmartProxy's IRouteSecurity pattern.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientSecurity {
|
||||
/// Source IPs/CIDRs the client may connect FROM (empty/None = any).
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// Source IPs blocked — overrides ip_allow_list (deny wins).
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Destination IPs/CIDRs the client may reach (empty/None = all).
|
||||
pub destination_allow_list: Option<Vec<String>>,
|
||||
/// Destination IPs blocked — overrides destination_allow_list (deny wins).
|
||||
pub destination_block_list: Option<Vec<String>>,
|
||||
/// Max concurrent connections from this client.
|
||||
pub max_connections: Option<u32>,
|
||||
/// Per-client rate limiting.
|
||||
pub rate_limit: Option<ClientRateLimit>,
|
||||
}
|
||||
|
||||
/// A registered client entry — the server-side source of truth.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientEntry {
|
||||
/// Human-readable client ID (e.g. "alice-laptop").
|
||||
pub client_id: String,
|
||||
/// Client's Noise IK public key (base64).
|
||||
pub public_key: String,
|
||||
/// Client's WireGuard public key (base64) — optional.
|
||||
pub wg_public_key: Option<String>,
|
||||
/// Security settings (ACLs, rate limits).
|
||||
pub security: Option<ClientSecurity>,
|
||||
/// Traffic priority (lower = higher priority, default: 100).
|
||||
pub priority: Option<u32>,
|
||||
/// Whether this client is enabled (default: true).
|
||||
pub enabled: Option<bool>,
|
||||
/// Tags for grouping.
|
||||
pub tags: Option<Vec<String>>,
|
||||
/// Optional description.
|
||||
pub description: Option<String>,
|
||||
/// Optional expiry (ISO 8601 timestamp).
|
||||
pub expires_at: Option<String>,
|
||||
/// Assigned VPN IP address.
|
||||
pub assigned_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientEntry {
|
||||
/// Whether this client is considered enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Whether this client has expired based on current time.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(ref expires) = self.expires_at {
|
||||
if let Ok(expiry) = chrono::DateTime::parse_from_rfc3339(expires) {
|
||||
return chrono::Utc::now() > expiry;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory client registry with dual-key indexing.
|
||||
pub struct ClientRegistry {
|
||||
/// Primary index: clientId → ClientEntry
|
||||
entries: HashMap<String, ClientEntry>,
|
||||
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
||||
key_index: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
key_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a registry from a list of client entries.
|
||||
pub fn from_entries(entries: Vec<ClientEntry>) -> Result<Self> {
|
||||
let mut registry = Self::new();
|
||||
for entry in entries {
|
||||
registry.add(entry)?;
|
||||
}
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Add a client to the registry.
|
||||
pub fn add(&mut self, entry: ClientEntry) -> Result<()> {
|
||||
if self.entries.contains_key(&entry.client_id) {
|
||||
anyhow::bail!("Client '{}' already exists", entry.client_id);
|
||||
}
|
||||
if self.key_index.contains_key(&entry.public_key) {
|
||||
anyhow::bail!("Public key already registered to another client");
|
||||
}
|
||||
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
||||
self.entries.insert(entry.client_id.clone(), entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a client by ID.
|
||||
pub fn remove(&mut self, client_id: &str) -> Result<ClientEntry> {
|
||||
let entry = self.entries.remove(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
self.key_index.remove(&entry.public_key);
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Get a client by ID.
|
||||
pub fn get_by_id(&self, client_id: &str) -> Option<&ClientEntry> {
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Get a client by public key (used during IK handshake verification).
|
||||
pub fn get_by_key(&self, public_key: &str) -> Option<&ClientEntry> {
|
||||
let client_id = self.key_index.get(public_key)?;
|
||||
self.entries.get(client_id)
|
||||
}
|
||||
|
||||
/// Check if a public key is authorized (exists, enabled, not expired).
|
||||
pub fn is_authorized(&self, public_key: &str) -> bool {
|
||||
match self.get_by_key(public_key) {
|
||||
Some(entry) => entry.is_enabled() && !entry.is_expired(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client entry. The closure receives a mutable reference to the entry.
|
||||
pub fn update<F>(&mut self, client_id: &str, updater: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut ClientEntry),
|
||||
{
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
let old_key = entry.public_key.clone();
|
||||
updater(entry);
|
||||
// If public key changed, update the index
|
||||
if entry.public_key != old_key {
|
||||
self.key_index.remove(&old_key);
|
||||
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all client entries.
|
||||
pub fn list(&self) -> Vec<&ClientEntry> {
|
||||
self.entries.values().collect()
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns the updated entry.
|
||||
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
||||
let entry = self.entries.get_mut(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
// Update key index
|
||||
self.key_index.remove(&entry.public_key);
|
||||
entry.public_key = new_public_key.clone();
|
||||
entry.wg_public_key = new_wg_public_key;
|
||||
self.key_index.insert(new_public_key, client_id.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered clients.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(id: &str, key: &str) -> ClientEntry {
|
||||
ClientEntry {
|
||||
client_id: id.to_string(),
|
||||
public_key: key.to_string(),
|
||||
wg_public_key: None,
|
||||
security: None,
|
||||
priority: None,
|
||||
enabled: None,
|
||||
tags: None,
|
||||
description: None,
|
||||
expires_at: None,
|
||||
assigned_ip: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_lookup() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
assert!(reg.get_by_id("alice").is_some());
|
||||
assert!(reg.get_by_key("key_alice").is_some());
|
||||
assert_eq!(reg.get_by_key("key_alice").unwrap().client_id, "alice");
|
||||
assert!(reg.get_by_id("bob").is_none());
|
||||
assert!(reg.get_by_key("key_bob").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_id() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key1")).unwrap();
|
||||
assert!(reg.add(make_entry("alice", "key2")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_duplicate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "same_key")).unwrap();
|
||||
assert!(reg.add(make_entry("bob", "same_key")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert_eq!(reg.len(), 1);
|
||||
|
||||
let removed = reg.remove("alice").unwrap();
|
||||
assert_eq!(removed.client_id, "alice");
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(reg.get_by_key("key_alice").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.remove("ghost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_enabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
assert!(reg.is_authorized("key_alice")); // enabled by default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_disabled() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.enabled = Some(false);
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_expired() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2020-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(!reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_future_expiry() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.expires_at = Some("2099-01-01T00:00:00Z".to_string());
|
||||
reg.add(entry).unwrap();
|
||||
assert!(reg.is_authorized("key_alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authorized_unknown_key() {
|
||||
let reg = ClientRegistry::new();
|
||||
assert!(!reg.is_authorized("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_client() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_alice")).unwrap();
|
||||
|
||||
reg.update("alice", |entry| {
|
||||
entry.description = Some("Updated".to_string());
|
||||
entry.enabled = Some(false);
|
||||
}).unwrap();
|
||||
|
||||
let entry = reg.get_by_id("alice").unwrap();
|
||||
assert_eq!(entry.description.as_deref(), Some("Updated"));
|
||||
assert!(!entry.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_nonexistent_fails() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
assert!(reg.update("ghost", |_| {}).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rotate_key() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "old_key")).unwrap();
|
||||
|
||||
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
|
||||
|
||||
assert!(reg.get_by_key("old_key").is_none());
|
||||
assert!(reg.get_by_key("new_key").is_some());
|
||||
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_entries() {
|
||||
let entries = vec![
|
||||
make_entry("alice", "key_a"),
|
||||
make_entry("bob", "key_b"),
|
||||
];
|
||||
let reg = ClientRegistry::from_entries(entries).unwrap();
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert!(reg.get_by_key("key_a").is_some());
|
||||
assert!(reg.get_by_key("key_b").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_clients() {
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(make_entry("alice", "key_a")).unwrap();
|
||||
reg.add(make_entry("bob", "key_b")).unwrap();
|
||||
let list = reg.list();
|
||||
assert_eq!(list.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_with_rate_limit() {
|
||||
let mut entry = make_entry("alice", "key_alice");
|
||||
entry.security = Some(ClientSecurity {
|
||||
ip_allow_list: Some(vec!["192.168.1.0/24".to_string()]),
|
||||
ip_block_list: Some(vec!["192.168.1.100".to_string()]),
|
||||
destination_allow_list: None,
|
||||
destination_block_list: None,
|
||||
max_connections: Some(5),
|
||||
rate_limit: Some(ClientRateLimit {
|
||||
bytes_per_sec: 1_000_000,
|
||||
burst_bytes: 2_000_000,
|
||||
}),
|
||||
});
|
||||
let mut reg = ClientRegistry::new();
|
||||
reg.add(entry).unwrap();
|
||||
let e = reg.get_by_id("alice").unwrap();
|
||||
let sec = e.security.as_ref().unwrap();
|
||||
assert_eq!(sec.rate_limit.as_ref().unwrap().bytes_per_sec, 1_000_000);
|
||||
assert_eq!(sec.max_connections, Some(5));
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,10 @@ use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use snow::Builder;
|
||||
|
||||
/// Noise protocol pattern: NK (client knows server pubkey, no client auth at Noise level)
|
||||
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
|
||||
/// Noise protocol pattern: IK (client presents static key, server authenticates client)
|
||||
/// IK = Initiator's static key is transmitted; responder's Key is pre-known.
|
||||
/// This provides mutual authentication: server verifies client identity via public key.
|
||||
const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
/// Generate a new Noise static keypair.
|
||||
/// Returns (public_key_base64, private_key_base64).
|
||||
@@ -22,18 +24,23 @@ pub fn generate_keypair_raw() -> Result<snow::Keypair> {
|
||||
Ok(builder.generate_keypair()?)
|
||||
}
|
||||
|
||||
/// Create a Noise NK initiator (client side).
|
||||
/// The client knows the server's static public key.
|
||||
pub fn create_initiator(server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
/// Create a Noise IK initiator (client side).
|
||||
/// The client provides its own static keypair AND the server's public key.
|
||||
/// The client's static key is transmitted (encrypted) during the handshake,
|
||||
/// allowing the server to authenticate the client.
|
||||
pub fn create_initiator(client_private_key: &[u8], server_public_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
.local_private_key(client_private_key)
|
||||
.remote_public_key(server_public_key)
|
||||
.build_initiator()?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Create a Noise NK responder (server side).
|
||||
/// Create a Noise IK responder (server side).
|
||||
/// The server uses its static private key.
|
||||
/// After the handshake, call `get_remote_static()` on the HandshakeState
|
||||
/// (before `into_transport_mode()`) to retrieve the client's public key.
|
||||
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
let builder = Builder::new(NOISE_PATTERN.parse()?);
|
||||
let state = builder
|
||||
@@ -42,19 +49,20 @@ pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Perform the full Noise NK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport).
|
||||
/// Perform the full Noise IK handshake between initiator and responder.
|
||||
/// Returns (initiator_transport, responder_transport, client_public_key).
|
||||
/// The client_public_key is extracted from the responder before entering transport mode.
|
||||
pub fn perform_handshake(
|
||||
mut initiator: snow::HandshakeState,
|
||||
mut responder: snow::HandshakeState,
|
||||
) -> Result<(snow::TransportState, snow::TransportState)> {
|
||||
) -> Result<(snow::TransportState, snow::TransportState, Vec<u8>)> {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es (initiator sends)
|
||||
// -> e, es, s, ss (initiator sends ephemeral + encrypted static key)
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let msg1 = buf[..len].to_vec();
|
||||
|
||||
// <- e, ee (responder reads and responds)
|
||||
// <- e, ee, se (responder reads and responds)
|
||||
responder.read_message(&msg1, &mut buf)?;
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let msg2 = buf[..len].to_vec();
|
||||
@@ -62,10 +70,16 @@ pub fn perform_handshake(
|
||||
// Initiator reads response
|
||||
initiator.read_message(&msg2, &mut buf)?;
|
||||
|
||||
// Extract client's public key from responder BEFORE entering transport mode
|
||||
let client_public_key = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake did not provide client static key"))?
|
||||
.to_vec();
|
||||
|
||||
let i_transport = initiator.into_transport_mode()?;
|
||||
let r_transport = responder.into_transport_mode()?;
|
||||
|
||||
Ok((i_transport, r_transport))
|
||||
Ok((i_transport, r_transport, client_public_key))
|
||||
}
|
||||
|
||||
/// XChaCha20-Poly1305 encryption for post-handshake data.
|
||||
@@ -135,15 +149,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_handshake() {
|
||||
fn noise_ik_handshake() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
let initiator = create_initiator(&server_kp.public).unwrap();
|
||||
let initiator = create_initiator(&client_kp.private, &server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
let (mut i_transport, mut r_transport) =
|
||||
let (mut i_transport, mut r_transport, remote_key) =
|
||||
perform_handshake(initiator, responder).unwrap();
|
||||
|
||||
// Verify the server received the client's public key
|
||||
assert_eq!(remote_key, client_kp.public);
|
||||
|
||||
// Test encrypted communication
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let plaintext = b"hello from client";
|
||||
@@ -159,6 +177,20 @@ mod tests {
|
||||
assert_eq!(&out[..len], plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_ik_wrong_server_key_fails() {
|
||||
let server_kp = generate_keypair_raw().unwrap();
|
||||
let wrong_server_kp = generate_keypair_raw().unwrap();
|
||||
let client_kp = generate_keypair_raw().unwrap();
|
||||
|
||||
// Client uses wrong server public key
|
||||
let initiator = create_initiator(&client_kp.private, &wrong_server_kp.public).unwrap();
|
||||
let responder = create_responder(&server_kp.private).unwrap();
|
||||
|
||||
// Handshake should fail because client targeted wrong server
|
||||
assert!(perform_handshake(initiator, responder).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xchacha_encrypt_decrypt() {
|
||||
let key = [42u8; 32];
|
||||
|
||||
@@ -18,3 +18,6 @@ pub mod ratelimit;
|
||||
pub mod qos;
|
||||
pub mod mtu;
|
||||
pub mod wireguard;
|
||||
pub mod client_registry;
|
||||
pub mod acl;
|
||||
pub mod proxy_protocol;
|
||||
|
||||
@@ -585,6 +585,103 @@ async fn handle_server_request(
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
// ── Client Registry (Hub) Commands ────────────────────────────────
|
||||
"createClient" => {
|
||||
let client_partial = request.params.get("client").cloned().unwrap_or_default();
|
||||
match vpn_server.create_client(client_partial).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Create client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.remove_registered_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.get_registered_client(&client_id).await {
|
||||
Ok(entry) => ManagementResponse::ok(id, entry),
|
||||
Err(e) => ManagementResponse::err(id, format!("Get client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listRegisteredClients" => {
|
||||
let clients = vpn_server.list_registered_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"updateClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let update = request.params.get("update").cloned().unwrap_or_default();
|
||||
match vpn_server.update_registered_client(&client_id, update).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Update client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"enableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.enable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Enable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"disableClient" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.disable_client(&client_id).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Disable client failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"rotateClientKey" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
match vpn_server.rotate_client_key(&client_id).await {
|
||||
Ok(bundle) => ManagementResponse::ok(id, bundle),
|
||||
Err(e) => ManagementResponse::err(id, format!("Key rotation failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"exportClientConfig" => {
|
||||
let client_id = match request.params.get("clientId").and_then(|v| v.as_str()) {
|
||||
Some(cid) => cid.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing clientId".to_string()),
|
||||
};
|
||||
let format = request.params.get("format").and_then(|v| v.as_str()).unwrap_or("smartvpn");
|
||||
match vpn_server.export_client_config(&client_id, format).await {
|
||||
Ok(config) => ManagementResponse::ok(id, config),
|
||||
Err(e) => ManagementResponse::err(id, format!("Export failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"generateClientKeypair" => match crypto::generate_keypair() {
|
||||
Ok((public_key, private_key)) => ManagementResponse::ok(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"publicKey": public_key,
|
||||
"privateKey": private_key,
|
||||
}),
|
||||
),
|
||||
Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)),
|
||||
},
|
||||
_ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
261
rust/src/proxy_protocol.rs
Normal file
261
rust/src/proxy_protocol.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
//! PROXY protocol v2 parser for extracting real client addresses
|
||||
//! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.).
|
||||
//!
|
||||
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
|
||||
|
||||
use anyhow::Result;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Timeout for reading the PROXY protocol header from a new connection.
|
||||
const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// The 12-byte PP v2 signature.
|
||||
const PP_V2_SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
/// Parsed PROXY protocol v2 header.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyHeader {
|
||||
/// Real client source address.
|
||||
pub src_addr: SocketAddr,
|
||||
/// Proxy-to-server destination address.
|
||||
pub dst_addr: SocketAddr,
|
||||
/// True if this is a LOCAL command (health check probe from proxy).
|
||||
pub is_local: bool,
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
///
|
||||
/// Reads exactly the header bytes — the stream is in a clean state for
|
||||
/// WebSocket upgrade afterward. Returns an error on timeout, invalid
|
||||
/// signature, or malformed header.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream))
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))?
|
||||
}
|
||||
|
||||
async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
// Read the 16-byte fixed prefix
|
||||
let mut prefix = [0u8; 16];
|
||||
stream.read_exact(&mut prefix).await?;
|
||||
|
||||
// Validate the 12-byte signature
|
||||
if prefix[..12] != PP_V2_SIGNATURE {
|
||||
anyhow::bail!("Invalid PROXY protocol v2 signature");
|
||||
}
|
||||
|
||||
// Byte 12: version (high nibble) | command (low nibble)
|
||||
let version = (prefix[12] & 0xF0) >> 4;
|
||||
let command = prefix[12] & 0x0F;
|
||||
|
||||
if version != 2 {
|
||||
anyhow::bail!("Unsupported PROXY protocol version: {}", version);
|
||||
}
|
||||
|
||||
// Byte 13: address family (high nibble) | protocol (low nibble)
|
||||
let addr_family = (prefix[13] & 0xF0) >> 4;
|
||||
let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP)
|
||||
|
||||
// Bytes 14-15: address data length (big-endian)
|
||||
let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize;
|
||||
|
||||
// Read the address data
|
||||
let mut addr_data = vec![0u8; addr_len];
|
||||
if addr_len > 0 {
|
||||
stream.read_exact(&mut addr_data).await?;
|
||||
}
|
||||
|
||||
// LOCAL command (0x00) = health check, no real address
|
||||
if command == 0x00 {
|
||||
return Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
is_local: true,
|
||||
});
|
||||
}
|
||||
|
||||
// PROXY command (0x01) — parse address block
|
||||
if command != 0x01 {
|
||||
anyhow::bail!("Unknown PROXY protocol command: {}", command);
|
||||
}
|
||||
|
||||
match addr_family {
|
||||
// AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes
|
||||
1 => {
|
||||
if addr_data.len() < 12 {
|
||||
anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
|
||||
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes
|
||||
2 => {
|
||||
if addr_data.len() < 36 {
|
||||
anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
|
||||
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)),
|
||||
dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_UNSPEC or unknown
|
||||
_ => {
|
||||
anyhow::bail!("Unsupported address family: {}", addr_family);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
|
||||
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
|
||||
match (src, dst) {
|
||||
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x11); // AF_INET | STREAM
|
||||
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x21); // AF_INET6 | STREAM
|
||||
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
_ => panic!("Mismatched address families"),
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 LOCAL header (health check probe).
|
||||
pub fn build_pp_v2_local() -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
buf.push(0x20); // version 2 | LOCAL command
|
||||
buf.push(0x00); // AF_UNSPEC
|
||||
buf.extend_from_slice(&0u16.to_be_bytes()); // no address data
|
||||
buf
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Helper: create a TCP pair and write data to the client side, then parse from server side.
|
||||
async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result<ProxyHeader> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let data = header_bytes.to_vec();
|
||||
let client_task = tokio::spawn(async move {
|
||||
let mut client = TcpStream::connect(addr).await.unwrap();
|
||||
client.write_all(&data).await.unwrap();
|
||||
client // keep alive
|
||||
});
|
||||
|
||||
let (mut server_stream, _) = listener.accept().await.unwrap();
|
||||
let result = read_proxy_header(&mut server_stream).await;
|
||||
let _client = client_task.await.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv4_header() {
|
||||
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
|
||||
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv6_header() {
|
||||
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[2001:db8::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_local_command() {
|
||||
let header = build_pp_v2_local();
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(parsed.is_local);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_invalid_signature() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[0] = 0xFF; // corrupt signature
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("signature"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_wrong_version() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[12] = 0x10; // version 1 instead of 2
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("version"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_truncated_header() {
|
||||
// Only 10 bytes — not even the full signature
|
||||
let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv4_header_is_exactly_28_bytes() {
|
||||
let src = "1.2.3.4:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "5.6.7.8:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28
|
||||
assert_eq!(header.len(), 28);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6_header_is_exactly_52_bytes() {
|
||||
let src = "[::1]:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52
|
||||
assert_eq!(header.len(), 52);
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,8 @@ use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
use crate::acl;
|
||||
use crate::client_registry::{ClientEntry, ClientRegistry};
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
@@ -45,6 +47,13 @@ pub struct ServerConfig {
|
||||
pub quic_listen_addr: Option<String>,
|
||||
/// QUIC idle timeout in seconds (default: 30).
|
||||
pub quic_idle_timeout_secs: Option<u64>,
|
||||
/// Pre-registered clients for IK authentication.
|
||||
pub clients: Option<Vec<ClientEntry>>,
|
||||
/// Enable PROXY protocol v2 parsing on incoming WebSocket connections.
|
||||
/// SECURITY: Must be false when accepting direct client connections.
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -62,6 +71,12 @@ pub struct ClientInfo {
|
||||
pub keepalives_received: u64,
|
||||
pub rate_limit_bytes_per_sec: Option<u64>,
|
||||
pub burst_bytes: Option<u64>,
|
||||
/// Client's authenticated Noise IK public key (base64).
|
||||
pub authenticated_key: String,
|
||||
/// Registered client ID from the client registry.
|
||||
pub registered_client_id: String,
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -88,6 +103,7 @@ pub struct ServerState {
|
||||
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
|
||||
pub mtu_config: MtuConfig,
|
||||
pub started_at: std::time::Instant,
|
||||
pub client_registry: RwLock<ClientRegistry>,
|
||||
}
|
||||
|
||||
/// The VPN server.
|
||||
@@ -127,6 +143,12 @@ impl VpnServer {
|
||||
let overhead = TunnelOverhead::default_overhead();
|
||||
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
|
||||
|
||||
// Build client registry from config
|
||||
let registry = ClientRegistry::from_entries(
|
||||
config.clients.clone().unwrap_or_default()
|
||||
)?;
|
||||
info!("Client registry loaded with {} entries", registry.len());
|
||||
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
ip_pool: Mutex::new(ip_pool),
|
||||
@@ -135,6 +157,7 @@ impl VpnServer {
|
||||
rate_limiters: Mutex::new(HashMap::new()),
|
||||
mtu_config,
|
||||
started_at: std::time::Instant::now(),
|
||||
client_registry: RwLock::new(registry),
|
||||
});
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||
@@ -287,10 +310,267 @@ impl VpnServer {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ───────────────────────────────────
|
||||
|
||||
/// Create a new client entry. Generates keypairs and assigns an IP.
|
||||
/// Returns a JSON value with the full config bundle including secrets.
|
||||
pub async fn create_client(&self, partial: serde_json::Value) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let client_id = partial.get("clientId")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("clientId is required"))?
|
||||
.to_string();
|
||||
|
||||
// Generate Noise IK keypair for the client
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
|
||||
// Generate WireGuard keypair for the client
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
// Allocate a VPN IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Build entry from partial + generated values
|
||||
let entry = ClientEntry {
|
||||
client_id: client_id.clone(),
|
||||
public_key: noise_pub.clone(),
|
||||
wg_public_key: Some(wg_pub.clone()),
|
||||
security: serde_json::from_value(
|
||||
partial.get("security").cloned().unwrap_or(serde_json::Value::Null)
|
||||
).ok(),
|
||||
priority: partial.get("priority").and_then(|v| v.as_u64()).map(|v| v as u32),
|
||||
enabled: partial.get("enabled").and_then(|v| v.as_bool()).or(Some(true)),
|
||||
tags: partial.get("tags").and_then(|v| {
|
||||
v.as_array().map(|a| a.iter().filter_map(|s| s.as_str().map(String::from)).collect())
|
||||
}),
|
||||
description: partial.get("description").and_then(|v| v.as_str()).map(String::from),
|
||||
expires_at: partial.get("expiresAt").and_then(|v| v.as_str()).map(String::from),
|
||||
assigned_ip: Some(assigned_ip.to_string()),
|
||||
};
|
||||
|
||||
// Add to registry
|
||||
state.client_registry.write().await.add(entry.clone())?;
|
||||
|
||||
// Build SmartVPN client config
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
// Build WireGuard config string
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv,
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
let entry_json = serde_json::to_value(&entry)?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Remove a registered client from the registry (and disconnect if connected).
|
||||
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let entry = state.client_registry.write().await.remove(client_id)?;
|
||||
// Release the IP if assigned
|
||||
if let Some(ref ip_str) = entry.assigned_ip {
|
||||
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
}
|
||||
}
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a registered client by ID.
|
||||
pub async fn get_registered_client(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
Ok(serde_json::to_value(entry)?)
|
||||
}
|
||||
|
||||
/// List all registered clients.
|
||||
pub async fn list_registered_clients(&self) -> Vec<ClientEntry> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.client_registry.read().await.list().into_iter().cloned().collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a registered client's fields.
|
||||
pub async fn update_registered_client(&self, client_id: &str, update: serde_json::Value) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
if let Some(security) = update.get("security") {
|
||||
entry.security = serde_json::from_value(security.clone()).ok();
|
||||
}
|
||||
if let Some(priority) = update.get("priority").and_then(|v| v.as_u64()) {
|
||||
entry.priority = Some(priority as u32);
|
||||
}
|
||||
if let Some(enabled) = update.get("enabled").and_then(|v| v.as_bool()) {
|
||||
entry.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(tags) = update.get("tags").and_then(|v| v.as_array()) {
|
||||
entry.tags = Some(tags.iter().filter_map(|s| s.as_str().map(String::from)).collect());
|
||||
}
|
||||
if let Some(desc) = update.get("description").and_then(|v| v.as_str()) {
|
||||
entry.description = Some(desc.to_string());
|
||||
}
|
||||
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
|
||||
entry.expires_at = Some(expires.to_string());
|
||||
}
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enable a registered client.
|
||||
pub async fn enable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(true);
|
||||
})
|
||||
}
|
||||
|
||||
/// Disable a registered client (also disconnects if connected).
|
||||
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
state.client_registry.write().await.update(client_id, |entry| {
|
||||
entry.enabled = Some(false);
|
||||
})?;
|
||||
// Disconnect if currently connected
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
pub async fn rotate_client_key(&self, client_id: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
|
||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||
|
||||
state.client_registry.write().await.rotate_key(
|
||||
client_id,
|
||||
noise_pub.clone(),
|
||||
Some(wg_pub.clone()),
|
||||
)?;
|
||||
|
||||
// Disconnect existing connection (old key is no longer valid)
|
||||
let _ = self.disconnect_client(client_id).await;
|
||||
|
||||
// Get updated entry for the config bundle
|
||||
let entry_json = self.get_registered_client(client_id).await?;
|
||||
let assigned_ip = entry_json.get("assignedIp")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0.0.0.0");
|
||||
|
||||
let smartvpn_config = serde_json::json!({
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPrivateKey": noise_priv,
|
||||
"clientPublicKey": noise_pub,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
});
|
||||
|
||||
let wg_config = format!(
|
||||
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
wg_priv, assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"entry": entry_json,
|
||||
"smartvpnConfig": smartvpn_config,
|
||||
"wireguardConfig": wg_config,
|
||||
"secrets": {
|
||||
"noisePrivateKey": noise_priv,
|
||||
"wgPrivateKey": wg_priv,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Export a client config (without secrets) in the specified format.
|
||||
pub async fn export_client_config(&self, client_id: &str, format: &str) -> Result<serde_json::Value> {
|
||||
let state = self.state.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||
let registry = state.client_registry.read().await;
|
||||
let entry = registry.get_by_id(client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||
|
||||
match format {
|
||||
"smartvpn" => {
|
||||
Ok(serde_json::json!({
|
||||
"config": {
|
||||
"serverUrl": format!("wss://{}",
|
||||
state.config.listen_addr.replace("0.0.0.0", "localhost")),
|
||||
"serverPublicKey": state.config.public_key,
|
||||
"clientPublicKey": entry.public_key,
|
||||
"dns": state.config.dns,
|
||||
"mtu": state.config.mtu,
|
||||
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
|
||||
}
|
||||
}))
|
||||
}
|
||||
"wireguard" => {
|
||||
let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0");
|
||||
let config = format!(
|
||||
"[Interface]\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
|
||||
assigned_ip,
|
||||
state.config.dns.as_ref()
|
||||
.map(|d| format!("DNS = {}", d.join(", ")))
|
||||
.unwrap_or_default(),
|
||||
state.config.public_key,
|
||||
state.config.listen_addr,
|
||||
);
|
||||
Ok(serde_json::json!({ "config": config }))
|
||||
}
|
||||
_ => anyhow::bail!("Unknown format: {}", format),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
|
||||
/// to the transport-agnostic `handle_client_connection`.
|
||||
/// WebSocket listener — accepts TCP connections, optionally parses PROXY protocol v2,
|
||||
/// upgrades to WS, then hands off to `handle_client_connection`.
|
||||
async fn run_ws_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
@@ -303,17 +583,51 @@ async fn run_ws_listener(
|
||||
tokio::select! {
|
||||
accept = listener.accept() => {
|
||||
match accept {
|
||||
Ok((stream, addr)) => {
|
||||
info!("New connection from {}", addr);
|
||||
Ok((mut tcp_stream, tcp_addr)) => {
|
||||
info!("New connection from {}", tcp_addr);
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
match transport::accept_connection(stream).await {
|
||||
// Phase 0: Parse PROXY protocol v2 header if enabled
|
||||
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
|
||||
match crate::proxy_protocol::read_proxy_header(&mut tcp_stream).await {
|
||||
Ok(header) if header.is_local => {
|
||||
info!("PP v2 LOCAL probe from {}", tcp_addr);
|
||||
return; // Health check — close gracefully
|
||||
}
|
||||
Ok(header) => {
|
||||
info!("PP v2: real client {} (via {})", header.src_addr, tcp_addr);
|
||||
Some(header.src_addr)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
|
||||
return; // Drop connection
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(tcp_addr) // Direct connection — use TCP SocketAddr
|
||||
};
|
||||
|
||||
// Phase 1: Server-level connection IP block list (pre-handshake)
|
||||
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
|
||||
if !block_list.is_empty() {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, block_list) {
|
||||
warn!("Connection blocked by server IP block list: {}", addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: WebSocket upgrade + VPN handshake
|
||||
match transport::accept_connection(tcp_stream).await {
|
||||
Ok(ws) => {
|
||||
let (sink, stream) = transport_trait::split_ws(ws);
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
remote_addr,
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
@@ -389,6 +703,7 @@ async fn run_quic_listener(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
Some(remote),
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
@@ -421,26 +736,24 @@ async fn run_quic_listener(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
|
||||
/// the client, and runs the main packet forwarding loop.
|
||||
/// Transport-agnostic client handler. Performs the Noise IK handshake, authenticates
|
||||
/// the client against the registry, and runs the main packet forwarding loop.
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>,
|
||||
) -> Result<()> {
|
||||
let client_id = uuid_v4();
|
||||
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
let server_private_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&state.config.private_key,
|
||||
)?;
|
||||
|
||||
// Noise IK handshake (server side = responder)
|
||||
let mut responder = crypto::create_responder(&server_private_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// Receive handshake init
|
||||
// Receive handshake init (-> e, es, s, ss)
|
||||
let init_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before handshake"),
|
||||
@@ -455,6 +768,47 @@ async fn handle_client_connection(
|
||||
}
|
||||
|
||||
responder.read_message(&frame.payload, &mut buf)?;
|
||||
|
||||
// Extract client's static public key BEFORE entering transport mode
|
||||
let client_pub_key_bytes = responder
|
||||
.get_remote_static()
|
||||
.ok_or_else(|| anyhow::anyhow!("IK handshake: no client static key received"))?
|
||||
.to_vec();
|
||||
let client_pub_key_b64 = base64::Engine::encode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&client_pub_key_bytes,
|
||||
);
|
||||
|
||||
// Verify client against registry
|
||||
let (registered_client_id, client_security) = {
|
||||
let registry = state.client_registry.read().await;
|
||||
if !registry.is_authorized(&client_pub_key_b64) {
|
||||
warn!("Rejecting unauthorized client with key {}", &client_pub_key_b64[..8]);
|
||||
// Send handshake response but then disconnect
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_frame = Frame {
|
||||
packet_type: PacketType::HandshakeResp,
|
||||
payload: buf[..len].to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// Send disconnect frame
|
||||
let disconnect_frame = Frame {
|
||||
packet_type: PacketType::Disconnect,
|
||||
payload: Vec::new(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Client not authorized");
|
||||
}
|
||||
let entry = registry.get_by_key(&client_pub_key_b64).unwrap();
|
||||
(entry.client_id.clone(), entry.security.clone())
|
||||
};
|
||||
|
||||
// Complete handshake (<- e, ee, se)
|
||||
let len = responder.write_message(&[], &mut buf)?;
|
||||
let response_payload = buf[..len].to_vec();
|
||||
|
||||
@@ -468,9 +822,42 @@ async fn handle_client_connection(
|
||||
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Register client
|
||||
let default_rate = state.config.default_rate_limit_bytes_per_sec;
|
||||
let default_burst = state.config.default_burst_bytes;
|
||||
// Connection-level ACL: check real client IP against per-client ipAllowList/ipBlockList
|
||||
if let (Some(ref sec), Some(ref addr)) = (&client_security, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if !acl::is_source_allowed(
|
||||
v4,
|
||||
sec.ip_allow_list.as_deref(),
|
||||
sec.ip_block_list.as_deref(),
|
||||
) {
|
||||
warn!("Connection-level ACL denied client {} from IP {}", registered_client_id, addr);
|
||||
let disconnect_frame = Frame { packet_type: PacketType::Disconnect, payload: Vec::new() };
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Connection denied: source IP {} not allowed for client {}", addr, registered_client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the registered client ID as the connection ID
|
||||
let client_id = registered_client_id.clone();
|
||||
|
||||
// Allocate IP
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
|
||||
// Determine rate limits: per-client security overrides server defaults
|
||||
let (rate_limit, burst) = if let Some(ref sec) = client_security {
|
||||
if let Some(ref rl) = sec.rate_limit {
|
||||
(Some(rl.bytes_per_sec), Some(rl.burst_bytes))
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
}
|
||||
} else {
|
||||
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
|
||||
};
|
||||
|
||||
// Register connected client
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: assigned_ip.to_string(),
|
||||
@@ -481,13 +868,16 @@ async fn handle_client_connection(
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: default_rate,
|
||||
burst_bytes: default_burst,
|
||||
rate_limit_bytes_per_sec: rate_limit,
|
||||
burst_bytes: burst,
|
||||
authenticated_key: client_pub_key_b64.clone(),
|
||||
registered_client_id: registered_client_id.clone(),
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
// Set up rate limiter if defaults are configured
|
||||
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
|
||||
// Set up rate limiter
|
||||
if let (Some(rate), Some(burst)) = (rate_limit, burst) {
|
||||
state
|
||||
.rate_limiters
|
||||
.lock()
|
||||
@@ -517,7 +907,9 @@ async fn handle_client_connection(
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
info!("Client {} ({}) connected with IP {} from {}",
|
||||
registered_client_id, &client_pub_key_b64[..8], assigned_ip,
|
||||
remote_addr.map(|a| a.to_string()).unwrap_or_else(|| "unknown".to_string()));
|
||||
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
@@ -534,6 +926,24 @@ async fn handle_client_connection(
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
// ACL check on decrypted packet
|
||||
if let Some(ref sec) = client_security {
|
||||
if len >= 20 {
|
||||
// Extract src/dst from IPv4 header
|
||||
let src = Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]);
|
||||
let dst = Ipv4Addr::new(buf[16], buf[17], buf[18], buf[19]);
|
||||
let acl_result = acl::check_acl(sec, src, dst);
|
||||
if acl_result != acl::AclResult::Allow {
|
||||
let mut clients = state.clients.write().await;
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
info.packets_dropped += 1;
|
||||
info.bytes_dropped += len as u64;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting check
|
||||
let allowed = {
|
||||
let mut limiters = state.rate_limiters.lock().await;
|
||||
@@ -635,20 +1045,6 @@ async fn handle_client_connection(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn uuid_v4() -> String {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let bytes: [u8; 16] = rng.gen();
|
||||
format!(
|
||||
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
|
||||
bytes[0], bytes[1], bytes[2], bytes[3],
|
||||
bytes[4], bytes[5],
|
||||
bytes[6], bytes[7],
|
||||
bytes[8], bytes[9],
|
||||
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
|
||||
)
|
||||
}
|
||||
|
||||
fn timestamp_now() -> String {
|
||||
use std::time::SystemTime;
|
||||
let duration = SystemTime::now()
|
||||
|
||||
320
rust/tests/wg_e2e.rs
Normal file
320
rust/tests/wg_e2e.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! End-to-end WireGuard protocol tests over real UDP sockets.
|
||||
//!
|
||||
//! Entirely userspace — no root, no TUN devices.
|
||||
//! Two boringtun `Tunn` instances exchange real WireGuard packets
|
||||
//! over loopback UDP, validating handshake, encryption, and data flow.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::time;
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
use smartvpn_daemon::wireguard::generate_wg_keypair;
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn parse_key_pair(pub_b64: &str, priv_b64: &str) -> (PublicKey, StaticSecret) {
|
||||
let pub_bytes: [u8; 32] = BASE64.decode(pub_b64).unwrap().try_into().unwrap();
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
(PublicKey::from(pub_bytes), StaticSecret::from(priv_bytes))
|
||||
}
|
||||
|
||||
fn clone_secret(priv_b64: &str) -> StaticSecret {
|
||||
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
|
||||
StaticSecret::from(priv_bytes)
|
||||
}
|
||||
|
||||
fn make_ipv4_packet(src: Ipv4Addr, dst: Ipv4Addr, payload: &[u8]) -> Vec<u8> {
|
||||
let total_len = 20 + payload.len();
|
||||
let mut pkt = vec![0u8; total_len];
|
||||
pkt[0] = 0x45;
|
||||
pkt[2] = (total_len >> 8) as u8;
|
||||
pkt[3] = total_len as u8;
|
||||
pkt[9] = 0x11;
|
||||
pkt[12..16].copy_from_slice(&src.octets());
|
||||
pkt[16..20].copy_from_slice(&dst.octets());
|
||||
pkt[20..].copy_from_slice(payload);
|
||||
pkt
|
||||
}
|
||||
|
||||
/// Send any WriteToNetwork result, then drain the tunn for more packets.
|
||||
async fn send_and_drain(
|
||||
tunn: &mut Tunn,
|
||||
pkt: &[u8],
|
||||
socket: &UdpSocket,
|
||||
peer: SocketAddr,
|
||||
) {
|
||||
socket.send_to(pkt, peer).await.unwrap();
|
||||
let mut drain_buf = vec![0u8; 2048];
|
||||
loop {
|
||||
match tunn.decapsulate(None, &[], &mut drain_buf) {
|
||||
TunnResult::WriteToNetwork(p) => { socket.send_to(p, peer).await.unwrap(); }
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to receive a UDP packet and decapsulate it. Returns decrypted IP data if any.
|
||||
async fn try_recv_decap(
|
||||
tunn: &mut Tunn,
|
||||
socket: &UdpSocket,
|
||||
timeout_ms: u64,
|
||||
) -> Option<(Vec<u8>, Ipv4Addr, SocketAddr)> {
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
let (n, src_addr) = match time::timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
socket.recv_from(&mut recv_buf),
|
||||
).await {
|
||||
Ok(Ok(r)) => r,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let result = tunn.decapsulate(Some(src_addr.ip()), &recv_buf[..n], &mut dst_buf);
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(tunn, pkt, socket, src_addr).await;
|
||||
None
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(pkt, addr) => Some((pkt.to_vec(), addr, src_addr)),
|
||||
TunnResult::WriteToTunnelV6(_, _) => None,
|
||||
TunnResult::Done => None,
|
||||
TunnResult::Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the full WireGuard handshake between client and server over real UDP.
|
||||
async fn do_handshake(
|
||||
client_tunn: &mut Tunn,
|
||||
server_tunn: &mut Tunn,
|
||||
client_socket: &UdpSocket,
|
||||
server_socket: &UdpSocket,
|
||||
server_addr: SocketAddr,
|
||||
) {
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let mut recv_buf = vec![0u8; 65536];
|
||||
let mut dst_buf = vec![0u8; 65536];
|
||||
|
||||
// Step 1: Client initiates handshake
|
||||
match client_tunn.encapsulate(&[], &mut buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => panic!("Expected handshake init"),
|
||||
}
|
||||
|
||||
// Step 2: Server receives init → sends response
|
||||
let (n, client_from) = server_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match server_tunn.decapsulate(Some(client_from.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(server_tunn, pkt, server_socket, client_from).await;
|
||||
}
|
||||
other => panic!("Expected WriteToNetwork from server, got variant {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Step 3: Client receives response
|
||||
let (n, _) = client_socket.recv_from(&mut recv_buf).await.unwrap();
|
||||
match client_tunn.decapsulate(Some(server_addr.ip()), &recv_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
send_and_drain(client_tunn, pkt, client_socket, server_addr).await;
|
||||
}
|
||||
TunnResult::Done => {}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Step 4: Process any remaining handshake packets
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 200).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 100).await;
|
||||
|
||||
// Step 5: Timer ticks to settle
|
||||
for _ in 0..3 {
|
||||
match server_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
server_socket.send_to(pkt, client_from).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match client_tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
client_socket.send_to(pkt, server_addr).await.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = try_recv_decap(server_tunn, server_socket, 50).await;
|
||||
let _ = try_recv_decap(client_tunn, client_socket, 50).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn variant_name(r: &TunnResult) -> &'static str {
|
||||
match r {
|
||||
TunnResult::Done => "Done",
|
||||
TunnResult::Err(_) => "Err",
|
||||
TunnResult::WriteToNetwork(_) => "WriteToNetwork",
|
||||
TunnResult::WriteToTunnelV4(_, _) => "WriteToTunnelV4",
|
||||
TunnResult::WriteToTunnelV6(_, _) => "WriteToTunnelV6",
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an IP packet and send it, then loop-receive on the other side until decrypted.
|
||||
async fn send_and_expect_data(
|
||||
sender_tunn: &mut Tunn,
|
||||
receiver_tunn: &mut Tunn,
|
||||
sender_socket: &UdpSocket,
|
||||
receiver_socket: &UdpSocket,
|
||||
dest_addr: SocketAddr,
|
||||
ip_packet: &[u8],
|
||||
) -> (Vec<u8>, Ipv4Addr) {
|
||||
let mut enc_buf = vec![0u8; 65536];
|
||||
|
||||
match sender_tunn.encapsulate(ip_packet, &mut enc_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
sender_socket.send_to(pkt, dest_addr).await.unwrap();
|
||||
}
|
||||
TunnResult::Err(e) => panic!("Encapsulate failed: {:?}", e),
|
||||
other => panic!("Expected WriteToNetwork, got {}", variant_name(&other)),
|
||||
}
|
||||
|
||||
// Receive — may need a few rounds for control packets
|
||||
for _ in 0..10 {
|
||||
if let Some((data, addr, _)) = try_recv_decap(receiver_tunn, receiver_socket, 1000).await {
|
||||
return (data, addr);
|
||||
}
|
||||
}
|
||||
panic!("Did not receive decrypted IP packet");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 1: Single client ↔ server bidirectional data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_single_client_bidirectional() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
let client_addr = client_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, None, None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, None, None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
// Client → Server
|
||||
let pkt_c2s = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"Hello from client!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt_c2s,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt_c2s.len()], &pkt_c2s[..]);
|
||||
|
||||
// Server → Client
|
||||
let pkt_s2c = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2), b"Hello from server!");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut server_tunn, &mut client_tunn,
|
||||
&server_socket, &client_socket,
|
||||
client_addr, &pkt_s2c,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 1));
|
||||
assert_eq!(&decrypted[..pkt_s2c.len()], &pkt_s2c[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 2: Two clients ↔ one server (peer routing)
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_two_clients_peer_routing() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client1_pub_b64, client1_priv_b64) = generate_wg_keypair();
|
||||
let (client2_pub_b64, client2_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, _) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client1_public, client1_secret) = parse_key_pair(&client1_pub_b64, &client1_priv_b64);
|
||||
let (client2_public, client2_secret) = parse_key_pair(&client2_pub_b64, &client2_priv_b64);
|
||||
|
||||
// Separate server socket per peer to avoid UDP mux complexity in test
|
||||
let server_socket_1 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_socket_2 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client1_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client2_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr_1 = server_socket_1.local_addr().unwrap();
|
||||
let server_addr_2 = server_socket_2.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn_1 = Tunn::new(clone_secret(&server_priv_b64), client1_public, None, None, 0, None);
|
||||
let mut server_tunn_2 = Tunn::new(clone_secret(&server_priv_b64), client2_public, None, None, 1, None);
|
||||
let mut client1_tunn = Tunn::new(client1_secret, server_public.clone(), None, None, 2, None);
|
||||
let mut client2_tunn = Tunn::new(client2_secret, server_public, None, None, 3, None);
|
||||
|
||||
do_handshake(&mut client1_tunn, &mut server_tunn_1, &client1_socket, &server_socket_1, server_addr_1).await;
|
||||
do_handshake(&mut client2_tunn, &mut server_tunn_2, &client2_socket, &server_socket_2, server_addr_2).await;
|
||||
|
||||
// Client 1 → Server
|
||||
let pkt1 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"From client 1");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client1_tunn, &mut server_tunn_1,
|
||||
&client1_socket, &server_socket_1,
|
||||
server_addr_1, &pkt1,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt1.len()], &pkt1[..]);
|
||||
|
||||
// Client 2 → Server
|
||||
let pkt2 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 3), Ipv4Addr::new(10, 0, 0, 1), b"From client 2");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client2_tunn, &mut server_tunn_2,
|
||||
&client2_socket, &server_socket_2,
|
||||
server_addr_2, &pkt2,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 3));
|
||||
assert_eq!(&decrypted[..pkt2.len()], &pkt2[..]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 3: Preshared key handshake + data exchange
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_e2e_preshared_key() {
|
||||
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
|
||||
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
|
||||
|
||||
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
|
||||
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
|
||||
|
||||
let psk: [u8; 32] = rand::random();
|
||||
|
||||
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server_socket.local_addr().unwrap();
|
||||
|
||||
let mut server_tunn = Tunn::new(server_secret, client_public, Some(psk), None, 0, None);
|
||||
let mut client_tunn = Tunn::new(client_secret, server_public, Some(psk), None, 1, None);
|
||||
|
||||
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
|
||||
|
||||
let pkt = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"PSK-protected data");
|
||||
let (decrypted, src_ip) = send_and_expect_data(
|
||||
&mut client_tunn, &mut server_tunn,
|
||||
&client_socket, &server_socket,
|
||||
server_addr, &pkt,
|
||||
).await;
|
||||
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
|
||||
assert_eq!(&decrypted[..pkt.len()], &pkt[..]);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
@@ -40,7 +40,9 @@ let server: VpnServer;
|
||||
let serverPort: number;
|
||||
let keypair: IVpnKeypair;
|
||||
let client: VpnClient;
|
||||
let clientBundle: IClientConfigBundle;
|
||||
const extraClients: VpnClient[] = [];
|
||||
const extraBundles: IClientConfigBundle[] = [];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
@@ -64,7 +66,7 @@ tap.test('setup: start VPN server', async () => {
|
||||
expect(keypair.publicKey).toBeTypeofString();
|
||||
expect(keypair.privateKey).toBeTypeofString();
|
||||
|
||||
// Phase 3: start the VPN listener
|
||||
// Phase 3: start the VPN listener (empty clients, will use createClient at runtime)
|
||||
const serverConfig: IVpnServerConfig = {
|
||||
listenAddr: `127.0.0.1:${serverPort}`,
|
||||
privateKey: keypair.privateKey,
|
||||
@@ -76,6 +78,11 @@ tap.test('setup: start VPN server', async () => {
|
||||
// Verify server is now running
|
||||
const status = await server.getStatus();
|
||||
expect(status.state).toEqual('connected');
|
||||
|
||||
// Phase 4: create the first client via the hub
|
||||
clientBundle = await server.createClient({ clientId: 'test-client-0' });
|
||||
expect(clientBundle.secrets.noisePrivateKey).toBeTypeofString();
|
||||
expect(clientBundle.smartvpnConfig.clientPublicKey).toBeTypeofString();
|
||||
});
|
||||
|
||||
tap.test('single client connects and gets IP', async () => {
|
||||
@@ -89,6 +96,8 @@ tap.test('single client connects and gets IP', async () => {
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: clientBundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: clientBundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
@@ -175,11 +184,15 @@ tap.test('5 concurrent clients', async () => {
|
||||
assignedIps.add(existingClients[0].assignedIp);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const bundle = await server.createClient({ clientId: `test-client-${i + 1}` });
|
||||
extraBundles.push(bundle);
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
const result = await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${serverPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
expect(result.assignedIp).toStartWith('10.8.0.');
|
||||
|
||||
@@ -144,12 +144,17 @@ let keypair: IVpnKeypair;
|
||||
let throttle: ThrottleProxy;
|
||||
const allClients: VpnClient[] = [];
|
||||
|
||||
let clientCounter = 0;
|
||||
async function createConnectedClient(port: number): Promise<VpnClient> {
|
||||
clientCounter++;
|
||||
const bundle = await server.createClient({ clientId: `load-client-${clientCounter}` });
|
||||
const c = new VpnClient({ transport: { transport: 'stdio' } });
|
||||
await c.start();
|
||||
await c.connect({
|
||||
serverUrl: `ws://127.0.0.1:${port}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
allClients.push(c);
|
||||
|
||||
@@ -2,7 +2,7 @@ import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as dgram from 'dgram';
|
||||
import { VpnClient, VpnServer } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig } from '../ts/index.js';
|
||||
import type { IVpnClientOptions, IVpnServerOptions, IVpnKeypair, IVpnServerConfig, IClientConfigBundle } from '../ts/index.js';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
@@ -82,6 +82,8 @@ tap.test('setup: start VPN server in QUIC mode', async () => {
|
||||
});
|
||||
|
||||
tap.test('QUIC client connects and gets IP', async () => {
|
||||
const bundle = await server.createClient({ clientId: 'quic-client-1' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
@@ -92,6 +94,8 @@ tap.test('QUIC client connects and gets IP', async () => {
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${quicPort}`,
|
||||
serverPublicKey: keypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
@@ -162,12 +166,16 @@ tap.test('auto client connects to dual-mode server (QUIC preferred)', async () =
|
||||
const started = await client.start();
|
||||
expect(started).toBeTrue();
|
||||
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-auto-client' });
|
||||
|
||||
// "auto" mode (default): tries QUIC first at same host:port, falls back to WS
|
||||
// Since the WS port and QUIC port differ, auto will try QUIC on WS port (fail),
|
||||
// then fall back to WebSocket
|
||||
const result = await client.connect({
|
||||
serverUrl: `ws://127.0.0.1:${dualWsPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
// transport defaults to 'auto'
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
@@ -187,6 +195,8 @@ tap.test('auto client connects to dual-mode server (QUIC preferred)', async () =
|
||||
});
|
||||
|
||||
tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-quic-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
@@ -197,6 +207,8 @@ tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
const result = await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
@@ -211,6 +223,8 @@ tap.test('explicit QUIC client connects to dual-mode server', async () => {
|
||||
});
|
||||
|
||||
tap.test('keepalive exchange over QUIC', async () => {
|
||||
const bundle = await dualServer.createClient({ clientId: 'dual-keepalive-client' });
|
||||
|
||||
const options: IVpnClientOptions = {
|
||||
transport: { transport: 'stdio' },
|
||||
};
|
||||
@@ -220,6 +234,8 @@ tap.test('keepalive exchange over QUIC', async () => {
|
||||
await client.connect({
|
||||
serverUrl: `127.0.0.1:${dualQuicPort}`,
|
||||
serverPublicKey: dualKeypair.publicKey,
|
||||
clientPrivateKey: bundle.secrets.noisePrivateKey,
|
||||
clientPublicKey: bundle.smartvpnConfig.clientPublicKey,
|
||||
transport: 'quic',
|
||||
keepaliveIntervalSecs: 3,
|
||||
});
|
||||
|
||||
@@ -2,10 +2,17 @@ import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { VpnConfig } from '../ts/index.js';
|
||||
import type { IVpnClientConfig, IVpnServerConfig } from '../ts/index.js';
|
||||
|
||||
// Valid 32-byte base64 keys for testing
|
||||
const TEST_KEY_A = 'YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE=';
|
||||
const TEST_KEY_B = 'YmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmJiYmI=';
|
||||
const TEST_KEY_C = 'Y2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2NjY2M=';
|
||||
|
||||
tap.test('VpnConfig: validate valid client config', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['1.1.1.1', '8.8.8.8'],
|
||||
mtu: 1420,
|
||||
keepaliveIntervalSecs: 30,
|
||||
@@ -16,7 +23,9 @@ tap.test('VpnConfig: validate valid client config', async () => {
|
||||
|
||||
tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
const config = {
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -31,7 +40,9 @@ tap.test('VpnConfig: reject client config without serverUrl', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'http://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
@@ -43,10 +54,28 @@ tap.test('VpnConfig: reject client config with invalid serverUrl scheme', async
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config without clientPrivateKey', async () => {
|
||||
const config = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
} as IVpnClientConfig;
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateClientConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('clientPrivateKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
mtu: 100,
|
||||
};
|
||||
let threw = false;
|
||||
@@ -62,7 +91,9 @@ tap.test('VpnConfig: reject client config with invalid MTU', async () => {
|
||||
tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
const config: IVpnClientConfig = {
|
||||
serverUrl: 'wss://vpn.example.com/tunnel',
|
||||
serverPublicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
serverPublicKey: TEST_KEY_A,
|
||||
clientPrivateKey: TEST_KEY_B,
|
||||
clientPublicKey: TEST_KEY_C,
|
||||
dns: ['not-an-ip'],
|
||||
};
|
||||
let threw = false;
|
||||
@@ -78,12 +109,15 @@ tap.test('VpnConfig: reject client config with invalid DNS', async () => {
|
||||
tap.test('VpnConfig: validate valid server config', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
dns: ['1.1.1.1'],
|
||||
mtu: 1420,
|
||||
enableNat: true,
|
||||
clients: [
|
||||
{ clientId: 'test-client', publicKey: TEST_KEY_C },
|
||||
],
|
||||
};
|
||||
// Should not throw
|
||||
VpnConfig.validateServerConfig(config);
|
||||
@@ -92,8 +126,8 @@ tap.test('VpnConfig: validate valid server config', async () => {
|
||||
tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: 'dGVzdHByaXZhdGVrZXk=',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: 'invalid',
|
||||
};
|
||||
let threw = false;
|
||||
@@ -109,7 +143,7 @@ tap.test('VpnConfig: reject server config with invalid subnet', async () => {
|
||||
tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
const config = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
publicKey: 'dGVzdHB1YmxpY2tleQ==',
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
} as IVpnServerConfig;
|
||||
let threw = false;
|
||||
@@ -122,4 +156,24 @@ tap.test('VpnConfig: reject server config without privateKey', async () => {
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('VpnConfig: reject server config with invalid client publicKey', async () => {
|
||||
const config: IVpnServerConfig = {
|
||||
listenAddr: '0.0.0.0:443',
|
||||
privateKey: TEST_KEY_A,
|
||||
publicKey: TEST_KEY_B,
|
||||
subnet: '10.8.0.0/24',
|
||||
clients: [
|
||||
{ clientId: 'bad-client', publicKey: 'short-key' },
|
||||
],
|
||||
};
|
||||
let threw = false;
|
||||
try {
|
||||
VpnConfig.validateServerConfig(config);
|
||||
} catch (e) {
|
||||
threw = true;
|
||||
expect((e as Error).message).toContain('publicKey');
|
||||
}
|
||||
expect(threw).toBeTrue();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartvpn',
|
||||
version: '1.6.0',
|
||||
version: '1.9.0',
|
||||
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
|
||||
}
|
||||
|
||||
@@ -51,6 +51,15 @@ export class VpnConfig {
|
||||
if (!config.serverPublicKey) {
|
||||
throw new Error('VpnConfig: serverPublicKey is required');
|
||||
}
|
||||
// Noise IK requires client keypair
|
||||
if (!config.clientPrivateKey) {
|
||||
throw new Error('VpnConfig: clientPrivateKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPrivateKey, 'clientPrivateKey');
|
||||
if (!config.clientPublicKey) {
|
||||
throw new Error('VpnConfig: clientPublicKey is required for Noise IK authentication');
|
||||
}
|
||||
VpnConfig.validateBase64Key(config.clientPublicKey, 'clientPublicKey');
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
@@ -116,6 +125,18 @@ export class VpnConfig {
|
||||
if (!VpnConfig.isValidSubnet(config.subnet)) {
|
||||
throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`);
|
||||
}
|
||||
// Validate client entries if provided
|
||||
if (config.clients) {
|
||||
for (const client of config.clients) {
|
||||
if (!client.clientId) {
|
||||
throw new Error('VpnConfig: client entry must have a clientId');
|
||||
}
|
||||
if (!client.publicKey) {
|
||||
throw new Error(`VpnConfig: client '${client.clientId}' must have a publicKey`);
|
||||
}
|
||||
VpnConfig.validateBase64Key(client.publicKey, `client '${client.clientId}' publicKey`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) {
|
||||
throw new Error('VpnConfig: mtu must be between 576 and 65535');
|
||||
|
||||
@@ -10,6 +10,8 @@ import type {
|
||||
IVpnClientTelemetry,
|
||||
IWgPeerConfig,
|
||||
IWgPeerInfo,
|
||||
IClientEntry,
|
||||
IClientConfigBundle,
|
||||
TVpnServerCommands,
|
||||
} from './smartvpn.interfaces.js';
|
||||
|
||||
@@ -152,6 +154,81 @@ export class VpnServer extends plugins.events.EventEmitter {
|
||||
return result.peers;
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ─────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a new client. Generates keypairs, assigns IP, returns full config bundle.
|
||||
* The secrets (private keys) are only returned at creation time.
|
||||
*/
|
||||
public async createClient(opts: Partial<IClientEntry>): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('createClient', { client: opts });
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a registered client (also disconnects if connected).
|
||||
*/
|
||||
public async removeClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('removeClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a registered client by ID.
|
||||
*/
|
||||
public async getClient(clientId: string): Promise<IClientEntry> {
|
||||
return this.bridge.sendCommand('getClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* List all registered clients.
|
||||
*/
|
||||
public async listRegisteredClients(): Promise<IClientEntry[]> {
|
||||
const result = await this.bridge.sendCommand('listRegisteredClients', {} as Record<string, never>);
|
||||
return result.clients;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a registered client's fields (ACLs, tags, description, etc.).
|
||||
*/
|
||||
public async updateClient(clientId: string, update: Partial<IClientEntry>): Promise<void> {
|
||||
await this.bridge.sendCommand('updateClient', { clientId, update });
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a previously disabled client.
|
||||
*/
|
||||
public async enableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('enableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable a client (also disconnects if connected).
|
||||
*/
|
||||
public async disableClient(clientId: string): Promise<void> {
|
||||
await this.bridge.sendCommand('disableClient', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Rotate a client's keys. Returns a new config bundle with fresh keypairs.
|
||||
*/
|
||||
public async rotateClientKey(clientId: string): Promise<IClientConfigBundle> {
|
||||
return this.bridge.sendCommand('rotateClientKey', { clientId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Export a client config (without secrets) in the specified format.
|
||||
*/
|
||||
public async exportClientConfig(clientId: string, format: 'smartvpn' | 'wireguard'): Promise<string> {
|
||||
const result = await this.bridge.sendCommand('exportClientConfig', { clientId, format });
|
||||
return result.config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a standalone Noise IK keypair (not tied to a client).
|
||||
*/
|
||||
public async generateClientKeypair(): Promise<IVpnKeypair> {
|
||||
return this.bridge.sendCommand('generateClientKeypair', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the daemon bridge.
|
||||
*/
|
||||
|
||||
@@ -24,8 +24,12 @@ export type TVpnTransportOptions = IVpnTransportStdio | IVpnTransportSocket;
|
||||
export interface IVpnClientConfig {
|
||||
/** Server WebSocket URL, e.g. wss://vpn.example.com/tunnel */
|
||||
serverUrl: string;
|
||||
/** Server's static public key (base64) for Noise NK handshake */
|
||||
/** Server's static public key (base64) for Noise IK handshake */
|
||||
serverPublicKey: string;
|
||||
/** Client's Noise IK private key (base64) — required for SmartVPN native transport */
|
||||
clientPrivateKey: string;
|
||||
/** Client's Noise IK public key (base64) — for reference/display */
|
||||
clientPublicKey: string;
|
||||
/** Optional DNS servers to use while connected */
|
||||
dns?: string[];
|
||||
/** Optional MTU for the TUN device */
|
||||
@@ -96,6 +100,15 @@ export interface IVpnServerConfig {
|
||||
wgListenPort?: number;
|
||||
/** WireGuard: configured peers */
|
||||
wgPeers?: IWgPeerConfig[];
|
||||
/** Pre-registered clients for Noise IK authentication */
|
||||
clients?: IClientEntry[];
|
||||
/** Enable PROXY protocol v2 on incoming WebSocket connections.
|
||||
* Required when behind a reverse proxy that sends PP v2 headers (HAProxy, SmartProxy).
|
||||
* SECURITY: Must be false when accepting direct client connections. */
|
||||
proxyProtocol?: boolean;
|
||||
/** Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
* Supports exact IPs, CIDR, wildcards, ranges. */
|
||||
connectionIpBlockList?: string[];
|
||||
}
|
||||
|
||||
export interface IVpnServerOptions {
|
||||
@@ -146,6 +159,12 @@ export interface IVpnClientInfo {
|
||||
keepalivesReceived: number;
|
||||
rateLimitBytesPerSec?: number;
|
||||
burstBytes?: number;
|
||||
/** Client's authenticated Noise IK public key (base64) */
|
||||
authenticatedKey: string;
|
||||
/** Registered client ID from the client registry */
|
||||
registeredClientId: string;
|
||||
/** Real client IP:port (from PROXY protocol or direct TCP connection) */
|
||||
remoteAddr?: string;
|
||||
}
|
||||
|
||||
export interface IVpnServerStatistics extends IVpnStatistics {
|
||||
@@ -205,6 +224,84 @@ export interface IVpnClientTelemetry {
|
||||
burstBytes?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client Registry (Hub) types — aligned with SmartProxy IRouteSecurity pattern
|
||||
// ============================================================================
|
||||
|
||||
/** Per-client rate limiting. */
|
||||
export interface IClientRateLimit {
|
||||
/** Max throughput in bytes/sec */
|
||||
bytesPerSec: number;
|
||||
/** Burst allowance in bytes */
|
||||
burstBytes: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Per-client security settings.
|
||||
* Mirrors SmartProxy's IRouteSecurity: ipAllowList/ipBlockList naming + deny-overrides-allow.
|
||||
* Adds VPN-specific destination filtering.
|
||||
*/
|
||||
export interface IClientSecurity {
|
||||
/** Source IPs/CIDRs the client may connect FROM (empty = any).
|
||||
* Supports: exact IP, CIDR, wildcard (192.168.1.*), ranges (1.1.1.1-1.1.1.5). */
|
||||
ipAllowList?: string[];
|
||||
/** Source IPs blocked — overrides ipAllowList (deny wins). */
|
||||
ipBlockList?: string[];
|
||||
/** Destination IPs/CIDRs the client may reach through the VPN (empty = all). */
|
||||
destinationAllowList?: string[];
|
||||
/** Destination IPs blocked — overrides destinationAllowList (deny wins). */
|
||||
destinationBlockList?: string[];
|
||||
/** Max concurrent connections from this client. */
|
||||
maxConnections?: number;
|
||||
/** Per-client rate limiting. */
|
||||
rateLimit?: IClientRateLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server-side client definition — the central config object for the Hub.
|
||||
* Naming and structure aligned with SmartProxy's IRouteConfig / IRouteSecurity.
|
||||
*/
|
||||
export interface IClientEntry {
|
||||
/** Human-readable client ID (e.g. "alice-laptop") */
|
||||
clientId: string;
|
||||
/** Client's Noise IK public key (base64) — for SmartVPN native transport */
|
||||
publicKey: string;
|
||||
/** Client's WireGuard public key (base64) — for WireGuard transport */
|
||||
wgPublicKey?: string;
|
||||
/** Security settings (ACLs, rate limits) */
|
||||
security?: IClientSecurity;
|
||||
/** Traffic priority (lower = higher priority, default: 100) */
|
||||
priority?: number;
|
||||
/** Whether this client is enabled (default: true) */
|
||||
enabled?: boolean;
|
||||
/** Tags for grouping (e.g. ["engineering", "office"]) */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
/** Optional expiry (ISO 8601 timestamp, omit = never expires) */
|
||||
expiresAt?: string;
|
||||
/** Assigned VPN IP address (set by server) */
|
||||
assignedIp?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete client config bundle — returned by createClient() and rotateClientKey().
|
||||
* Contains everything the client needs to connect.
|
||||
*/
|
||||
export interface IClientConfigBundle {
|
||||
/** The server-side client entry */
|
||||
entry: IClientEntry;
|
||||
/** Ready-to-use SmartVPN client config (typed object) */
|
||||
smartvpnConfig: IVpnClientConfig;
|
||||
/** Ready-to-use WireGuard .conf file content (string) */
|
||||
wireguardConfig: string;
|
||||
/** Client's private keys (ONLY returned at creation time, not stored server-side) */
|
||||
secrets: {
|
||||
noisePrivateKey: string;
|
||||
wgPrivateKey: string;
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WireGuard-specific types
|
||||
// ============================================================================
|
||||
@@ -262,6 +359,17 @@ export type TVpnServerCommands = {
|
||||
addWgPeer: { params: { peer: IWgPeerConfig }; result: void };
|
||||
removeWgPeer: { params: { publicKey: string }; result: void };
|
||||
listWgPeers: { params: Record<string, never>; result: { peers: IWgPeerInfo[] } };
|
||||
// Client Registry (Hub) commands
|
||||
createClient: { params: { client: Partial<IClientEntry> }; result: IClientConfigBundle };
|
||||
removeClient: { params: { clientId: string }; result: void };
|
||||
getClient: { params: { clientId: string }; result: IClientEntry };
|
||||
listRegisteredClients: { params: Record<string, never>; result: { clients: IClientEntry[] } };
|
||||
updateClient: { params: { clientId: string; update: Partial<IClientEntry> }; result: void };
|
||||
enableClient: { params: { clientId: string }; result: void };
|
||||
disableClient: { params: { clientId: string }; result: void };
|
||||
rotateClientKey: { params: { clientId: string }; result: IClientConfigBundle };
|
||||
exportClientConfig: { params: { clientId: string; format: 'smartvpn' | 'wireguard' }; result: { config: string } };
|
||||
generateClientKeypair: { params: Record<string, never>; result: IVpnKeypair };
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"esModuleInterop": true,
|
||||
"verbatimModuleSyntax": true
|
||||
"verbatimModuleSyntax": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"exclude": [
|
||||
"dist_ts/**/*.d.ts"
|
||||
|
||||
Reference in New Issue
Block a user