use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use anyhow::{anyhow, Result}; use base64::engine::general_purpose::STANDARD as BASE64; use base64::Engine; use boringtun::noise::errors::WireGuardError; use boringtun::noise::rate_limiter::RateLimiter; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::{PublicKey, StaticSecret}; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, RwLock}; use tracing::{debug, error, info, warn}; use crate::server::{ClientInfo, ForwardingEngine, ServerState}; use crate::tunnel::{self, TunConfig}; // ============================================================================ // Constants // ============================================================================ const MAX_UDP_PACKET: usize = 65536; const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET; /// Minimum dst buffer size for boringtun encapsulate/decapsulate const _MIN_DST_BUF: usize = 148; const TIMER_TICK_MS: u64 = 100; const DEFAULT_MTU: u16 = 1420; // ============================================================================ // Configuration types // ============================================================================ #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgPeerConfig { pub public_key: String, #[serde(default)] pub preshared_key: Option, pub allowed_ips: Vec, #[serde(default)] pub endpoint: Option, #[serde(default)] pub persistent_keepalive: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct WgClientConfig { pub private_key: String, pub address: String, #[serde(default)] pub address_prefix: Option, #[serde(default)] pub dns: Option>, #[serde(default)] pub mtu: Option, pub peer: WgPeerConfig, } // ============================================================================ // Stats types // ============================================================================ #[derive(Debug, Clone, Default, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgPeerStats { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub last_handshake_time: Option, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct WgPeerInfo { pub public_key: String, pub allowed_ips: Vec, pub endpoint: Option, pub persistent_keepalive: Option, #[serde(flatten)] pub stats: WgPeerStats, } // ============================================================================ // Key generation and parsing // ============================================================================ /// Generate a WireGuard-compatible X25519 keypair. /// Returns (public_key_base64, private_key_base64). pub fn generate_wg_keypair() -> (String, String) { let private = StaticSecret::random_from_rng(OsRng); let public = PublicKey::from(&private); let priv_b64 = BASE64.encode(private.to_bytes()); let pub_b64 = BASE64.encode(public.to_bytes()); (pub_b64, priv_b64) } /// Derive the WireGuard public key (base64) from a private key (base64). pub fn wg_public_key_from_private(private_key_b64: &str) -> Result { let private = parse_private_key(private_key_b64)?; let public = PublicKey::from(&private); Ok(BASE64.encode(public.to_bytes())) } fn parse_private_key(b64: &str) -> Result { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { return Err(anyhow!("Private key must be 32 bytes, got {}", bytes.len())); } let mut arr = [0u8; 32]; arr.copy_from_slice(&bytes); Ok(StaticSecret::from(arr)) } fn parse_public_key(b64: &str) -> Result { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { return Err(anyhow!("Public key must be 32 bytes, got {}", bytes.len())); } let mut arr = [0u8; 32]; arr.copy_from_slice(&bytes); Ok(PublicKey::from(arr)) } fn parse_preshared_key(b64: &str) -> Result<[u8; 32]> { let bytes = BASE64.decode(b64)?; if bytes.len() != 32 { return Err(anyhow!( "Preshared key must be 32 bytes, got {}", bytes.len() )); } let mut arr = [0u8; 32]; arr.copy_from_slice(&bytes); Ok(arr) } // ============================================================================ // AllowedIPs matching // ============================================================================ #[derive(Debug, Clone)] struct AllowedIp { addr: IpAddr, prefix_len: u8, } impl AllowedIp { fn parse(cidr: &str) -> Result { let parts: Vec<&str> = cidr.split('/').collect(); if parts.len() != 2 { return Err(anyhow!("Invalid CIDR: {}", cidr)); } let addr: IpAddr = parts[0].parse()?; let prefix_len: u8 = parts[1].parse()?; match addr { IpAddr::V4(_) if prefix_len > 32 => { return Err(anyhow!("IPv4 prefix length {} > 32", prefix_len)) } IpAddr::V6(_) if prefix_len > 128 => { return Err(anyhow!("IPv6 prefix length {} > 128", prefix_len)) } _ => {} } Ok(Self { addr, prefix_len }) } fn matches(&self, ip: IpAddr) -> bool { match (self.addr, ip) { (IpAddr::V4(net), IpAddr::V4(target)) => { if self.prefix_len == 0 { return true; } if self.prefix_len >= 32 { return net == target; } let mask = u32::MAX << (32 - self.prefix_len); (u32::from(net) & mask) == (u32::from(target) & mask) } (IpAddr::V6(net), IpAddr::V6(target)) => { if self.prefix_len == 0 { return true; } if self.prefix_len >= 128 { return net == target; } let net_bits = u128::from(net); let target_bits = u128::from(target); let mask = u128::MAX << (128 - self.prefix_len); (net_bits & mask) == (target_bits & mask) } _ => false, } } } // ============================================================================ // Dynamic peer management commands // ============================================================================ pub enum WgCommand { AddPeer(WgPeerConfig, oneshot::Sender>), RemovePeer(String, oneshot::Sender>), } // ============================================================================ // Internal peer state (owned by event loop) // ============================================================================ struct PeerState { tunn: Tunn, public_key_b64: String, allowed_ips: Vec, endpoint: Option, #[allow(dead_code)] persistent_keepalive: Option, stats: WgPeerStats, /// Whether this peer has completed a WireGuard handshake and is in state.clients. is_connected: bool, /// Last time we received data or handshake activity from this peer. last_activity_at: Option, /// VPN IP assigned during registration (used for connect/disconnect). vpn_ip: Option, /// Previous synced byte counts for aggregate stats delta tracking. prev_synced_bytes_sent: u64, prev_synced_bytes_received: u64, } impl PeerState { fn matches_allowed_ips(&self, ip: IpAddr) -> bool { self.allowed_ips.iter().any(|aip| aip.matches(ip)) } } fn add_peer_to_loop( peers: &mut Vec, config: &WgPeerConfig, peer_index: &AtomicU32, rate_limiter: &Arc, server_private_key_b64: &str, ) -> Result<()> { // Check for duplicate if peers.iter().any(|p| p.public_key_b64 == config.public_key) { return Err(anyhow!("Peer already exists: {}", config.public_key)); } let peer_public = parse_public_key(&config.public_key)?; let psk = match &config.preshared_key { Some(k) => Some(parse_preshared_key(k)?), None => None, }; let idx = peer_index.fetch_add(1, Ordering::Relaxed); let priv_copy = parse_private_key(server_private_key_b64)?; let tunn = Tunn::new( priv_copy, peer_public, psk, config.persistent_keepalive, idx, Some(rate_limiter.clone()), ); let allowed_ips: Vec = config .allowed_ips .iter() .map(|cidr| AllowedIp::parse(cidr)) .collect::>>()?; let endpoint = match &config.endpoint { Some(ep) => Some(ep.parse::()?), None => None, }; peers.push(PeerState { tunn, public_key_b64: config.public_key.clone(), allowed_ips, endpoint, persistent_keepalive: config.persistent_keepalive, stats: WgPeerStats::default(), is_connected: false, last_activity_at: None, vpn_ip: None, prev_synced_bytes_sent: 0, prev_synced_bytes_received: 0, }); info!("Added WireGuard peer: {}", config.public_key); Ok(()) } // ============================================================================ // Integrated WG listener (shares ServerState with WS/QUIC) // ============================================================================ /// Configuration for the integrated WireGuard listener. #[derive(Debug, Clone)] pub struct WgListenerConfig { pub private_key: String, pub listen_port: u16, pub peers: Vec, } /// Extract the peer's VPN IP from AllowedIp entries. /// Prefers /32 entries (exact match); falls back to any IPv4 address. fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option { // Prefer /32 entries (exact peer VPN IP) for aip in allowed_ips { if let IpAddr::V4(v4) = aip.addr { if aip.prefix_len == 32 { return Some(v4); } } } // Fallback: use the first non-unspecified IPv4 address from any prefix length for aip in allowed_ips { if let IpAddr::V4(v4) = aip.addr { if !v4.is_unspecified() { return Some(v4); } } } None } /// Timestamp helper (mirrors server.rs timestamp_now). fn wg_timestamp_now() -> String { use std::time::SystemTime; let duration = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default(); format!("{}", duration.as_secs()) } /// Register a WG peer in ServerState (tun_routes + ip_pool only). /// Does NOT add to state.clients — peers appear there only after handshake. /// Returns the VPN IP. async fn register_wg_peer( state: &Arc, peer: &PeerState, wg_return_tx: &mpsc::Sender<(String, Vec)>, ) -> Result> { let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) { Some(ip) => ip, None => { warn!("WG peer {} has no /32 IPv4 in allowed_ips, skipping registration", peer.public_key_b64); return Ok(None); } }; let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]); // Reserve IP in the pool if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) { warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e); return Ok(None); } // Create per-peer return channel and register in tun_routes let fwd_mode = state.config.forwarding_mode.as_deref().unwrap_or("testing"); let forwarding_active = fwd_mode == "tun" || fwd_mode == "socket"; if forwarding_active { let (peer_return_tx, mut peer_return_rx) = mpsc::channel::>(256); state.tun_routes.write().await.insert(vpn_ip, peer_return_tx); // Spawn relay task: per-peer channel → merged channel tagged with pubkey let relay_tx = wg_return_tx.clone(); let pubkey = peer.public_key_b64.clone(); tokio::spawn(async move { while let Some(packet) = peer_return_rx.recv().await { if relay_tx.send((pubkey.clone(), packet)).await.is_err() { break; } } }); } info!("WG peer {} registered with IP {} (not yet connected)", peer.public_key_b64, vpn_ip); Ok(Some(vpn_ip)) } /// Add a WG peer to state.clients on first successful handshake (data received). async fn connect_wg_peer( state: &Arc, peer: &PeerState, vpn_ip: Ipv4Addr, ) { let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]); let client_info = ClientInfo { client_id: client_id.clone(), assigned_ip: vpn_ip.to_string(), connected_since: wg_timestamp_now(), bytes_sent: peer.stats.bytes_sent, bytes_received: peer.stats.bytes_received, packets_dropped: 0, bytes_dropped: 0, last_keepalive_at: None, keepalives_received: 0, rate_limit_bytes_per_sec: None, burst_bytes: None, authenticated_key: peer.public_key_b64.clone(), registered_client_id: client_id.clone(), remote_addr: peer.endpoint.map(|e| e.to_string()), transport_type: "wireguard".to_string(), }; state.clients.write().await.insert(client_info.client_id.clone(), client_info); // Increment total_connections { let mut stats = state.stats.write().await; stats.total_connections += 1; stats.total_connections_wireguard += 1; } info!("WG peer {} connected (IP: {})", peer.public_key_b64, vpn_ip); } /// Remove a WG peer from state.clients (disconnect without unregistering). async fn disconnect_wg_peer( state: &Arc, pubkey: &str, ) { let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]); if state.clients.write().await.remove(&client_id).is_some() { info!("WG peer {} disconnected (removed from active clients)", pubkey); } } /// Unregister a WG peer from ServerState. async fn unregister_wg_peer( state: &Arc, pubkey: &str, vpn_ip: Option, ) { let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]); if let Some(ip) = vpn_ip { state.tun_routes.write().await.remove(&ip); state.ip_pool.lock().await.release(&ip); } state.clients.write().await.remove(&client_id); state.rate_limiters.lock().await.remove(&client_id); } /// Integrated WireGuard listener that shares ServerState with WS/QUIC listeners. /// Uses the shared ForwardingEngine for packet routing instead of its own TUN device. pub async fn run_wg_listener( state: Arc, config: WgListenerConfig, mut shutdown_rx: mpsc::Receiver<()>, mut command_rx: mpsc::Receiver, ) -> Result<()> { // Parse server private key let server_private = parse_private_key(&config.private_key)?; let server_public = PublicKey::from(&server_private); // Create rate limiter for DDoS protection let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64)); // Build initial peer state let peer_index = AtomicU32::new(0); let mut peers: Vec = Vec::with_capacity(config.peers.len()); for peer_config in &config.peers { let peer_public = parse_public_key(&peer_config.public_key)?; let psk = match &peer_config.preshared_key { Some(k) => Some(parse_preshared_key(k)?), None => None, }; let idx = peer_index.fetch_add(1, Ordering::Relaxed); let priv_copy = parse_private_key(&config.private_key)?; let tunn = Tunn::new( priv_copy, peer_public, psk, peer_config.persistent_keepalive, idx, Some(rate_limiter.clone()), ); let allowed_ips: Vec = peer_config .allowed_ips .iter() .map(|cidr| AllowedIp::parse(cidr)) .collect::>>()?; let endpoint = match &peer_config.endpoint { Some(ep) => Some(ep.parse::()?), None => None, }; peers.push(PeerState { tunn, public_key_b64: peer_config.public_key.clone(), allowed_ips, endpoint, persistent_keepalive: peer_config.persistent_keepalive, stats: WgPeerStats::default(), is_connected: false, last_activity_at: None, vpn_ip: None, prev_synced_bytes_sent: 0, prev_synced_bytes_received: 0, }); } // Bind UDP socket let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", config.listen_port)).await?; info!("WireGuard listener started on UDP port {}", config.listen_port); // Merged return-packet channel: all per-peer channels feed into this let (wg_return_tx, mut wg_return_rx) = mpsc::channel::<(String, Vec)>(1024); // Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients) let mut peer_vpn_ips: HashMap = HashMap::new(); for peer in peers.iter_mut() { if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await { peer_vpn_ips.insert(peer.public_key_b64.clone(), ip); peer.vpn_ip = Some(ip); } } // Buffers let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1)); let mut idle_check_timer = tokio::time::interval(std::time::Duration::from_secs(10)); loop { tokio::select! { // --- UDP receive → decapsulate → ForwardingEngine --- result = udp_socket.recv_from(&mut udp_buf) => { let (n, src_addr) = result?; if n == 0 { continue; } let mut handled = false; for peer in peers.iter_mut() { match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, src_addr).await?; loop { match peer.tunn.decapsulate(None, &[], &mut dst_buf) { TunnResult::WriteToNetwork(pkt) => { let ep = peer.endpoint.unwrap_or(src_addr); udp_socket.send_to(pkt, ep).await?; } _ => break, } } peer.endpoint = Some(src_addr); // Handshake response counts as activity peer.last_activity_at = Some(tokio::time::Instant::now()); handled = true; break; } TunnResult::WriteToTunnelV4(packet, addr) => { if peer.matches_allowed_ips(IpAddr::V4(addr)) { let pkt_len = packet.len() as u64; // Forward via shared forwarding engine let mut engine = state.forwarding_engine.lock().await; match &mut *engine { ForwardingEngine::Tun(writer) => { use tokio::io::AsyncWriteExt; if let Err(e) = writer.write_all(packet).await { warn!("TUN write error for WG peer: {}", e); } } ForwardingEngine::Socket(sender) => { let _ = sender.try_send(packet.to_vec()); } ForwardingEngine::Bridge(sender) => { let _ = sender.try_send(packet.to_vec()); } ForwardingEngine::Hybrid { socket_tx, bridge_tx, routing_table } => { if packet.len() >= 20 { let src_ip = Ipv4Addr::new(packet[12], packet[13], packet[14], packet[15]); let use_bridge = routing_table.read().await.get(&src_ip).copied().unwrap_or(false); if use_bridge { let _ = bridge_tx.try_send(packet.to_vec()); } else { let _ = socket_tx.try_send(packet.to_vec()); } } } ForwardingEngine::Testing => {} } peer.stats.bytes_received += pkt_len; peer.stats.packets_received += 1; } peer.endpoint = Some(src_addr); // Track activity and detect handshake completion peer.last_activity_at = Some(tokio::time::Instant::now()); if !peer.is_connected { peer.is_connected = true; peer.stats.last_handshake_time = Some(wg_timestamp_now()); if let Some(vpn_ip) = peer.vpn_ip { connect_wg_peer(&state, peer, vpn_ip).await; } } handled = true; break; } TunnResult::WriteToTunnelV6(packet, addr) => { if peer.matches_allowed_ips(IpAddr::V6(addr)) { let pkt_len = packet.len() as u64; let mut engine = state.forwarding_engine.lock().await; match &mut *engine { ForwardingEngine::Tun(writer) => { use tokio::io::AsyncWriteExt; if let Err(e) = writer.write_all(packet).await { warn!("TUN write error for WG peer: {}", e); } } ForwardingEngine::Socket(sender) => { let _ = sender.try_send(packet.to_vec()); } ForwardingEngine::Bridge(sender) => { let _ = sender.try_send(packet.to_vec()); } ForwardingEngine::Hybrid { socket_tx, bridge_tx, routing_table } => { if packet.len() >= 20 { let src_ip = Ipv4Addr::new(packet[12], packet[13], packet[14], packet[15]); let use_bridge = routing_table.read().await.get(&src_ip).copied().unwrap_or(false); if use_bridge { let _ = bridge_tx.try_send(packet.to_vec()); } else { let _ = socket_tx.try_send(packet.to_vec()); } } } ForwardingEngine::Testing => {} } peer.stats.bytes_received += pkt_len; peer.stats.packets_received += 1; } peer.endpoint = Some(src_addr); // Track activity and detect handshake completion peer.last_activity_at = Some(tokio::time::Instant::now()); if !peer.is_connected { peer.is_connected = true; peer.stats.last_handshake_time = Some(wg_timestamp_now()); if let Some(vpn_ip) = peer.vpn_ip { connect_wg_peer(&state, peer, vpn_ip).await; } } handled = true; break; } TunnResult::Done => { continue; } TunnResult::Err(e) => { debug!("decapsulate error from {}: {:?}", src_addr, e); continue; } } } if !handled { debug!("No WG peer matched UDP packet from {}", src_addr); } } // --- Return packets from tun_routes → encapsulate → UDP --- Some((pubkey, packet)) = wg_return_rx.recv() => { if let Some(peer) = peers.iter_mut().find(|p| p.public_key_b64 == pubkey) { match peer.tunn.encapsulate(&packet, &mut dst_buf) { TunnResult::WriteToNetwork(out) => { if let Some(endpoint) = peer.endpoint { let pkt_len = packet.len() as u64; udp_socket.send_to(out, endpoint).await?; peer.stats.bytes_sent += pkt_len; peer.stats.packets_sent += 1; } else { debug!("No endpoint for WG peer {}, dropping return packet", peer.public_key_b64); } } TunnResult::Err(e) => { debug!("encapsulate error for WG peer {}: {:?}", peer.public_key_b64, e); } _ => {} } } } // --- WireGuard protocol timers (100ms) --- _ = timer.tick() => { for peer in peers.iter_mut() { match peer.tunn.update_timers(&mut dst_buf) { TunnResult::WriteToNetwork(packet) => { if let Some(endpoint) = peer.endpoint { udp_socket.send_to(packet, endpoint).await?; } } TunnResult::Err(WireGuardError::ConnectionExpired) => { warn!("WG peer {} connection expired", peer.public_key_b64); if peer.is_connected { peer.is_connected = false; disconnect_wg_peer(&state, &peer.public_key_b64).await; } } TunnResult::Err(e) => { debug!("Timer error for WG peer {}: {:?}", peer.public_key_b64, e); } _ => {} } } } // --- Sync stats to ServerState (every 1s) --- _ = stats_timer.tick() => { let mut clients = state.clients.write().await; let mut stats = state.stats.write().await; for peer in peers.iter_mut() { // Always update aggregate stats (regardless of connection state) let delta_sent = peer.stats.bytes_sent.saturating_sub(peer.prev_synced_bytes_sent); let delta_recv = peer.stats.bytes_received.saturating_sub(peer.prev_synced_bytes_received); if delta_sent > 0 || delta_recv > 0 { stats.bytes_sent += delta_sent; stats.bytes_received += delta_recv; peer.prev_synced_bytes_sent = peer.stats.bytes_sent; peer.prev_synced_bytes_received = peer.stats.bytes_received; } // Only update ClientInfo if peer is connected (in state.clients) let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]); if let Some(info) = clients.get_mut(&client_id) { info.bytes_sent = peer.stats.bytes_sent; info.bytes_received = peer.stats.bytes_received; info.remote_addr = peer.endpoint.map(|e| e.to_string()); } } } // --- Idle timeout check (every 10s) --- _ = idle_check_timer.tick() => { let now = tokio::time::Instant::now(); for peer in peers.iter_mut() { if peer.is_connected { if let Some(last) = peer.last_activity_at { if now.duration_since(last) > std::time::Duration::from_secs(180) { info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64); peer.is_connected = false; disconnect_wg_peer(&state, &peer.public_key_b64).await; } } } } } // --- Dynamic peer commands --- cmd = command_rx.recv() => { match cmd { Some(WgCommand::AddPeer(peer_config, resp_tx)) => { let result = add_peer_to_loop( &mut peers, &peer_config, &peer_index, &rate_limiter, &config.private_key, ); if result.is_ok() { // Register new peer in ServerState (IP + tun_routes only) let peer = peers.last_mut().unwrap(); match register_wg_peer(&state, peer, &wg_return_tx).await { Ok(Some(ip)) => { peer_vpn_ips.insert(peer_config.public_key.clone(), ip); peer.vpn_ip = Some(ip); } Ok(None) => {} Err(e) => { warn!("Failed to register WG peer: {}", e); } } } let _ = resp_tx.send(result); } Some(WgCommand::RemovePeer(pubkey, resp_tx)) => { let prev_len = peers.len(); peers.retain(|p| p.public_key_b64 != pubkey); if peers.len() < prev_len { let vpn_ip = peer_vpn_ips.remove(&pubkey); unregister_wg_peer(&state, &pubkey, vpn_ip).await; let _ = resp_tx.send(Ok(())); } else { let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey))); } } None => { info!("WG command channel closed"); break; } } } // --- Shutdown --- _ = shutdown_rx.recv() => { info!("WireGuard listener shutdown signal received"); break; } } } // Cleanup: unregister all peers from ServerState for peer in &peers { let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied(); unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await; } info!("WireGuard listener stopped"); Ok(()) } // ============================================================================ // WgClient // ============================================================================ pub struct WgClient { shutdown_tx: Option>, shared_stats: Arc>, state: Arc>, assigned_ip: Option, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct WgClientState { state: String, #[serde(skip_serializing_if = "Option::is_none")] assigned_ip: Option, #[serde(skip_serializing_if = "Option::is_none")] connected_since: Option, #[serde(skip_serializing_if = "Option::is_none")] last_error: Option, } impl WgClient { pub fn new() -> Self { Self { shutdown_tx: None, shared_stats: Arc::new(RwLock::new(WgPeerStats::default())), state: Arc::new(RwLock::new(WgClientState { state: "disconnected".to_string(), assigned_ip: None, connected_since: None, last_error: None, })), assigned_ip: None, } } pub fn is_running(&self) -> bool { self.shutdown_tx.is_some() } pub async fn connect(&mut self, config: WgClientConfig) -> Result { if self.is_running() { return Err(anyhow!("WireGuard client is already connected")); } { let mut state = self.state.write().await; state.state = "connecting".to_string(); } let mtu = config.mtu.unwrap_or(DEFAULT_MTU); let _prefix = config.address_prefix.unwrap_or(24); let address: Ipv4Addr = config.address.parse()?; // Parse keys let client_private = parse_private_key(&config.private_key)?; let peer_public = parse_public_key(&config.peer.public_key)?; let psk = match &config.peer.preshared_key { Some(k) => Some(parse_preshared_key(k)?), None => None, }; let tunn = Tunn::new( client_private, peer_public, psk, config.peer.persistent_keepalive, 0, // single peer, index 0 None, ); // Parse server endpoint let endpoint: SocketAddr = config .peer .endpoint .as_ref() .ok_or_else(|| anyhow!("Peer endpoint is required for client mode"))? .parse()?; // Parse AllowedIPs let allowed_ips: Vec = config .peer .allowed_ips .iter() .map(|cidr| AllowedIp::parse(cidr)) .collect::>>()?; // Create TUN device let tun_config = TunConfig { name: "wg-client0".to_string(), address, netmask: Ipv4Addr::new(255, 255, 255, 0), mtu, }; let tun_device = tunnel::create_tun(&tun_config)?; info!("WireGuard client TUN device created: {}", tun_config.name); // Add routes for AllowedIPs for cidr in &config.peer.allowed_ips { if let Err(e) = tunnel::add_route(cidr, &tun_config.name).await { warn!("Failed to add route for {}: {}", cidr, e); } } // Bind ephemeral UDP socket let udp_socket = UdpSocket::bind("0.0.0.0:0").await?; info!( "WireGuard client bound to {}", udp_socket.local_addr()? ); let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); let shared_stats = self.shared_stats.clone(); let state = self.state.clone(); let assigned_ip = config.address.clone(); // Update state — handshake hasn't completed yet { let mut s = state.write().await; s.state = "handshaking".to_string(); s.assigned_ip = Some(assigned_ip.clone()); s.connected_since = None; } // Spawn client loop tokio::spawn(async move { if let Err(e) = wg_client_loop( udp_socket, tun_device, tunn, endpoint, allowed_ips, shared_stats, state.clone(), shutdown_rx, ) .await { error!("WireGuard client loop error: {}", e); let mut s = state.write().await; s.state = "error".to_string(); s.last_error = Some(format!("{}", e)); } }); self.shutdown_tx = Some(shutdown_tx); self.assigned_ip = Some(config.address.clone()); Ok(config.address) } pub async fn disconnect(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()); } { let mut s = self.state.write().await; s.state = "disconnected".to_string(); s.assigned_ip = None; s.connected_since = None; } self.assigned_ip = None; info!("WireGuard client disconnected"); Ok(()) } pub async fn get_status(&self) -> serde_json::Value { let s = self.state.read().await; serde_json::to_value(&*s).unwrap_or_default() } pub async fn get_statistics(&self) -> serde_json::Value { let stats = self.shared_stats.read().await; serde_json::to_value(&*stats).unwrap_or_default() } } // ============================================================================ // Client event loop // ============================================================================ async fn wg_client_loop( udp_socket: UdpSocket, tun_device: tun::AsyncDevice, mut tunn: Tunn, endpoint: SocketAddr, _allowed_ips: Vec, shared_stats: Arc>, state: Arc>, mut shutdown_rx: oneshot::Receiver<()>, ) -> Result<()> { let mut udp_buf = vec![0u8; MAX_UDP_PACKET]; let mut tun_buf = vec![0u8; MAX_UDP_PACKET]; let mut dst_buf = vec![0u8; WG_BUFFER_SIZE]; let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS)); let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1)); let mut handshake_complete = false; let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device); // Local stats (synced periodically) let mut local_stats = WgPeerStats::default(); // Initiate handshake match tunn.encapsulate(&[], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; debug!("Sent WireGuard handshake initiation"); } _ => {} } loop { tokio::select! { // --- UDP receive --- result = udp_socket.recv_from(&mut udp_buf) => { let (n, src_addr) = result?; if n == 0 { continue; } match tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; // Drain loop loop { match tunn.decapsulate(None, &[], &mut dst_buf) { TunnResult::WriteToNetwork(pkt) => { udp_socket.send_to(pkt, endpoint).await?; } _ => break, } } } TunnResult::WriteToTunnelV4(packet, _addr) => { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; local_stats.bytes_received += pkt_len; local_stats.packets_received += 1; if !handshake_complete { handshake_complete = true; let mut s = state.write().await; s.state = "connected".to_string(); s.connected_since = Some(chrono_now()); info!("WireGuard handshake completed, tunnel active"); } } TunnResult::WriteToTunnelV6(packet, _addr) => { let pkt_len = packet.len() as u64; tun_writer.write_all(packet).await?; local_stats.bytes_received += pkt_len; local_stats.packets_received += 1; if !handshake_complete { handshake_complete = true; let mut s = state.write().await; s.state = "connected".to_string(); s.connected_since = Some(chrono_now()); info!("WireGuard handshake completed, tunnel active"); } } TunnResult::Done => {} TunnResult::Err(WireGuardError::ConnectionExpired) => { warn!("WireGuard session expired during decapsulate, re-initiating handshake"); match tunn.format_handshake_initiation(&mut dst_buf, true) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; } _ => {} } } TunnResult::Err(e) => { debug!("Client decapsulate error: {:?}", e); } } } // --- TUN read --- result = tun_reader.read(&mut tun_buf) => { let n = result?; if n == 0 { continue; } match tunn.encapsulate(&tun_buf[..n], &mut dst_buf) { TunnResult::WriteToNetwork(packet) => { let pkt_len = n as u64; udp_socket.send_to(packet, endpoint).await?; local_stats.bytes_sent += pkt_len; local_stats.packets_sent += 1; } TunnResult::Err(e) => { debug!("Client encapsulate error: {:?}", e); } _ => {} } } // --- Timer tick --- _ = timer.tick() => { match tunn.update_timers(&mut dst_buf) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; } TunnResult::Err(WireGuardError::ConnectionExpired) => { warn!("WireGuard connection expired, re-initiating handshake"); match tunn.format_handshake_initiation(&mut dst_buf, true) { TunnResult::WriteToNetwork(packet) => { udp_socket.send_to(packet, endpoint).await?; debug!("Sent handshake re-initiation after expiry"); } TunnResult::Err(e) => { warn!("Failed to re-initiate handshake: {:?}", e); } _ => {} } } TunnResult::Err(e) => { debug!("Client timer error: {:?}", e); } _ => {} } } // --- Sync stats --- _ = stats_timer.tick() => { let mut shared = shared_stats.write().await; *shared = local_stats.clone(); } // --- Shutdown --- _ = &mut shutdown_rx => { info!("WireGuard client shutdown signal received"); break; } } } Ok(()) } // ============================================================================ // Helpers // ============================================================================ fn chrono_now() -> String { // Simple ISO-8601 timestamp without chrono dependency let dur = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default(); format!("{}s since epoch", dur.as_secs()) } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; use crate::tunnel::extract_dst_ip; use std::net::Ipv6Addr; #[test] fn test_generate_wg_keypair() { let (pub_key, priv_key) = generate_wg_keypair(); // Base64 of 32 bytes = 44 chars (with padding) assert_eq!(pub_key.len(), 44); assert_eq!(priv_key.len(), 44); // Decode and verify 32 bytes let pub_bytes = BASE64.decode(&pub_key).unwrap(); let priv_bytes = BASE64.decode(&priv_key).unwrap(); assert_eq!(pub_bytes.len(), 32); assert_eq!(priv_bytes.len(), 32); } #[test] fn test_key_roundtrip() { let (pub_b64, priv_b64) = generate_wg_keypair(); // Parse back let secret = parse_private_key(&priv_b64).unwrap(); let public = parse_public_key(&pub_b64).unwrap(); // Derive public from private and verify match let derived_public = PublicKey::from(&secret); assert_eq!(public.to_bytes(), derived_public.to_bytes()); } #[test] fn test_wg_public_key_from_private() { let (pub_b64, priv_b64) = generate_wg_keypair(); let derived = wg_public_key_from_private(&priv_b64).unwrap(); assert_eq!(derived, pub_b64); } #[test] fn test_wg_public_key_from_private_invalid() { assert!(wg_public_key_from_private("not-valid").is_err()); assert!(wg_public_key_from_private("AAAA").is_err()); } #[test] fn test_parse_invalid_key() { assert!(parse_private_key("not-valid-base64!!!").is_err()); assert!(parse_private_key("AAAA").is_err()); // too short (3 bytes) assert!(parse_public_key("AAAA").is_err()); } #[test] fn test_allowed_ip_v4_match() { let aip = AllowedIp::parse("10.0.0.0/24").unwrap(); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 254)))); assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 1, 1)))); assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); } #[test] fn test_allowed_ip_v4_catch_all() { let aip = AllowedIp::parse("0.0.0.0/0").unwrap(); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)))); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)))); } #[test] fn test_allowed_ip_v4_host() { let aip = AllowedIp::parse("10.0.0.5/32").unwrap(); assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))); assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 6)))); } #[test] fn test_allowed_ip_v6_match() { let aip = AllowedIp::parse("fd00::/64").unwrap(); assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)))); assert!(!aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd01, 0, 0, 0, 0, 0, 0, 1)))); } #[test] fn test_allowed_ip_v6_catch_all() { let aip = AllowedIp::parse("::/0").unwrap(); assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)))); } #[test] fn test_allowed_ip_cross_family_no_match() { let v4 = AllowedIp::parse("10.0.0.0/8").unwrap(); assert!(!v4.matches(IpAddr::V6(Ipv6Addr::LOCALHOST))); let v6 = AllowedIp::parse("::/0").unwrap(); assert!(!v6.matches(IpAddr::V4(Ipv4Addr::LOCALHOST))); } #[test] fn test_extract_dst_ip_v4() { // Minimal IPv4 header: version=4, IHL=5, total_length=20, dst at bytes 16-19 let mut pkt = [0u8; 20]; pkt[0] = 0x45; // version 4, IHL 5 pkt[16] = 10; pkt[17] = 0; pkt[18] = 0; pkt[19] = 1; assert_eq!( extract_dst_ip(&pkt), Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))) ); } #[test] fn test_extract_dst_ip_v6() { // Minimal IPv6 header: version=6, dst at bytes 24-39 let mut pkt = [0u8; 40]; pkt[0] = 0x60; // version 6 pkt[24] = 0xfd; pkt[39] = 0x01; let expected = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)); assert_eq!(extract_dst_ip(&pkt), Some(expected)); } #[test] fn test_extract_dst_ip_empty() { assert_eq!(extract_dst_ip(&[]), None); } #[test] fn test_loopback_tunnel() { // Two Tunn instances: server and client, exchanging packets in memory let (server_pub, server_priv) = generate_wg_keypair(); let (client_pub, client_priv) = generate_wg_keypair(); let server_secret = parse_private_key(&server_priv).unwrap(); let client_secret = parse_private_key(&client_priv).unwrap(); let server_public = parse_public_key(&server_pub).unwrap(); let client_public = parse_public_key(&client_pub).unwrap(); let mut server_tunn = Tunn::new( server_secret, client_public, None, None, 0, None, ); let mut client_tunn = Tunn::new( client_secret, server_public, None, None, 1, None, ); let mut buf_a = vec![0u8; 2048]; let mut buf_b = vec![0u8; 2048]; // Client initiates handshake let handshake_init = match client_tunn.encapsulate(&[], &mut buf_a) { TunnResult::WriteToNetwork(pkt) => pkt.to_vec(), other => panic!("Expected WriteToNetwork for handshake init, got {:?}", format!("{:?}", std::mem::discriminant(&other))), }; // Server processes handshake init let handshake_resp = match server_tunn.decapsulate(None, &handshake_init, &mut buf_b) { TunnResult::WriteToNetwork(pkt) => pkt.to_vec(), other => panic!("Expected WriteToNetwork for handshake resp, got {:?}", format!("{:?}", std::mem::discriminant(&other))), }; // Drain server loop { match server_tunn.decapsulate(None, &[], &mut buf_b) { TunnResult::WriteToNetwork(_) => {} _ => break, } } // Client processes handshake response match client_tunn.decapsulate(None, &handshake_resp, &mut buf_a) { TunnResult::WriteToNetwork(pkt) => { // Client might send a keepalive or transport data // Feed it to server let pkt_copy = pkt.to_vec(); let _ = server_tunn.decapsulate(None, &pkt_copy, &mut buf_b); } TunnResult::Done => {} _other => { // Drain loop { match client_tunn.decapsulate(None, &[], &mut buf_a) { TunnResult::WriteToNetwork(_) => {} _ => break, } } } } // Drain client loop { match client_tunn.decapsulate(None, &[], &mut buf_a) { TunnResult::WriteToNetwork(_) => {} _ => break, } } // Now try to send a fake IP packet from client to server let mut fake_ip = [0u8; 28]; fake_ip[0] = 0x45; // IPv4 fake_ip[2] = 0; fake_ip[3] = 28; // total length // Source IP (bytes 12-15): 10.0.0.2 (client) fake_ip[12] = 10; fake_ip[13] = 0; fake_ip[14] = 0; fake_ip[15] = 2; // Destination IP (bytes 16-19): 10.0.0.1 (server) fake_ip[16] = 10; fake_ip[17] = 0; fake_ip[18] = 0; fake_ip[19] = 1; match client_tunn.encapsulate(&fake_ip, &mut buf_a) { TunnResult::WriteToNetwork(encrypted) => { let encrypted_copy = encrypted.to_vec(); // Server decapsulates match server_tunn.decapsulate(None, &encrypted_copy, &mut buf_b) { TunnResult::WriteToTunnelV4(decrypted, src_addr) => { // src_addr is the source IP from the inner packet (for AllowedIPs check) assert_eq!(src_addr, Ipv4Addr::new(10, 0, 0, 2)); assert_eq!(&decrypted[..fake_ip.len()], &fake_ip); } TunnResult::WriteToNetwork(_pkt) => { // Might need another round trip, that's OK } _ => { // Session might not be fully established yet, acceptable } } } TunnResult::Err(_) => { // Session not yet established, acceptable in unit test } _ => {} } } }