diff --git a/changelog.md b/changelog.md index 28d3af6..e8f94a0 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,13 @@ # Changelog +## 2026-03-31 - 1.16.2 - fix(wireguard) +sync runtime peer management with client registration and derive the correct server public key from the WireGuard private key + +- Register, remove, and rotate WireGuard peers in the running listener when clients are added, deleted, or rekeyed. +- Generate client WireGuard configs with the public key derived from the configured WireGuard private key instead of reusing the generic server public key. +- Handle expired WireGuard sessions by re-initiating handshakes and mark client state as handshaking until the tunnel becomes active. +- Improve allowed IP matching and peer VPN IP extraction for runtime packet routing. + ## 2026-03-30 - 1.16.1 - fix(rust/server) add serde alias for clientAllowedIPs in server config diff --git a/rust/src/server.rs b/rust/src/server.rs index 3274982..1d6e5a3 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::sync::{mpsc, Mutex, RwLock}; -use tracing::{info, error, warn}; +use tracing::{debug, info, error, warn}; use crate::acl; use crate::client_registry::{ClientEntry, ClientRegistry}; @@ -594,6 +594,20 @@ impl VpnServer { // Add to registry state.client_registry.write().await.add(entry.clone())?; + // Register WG peer with the running WG listener (if active) + if self.wg_command_tx.is_some() { + let wg_peer_config = crate::wireguard::WgPeerConfig { + public_key: wg_pub.clone(), + preshared_key: None, + allowed_ips: vec![format!("{}/32", assigned_ip)], + endpoint: None, + persistent_keepalive: Some(25), + }; + if let Err(e) = self.add_wg_peer(wg_peer_config).await { + warn!("Failed to register WG peer for client {}: {}", client_id, e); + } + } + // Build SmartVPN client config let smartvpn_server_url = format!("wss://{}", state.config.server_endpoint.as_deref() @@ -610,6 +624,10 @@ impl VpnServer { }); // Build WireGuard config string + let wg_server_pubkey = match &state.config.wg_private_key { + Some(wg_priv_key) => crate::wireguard::wg_public_key_from_private(wg_priv_key)?, + None => state.config.public_key.clone(), + }; let wg_endpoint = state.config.server_endpoint.as_deref() .unwrap_or(&state.config.listen_addr); let wg_allowed_ips = state.config.client_allowed_ips.as_ref() @@ -622,7 +640,7 @@ impl VpnServer { state.config.dns.as_ref() .map(|d| format!("DNS = {}", d.join(", "))) .unwrap_or_default(), - state.config.public_key, + wg_server_pubkey, wg_allowed_ips, wg_endpoint, ); @@ -645,6 +663,14 @@ impl VpnServer { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; let entry = state.client_registry.write().await.remove(client_id)?; + // Remove WG peer from running listener + if self.wg_command_tx.is_some() { + if let Some(ref wg_key) = entry.wg_public_key { + if let Err(e) = self.remove_wg_peer(wg_key).await { + debug!("Failed to remove WG peer for client {}: {}", client_id, e); + } + } + } // Release the IP if assigned if let Some(ref ip_str) = entry.assigned_ip { if let Ok(ip) = ip_str.parse::() { @@ -731,6 +757,14 @@ impl VpnServer { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; + // Capture old WG key before rotation (needed to remove from WG listener) + let old_wg_pub = { + let registry = state.client_registry.read().await; + let entry = registry.get_by_id(client_id) + .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; + entry.wg_public_key.clone() + }; + let (noise_pub, noise_priv) = crypto::generate_keypair()?; let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair(); @@ -749,6 +783,25 @@ impl VpnServer { .and_then(|v| v.as_str()) .unwrap_or("0.0.0.0"); + // Update WG listener: remove old peer, add new peer + if self.wg_command_tx.is_some() { + if let Some(ref old_key) = old_wg_pub { + if let Err(e) = self.remove_wg_peer(old_key).await { + debug!("Failed to remove old WG peer during rotation: {}", e); + } + } + let wg_peer_config = crate::wireguard::WgPeerConfig { + public_key: wg_pub.clone(), + preshared_key: None, + allowed_ips: vec![format!("{}/32", assigned_ip)], + endpoint: None, + persistent_keepalive: Some(25), + }; + if let Err(e) = self.add_wg_peer(wg_peer_config).await { + warn!("Failed to register new WG peer during rotation: {}", e); + } + } + let smartvpn_server_url = format!("wss://{}", state.config.server_endpoint.as_deref() .unwrap_or(&state.config.listen_addr) @@ -763,6 +816,10 @@ impl VpnServer { "keepaliveIntervalSecs": state.config.keepalive_interval_secs, }); + let wg_server_pubkey = match &state.config.wg_private_key { + Some(wg_priv_key) => crate::wireguard::wg_public_key_from_private(wg_priv_key)?, + None => state.config.public_key.clone(), + }; let wg_endpoint = state.config.server_endpoint.as_deref() .unwrap_or(&state.config.listen_addr); let wg_allowed_ips = state.config.client_allowed_ips.as_ref() @@ -774,7 +831,7 @@ impl VpnServer { state.config.dns.as_ref() .map(|d| format!("DNS = {}", d.join(", "))) .unwrap_or_default(), - state.config.public_key, + wg_server_pubkey, wg_allowed_ips, wg_endpoint, ); @@ -816,6 +873,10 @@ impl VpnServer { })) } "wireguard" => { + let wg_server_pubkey = match &state.config.wg_private_key { + Some(wg_priv_key) => crate::wireguard::wg_public_key_from_private(wg_priv_key)?, + None => state.config.public_key.clone(), + }; let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0"); let wg_endpoint = state.config.server_endpoint.as_deref() .unwrap_or(&state.config.listen_addr); @@ -828,7 +889,7 @@ impl VpnServer { state.config.dns.as_ref() .map(|d| format!("DNS = {}", d.join(", "))) .unwrap_or_default(), - state.config.public_key, + wg_server_pubkey, wg_allowed_ips, wg_endpoint, ); diff --git a/rust/src/wireguard.rs b/rust/src/wireguard.rs index d2b4bbd..d698360 100644 --- a/rust/src/wireguard.rs +++ b/rust/src/wireguard.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use base64::engine::general_purpose::STANDARD as BASE64; use base64::Engine; +use boringtun::noise::errors::WireGuardError; use boringtun::noise::rate_limiter::RateLimiter; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::{PublicKey, StaticSecret}; @@ -99,6 +100,13 @@ pub fn generate_wg_keypair() -> (String, String) { (pub_b64, priv_b64) } +/// Derive the WireGuard public key (base64) from a private key (base64). +pub fn wg_public_key_from_private(private_key_b64: &str) -> Result { + let private = parse_private_key(private_key_b64)?; + let public = PublicKey::from(&private); + Ok(BASE64.encode(public.to_bytes())) +} + fn parse_private_key(b64: &str) -> Result { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { @@ -215,8 +223,8 @@ struct PeerState { } impl PeerState { - fn matches_dst(&self, dst_ip: IpAddr) -> bool { - self.allowed_ips.iter().any(|aip| aip.matches(dst_ip)) + fn matches_allowed_ips(&self, ip: IpAddr) -> bool { + self.allowed_ips.iter().any(|aip| aip.matches(ip)) } } @@ -286,9 +294,10 @@ pub struct WgListenerConfig { pub peers: Vec, } -/// Extract the first /32 IPv4 address from a list of AllowedIp entries. -/// This is the peer's VPN IP used for return-packet routing. +/// Extract the peer's VPN IP from AllowedIp entries. +/// Prefers /32 entries (exact match); falls back to any IPv4 address. fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option { + // Prefer /32 entries (exact peer VPN IP) for aip in allowed_ips { if let IpAddr::V4(v4) = aip.addr { if aip.prefix_len == 32 { @@ -296,6 +305,12 @@ fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option { } } } + // Fallback: use the first IPv4 address from any prefix length + for aip in allowed_ips { + if let IpAddr::V4(v4) = aip.addr { + return Some(v4); + } + } None } @@ -495,7 +510,7 @@ pub async fn run_wg_listener( break; } TunnResult::WriteToTunnelV4(packet, addr) => { - if peer.matches_dst(IpAddr::V4(addr)) { + if peer.matches_allowed_ips(IpAddr::V4(addr)) { let pkt_len = packet.len() as u64; // Forward via shared forwarding engine let mut engine = state.forwarding_engine.lock().await; @@ -519,7 +534,7 @@ pub async fn run_wg_listener( break; } TunnResult::WriteToTunnelV6(packet, addr) => { - if peer.matches_dst(IpAddr::V6(addr)) { + if peer.matches_allowed_ips(IpAddr::V6(addr)) { let pkt_len = packet.len() as u64; let mut engine = state.forwarding_engine.lock().await; match &mut *engine { @@ -586,6 +601,9 @@ pub async fn run_wg_listener( udp_socket.send_to(packet, endpoint).await?; } } + TunnResult::Err(WireGuardError::ConnectionExpired) => { + warn!("WG peer {} connection expired", peer.public_key_b64); + } TunnResult::Err(e) => { debug!("Timer error for WG peer {}: {:?}", peer.public_key_b64, e); @@ -796,12 +814,12 @@ impl WgClient { let state = self.state.clone(); let assigned_ip = config.address.clone(); - // Update state + // Update state — handshake hasn't completed yet { let mut s = state.write().await; - s.state = "connected".to_string(); + s.state = "handshaking".to_string(); s.assigned_ip = Some(assigned_ip.clone()); - s.connected_since = Some(chrono_now()); + s.connected_since = None; } // Spawn client loop @@ -868,7 +886,7 @@ async fn wg_client_loop( endpoint: SocketAddr, _allowed_ips: Vec, shared_stats: Arc>, - _state: Arc>, + state: Arc>, mut shutdown_rx: oneshot::Receiver<()>, ) -> Result<()> { let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; @@ -876,6 +894,7 @@ async fn wg_client_loop( let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1)); + let mut handshake_complete = false; let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device); @@ -916,14 +935,37 @@ async fn wg_client_loop( tun_writer.write_all(packet).await?; local_stats.bytes_received += pkt_len; local_stats.packets_received += 1; + if !handshake_complete { + handshake_complete = true; + let mut s = state.write().await; + s.state = "connected".to_string(); + s.connected_since = Some(chrono_now()); + info!("WireGuard handshake completed, tunnel active"); + } } TunnResult::WriteToTunnelV6(packet, _addr) => { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; local_stats.bytes_received += pkt_len; local_stats.packets_received += 1; + if !handshake_complete { + handshake_complete = true; + let mut s = state.write().await; + s.state = "connected".to_string(); + s.connected_since = Some(chrono_now()); + info!("WireGuard handshake completed, tunnel active"); + } } TunnResult::Done => {} + TunnResult::Err(WireGuardError::ConnectionExpired) => { + warn!("WireGuard session expired during decapsulate, re-initiating handshake"); + match tunn.format_handshake_initiation(&mut dst_buf, true) { + TunnResult::WriteToNetwork(packet) => { + udp_socket.send_to(packet, endpoint).await?; + } + _ => {} + } + } TunnResult::Err(e) => { debug!("Client decapsulate error: {:?}", e); } @@ -955,6 +997,19 @@ async fn wg_client_loop( TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; } + TunnResult::Err(WireGuardError::ConnectionExpired) => { + warn!("WireGuard connection expired, re-initiating handshake"); + match tunn.format_handshake_initiation(&mut dst_buf, true) { + TunnResult::WriteToNetwork(packet) => { + udp_socket.send_to(packet, endpoint).await?; + debug!("Sent handshake re-initiation after expiry"); + } + TunnResult::Err(e) => { + warn!("Failed to re-initiate handshake: {:?}", e); + } + _ => {} + } + } TunnResult::Err(e) => { debug!("Client timer error: {:?}", e); } @@ -1028,6 +1083,19 @@ mod tests { assert_eq!(public.to_bytes(), derived_public.to_bytes()); } + #[test] + fn test_wg_public_key_from_private() { + let (pub_b64, priv_b64) = generate_wg_keypair(); + let derived = wg_public_key_from_private(&priv_b64).unwrap(); + assert_eq!(derived, pub_b64); + } + + #[test] + fn test_wg_public_key_from_private_invalid() { + assert!(wg_public_key_from_private("not-valid").is_err()); + assert!(wg_public_key_from_private("AAAA").is_err()); + } + #[test] fn test_parse_invalid_key() { assert!(parse_private_key("not-valid-base64!!!").is_err()); diff --git a/ts/00_commitinfo_data.ts b/ts/00_commitinfo_data.ts index dd3bd95..885d91e 100644 --- a/ts/00_commitinfo_data.ts +++ b/ts/00_commitinfo_data.ts @@ -3,6 +3,6 @@ */ export const commitinfo = { name: '@push.rocks/smartvpn', - version: '1.16.1', + version: '1.16.2', description: 'A VPN solution with TypeScript control plane and Rust data plane daemon' }