fix(rust-wireguard): keep WireGuard peer registration and client state in sync

This commit is contained in:
2026-05-12 23:08:11 +00:00
parent 773eb6426e
commit 3c515c7d7f
5 changed files with 314 additions and 51 deletions
+9
View File
@@ -3,6 +3,15 @@
## Pending
### Fixes
- keep WireGuard peer registration and client state in sync (rust-wireguard)
- index WireGuard public keys in the client registry for duplicate detection and direct lookup
- skip inactive clients when loading or adding WireGuard peers and fail fast when peer registration cannot be completed
- make IP reservation idempotent for the same client and avoid releasing WireGuard-assigned IPs on disconnect
- roll back client registry and peer state when WireGuard peer creation or key rotation fails
- update hybrid routing entries when registered client networking flags change
## 2026-05-12 - 1.19.3
### Fixes
+71 -1
View File
@@ -98,6 +98,8 @@ pub struct ClientRegistry {
entries: HashMap<String, ClientEntry>,
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
key_index: HashMap<String, String>,
/// WireGuard public key → clientId (fast lookup during WG handshakes)
wg_key_index: HashMap<String, String>,
/// Tertiary index: assignedIp → clientId (fast lookup during NAT destination policy)
ip_index: HashMap<String, String>,
}
@@ -107,6 +109,7 @@ impl ClientRegistry {
Self {
entries: HashMap::new(),
key_index: HashMap::new(),
wg_key_index: HashMap::new(),
ip_index: HashMap::new(),
}
}
@@ -132,6 +135,12 @@ impl ClientRegistry {
if self.key_index.contains_key(&entry.public_key) {
anyhow::bail!("Public key already registered to another client");
}
if let Some(ref wg_key) = entry.wg_public_key {
if self.wg_key_index.contains_key(wg_key) {
anyhow::bail!("WireGuard public key already registered to another client");
}
self.wg_key_index.insert(wg_key.clone(), entry.client_id.clone());
}
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
if let Some(ref ip) = entry.assigned_ip {
self.ip_index.insert(ip.clone(), entry.client_id.clone());
@@ -145,6 +154,9 @@ impl ClientRegistry {
let entry = self.entries.remove(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
self.key_index.remove(&entry.public_key);
if let Some(ref wg_key) = entry.wg_public_key {
self.wg_key_index.remove(wg_key);
}
if let Some(ref ip) = entry.assigned_ip {
self.ip_index.remove(ip);
}
@@ -162,6 +174,12 @@ impl ClientRegistry {
self.entries.get(client_id)
}
/// Get a client by WireGuard public key.
pub fn get_by_wg_key(&self, public_key: &str) -> Option<&ClientEntry> {
let client_id = self.wg_key_index.get(public_key)?;
self.entries.get(client_id)
}
/// Get a client by assigned IP (used for per-client destination policy in NAT engine).
pub fn get_by_assigned_ip(&self, ip: &str) -> Option<&ClientEntry> {
let client_id = self.ip_index.get(ip)?;
@@ -184,6 +202,7 @@ impl ClientRegistry {
let entry = self.entries.get_mut(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
let old_key = entry.public_key.clone();
let old_wg_key = entry.wg_public_key.clone();
let old_ip = entry.assigned_ip.clone();
updater(entry);
// If public key changed, update the key index
@@ -191,6 +210,15 @@ impl ClientRegistry {
self.key_index.remove(&old_key);
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
}
// If WireGuard public key changed, update the WG key index.
if entry.wg_public_key != old_wg_key {
if let Some(ref old) = old_wg_key {
self.wg_key_index.remove(old);
}
if let Some(ref new_key) = entry.wg_public_key {
self.wg_key_index.insert(new_key.clone(), client_id.to_string());
}
}
// If assigned IP changed, update the IP index
if entry.assigned_ip != old_ip {
if let Some(ref old) = old_ip {
@@ -210,13 +238,32 @@ impl ClientRegistry {
/// Rotate a client's keys. Returns the updated entry.
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
if let Some(existing_client_id) = self.key_index.get(&new_public_key) {
if existing_client_id != client_id {
anyhow::bail!("Public key already registered to another client");
}
}
if let Some(ref new_wg_key) = new_wg_public_key {
if let Some(existing_client_id) = self.wg_key_index.get(new_wg_key) {
if existing_client_id != client_id {
anyhow::bail!("WireGuard public key already registered to another client");
}
}
}
let entry = self.entries.get_mut(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
// Update key index
self.key_index.remove(&entry.public_key);
if let Some(ref old_wg_key) = entry.wg_public_key {
self.wg_key_index.remove(old_wg_key);
}
entry.public_key = new_public_key.clone();
entry.wg_public_key = new_wg_public_key;
self.key_index.insert(new_public_key, client_id.to_string());
if let Some(ref wg_key) = entry.wg_public_key {
self.wg_key_index.insert(wg_key.clone(), client_id.to_string());
}
Ok(())
}
@@ -367,13 +414,36 @@ mod tests {
let mut reg = ClientRegistry::new();
reg.add(make_entry("alice", "old_key")).unwrap();
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
reg.rotate_key("alice", "new_key".to_string(), Some("new_wg_key".to_string())).unwrap();
assert!(reg.get_by_key("old_key").is_none());
assert!(reg.get_by_key("new_key").is_some());
assert!(reg.get_by_wg_key("new_wg_key").is_some());
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
}
#[test]
fn lookup_by_wireguard_key() {
let mut reg = ClientRegistry::new();
let mut entry = make_entry("alice", "key_alice");
entry.wg_public_key = Some("wg_key_alice".to_string());
reg.add(entry).unwrap();
assert_eq!(reg.get_by_wg_key("wg_key_alice").unwrap().client_id, "alice");
}
#[test]
fn reject_duplicate_wireguard_key() {
let mut reg = ClientRegistry::new();
let mut alice = make_entry("alice", "key_alice");
alice.wg_public_key = Some("same_wg_key".to_string());
let mut bob = make_entry("bob", "key_bob");
bob.wg_public_key = Some("same_wg_key".to_string());
reg.add(alice).unwrap();
assert!(reg.add(bob).is_err());
}
#[test]
fn from_entries() {
let entries = vec![
+14 -1
View File
@@ -124,7 +124,10 @@ impl IpPool {
/// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips).
pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> {
if self.allocated.contains_key(&ip) {
if let Some(existing_client_id) = self.allocated.get(&ip) {
if existing_client_id == client_id {
return Ok(());
}
anyhow::bail!("IP {} is already allocated", ip);
}
self.allocated.insert(ip, client_id.to_string());
@@ -233,6 +236,16 @@ mod tests {
assert!(pool.allocate("client2").is_err());
}
#[test]
fn ip_pool_reserve_is_idempotent_for_same_client() {
let mut pool = IpPool::new("10.9.0.0/24").unwrap();
let ip = pool.allocate("alice").unwrap();
pool.reserve(ip, "alice").unwrap();
assert_eq!(pool.allocated_count(), 1);
assert!(pool.reserve(ip, "bob").is_err());
}
#[test]
fn ip_pool_invalid_subnet() {
assert!(IpPool::new("invalid").is_err());
+144 -26
View File
@@ -584,6 +584,9 @@ impl VpnServer {
if self.wg_command_tx.is_some() {
let registry = state.client_registry.read().await;
for entry in registry.list() {
if !entry.is_enabled() || entry.is_expired() {
continue;
}
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
let peer_config = crate::wireguard::WgPeerConfig {
public_key: wg_key.clone(),
@@ -725,8 +728,10 @@ impl VpnServer {
if let Some(ref state) = self.state {
let mut clients = state.clients.write().await;
if let Some(client) = clients.remove(client_id) {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
if client.transport_type != "wireguard" {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
}
state.rate_limiters.lock().await.remove(client_id);
info!("Client {} disconnected", client_id);
}
@@ -890,7 +895,9 @@ impl VpnServer {
persistent_keepalive: Some(25),
};
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
warn!("Failed to register WG peer for client {}: {}", client_id, e);
let _ = state.client_registry.write().await.remove(&client_id);
state.ip_pool.lock().await.release(&assigned_ip);
return Err(anyhow::anyhow!("Failed to register WG peer for client {}: {}", client_id, e));
}
}
@@ -948,15 +955,28 @@ impl VpnServer {
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let entry = state.client_registry.write().await.remove(client_id)?;
let entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
};
// Remove WG peer from running listener
if self.wg_command_tx.is_some() {
if let Some(ref wg_key) = entry.wg_public_key {
if let Err(e) = self.remove_wg_peer(wg_key).await {
debug!("Failed to remove WG peer for client {}: {}", client_id, e);
if entry.is_enabled() && !entry.is_expired() {
return Err(anyhow::anyhow!("Failed to remove WG peer for client {}: {}", client_id, e));
}
debug!("Failed to remove inactive WG peer for client {}: {}", client_id, e);
}
}
}
state.client_registry.write().await.remove(client_id)?;
// Release the IP if assigned
if let Some(ref ip_str) = entry.assigned_ip {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
@@ -1013,7 +1033,37 @@ impl VpnServer {
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
entry.expires_at = Some(expires.to_string());
}
if let Some(use_host_ip) = update.get("useHostIp").and_then(|v| v.as_bool()) {
entry.use_host_ip = Some(use_host_ip);
}
if let Some(use_dhcp) = update.get("useDhcp").and_then(|v| v.as_bool()) {
entry.use_dhcp = Some(use_dhcp);
}
if let Some(static_ip) = update.get("staticIp").and_then(|v| v.as_str()) {
entry.static_ip = Some(static_ip.to_string());
}
if let Some(force_vlan) = update.get("forceVlan").and_then(|v| v.as_bool()) {
entry.force_vlan = Some(force_vlan);
}
if let Some(vlan_id) = update.get("vlanId").and_then(|v| v.as_u64()) {
entry.vlan_id = Some(vlan_id as u16);
}
})?;
let updated_entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id).cloned()
};
if let Some(entry) = updated_entry {
if let Some(ref ip_str) = entry.assigned_ip {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
routing_table.write().await.insert(ip, entry.use_host_ip.unwrap_or(false));
}
}
}
}
Ok(())
}
@@ -1023,13 +1073,56 @@ impl VpnServer {
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(true);
})
})?;
let entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
};
if self.wg_command_tx.is_some() {
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
let peer_config = crate::wireguard::WgPeerConfig {
public_key: wg_key.clone(),
preshared_key: None,
allowed_ips: vec![format!("{}/32", ip_str)],
endpoint: None,
persistent_keepalive: Some(25),
};
if let Err(e) = self.add_wg_peer(peer_config).await {
let _ = state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(false);
});
return Err(anyhow::anyhow!("Failed to register WG peer for enabled client {}: {}", client_id, e));
}
}
}
Ok(())
}
/// Disable a registered client (also disconnects if connected).
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id).cloned()
};
if self.wg_command_tx.is_some() {
if let Some(ref entry) = entry {
if let Some(ref wg_key) = entry.wg_public_key {
if let Err(e) = self.remove_wg_peer(wg_key).await {
debug!("Failed to remove WG peer for disabled client {}: {}", client_id, e);
}
}
}
}
state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(false);
})?;
@@ -1043,17 +1136,33 @@ impl VpnServer {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
// Capture old WG key before rotation (needed to remove from WG listener)
let old_wg_pub = {
// Capture old keys before rotation so listener and registry can be rolled back together.
let old_entry = {
let registry = state.client_registry.read().await;
let entry = registry.get_by_id(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
entry.wg_public_key.clone()
registry.get_by_id(client_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
};
let old_noise_pub = old_entry.public_key.clone();
let old_wg_pub = old_entry.wg_public_key.clone();
let assigned_ip = old_entry.assigned_ip.clone().unwrap_or_else(|| "0.0.0.0".to_string());
let should_have_wg_peer = self.wg_command_tx.is_some()
&& old_entry.is_enabled()
&& !old_entry.is_expired()
&& old_wg_pub.is_some()
&& assigned_ip != "0.0.0.0";
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
if should_have_wg_peer {
if let Some(ref old_key) = old_wg_pub {
if let Err(e) = self.remove_wg_peer(old_key).await {
return Err(anyhow::anyhow!("Failed to remove old WG peer during rotation for {}: {}", client_id, e));
}
}
}
state.client_registry.write().await.rotate_key(
client_id,
noise_pub.clone(),
@@ -1063,31 +1172,40 @@ impl VpnServer {
// Disconnect existing connection (old key is no longer valid)
let _ = self.disconnect_client(client_id).await;
// Get updated entry for the config bundle
let entry_json = self.get_registered_client(client_id).await?;
let assigned_ip = entry_json.get("assignedIp")
.and_then(|v| v.as_str())
.unwrap_or("0.0.0.0");
// Update WG listener: remove old peer, add new peer
if self.wg_command_tx.is_some() {
if let Some(ref old_key) = old_wg_pub {
if let Err(e) = self.remove_wg_peer(old_key).await {
debug!("Failed to remove old WG peer during rotation: {}", e);
}
}
// Update WG listener with the new key. Roll back the registry if this fails.
if should_have_wg_peer {
let wg_peer_config = crate::wireguard::WgPeerConfig {
public_key: wg_pub.clone(),
preshared_key: None,
allowed_ips: vec![format!("{}/32", assigned_ip)],
allowed_ips: vec![format!("{}/32", assigned_ip.as_str())],
endpoint: None,
persistent_keepalive: Some(25),
};
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
warn!("Failed to register new WG peer during rotation: {}", e);
let _ = state.client_registry.write().await.rotate_key(
client_id,
old_noise_pub,
old_wg_pub.clone(),
);
if let Some(ref old_key) = old_wg_pub {
let rollback_peer_config = crate::wireguard::WgPeerConfig {
public_key: old_key.clone(),
preshared_key: None,
allowed_ips: vec![format!("{}/32", assigned_ip.as_str())],
endpoint: None,
persistent_keepalive: Some(25),
};
if let Err(rollback_err) = self.add_wg_peer(rollback_peer_config).await {
warn!("Failed to restore old WG peer after rotation failure for {}: {}", client_id, rollback_err);
}
}
return Err(anyhow::anyhow!("Failed to register new WG peer during rotation for {}: {}", client_id, e));
}
}
// Get updated entry for the config bundle
let entry_json = self.get_registered_client(client_id).await?;
let smartvpn_server_url = format!("wss://{}",
state.config.server_endpoint.as_deref()
.unwrap_or(&state.config.listen_addr)
+76 -23
View File
@@ -215,6 +215,8 @@ pub enum WgCommand {
struct PeerState {
tunn: Tunn,
public_key_b64: String,
/// Registered SmartVPN client ID when this peer belongs to a registry entry.
client_id: Option<String>,
allowed_ips: Vec<AllowedIp>,
endpoint: Option<SocketAddr>,
#[allow(dead_code)]
@@ -237,6 +239,28 @@ impl PeerState {
}
}
fn synthetic_wg_client_id(pubkey: &str) -> String {
format!("wg-{}", &pubkey[..8.min(pubkey.len())])
}
fn peer_client_id(peer: &PeerState) -> String {
peer.client_id
.clone()
.unwrap_or_else(|| synthetic_wg_client_id(&peer.public_key_b64))
}
async fn registered_peer_info(
state: &Arc<ServerState>,
pubkey: &str,
) -> Option<(String, bool, bool)> {
let registry = state.client_registry.read().await;
registry.get_by_wg_key(pubkey).map(|entry| (
entry.client_id.clone(),
entry.use_host_ip.unwrap_or(false),
entry.is_enabled() && !entry.is_expired(),
))
}
fn add_peer_to_loop(
peers: &mut Vec<PeerState>,
@@ -281,6 +305,7 @@ fn add_peer_to_loop(
peers.push(PeerState {
tunn,
public_key_b64: config.public_key.clone(),
client_id: None,
allowed_ips,
endpoint,
persistent_keepalive: config.persistent_keepalive,
@@ -344,7 +369,7 @@ fn wg_timestamp_now() -> String {
/// Returns the VPN IP.
async fn register_wg_peer(
state: &Arc<ServerState>,
peer: &PeerState,
peer: &mut PeerState,
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
) -> Result<Option<Ipv4Addr>> {
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
@@ -356,12 +381,24 @@ async fn register_wg_peer(
}
};
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let (client_id, use_host_ip) = match registered_peer_info(state, &peer.public_key_b64).await {
Some((client_id, use_host_ip, true)) => (client_id, use_host_ip),
Some((client_id, _, false)) => {
warn!("WG peer {} maps to disabled or expired client {}, skipping registration", peer.public_key_b64, client_id);
return Ok(None);
}
None => (synthetic_wg_client_id(&peer.public_key_b64), false),
};
peer.client_id = Some(client_id.clone());
// Reserve IP in the pool
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
return Ok(None);
return Err(e);
}
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
routing_table.write().await.insert(vpn_ip, use_host_ip);
}
// Create per-peer return channel and register in tun_routes
@@ -393,7 +430,7 @@ async fn connect_wg_peer(
peer: &PeerState,
vpn_ip: Ipv4Addr,
) {
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let client_id = peer_client_id(peer);
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: vpn_ip.to_string(),
@@ -426,24 +463,27 @@ async fn connect_wg_peer(
/// Remove a WG peer from state.clients (disconnect without unregistering).
async fn disconnect_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
peer: &PeerState,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
let client_id = peer_client_id(peer);
if state.clients.write().await.remove(&client_id).is_some() {
info!("WG peer {} disconnected (removed from active clients)", pubkey);
info!("WG peer {} disconnected (removed from active clients)", peer.public_key_b64);
}
}
/// Unregister a WG peer from ServerState.
async fn unregister_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
peer: &PeerState,
vpn_ip: Option<Ipv4Addr>,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
let client_id = peer_client_id(peer);
if let Some(ip) = vpn_ip {
state.tun_routes.write().await.remove(&ip);
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
routing_table.write().await.remove(&ip);
}
state.ip_pool.lock().await.release(&ip);
}
state.clients.write().await.remove(&client_id);
@@ -501,6 +541,7 @@ pub async fn run_wg_listener(
peers.push(PeerState {
tunn,
public_key_b64: peer_config.public_key.clone(),
client_id: None,
allowed_ips,
endpoint,
persistent_keepalive: peer_config.persistent_keepalive,
@@ -523,9 +564,13 @@ pub async fn run_wg_listener(
// Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients)
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
for peer in peers.iter_mut() {
if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await {
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
peer.vpn_ip = Some(ip);
match register_wg_peer(&state, peer, &wg_return_tx).await {
Ok(Some(ip)) => {
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
peer.vpn_ip = Some(ip);
}
Ok(None) => {}
Err(e) => warn!("Failed to register initial WG peer {}: {}", peer.public_key_b64, e),
}
}
@@ -705,7 +750,7 @@ pub async fn run_wg_listener(
warn!("WG peer {} connection expired", peer.public_key_b64);
if peer.is_connected {
peer.is_connected = false;
disconnect_wg_peer(&state, &peer.public_key_b64).await;
disconnect_wg_peer(&state, peer).await;
}
}
TunnResult::Err(e) => {
@@ -733,7 +778,7 @@ pub async fn run_wg_listener(
}
// Only update ClientInfo if peer is connected (in state.clients)
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let client_id = peer_client_id(peer);
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_sent = peer.stats.bytes_sent;
info.bytes_received = peer.stats.bytes_received;
@@ -751,7 +796,7 @@ pub async fn run_wg_listener(
if now.duration_since(last) > std::time::Duration::from_secs(180) {
info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64);
peer.is_connected = false;
disconnect_wg_peer(&state, &peer.public_key_b64).await;
disconnect_wg_peer(&state, peer).await;
}
}
}
@@ -762,7 +807,12 @@ pub async fn run_wg_listener(
cmd = command_rx.recv() => {
match cmd {
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
let result = add_peer_to_loop(
if let Some((client_id, _, false)) = registered_peer_info(&state, &peer_config.public_key).await {
let _ = resp_tx.send(Err(anyhow!("WG peer maps to disabled or expired client: {}", client_id)));
continue;
}
let mut result = add_peer_to_loop(
&mut peers,
&peer_config,
&peer_index,
@@ -777,20 +827,23 @@ pub async fn run_wg_listener(
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
peer.vpn_ip = Some(ip);
}
Ok(None) => {}
Ok(None) => {
peers.retain(|p| p.public_key_b64 != peer_config.public_key);
result = Err(anyhow!("WG peer was not registered"));
}
Err(e) => {
warn!("Failed to register WG peer: {}", e);
peers.retain(|p| p.public_key_b64 != peer_config.public_key);
result = Err(e);
}
}
}
let _ = resp_tx.send(result);
}
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
let prev_len = peers.len();
peers.retain(|p| p.public_key_b64 != pubkey);
if peers.len() < prev_len {
if let Some(index) = peers.iter().position(|p| p.public_key_b64 == pubkey) {
let peer = peers.remove(index);
let vpn_ip = peer_vpn_ips.remove(&pubkey);
unregister_wg_peer(&state, &pubkey, vpn_ip).await;
unregister_wg_peer(&state, &peer, vpn_ip).await;
let _ = resp_tx.send(Ok(()));
} else {
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
@@ -814,7 +867,7 @@ pub async fn run_wg_listener(
// Cleanup: unregister all peers from ServerState
for peer in &peers {
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await;
unregister_wg_peer(&state, peer, vpn_ip).await;
}
info!("WireGuard listener stopped");