206 lines
6.5 KiB
Rust
206 lines
6.5 KiB
Rust
use anyhow::Result;
|
|
use std::collections::HashMap;
|
|
use std::net::Ipv4Addr;
|
|
use tracing::{info, warn};
|
|
|
|
/// IP pool manager for allocating VPN client addresses from a subnet.
|
|
pub struct IpPool {
|
|
/// Network address (e.g., 10.8.0.0)
|
|
network: Ipv4Addr,
|
|
/// Prefix length (e.g., 24)
|
|
prefix_len: u8,
|
|
/// Allocated IPs: IP -> client_id
|
|
allocated: HashMap<Ipv4Addr, String>,
|
|
/// Next candidate offset (skipping .0 network and .1 gateway)
|
|
next_offset: u32,
|
|
}
|
|
|
|
impl IpPool {
|
|
/// Create a new IP pool from a CIDR subnet string (e.g., "10.8.0.0/24").
|
|
pub fn new(subnet: &str) -> Result<Self> {
|
|
let parts: Vec<&str> = subnet.split('/').collect();
|
|
if parts.len() != 2 {
|
|
anyhow::bail!("Invalid subnet format: {}", subnet);
|
|
}
|
|
let network: Ipv4Addr = parts[0].parse()?;
|
|
let prefix_len: u8 = parts[1].parse()?;
|
|
if prefix_len > 30 {
|
|
anyhow::bail!("Prefix too long for VPN pool: /{}", prefix_len);
|
|
}
|
|
|
|
Ok(Self {
|
|
network,
|
|
prefix_len,
|
|
allocated: HashMap::new(),
|
|
next_offset: 2, // Skip .0 (network) and .1 (server/gateway)
|
|
})
|
|
}
|
|
|
|
/// Get the gateway/server address (first usable IP, e.g., 10.8.0.1).
|
|
pub fn gateway_addr(&self) -> Ipv4Addr {
|
|
let net_u32 = u32::from(self.network);
|
|
Ipv4Addr::from(net_u32 + 1)
|
|
}
|
|
|
|
/// Total number of usable client addresses in the pool.
|
|
pub fn capacity(&self) -> u32 {
|
|
let host_bits = 32 - self.prefix_len as u32;
|
|
let total = 1u32 << host_bits;
|
|
total.saturating_sub(3) // minus network, gateway, broadcast
|
|
}
|
|
|
|
/// Allocate an IP for a client. Returns the assigned IP.
|
|
pub fn allocate(&mut self, client_id: &str) -> Result<Ipv4Addr> {
|
|
let host_bits = 32 - self.prefix_len as u32;
|
|
let max_offset = (1u32 << host_bits) - 1; // broadcast offset
|
|
|
|
// Try to find a free IP starting from next_offset
|
|
let start = self.next_offset;
|
|
let mut offset = start;
|
|
loop {
|
|
if offset >= max_offset {
|
|
offset = 2; // wrap around
|
|
}
|
|
|
|
let ip = Ipv4Addr::from(u32::from(self.network) + offset);
|
|
if !self.allocated.contains_key(&ip) {
|
|
self.allocated.insert(ip, client_id.to_string());
|
|
self.next_offset = offset + 1;
|
|
info!("Allocated IP {} for client {}", ip, client_id);
|
|
return Ok(ip);
|
|
}
|
|
|
|
offset += 1;
|
|
if offset == start {
|
|
anyhow::bail!("IP pool exhausted");
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Release an IP back to the pool.
|
|
pub fn release(&mut self, ip: &Ipv4Addr) -> Option<String> {
|
|
let client_id = self.allocated.remove(ip);
|
|
if let Some(ref id) = client_id {
|
|
info!("Released IP {} from client {}", ip, id);
|
|
}
|
|
client_id
|
|
}
|
|
|
|
/// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips).
|
|
pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> {
|
|
if self.allocated.contains_key(&ip) {
|
|
anyhow::bail!("IP {} is already allocated", ip);
|
|
}
|
|
self.allocated.insert(ip, client_id.to_string());
|
|
info!("Reserved IP {} for client {}", ip, client_id);
|
|
Ok(())
|
|
}
|
|
|
|
/// Number of currently allocated IPs.
|
|
pub fn allocated_count(&self) -> usize {
|
|
self.allocated.len()
|
|
}
|
|
}
|
|
|
|
/// Enable IP forwarding on Linux.
|
|
pub fn enable_ip_forwarding() -> Result<()> {
|
|
std::fs::write("/proc/sys/net/ipv4/ip_forward", "1")?;
|
|
info!("Enabled IPv4 forwarding");
|
|
Ok(())
|
|
}
|
|
|
|
/// Set up NAT/masquerade using iptables for a given subnet and outbound interface.
|
|
pub async fn setup_nat(subnet: &str, interface: &str) -> Result<()> {
|
|
let output = tokio::process::Command::new("iptables")
|
|
.args([
|
|
"-t", "nat", "-A", "POSTROUTING",
|
|
"-s", subnet,
|
|
"-o", interface,
|
|
"-j", "MASQUERADE",
|
|
])
|
|
.output()
|
|
.await?;
|
|
|
|
if !output.status.success() {
|
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
anyhow::bail!("iptables NAT setup failed: {}", stderr);
|
|
}
|
|
|
|
info!("NAT masquerade set up for {} via {}", subnet, interface);
|
|
Ok(())
|
|
}
|
|
|
|
/// Remove NAT/masquerade rule.
|
|
pub async fn remove_nat(subnet: &str, interface: &str) -> Result<()> {
|
|
let output = tokio::process::Command::new("iptables")
|
|
.args([
|
|
"-t", "nat", "-D", "POSTROUTING",
|
|
"-s", subnet,
|
|
"-o", interface,
|
|
"-j", "MASQUERADE",
|
|
])
|
|
.output()
|
|
.await?;
|
|
|
|
if !output.status.success() {
|
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
warn!("iptables NAT removal failed (may not exist): {}", stderr);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Get the default outbound network interface name.
|
|
pub fn get_default_interface() -> Result<String> {
|
|
// Parse /proc/net/route for the default route
|
|
let content = std::fs::read_to_string("/proc/net/route")?;
|
|
for line in content.lines().skip(1) {
|
|
let fields: Vec<&str> = line.split_whitespace().collect();
|
|
if fields.len() >= 2 && fields[1] == "00000000" {
|
|
return Ok(fields[0].to_string());
|
|
}
|
|
}
|
|
anyhow::bail!("Could not determine default network interface")
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn ip_pool_basic() {
|
|
let mut pool = IpPool::new("10.8.0.0/24").unwrap();
|
|
assert_eq!(pool.gateway_addr(), Ipv4Addr::new(10, 8, 0, 1));
|
|
assert_eq!(pool.capacity(), 253); // 256 - 3 (net, gw, broadcast)
|
|
|
|
let ip1 = pool.allocate("client1").unwrap();
|
|
assert_eq!(ip1, Ipv4Addr::new(10, 8, 0, 2));
|
|
|
|
let ip2 = pool.allocate("client2").unwrap();
|
|
assert_eq!(ip2, Ipv4Addr::new(10, 8, 0, 3));
|
|
|
|
assert_eq!(pool.allocated_count(), 2);
|
|
|
|
pool.release(&ip1);
|
|
assert_eq!(pool.allocated_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn ip_pool_small_subnet() {
|
|
let mut pool = IpPool::new("192.168.1.0/30").unwrap();
|
|
// /30 = 4 addresses: .0 net, .1 gw, .2 client, .3 broadcast
|
|
assert_eq!(pool.capacity(), 1);
|
|
|
|
let ip = pool.allocate("client1").unwrap();
|
|
assert_eq!(ip, Ipv4Addr::new(192, 168, 1, 2));
|
|
|
|
// Pool should be exhausted
|
|
assert!(pool.allocate("client2").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn ip_pool_invalid_subnet() {
|
|
assert!(IpPool::new("invalid").is_err());
|
|
assert!(IpPool::new("10.8.0.0/31").is_err());
|
|
}
|
|
}
|