From 3c515c7d7fb40c322d257a72b6fe66f0491c5da5 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Tue, 12 May 2026 23:08:11 +0000 Subject: [PATCH] fix(rust-wireguard): keep WireGuard peer registration and client state in sync --- changelog.md | 9 ++ rust/src/client_registry.rs | 72 ++++++++++++++- rust/src/network.rs | 15 +++- rust/src/server.rs | 170 ++++++++++++++++++++++++++++++------ rust/src/wireguard.rs | 99 ++++++++++++++++----- 5 files changed, 314 insertions(+), 51 deletions(-) diff --git a/changelog.md b/changelog.md index baff29c..d404c0f 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,15 @@ ## Pending +### Fixes + +- keep WireGuard peer registration and client state in sync (rust-wireguard) + - index WireGuard public keys in the client registry for duplicate detection and direct lookup + - skip inactive clients when loading or adding WireGuard peers and fail fast when peer registration cannot be completed + - make IP reservation idempotent for the same client and avoid releasing WireGuard-assigned IPs on disconnect + - roll back client registry and peer state when WireGuard peer creation or key rotation fails + - update hybrid routing entries when registered client networking flags change + ## 2026-05-12 - 1.19.3 ### Fixes diff --git a/rust/src/client_registry.rs b/rust/src/client_registry.rs index 91eb3e9..e517107 100644 --- a/rust/src/client_registry.rs +++ b/rust/src/client_registry.rs @@ -98,6 +98,8 @@ pub struct ClientRegistry { entries: HashMap, /// Secondary index: publicKey (base64) → clientId (fast lookup during handshake) key_index: HashMap, + /// WireGuard public key → clientId (fast lookup during WG handshakes) + wg_key_index: HashMap, /// Tertiary index: assignedIp → clientId (fast lookup during NAT destination policy) ip_index: HashMap, } @@ -107,6 +109,7 @@ impl ClientRegistry { Self { entries: HashMap::new(), key_index: HashMap::new(), + wg_key_index: HashMap::new(), ip_index: HashMap::new(), } } @@ -132,6 +135,12 @@ impl ClientRegistry { if self.key_index.contains_key(&entry.public_key) { anyhow::bail!("Public key already registered to another client"); } + if let Some(ref wg_key) = entry.wg_public_key { + if self.wg_key_index.contains_key(wg_key) { + anyhow::bail!("WireGuard public key already registered to another client"); + } + self.wg_key_index.insert(wg_key.clone(), entry.client_id.clone()); + } self.key_index.insert(entry.public_key.clone(), entry.client_id.clone()); if let Some(ref ip) = entry.assigned_ip { self.ip_index.insert(ip.clone(), entry.client_id.clone()); @@ -145,6 +154,9 @@ impl ClientRegistry { let entry = self.entries.remove(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; self.key_index.remove(&entry.public_key); + if let Some(ref wg_key) = entry.wg_public_key { + self.wg_key_index.remove(wg_key); + } if let Some(ref ip) = entry.assigned_ip { self.ip_index.remove(ip); } @@ -162,6 +174,12 @@ impl ClientRegistry { self.entries.get(client_id) } + /// Get a client by WireGuard public key. + pub fn get_by_wg_key(&self, public_key: &str) -> Option<&ClientEntry> { + let client_id = self.wg_key_index.get(public_key)?; + self.entries.get(client_id) + } + /// Get a client by assigned IP (used for per-client destination policy in NAT engine). pub fn get_by_assigned_ip(&self, ip: &str) -> Option<&ClientEntry> { let client_id = self.ip_index.get(ip)?; @@ -184,6 +202,7 @@ impl ClientRegistry { let entry = self.entries.get_mut(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; let old_key = entry.public_key.clone(); + let old_wg_key = entry.wg_public_key.clone(); let old_ip = entry.assigned_ip.clone(); updater(entry); // If public key changed, update the key index @@ -191,6 +210,15 @@ impl ClientRegistry { self.key_index.remove(&old_key); self.key_index.insert(entry.public_key.clone(), client_id.to_string()); } + // If WireGuard public key changed, update the WG key index. + if entry.wg_public_key != old_wg_key { + if let Some(ref old) = old_wg_key { + self.wg_key_index.remove(old); + } + if let Some(ref new_key) = entry.wg_public_key { + self.wg_key_index.insert(new_key.clone(), client_id.to_string()); + } + } // If assigned IP changed, update the IP index if entry.assigned_ip != old_ip { if let Some(ref old) = old_ip { @@ -210,13 +238,32 @@ impl ClientRegistry { /// Rotate a client's keys. Returns the updated entry. pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option) -> Result<()> { + if let Some(existing_client_id) = self.key_index.get(&new_public_key) { + if existing_client_id != client_id { + anyhow::bail!("Public key already registered to another client"); + } + } + if let Some(ref new_wg_key) = new_wg_public_key { + if let Some(existing_client_id) = self.wg_key_index.get(new_wg_key) { + if existing_client_id != client_id { + anyhow::bail!("WireGuard public key already registered to another client"); + } + } + } + let entry = self.entries.get_mut(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; // Update key index self.key_index.remove(&entry.public_key); + if let Some(ref old_wg_key) = entry.wg_public_key { + self.wg_key_index.remove(old_wg_key); + } entry.public_key = new_public_key.clone(); entry.wg_public_key = new_wg_public_key; self.key_index.insert(new_public_key, client_id.to_string()); + if let Some(ref wg_key) = entry.wg_public_key { + self.wg_key_index.insert(wg_key.clone(), client_id.to_string()); + } Ok(()) } @@ -367,13 +414,36 @@ mod tests { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "old_key")).unwrap(); - reg.rotate_key("alice", "new_key".to_string(), None).unwrap(); + reg.rotate_key("alice", "new_key".to_string(), Some("new_wg_key".to_string())).unwrap(); assert!(reg.get_by_key("old_key").is_none()); assert!(reg.get_by_key("new_key").is_some()); + assert!(reg.get_by_wg_key("new_wg_key").is_some()); assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key"); } + #[test] + fn lookup_by_wireguard_key() { + let mut reg = ClientRegistry::new(); + let mut entry = make_entry("alice", "key_alice"); + entry.wg_public_key = Some("wg_key_alice".to_string()); + reg.add(entry).unwrap(); + + assert_eq!(reg.get_by_wg_key("wg_key_alice").unwrap().client_id, "alice"); + } + + #[test] + fn reject_duplicate_wireguard_key() { + let mut reg = ClientRegistry::new(); + let mut alice = make_entry("alice", "key_alice"); + alice.wg_public_key = Some("same_wg_key".to_string()); + let mut bob = make_entry("bob", "key_bob"); + bob.wg_public_key = Some("same_wg_key".to_string()); + + reg.add(alice).unwrap(); + assert!(reg.add(bob).is_err()); + } + #[test] fn from_entries() { let entries = vec![ diff --git a/rust/src/network.rs b/rust/src/network.rs index 53631fd..4155ed8 100644 --- a/rust/src/network.rs +++ b/rust/src/network.rs @@ -124,7 +124,10 @@ impl IpPool { /// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips). pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> { - if self.allocated.contains_key(&ip) { + if let Some(existing_client_id) = self.allocated.get(&ip) { + if existing_client_id == client_id { + return Ok(()); + } anyhow::bail!("IP {} is already allocated", ip); } self.allocated.insert(ip, client_id.to_string()); @@ -233,6 +236,16 @@ mod tests { assert!(pool.allocate("client2").is_err()); } + #[test] + fn ip_pool_reserve_is_idempotent_for_same_client() { + let mut pool = IpPool::new("10.9.0.0/24").unwrap(); + let ip = pool.allocate("alice").unwrap(); + + pool.reserve(ip, "alice").unwrap(); + assert_eq!(pool.allocated_count(), 1); + assert!(pool.reserve(ip, "bob").is_err()); + } + #[test] fn ip_pool_invalid_subnet() { assert!(IpPool::new("invalid").is_err()); diff --git a/rust/src/server.rs b/rust/src/server.rs index fc9c938..10c0aab 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -584,6 +584,9 @@ impl VpnServer { if self.wg_command_tx.is_some() { let registry = state.client_registry.read().await; for entry in registry.list() { + if !entry.is_enabled() || entry.is_expired() { + continue; + } if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) { let peer_config = crate::wireguard::WgPeerConfig { public_key: wg_key.clone(), @@ -725,8 +728,10 @@ impl VpnServer { if let Some(ref state) = self.state { let mut clients = state.clients.write().await; if let Some(client) = clients.remove(client_id) { - let ip: Ipv4Addr = client.assigned_ip.parse()?; - state.ip_pool.lock().await.release(&ip); + if client.transport_type != "wireguard" { + let ip: Ipv4Addr = client.assigned_ip.parse()?; + state.ip_pool.lock().await.release(&ip); + } state.rate_limiters.lock().await.remove(client_id); info!("Client {} disconnected", client_id); } @@ -890,7 +895,9 @@ impl VpnServer { persistent_keepalive: Some(25), }; if let Err(e) = self.add_wg_peer(wg_peer_config).await { - warn!("Failed to register WG peer for client {}: {}", client_id, e); + let _ = state.client_registry.write().await.remove(&client_id); + state.ip_pool.lock().await.release(&assigned_ip); + return Err(anyhow::anyhow!("Failed to register WG peer for client {}: {}", client_id, e)); } } @@ -948,15 +955,28 @@ impl VpnServer { pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; - let entry = state.client_registry.write().await.remove(client_id)?; + + let entry = { + let registry = state.client_registry.read().await; + registry.get_by_id(client_id) + .cloned() + .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))? + }; + // Remove WG peer from running listener if self.wg_command_tx.is_some() { if let Some(ref wg_key) = entry.wg_public_key { if let Err(e) = self.remove_wg_peer(wg_key).await { - debug!("Failed to remove WG peer for client {}: {}", client_id, e); + if entry.is_enabled() && !entry.is_expired() { + return Err(anyhow::anyhow!("Failed to remove WG peer for client {}: {}", client_id, e)); + } + debug!("Failed to remove inactive WG peer for client {}: {}", client_id, e); } } } + + state.client_registry.write().await.remove(client_id)?; + // Release the IP if assigned if let Some(ref ip_str) = entry.assigned_ip { if let Ok(ip) = ip_str.parse::() { @@ -1013,7 +1033,37 @@ impl VpnServer { if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) { entry.expires_at = Some(expires.to_string()); } + if let Some(use_host_ip) = update.get("useHostIp").and_then(|v| v.as_bool()) { + entry.use_host_ip = Some(use_host_ip); + } + if let Some(use_dhcp) = update.get("useDhcp").and_then(|v| v.as_bool()) { + entry.use_dhcp = Some(use_dhcp); + } + if let Some(static_ip) = update.get("staticIp").and_then(|v| v.as_str()) { + entry.static_ip = Some(static_ip.to_string()); + } + if let Some(force_vlan) = update.get("forceVlan").and_then(|v| v.as_bool()) { + entry.force_vlan = Some(force_vlan); + } + if let Some(vlan_id) = update.get("vlanId").and_then(|v| v.as_u64()) { + entry.vlan_id = Some(vlan_id as u16); + } })?; + + let updated_entry = { + let registry = state.client_registry.read().await; + registry.get_by_id(client_id).cloned() + }; + if let Some(entry) = updated_entry { + if let Some(ref ip_str) = entry.assigned_ip { + if let Ok(ip) = ip_str.parse::() { + if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await { + routing_table.write().await.insert(ip, entry.use_host_ip.unwrap_or(false)); + } + } + } + } + Ok(()) } @@ -1023,13 +1073,56 @@ impl VpnServer { .ok_or_else(|| anyhow::anyhow!("Server not running"))?; state.client_registry.write().await.update(client_id, |entry| { entry.enabled = Some(true); - }) + })?; + + let entry = { + let registry = state.client_registry.read().await; + registry.get_by_id(client_id) + .cloned() + .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))? + }; + + if self.wg_command_tx.is_some() { + if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) { + let peer_config = crate::wireguard::WgPeerConfig { + public_key: wg_key.clone(), + preshared_key: None, + allowed_ips: vec![format!("{}/32", ip_str)], + endpoint: None, + persistent_keepalive: Some(25), + }; + if let Err(e) = self.add_wg_peer(peer_config).await { + let _ = state.client_registry.write().await.update(client_id, |entry| { + entry.enabled = Some(false); + }); + return Err(anyhow::anyhow!("Failed to register WG peer for enabled client {}: {}", client_id, e)); + } + } + } + + Ok(()) } /// Disable a registered client (also disconnects if connected). pub async fn disable_client(&self, client_id: &str) -> Result<()> { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; + + let entry = { + let registry = state.client_registry.read().await; + registry.get_by_id(client_id).cloned() + }; + + if self.wg_command_tx.is_some() { + if let Some(ref entry) = entry { + if let Some(ref wg_key) = entry.wg_public_key { + if let Err(e) = self.remove_wg_peer(wg_key).await { + debug!("Failed to remove WG peer for disabled client {}: {}", client_id, e); + } + } + } + } + state.client_registry.write().await.update(client_id, |entry| { entry.enabled = Some(false); })?; @@ -1043,17 +1136,33 @@ impl VpnServer { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; - // Capture old WG key before rotation (needed to remove from WG listener) - let old_wg_pub = { + // Capture old keys before rotation so listener and registry can be rolled back together. + let old_entry = { let registry = state.client_registry.read().await; - let entry = registry.get_by_id(client_id) - .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; - entry.wg_public_key.clone() + registry.get_by_id(client_id) + .cloned() + .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))? }; + let old_noise_pub = old_entry.public_key.clone(); + let old_wg_pub = old_entry.wg_public_key.clone(); + let assigned_ip = old_entry.assigned_ip.clone().unwrap_or_else(|| "0.0.0.0".to_string()); + let should_have_wg_peer = self.wg_command_tx.is_some() + && old_entry.is_enabled() + && !old_entry.is_expired() + && old_wg_pub.is_some() + && assigned_ip != "0.0.0.0"; let (noise_pub, noise_priv) = crypto::generate_keypair()?; let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair(); + if should_have_wg_peer { + if let Some(ref old_key) = old_wg_pub { + if let Err(e) = self.remove_wg_peer(old_key).await { + return Err(anyhow::anyhow!("Failed to remove old WG peer during rotation for {}: {}", client_id, e)); + } + } + } + state.client_registry.write().await.rotate_key( client_id, noise_pub.clone(), @@ -1063,31 +1172,40 @@ impl VpnServer { // Disconnect existing connection (old key is no longer valid) let _ = self.disconnect_client(client_id).await; - // Get updated entry for the config bundle - let entry_json = self.get_registered_client(client_id).await?; - let assigned_ip = entry_json.get("assignedIp") - .and_then(|v| v.as_str()) - .unwrap_or("0.0.0.0"); - - // Update WG listener: remove old peer, add new peer - if self.wg_command_tx.is_some() { - if let Some(ref old_key) = old_wg_pub { - if let Err(e) = self.remove_wg_peer(old_key).await { - debug!("Failed to remove old WG peer during rotation: {}", e); - } - } + // Update WG listener with the new key. Roll back the registry if this fails. + if should_have_wg_peer { let wg_peer_config = crate::wireguard::WgPeerConfig { public_key: wg_pub.clone(), preshared_key: None, - allowed_ips: vec![format!("{}/32", assigned_ip)], + allowed_ips: vec![format!("{}/32", assigned_ip.as_str())], endpoint: None, persistent_keepalive: Some(25), }; if let Err(e) = self.add_wg_peer(wg_peer_config).await { - warn!("Failed to register new WG peer during rotation: {}", e); + let _ = state.client_registry.write().await.rotate_key( + client_id, + old_noise_pub, + old_wg_pub.clone(), + ); + if let Some(ref old_key) = old_wg_pub { + let rollback_peer_config = crate::wireguard::WgPeerConfig { + public_key: old_key.clone(), + preshared_key: None, + allowed_ips: vec![format!("{}/32", assigned_ip.as_str())], + endpoint: None, + persistent_keepalive: Some(25), + }; + if let Err(rollback_err) = self.add_wg_peer(rollback_peer_config).await { + warn!("Failed to restore old WG peer after rotation failure for {}: {}", client_id, rollback_err); + } + } + return Err(anyhow::anyhow!("Failed to register new WG peer during rotation for {}: {}", client_id, e)); } } + // Get updated entry for the config bundle + let entry_json = self.get_registered_client(client_id).await?; + let smartvpn_server_url = format!("wss://{}", state.config.server_endpoint.as_deref() .unwrap_or(&state.config.listen_addr) diff --git a/rust/src/wireguard.rs b/rust/src/wireguard.rs index 4ab4a04..84b1d06 100644 --- a/rust/src/wireguard.rs +++ b/rust/src/wireguard.rs @@ -215,6 +215,8 @@ pub enum WgCommand { struct PeerState { tunn: Tunn, public_key_b64: String, + /// Registered SmartVPN client ID when this peer belongs to a registry entry. + client_id: Option, allowed_ips: Vec, endpoint: Option, #[allow(dead_code)] @@ -237,6 +239,28 @@ impl PeerState { } } +fn synthetic_wg_client_id(pubkey: &str) -> String { + format!("wg-{}", &pubkey[..8.min(pubkey.len())]) +} + +fn peer_client_id(peer: &PeerState) -> String { + peer.client_id + .clone() + .unwrap_or_else(|| synthetic_wg_client_id(&peer.public_key_b64)) +} + +async fn registered_peer_info( + state: &Arc, + pubkey: &str, +) -> Option<(String, bool, bool)> { + let registry = state.client_registry.read().await; + registry.get_by_wg_key(pubkey).map(|entry| ( + entry.client_id.clone(), + entry.use_host_ip.unwrap_or(false), + entry.is_enabled() && !entry.is_expired(), + )) +} + fn add_peer_to_loop( peers: &mut Vec, @@ -281,6 +305,7 @@ fn add_peer_to_loop( peers.push(PeerState { tunn, public_key_b64: config.public_key.clone(), + client_id: None, allowed_ips, endpoint, persistent_keepalive: config.persistent_keepalive, @@ -344,7 +369,7 @@ fn wg_timestamp_now() -> String { /// Returns the VPN IP. async fn register_wg_peer( state: &Arc, - peer: &PeerState, + peer: &mut PeerState, wg_return_tx: &mpsc::Sender<(String, Vec)>, ) -> Result> { let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) { @@ -356,12 +381,24 @@ async fn register_wg_peer( } }; - let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]); + let (client_id, use_host_ip) = match registered_peer_info(state, &peer.public_key_b64).await { + Some((client_id, use_host_ip, true)) => (client_id, use_host_ip), + Some((client_id, _, false)) => { + warn!("WG peer {} maps to disabled or expired client {}, skipping registration", peer.public_key_b64, client_id); + return Ok(None); + } + None => (synthetic_wg_client_id(&peer.public_key_b64), false), + }; + peer.client_id = Some(client_id.clone()); // Reserve IP in the pool if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) { warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e); - return Ok(None); + return Err(e); + } + + if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await { + routing_table.write().await.insert(vpn_ip, use_host_ip); } // Create per-peer return channel and register in tun_routes @@ -393,7 +430,7 @@ async fn connect_wg_peer( peer: &PeerState, vpn_ip: Ipv4Addr, ) { - let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]); + let client_id = peer_client_id(peer); let client_info = ClientInfo { client_id: client_id.clone(), assigned_ip: vpn_ip.to_string(), @@ -426,24 +463,27 @@ async fn connect_wg_peer( /// Remove a WG peer from state.clients (disconnect without unregistering). async fn disconnect_wg_peer( state: &Arc, - pubkey: &str, + peer: &PeerState, ) { - let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]); + let client_id = peer_client_id(peer); if state.clients.write().await.remove(&client_id).is_some() { - info!("WG peer {} disconnected (removed from active clients)", pubkey); + info!("WG peer {} disconnected (removed from active clients)", peer.public_key_b64); } } /// Unregister a WG peer from ServerState. async fn unregister_wg_peer( state: &Arc, - pubkey: &str, + peer: &PeerState, vpn_ip: Option, ) { - let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]); + let client_id = peer_client_id(peer); if let Some(ip) = vpn_ip { state.tun_routes.write().await.remove(&ip); + if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await { + routing_table.write().await.remove(&ip); + } state.ip_pool.lock().await.release(&ip); } state.clients.write().await.remove(&client_id); @@ -501,6 +541,7 @@ pub async fn run_wg_listener( peers.push(PeerState { tunn, public_key_b64: peer_config.public_key.clone(), + client_id: None, allowed_ips, endpoint, persistent_keepalive: peer_config.persistent_keepalive, @@ -523,9 +564,13 @@ pub async fn run_wg_listener( // Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients) let mut peer_vpn_ips: HashMap = HashMap::new(); for peer in peers.iter_mut() { - if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await { - peer_vpn_ips.insert(peer.public_key_b64.clone(), ip); - peer.vpn_ip = Some(ip); + match register_wg_peer(&state, peer, &wg_return_tx).await { + Ok(Some(ip)) => { + peer_vpn_ips.insert(peer.public_key_b64.clone(), ip); + peer.vpn_ip = Some(ip); + } + Ok(None) => {} + Err(e) => warn!("Failed to register initial WG peer {}: {}", peer.public_key_b64, e), } } @@ -705,7 +750,7 @@ pub async fn run_wg_listener( warn!("WG peer {} connection expired", peer.public_key_b64); if peer.is_connected { peer.is_connected = false; - disconnect_wg_peer(&state, &peer.public_key_b64).await; + disconnect_wg_peer(&state, peer).await; } } TunnResult::Err(e) => { @@ -733,7 +778,7 @@ pub async fn run_wg_listener( } // Only update ClientInfo if peer is connected (in state.clients) - let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]); + let client_id = peer_client_id(peer); if let Some(info) = clients.get_mut(&client_id) { info.bytes_sent = peer.stats.bytes_sent; info.bytes_received = peer.stats.bytes_received; @@ -751,7 +796,7 @@ pub async fn run_wg_listener( if now.duration_since(last) > std::time::Duration::from_secs(180) { info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64); peer.is_connected = false; - disconnect_wg_peer(&state, &peer.public_key_b64).await; + disconnect_wg_peer(&state, peer).await; } } } @@ -762,7 +807,12 @@ pub async fn run_wg_listener( cmd = command_rx.recv() => { match cmd { Some(WgCommand::AddPeer(peer_config, resp_tx)) => { - let result = add_peer_to_loop( + if let Some((client_id, _, false)) = registered_peer_info(&state, &peer_config.public_key).await { + let _ = resp_tx.send(Err(anyhow!("WG peer maps to disabled or expired client: {}", client_id))); + continue; + } + + let mut result = add_peer_to_loop( &mut peers, &peer_config, &peer_index, @@ -777,20 +827,23 @@ pub async fn run_wg_listener( peer_vpn_ips.insert(peer_config.public_key.clone(), ip); peer.vpn_ip = Some(ip); } - Ok(None) => {} + Ok(None) => { + peers.retain(|p| p.public_key_b64 != peer_config.public_key); + result = Err(anyhow!("WG peer was not registered")); + } Err(e) => { - warn!("Failed to register WG peer: {}", e); + peers.retain(|p| p.public_key_b64 != peer_config.public_key); + result = Err(e); } } } let _ = resp_tx.send(result); } Some(WgCommand::RemovePeer(pubkey, resp_tx)) => { - let prev_len = peers.len(); - peers.retain(|p| p.public_key_b64 != pubkey); - if peers.len() < prev_len { + if let Some(index) = peers.iter().position(|p| p.public_key_b64 == pubkey) { + let peer = peers.remove(index); let vpn_ip = peer_vpn_ips.remove(&pubkey); - unregister_wg_peer(&state, &pubkey, vpn_ip).await; + unregister_wg_peer(&state, &peer, vpn_ip).await; let _ = resp_tx.send(Ok(())); } else { let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey))); @@ -814,7 +867,7 @@ pub async fn run_wg_listener( // Cleanup: unregister all peers from ServerState for peer in &peers { let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied(); - unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await; + unregister_wg_peer(&state, peer, vpn_ip).await; } info!("WireGuard listener stopped");