use anyhow::Result; use std::collections::HashMap; use std::net::Ipv4Addr; use tracing::{info, warn}; /// IP pool manager for allocating VPN client addresses from a subnet. pub struct IpPool { /// Network address (e.g., 10.8.0.0) network: Ipv4Addr, /// Prefix length (e.g., 24) prefix_len: u8, /// Allocated IPs: IP -> client_id allocated: HashMap, /// Next candidate offset (skipping .0 network and .1 gateway) next_offset: u32, } impl IpPool { /// Create a new IP pool from a CIDR subnet string (e.g., "10.8.0.0/24"). pub fn new(subnet: &str) -> Result { let parts: Vec<&str> = subnet.split('/').collect(); if parts.len() != 2 { anyhow::bail!("Invalid subnet format: {}", subnet); } let network: Ipv4Addr = parts[0].parse()?; let prefix_len: u8 = parts[1].parse()?; if prefix_len > 30 { anyhow::bail!("Prefix too long for VPN pool: /{}", prefix_len); } Ok(Self { network, prefix_len, allocated: HashMap::new(), next_offset: 2, // Skip .0 (network) and .1 (server/gateway) }) } /// Get the gateway/server address (first usable IP, e.g., 10.8.0.1). pub fn gateway_addr(&self) -> Ipv4Addr { let net_u32 = u32::from(self.network); Ipv4Addr::from(net_u32 + 1) } /// Total number of usable client addresses in the pool. pub fn capacity(&self) -> u32 { let host_bits = 32 - self.prefix_len as u32; let total = 1u32 << host_bits; total.saturating_sub(3) // minus network, gateway, broadcast } /// Allocate an IP for a client. Returns the assigned IP. pub fn allocate(&mut self, client_id: &str) -> Result { let host_bits = 32 - self.prefix_len as u32; let max_offset = (1u32 << host_bits) - 1; // broadcast offset // Try to find a free IP starting from next_offset let start = self.next_offset; let mut offset = start; loop { if offset >= max_offset { offset = 2; // wrap around } let ip = Ipv4Addr::from(u32::from(self.network) + offset); if !self.allocated.contains_key(&ip) { self.allocated.insert(ip, client_id.to_string()); self.next_offset = offset + 1; info!("Allocated IP {} for client {}", ip, client_id); return Ok(ip); } offset += 1; if offset == start { anyhow::bail!("IP pool exhausted"); } } } /// Release an IP back to the pool. pub fn release(&mut self, ip: &Ipv4Addr) -> Option { let client_id = self.allocated.remove(ip); if let Some(ref id) = client_id { info!("Released IP {} from client {}", ip, id); } client_id } /// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips). pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> { if self.allocated.contains_key(&ip) { anyhow::bail!("IP {} is already allocated", ip); } self.allocated.insert(ip, client_id.to_string()); info!("Reserved IP {} for client {}", ip, client_id); Ok(()) } /// Number of currently allocated IPs. pub fn allocated_count(&self) -> usize { self.allocated.len() } } /// Enable IP forwarding on Linux. pub fn enable_ip_forwarding() -> Result<()> { std::fs::write("/proc/sys/net/ipv4/ip_forward", "1")?; info!("Enabled IPv4 forwarding"); Ok(()) } /// Set up NAT/masquerade using iptables for a given subnet and outbound interface. pub async fn setup_nat(subnet: &str, interface: &str) -> Result<()> { let output = tokio::process::Command::new("iptables") .args([ "-t", "nat", "-A", "POSTROUTING", "-s", subnet, "-o", interface, "-j", "MASQUERADE", ]) .output() .await?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); anyhow::bail!("iptables NAT setup failed: {}", stderr); } info!("NAT masquerade set up for {} via {}", subnet, interface); Ok(()) } /// Remove NAT/masquerade rule. pub async fn remove_nat(subnet: &str, interface: &str) -> Result<()> { let output = tokio::process::Command::new("iptables") .args([ "-t", "nat", "-D", "POSTROUTING", "-s", subnet, "-o", interface, "-j", "MASQUERADE", ]) .output() .await?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); warn!("iptables NAT removal failed (may not exist): {}", stderr); } Ok(()) } /// Get the default outbound network interface name. pub fn get_default_interface() -> Result { // Parse /proc/net/route for the default route let content = std::fs::read_to_string("/proc/net/route")?; for line in content.lines().skip(1) { let fields: Vec<&str> = line.split_whitespace().collect(); if fields.len() >= 2 && fields[1] == "00000000" { return Ok(fields[0].to_string()); } } anyhow::bail!("Could not determine default network interface") } #[cfg(test)] mod tests { use super::*; #[test] fn ip_pool_basic() { let mut pool = IpPool::new("10.8.0.0/24").unwrap(); assert_eq!(pool.gateway_addr(), Ipv4Addr::new(10, 8, 0, 1)); assert_eq!(pool.capacity(), 253); // 256 - 3 (net, gw, broadcast) let ip1 = pool.allocate("client1").unwrap(); assert_eq!(ip1, Ipv4Addr::new(10, 8, 0, 2)); let ip2 = pool.allocate("client2").unwrap(); assert_eq!(ip2, Ipv4Addr::new(10, 8, 0, 3)); assert_eq!(pool.allocated_count(), 2); pool.release(&ip1); assert_eq!(pool.allocated_count(), 1); } #[test] fn ip_pool_small_subnet() { let mut pool = IpPool::new("192.168.1.0/30").unwrap(); // /30 = 4 addresses: .0 net, .1 gw, .2 client, .3 broadcast assert_eq!(pool.capacity(), 1); let ip = pool.allocate("client1").unwrap(); assert_eq!(ip, Ipv4Addr::new(192, 168, 1, 2)); // Pool should be exhausted assert!(pool.allocate("client2").is_err()); } #[test] fn ip_pool_invalid_subnet() { assert!(IpPool::new("invalid").is_err()); assert!(IpPool::new("10.8.0.0/31").is_err()); } }