use anyhow::Result; use bytes::BytesMut; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::Ipv4Addr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::sync::{mpsc, Mutex, RwLock}; use tracing::{info, error, warn}; use crate::acl; use crate::client_registry::{ClientEntry, ClientRegistry}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; use crate::mtu::{MtuConfig, TunnelOverhead}; use crate::network::IpPool; use crate::ratelimit::TokenBucket; use crate::transport; use crate::transport_trait::{self, TransportSink, TransportStream}; use crate::quic_transport; use crate::tunnel::{self, TunConfig}; /// Dead-peer timeout: 3x max keepalive interval (Healthy=60s). const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180); /// Server configuration (matches TS IVpnServerConfig). #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ServerConfig { pub listen_addr: String, pub tls_cert: Option, pub tls_key: Option, pub private_key: String, pub public_key: String, pub subnet: String, pub dns: Option>, pub mtu: Option, pub keepalive_interval_secs: Option, pub enable_nat: Option, /// Forwarding mode: "tun" (kernel TUN, requires root), "socket" (userspace NAT), /// or "testing" (monitoring only, no forwarding). Default: "testing". pub forwarding_mode: Option, /// Default rate limit for new clients (bytes/sec). None = unlimited. pub default_rate_limit_bytes_per_sec: Option, /// Default burst size for new clients (bytes). None = unlimited. pub default_burst_bytes: Option, /// Transport mode: "websocket" (default), "quic", or "both". pub transport_mode: Option, /// QUIC listen address (host:port). Defaults to listen_addr. pub quic_listen_addr: Option, /// QUIC idle timeout in seconds (default: 30). pub quic_idle_timeout_secs: Option, /// Pre-registered clients for IK authentication. pub clients: Option>, /// Enable PROXY protocol v2 parsing on incoming WebSocket connections. /// SECURITY: Must be false when accepting direct client connections. pub proxy_protocol: Option, /// Server-level IP block list — applied at TCP accept, before Noise handshake. pub connection_ip_block_list: Option>, } /// Information about a connected client. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientInfo { pub client_id: String, pub assigned_ip: String, pub connected_since: String, pub bytes_sent: u64, pub bytes_received: u64, pub packets_dropped: u64, pub bytes_dropped: u64, pub last_keepalive_at: Option, pub keepalives_received: u64, pub rate_limit_bytes_per_sec: Option, pub burst_bytes: Option, /// Client's authenticated Noise IK public key (base64). pub authenticated_key: String, /// Registered client ID from the client registry. pub registered_client_id: String, /// Real client IP:port (from PROXY protocol header or direct TCP connection). pub remote_addr: Option, } /// Server statistics. #[derive(Debug, Clone, Serialize, Default)] #[serde(rename_all = "camelCase")] pub struct ServerStatistics { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub keepalives_sent: u64, pub keepalives_received: u64, pub uptime_seconds: u64, pub active_clients: u64, pub total_connections: u64, } /// The forwarding engine determines how decrypted IP packets are routed. pub enum ForwardingEngine { /// Kernel TUN device — packets written to the TUN, kernel handles routing. Tun(tokio::io::WriteHalf), /// Userspace NAT — packets sent to smoltcp-based NAT engine via channel. Socket(mpsc::Sender>), /// Testing/monitoring — packets are counted but not forwarded. Testing, } /// Shared server state. pub struct ServerState { pub config: ServerConfig, pub ip_pool: Mutex, pub clients: RwLock>, pub stats: RwLock, pub rate_limiters: Mutex>, pub mtu_config: MtuConfig, pub started_at: std::time::Instant, pub client_registry: RwLock, /// The forwarding engine for decrypted IP packets. pub forwarding_engine: Mutex, /// Routing table: assigned VPN IP → channel sender for return packets. pub tun_routes: RwLock>>>, /// Shutdown signal for the forwarding background task (TUN reader or NAT engine). pub tun_shutdown: mpsc::Sender<()>, } /// The VPN server. pub struct VpnServer { state: Option>, shutdown_tx: Option>, } impl VpnServer { pub fn new() -> Self { Self { state: None, shutdown_tx: None, } } pub async fn start(&mut self, config: ServerConfig) -> Result<()> { if self.state.is_some() { anyhow::bail!("Server is already running"); } let ip_pool = IpPool::new(&config.subnet)?; if config.enable_nat.unwrap_or(false) { if let Err(e) = crate::network::enable_ip_forwarding() { warn!("Failed to enable IP forwarding: {}", e); } if let Ok(iface) = crate::network::get_default_interface() { if let Err(e) = crate::network::setup_nat(&config.subnet, &iface).await { warn!("Failed to setup NAT: {}", e); } } } let link_mtu = config.mtu.unwrap_or(1420); let mode = config.forwarding_mode.as_deref().unwrap_or("testing"); let gateway_ip = ip_pool.gateway_addr(); // Create forwarding engine based on mode enum ForwardingSetup { Tun { writer: tokio::io::WriteHalf, reader: tokio::io::ReadHalf, shutdown_rx: mpsc::Receiver<()>, }, Socket { packet_tx: mpsc::Sender>, packet_rx: mpsc::Receiver>, shutdown_rx: mpsc::Receiver<()>, }, Testing, } let (setup, fwd_shutdown_tx) = match mode { "tun" => { let tun_config = TunConfig { name: "svpn0".to_string(), address: gateway_ip, netmask: Ipv4Addr::new(255, 255, 255, 0), mtu: link_mtu, }; let tun_device = tunnel::create_tun(&tun_config)?; tunnel::add_route(&config.subnet, &tun_config.name).await?; let (reader, writer) = tokio::io::split(tun_device); let (tx, rx) = mpsc::channel::<()>(1); (ForwardingSetup::Tun { writer, reader, shutdown_rx: rx }, tx) } "socket" => { info!("Starting userspace NAT forwarding (no root required)"); let (packet_tx, packet_rx) = mpsc::channel::>(4096); let (tx, rx) = mpsc::channel::<()>(1); (ForwardingSetup::Socket { packet_tx, packet_rx, shutdown_rx: rx }, tx) } _ => { info!("Forwarding disabled (testing/monitoring mode)"); let (tx, _rx) = mpsc::channel::<()>(1); (ForwardingSetup::Testing, tx) } }; // Compute effective MTU from overhead let overhead = TunnelOverhead::default_overhead(); let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu)); // Build client registry from config let registry = ClientRegistry::from_entries( config.clients.clone().unwrap_or_default() )?; info!("Client registry loaded with {} entries", registry.len()); let state = Arc::new(ServerState { config: config.clone(), ip_pool: Mutex::new(ip_pool), clients: RwLock::new(HashMap::new()), stats: RwLock::new(ServerStatistics::default()), rate_limiters: Mutex::new(HashMap::new()), mtu_config, started_at: std::time::Instant::now(), client_registry: RwLock::new(registry), forwarding_engine: Mutex::new(ForwardingEngine::Testing), tun_routes: RwLock::new(HashMap::new()), tun_shutdown: fwd_shutdown_tx, }); // Spawn the forwarding background task and set the engine match setup { ForwardingSetup::Tun { writer, reader, shutdown_rx } => { *state.forwarding_engine.lock().await = ForwardingEngine::Tun(writer); let tun_state = state.clone(); tokio::spawn(async move { if let Err(e) = run_tun_reader(tun_state, reader, shutdown_rx).await { error!("TUN reader error: {}", e); } }); } ForwardingSetup::Socket { packet_tx, packet_rx, shutdown_rx } => { *state.forwarding_engine.lock().await = ForwardingEngine::Socket(packet_tx); let nat_engine = crate::userspace_nat::NatEngine::new( gateway_ip, link_mtu as usize, state.clone(), ); tokio::spawn(async move { if let Err(e) = nat_engine.run(packet_rx, shutdown_rx).await { error!("NAT engine error: {}", e); } }); } ForwardingSetup::Testing => {} } let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); self.state = Some(state.clone()); self.shutdown_tx = Some(shutdown_tx); let transport_mode = config.transport_mode.as_deref().unwrap_or("both"); let listen_addr = config.listen_addr.clone(); match transport_mode { "quic" => { let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone()); let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30); tokio::spawn(async move { if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await { error!("QUIC listener error: {}", e); } }); } "both" => { let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone()); let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30); let state2 = state.clone(); let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1); // Store second shutdown sender so both listeners stop let shutdown_tx_orig = self.shutdown_tx.take().unwrap(); let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1); self.shutdown_tx = Some(combined_tx); // Forward combined shutdown to both listeners tokio::spawn(async move { combined_rx.recv().await; let _ = shutdown_tx_orig.send(()).await; let _ = shutdown_tx2.send(()).await; }); tokio::spawn(async move { if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await { error!("WebSocket listener error: {}", e); } }); tokio::spawn(async move { if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await { error!("QUIC listener error: {}", e); } }); } _ => { // "websocket" (default) tokio::spawn(async move { if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await { error!("Server listener error: {}", e); } }); } } info!("VPN server started (transport: {})", transport_mode); Ok(()) } pub async fn stop(&mut self) -> Result<()> { if let Some(ref state) = self.state { let mode = state.config.forwarding_mode.as_deref().unwrap_or("testing"); match mode { "tun" => { let _ = state.tun_shutdown.send(()).await; *state.forwarding_engine.lock().await = ForwardingEngine::Testing; if let Err(e) = tunnel::remove_route(&state.config.subnet, "svpn0").await { warn!("Failed to remove TUN route: {}", e); } } "socket" => { let _ = state.tun_shutdown.send(()).await; *state.forwarding_engine.lock().await = ForwardingEngine::Testing; } _ => {} } // Clean up NAT rules if state.config.enable_nat.unwrap_or(false) { if let Ok(iface) = crate::network::get_default_interface() { if let Err(e) = crate::network::remove_nat(&state.config.subnet, &iface).await { warn!("Failed to remove NAT rules: {}", e); } } } } if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()).await; } self.state = None; info!("VPN server stopped"); Ok(()) } pub fn get_status(&self) -> serde_json::Value { if let Some(ref state) = self.state { serde_json::json!({ "state": "connected", "connectedSince": format!("{:?}", state.started_at.elapsed()), }) } else { serde_json::json!({ "state": "disconnected" }) } } pub async fn get_statistics(&self) -> ServerStatistics { if let Some(ref state) = self.state { let mut stats = state.stats.read().await.clone(); stats.uptime_seconds = state.started_at.elapsed().as_secs(); stats.active_clients = state.clients.read().await.len() as u64; stats } else { ServerStatistics::default() } } pub async fn list_clients(&self) -> Vec { if let Some(ref state) = self.state { state.clients.read().await.values().cloned().collect() } else { Vec::new() } } pub async fn disconnect_client(&self, client_id: &str) -> Result<()> { if let Some(ref state) = self.state { let mut clients = state.clients.write().await; if let Some(client) = clients.remove(client_id) { let ip: Ipv4Addr = client.assigned_ip.parse()?; state.ip_pool.lock().await.release(&ip); state.rate_limiters.lock().await.remove(client_id); info!("Client {} disconnected", client_id); } } Ok(()) } /// Set a rate limit for a specific client. pub async fn set_client_rate_limit( &self, client_id: &str, rate_bytes_per_sec: u64, burst_bytes: u64, ) -> Result<()> { if let Some(ref state) = self.state { let mut limiters = state.rate_limiters.lock().await; if let Some(limiter) = limiters.get_mut(client_id) { limiter.update_limits(rate_bytes_per_sec, burst_bytes); } else { limiters.insert( client_id.to_string(), TokenBucket::new(rate_bytes_per_sec, burst_bytes), ); } // Update client info let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(client_id) { info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec); info.burst_bytes = Some(burst_bytes); } } Ok(()) } /// Remove rate limit for a specific client (unlimited). pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> { if let Some(ref state) = self.state { state.rate_limiters.lock().await.remove(client_id); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(client_id) { info.rate_limit_bytes_per_sec = None; info.burst_bytes = None; } } Ok(()) } // ── Client Registry (Hub) Methods ─────────────────────────────────── /// Create a new client entry. Generates keypairs and assigns an IP. /// Returns a JSON value with the full config bundle including secrets. pub async fn create_client(&self, partial: serde_json::Value) -> Result { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; let client_id = partial.get("clientId") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("clientId is required"))? .to_string(); // Generate Noise IK keypair for the client let (noise_pub, noise_priv) = crypto::generate_keypair()?; // Generate WireGuard keypair for the client let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair(); // Allocate a VPN IP let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?; // Build entry from partial + generated values let entry = ClientEntry { client_id: client_id.clone(), public_key: noise_pub.clone(), wg_public_key: Some(wg_pub.clone()), security: serde_json::from_value( partial.get("security").cloned().unwrap_or(serde_json::Value::Null) ).ok(), priority: partial.get("priority").and_then(|v| v.as_u64()).map(|v| v as u32), enabled: partial.get("enabled").and_then(|v| v.as_bool()).or(Some(true)), tags: partial.get("tags").and_then(|v| { v.as_array().map(|a| a.iter().filter_map(|s| s.as_str().map(String::from)).collect()) }), description: partial.get("description").and_then(|v| v.as_str()).map(String::from), expires_at: partial.get("expiresAt").and_then(|v| v.as_str()).map(String::from), assigned_ip: Some(assigned_ip.to_string()), }; // Add to registry state.client_registry.write().await.add(entry.clone())?; // Build SmartVPN client config let smartvpn_config = serde_json::json!({ "serverUrl": format!("wss://{}", state.config.listen_addr.replace("0.0.0.0", "localhost")), "serverPublicKey": state.config.public_key, "clientPrivateKey": noise_priv, "clientPublicKey": noise_pub, "dns": state.config.dns, "mtu": state.config.mtu, "keepaliveIntervalSecs": state.config.keepalive_interval_secs, }); // Build WireGuard config string let wg_config = format!( "[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n", wg_priv, assigned_ip, state.config.dns.as_ref() .map(|d| format!("DNS = {}", d.join(", "))) .unwrap_or_default(), state.config.public_key, state.config.listen_addr, ); let entry_json = serde_json::to_value(&entry)?; Ok(serde_json::json!({ "entry": entry_json, "smartvpnConfig": smartvpn_config, "wireguardConfig": wg_config, "secrets": { "noisePrivateKey": noise_priv, "wgPrivateKey": wg_priv, } })) } /// Remove a registered client from the registry (and disconnect if connected). pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; let entry = state.client_registry.write().await.remove(client_id)?; // Release the IP if assigned if let Some(ref ip_str) = entry.assigned_ip { if let Ok(ip) = ip_str.parse::() { state.ip_pool.lock().await.release(&ip); } } // Disconnect if currently connected let _ = self.disconnect_client(client_id).await; Ok(()) } /// Get a registered client by ID. pub async fn get_registered_client(&self, client_id: &str) -> Result { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; let registry = state.client_registry.read().await; let entry = registry.get_by_id(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; Ok(serde_json::to_value(entry)?) } /// List all registered clients. pub async fn list_registered_clients(&self) -> Vec { if let Some(ref state) = self.state { state.client_registry.read().await.list().into_iter().cloned().collect() } else { Vec::new() } } /// Update a registered client's fields. pub async fn update_registered_client(&self, client_id: &str, update: serde_json::Value) -> Result<()> { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; state.client_registry.write().await.update(client_id, |entry| { if let Some(security) = update.get("security") { entry.security = serde_json::from_value(security.clone()).ok(); } if let Some(priority) = update.get("priority").and_then(|v| v.as_u64()) { entry.priority = Some(priority as u32); } if let Some(enabled) = update.get("enabled").and_then(|v| v.as_bool()) { entry.enabled = Some(enabled); } if let Some(tags) = update.get("tags").and_then(|v| v.as_array()) { entry.tags = Some(tags.iter().filter_map(|s| s.as_str().map(String::from)).collect()); } if let Some(desc) = update.get("description").and_then(|v| v.as_str()) { entry.description = Some(desc.to_string()); } if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) { entry.expires_at = Some(expires.to_string()); } })?; Ok(()) } /// Enable a registered client. pub async fn enable_client(&self, client_id: &str) -> Result<()> { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; state.client_registry.write().await.update(client_id, |entry| { entry.enabled = Some(true); }) } /// Disable a registered client (also disconnects if connected). pub async fn disable_client(&self, client_id: &str) -> Result<()> { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; state.client_registry.write().await.update(client_id, |entry| { entry.enabled = Some(false); })?; // Disconnect if currently connected let _ = self.disconnect_client(client_id).await; Ok(()) } /// Rotate a client's keys. Returns a new config bundle with fresh keypairs. pub async fn rotate_client_key(&self, client_id: &str) -> Result { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; let (noise_pub, noise_priv) = crypto::generate_keypair()?; let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair(); state.client_registry.write().await.rotate_key( client_id, noise_pub.clone(), Some(wg_pub.clone()), )?; // Disconnect existing connection (old key is no longer valid) let _ = self.disconnect_client(client_id).await; // Get updated entry for the config bundle let entry_json = self.get_registered_client(client_id).await?; let assigned_ip = entry_json.get("assignedIp") .and_then(|v| v.as_str()) .unwrap_or("0.0.0.0"); let smartvpn_config = serde_json::json!({ "serverUrl": format!("wss://{}", state.config.listen_addr.replace("0.0.0.0", "localhost")), "serverPublicKey": state.config.public_key, "clientPrivateKey": noise_priv, "clientPublicKey": noise_pub, "dns": state.config.dns, "mtu": state.config.mtu, "keepaliveIntervalSecs": state.config.keepalive_interval_secs, }); let wg_config = format!( "[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n", wg_priv, assigned_ip, state.config.dns.as_ref() .map(|d| format!("DNS = {}", d.join(", "))) .unwrap_or_default(), state.config.public_key, state.config.listen_addr, ); Ok(serde_json::json!({ "entry": entry_json, "smartvpnConfig": smartvpn_config, "wireguardConfig": wg_config, "secrets": { "noisePrivateKey": noise_priv, "wgPrivateKey": wg_priv, } })) } /// Export a client config (without secrets) in the specified format. pub async fn export_client_config(&self, client_id: &str, format: &str) -> Result { let state = self.state.as_ref() .ok_or_else(|| anyhow::anyhow!("Server not running"))?; let registry = state.client_registry.read().await; let entry = registry.get_by_id(client_id) .ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?; match format { "smartvpn" => { Ok(serde_json::json!({ "config": { "serverUrl": format!("wss://{}", state.config.listen_addr.replace("0.0.0.0", "localhost")), "serverPublicKey": state.config.public_key, "clientPublicKey": entry.public_key, "dns": state.config.dns, "mtu": state.config.mtu, "keepaliveIntervalSecs": state.config.keepalive_interval_secs, } })) } "wireguard" => { let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0"); let config = format!( "[Interface]\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n", assigned_ip, state.config.dns.as_ref() .map(|d| format!("DNS = {}", d.join(", "))) .unwrap_or_default(), state.config.public_key, state.config.listen_addr, ); Ok(serde_json::json!({ "config": config })) } _ => anyhow::bail!("Unknown format: {}", format), } } } /// WebSocket listener — accepts TCP connections, optionally parses PROXY protocol v2, /// upgrades to WS, then hands off to `handle_client_connection`. async fn run_ws_listener( state: Arc, listen_addr: String, shutdown_rx: &mut mpsc::Receiver<()>, ) -> Result<()> { let listener = TcpListener::bind(&listen_addr).await?; info!("WebSocket server listening on {}", listen_addr); loop { tokio::select! { accept = listener.accept() => { match accept { Ok((mut tcp_stream, tcp_addr)) => { info!("New connection from {}", tcp_addr); let state = state.clone(); tokio::spawn(async move { // Phase 0: Parse PROXY protocol v2 header if enabled let remote_addr = if state.config.proxy_protocol.unwrap_or(false) { match crate::proxy_protocol::read_proxy_header(&mut tcp_stream).await { Ok(header) if header.is_local => { info!("PP v2 LOCAL probe from {}", tcp_addr); return; // Health check — close gracefully } Ok(header) => { info!("PP v2: real client {} (via {})", header.src_addr, tcp_addr); Some(header.src_addr) } Err(e) => { warn!("PP v2 parse failed from {}: {}", tcp_addr, e); return; // Drop connection } } } else { Some(tcp_addr) // Direct connection — use TCP SocketAddr }; // Phase 1: Server-level connection IP block list (pre-handshake) if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) { if !block_list.is_empty() { if let std::net::IpAddr::V4(v4) = addr.ip() { if acl::is_connection_blocked(v4, block_list) { warn!("Connection blocked by server IP block list: {}", addr); return; } } } } // Phase 2: WebSocket upgrade + VPN handshake match transport::accept_connection(tcp_stream).await { Ok(ws) => { let (sink, stream) = transport_trait::split_ws(ws); if let Err(e) = handle_client_connection( state, Box::new(sink), Box::new(stream), remote_addr, ).await { warn!("Client connection error: {}", e); } } Err(e) => { warn!("WebSocket upgrade failed: {}", e); } } }); } Err(e) => { error!("Accept error: {}", e); } } } _ = shutdown_rx.recv() => { info!("Shutdown signal received"); break; } } } Ok(()) } /// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic /// `handle_client_connection`. async fn run_quic_listener( state: Arc, listen_addr: String, idle_timeout_secs: u64, shutdown_rx: &mut mpsc::Receiver<()>, ) -> Result<()> { // Generate or use configured TLS certificate for QUIC let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) = (&state.config.tls_cert, &state.config.tls_key) { // Parse PEM certificates let certs: Vec> = rustls_pemfile::certs(&mut cert_pem.as_bytes()) .collect::, _>>()?; let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())? .ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?; (certs, key) } else { // Generate self-signed certificate let (certs, key) = quic_transport::generate_self_signed_cert()?; info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0])); (certs, key) }; let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig { listen_addr, cert_chain, private_key, idle_timeout_secs, })?; loop { tokio::select! { incoming = endpoint.accept() => { match incoming { Some(incoming) => { let state = state.clone(); tokio::spawn(async move { match incoming.await { Ok(conn) => { let remote = conn.remote_address(); info!("New QUIC connection from {}", remote); match quic_transport::accept_quic_connection(conn).await { Ok((sink, stream)) => { if let Err(e) = handle_client_connection( state, Box::new(sink), Box::new(stream), Some(remote), ).await { warn!("QUIC client error: {}", e); } } Err(e) => { warn!("QUIC stream accept failed: {}", e); } } } Err(e) => { warn!("QUIC handshake failed: {}", e); } } }); } None => { info!("QUIC endpoint closed"); break; } } } _ = shutdown_rx.recv() => { info!("QUIC shutdown signal received"); endpoint.close(0u32.into(), b"shutdown"); break; } } } Ok(()) } /// TUN reader task: reads IP packets from the TUN device and dispatches them /// to the correct client via the routing table. async fn run_tun_reader( state: Arc, mut tun_reader: tokio::io::ReadHalf, mut shutdown_rx: mpsc::Receiver<()>, ) -> Result<()> { use tokio::io::AsyncReadExt; let mut buf = vec![0u8; 65536]; loop { tokio::select! { result = tun_reader.read(&mut buf) => { let n = match result { Ok(0) => { info!("TUN reader: device closed"); break; } Ok(n) => n, Err(e) => { error!("TUN reader error: {}", e); break; } }; // Extract destination IP from the raw IP packet let dst_ip = match tunnel::extract_dst_ip(&buf[..n]) { Some(std::net::IpAddr::V4(v4)) => v4, _ => continue, // IPv6 or malformed — skip }; // Look up client by destination IP let routes = state.tun_routes.read().await; if let Some(sender) = routes.get(&dst_ip) { if sender.try_send(buf[..n].to_vec()).is_err() { // Channel full or closed — drop packet (correct for IP best-effort) } } } _ = shutdown_rx.recv() => { info!("TUN reader shutting down"); break; } } } Ok(()) } /// Transport-agnostic client handler. Performs the Noise IK handshake, authenticates /// the client against the registry, and runs the main packet forwarding loop. async fn handle_client_connection( state: Arc, mut sink: Box, mut stream: Box, remote_addr: Option, ) -> Result<()> { let server_private_key = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &state.config.private_key, )?; // Noise IK handshake (server side = responder) let mut responder = crypto::create_responder(&server_private_key)?; let mut buf = vec![0u8; 65535]; // Receive handshake init (-> e, es, s, ss) let init_msg = match stream.recv_reliable().await? { Some(data) => data, None => anyhow::bail!("Connection closed before handshake"), }; let mut frame_buf = BytesMut::from(&init_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete handshake frame"))?; if frame.packet_type != PacketType::HandshakeInit { anyhow::bail!("Expected HandshakeInit, got {:?}", frame.packet_type); } responder.read_message(&frame.payload, &mut buf)?; // Extract client's static public key BEFORE entering transport mode let client_pub_key_bytes = responder .get_remote_static() .ok_or_else(|| anyhow::anyhow!("IK handshake: no client static key received"))? .to_vec(); let client_pub_key_b64 = base64::Engine::encode( &base64::engine::general_purpose::STANDARD, &client_pub_key_bytes, ); // Verify client against registry let (registered_client_id, client_security) = { let registry = state.client_registry.read().await; if !registry.is_authorized(&client_pub_key_b64) { warn!("Rejecting unauthorized client with key {}", &client_pub_key_b64[..8]); // Send handshake response but then disconnect let len = responder.write_message(&[], &mut buf)?; let response_frame = Frame { packet_type: PacketType::HandshakeResp, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; // Send disconnect frame let disconnect_frame = Frame { packet_type: PacketType::Disconnect, payload: Vec::new(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?; let _ = sink.send_reliable(frame_bytes.to_vec()).await; anyhow::bail!("Client not authorized"); } let entry = registry.get_by_key(&client_pub_key_b64).unwrap(); (entry.client_id.clone(), entry.security.clone()) }; // Complete handshake (<- e, ee, se) let len = responder.write_message(&[], &mut buf)?; let response_payload = buf[..len].to_vec(); let response_frame = Frame { packet_type: PacketType::HandshakeResp, payload: response_payload, }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; let mut noise_transport = responder.into_transport_mode()?; // Connection-level ACL: check real client IP against per-client ipAllowList/ipBlockList if let (Some(ref sec), Some(ref addr)) = (&client_security, &remote_addr) { if let std::net::IpAddr::V4(v4) = addr.ip() { if !acl::is_source_allowed( v4, sec.ip_allow_list.as_deref(), sec.ip_block_list.as_deref(), ) { warn!("Connection-level ACL denied client {} from IP {}", registered_client_id, addr); let disconnect_frame = Frame { packet_type: PacketType::Disconnect, payload: Vec::new() }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?; let _ = sink.send_reliable(frame_bytes.to_vec()).await; anyhow::bail!("Connection denied: source IP {} not allowed for client {}", addr, registered_client_id); } } } // Use the registered client ID as the connection ID let client_id = registered_client_id.clone(); // Allocate IP let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?; // Create return-packet channel for forwarding engine -> client let (tun_return_tx, mut tun_return_rx) = mpsc::channel::>(256); let fwd_mode = state.config.forwarding_mode.as_deref().unwrap_or("testing"); let forwarding_active = fwd_mode == "tun" || fwd_mode == "socket"; if forwarding_active { state.tun_routes.write().await.insert(assigned_ip, tun_return_tx); } // Determine rate limits: per-client security overrides server defaults let (rate_limit, burst) = if let Some(ref sec) = client_security { if let Some(ref rl) = sec.rate_limit { (Some(rl.bytes_per_sec), Some(rl.burst_bytes)) } else { (state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes) } } else { (state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes) }; // Register connected client let client_info = ClientInfo { client_id: client_id.clone(), assigned_ip: assigned_ip.to_string(), connected_since: timestamp_now(), bytes_sent: 0, bytes_received: 0, packets_dropped: 0, bytes_dropped: 0, last_keepalive_at: None, keepalives_received: 0, rate_limit_bytes_per_sec: rate_limit, burst_bytes: burst, authenticated_key: client_pub_key_b64.clone(), registered_client_id: registered_client_id.clone(), remote_addr: remote_addr.map(|a| a.to_string()), }; state.clients.write().await.insert(client_id.clone(), client_info); // Set up rate limiter if let (Some(rate), Some(burst)) = (rate_limit, burst) { state .rate_limiters .lock() .await .insert(client_id.clone(), TokenBucket::new(rate, burst)); } { let mut stats = state.stats.write().await; stats.total_connections += 1; } // Send assigned IP info (encrypted), include effective MTU let ip_info = serde_json::json!({ "assignedIp": assigned_ip.to_string(), "gateway": state.ip_pool.lock().await.gateway_addr().to_string(), "mtu": state.config.mtu.unwrap_or(1420), "effectiveMtu": state.mtu_config.effective_mtu, }); let ip_info_bytes = serde_json::to_vec(&ip_info)?; let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?; let encrypted_info = Frame { packet_type: PacketType::IpPacket, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; info!("Client {} ({}) connected with IP {} from {}", registered_client_id, &client_pub_key_b64[..8], assigned_ip, remote_addr.map(|a| a.to_string()).unwrap_or_else(|| "unknown".to_string())); // Main packet loop with dead-peer detection let mut last_activity = tokio::time::Instant::now(); loop { tokio::select! { msg = stream.recv_reliable() => { match msg { Ok(Some(data)) => { last_activity = tokio::time::Instant::now(); let mut frame_buf = BytesMut::from(&data[..]); match ::decode(&mut FrameCodec, &mut frame_buf) { Ok(Some(frame)) => match frame.packet_type { PacketType::IpPacket => { match noise_transport.read_message(&frame.payload, &mut buf) { Ok(len) => { // ACL check on decrypted packet if let Some(ref sec) = client_security { if len >= 20 { // Extract src/dst from IPv4 header let src = Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]); let dst = Ipv4Addr::new(buf[16], buf[17], buf[18], buf[19]); let acl_result = acl::check_acl(sec, src, dst); if acl_result != acl::AclResult::Allow { let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.packets_dropped += 1; info.bytes_dropped += len as u64; } continue; } } } // Rate limiting check let allowed = { let mut limiters = state.rate_limiters.lock().await; if let Some(limiter) = limiters.get_mut(&client_id) { limiter.try_consume(len) } else { true } }; if !allowed { let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.packets_dropped += 1; info.bytes_dropped += len as u64; } continue; } let mut stats = state.stats.write().await; stats.bytes_received += len as u64; stats.packets_received += 1; // Update per-client stats drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.bytes_received += len as u64; } drop(clients); // Forward decrypted packet via the active engine { let mut engine = state.forwarding_engine.lock().await; match &mut *engine { ForwardingEngine::Tun(writer) => { use tokio::io::AsyncWriteExt; if let Err(e) = writer.write_all(&buf[..len]).await { warn!("TUN write error for client {}: {}", client_id, e); } } ForwardingEngine::Socket(sender) => { let _ = sender.try_send(buf[..len].to_vec()); } ForwardingEngine::Testing => {} } } } Err(e) => { warn!("Decrypt error from {}: {}", client_id, e); break; } } } PacketType::Keepalive => { // Echo the keepalive payload back in the ACK let ack_frame = Frame { packet_type: PacketType::KeepaliveAck, payload: frame.payload.clone(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; let mut stats = state.stats.write().await; stats.keepalives_received += 1; stats.keepalives_sent += 1; // Update per-client keepalive tracking drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.last_keepalive_at = Some(timestamp_now()); info.keepalives_received += 1; } } PacketType::Disconnect => { info!("Client {} sent disconnect", client_id); break; } _ => { warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type); } }, Ok(None) => { warn!("Incomplete frame from {}", client_id); } Err(e) => { warn!("Frame decode error from {}: {}", client_id, e); break; } } } Ok(None) => { info!("Client {} connection closed", client_id); break; } Err(e) => { warn!("Transport error from {}: {}", client_id, e); break; } } } // Return packets from TUN device destined for this client Some(packet) = tun_return_rx.recv() => { let pkt_len = packet.len(); match noise_transport.write_message(&packet, &mut buf) { Ok(len) => { let frame = Frame { packet_type: PacketType::IpPacket, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode( &mut FrameCodec, frame, &mut frame_bytes )?; sink.send_reliable(frame_bytes.to_vec()).await?; // Update stats let mut stats = state.stats.write().await; stats.bytes_sent += pkt_len as u64; stats.packets_sent += 1; drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.bytes_sent += pkt_len as u64; } } Err(e) => { warn!("Noise encrypt error for return packet to {}: {}", client_id, e); break; } } } _ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => { warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs()); break; } } } // Cleanup if forwarding_active { state.tun_routes.write().await.remove(&assigned_ip); } state.clients.write().await.remove(&client_id); state.ip_pool.lock().await.release(&assigned_ip); state.rate_limiters.lock().await.remove(&client_id); info!("Client {} disconnected, released IP {}", client_id, assigned_ip); Ok(()) } fn timestamp_now() -> String { use std::time::SystemTime; let duration = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default(); format!("{}", duration.as_secs()) }