fix(rust-wireguard): keep WireGuard peer registration and client state in sync
This commit is contained in:
@@ -3,6 +3,15 @@
|
|||||||
## Pending
|
## Pending
|
||||||
|
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
- keep WireGuard peer registration and client state in sync (rust-wireguard)
|
||||||
|
- index WireGuard public keys in the client registry for duplicate detection and direct lookup
|
||||||
|
- skip inactive clients when loading or adding WireGuard peers and fail fast when peer registration cannot be completed
|
||||||
|
- make IP reservation idempotent for the same client and avoid releasing WireGuard-assigned IPs on disconnect
|
||||||
|
- roll back client registry and peer state when WireGuard peer creation or key rotation fails
|
||||||
|
- update hybrid routing entries when registered client networking flags change
|
||||||
|
|
||||||
## 2026-05-12 - 1.19.3
|
## 2026-05-12 - 1.19.3
|
||||||
|
|
||||||
### Fixes
|
### Fixes
|
||||||
|
|||||||
@@ -98,6 +98,8 @@ pub struct ClientRegistry {
|
|||||||
entries: HashMap<String, ClientEntry>,
|
entries: HashMap<String, ClientEntry>,
|
||||||
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
|
||||||
key_index: HashMap<String, String>,
|
key_index: HashMap<String, String>,
|
||||||
|
/// WireGuard public key → clientId (fast lookup during WG handshakes)
|
||||||
|
wg_key_index: HashMap<String, String>,
|
||||||
/// Tertiary index: assignedIp → clientId (fast lookup during NAT destination policy)
|
/// Tertiary index: assignedIp → clientId (fast lookup during NAT destination policy)
|
||||||
ip_index: HashMap<String, String>,
|
ip_index: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
@@ -107,6 +109,7 @@ impl ClientRegistry {
|
|||||||
Self {
|
Self {
|
||||||
entries: HashMap::new(),
|
entries: HashMap::new(),
|
||||||
key_index: HashMap::new(),
|
key_index: HashMap::new(),
|
||||||
|
wg_key_index: HashMap::new(),
|
||||||
ip_index: HashMap::new(),
|
ip_index: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,6 +135,12 @@ impl ClientRegistry {
|
|||||||
if self.key_index.contains_key(&entry.public_key) {
|
if self.key_index.contains_key(&entry.public_key) {
|
||||||
anyhow::bail!("Public key already registered to another client");
|
anyhow::bail!("Public key already registered to another client");
|
||||||
}
|
}
|
||||||
|
if let Some(ref wg_key) = entry.wg_public_key {
|
||||||
|
if self.wg_key_index.contains_key(wg_key) {
|
||||||
|
anyhow::bail!("WireGuard public key already registered to another client");
|
||||||
|
}
|
||||||
|
self.wg_key_index.insert(wg_key.clone(), entry.client_id.clone());
|
||||||
|
}
|
||||||
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
|
||||||
if let Some(ref ip) = entry.assigned_ip {
|
if let Some(ref ip) = entry.assigned_ip {
|
||||||
self.ip_index.insert(ip.clone(), entry.client_id.clone());
|
self.ip_index.insert(ip.clone(), entry.client_id.clone());
|
||||||
@@ -145,6 +154,9 @@ impl ClientRegistry {
|
|||||||
let entry = self.entries.remove(client_id)
|
let entry = self.entries.remove(client_id)
|
||||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||||
self.key_index.remove(&entry.public_key);
|
self.key_index.remove(&entry.public_key);
|
||||||
|
if let Some(ref wg_key) = entry.wg_public_key {
|
||||||
|
self.wg_key_index.remove(wg_key);
|
||||||
|
}
|
||||||
if let Some(ref ip) = entry.assigned_ip {
|
if let Some(ref ip) = entry.assigned_ip {
|
||||||
self.ip_index.remove(ip);
|
self.ip_index.remove(ip);
|
||||||
}
|
}
|
||||||
@@ -162,6 +174,12 @@ impl ClientRegistry {
|
|||||||
self.entries.get(client_id)
|
self.entries.get(client_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get a client by WireGuard public key.
|
||||||
|
pub fn get_by_wg_key(&self, public_key: &str) -> Option<&ClientEntry> {
|
||||||
|
let client_id = self.wg_key_index.get(public_key)?;
|
||||||
|
self.entries.get(client_id)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a client by assigned IP (used for per-client destination policy in NAT engine).
|
/// Get a client by assigned IP (used for per-client destination policy in NAT engine).
|
||||||
pub fn get_by_assigned_ip(&self, ip: &str) -> Option<&ClientEntry> {
|
pub fn get_by_assigned_ip(&self, ip: &str) -> Option<&ClientEntry> {
|
||||||
let client_id = self.ip_index.get(ip)?;
|
let client_id = self.ip_index.get(ip)?;
|
||||||
@@ -184,6 +202,7 @@ impl ClientRegistry {
|
|||||||
let entry = self.entries.get_mut(client_id)
|
let entry = self.entries.get_mut(client_id)
|
||||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||||
let old_key = entry.public_key.clone();
|
let old_key = entry.public_key.clone();
|
||||||
|
let old_wg_key = entry.wg_public_key.clone();
|
||||||
let old_ip = entry.assigned_ip.clone();
|
let old_ip = entry.assigned_ip.clone();
|
||||||
updater(entry);
|
updater(entry);
|
||||||
// If public key changed, update the key index
|
// If public key changed, update the key index
|
||||||
@@ -191,6 +210,15 @@ impl ClientRegistry {
|
|||||||
self.key_index.remove(&old_key);
|
self.key_index.remove(&old_key);
|
||||||
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
|
||||||
}
|
}
|
||||||
|
// If WireGuard public key changed, update the WG key index.
|
||||||
|
if entry.wg_public_key != old_wg_key {
|
||||||
|
if let Some(ref old) = old_wg_key {
|
||||||
|
self.wg_key_index.remove(old);
|
||||||
|
}
|
||||||
|
if let Some(ref new_key) = entry.wg_public_key {
|
||||||
|
self.wg_key_index.insert(new_key.clone(), client_id.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
// If assigned IP changed, update the IP index
|
// If assigned IP changed, update the IP index
|
||||||
if entry.assigned_ip != old_ip {
|
if entry.assigned_ip != old_ip {
|
||||||
if let Some(ref old) = old_ip {
|
if let Some(ref old) = old_ip {
|
||||||
@@ -210,13 +238,32 @@ impl ClientRegistry {
|
|||||||
|
|
||||||
/// Rotate a client's keys. Returns the updated entry.
|
/// Rotate a client's keys. Returns the updated entry.
|
||||||
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
|
||||||
|
if let Some(existing_client_id) = self.key_index.get(&new_public_key) {
|
||||||
|
if existing_client_id != client_id {
|
||||||
|
anyhow::bail!("Public key already registered to another client");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(ref new_wg_key) = new_wg_public_key {
|
||||||
|
if let Some(existing_client_id) = self.wg_key_index.get(new_wg_key) {
|
||||||
|
if existing_client_id != client_id {
|
||||||
|
anyhow::bail!("WireGuard public key already registered to another client");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let entry = self.entries.get_mut(client_id)
|
let entry = self.entries.get_mut(client_id)
|
||||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
||||||
// Update key index
|
// Update key index
|
||||||
self.key_index.remove(&entry.public_key);
|
self.key_index.remove(&entry.public_key);
|
||||||
|
if let Some(ref old_wg_key) = entry.wg_public_key {
|
||||||
|
self.wg_key_index.remove(old_wg_key);
|
||||||
|
}
|
||||||
entry.public_key = new_public_key.clone();
|
entry.public_key = new_public_key.clone();
|
||||||
entry.wg_public_key = new_wg_public_key;
|
entry.wg_public_key = new_wg_public_key;
|
||||||
self.key_index.insert(new_public_key, client_id.to_string());
|
self.key_index.insert(new_public_key, client_id.to_string());
|
||||||
|
if let Some(ref wg_key) = entry.wg_public_key {
|
||||||
|
self.wg_key_index.insert(wg_key.clone(), client_id.to_string());
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -367,13 +414,36 @@ mod tests {
|
|||||||
let mut reg = ClientRegistry::new();
|
let mut reg = ClientRegistry::new();
|
||||||
reg.add(make_entry("alice", "old_key")).unwrap();
|
reg.add(make_entry("alice", "old_key")).unwrap();
|
||||||
|
|
||||||
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
|
reg.rotate_key("alice", "new_key".to_string(), Some("new_wg_key".to_string())).unwrap();
|
||||||
|
|
||||||
assert!(reg.get_by_key("old_key").is_none());
|
assert!(reg.get_by_key("old_key").is_none());
|
||||||
assert!(reg.get_by_key("new_key").is_some());
|
assert!(reg.get_by_key("new_key").is_some());
|
||||||
|
assert!(reg.get_by_wg_key("new_wg_key").is_some());
|
||||||
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lookup_by_wireguard_key() {
|
||||||
|
let mut reg = ClientRegistry::new();
|
||||||
|
let mut entry = make_entry("alice", "key_alice");
|
||||||
|
entry.wg_public_key = Some("wg_key_alice".to_string());
|
||||||
|
reg.add(entry).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(reg.get_by_wg_key("wg_key_alice").unwrap().client_id, "alice");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reject_duplicate_wireguard_key() {
|
||||||
|
let mut reg = ClientRegistry::new();
|
||||||
|
let mut alice = make_entry("alice", "key_alice");
|
||||||
|
alice.wg_public_key = Some("same_wg_key".to_string());
|
||||||
|
let mut bob = make_entry("bob", "key_bob");
|
||||||
|
bob.wg_public_key = Some("same_wg_key".to_string());
|
||||||
|
|
||||||
|
reg.add(alice).unwrap();
|
||||||
|
assert!(reg.add(bob).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn from_entries() {
|
fn from_entries() {
|
||||||
let entries = vec![
|
let entries = vec![
|
||||||
|
|||||||
+14
-1
@@ -124,7 +124,10 @@ impl IpPool {
|
|||||||
|
|
||||||
/// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips).
|
/// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips).
|
||||||
pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> {
|
pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> {
|
||||||
if self.allocated.contains_key(&ip) {
|
if let Some(existing_client_id) = self.allocated.get(&ip) {
|
||||||
|
if existing_client_id == client_id {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
anyhow::bail!("IP {} is already allocated", ip);
|
anyhow::bail!("IP {} is already allocated", ip);
|
||||||
}
|
}
|
||||||
self.allocated.insert(ip, client_id.to_string());
|
self.allocated.insert(ip, client_id.to_string());
|
||||||
@@ -233,6 +236,16 @@ mod tests {
|
|||||||
assert!(pool.allocate("client2").is_err());
|
assert!(pool.allocate("client2").is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ip_pool_reserve_is_idempotent_for_same_client() {
|
||||||
|
let mut pool = IpPool::new("10.9.0.0/24").unwrap();
|
||||||
|
let ip = pool.allocate("alice").unwrap();
|
||||||
|
|
||||||
|
pool.reserve(ip, "alice").unwrap();
|
||||||
|
assert_eq!(pool.allocated_count(), 1);
|
||||||
|
assert!(pool.reserve(ip, "bob").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ip_pool_invalid_subnet() {
|
fn ip_pool_invalid_subnet() {
|
||||||
assert!(IpPool::new("invalid").is_err());
|
assert!(IpPool::new("invalid").is_err());
|
||||||
|
|||||||
+142
-24
@@ -584,6 +584,9 @@ impl VpnServer {
|
|||||||
if self.wg_command_tx.is_some() {
|
if self.wg_command_tx.is_some() {
|
||||||
let registry = state.client_registry.read().await;
|
let registry = state.client_registry.read().await;
|
||||||
for entry in registry.list() {
|
for entry in registry.list() {
|
||||||
|
if !entry.is_enabled() || entry.is_expired() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
|
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
|
||||||
let peer_config = crate::wireguard::WgPeerConfig {
|
let peer_config = crate::wireguard::WgPeerConfig {
|
||||||
public_key: wg_key.clone(),
|
public_key: wg_key.clone(),
|
||||||
@@ -725,8 +728,10 @@ impl VpnServer {
|
|||||||
if let Some(ref state) = self.state {
|
if let Some(ref state) = self.state {
|
||||||
let mut clients = state.clients.write().await;
|
let mut clients = state.clients.write().await;
|
||||||
if let Some(client) = clients.remove(client_id) {
|
if let Some(client) = clients.remove(client_id) {
|
||||||
|
if client.transport_type != "wireguard" {
|
||||||
let ip: Ipv4Addr = client.assigned_ip.parse()?;
|
let ip: Ipv4Addr = client.assigned_ip.parse()?;
|
||||||
state.ip_pool.lock().await.release(&ip);
|
state.ip_pool.lock().await.release(&ip);
|
||||||
|
}
|
||||||
state.rate_limiters.lock().await.remove(client_id);
|
state.rate_limiters.lock().await.remove(client_id);
|
||||||
info!("Client {} disconnected", client_id);
|
info!("Client {} disconnected", client_id);
|
||||||
}
|
}
|
||||||
@@ -890,7 +895,9 @@ impl VpnServer {
|
|||||||
persistent_keepalive: Some(25),
|
persistent_keepalive: Some(25),
|
||||||
};
|
};
|
||||||
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
|
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
|
||||||
warn!("Failed to register WG peer for client {}: {}", client_id, e);
|
let _ = state.client_registry.write().await.remove(&client_id);
|
||||||
|
state.ip_pool.lock().await.release(&assigned_ip);
|
||||||
|
return Err(anyhow::anyhow!("Failed to register WG peer for client {}: {}", client_id, e));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -948,15 +955,28 @@ impl VpnServer {
|
|||||||
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
|
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
|
||||||
let state = self.state.as_ref()
|
let state = self.state.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||||
let entry = state.client_registry.write().await.remove(client_id)?;
|
|
||||||
|
let entry = {
|
||||||
|
let registry = state.client_registry.read().await;
|
||||||
|
registry.get_by_id(client_id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
|
||||||
|
};
|
||||||
|
|
||||||
// Remove WG peer from running listener
|
// Remove WG peer from running listener
|
||||||
if self.wg_command_tx.is_some() {
|
if self.wg_command_tx.is_some() {
|
||||||
if let Some(ref wg_key) = entry.wg_public_key {
|
if let Some(ref wg_key) = entry.wg_public_key {
|
||||||
if let Err(e) = self.remove_wg_peer(wg_key).await {
|
if let Err(e) = self.remove_wg_peer(wg_key).await {
|
||||||
debug!("Failed to remove WG peer for client {}: {}", client_id, e);
|
if entry.is_enabled() && !entry.is_expired() {
|
||||||
|
return Err(anyhow::anyhow!("Failed to remove WG peer for client {}: {}", client_id, e));
|
||||||
|
}
|
||||||
|
debug!("Failed to remove inactive WG peer for client {}: {}", client_id, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.client_registry.write().await.remove(client_id)?;
|
||||||
|
|
||||||
// Release the IP if assigned
|
// Release the IP if assigned
|
||||||
if let Some(ref ip_str) = entry.assigned_ip {
|
if let Some(ref ip_str) = entry.assigned_ip {
|
||||||
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||||
@@ -1013,7 +1033,37 @@ impl VpnServer {
|
|||||||
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
|
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
|
||||||
entry.expires_at = Some(expires.to_string());
|
entry.expires_at = Some(expires.to_string());
|
||||||
}
|
}
|
||||||
|
if let Some(use_host_ip) = update.get("useHostIp").and_then(|v| v.as_bool()) {
|
||||||
|
entry.use_host_ip = Some(use_host_ip);
|
||||||
|
}
|
||||||
|
if let Some(use_dhcp) = update.get("useDhcp").and_then(|v| v.as_bool()) {
|
||||||
|
entry.use_dhcp = Some(use_dhcp);
|
||||||
|
}
|
||||||
|
if let Some(static_ip) = update.get("staticIp").and_then(|v| v.as_str()) {
|
||||||
|
entry.static_ip = Some(static_ip.to_string());
|
||||||
|
}
|
||||||
|
if let Some(force_vlan) = update.get("forceVlan").and_then(|v| v.as_bool()) {
|
||||||
|
entry.force_vlan = Some(force_vlan);
|
||||||
|
}
|
||||||
|
if let Some(vlan_id) = update.get("vlanId").and_then(|v| v.as_u64()) {
|
||||||
|
entry.vlan_id = Some(vlan_id as u16);
|
||||||
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let updated_entry = {
|
||||||
|
let registry = state.client_registry.read().await;
|
||||||
|
registry.get_by_id(client_id).cloned()
|
||||||
|
};
|
||||||
|
if let Some(entry) = updated_entry {
|
||||||
|
if let Some(ref ip_str) = entry.assigned_ip {
|
||||||
|
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||||
|
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
|
||||||
|
routing_table.write().await.insert(ip, entry.use_host_ip.unwrap_or(false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1023,13 +1073,56 @@ impl VpnServer {
|
|||||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||||
state.client_registry.write().await.update(client_id, |entry| {
|
state.client_registry.write().await.update(client_id, |entry| {
|
||||||
entry.enabled = Some(true);
|
entry.enabled = Some(true);
|
||||||
})
|
})?;
|
||||||
|
|
||||||
|
let entry = {
|
||||||
|
let registry = state.client_registry.read().await;
|
||||||
|
registry.get_by_id(client_id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.wg_command_tx.is_some() {
|
||||||
|
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
|
||||||
|
let peer_config = crate::wireguard::WgPeerConfig {
|
||||||
|
public_key: wg_key.clone(),
|
||||||
|
preshared_key: None,
|
||||||
|
allowed_ips: vec![format!("{}/32", ip_str)],
|
||||||
|
endpoint: None,
|
||||||
|
persistent_keepalive: Some(25),
|
||||||
|
};
|
||||||
|
if let Err(e) = self.add_wg_peer(peer_config).await {
|
||||||
|
let _ = state.client_registry.write().await.update(client_id, |entry| {
|
||||||
|
entry.enabled = Some(false);
|
||||||
|
});
|
||||||
|
return Err(anyhow::anyhow!("Failed to register WG peer for enabled client {}: {}", client_id, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Disable a registered client (also disconnects if connected).
|
/// Disable a registered client (also disconnects if connected).
|
||||||
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
|
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
|
||||||
let state = self.state.as_ref()
|
let state = self.state.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||||
|
|
||||||
|
let entry = {
|
||||||
|
let registry = state.client_registry.read().await;
|
||||||
|
registry.get_by_id(client_id).cloned()
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.wg_command_tx.is_some() {
|
||||||
|
if let Some(ref entry) = entry {
|
||||||
|
if let Some(ref wg_key) = entry.wg_public_key {
|
||||||
|
if let Err(e) = self.remove_wg_peer(wg_key).await {
|
||||||
|
debug!("Failed to remove WG peer for disabled client {}: {}", client_id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
state.client_registry.write().await.update(client_id, |entry| {
|
state.client_registry.write().await.update(client_id, |entry| {
|
||||||
entry.enabled = Some(false);
|
entry.enabled = Some(false);
|
||||||
})?;
|
})?;
|
||||||
@@ -1043,17 +1136,33 @@ impl VpnServer {
|
|||||||
let state = self.state.as_ref()
|
let state = self.state.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
|
||||||
|
|
||||||
// Capture old WG key before rotation (needed to remove from WG listener)
|
// Capture old keys before rotation so listener and registry can be rolled back together.
|
||||||
let old_wg_pub = {
|
let old_entry = {
|
||||||
let registry = state.client_registry.read().await;
|
let registry = state.client_registry.read().await;
|
||||||
let entry = registry.get_by_id(client_id)
|
registry.get_by_id(client_id)
|
||||||
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
|
.cloned()
|
||||||
entry.wg_public_key.clone()
|
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
|
||||||
};
|
};
|
||||||
|
let old_noise_pub = old_entry.public_key.clone();
|
||||||
|
let old_wg_pub = old_entry.wg_public_key.clone();
|
||||||
|
let assigned_ip = old_entry.assigned_ip.clone().unwrap_or_else(|| "0.0.0.0".to_string());
|
||||||
|
let should_have_wg_peer = self.wg_command_tx.is_some()
|
||||||
|
&& old_entry.is_enabled()
|
||||||
|
&& !old_entry.is_expired()
|
||||||
|
&& old_wg_pub.is_some()
|
||||||
|
&& assigned_ip != "0.0.0.0";
|
||||||
|
|
||||||
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
|
||||||
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
|
||||||
|
|
||||||
|
if should_have_wg_peer {
|
||||||
|
if let Some(ref old_key) = old_wg_pub {
|
||||||
|
if let Err(e) = self.remove_wg_peer(old_key).await {
|
||||||
|
return Err(anyhow::anyhow!("Failed to remove old WG peer during rotation for {}: {}", client_id, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
state.client_registry.write().await.rotate_key(
|
state.client_registry.write().await.rotate_key(
|
||||||
client_id,
|
client_id,
|
||||||
noise_pub.clone(),
|
noise_pub.clone(),
|
||||||
@@ -1063,30 +1172,39 @@ impl VpnServer {
|
|||||||
// Disconnect existing connection (old key is no longer valid)
|
// Disconnect existing connection (old key is no longer valid)
|
||||||
let _ = self.disconnect_client(client_id).await;
|
let _ = self.disconnect_client(client_id).await;
|
||||||
|
|
||||||
// Get updated entry for the config bundle
|
// Update WG listener with the new key. Roll back the registry if this fails.
|
||||||
let entry_json = self.get_registered_client(client_id).await?;
|
if should_have_wg_peer {
|
||||||
let assigned_ip = entry_json.get("assignedIp")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.unwrap_or("0.0.0.0");
|
|
||||||
|
|
||||||
// Update WG listener: remove old peer, add new peer
|
|
||||||
if self.wg_command_tx.is_some() {
|
|
||||||
if let Some(ref old_key) = old_wg_pub {
|
|
||||||
if let Err(e) = self.remove_wg_peer(old_key).await {
|
|
||||||
debug!("Failed to remove old WG peer during rotation: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let wg_peer_config = crate::wireguard::WgPeerConfig {
|
let wg_peer_config = crate::wireguard::WgPeerConfig {
|
||||||
public_key: wg_pub.clone(),
|
public_key: wg_pub.clone(),
|
||||||
preshared_key: None,
|
preshared_key: None,
|
||||||
allowed_ips: vec![format!("{}/32", assigned_ip)],
|
allowed_ips: vec![format!("{}/32", assigned_ip.as_str())],
|
||||||
endpoint: None,
|
endpoint: None,
|
||||||
persistent_keepalive: Some(25),
|
persistent_keepalive: Some(25),
|
||||||
};
|
};
|
||||||
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
|
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
|
||||||
warn!("Failed to register new WG peer during rotation: {}", e);
|
let _ = state.client_registry.write().await.rotate_key(
|
||||||
|
client_id,
|
||||||
|
old_noise_pub,
|
||||||
|
old_wg_pub.clone(),
|
||||||
|
);
|
||||||
|
if let Some(ref old_key) = old_wg_pub {
|
||||||
|
let rollback_peer_config = crate::wireguard::WgPeerConfig {
|
||||||
|
public_key: old_key.clone(),
|
||||||
|
preshared_key: None,
|
||||||
|
allowed_ips: vec![format!("{}/32", assigned_ip.as_str())],
|
||||||
|
endpoint: None,
|
||||||
|
persistent_keepalive: Some(25),
|
||||||
|
};
|
||||||
|
if let Err(rollback_err) = self.add_wg_peer(rollback_peer_config).await {
|
||||||
|
warn!("Failed to restore old WG peer after rotation failure for {}: {}", client_id, rollback_err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return Err(anyhow::anyhow!("Failed to register new WG peer during rotation for {}: {}", client_id, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get updated entry for the config bundle
|
||||||
|
let entry_json = self.get_registered_client(client_id).await?;
|
||||||
|
|
||||||
let smartvpn_server_url = format!("wss://{}",
|
let smartvpn_server_url = format!("wss://{}",
|
||||||
state.config.server_endpoint.as_deref()
|
state.config.server_endpoint.as_deref()
|
||||||
|
|||||||
+74
-21
@@ -215,6 +215,8 @@ pub enum WgCommand {
|
|||||||
struct PeerState {
|
struct PeerState {
|
||||||
tunn: Tunn,
|
tunn: Tunn,
|
||||||
public_key_b64: String,
|
public_key_b64: String,
|
||||||
|
/// Registered SmartVPN client ID when this peer belongs to a registry entry.
|
||||||
|
client_id: Option<String>,
|
||||||
allowed_ips: Vec<AllowedIp>,
|
allowed_ips: Vec<AllowedIp>,
|
||||||
endpoint: Option<SocketAddr>,
|
endpoint: Option<SocketAddr>,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@@ -237,6 +239,28 @@ impl PeerState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synthetic_wg_client_id(pubkey: &str) -> String {
|
||||||
|
format!("wg-{}", &pubkey[..8.min(pubkey.len())])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peer_client_id(peer: &PeerState) -> String {
|
||||||
|
peer.client_id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| synthetic_wg_client_id(&peer.public_key_b64))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn registered_peer_info(
|
||||||
|
state: &Arc<ServerState>,
|
||||||
|
pubkey: &str,
|
||||||
|
) -> Option<(String, bool, bool)> {
|
||||||
|
let registry = state.client_registry.read().await;
|
||||||
|
registry.get_by_wg_key(pubkey).map(|entry| (
|
||||||
|
entry.client_id.clone(),
|
||||||
|
entry.use_host_ip.unwrap_or(false),
|
||||||
|
entry.is_enabled() && !entry.is_expired(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
fn add_peer_to_loop(
|
fn add_peer_to_loop(
|
||||||
peers: &mut Vec<PeerState>,
|
peers: &mut Vec<PeerState>,
|
||||||
@@ -281,6 +305,7 @@ fn add_peer_to_loop(
|
|||||||
peers.push(PeerState {
|
peers.push(PeerState {
|
||||||
tunn,
|
tunn,
|
||||||
public_key_b64: config.public_key.clone(),
|
public_key_b64: config.public_key.clone(),
|
||||||
|
client_id: None,
|
||||||
allowed_ips,
|
allowed_ips,
|
||||||
endpoint,
|
endpoint,
|
||||||
persistent_keepalive: config.persistent_keepalive,
|
persistent_keepalive: config.persistent_keepalive,
|
||||||
@@ -344,7 +369,7 @@ fn wg_timestamp_now() -> String {
|
|||||||
/// Returns the VPN IP.
|
/// Returns the VPN IP.
|
||||||
async fn register_wg_peer(
|
async fn register_wg_peer(
|
||||||
state: &Arc<ServerState>,
|
state: &Arc<ServerState>,
|
||||||
peer: &PeerState,
|
peer: &mut PeerState,
|
||||||
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
|
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
|
||||||
) -> Result<Option<Ipv4Addr>> {
|
) -> Result<Option<Ipv4Addr>> {
|
||||||
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
|
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
|
||||||
@@ -356,12 +381,24 @@ async fn register_wg_peer(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
|
let (client_id, use_host_ip) = match registered_peer_info(state, &peer.public_key_b64).await {
|
||||||
|
Some((client_id, use_host_ip, true)) => (client_id, use_host_ip),
|
||||||
|
Some((client_id, _, false)) => {
|
||||||
|
warn!("WG peer {} maps to disabled or expired client {}, skipping registration", peer.public_key_b64, client_id);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
None => (synthetic_wg_client_id(&peer.public_key_b64), false),
|
||||||
|
};
|
||||||
|
peer.client_id = Some(client_id.clone());
|
||||||
|
|
||||||
// Reserve IP in the pool
|
// Reserve IP in the pool
|
||||||
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
|
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
|
||||||
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
|
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
|
||||||
return Ok(None);
|
return Err(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
|
||||||
|
routing_table.write().await.insert(vpn_ip, use_host_ip);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create per-peer return channel and register in tun_routes
|
// Create per-peer return channel and register in tun_routes
|
||||||
@@ -393,7 +430,7 @@ async fn connect_wg_peer(
|
|||||||
peer: &PeerState,
|
peer: &PeerState,
|
||||||
vpn_ip: Ipv4Addr,
|
vpn_ip: Ipv4Addr,
|
||||||
) {
|
) {
|
||||||
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
|
let client_id = peer_client_id(peer);
|
||||||
let client_info = ClientInfo {
|
let client_info = ClientInfo {
|
||||||
client_id: client_id.clone(),
|
client_id: client_id.clone(),
|
||||||
assigned_ip: vpn_ip.to_string(),
|
assigned_ip: vpn_ip.to_string(),
|
||||||
@@ -426,24 +463,27 @@ async fn connect_wg_peer(
|
|||||||
/// Remove a WG peer from state.clients (disconnect without unregistering).
|
/// Remove a WG peer from state.clients (disconnect without unregistering).
|
||||||
async fn disconnect_wg_peer(
|
async fn disconnect_wg_peer(
|
||||||
state: &Arc<ServerState>,
|
state: &Arc<ServerState>,
|
||||||
pubkey: &str,
|
peer: &PeerState,
|
||||||
) {
|
) {
|
||||||
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
|
let client_id = peer_client_id(peer);
|
||||||
if state.clients.write().await.remove(&client_id).is_some() {
|
if state.clients.write().await.remove(&client_id).is_some() {
|
||||||
info!("WG peer {} disconnected (removed from active clients)", pubkey);
|
info!("WG peer {} disconnected (removed from active clients)", peer.public_key_b64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unregister a WG peer from ServerState.
|
/// Unregister a WG peer from ServerState.
|
||||||
async fn unregister_wg_peer(
|
async fn unregister_wg_peer(
|
||||||
state: &Arc<ServerState>,
|
state: &Arc<ServerState>,
|
||||||
pubkey: &str,
|
peer: &PeerState,
|
||||||
vpn_ip: Option<Ipv4Addr>,
|
vpn_ip: Option<Ipv4Addr>,
|
||||||
) {
|
) {
|
||||||
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
|
let client_id = peer_client_id(peer);
|
||||||
|
|
||||||
if let Some(ip) = vpn_ip {
|
if let Some(ip) = vpn_ip {
|
||||||
state.tun_routes.write().await.remove(&ip);
|
state.tun_routes.write().await.remove(&ip);
|
||||||
|
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
|
||||||
|
routing_table.write().await.remove(&ip);
|
||||||
|
}
|
||||||
state.ip_pool.lock().await.release(&ip);
|
state.ip_pool.lock().await.release(&ip);
|
||||||
}
|
}
|
||||||
state.clients.write().await.remove(&client_id);
|
state.clients.write().await.remove(&client_id);
|
||||||
@@ -501,6 +541,7 @@ pub async fn run_wg_listener(
|
|||||||
peers.push(PeerState {
|
peers.push(PeerState {
|
||||||
tunn,
|
tunn,
|
||||||
public_key_b64: peer_config.public_key.clone(),
|
public_key_b64: peer_config.public_key.clone(),
|
||||||
|
client_id: None,
|
||||||
allowed_ips,
|
allowed_ips,
|
||||||
endpoint,
|
endpoint,
|
||||||
persistent_keepalive: peer_config.persistent_keepalive,
|
persistent_keepalive: peer_config.persistent_keepalive,
|
||||||
@@ -523,10 +564,14 @@ pub async fn run_wg_listener(
|
|||||||
// Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients)
|
// Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients)
|
||||||
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
|
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
|
||||||
for peer in peers.iter_mut() {
|
for peer in peers.iter_mut() {
|
||||||
if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await {
|
match register_wg_peer(&state, peer, &wg_return_tx).await {
|
||||||
|
Ok(Some(ip)) => {
|
||||||
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
|
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
|
||||||
peer.vpn_ip = Some(ip);
|
peer.vpn_ip = Some(ip);
|
||||||
}
|
}
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => warn!("Failed to register initial WG peer {}: {}", peer.public_key_b64, e),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Buffers
|
// Buffers
|
||||||
@@ -705,7 +750,7 @@ pub async fn run_wg_listener(
|
|||||||
warn!("WG peer {} connection expired", peer.public_key_b64);
|
warn!("WG peer {} connection expired", peer.public_key_b64);
|
||||||
if peer.is_connected {
|
if peer.is_connected {
|
||||||
peer.is_connected = false;
|
peer.is_connected = false;
|
||||||
disconnect_wg_peer(&state, &peer.public_key_b64).await;
|
disconnect_wg_peer(&state, peer).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TunnResult::Err(e) => {
|
TunnResult::Err(e) => {
|
||||||
@@ -733,7 +778,7 @@ pub async fn run_wg_listener(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Only update ClientInfo if peer is connected (in state.clients)
|
// Only update ClientInfo if peer is connected (in state.clients)
|
||||||
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
|
let client_id = peer_client_id(peer);
|
||||||
if let Some(info) = clients.get_mut(&client_id) {
|
if let Some(info) = clients.get_mut(&client_id) {
|
||||||
info.bytes_sent = peer.stats.bytes_sent;
|
info.bytes_sent = peer.stats.bytes_sent;
|
||||||
info.bytes_received = peer.stats.bytes_received;
|
info.bytes_received = peer.stats.bytes_received;
|
||||||
@@ -751,7 +796,7 @@ pub async fn run_wg_listener(
|
|||||||
if now.duration_since(last) > std::time::Duration::from_secs(180) {
|
if now.duration_since(last) > std::time::Duration::from_secs(180) {
|
||||||
info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64);
|
info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64);
|
||||||
peer.is_connected = false;
|
peer.is_connected = false;
|
||||||
disconnect_wg_peer(&state, &peer.public_key_b64).await;
|
disconnect_wg_peer(&state, peer).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -762,7 +807,12 @@ pub async fn run_wg_listener(
|
|||||||
cmd = command_rx.recv() => {
|
cmd = command_rx.recv() => {
|
||||||
match cmd {
|
match cmd {
|
||||||
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
|
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
|
||||||
let result = add_peer_to_loop(
|
if let Some((client_id, _, false)) = registered_peer_info(&state, &peer_config.public_key).await {
|
||||||
|
let _ = resp_tx.send(Err(anyhow!("WG peer maps to disabled or expired client: {}", client_id)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = add_peer_to_loop(
|
||||||
&mut peers,
|
&mut peers,
|
||||||
&peer_config,
|
&peer_config,
|
||||||
&peer_index,
|
&peer_index,
|
||||||
@@ -777,20 +827,23 @@ pub async fn run_wg_listener(
|
|||||||
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
|
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
|
||||||
peer.vpn_ip = Some(ip);
|
peer.vpn_ip = Some(ip);
|
||||||
}
|
}
|
||||||
Ok(None) => {}
|
Ok(None) => {
|
||||||
|
peers.retain(|p| p.public_key_b64 != peer_config.public_key);
|
||||||
|
result = Err(anyhow!("WG peer was not registered"));
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to register WG peer: {}", e);
|
peers.retain(|p| p.public_key_b64 != peer_config.public_key);
|
||||||
|
result = Err(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let _ = resp_tx.send(result);
|
let _ = resp_tx.send(result);
|
||||||
}
|
}
|
||||||
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
|
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
|
||||||
let prev_len = peers.len();
|
if let Some(index) = peers.iter().position(|p| p.public_key_b64 == pubkey) {
|
||||||
peers.retain(|p| p.public_key_b64 != pubkey);
|
let peer = peers.remove(index);
|
||||||
if peers.len() < prev_len {
|
|
||||||
let vpn_ip = peer_vpn_ips.remove(&pubkey);
|
let vpn_ip = peer_vpn_ips.remove(&pubkey);
|
||||||
unregister_wg_peer(&state, &pubkey, vpn_ip).await;
|
unregister_wg_peer(&state, &peer, vpn_ip).await;
|
||||||
let _ = resp_tx.send(Ok(()));
|
let _ = resp_tx.send(Ok(()));
|
||||||
} else {
|
} else {
|
||||||
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
|
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
|
||||||
@@ -814,7 +867,7 @@ pub async fn run_wg_listener(
|
|||||||
// Cleanup: unregister all peers from ServerState
|
// Cleanup: unregister all peers from ServerState
|
||||||
for peer in &peers {
|
for peer in &peers {
|
||||||
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
|
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
|
||||||
unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await;
|
unregister_wg_peer(&state, peer, vpn_ip).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("WireGuard listener stopped");
|
info!("WireGuard listener stopped");
|
||||||
|
|||||||
Reference in New Issue
Block a user