use anyhow::Result; use serde::{Deserialize, Serialize}; use std::collections::HashMap; /// Per-client rate limiting configuration. #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientRateLimit { pub bytes_per_sec: u64, pub burst_bytes: u64, } /// Per-client security settings — aligned with SmartProxy's IRouteSecurity pattern. #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientSecurity { /// Source IPs/CIDRs the client may connect FROM (empty/None = any). pub ip_allow_list: Option>, /// Source IPs blocked — overrides ip_allow_list (deny wins). pub ip_block_list: Option>, /// Destination IPs/CIDRs the client may reach (empty/None = all). pub destination_allow_list: Option>, /// Destination IPs blocked — overrides destination_allow_list (deny wins). pub destination_block_list: Option>, /// Max concurrent connections from this client. pub max_connections: Option, /// Per-client rate limiting. pub rate_limit: Option, } /// A registered client entry — the server-side source of truth. #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientEntry { /// Human-readable client ID (e.g. "alice-laptop"). pub client_id: String, /// Client's Noise IK public key (base64). pub public_key: String, /// Client's WireGuard public key (base64) — optional. pub wg_public_key: Option, /// Security settings (ACLs, rate limits). pub security: Option, /// Traffic priority (lower = higher priority, default: 100). pub priority: Option, /// Whether this client is enabled (default: true). pub enabled: Option, /// Tags assigned by the server admin — trusted, used for access control. pub server_defined_client_tags: Option>, /// Tags reported by the connecting client — informational only. pub client_defined_client_tags: Option>, /// Legacy tags field — treated as serverDefinedClientTags during deserialization. #[serde(default)] pub tags: Option>, /// Optional description. pub description: Option, /// Optional expiry (ISO 8601 timestamp). pub expires_at: Option, /// Assigned VPN IP address. pub assigned_ip: Option, } impl ClientEntry { /// Whether this client is considered enabled (defaults to true). pub fn is_enabled(&self) -> bool { self.enabled.unwrap_or(true) } /// Whether this client has expired based on current time. pub fn is_expired(&self) -> bool { if let Some(ref expires) = self.expires_at { if let Ok(expiry) = chrono::DateTime::parse_from_rfc3339(expires) { return chrono::Utc::now() > expiry; } } false } } /// In-memory client registry with dual-key indexing. pub struct ClientRegistry { /// Primary index: clientId → ClientEntry entries: HashMap, /// Secondary index: publicKey (base64) → clientId (fast lookup during handshake) key_index: HashMap, } impl ClientRegistry { pub fn new() -> Self { Self { entries: HashMap::new(), key_index: HashMap::new(), } } /// Build a registry from a list of client entries. pub fn from_entries(entries: Vec) -> Result { let mut registry = Self::new(); for mut entry in entries { // Migrate legacy `tags` → `serverDefinedClientTags` if entry.server_defined_client_tags.is_none() && entry.tags.is_some() { entry.server_defined_client_tags = entry.tags.take(); } registry.add(entry)?; } Ok(registry) } /// Add a client to the registry. pub fn add(&mut self, entry: ClientEntry) -> Result<()> { if self.entries.contains_key(&entry.client_id) { anyhow::bail!("Client '{}' already exists", entry.client_id); } if self.key_index.contains_key(&entry.public_key) { anyhow::bail!("Public key already registered to another client"); } self.key_index.insert(entry.public_key.clone(), entry.client_id.clone()); self.entries.insert(entry.client_id.clone(), entry); Ok(()) } /// Remove a client by ID. pub fn remove(&mut self, client_id: &str) -> Result { let entry = self.entries.remove(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; self.key_index.remove(&entry.public_key); Ok(entry) } /// Get a client by ID. pub fn get_by_id(&self, client_id: &str) -> Option<&ClientEntry> { self.entries.get(client_id) } /// Get a client by public key (used during IK handshake verification). pub fn get_by_key(&self, public_key: &str) -> Option<&ClientEntry> { let client_id = self.key_index.get(public_key)?; self.entries.get(client_id) } /// Check if a public key is authorized (exists, enabled, not expired). pub fn is_authorized(&self, public_key: &str) -> bool { match self.get_by_key(public_key) { Some(entry) => entry.is_enabled() && !entry.is_expired(), None => false, } } /// Update a client entry. The closure receives a mutable reference to the entry. pub fn update(&mut self, client_id: &str, updater: F) -> Result<()> where F: FnOnce(&mut ClientEntry), { let entry = self.entries.get_mut(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; let old_key = entry.public_key.clone(); updater(entry); // If public key changed, update the index if entry.public_key != old_key { self.key_index.remove(&old_key); self.key_index.insert(entry.public_key.clone(), client_id.to_string()); } Ok(()) } /// List all client entries. pub fn list(&self) -> Vec<&ClientEntry> { self.entries.values().collect() } /// Rotate a client's keys. Returns the updated entry. pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option) -> Result<()> { let entry = self.entries.get_mut(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; // Update key index self.key_index.remove(&entry.public_key); entry.public_key = new_public_key.clone(); entry.wg_public_key = new_wg_public_key; self.key_index.insert(new_public_key, client_id.to_string()); Ok(()) } /// Number of registered clients. pub fn len(&self) -> usize { self.entries.len() } /// Whether the registry is empty. pub fn is_empty(&self) -> bool { self.entries.is_empty() } } #[cfg(test)] mod tests { use super::*; fn make_entry(id: &str, key: &str) -> ClientEntry { ClientEntry { client_id: id.to_string(), public_key: key.to_string(), wg_public_key: None, security: None, priority: None, enabled: None, server_defined_client_tags: None, client_defined_client_tags: None, tags: None, description: None, expires_at: None, assigned_ip: None, } } #[test] fn add_and_lookup() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "key_alice")).unwrap(); assert!(reg.get_by_id("alice").is_some()); assert!(reg.get_by_key("key_alice").is_some()); assert_eq!(reg.get_by_key("key_alice").unwrap().client_id, "alice"); assert!(reg.get_by_id("bob").is_none()); assert!(reg.get_by_key("key_bob").is_none()); } #[test] fn reject_duplicate_id() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "key1")).unwrap(); assert!(reg.add(make_entry("alice", "key2")).is_err()); } #[test] fn reject_duplicate_key() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "same_key")).unwrap(); assert!(reg.add(make_entry("bob", "same_key")).is_err()); } #[test] fn remove_client() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "key_alice")).unwrap(); assert_eq!(reg.len(), 1); let removed = reg.remove("alice").unwrap(); assert_eq!(removed.client_id, "alice"); assert_eq!(reg.len(), 0); assert!(reg.get_by_key("key_alice").is_none()); } #[test] fn remove_nonexistent_fails() { let mut reg = ClientRegistry::new(); assert!(reg.remove("ghost").is_err()); } #[test] fn is_authorized_enabled() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "key_alice")).unwrap(); assert!(reg.is_authorized("key_alice")); // enabled by default } #[test] fn is_authorized_disabled() { let mut reg = ClientRegistry::new(); let mut entry = make_entry("alice", "key_alice"); entry.enabled = Some(false); reg.add(entry).unwrap(); assert!(!reg.is_authorized("key_alice")); } #[test] fn is_authorized_expired() { let mut reg = ClientRegistry::new(); let mut entry = make_entry("alice", "key_alice"); entry.expires_at = Some("2020-01-01T00:00:00Z".to_string()); reg.add(entry).unwrap(); assert!(!reg.is_authorized("key_alice")); } #[test] fn is_authorized_future_expiry() { let mut reg = ClientRegistry::new(); let mut entry = make_entry("alice", "key_alice"); entry.expires_at = Some("2099-01-01T00:00:00Z".to_string()); reg.add(entry).unwrap(); assert!(reg.is_authorized("key_alice")); } #[test] fn is_authorized_unknown_key() { let reg = ClientRegistry::new(); assert!(!reg.is_authorized("nonexistent")); } #[test] fn update_client() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "key_alice")).unwrap(); reg.update("alice", |entry| { entry.description = Some("Updated".to_string()); entry.enabled = Some(false); }).unwrap(); let entry = reg.get_by_id("alice").unwrap(); assert_eq!(entry.description.as_deref(), Some("Updated")); assert!(!entry.is_enabled()); } #[test] fn update_nonexistent_fails() { let mut reg = ClientRegistry::new(); assert!(reg.update("ghost", |_| {}).is_err()); } #[test] fn rotate_key() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "old_key")).unwrap(); reg.rotate_key("alice", "new_key".to_string(), None).unwrap(); assert!(reg.get_by_key("old_key").is_none()); assert!(reg.get_by_key("new_key").is_some()); assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key"); } #[test] fn from_entries() { let entries = vec![ make_entry("alice", "key_a"), make_entry("bob", "key_b"), ]; let reg = ClientRegistry::from_entries(entries).unwrap(); assert_eq!(reg.len(), 2); assert!(reg.get_by_key("key_a").is_some()); assert!(reg.get_by_key("key_b").is_some()); } #[test] fn list_clients() { let mut reg = ClientRegistry::new(); reg.add(make_entry("alice", "key_a")).unwrap(); reg.add(make_entry("bob", "key_b")).unwrap(); let list = reg.list(); assert_eq!(list.len(), 2); } #[test] fn security_with_rate_limit() { let mut entry = make_entry("alice", "key_alice"); entry.security = Some(ClientSecurity { ip_allow_list: Some(vec!["192.168.1.0/24".to_string()]), ip_block_list: Some(vec!["192.168.1.100".to_string()]), destination_allow_list: None, destination_block_list: None, max_connections: Some(5), rate_limit: Some(ClientRateLimit { bytes_per_sec: 1_000_000, burst_bytes: 2_000_000, }), }); let mut reg = ClientRegistry::new(); reg.add(entry).unwrap(); let e = reg.get_by_id("alice").unwrap(); let sec = e.security.as_ref().unwrap(); assert_eq!(sec.rate_limit.as_ref().unwrap().bytes_per_sec, 1_000_000); assert_eq!(sec.max_connections, Some(5)); } }