use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::Instant; use anyhow::{anyhow, Result}; use base64::engine::general_purpose::STANDARD as BASE64; use base64::Engine; use boringtun::noise::rate_limiter::RateLimiter; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::{PublicKey, StaticSecret}; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, RwLock}; use tracing::{debug, error, info, warn}; use crate::network; use crate::tunnel::extract_dst_ip; use crate::tunnel::{self, TunConfig}; // ============================================================================ // Constants // ============================================================================ const MAX_UDP_PACKET: usize = 65536; const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET; /// Minimum dst buffer size for boringtun encapsulate/decapsulate const _MIN_DST_BUF: usize = 148; const TIMER_TICK_MS: u64 = 100; const DEFAULT_WG_PORT: u16 = 51820; const DEFAULT_TUN_ADDRESS: &str = "10.8.0.1"; const DEFAULT_TUN_NETMASK: &str = "255.255.255.0"; const DEFAULT_MTU: u16 = 1420; // ============================================================================ // Configuration types // ============================================================================ #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgPeerConfig { pub public_key: String, #[serde(default)] pub preshared_key: Option, pub allowed_ips: Vec, #[serde(default)] pub endpoint: Option, #[serde(default)] pub persistent_keepalive: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct WgServerConfig { pub private_key: String, #[serde(default)] pub listen_port: Option, #[serde(default)] pub tun_address: Option, #[serde(default)] pub tun_netmask: Option, #[serde(default)] pub mtu: Option, pub peers: Vec, #[serde(default)] pub dns: Option>, #[serde(default)] pub enable_nat: Option, #[serde(default)] pub subnet: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct WgClientConfig { pub private_key: String, pub address: String, #[serde(default)] pub address_prefix: Option, #[serde(default)] pub dns: Option>, #[serde(default)] pub mtu: Option, pub peer: WgPeerConfig, } // ============================================================================ // Stats types // ============================================================================ #[derive(Debug, Clone, Default, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgPeerStats { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub last_handshake_time: Option, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgPeerInfo { pub public_key: String, pub allowed_ips: Vec, pub endpoint: Option, pub persistent_keepalive: Option, #[serde(flatten)] pub stats: WgPeerStats, } #[derive(Debug, Clone, Default, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgServerStats { pub total_bytes_sent: u64, pub total_bytes_received: u64, pub total_packets_sent: u64, pub total_packets_received: u64, pub active_peers: usize, pub uptime_seconds: f64, } // ============================================================================ // Key generation and parsing // ============================================================================ /// Generate a WireGuard-compatible X25519 keypair. /// Returns (public_key_base64, private_key_base64). pub fn generate_wg_keypair() -> (String, String) { let private = StaticSecret::random_from_rng(OsRng); let public = PublicKey::from(&private); let priv_b64 = BASE64.encode(private.to_bytes()); let pub_b64 = BASE64.encode(public.to_bytes()); (pub_b64, priv_b64) } fn parse_private_key(b64: &str) -> Result { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { return Err(anyhow!("Private key must be 32 bytes, got {}", bytes.len())); } let mut arr = [0u8; 32]; arr.copy_from_slice(&bytes); Ok(StaticSecret::from(arr)) } fn parse_public_key(b64: &str) -> Result { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { return Err(anyhow!("Public key must be 32 bytes, got {}", bytes.len())); } let mut arr = [0u8; 32]; arr.copy_from_slice(&bytes); Ok(PublicKey::from(arr)) } fn parse_preshared_key(b64: &str) -> Result<[u8; 32]> { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { return Err(anyhow!( "Preshared key must be 32 bytes, got {}", bytes.len() )); } let mut arr = [0u8; 32]; arr.copy_from_slice(&bytes); Ok(arr) } // ============================================================================ // AllowedIPs matching // ============================================================================ #[derive(Debug, Clone)] struct AllowedIp { addr: IpAddr, prefix_len: u8, } impl AllowedIp { fn parse(cidr: &str) -> Result { let parts: Vec<&str> = cidr.split('/').collect(); if parts.len() != 2 { return Err(anyhow!("Invalid CIDR: {}", cidr)); } let addr: IpAddr = parts[0].parse()?; let prefix_len: u8 = parts[1].parse()?; match addr { IpAddr::V4(_) if prefix_len > 32 => { return Err(anyhow!("IPv4 prefix length {} > 32", prefix_len)) } IpAddr::V6(_) if prefix_len > 128 => { return Err(anyhow!("IPv6 prefix length {} > 128", prefix_len)) } _ => {} } Ok(Self { addr, prefix_len }) } fn matches(&self, ip: IpAddr) -> bool { match (self.addr, ip) { (IpAddr::V4(net), IpAddr::V4(target)) => { if self.prefix_len == 0 { return true; } if self.prefix_len >= 32 { return net == target; } let mask = u32::MAX << (32 - self.prefix_len); (u32::from(net) & mask) == (u32::from(target) & mask) } (IpAddr::V6(net), IpAddr::V6(target)) => { if self.prefix_len == 0 { return true; } if self.prefix_len >= 128 { return net == target; } let net_bits = u128::from(net); let target_bits = u128::from(target); let mask = u128::MAX << (128 - self.prefix_len); (net_bits & mask) == (target_bits & mask) } _ => false, } } } // ============================================================================ // Dynamic peer management commands // ============================================================================ enum WgCommand { AddPeer(WgPeerConfig, oneshot::Sender>), RemovePeer(String, oneshot::Sender>), } // ============================================================================ // Internal peer state (owned by event loop) // ============================================================================ struct PeerState { tunn: Tunn, public_key_b64: String, allowed_ips: Vec, endpoint: Option, #[allow(dead_code)] persistent_keepalive: Option, stats: WgPeerStats, } impl PeerState { fn matches_dst(&self, dst_ip: IpAddr) -> bool { self.allowed_ips.iter().any(|aip| aip.matches(dst_ip)) } } // ============================================================================ // WgServer // ============================================================================ pub struct WgServer { shutdown_tx: Option>, command_tx: Option>, shared_stats: Arc>>, server_stats: Arc>, started_at: Option, listen_port: Option, } impl WgServer { pub fn new() -> Self { Self { shutdown_tx: None, command_tx: None, shared_stats: Arc::new(RwLock::new(HashMap::new())), server_stats: Arc::new(RwLock::new(WgServerStats::default())), started_at: None, listen_port: None, } } pub fn is_running(&self) -> bool { self.shutdown_tx.is_some() } pub async fn start(&mut self, config: WgServerConfig) -> Result<()> { if self.is_running() { return Err(anyhow!("WireGuard server is already running")); } let listen_port = config.listen_port.unwrap_or(DEFAULT_WG_PORT); let tun_address = config .tun_address .as_deref() .unwrap_or(DEFAULT_TUN_ADDRESS); let tun_netmask = config .tun_netmask .as_deref() .unwrap_or(DEFAULT_TUN_NETMASK); let mtu = config.mtu.unwrap_or(DEFAULT_MTU); // Parse server private key let server_private = parse_private_key(&config.private_key)?; let server_public = PublicKey::from(&server_private); // Create rate limiter for DDoS protection let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64)); // Build peer state let peer_index = AtomicU32::new(0); let mut peers: Vec = Vec::with_capacity(config.peers.len()); for peer_config in &config.peers { let peer_public = parse_public_key(&peer_config.public_key)?; let psk = match &peer_config.preshared_key { Some(k) => Some(parse_preshared_key(k)?), None => None, }; let idx = peer_index.fetch_add(1, Ordering::Relaxed); // Clone the private key for each Tunn (StaticSecret doesn't implement Clone, // so re-parse from config) let priv_copy = parse_private_key(&config.private_key)?; let tunn = Tunn::new( priv_copy, peer_public, psk, peer_config.persistent_keepalive, idx, Some(rate_limiter.clone()), ); let allowed_ips: Vec = peer_config .allowed_ips .iter() .map(|cidr| AllowedIp::parse(cidr)) .collect::>>()?; let endpoint = match &peer_config.endpoint { Some(ep) => Some(ep.parse::()?), None => None, }; peers.push(PeerState { tunn, public_key_b64: peer_config.public_key.clone(), allowed_ips, endpoint, persistent_keepalive: peer_config.persistent_keepalive, stats: WgPeerStats::default(), }); } // Create TUN device let tun_config = TunConfig { name: "wg0".to_string(), address: tun_address.parse()?, netmask: tun_netmask.parse()?, mtu, }; let tun_device = tunnel::create_tun(&tun_config)?; info!("WireGuard TUN device created: {}", tun_config.name); // Bind UDP socket let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", listen_port)).await?; info!("WireGuard server listening on UDP port {}", listen_port); // Enable IP forwarding and NAT if requested if config.enable_nat.unwrap_or(false) { network::enable_ip_forwarding()?; let subnet = config .subnet .as_deref() .unwrap_or("10.8.0.0/24"); let iface = network::get_default_interface()?; network::setup_nat(subnet, &iface).await?; info!("NAT enabled for subnet {} via {}", subnet, iface); } // Channels let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); let (command_tx, command_rx) = mpsc::channel::(32); let shared_stats = self.shared_stats.clone(); let server_stats = self.server_stats.clone(); let started_at = Instant::now(); // Initialize shared stats { let mut stats = shared_stats.write().await; for peer in &peers { stats.insert(peer.public_key_b64.clone(), WgPeerStats::default()); } } // Spawn the event loop tokio::spawn(async move { if let Err(e) = wg_server_loop( udp_socket, tun_device, peers, peer_index, rate_limiter, config.private_key.clone(), shared_stats, server_stats, started_at, shutdown_rx, command_rx, ) .await { error!("WireGuard server loop error: {}", e); } info!("WireGuard server loop exited"); }); self.shutdown_tx = Some(shutdown_tx); self.command_tx = Some(command_tx); self.started_at = Some(started_at); self.listen_port = Some(listen_port); Ok(()) } pub async fn stop(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()); } self.command_tx = None; self.started_at = None; self.listen_port = None; info!("WireGuard server stopped"); Ok(()) } pub fn get_status(&self) -> serde_json::Value { if self.is_running() { serde_json::json!({ "state": "running", "listenPort": self.listen_port, "uptimeSeconds": self.started_at.map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0), }) } else { serde_json::json!({ "state": "stopped" }) } } pub async fn get_statistics(&self) -> serde_json::Value { let mut stats = self.server_stats.write().await; if let Some(started) = self.started_at { stats.uptime_seconds = started.elapsed().as_secs_f64(); } // Aggregate from peer stats let peer_stats = self.shared_stats.read().await; stats.active_peers = peer_stats.len(); stats.total_bytes_sent = peer_stats.values().map(|s| s.bytes_sent).sum(); stats.total_bytes_received = peer_stats.values().map(|s| s.bytes_received).sum(); stats.total_packets_sent = peer_stats.values().map(|s| s.packets_sent).sum(); stats.total_packets_received = peer_stats.values().map(|s| s.packets_received).sum(); serde_json::to_value(&*stats).unwrap_or_default() } pub async fn list_peers(&self) -> Vec { let stats = self.shared_stats.read().await; stats .iter() .map(|(key, s)| WgPeerInfo { public_key: key.clone(), allowed_ips: vec![], // populated from event loop snapshots endpoint: None, persistent_keepalive: None, stats: s.clone(), }) .collect() } pub async fn add_peer(&self, config: WgPeerConfig) -> Result<()> { let tx = self .command_tx .as_ref() .ok_or_else(|| anyhow!("Server not running"))?; let (resp_tx, resp_rx) = oneshot::channel(); tx.send(WgCommand::AddPeer(config, resp_tx)) .await .map_err(|_| anyhow!("Server event loop closed"))?; resp_rx.await.map_err(|_| anyhow!("No response"))? } pub async fn remove_peer(&self, public_key: &str) -> Result<()> { let tx = self .command_tx .as_ref() .ok_or_else(|| anyhow!("Server not running"))?; let (resp_tx, resp_rx) = oneshot::channel(); tx.send(WgCommand::RemovePeer(public_key.to_string(), resp_tx)) .await .map_err(|_| anyhow!("Server event loop closed"))?; resp_rx.await.map_err(|_| anyhow!("No response"))? } } // ============================================================================ // Server event loop // ============================================================================ async fn wg_server_loop( udp_socket: UdpSocket, tun_device: tun::AsyncDevice, mut peers: Vec, peer_index: AtomicU32, rate_limiter: Arc, server_private_key_b64: String, shared_stats: Arc>>, _server_stats: Arc>, _started_at: Instant, mut shutdown_rx: oneshot::Receiver<()>, mut command_rx: mpsc::Receiver, ) -> Result<()> { let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; let mut tun_buf = vec![0u8; MAX_UDP_PACKET]; let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); // Split TUN for concurrent read/write in select let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device); // Stats sync interval let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1)); loop { tokio::select! { // --- UDP receive --- result = udp_socket.recv_from(&mut udp_buf) => { let (n, src_addr) = result?; if n == 0 { continue; } // Find which peer this packet belongs to by trying decapsulate let mut handled = false; for peer in peers.iter_mut() { match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, src_addr).await?; // Drain loop loop { match peer.tunn.decapsulate(None, &[], &mut dst_buf) { TunnResult::WriteToNetwork(pkt) => { let ep = peer.endpoint.unwrap_or(src_addr); udp_socket.send_to(pkt, ep).await?; } _ => break, } } peer.endpoint = Some(src_addr); handled = true; break; } TunnResult::WriteToTunnelV4(packet, addr) => { if peer.matches_dst(IpAddr::V4(addr)) { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; peer.stats.bytes_received += pkt_len; peer.stats.packets_received += 1; } peer.endpoint = Some(src_addr); handled = true; break; } TunnResult::WriteToTunnelV6(packet, addr) => { if peer.matches_dst(IpAddr::V6(addr)) { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; peer.stats.bytes_received += pkt_len; peer.stats.packets_received += 1; } peer.endpoint = Some(src_addr); handled = true; break; } TunnResult::Done => { // This peer didn't recognize the packet, try next continue; } TunnResult::Err(e) => { debug!("decapsulate error from {}: {:?}", src_addr, e); continue; } } } if !handled { debug!("No peer matched UDP packet from {}", src_addr); } } // --- TUN read --- result = tun_reader.read(&mut tun_buf) => { let n = result?; if n == 0 { continue; } let dst_ip = match extract_dst_ip(&tun_buf[..n]) { Some(ip) => ip, None => { continue; } }; // Find peer whose AllowedIPs match the destination for peer in peers.iter_mut() { if !peer.matches_dst(dst_ip) { continue; } match peer.tunn.encapsulate(&tun_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { if let Some(endpoint) = peer.endpoint { let pkt_len = n as u64; udp_socket.send_to(packet, endpoint).await?; peer.stats.bytes_sent += pkt_len; peer.stats.packets_sent += 1; } else { debug!("No endpoint for peer {}, dropping packet", peer.public_key_b64); } } TunnResult::Err(e) => { debug!("encapsulate error for peer {}: {:?}", peer.public_key_b64, e); } _ => {} } break; } } // --- Timer tick (100ms) for WireGuard timers --- _ = timer.tick() => { for peer in peers.iter_mut() { match peer.tunn.update_timers(&mut dst_buf) { TunnResult::WriteToNetwork(packet) => { if let Some(endpoint) = peer.endpoint { udp_socket.send_to(packet, endpoint).await?; } } TunnResult::Err(e) => { debug!("Timer error for peer {}: {:?}", peer.public_key_b64, e); } _ => {} } } } // --- Sync stats to shared state --- _ = stats_timer.tick() => { let mut shared = shared_stats.write().await; for peer in peers.iter() { shared.insert(peer.public_key_b64.clone(), peer.stats.clone()); } } // --- Dynamic peer commands --- cmd = command_rx.recv() => { match cmd { Some(WgCommand::AddPeer(config, resp_tx)) => { let result = add_peer_to_loop( &mut peers, &config, &peer_index, &rate_limiter, &server_private_key_b64, ); if result.is_ok() { let mut shared = shared_stats.write().await; shared.insert(config.public_key.clone(), WgPeerStats::default()); } let _ = resp_tx.send(result); } Some(WgCommand::RemovePeer(pubkey, resp_tx)) => { let prev_len = peers.len(); peers.retain(|p| p.public_key_b64 != pubkey); if peers.len() < prev_len { let mut shared = shared_stats.write().await; shared.remove(&pubkey); let _ = resp_tx.send(Ok(())); } else { let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey))); } } None => { info!("Command channel closed"); break; } } } // --- Shutdown --- _ = &mut shutdown_rx => { info!("WireGuard server shutdown signal received"); break; } } } Ok(()) } fn add_peer_to_loop( peers: &mut Vec, config: &WgPeerConfig, peer_index: &AtomicU32, rate_limiter: &Arc, server_private_key_b64: &str, ) -> Result<()> { // Check for duplicate if peers.iter().any(|p| p.public_key_b64 == config.public_key) { return Err(anyhow!("Peer already exists: {}", config.public_key)); } let peer_public = parse_public_key(&config.public_key)?; let psk = match &config.preshared_key { Some(k) => Some(parse_preshared_key(k)?), None => None, }; let idx = peer_index.fetch_add(1, Ordering::Relaxed); let priv_copy = parse_private_key(server_private_key_b64)?; let tunn = Tunn::new( priv_copy, peer_public, psk, config.persistent_keepalive, idx, Some(rate_limiter.clone()), ); let allowed_ips: Vec = config .allowed_ips .iter() .map(|cidr| AllowedIp::parse(cidr)) .collect::>>()?; let endpoint = match &config.endpoint { Some(ep) => Some(ep.parse::()?), None => None, }; peers.push(PeerState { tunn, public_key_b64: config.public_key.clone(), allowed_ips, endpoint, persistent_keepalive: config.persistent_keepalive, stats: WgPeerStats::default(), }); info!("Added WireGuard peer: {}", config.public_key); Ok(()) } // ============================================================================ // WgClient // ============================================================================ pub struct WgClient { shutdown_tx: Option>, shared_stats: Arc>, state: Arc>, assigned_ip: Option, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct WgClientState { state: String, #[serde(skip_serializing_if = "Option::is_none")] assigned_ip: Option, #[serde(skip_serializing_if = "Option::is_none")] connected_since: Option, #[serde(skip_serializing_if = "Option::is_none")] last_error: Option, } impl WgClient { pub fn new() -> Self { Self { shutdown_tx: None, shared_stats: Arc::new(RwLock::new(WgPeerStats::default())), state: Arc::new(RwLock::new(WgClientState { state: "disconnected".to_string(), assigned_ip: None, connected_since: None, last_error: None, })), assigned_ip: None, } } pub fn is_running(&self) -> bool { self.shutdown_tx.is_some() } pub async fn connect(&mut self, config: WgClientConfig) -> Result { if self.is_running() { return Err(anyhow!("WireGuard client is already connected")); } { let mut state = self.state.write().await; state.state = "connecting".to_string(); } let mtu = config.mtu.unwrap_or(DEFAULT_MTU); let _prefix = config.address_prefix.unwrap_or(24); let address: Ipv4Addr = config.address.parse()?; // Parse keys let client_private = parse_private_key(&config.private_key)?; let peer_public = parse_public_key(&config.peer.public_key)?; let psk = match &config.peer.preshared_key { Some(k) => Some(parse_preshared_key(k)?), None => None, }; let tunn = Tunn::new( client_private, peer_public, psk, config.peer.persistent_keepalive, 0, // single peer, index 0 None, ); // Parse server endpoint let endpoint: SocketAddr = config .peer .endpoint .as_ref() .ok_or_else(|| anyhow!("Peer endpoint is required for client mode"))? .parse()?; // Parse AllowedIPs let allowed_ips: Vec = config .peer .allowed_ips .iter() .map(|cidr| AllowedIp::parse(cidr)) .collect::>>()?; // Create TUN device let tun_config = TunConfig { name: "wg-client0".to_string(), address, netmask: Ipv4Addr::new(255, 255, 255, 0), mtu, }; let tun_device = tunnel::create_tun(&tun_config)?; info!("WireGuard client TUN device created: {}", tun_config.name); // Add routes for AllowedIPs for cidr in &config.peer.allowed_ips { if let Err(e) = tunnel::add_route(cidr, &tun_config.name).await { warn!("Failed to add route for {}: {}", cidr, e); } } // Bind ephemeral UDP socket let udp_socket = UdpSocket::bind("0.0.0.0:0").await?; info!( "WireGuard client bound to {}", udp_socket.local_addr()? ); let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); let shared_stats = self.shared_stats.clone(); let state = self.state.clone(); let assigned_ip = config.address.clone(); // Update state { let mut s = state.write().await; s.state = "connected".to_string(); s.assigned_ip = Some(assigned_ip.clone()); s.connected_since = Some(chrono_now()); } // Spawn client loop tokio::spawn(async move { if let Err(e) = wg_client_loop( udp_socket, tun_device, tunn, endpoint, allowed_ips, shared_stats, state.clone(), shutdown_rx, ) .await { error!("WireGuard client loop error: {}", e); let mut s = state.write().await; s.state = "error".to_string(); s.last_error = Some(format!("{}", e)); } }); self.shutdown_tx = Some(shutdown_tx); self.assigned_ip = Some(config.address.clone()); Ok(config.address) } pub async fn disconnect(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()); } { let mut s = self.state.write().await; s.state = "disconnected".to_string(); s.assigned_ip = None; s.connected_since = None; } self.assigned_ip = None; info!("WireGuard client disconnected"); Ok(()) } pub async fn get_status(&self) -> serde_json::Value { let s = self.state.read().await; serde_json::to_value(&*s).unwrap_or_default() } pub async fn get_statistics(&self) -> serde_json::Value { let stats = self.shared_stats.read().await; serde_json::to_value(&*stats).unwrap_or_default() } } // ============================================================================ // Client event loop // ============================================================================ async fn wg_client_loop( udp_socket: UdpSocket, tun_device: tun::AsyncDevice, mut tunn: Tunn, endpoint: SocketAddr, _allowed_ips: Vec, shared_stats: Arc>, _state: Arc>, mut shutdown_rx: oneshot::Receiver<()>, ) -> Result<()> { let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; let mut tun_buf = vec![0u8; MAX_UDP_PACKET]; let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1)); let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device); // Local stats (synced periodically) let mut local_stats = WgPeerStats::default(); // Initiate handshake match tunn.encapsulate(&[], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; debug!("Sent WireGuard handshake initiation"); } _ => {} } loop { tokio::select! { // --- UDP receive --- result = udp_socket.recv_from(&mut udp_buf) => { let (n, src_addr) = result?; if n == 0 { continue; } match tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; // Drain loop loop { match tunn.decapsulate(None, &[], &mut dst_buf) { TunnResult::WriteToNetwork(pkt) => { udp_socket.send_to(pkt, endpoint).await?; } _ => break, } } } TunnResult::WriteToTunnelV4(packet, _addr) => { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; local_stats.bytes_received += pkt_len; local_stats.packets_received += 1; } TunnResult::WriteToTunnelV6(packet, _addr) => { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; local_stats.bytes_received += pkt_len; local_stats.packets_received += 1; } TunnResult::Done => {} TunnResult::Err(e) => { debug!("Client decapsulate error: {:?}", e); } } } // --- TUN read --- result = tun_reader.read(&mut tun_buf) => { let n = result?; if n == 0 { continue; } match tunn.encapsulate(&tun_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { let pkt_len = n as u64; udp_socket.send_to(packet, endpoint).await?; local_stats.bytes_sent += pkt_len; local_stats.packets_sent += 1; } TunnResult::Err(e) => { debug!("Client encapsulate error: {:?}", e); } _ => {} } } // --- Timer tick --- _ = timer.tick() => { match tunn.update_timers(&mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; } TunnResult::Err(e) => { debug!("Client timer error: {:?}", e); } _ => {} } } // --- Sync stats --- _ = stats_timer.tick() => { let mut shared = shared_stats.write().await; *shared = local_stats.clone(); } // --- Shutdown --- _ = &mut shutdown_rx => { info!("WireGuard client shutdown signal received"); break; } } } Ok(()) } // ============================================================================ // Helpers // ============================================================================ fn chrono_now() -> String { // Simple ISO-8601 timestamp without chrono dependency let dur = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default(); format!("{}s since epoch", dur.as_secs()) } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; use std::net::Ipv6Addr; #[test] fn test_generate_wg_keypair() { let (pub_key, priv_key) = generate_wg_keypair(); // Base64 of 32 bytes = 44 chars (with padding) assert_eq!(pub_key.len(), 44); assert_eq!(priv_key.len(), 44); // Decode and verify 32 bytes let pub_bytes = BASE64.decode(&pub_key).unwrap(); let priv_bytes = BASE64.decode(&priv_key).unwrap(); assert_eq!(pub_bytes.len(), 32); assert_eq!(priv_bytes.len(), 32); } #[test] fn test_key_roundtrip() { let (pub_b64, priv_b64) = generate_wg_keypair(); // Parse back let secret = parse_private_key(&priv_b64).unwrap(); let public = parse_public_key(&pub_b64).unwrap(); // Derive public from private and verify match let derived_public = PublicKey::from(&secret); assert_eq!(public.to_bytes(), derived_public.to_bytes()); } #[test] fn test_parse_invalid_key() { assert!(parse_private_key("not-valid-base64!!!").is_err()); assert!(parse_private_key("AAAA").is_err()); // too short (3 bytes) assert!(parse_public_key("AAAA").is_err()); } #[test] fn test_allowed_ip_v4_match() { let aip = AllowedIp::parse("10.0.0.0/24").unwrap(); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 254)))); assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 1, 1)))); assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); } #[test] fn test_allowed_ip_v4_catch_all() { let aip = AllowedIp::parse("0.0.0.0/0").unwrap(); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)))); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)))); } #[test] fn test_allowed_ip_v4_host() { let aip = AllowedIp::parse("10.0.0.5/32").unwrap(); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))); assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 6)))); } #[test] fn test_allowed_ip_v6_match() { let aip = AllowedIp::parse("fd00::/64").unwrap(); assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)))); assert!(!aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd01, 0, 0, 0, 0, 0, 0, 1)))); } #[test] fn test_allowed_ip_v6_catch_all() { let aip = AllowedIp::parse("::/0").unwrap(); assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)))); } #[test] fn test_allowed_ip_cross_family_no_match() { let v4 = AllowedIp::parse("10.0.0.0/8").unwrap(); assert!(!v4.matches(IpAddr::V6(Ipv6Addr::LOCALHOST))); let v6 = AllowedIp::parse("::/0").unwrap(); assert!(!v6.matches(IpAddr::V4(Ipv4Addr::LOCALHOST))); } #[test] fn test_extract_dst_ip_v4() { // Minimal IPv4 header: version=4, IHL=5, total_length=20, dst at bytes 16-19 let mut pkt = [0u8; 20]; pkt[0] = 0x45; // version 4, IHL 5 pkt[16] = 10; pkt[17] = 0; pkt[18] = 0; pkt[19] = 1; assert_eq!( extract_dst_ip(&pkt), Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))) ); } #[test] fn test_extract_dst_ip_v6() { // Minimal IPv6 header: version=6, dst at bytes 24-39 let mut pkt = [0u8; 40]; pkt[0] = 0x60; // version 6 pkt[24] = 0xfd; pkt[39] = 0x01; let expected = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)); assert_eq!(extract_dst_ip(&pkt), Some(expected)); } #[test] fn test_extract_dst_ip_empty() { assert_eq!(extract_dst_ip(&[]), None); } #[test] fn test_loopback_tunnel() { // Two Tunn instances: server and client, exchanging packets in memory let (server_pub, server_priv) = generate_wg_keypair(); let (client_pub, client_priv) = generate_wg_keypair(); let server_secret = parse_private_key(&server_priv).unwrap(); let client_secret = parse_private_key(&client_priv).unwrap(); let server_public = parse_public_key(&server_pub).unwrap(); let client_public = parse_public_key(&client_pub).unwrap(); let mut server_tunn = Tunn::new( server_secret, client_public, None, None, 0, None, ); let mut client_tunn = Tunn::new( client_secret, server_public, None, None, 1, None, ); let mut buf_a = vec![0u8; 2048]; let mut buf_b = vec![0u8; 2048]; // Client initiates handshake let handshake_init = match client_tunn.encapsulate(&[], &mut buf_a) { TunnResult::WriteToNetwork(pkt) => pkt.to_vec(), other => panic!("Expected WriteToNetwork for handshake init, got {:?}", format!("{:?}", std::mem::discriminant(&other))), }; // Server processes handshake init let handshake_resp = match server_tunn.decapsulate(None, &handshake_init, &mut buf_b) { TunnResult::WriteToNetwork(pkt) => pkt.to_vec(), other => panic!("Expected WriteToNetwork for handshake resp, got {:?}", format!("{:?}", std::mem::discriminant(&other))), }; // Drain server loop { match server_tunn.decapsulate(None, &[], &mut buf_b) { TunnResult::WriteToNetwork(_) => {} _ => break, } } // Client processes handshake response match client_tunn.decapsulate(None, &handshake_resp, &mut buf_a) { TunnResult::WriteToNetwork(pkt) => { // Client might send a keepalive or transport data // Feed it to server let pkt_copy = pkt.to_vec(); let _ = server_tunn.decapsulate(None, &pkt_copy, &mut buf_b); } TunnResult::Done => {} other => { // Drain loop { match client_tunn.decapsulate(None, &[], &mut buf_a) { TunnResult::WriteToNetwork(_) => {} _ => break, } } } } // Drain client loop { match client_tunn.decapsulate(None, &[], &mut buf_a) { TunnResult::WriteToNetwork(_) => {} _ => break, } } // Now try to send a fake IP packet from client to server let mut fake_ip = [0u8; 28]; fake_ip[0] = 0x45; // IPv4 fake_ip[2] = 0; fake_ip[3] = 28; // total length // Source IP (bytes 12-15): 10.0.0.2 (client) fake_ip[12] = 10; fake_ip[13] = 0; fake_ip[14] = 0; fake_ip[15] = 2; // Destination IP (bytes 16-19): 10.0.0.1 (server) fake_ip[16] = 10; fake_ip[17] = 0; fake_ip[18] = 0; fake_ip[19] = 1; match client_tunn.encapsulate(&fake_ip, &mut buf_a) { TunnResult::WriteToNetwork(encrypted) => { let encrypted_copy = encrypted.to_vec(); // Server decapsulates match server_tunn.decapsulate(None, &encrypted_copy, &mut buf_b) { TunnResult::WriteToTunnelV4(decrypted, src_addr) => { // src_addr is the source IP from the inner packet (for AllowedIPs check) assert_eq!(src_addr, Ipv4Addr::new(10, 0, 0, 2)); assert_eq!(&decrypted[..fake_ip.len()], &fake_ip); } TunnResult::WriteToNetwork(_pkt) => { // Might need another round trip, that's OK } _ => { // Session might not be fully established yet, acceptable } } } TunnResult::Err(_) => { // Session not yet established, acceptable in unit test } _ => {} } } }