From e4e59d72f9e7583189e11804b2da4b8e6f48d333 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Sun, 29 Mar 2026 15:24:41 +0000 Subject: [PATCH] feat(wireguard): add WireGuard transport support with management APIs and config generation --- npmextra.json => .smartconfig.json | 0 changelog.md | 9 + rust/Cargo.lock | 130 ++- rust/Cargo.toml | 1 + rust/src/lib.rs | 1 + rust/src/management.rs | 249 ++++-- rust/src/wireguard.rs | 1329 ++++++++++++++++++++++++++++ test/test.wireguard.node.ts | 353 ++++++++ ts/00_commitinfo_data.ts | 2 +- ts/index.ts | 1 + ts/smartvpn.classes.vpnconfig.ts | 144 ++- ts/smartvpn.classes.vpnserver.ts | 31 + ts/smartvpn.classes.wgconfig.ts | 123 +++ ts/smartvpn.interfaces.ts | 59 +- 14 files changed, 2347 insertions(+), 85 deletions(-) rename npmextra.json => .smartconfig.json (100%) create mode 100644 rust/src/wireguard.rs create mode 100644 test/test.wireguard.node.ts create mode 100644 ts/smartvpn.classes.wgconfig.ts diff --git a/npmextra.json b/.smartconfig.json similarity index 100% rename from npmextra.json rename to .smartconfig.json diff --git a/changelog.md b/changelog.md index 9d3108c..1cbc543 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,14 @@ # Changelog +## 2026-03-29 - 1.5.0 - feat(wireguard) +add WireGuard transport support with management APIs and config generation + +- add Rust WireGuard module integration using boringtun and route management through client/server management handlers +- extend TypeScript client and server configuration schemas with WireGuard-specific options and validation +- add server-side WireGuard peer management commands including keypair generation, peer add/remove, and peer listing +- export a WireGuard config generator for producing client and server .conf files +- add WireGuard-focused test coverage for config validation and config generation + ## 2026-03-21 - 1.4.1 - fix(readme) preserve markdown line breaks in feature list diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c7ce335..77da132 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -137,12 +137,30 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.11.0" @@ -180,6 +198,30 @@ dependencies = [ "piper", ] +[[package]] +name = "boringtun" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dc4267b0c97985d9b089b19ff965b959e61870640d2f0842a97552e030fa43f" +dependencies = [ + "aead", + "base64 0.13.1", + "blake2", + "chacha20poly1305", + "hex", + "hmac", + "ip_network", + "ip_network_table", + "libc", + "nix 0.25.1", + "parking_lot", + "rand_core 0.6.4", + "ring", + "tracing", + "untrusted", + "x25519-dalek", +] + [[package]] name = "bumpalo" version = "3.20.2" @@ -655,6 +697,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.4.0" @@ -680,6 +737,28 @@ dependencies = [ "generic-array", ] +[[package]] +name = "ip_network" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2f047c0a98b2f299aa5d6d7088443570faae494e9ae1305e48be000c9e0eb1" + +[[package]] +name = "ip_network_table" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4099b7cfc5c5e2fe8c5edf3f6f7adf7a714c9cc697534f63a5a5da30397cb2c0" +dependencies = [ + "ip_network", + "ip_network_table-deps-treebitmap", +] + +[[package]] +name = "ip_network_table-deps-treebitmap" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e537132deb99c0eb4b752f0346b6a836200eaaa3516dd7e5514b63930a09e5d" + [[package]] name = "ipnet" version = "2.11.0" @@ -824,13 +903,25 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nix" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f346ff70e7dbfd675fe90590b92d59ef2de15a8779ae305ebcbfd3f0caf59be4" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "libc", +] + [[package]] name = "nix" version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags", + "bitflags 2.11.0", "cfg-if", "cfg_aliases", "libc", @@ -910,7 +1001,7 @@ version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" dependencies = [ - "base64", + "base64 0.22.1", "serde_core", ] @@ -1128,7 +1219,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags", + "bitflags 2.11.0", ] [[package]] @@ -1296,7 +1387,7 @@ version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags", + "bitflags 2.11.0", "core-foundation", "core-foundation-sys", "libc", @@ -1433,7 +1524,8 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", - "base64", + "base64 0.22.1", + "boringtun", "bytes", "chacha20poly1305", "clap", @@ -1733,7 +1825,7 @@ dependencies = [ "ipnet", "libc", "log", - "nix", + "nix 0.30.1", "thiserror 2.0.18", "tokio", "tokio-util", @@ -2197,6 +2289,18 @@ version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + [[package]] name = "yasna" version = "0.5.2" @@ -2231,6 +2335,20 @@ name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zmij" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9ce9f7e..f1706b9 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -34,6 +34,7 @@ rustls-pki-types = "1" rustls-pemfile = "2" webpki-roots = "1" mimalloc = "0.1" +boringtun = "0.7" [profile.release] opt-level = 3 diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 22937ce..0ca8a35 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -17,3 +17,4 @@ pub mod telemetry; pub mod ratelimit; pub mod qos; pub mod mtu; +pub mod wireguard; diff --git a/rust/src/management.rs b/rust/src/management.rs index 3cb42a2..ba59bea 100644 --- a/rust/src/management.rs +++ b/rust/src/management.rs @@ -7,6 +7,7 @@ use tracing::{info, error, warn}; use crate::client::{ClientConfig, VpnClient}; use crate::crypto; use crate::server::{ServerConfig, VpnServer}; +use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig}; // ============================================================================ // IPC protocol types @@ -93,6 +94,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> { let mut vpn_client = VpnClient::new(); let mut vpn_server = VpnServer::new(); + let mut wg_client = WgClient::new(); + let mut wg_server = WgServer::new(); send_event_stdout("ready", serde_json::json!({ "mode": mode })); @@ -127,8 +130,8 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> { }; let response = match mode { - "client" => handle_client_request(&request, &mut vpn_client).await, - "server" => handle_server_request(&request, &mut vpn_server).await, + "client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await, + "server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await, _ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)), }; send_response_stdout(&response); @@ -150,6 +153,8 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()> // Shared state behind Mutex for socket mode (multiple connections) let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new())); let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new())); + let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new())); + let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new())); loop { match listener.accept().await { @@ -157,9 +162,11 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()> let mode = mode.to_string(); let client = vpn_client.clone(); let server = vpn_server.clone(); + let wg_c = wg_client.clone(); + let wg_s = wg_server.clone(); tokio::spawn(async move { if let Err(e) = - handle_socket_connection(stream, &mode, client, server).await + handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await { warn!("Socket connection error: {}", e); } @@ -177,6 +184,8 @@ async fn handle_socket_connection( mode: &str, vpn_client: std::sync::Arc>, vpn_server: std::sync::Arc>, + wg_client: std::sync::Arc>, + wg_server: std::sync::Arc>, ) -> Result<()> { let (reader, mut writer) = stream.into_split(); let buf_reader = BufReader::new(reader); @@ -227,11 +236,13 @@ async fn handle_socket_connection( let response = match mode { "client" => { let mut client = vpn_client.lock().await; - handle_client_request(&request, &mut client).await + let mut wg_c = wg_client.lock().await; + handle_client_request(&request, &mut client, &mut wg_c).await } "server" => { let mut server = vpn_server.lock().await; - handle_server_request(&request, &mut server).await + let mut wg_s = wg_server.lock().await; + handle_server_request(&request, &mut server, &mut wg_s).await } _ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)), }; @@ -252,38 +263,79 @@ async fn handle_socket_connection( async fn handle_client_request( request: &ManagementRequest, vpn_client: &mut VpnClient, + wg_client: &mut WgClient, ) -> ManagementResponse { let id = request.id.clone(); match request.method.as_str() { "connect" => { - let config: ClientConfig = match serde_json::from_value( - request.params.get("config").cloned().unwrap_or_default(), - ) { - Ok(c) => c, - Err(e) => { - return ManagementResponse::err(id, format!("Invalid config: {}", e)); - } - }; + // Check if transport is "wireguard" + let transport = request.params + .get("config") + .and_then(|c| c.get("transport")) + .and_then(|t| t.as_str()) + .unwrap_or(""); - match vpn_client.connect(config).await { - Ok(assigned_ip) => { - ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip })) + if transport == "wireguard" { + let config: WgClientConfig = match serde_json::from_value( + request.params.get("config").cloned().unwrap_or_default(), + ) { + Ok(c) => c, + Err(e) => { + return ManagementResponse::err(id, format!("Invalid WG config: {}", e)); + } + }; + match wg_client.connect(config).await { + Ok(assigned_ip) => { + ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip })) + } + Err(e) => ManagementResponse::err(id, format!("WG connect failed: {}", e)), + } + } else { + let config: ClientConfig = match serde_json::from_value( + request.params.get("config").cloned().unwrap_or_default(), + ) { + Ok(c) => c, + Err(e) => { + return ManagementResponse::err(id, format!("Invalid config: {}", e)); + } + }; + match vpn_client.connect(config).await { + Ok(assigned_ip) => { + ManagementResponse::ok(id, serde_json::json!({ "assignedIp": assigned_ip })) + } + Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)), + } + } + } + "disconnect" => { + if wg_client.is_running() { + match wg_client.disconnect().await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("WG disconnect failed: {}", e)), + } + } else { + match vpn_client.disconnect().await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)), } - Err(e) => ManagementResponse::err(id, format!("Connect failed: {}", e)), } } - "disconnect" => match vpn_client.disconnect().await { - Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), - Err(e) => ManagementResponse::err(id, format!("Disconnect failed: {}", e)), - }, "getStatus" => { - let status = vpn_client.get_status().await; - ManagementResponse::ok(id, status) + if wg_client.is_running() { + ManagementResponse::ok(id, wg_client.get_status().await) + } else { + let status = vpn_client.get_status().await; + ManagementResponse::ok(id, status) + } } "getStatistics" => { - let stats = vpn_client.get_statistics().await; - ManagementResponse::ok(id, stats) + if wg_client.is_running() { + ManagementResponse::ok(id, wg_client.get_statistics().await) + } else { + let stats = vpn_client.get_statistics().await; + ManagementResponse::ok(id, stats) + } } "getConnectionQuality" => { match vpn_client.get_connection_quality() { @@ -329,45 +381,92 @@ async fn handle_client_request( async fn handle_server_request( request: &ManagementRequest, vpn_server: &mut VpnServer, + wg_server: &mut WgServer, ) -> ManagementResponse { let id = request.id.clone(); match request.method.as_str() { "start" => { - let config: ServerConfig = match serde_json::from_value( - request.params.get("config").cloned().unwrap_or_default(), - ) { - Ok(c) => c, - Err(e) => { - return ManagementResponse::err(id, format!("Invalid config: {}", e)); - } - }; + // Check if transportMode is "wireguard" + let transport_mode = request.params + .get("config") + .and_then(|c| c.get("transportMode")) + .and_then(|t| t.as_str()) + .unwrap_or(""); - match vpn_server.start(config).await { - Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), - Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)), + if transport_mode == "wireguard" { + let config: WgServerConfig = match serde_json::from_value( + request.params.get("config").cloned().unwrap_or_default(), + ) { + Ok(c) => c, + Err(e) => { + return ManagementResponse::err(id, format!("Invalid WG config: {}", e)); + } + }; + match wg_server.start(config).await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)), + } + } else { + let config: ServerConfig = match serde_json::from_value( + request.params.get("config").cloned().unwrap_or_default(), + ) { + Ok(c) => c, + Err(e) => { + return ManagementResponse::err(id, format!("Invalid config: {}", e)); + } + }; + match vpn_server.start(config).await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)), + } + } + } + "stop" => { + if wg_server.is_running() { + match wg_server.stop().await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)), + } + } else { + match vpn_server.stop().await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)), + } } } - "stop" => match vpn_server.stop().await { - Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), - Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)), - }, "getStatus" => { - let status = vpn_server.get_status(); - ManagementResponse::ok(id, status) + if wg_server.is_running() { + ManagementResponse::ok(id, wg_server.get_status()) + } else { + let status = vpn_server.get_status(); + ManagementResponse::ok(id, status) + } } "getStatistics" => { - let stats = vpn_server.get_statistics().await; - match serde_json::to_value(&stats) { - Ok(v) => ManagementResponse::ok(id, v), - Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + if wg_server.is_running() { + ManagementResponse::ok(id, wg_server.get_statistics().await) + } else { + let stats = vpn_server.get_statistics().await; + match serde_json::to_value(&stats) { + Ok(v) => ManagementResponse::ok(id, v), + Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + } } } "listClients" => { - let clients = vpn_server.list_clients().await; - match serde_json::to_value(&clients) { - Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })), - Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + if wg_server.is_running() { + let peers = wg_server.list_peers().await; + match serde_json::to_value(&peers) { + Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })), + Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + } + } else { + let clients = vpn_server.list_clients().await; + match serde_json::to_value(&clients) { + Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })), + Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + } } } "disconnectClient" => { @@ -436,6 +535,56 @@ async fn handle_server_request( ), Err(e) => ManagementResponse::err(id, format!("Keypair generation failed: {}", e)), }, + "generateWgKeypair" => { + let (public_key, private_key) = wireguard::generate_wg_keypair(); + ManagementResponse::ok( + id, + serde_json::json!({ + "publicKey": public_key, + "privateKey": private_key, + }), + ) + } + "addWgPeer" => { + if !wg_server.is_running() { + return ManagementResponse::err(id, "WireGuard server not running".to_string()); + } + let config: WgPeerConfig = match serde_json::from_value( + request.params.get("peer").cloned().unwrap_or_default(), + ) { + Ok(c) => c, + Err(e) => { + return ManagementResponse::err(id, format!("Invalid peer config: {}", e)); + } + }; + match wg_server.add_peer(config).await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)), + } + } + "removeWgPeer" => { + if !wg_server.is_running() { + return ManagementResponse::err(id, "WireGuard server not running".to_string()); + } + let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) { + Some(k) => k.to_string(), + None => return ManagementResponse::err(id, "Missing publicKey".to_string()), + }; + match wg_server.remove_peer(&public_key).await { + Ok(()) => ManagementResponse::ok(id, serde_json::json!({})), + Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)), + } + } + "listWgPeers" => { + if !wg_server.is_running() { + return ManagementResponse::err(id, "WireGuard server not running".to_string()); + } + let peers = wg_server.list_peers().await; + match serde_json::to_value(&peers) { + Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })), + Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)), + } + } _ => ManagementResponse::err(id, format!("Unknown server method: {}", request.method)), } } diff --git a/rust/src/wireguard.rs b/rust/src/wireguard.rs new file mode 100644 index 0000000..a96987c --- /dev/null +++ b/rust/src/wireguard.rs @@ -0,0 +1,1329 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Arc; +use std::time::Instant; + +use anyhow::{anyhow, Result}; +use base64::engine::general_purpose::STANDARD as BASE64; +use base64::Engine; +use boringtun::noise::rate_limiter::RateLimiter; +use boringtun::noise::{Tunn, TunnResult}; +use boringtun::x25519::{PublicKey, StaticSecret}; +use rand::rngs::OsRng; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::UdpSocket; +use tokio::sync::{mpsc, oneshot, RwLock}; +use tracing::{debug, error, info, warn}; + +use crate::network; +use crate::tunnel::{self, TunConfig}; + +// ============================================================================ +// Constants +// ============================================================================ + +const MAX_UDP_PACKET: usize = 65536; +const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET; +/// Minimum dst buffer size for boringtun encapsulate/decapsulate +const _MIN_DST_BUF: usize = 148; +const TIMER_TICK_MS: u64 = 100; +const DEFAULT_WG_PORT: u16 = 51820; +const DEFAULT_TUN_ADDRESS: &str = "10.8.0.1"; +const DEFAULT_TUN_NETMASK: &str = "255.255.255.0"; +const DEFAULT_MTU: u16 = 1420; + +// ============================================================================ +// Configuration types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WgPeerConfig { + pub public_key: String, + #[serde(default)] + pub preshared_key: Option, + pub allowed_ips: Vec, + #[serde(default)] + pub endpoint: Option, + #[serde(default)] + pub persistent_keepalive: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WgServerConfig { + pub private_key: String, + #[serde(default)] + pub listen_port: Option, + #[serde(default)] + pub tun_address: Option, + #[serde(default)] + pub tun_netmask: Option, + #[serde(default)] + pub mtu: Option, + pub peers: Vec, + #[serde(default)] + pub dns: Option>, + #[serde(default)] + pub enable_nat: Option, + #[serde(default)] + pub subnet: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WgClientConfig { + pub private_key: String, + pub address: String, + #[serde(default)] + pub address_prefix: Option, + #[serde(default)] + pub dns: Option>, + #[serde(default)] + pub mtu: Option, + pub peer: WgPeerConfig, +} + +// ============================================================================ +// Stats types +// ============================================================================ + +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WgPeerStats { + pub bytes_sent: u64, + pub bytes_received: u64, + pub packets_sent: u64, + pub packets_received: u64, + pub last_handshake_time: Option, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WgPeerInfo { + pub public_key: String, + pub allowed_ips: Vec, + pub endpoint: Option, + pub persistent_keepalive: Option, + #[serde(flatten)] + pub stats: WgPeerStats, +} + +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WgServerStats { + pub total_bytes_sent: u64, + pub total_bytes_received: u64, + pub total_packets_sent: u64, + pub total_packets_received: u64, + pub active_peers: usize, + pub uptime_seconds: f64, +} + +// ============================================================================ +// Key generation and parsing +// ============================================================================ + +/// Generate a WireGuard-compatible X25519 keypair. +/// Returns (public_key_base64, private_key_base64). +pub fn generate_wg_keypair() -> (String, String) { + let private = StaticSecret::random_from_rng(OsRng); + let public = PublicKey::from(&private); + let priv_b64 = BASE64.encode(private.to_bytes()); + let pub_b64 = BASE64.encode(public.to_bytes()); + (pub_b64, priv_b64) +} + +fn parse_private_key(b64: &str) -> Result { + let bytes = BASE64.decode(b64)?; + if bytes.len() != 32 { + return Err(anyhow!("Private key must be 32 bytes, got {}", bytes.len())); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Ok(StaticSecret::from(arr)) +} + +fn parse_public_key(b64: &str) -> Result { + let bytes = BASE64.decode(b64)?; + if bytes.len() != 32 { + return Err(anyhow!("Public key must be 32 bytes, got {}", bytes.len())); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Ok(PublicKey::from(arr)) +} + +fn parse_preshared_key(b64: &str) -> Result<[u8; 32]> { + let bytes = BASE64.decode(b64)?; + if bytes.len() != 32 { + return Err(anyhow!( + "Preshared key must be 32 bytes, got {}", + bytes.len() + )); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Ok(arr) +} + +// ============================================================================ +// AllowedIPs matching +// ============================================================================ + +#[derive(Debug, Clone)] +struct AllowedIp { + addr: IpAddr, + prefix_len: u8, +} + +impl AllowedIp { + fn parse(cidr: &str) -> Result { + let parts: Vec<&str> = cidr.split('/').collect(); + if parts.len() != 2 { + return Err(anyhow!("Invalid CIDR: {}", cidr)); + } + let addr: IpAddr = parts[0].parse()?; + let prefix_len: u8 = parts[1].parse()?; + match addr { + IpAddr::V4(_) if prefix_len > 32 => { + return Err(anyhow!("IPv4 prefix length {} > 32", prefix_len)) + } + IpAddr::V6(_) if prefix_len > 128 => { + return Err(anyhow!("IPv6 prefix length {} > 128", prefix_len)) + } + _ => {} + } + Ok(Self { addr, prefix_len }) + } + + fn matches(&self, ip: IpAddr) -> bool { + match (self.addr, ip) { + (IpAddr::V4(net), IpAddr::V4(target)) => { + if self.prefix_len == 0 { + return true; + } + if self.prefix_len >= 32 { + return net == target; + } + let mask = u32::MAX << (32 - self.prefix_len); + (u32::from(net) & mask) == (u32::from(target) & mask) + } + (IpAddr::V6(net), IpAddr::V6(target)) => { + if self.prefix_len == 0 { + return true; + } + if self.prefix_len >= 128 { + return net == target; + } + let net_bits = u128::from(net); + let target_bits = u128::from(target); + let mask = u128::MAX << (128 - self.prefix_len); + (net_bits & mask) == (target_bits & mask) + } + _ => false, + } + } +} + +/// Extract destination IP from an IP packet header. +fn extract_dst_ip(packet: &[u8]) -> Option { + if packet.is_empty() { + return None; + } + let version = packet[0] >> 4; + match version { + 4 if packet.len() >= 20 => { + let dst = Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]); + Some(IpAddr::V4(dst)) + } + 6 if packet.len() >= 40 => { + let mut octets = [0u8; 16]; + octets.copy_from_slice(&packet[24..40]); + Some(IpAddr::V6(Ipv6Addr::from(octets))) + } + _ => None, + } +} + +// ============================================================================ +// Dynamic peer management commands +// ============================================================================ + +enum WgCommand { + AddPeer(WgPeerConfig, oneshot::Sender>), + RemovePeer(String, oneshot::Sender>), +} + +// ============================================================================ +// Internal peer state (owned by event loop) +// ============================================================================ + +struct PeerState { + tunn: Tunn, + public_key_b64: String, + allowed_ips: Vec, + endpoint: Option, + #[allow(dead_code)] + persistent_keepalive: Option, + stats: WgPeerStats, +} + +impl PeerState { + fn matches_dst(&self, dst_ip: IpAddr) -> bool { + self.allowed_ips.iter().any(|aip| aip.matches(dst_ip)) + } +} + +// ============================================================================ +// WgServer +// ============================================================================ + +pub struct WgServer { + shutdown_tx: Option>, + command_tx: Option>, + shared_stats: Arc>>, + server_stats: Arc>, + started_at: Option, + listen_port: Option, +} + +impl WgServer { + pub fn new() -> Self { + Self { + shutdown_tx: None, + command_tx: None, + shared_stats: Arc::new(RwLock::new(HashMap::new())), + server_stats: Arc::new(RwLock::new(WgServerStats::default())), + started_at: None, + listen_port: None, + } + } + + pub fn is_running(&self) -> bool { + self.shutdown_tx.is_some() + } + + pub async fn start(&mut self, config: WgServerConfig) -> Result<()> { + if self.is_running() { + return Err(anyhow!("WireGuard server is already running")); + } + + let listen_port = config.listen_port.unwrap_or(DEFAULT_WG_PORT); + let tun_address = config + .tun_address + .as_deref() + .unwrap_or(DEFAULT_TUN_ADDRESS); + let tun_netmask = config + .tun_netmask + .as_deref() + .unwrap_or(DEFAULT_TUN_NETMASK); + let mtu = config.mtu.unwrap_or(DEFAULT_MTU); + + // Parse server private key + let server_private = parse_private_key(&config.private_key)?; + let server_public = PublicKey::from(&server_private); + + // Create rate limiter for DDoS protection + let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64)); + + // Build peer state + let peer_index = AtomicU32::new(0); + let mut peers: Vec = Vec::with_capacity(config.peers.len()); + + for peer_config in &config.peers { + let peer_public = parse_public_key(&peer_config.public_key)?; + let psk = match &peer_config.preshared_key { + Some(k) => Some(parse_preshared_key(k)?), + None => None, + }; + let idx = peer_index.fetch_add(1, Ordering::Relaxed); + + // Clone the private key for each Tunn (StaticSecret doesn't implement Clone, + // so re-parse from config) + let priv_copy = parse_private_key(&config.private_key)?; + + let tunn = Tunn::new( + priv_copy, + peer_public, + psk, + peer_config.persistent_keepalive, + idx, + Some(rate_limiter.clone()), + ); + + let allowed_ips: Vec = peer_config + .allowed_ips + .iter() + .map(|cidr| AllowedIp::parse(cidr)) + .collect::>>()?; + + let endpoint = match &peer_config.endpoint { + Some(ep) => Some(ep.parse::()?), + None => None, + }; + + peers.push(PeerState { + tunn, + public_key_b64: peer_config.public_key.clone(), + allowed_ips, + endpoint, + persistent_keepalive: peer_config.persistent_keepalive, + stats: WgPeerStats::default(), + }); + } + + // Create TUN device + let tun_config = TunConfig { + name: "wg0".to_string(), + address: tun_address.parse()?, + netmask: tun_netmask.parse()?, + mtu, + }; + let tun_device = tunnel::create_tun(&tun_config)?; + info!("WireGuard TUN device created: {}", tun_config.name); + + // Bind UDP socket + let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", listen_port)).await?; + info!("WireGuard server listening on UDP port {}", listen_port); + + // Enable IP forwarding and NAT if requested + if config.enable_nat.unwrap_or(false) { + network::enable_ip_forwarding()?; + let subnet = config + .subnet + .as_deref() + .unwrap_or("10.8.0.0/24"); + let iface = network::get_default_interface()?; + network::setup_nat(subnet, &iface).await?; + info!("NAT enabled for subnet {} via {}", subnet, iface); + } + + // Channels + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let (command_tx, command_rx) = mpsc::channel::(32); + + let shared_stats = self.shared_stats.clone(); + let server_stats = self.server_stats.clone(); + let started_at = Instant::now(); + + // Initialize shared stats + { + let mut stats = shared_stats.write().await; + for peer in &peers { + stats.insert(peer.public_key_b64.clone(), WgPeerStats::default()); + } + } + + // Spawn the event loop + tokio::spawn(async move { + if let Err(e) = wg_server_loop( + udp_socket, + tun_device, + peers, + peer_index, + rate_limiter, + config.private_key.clone(), + shared_stats, + server_stats, + started_at, + shutdown_rx, + command_rx, + ) + .await + { + error!("WireGuard server loop error: {}", e); + } + info!("WireGuard server loop exited"); + }); + + self.shutdown_tx = Some(shutdown_tx); + self.command_tx = Some(command_tx); + self.started_at = Some(started_at); + self.listen_port = Some(listen_port); + + Ok(()) + } + + pub async fn stop(&mut self) -> Result<()> { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + self.command_tx = None; + self.started_at = None; + self.listen_port = None; + info!("WireGuard server stopped"); + Ok(()) + } + + pub fn get_status(&self) -> serde_json::Value { + if self.is_running() { + serde_json::json!({ + "state": "running", + "listenPort": self.listen_port, + "uptimeSeconds": self.started_at.map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0), + }) + } else { + serde_json::json!({ "state": "stopped" }) + } + } + + pub async fn get_statistics(&self) -> serde_json::Value { + let mut stats = self.server_stats.write().await; + if let Some(started) = self.started_at { + stats.uptime_seconds = started.elapsed().as_secs_f64(); + } + // Aggregate from peer stats + let peer_stats = self.shared_stats.read().await; + stats.active_peers = peer_stats.len(); + stats.total_bytes_sent = peer_stats.values().map(|s| s.bytes_sent).sum(); + stats.total_bytes_received = peer_stats.values().map(|s| s.bytes_received).sum(); + stats.total_packets_sent = peer_stats.values().map(|s| s.packets_sent).sum(); + stats.total_packets_received = peer_stats.values().map(|s| s.packets_received).sum(); + serde_json::to_value(&*stats).unwrap_or_default() + } + + pub async fn list_peers(&self) -> Vec { + let stats = self.shared_stats.read().await; + stats + .iter() + .map(|(key, s)| WgPeerInfo { + public_key: key.clone(), + allowed_ips: vec![], // populated from event loop snapshots + endpoint: None, + persistent_keepalive: None, + stats: s.clone(), + }) + .collect() + } + + pub async fn add_peer(&self, config: WgPeerConfig) -> Result<()> { + let tx = self + .command_tx + .as_ref() + .ok_or_else(|| anyhow!("Server not running"))?; + let (resp_tx, resp_rx) = oneshot::channel(); + tx.send(WgCommand::AddPeer(config, resp_tx)) + .await + .map_err(|_| anyhow!("Server event loop closed"))?; + resp_rx.await.map_err(|_| anyhow!("No response"))? + } + + pub async fn remove_peer(&self, public_key: &str) -> Result<()> { + let tx = self + .command_tx + .as_ref() + .ok_or_else(|| anyhow!("Server not running"))?; + let (resp_tx, resp_rx) = oneshot::channel(); + tx.send(WgCommand::RemovePeer(public_key.to_string(), resp_tx)) + .await + .map_err(|_| anyhow!("Server event loop closed"))?; + resp_rx.await.map_err(|_| anyhow!("No response"))? + } +} + +// ============================================================================ +// Server event loop +// ============================================================================ + +async fn wg_server_loop( + udp_socket: UdpSocket, + tun_device: tun::AsyncDevice, + mut peers: Vec, + peer_index: AtomicU32, + rate_limiter: Arc, + server_private_key_b64: String, + shared_stats: Arc>>, + _server_stats: Arc>, + _started_at: Instant, + mut shutdown_rx: oneshot::Receiver<()>, + mut command_rx: mpsc::Receiver, +) -> Result<()> { + let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; + let mut tun_buf = vec![0u8; MAX_UDP_PACKET]; + let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; + let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); + + // Split TUN for concurrent read/write in select + let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device); + + // Stats sync interval + let mut stats_timer = + tokio::time::interval(std::time::Duration::from_secs(1)); + + loop { + tokio::select! { + // --- UDP receive --- + result = udp_socket.recv_from(&mut udp_buf) => { + let (n, src_addr) = result?; + if n == 0 { continue; } + + // Find which peer this packet belongs to by trying decapsulate + let mut handled = false; + for peer in peers.iter_mut() { + match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + udp_socket.send_to(packet, src_addr).await?; + // Drain loop + loop { + match peer.tunn.decapsulate(None, &[], &mut dst_buf) { + TunnResult::WriteToNetwork(pkt) => { + let ep = peer.endpoint.unwrap_or(src_addr); + udp_socket.send_to(pkt, ep).await?; + } + _ => break, + } + } + peer.endpoint = Some(src_addr); + handled = true; + break; + } + TunnResult::WriteToTunnelV4(packet, addr) => { + if peer.matches_dst(IpAddr::V4(addr)) { + let pkt_len = packet.len() as u64; + tun_writer.write_all(packet).await?; + peer.stats.bytes_received += pkt_len; + peer.stats.packets_received += 1; + } + peer.endpoint = Some(src_addr); + handled = true; + break; + } + TunnResult::WriteToTunnelV6(packet, addr) => { + if peer.matches_dst(IpAddr::V6(addr)) { + let pkt_len = packet.len() as u64; + tun_writer.write_all(packet).await?; + peer.stats.bytes_received += pkt_len; + peer.stats.packets_received += 1; + } + peer.endpoint = Some(src_addr); + handled = true; + break; + } + TunnResult::Done => { + // This peer didn't recognize the packet, try next + continue; + } + TunnResult::Err(e) => { + debug!("decapsulate error from {}: {:?}", src_addr, e); + continue; + } + } + } + if !handled { + debug!("No peer matched UDP packet from {}", src_addr); + } + } + + // --- TUN read --- + result = tun_reader.read(&mut tun_buf) => { + let n = result?; + if n == 0 { continue; } + + let dst_ip = match extract_dst_ip(&tun_buf[..n]) { + Some(ip) => ip, + None => { continue; } + }; + + // Find peer whose AllowedIPs match the destination + for peer in peers.iter_mut() { + if !peer.matches_dst(dst_ip) { + continue; + } + match peer.tunn.encapsulate(&tun_buf[..n], &mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + if let Some(endpoint) = peer.endpoint { + let pkt_len = n as u64; + udp_socket.send_to(packet, endpoint).await?; + peer.stats.bytes_sent += pkt_len; + peer.stats.packets_sent += 1; + } else { + debug!("No endpoint for peer {}, dropping packet", peer.public_key_b64); + } + } + TunnResult::Err(e) => { + debug!("encapsulate error for peer {}: {:?}", peer.public_key_b64, e); + } + _ => {} + } + break; + } + } + + // --- Timer tick (100ms) for WireGuard timers --- + _ = timer.tick() => { + for peer in peers.iter_mut() { + match peer.tunn.update_timers(&mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + if let Some(endpoint) = peer.endpoint { + udp_socket.send_to(packet, endpoint).await?; + } + } + TunnResult::Err(e) => { + debug!("Timer error for peer {}: {:?}", peer.public_key_b64, e); + } + _ => {} + } + } + } + + // --- Sync stats to shared state --- + _ = stats_timer.tick() => { + let mut shared = shared_stats.write().await; + for peer in peers.iter() { + shared.insert(peer.public_key_b64.clone(), peer.stats.clone()); + } + } + + // --- Dynamic peer commands --- + cmd = command_rx.recv() => { + match cmd { + Some(WgCommand::AddPeer(config, resp_tx)) => { + let result = add_peer_to_loop( + &mut peers, + &config, + &peer_index, + &rate_limiter, + &server_private_key_b64, + ); + if result.is_ok() { + let mut shared = shared_stats.write().await; + shared.insert(config.public_key.clone(), WgPeerStats::default()); + } + let _ = resp_tx.send(result); + } + Some(WgCommand::RemovePeer(pubkey, resp_tx)) => { + let prev_len = peers.len(); + peers.retain(|p| p.public_key_b64 != pubkey); + if peers.len() < prev_len { + let mut shared = shared_stats.write().await; + shared.remove(&pubkey); + let _ = resp_tx.send(Ok(())); + } else { + let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey))); + } + } + None => { + info!("Command channel closed"); + break; + } + } + } + + // --- Shutdown --- + _ = &mut shutdown_rx => { + info!("WireGuard server shutdown signal received"); + break; + } + } + } + + Ok(()) +} + +fn add_peer_to_loop( + peers: &mut Vec, + config: &WgPeerConfig, + peer_index: &AtomicU32, + rate_limiter: &Arc, + server_private_key_b64: &str, +) -> Result<()> { + // Check for duplicate + if peers.iter().any(|p| p.public_key_b64 == config.public_key) { + return Err(anyhow!("Peer already exists: {}", config.public_key)); + } + + let peer_public = parse_public_key(&config.public_key)?; + let psk = match &config.preshared_key { + Some(k) => Some(parse_preshared_key(k)?), + None => None, + }; + let idx = peer_index.fetch_add(1, Ordering::Relaxed); + let priv_copy = parse_private_key(server_private_key_b64)?; + + let tunn = Tunn::new( + priv_copy, + peer_public, + psk, + config.persistent_keepalive, + idx, + Some(rate_limiter.clone()), + ); + + let allowed_ips: Vec = config + .allowed_ips + .iter() + .map(|cidr| AllowedIp::parse(cidr)) + .collect::>>()?; + + let endpoint = match &config.endpoint { + Some(ep) => Some(ep.parse::()?), + None => None, + }; + + peers.push(PeerState { + tunn, + public_key_b64: config.public_key.clone(), + allowed_ips, + endpoint, + persistent_keepalive: config.persistent_keepalive, + stats: WgPeerStats::default(), + }); + + info!("Added WireGuard peer: {}", config.public_key); + Ok(()) +} + +// ============================================================================ +// WgClient +// ============================================================================ + +pub struct WgClient { + shutdown_tx: Option>, + shared_stats: Arc>, + state: Arc>, + assigned_ip: Option, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct WgClientState { + state: String, + #[serde(skip_serializing_if = "Option::is_none")] + assigned_ip: Option, + #[serde(skip_serializing_if = "Option::is_none")] + connected_since: Option, + #[serde(skip_serializing_if = "Option::is_none")] + last_error: Option, +} + +impl WgClient { + pub fn new() -> Self { + Self { + shutdown_tx: None, + shared_stats: Arc::new(RwLock::new(WgPeerStats::default())), + state: Arc::new(RwLock::new(WgClientState { + state: "disconnected".to_string(), + assigned_ip: None, + connected_since: None, + last_error: None, + })), + assigned_ip: None, + } + } + + pub fn is_running(&self) -> bool { + self.shutdown_tx.is_some() + } + + pub async fn connect(&mut self, config: WgClientConfig) -> Result { + if self.is_running() { + return Err(anyhow!("WireGuard client is already connected")); + } + + { + let mut state = self.state.write().await; + state.state = "connecting".to_string(); + } + + let mtu = config.mtu.unwrap_or(DEFAULT_MTU); + let _prefix = config.address_prefix.unwrap_or(24); + let address: Ipv4Addr = config.address.parse()?; + + // Parse keys + let client_private = parse_private_key(&config.private_key)?; + let peer_public = parse_public_key(&config.peer.public_key)?; + let psk = match &config.peer.preshared_key { + Some(k) => Some(parse_preshared_key(k)?), + None => None, + }; + + let tunn = Tunn::new( + client_private, + peer_public, + psk, + config.peer.persistent_keepalive, + 0, // single peer, index 0 + None, + ); + + // Parse server endpoint + let endpoint: SocketAddr = config + .peer + .endpoint + .as_ref() + .ok_or_else(|| anyhow!("Peer endpoint is required for client mode"))? + .parse()?; + + // Parse AllowedIPs + let allowed_ips: Vec = config + .peer + .allowed_ips + .iter() + .map(|cidr| AllowedIp::parse(cidr)) + .collect::>>()?; + + // Create TUN device + let tun_config = TunConfig { + name: "wg-client0".to_string(), + address, + netmask: Ipv4Addr::new(255, 255, 255, 0), + mtu, + }; + let tun_device = tunnel::create_tun(&tun_config)?; + info!("WireGuard client TUN device created: {}", tun_config.name); + + // Add routes for AllowedIPs + for cidr in &config.peer.allowed_ips { + if let Err(e) = tunnel::add_route(cidr, &tun_config.name).await { + warn!("Failed to add route for {}: {}", cidr, e); + } + } + + // Bind ephemeral UDP socket + let udp_socket = UdpSocket::bind("0.0.0.0:0").await?; + info!( + "WireGuard client bound to {}", + udp_socket.local_addr()? + ); + + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let shared_stats = self.shared_stats.clone(); + let state = self.state.clone(); + let assigned_ip = config.address.clone(); + + // Update state + { + let mut s = state.write().await; + s.state = "connected".to_string(); + s.assigned_ip = Some(assigned_ip.clone()); + s.connected_since = Some(chrono_now()); + } + + // Spawn client loop + tokio::spawn(async move { + if let Err(e) = wg_client_loop( + udp_socket, + tun_device, + tunn, + endpoint, + allowed_ips, + shared_stats, + state.clone(), + shutdown_rx, + ) + .await + { + error!("WireGuard client loop error: {}", e); + let mut s = state.write().await; + s.state = "error".to_string(); + s.last_error = Some(format!("{}", e)); + } + }); + + self.shutdown_tx = Some(shutdown_tx); + self.assigned_ip = Some(config.address.clone()); + + Ok(config.address) + } + + pub async fn disconnect(&mut self) -> Result<()> { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + { + let mut s = self.state.write().await; + s.state = "disconnected".to_string(); + s.assigned_ip = None; + s.connected_since = None; + } + self.assigned_ip = None; + info!("WireGuard client disconnected"); + Ok(()) + } + + pub async fn get_status(&self) -> serde_json::Value { + let s = self.state.read().await; + serde_json::to_value(&*s).unwrap_or_default() + } + + pub async fn get_statistics(&self) -> serde_json::Value { + let stats = self.shared_stats.read().await; + serde_json::to_value(&*stats).unwrap_or_default() + } +} + +// ============================================================================ +// Client event loop +// ============================================================================ + +async fn wg_client_loop( + udp_socket: UdpSocket, + tun_device: tun::AsyncDevice, + mut tunn: Tunn, + endpoint: SocketAddr, + _allowed_ips: Vec, + shared_stats: Arc>, + _state: Arc>, + mut shutdown_rx: oneshot::Receiver<()>, +) -> Result<()> { + let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; + let mut tun_buf = vec![0u8; MAX_UDP_PACKET]; + let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; + let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); + let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1)); + + let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device); + + // Local stats (synced periodically) + let mut local_stats = WgPeerStats::default(); + + // Initiate handshake + match tunn.encapsulate(&[], &mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + udp_socket.send_to(packet, endpoint).await?; + debug!("Sent WireGuard handshake initiation"); + } + _ => {} + } + + loop { + tokio::select! { + // --- UDP receive --- + result = udp_socket.recv_from(&mut udp_buf) => { + let (n, src_addr) = result?; + if n == 0 { continue; } + + match tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + udp_socket.send_to(packet, endpoint).await?; + // Drain loop + loop { + match tunn.decapsulate(None, &[], &mut dst_buf) { + TunnResult::WriteToNetwork(pkt) => { + udp_socket.send_to(pkt, endpoint).await?; + } + _ => break, + } + } + } + TunnResult::WriteToTunnelV4(packet, _addr) => { + let pkt_len = packet.len() as u64; + tun_writer.write_all(packet).await?; + local_stats.bytes_received += pkt_len; + local_stats.packets_received += 1; + } + TunnResult::WriteToTunnelV6(packet, _addr) => { + let pkt_len = packet.len() as u64; + tun_writer.write_all(packet).await?; + local_stats.bytes_received += pkt_len; + local_stats.packets_received += 1; + } + TunnResult::Done => {} + TunnResult::Err(e) => { + debug!("Client decapsulate error: {:?}", e); + } + } + } + + // --- TUN read --- + result = tun_reader.read(&mut tun_buf) => { + let n = result?; + if n == 0 { continue; } + + match tunn.encapsulate(&tun_buf[..n], &mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + let pkt_len = n as u64; + udp_socket.send_to(packet, endpoint).await?; + local_stats.bytes_sent += pkt_len; + local_stats.packets_sent += 1; + } + TunnResult::Err(e) => { + debug!("Client encapsulate error: {:?}", e); + } + _ => {} + } + } + + // --- Timer tick --- + _ = timer.tick() => { + match tunn.update_timers(&mut dst_buf) { + TunnResult::WriteToNetwork(packet) => { + udp_socket.send_to(packet, endpoint).await?; + } + TunnResult::Err(e) => { + debug!("Client timer error: {:?}", e); + } + _ => {} + } + } + + // --- Sync stats --- + _ = stats_timer.tick() => { + let mut shared = shared_stats.write().await; + *shared = local_stats.clone(); + } + + // --- Shutdown --- + _ = &mut shutdown_rx => { + info!("WireGuard client shutdown signal received"); + break; + } + } + } + + Ok(()) +} + +// ============================================================================ +// Helpers +// ============================================================================ + +fn chrono_now() -> String { + // Simple ISO-8601 timestamp without chrono dependency + let dur = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default(); + format!("{}s since epoch", dur.as_secs()) +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_wg_keypair() { + let (pub_key, priv_key) = generate_wg_keypair(); + // Base64 of 32 bytes = 44 chars (with padding) + assert_eq!(pub_key.len(), 44); + assert_eq!(priv_key.len(), 44); + + // Decode and verify 32 bytes + let pub_bytes = BASE64.decode(&pub_key).unwrap(); + let priv_bytes = BASE64.decode(&priv_key).unwrap(); + assert_eq!(pub_bytes.len(), 32); + assert_eq!(priv_bytes.len(), 32); + } + + #[test] + fn test_key_roundtrip() { + let (pub_b64, priv_b64) = generate_wg_keypair(); + + // Parse back + let secret = parse_private_key(&priv_b64).unwrap(); + let public = parse_public_key(&pub_b64).unwrap(); + + // Derive public from private and verify match + let derived_public = PublicKey::from(&secret); + assert_eq!(public.to_bytes(), derived_public.to_bytes()); + } + + #[test] + fn test_parse_invalid_key() { + assert!(parse_private_key("not-valid-base64!!!").is_err()); + assert!(parse_private_key("AAAA").is_err()); // too short (3 bytes) + assert!(parse_public_key("AAAA").is_err()); + } + + #[test] + fn test_allowed_ip_v4_match() { + let aip = AllowedIp::parse("10.0.0.0/24").unwrap(); + assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); + assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 254)))); + assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 1, 1)))); + assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); + } + + #[test] + fn test_allowed_ip_v4_catch_all() { + let aip = AllowedIp::parse("0.0.0.0/0").unwrap(); + assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)))); + assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)))); + } + + #[test] + fn test_allowed_ip_v4_host() { + let aip = AllowedIp::parse("10.0.0.5/32").unwrap(); + assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))); + assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 6)))); + } + + #[test] + fn test_allowed_ip_v6_match() { + let aip = AllowedIp::parse("fd00::/64").unwrap(); + assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)))); + assert!(!aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd01, 0, 0, 0, 0, 0, 0, 1)))); + } + + #[test] + fn test_allowed_ip_v6_catch_all() { + let aip = AllowedIp::parse("::/0").unwrap(); + assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)))); + } + + #[test] + fn test_allowed_ip_cross_family_no_match() { + let v4 = AllowedIp::parse("10.0.0.0/8").unwrap(); + assert!(!v4.matches(IpAddr::V6(Ipv6Addr::LOCALHOST))); + + let v6 = AllowedIp::parse("::/0").unwrap(); + assert!(!v6.matches(IpAddr::V4(Ipv4Addr::LOCALHOST))); + } + + #[test] + fn test_extract_dst_ip_v4() { + // Minimal IPv4 header: version=4, IHL=5, total_length=20, dst at bytes 16-19 + let mut pkt = [0u8; 20]; + pkt[0] = 0x45; // version 4, IHL 5 + pkt[16] = 10; + pkt[17] = 0; + pkt[18] = 0; + pkt[19] = 1; + assert_eq!( + extract_dst_ip(&pkt), + Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))) + ); + } + + #[test] + fn test_extract_dst_ip_v6() { + // Minimal IPv6 header: version=6, dst at bytes 24-39 + let mut pkt = [0u8; 40]; + pkt[0] = 0x60; // version 6 + pkt[24] = 0xfd; + pkt[39] = 0x01; + let expected = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)); + assert_eq!(extract_dst_ip(&pkt), Some(expected)); + } + + #[test] + fn test_extract_dst_ip_empty() { + assert_eq!(extract_dst_ip(&[]), None); + } + + #[test] + fn test_loopback_tunnel() { + // Two Tunn instances: server and client, exchanging packets in memory + let (server_pub, server_priv) = generate_wg_keypair(); + let (client_pub, client_priv) = generate_wg_keypair(); + + let server_secret = parse_private_key(&server_priv).unwrap(); + let client_secret = parse_private_key(&client_priv).unwrap(); + let server_public = parse_public_key(&server_pub).unwrap(); + let client_public = parse_public_key(&client_pub).unwrap(); + + let mut server_tunn = Tunn::new( + server_secret, + client_public, + None, + None, + 0, + None, + ); + let mut client_tunn = Tunn::new( + client_secret, + server_public, + None, + None, + 1, + None, + ); + + let mut buf_a = vec![0u8; 2048]; + let mut buf_b = vec![0u8; 2048]; + + // Client initiates handshake + let handshake_init = match client_tunn.encapsulate(&[], &mut buf_a) { + TunnResult::WriteToNetwork(pkt) => pkt.to_vec(), + other => panic!("Expected WriteToNetwork for handshake init, got {:?}", format!("{:?}", std::mem::discriminant(&other))), + }; + + // Server processes handshake init + let handshake_resp = match server_tunn.decapsulate(None, &handshake_init, &mut buf_b) { + TunnResult::WriteToNetwork(pkt) => pkt.to_vec(), + other => panic!("Expected WriteToNetwork for handshake resp, got {:?}", format!("{:?}", std::mem::discriminant(&other))), + }; + + // Drain server + loop { + match server_tunn.decapsulate(None, &[], &mut buf_b) { + TunnResult::WriteToNetwork(_) => {} + _ => break, + } + } + + // Client processes handshake response + match client_tunn.decapsulate(None, &handshake_resp, &mut buf_a) { + TunnResult::WriteToNetwork(pkt) => { + // Client might send a keepalive or transport data + // Feed it to server + let pkt_copy = pkt.to_vec(); + let _ = server_tunn.decapsulate(None, &pkt_copy, &mut buf_b); + } + TunnResult::Done => {} + other => { + // Drain + loop { + match client_tunn.decapsulate(None, &[], &mut buf_a) { + TunnResult::WriteToNetwork(_) => {} + _ => break, + } + } + } + } + + // Drain client + loop { + match client_tunn.decapsulate(None, &[], &mut buf_a) { + TunnResult::WriteToNetwork(_) => {} + _ => break, + } + } + + // Now try to send a fake IP packet from client to server + let mut fake_ip = [0u8; 28]; + fake_ip[0] = 0x45; // IPv4 + fake_ip[2] = 0; + fake_ip[3] = 28; // total length + // Source IP (bytes 12-15): 10.0.0.2 (client) + fake_ip[12] = 10; + fake_ip[13] = 0; + fake_ip[14] = 0; + fake_ip[15] = 2; + // Destination IP (bytes 16-19): 10.0.0.1 (server) + fake_ip[16] = 10; + fake_ip[17] = 0; + fake_ip[18] = 0; + fake_ip[19] = 1; + + match client_tunn.encapsulate(&fake_ip, &mut buf_a) { + TunnResult::WriteToNetwork(encrypted) => { + let encrypted_copy = encrypted.to_vec(); + // Server decapsulates + match server_tunn.decapsulate(None, &encrypted_copy, &mut buf_b) { + TunnResult::WriteToTunnelV4(decrypted, src_addr) => { + // src_addr is the source IP from the inner packet (for AllowedIPs check) + assert_eq!(src_addr, Ipv4Addr::new(10, 0, 0, 2)); + assert_eq!(&decrypted[..fake_ip.len()], &fake_ip); + } + TunnResult::WriteToNetwork(_pkt) => { + // Might need another round trip, that's OK + } + _ => { + // Session might not be fully established yet, acceptable + } + } + } + TunnResult::Err(_) => { + // Session not yet established, acceptable in unit test + } + _ => {} + } + } +} diff --git a/test/test.wireguard.node.ts b/test/test.wireguard.node.ts new file mode 100644 index 0000000..5b146d4 --- /dev/null +++ b/test/test.wireguard.node.ts @@ -0,0 +1,353 @@ +import { tap, expect } from '@git.zone/tstest/tapbundle'; +import { + VpnConfig, + VpnServer, + WgConfigGenerator, +} from '../ts/index.js'; +import type { + IVpnClientConfig, + IVpnServerConfig, + IVpnServerOptions, + IWgPeerConfig, +} from '../ts/index.js'; + +// ============================================================================ +// WireGuard config validation — client +// ============================================================================ + +// A valid 32-byte key in base64 (44 chars) +const VALID_KEY = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='; +const VALID_KEY_2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB='; + +tap.test('WG client config: valid wireguard config passes validation', async () => { + const config: IVpnClientConfig = { + serverUrl: '', // not needed for WG + serverPublicKey: VALID_KEY, + transport: 'wireguard', + wgPrivateKey: VALID_KEY_2, + wgAddress: '10.8.0.2', + wgEndpoint: 'vpn.example.com:51820', + wgAllowedIps: ['0.0.0.0/0'], + }; + VpnConfig.validateClientConfig(config); +}); + +tap.test('WG client config: rejects missing wgPrivateKey', async () => { + const config: IVpnClientConfig = { + serverUrl: '', + serverPublicKey: VALID_KEY, + transport: 'wireguard', + wgAddress: '10.8.0.2', + wgEndpoint: 'vpn.example.com:51820', + }; + let threw = false; + try { + VpnConfig.validateClientConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('wgPrivateKey'); + } + expect(threw).toBeTrue(); +}); + +tap.test('WG client config: rejects missing wgAddress', async () => { + const config: IVpnClientConfig = { + serverUrl: '', + serverPublicKey: VALID_KEY, + transport: 'wireguard', + wgPrivateKey: VALID_KEY_2, + wgEndpoint: 'vpn.example.com:51820', + }; + let threw = false; + try { + VpnConfig.validateClientConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('wgAddress'); + } + expect(threw).toBeTrue(); +}); + +tap.test('WG client config: rejects missing wgEndpoint', async () => { + const config: IVpnClientConfig = { + serverUrl: '', + serverPublicKey: VALID_KEY, + transport: 'wireguard', + wgPrivateKey: VALID_KEY_2, + wgAddress: '10.8.0.2', + }; + let threw = false; + try { + VpnConfig.validateClientConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('wgEndpoint'); + } + expect(threw).toBeTrue(); +}); + +tap.test('WG client config: rejects invalid key length', async () => { + const config: IVpnClientConfig = { + serverUrl: '', + serverPublicKey: VALID_KEY, + transport: 'wireguard', + wgPrivateKey: 'tooshort', + wgAddress: '10.8.0.2', + wgEndpoint: 'vpn.example.com:51820', + }; + let threw = false; + try { + VpnConfig.validateClientConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('44 characters'); + } + expect(threw).toBeTrue(); +}); + +tap.test('WG client config: rejects invalid CIDR in allowedIps', async () => { + const config: IVpnClientConfig = { + serverUrl: '', + serverPublicKey: VALID_KEY, + transport: 'wireguard', + wgPrivateKey: VALID_KEY_2, + wgAddress: '10.8.0.2', + wgEndpoint: 'vpn.example.com:51820', + wgAllowedIps: ['not-a-cidr'], + }; + let threw = false; + try { + VpnConfig.validateClientConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('CIDR'); + } + expect(threw).toBeTrue(); +}); + +// ============================================================================ +// WireGuard config validation — server +// ============================================================================ + +tap.test('WG server config: valid config passes validation', async () => { + const config: IVpnServerConfig = { + listenAddr: '', + privateKey: VALID_KEY, + publicKey: VALID_KEY_2, + subnet: '10.8.0.0/24', + transportMode: 'wireguard', + wgPeers: [ + { + publicKey: VALID_KEY_2, + allowedIps: ['10.8.0.2/32'], + }, + ], + }; + VpnConfig.validateServerConfig(config); +}); + +tap.test('WG server config: rejects empty wgPeers', async () => { + const config: IVpnServerConfig = { + listenAddr: '', + privateKey: VALID_KEY, + publicKey: VALID_KEY_2, + subnet: '10.8.0.0/24', + transportMode: 'wireguard', + wgPeers: [], + }; + let threw = false; + try { + VpnConfig.validateServerConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('wgPeers'); + } + expect(threw).toBeTrue(); +}); + +tap.test('WG server config: rejects peer without publicKey', async () => { + const config: IVpnServerConfig = { + listenAddr: '', + privateKey: VALID_KEY, + publicKey: VALID_KEY_2, + subnet: '10.8.0.0/24', + transportMode: 'wireguard', + wgPeers: [ + { + publicKey: '', + allowedIps: ['10.8.0.2/32'], + }, + ], + }; + let threw = false; + try { + VpnConfig.validateServerConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('publicKey'); + } + expect(threw).toBeTrue(); +}); + +tap.test('WG server config: rejects invalid wgListenPort', async () => { + const config: IVpnServerConfig = { + listenAddr: '', + privateKey: VALID_KEY, + publicKey: VALID_KEY_2, + subnet: '10.8.0.0/24', + transportMode: 'wireguard', + wgListenPort: 0, + wgPeers: [ + { + publicKey: VALID_KEY_2, + allowedIps: ['10.8.0.2/32'], + }, + ], + }; + let threw = false; + try { + VpnConfig.validateServerConfig(config); + } catch (e) { + threw = true; + expect((e as Error).message).toContain('wgListenPort'); + } + expect(threw).toBeTrue(); +}); + +// ============================================================================ +// WireGuard keypair generation via daemon +// ============================================================================ + +let server: VpnServer; + +tap.test('WG: spawn server daemon for keypair generation', async () => { + const options: IVpnServerOptions = { + transport: { transport: 'stdio' }, + }; + server = new VpnServer(options); + const started = await server['bridge'].start(); + expect(started).toBeTrue(); + expect(server.running).toBeTrue(); +}); + +tap.test('WG: generateWgKeypair returns valid keypair', async () => { + const keypair = await server.generateWgKeypair(); + expect(keypair.publicKey).toBeTypeofString(); + expect(keypair.privateKey).toBeTypeofString(); + // WireGuard keys: base64 of 32 bytes = 44 characters + expect(keypair.publicKey.length).toEqual(44); + expect(keypair.privateKey.length).toEqual(44); + // Verify they decode to 32 bytes + const pubBuf = Buffer.from(keypair.publicKey, 'base64'); + const privBuf = Buffer.from(keypair.privateKey, 'base64'); + expect(pubBuf.length).toEqual(32); + expect(privBuf.length).toEqual(32); +}); + +tap.test('WG: generateWgKeypair returns unique keys each time', async () => { + const kp1 = await server.generateWgKeypair(); + const kp2 = await server.generateWgKeypair(); + expect(kp1.publicKey).not.toEqual(kp2.publicKey); + expect(kp1.privateKey).not.toEqual(kp2.privateKey); +}); + +tap.test('WG: stop server daemon', async () => { + server.stop(); + await new Promise((resolve) => setTimeout(resolve, 500)); + expect(server.running).toBeFalse(); +}); + +// ============================================================================ +// WireGuard config file generation +// ============================================================================ + +tap.test('WgConfigGenerator: generate client config', async () => { + const conf = WgConfigGenerator.generateClientConfig({ + privateKey: 'clientPrivateKeyBase64====================', + address: '10.8.0.2/24', + dns: ['1.1.1.1', '8.8.8.8'], + mtu: 1420, + peer: { + publicKey: 'serverPublicKeyBase64====================', + endpoint: 'vpn.example.com:51820', + allowedIps: ['0.0.0.0/0', '::/0'], + persistentKeepalive: 25, + }, + }); + expect(conf).toContain('[Interface]'); + expect(conf).toContain('PrivateKey = clientPrivateKeyBase64===================='); + expect(conf).toContain('Address = 10.8.0.2/24'); + expect(conf).toContain('DNS = 1.1.1.1, 8.8.8.8'); + expect(conf).toContain('MTU = 1420'); + expect(conf).toContain('[Peer]'); + expect(conf).toContain('PublicKey = serverPublicKeyBase64===================='); + expect(conf).toContain('Endpoint = vpn.example.com:51820'); + expect(conf).toContain('AllowedIPs = 0.0.0.0/0, ::/0'); + expect(conf).toContain('PersistentKeepalive = 25'); +}); + +tap.test('WgConfigGenerator: generate client config without optional fields', async () => { + const conf = WgConfigGenerator.generateClientConfig({ + privateKey: 'key1', + address: '10.0.0.2/32', + peer: { + publicKey: 'key2', + endpoint: 'server:51820', + allowedIps: ['10.0.0.0/24'], + }, + }); + expect(conf).toContain('[Interface]'); + expect(conf).not.toContain('DNS'); + expect(conf).not.toContain('MTU'); + expect(conf).not.toContain('PresharedKey'); + expect(conf).not.toContain('PersistentKeepalive'); +}); + +tap.test('WgConfigGenerator: generate server config with NAT', async () => { + const conf = WgConfigGenerator.generateServerConfig({ + privateKey: 'serverPrivKey', + address: '10.8.0.1/24', + listenPort: 51820, + dns: ['1.1.1.1'], + enableNat: true, + natInterface: 'ens3', + peers: [ + { + publicKey: 'peer1PubKey', + allowedIps: ['10.8.0.2/32'], + presharedKey: 'psk1', + persistentKeepalive: 25, + }, + { + publicKey: 'peer2PubKey', + allowedIps: ['10.8.0.3/32'], + }, + ], + }); + expect(conf).toContain('[Interface]'); + expect(conf).toContain('ListenPort = 51820'); + expect(conf).toContain('PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ens3 -j MASQUERADE'); + expect(conf).toContain('PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ens3 -j MASQUERADE'); + // Two [Peer] sections + const peerCount = (conf.match(/\[Peer\]/g) || []).length; + expect(peerCount).toEqual(2); + expect(conf).toContain('PresharedKey = psk1'); +}); + +tap.test('WgConfigGenerator: generate server config without NAT', async () => { + const conf = WgConfigGenerator.generateServerConfig({ + privateKey: 'serverPrivKey', + address: '10.8.0.1/24', + listenPort: 51820, + peers: [ + { + publicKey: 'peerKey', + allowedIps: ['10.8.0.2/32'], + }, + ], + }); + expect(conf).not.toContain('PostUp'); + expect(conf).not.toContain('PostDown'); +}); + +export default tap.start(); diff --git a/ts/00_commitinfo_data.ts b/ts/00_commitinfo_data.ts index dca69ef..1dc0fee 100644 --- a/ts/00_commitinfo_data.ts +++ b/ts/00_commitinfo_data.ts @@ -3,6 +3,6 @@ */ export const commitinfo = { name: '@push.rocks/smartvpn', - version: '1.4.1', + version: '1.5.0', description: 'A VPN solution with TypeScript control plane and Rust data plane daemon' } diff --git a/ts/index.ts b/ts/index.ts index 9f5ad0c..90224c5 100644 --- a/ts/index.ts +++ b/ts/index.ts @@ -4,3 +4,4 @@ export { VpnClient } from './smartvpn.classes.vpnclient.js'; export { VpnServer } from './smartvpn.classes.vpnserver.js'; export { VpnConfig } from './smartvpn.classes.vpnconfig.js'; export { VpnInstaller } from './smartvpn.classes.vpninstaller.js'; +export { WgConfigGenerator } from './smartvpn.classes.wgconfig.js'; diff --git a/ts/smartvpn.classes.vpnconfig.ts b/ts/smartvpn.classes.vpnconfig.ts index 6740ef3..03936b8 100644 --- a/ts/smartvpn.classes.vpnconfig.ts +++ b/ts/smartvpn.classes.vpnconfig.ts @@ -12,17 +12,45 @@ export class VpnConfig { * Validate a client config object. Throws on invalid config. */ public static validateClientConfig(config: IVpnClientConfig): void { - if (!config.serverUrl) { - throw new Error('VpnConfig: serverUrl is required'); - } - // For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss:// - if (config.transport !== 'quic') { - if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) { - throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)'); + if (config.transport === 'wireguard') { + // WireGuard-specific validation + if (!config.wgPrivateKey) { + throw new Error('VpnConfig: wgPrivateKey is required for WireGuard transport'); + } + VpnConfig.validateBase64Key(config.wgPrivateKey, 'wgPrivateKey'); + if (!config.wgAddress) { + throw new Error('VpnConfig: wgAddress is required for WireGuard transport'); + } + if (!config.serverPublicKey) { + throw new Error('VpnConfig: serverPublicKey is required for WireGuard transport'); + } + VpnConfig.validateBase64Key(config.serverPublicKey, 'serverPublicKey'); + if (!config.wgEndpoint) { + throw new Error('VpnConfig: wgEndpoint is required for WireGuard transport'); + } + if (config.wgPresharedKey) { + VpnConfig.validateBase64Key(config.wgPresharedKey, 'wgPresharedKey'); + } + if (config.wgAllowedIps) { + for (const cidr of config.wgAllowedIps) { + if (!VpnConfig.isValidCidr(cidr)) { + throw new Error(`VpnConfig: invalid allowedIp CIDR: ${cidr}`); + } + } + } + } else { + if (!config.serverUrl) { + throw new Error('VpnConfig: serverUrl is required'); + } + // For QUIC-only transport, serverUrl is a host:port address; for WebSocket/auto it must be ws:// or wss:// + if (config.transport !== 'quic') { + if (!config.serverUrl.startsWith('wss://') && !config.serverUrl.startsWith('ws://')) { + throw new Error('VpnConfig: serverUrl must start with wss:// or ws:// (for WebSocket transport)'); + } + } + if (!config.serverPublicKey) { + throw new Error('VpnConfig: serverPublicKey is required'); } - } - if (!config.serverPublicKey) { - throw new Error('VpnConfig: serverPublicKey is required'); } if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) { throw new Error('VpnConfig: mtu must be between 576 and 65535'); @@ -43,20 +71,51 @@ export class VpnConfig { * Validate a server config object. Throws on invalid config. */ public static validateServerConfig(config: IVpnServerConfig): void { - if (!config.listenAddr) { - throw new Error('VpnConfig: listenAddr is required'); - } - if (!config.privateKey) { - throw new Error('VpnConfig: privateKey is required'); - } - if (!config.publicKey) { - throw new Error('VpnConfig: publicKey is required'); - } - if (!config.subnet) { - throw new Error('VpnConfig: subnet is required'); - } - if (!VpnConfig.isValidSubnet(config.subnet)) { - throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`); + if (config.transportMode === 'wireguard') { + // WireGuard server validation + if (!config.privateKey) { + throw new Error('VpnConfig: privateKey is required'); + } + VpnConfig.validateBase64Key(config.privateKey, 'privateKey'); + if (!config.wgPeers || config.wgPeers.length === 0) { + throw new Error('VpnConfig: at least one wgPeers entry is required for WireGuard mode'); + } + for (const peer of config.wgPeers) { + if (!peer.publicKey) { + throw new Error('VpnConfig: peer publicKey is required'); + } + VpnConfig.validateBase64Key(peer.publicKey, 'peer.publicKey'); + if (!peer.allowedIps || peer.allowedIps.length === 0) { + throw new Error('VpnConfig: peer allowedIps is required'); + } + for (const cidr of peer.allowedIps) { + if (!VpnConfig.isValidCidr(cidr)) { + throw new Error(`VpnConfig: invalid peer allowedIp CIDR: ${cidr}`); + } + } + if (peer.presharedKey) { + VpnConfig.validateBase64Key(peer.presharedKey, 'peer.presharedKey'); + } + } + if (config.wgListenPort !== undefined && (config.wgListenPort < 1 || config.wgListenPort > 65535)) { + throw new Error('VpnConfig: wgListenPort must be between 1 and 65535'); + } + } else { + if (!config.listenAddr) { + throw new Error('VpnConfig: listenAddr is required'); + } + if (!config.privateKey) { + throw new Error('VpnConfig: privateKey is required'); + } + if (!config.publicKey) { + throw new Error('VpnConfig: publicKey is required'); + } + if (!config.subnet) { + throw new Error('VpnConfig: subnet is required'); + } + if (!VpnConfig.isValidSubnet(config.subnet)) { + throw new Error(`VpnConfig: invalid subnet: ${config.subnet}`); + } } if (config.mtu !== undefined && (config.mtu < 576 || config.mtu > 65535)) { throw new Error('VpnConfig: mtu must be between 576 and 65535'); @@ -104,4 +163,41 @@ export class VpnConfig { const prefixNum = parseInt(prefix, 10); return !isNaN(prefixNum) && prefixNum >= 0 && prefixNum <= 32; } + + /** + * Validate a CIDR string (IPv4 or IPv6). + */ + private static isValidCidr(cidr: string): boolean { + const parts = cidr.split('/'); + if (parts.length !== 2) return false; + const prefixNum = parseInt(parts[1], 10); + if (isNaN(prefixNum) || prefixNum < 0) return false; + // IPv4 + if (VpnConfig.isValidIp(parts[0])) { + return prefixNum <= 32; + } + // IPv6 (basic check) + if (parts[0].includes(':')) { + return prefixNum <= 128; + } + return false; + } + + /** + * Validate a base64-encoded 32-byte key (WireGuard X25519 format). + */ + private static validateBase64Key(key: string, fieldName: string): void { + if (key.length !== 44) { + throw new Error(`VpnConfig: ${fieldName} must be 44 characters (base64 of 32 bytes), got ${key.length}`); + } + try { + const buf = Buffer.from(key, 'base64'); + if (buf.length !== 32) { + throw new Error(`VpnConfig: ${fieldName} must decode to 32 bytes, got ${buf.length}`); + } + } catch (e) { + if (e instanceof Error && e.message.startsWith('VpnConfig:')) throw e; + throw new Error(`VpnConfig: ${fieldName} is not valid base64`); + } + } } diff --git a/ts/smartvpn.classes.vpnserver.ts b/ts/smartvpn.classes.vpnserver.ts index 7e36ca2..abcb676 100644 --- a/ts/smartvpn.classes.vpnserver.ts +++ b/ts/smartvpn.classes.vpnserver.ts @@ -8,6 +8,8 @@ import type { IVpnClientInfo, IVpnKeypair, IVpnClientTelemetry, + IWgPeerConfig, + IWgPeerInfo, TVpnServerCommands, } from './smartvpn.interfaces.js'; @@ -121,6 +123,35 @@ export class VpnServer extends plugins.events.EventEmitter { return this.bridge.sendCommand('getClientTelemetry', { clientId }); } + /** + * Generate a WireGuard-compatible X25519 keypair. + */ + public async generateWgKeypair(): Promise { + return this.bridge.sendCommand('generateWgKeypair', {} as Record); + } + + /** + * Add a WireGuard peer (server must be running in wireguard mode). + */ + public async addWgPeer(peer: IWgPeerConfig): Promise { + await this.bridge.sendCommand('addWgPeer', { peer }); + } + + /** + * Remove a WireGuard peer by public key. + */ + public async removeWgPeer(publicKey: string): Promise { + await this.bridge.sendCommand('removeWgPeer', { publicKey }); + } + + /** + * List WireGuard peers with stats. + */ + public async listWgPeers(): Promise { + const result = await this.bridge.sendCommand('listWgPeers', {} as Record); + return result.peers; + } + /** * Stop the daemon bridge. */ diff --git a/ts/smartvpn.classes.wgconfig.ts b/ts/smartvpn.classes.wgconfig.ts new file mode 100644 index 0000000..b4abf2f --- /dev/null +++ b/ts/smartvpn.classes.wgconfig.ts @@ -0,0 +1,123 @@ +import type { IWgPeerConfig } from './smartvpn.interfaces.js'; + +// ============================================================================ +// WireGuard .conf file generator +// ============================================================================ + +export interface IWgClientConfOptions { + /** Client private key (base64) */ + privateKey: string; + /** Client TUN address with prefix (e.g. 10.8.0.2/24) */ + address: string; + /** DNS servers */ + dns?: string[]; + /** TUN MTU */ + mtu?: number; + /** Server peer config */ + peer: { + publicKey: string; + presharedKey?: string; + endpoint: string; + allowedIps: string[]; + persistentKeepalive?: number; + }; +} + +export interface IWgServerConfOptions { + /** Server private key (base64) */ + privateKey: string; + /** Server TUN address with prefix (e.g. 10.8.0.1/24) */ + address: string; + /** UDP listen port */ + listenPort: number; + /** DNS servers */ + dns?: string[]; + /** TUN MTU */ + mtu?: number; + /** Enable NAT — adds PostUp/PostDown iptables rules */ + enableNat?: boolean; + /** Network interface for NAT (e.g. eth0). Auto-detected if omitted. */ + natInterface?: string; + /** Configured peers */ + peers: IWgPeerConfig[]; +} + +/** + * Generates standard WireGuard .conf files compatible with wg-quick, + * WireGuard iOS/Android apps, and other standard WireGuard clients. + */ +export class WgConfigGenerator { + /** + * Generate a client .conf file content. + */ + public static generateClientConfig(opts: IWgClientConfOptions): string { + const lines: string[] = []; + + lines.push('[Interface]'); + lines.push(`PrivateKey = ${opts.privateKey}`); + lines.push(`Address = ${opts.address}`); + if (opts.dns && opts.dns.length > 0) { + lines.push(`DNS = ${opts.dns.join(', ')}`); + } + if (opts.mtu) { + lines.push(`MTU = ${opts.mtu}`); + } + + lines.push(''); + lines.push('[Peer]'); + lines.push(`PublicKey = ${opts.peer.publicKey}`); + if (opts.peer.presharedKey) { + lines.push(`PresharedKey = ${opts.peer.presharedKey}`); + } + lines.push(`Endpoint = ${opts.peer.endpoint}`); + lines.push(`AllowedIPs = ${opts.peer.allowedIps.join(', ')}`); + if (opts.peer.persistentKeepalive) { + lines.push(`PersistentKeepalive = ${opts.peer.persistentKeepalive}`); + } + + lines.push(''); + return lines.join('\n'); + } + + /** + * Generate a server .conf file content. + */ + public static generateServerConfig(opts: IWgServerConfOptions): string { + const lines: string[] = []; + + lines.push('[Interface]'); + lines.push(`PrivateKey = ${opts.privateKey}`); + lines.push(`Address = ${opts.address}`); + lines.push(`ListenPort = ${opts.listenPort}`); + if (opts.dns && opts.dns.length > 0) { + lines.push(`DNS = ${opts.dns.join(', ')}`); + } + if (opts.mtu) { + lines.push(`MTU = ${opts.mtu}`); + } + if (opts.enableNat) { + const iface = opts.natInterface || 'eth0'; + lines.push(`PostUp = iptables -A FORWARD -i %i -j ACCEPT; iptables -t nat -A POSTROUTING -o ${iface} -j MASQUERADE`); + lines.push(`PostDown = iptables -D FORWARD -i %i -j ACCEPT; iptables -t nat -D POSTROUTING -o ${iface} -j MASQUERADE`); + } + + for (const peer of opts.peers) { + lines.push(''); + lines.push('[Peer]'); + lines.push(`PublicKey = ${peer.publicKey}`); + if (peer.presharedKey) { + lines.push(`PresharedKey = ${peer.presharedKey}`); + } + lines.push(`AllowedIPs = ${peer.allowedIps.join(', ')}`); + if (peer.endpoint) { + lines.push(`Endpoint = ${peer.endpoint}`); + } + if (peer.persistentKeepalive) { + lines.push(`PersistentKeepalive = ${peer.persistentKeepalive}`); + } + } + + lines.push(''); + return lines.join('\n'); + } +} diff --git a/ts/smartvpn.interfaces.ts b/ts/smartvpn.interfaces.ts index 0b58496..471fc8a 100644 --- a/ts/smartvpn.interfaces.ts +++ b/ts/smartvpn.interfaces.ts @@ -32,10 +32,24 @@ export interface IVpnClientConfig { mtu?: number; /** Keepalive interval in seconds (default: 30) */ keepaliveIntervalSecs?: number; - /** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', or 'quic' */ - transport?: 'auto' | 'websocket' | 'quic'; + /** Transport protocol: 'auto' (default, tries QUIC then WS), 'websocket', 'quic', or 'wireguard' */ + transport?: 'auto' | 'websocket' | 'quic' | 'wireguard'; /** For QUIC: SHA-256 hash of server certificate (base64) for cert pinning */ serverCertHash?: string; + /** WireGuard: client private key (base64, X25519) */ + wgPrivateKey?: string; + /** WireGuard: client TUN address (e.g. 10.8.0.2) */ + wgAddress?: string; + /** WireGuard: client TUN address prefix length (default: 24) */ + wgAddressPrefix?: number; + /** WireGuard: preshared key (base64, optional) */ + wgPresharedKey?: string; + /** WireGuard: persistent keepalive interval in seconds */ + wgPersistentKeepalive?: number; + /** WireGuard: server endpoint (host:port, e.g. vpn.example.com:51820) */ + wgEndpoint?: string; + /** WireGuard: allowed IPs (CIDR strings, e.g. ['0.0.0.0/0']) */ + wgAllowedIps?: string[]; } export interface IVpnClientOptions { @@ -72,12 +86,16 @@ export interface IVpnServerConfig { defaultRateLimitBytesPerSec?: number; /** Default burst size for new clients (bytes). Omit for unlimited. */ defaultBurstBytes?: number; - /** Transport mode: 'both' (default, WS+QUIC), 'websocket', or 'quic' */ - transportMode?: 'websocket' | 'quic' | 'both'; + /** Transport mode: 'both' (default, WS+QUIC), 'websocket', 'quic', or 'wireguard' */ + transportMode?: 'websocket' | 'quic' | 'both' | 'wireguard'; /** QUIC listen address (host:port). Defaults to listenAddr. */ quicListenAddr?: string; /** QUIC idle timeout in seconds (default: 30) */ quicIdleTimeoutSecs?: number; + /** WireGuard: UDP listen port (default: 51820) */ + wgListenPort?: number; + /** WireGuard: configured peers */ + wgPeers?: IWgPeerConfig[]; } export interface IVpnServerOptions { @@ -187,6 +205,35 @@ export interface IVpnClientTelemetry { burstBytes?: number; } +// ============================================================================ +// WireGuard-specific types +// ============================================================================ + +export interface IWgPeerConfig { + /** Peer's public key (base64, X25519) */ + publicKey: string; + /** Optional preshared key (base64) */ + presharedKey?: string; + /** Allowed IP ranges (CIDR strings) */ + allowedIps: string[]; + /** Peer endpoint (host:port) — optional for server peers, required for client */ + endpoint?: string; + /** Persistent keepalive interval in seconds */ + persistentKeepalive?: number; +} + +export interface IWgPeerInfo { + publicKey: string; + allowedIps: string[]; + endpoint?: string; + persistentKeepalive?: number; + bytesSent: number; + bytesReceived: number; + packetsSent: number; + packetsReceived: number; + lastHandshakeTime?: string; +} + // ============================================================================ // IPC Command maps (used by smartrust RustBridge) // ============================================================================ @@ -211,6 +258,10 @@ export type TVpnServerCommands = { setClientRateLimit: { params: { clientId: string; rateBytesPerSec: number; burstBytes: number }; result: void }; removeClientRateLimit: { params: { clientId: string }; result: void }; getClientTelemetry: { params: { clientId: string }; result: IVpnClientTelemetry }; + generateWgKeypair: { params: Record; result: IVpnKeypair }; + addWgPeer: { params: { peer: IWgPeerConfig }; result: void }; + removeWgPeer: { params: { publicKey: string }; result: void }; + listWgPeers: { params: Record; result: { peers: IWgPeerInfo[] } }; }; // ============================================================================