use anyhow::Result; use bytes::BytesMut; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::Ipv4Addr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::sync::{mpsc, Mutex, RwLock}; use tracing::{info, error, warn}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; use crate::mtu::{MtuConfig, TunnelOverhead}; use crate::network::IpPool; use crate::ratelimit::TokenBucket; use crate::transport; use crate::transport_trait::{self, TransportSink, TransportStream}; use crate::quic_transport; /// Dead-peer timeout: 3x max keepalive interval (Healthy=60s). const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180); /// Server configuration (matches TS IVpnServerConfig). #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ServerConfig { pub listen_addr: String, pub tls_cert: Option, pub tls_key: Option, pub private_key: String, pub public_key: String, pub subnet: String, pub dns: Option>, pub mtu: Option, pub keepalive_interval_secs: Option, pub enable_nat: Option, /// Default rate limit for new clients (bytes/sec). None = unlimited. pub default_rate_limit_bytes_per_sec: Option, /// Default burst size for new clients (bytes). None = unlimited. pub default_burst_bytes: Option, /// Transport mode: "websocket" (default), "quic", or "both". pub transport_mode: Option, /// QUIC listen address (host:port). Defaults to listen_addr. pub quic_listen_addr: Option, /// QUIC idle timeout in seconds (default: 30). pub quic_idle_timeout_secs: Option, } /// Information about a connected client. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientInfo { pub client_id: String, pub assigned_ip: String, pub connected_since: String, pub bytes_sent: u64, pub bytes_received: u64, pub packets_dropped: u64, pub bytes_dropped: u64, pub last_keepalive_at: Option, pub keepalives_received: u64, pub rate_limit_bytes_per_sec: Option, pub burst_bytes: Option, } /// Server statistics. #[derive(Debug, Clone, Serialize, Default)] #[serde(rename_all = "camelCase")] pub struct ServerStatistics { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub keepalives_sent: u64, pub keepalives_received: u64, pub uptime_seconds: u64, pub active_clients: u64, pub total_connections: u64, } /// Shared server state. pub struct ServerState { pub config: ServerConfig, pub ip_pool: Mutex, pub clients: RwLock>, pub stats: RwLock, pub rate_limiters: Mutex>, pub mtu_config: MtuConfig, pub started_at: std::time::Instant, } /// The VPN server. pub struct VpnServer { state: Option>, shutdown_tx: Option>, } impl VpnServer { pub fn new() -> Self { Self { state: None, shutdown_tx: None, } } pub async fn start(&mut self, config: ServerConfig) -> Result<()> { if self.state.is_some() { anyhow::bail!("Server is already running"); } let ip_pool = IpPool::new(&config.subnet)?; if config.enable_nat.unwrap_or(false) { if let Err(e) = crate::network::enable_ip_forwarding() { warn!("Failed to enable IP forwarding: {}", e); } if let Ok(iface) = crate::network::get_default_interface() { if let Err(e) = crate::network::setup_nat(&config.subnet, &iface).await { warn!("Failed to setup NAT: {}", e); } } } let link_mtu = config.mtu.unwrap_or(1420); // Compute effective MTU from overhead let overhead = TunnelOverhead::default_overhead(); let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu)); let state = Arc::new(ServerState { config: config.clone(), ip_pool: Mutex::new(ip_pool), clients: RwLock::new(HashMap::new()), stats: RwLock::new(ServerStatistics::default()), rate_limiters: Mutex::new(HashMap::new()), mtu_config, started_at: std::time::Instant::now(), }); let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); self.state = Some(state.clone()); self.shutdown_tx = Some(shutdown_tx); let transport_mode = config.transport_mode.as_deref().unwrap_or("both"); let listen_addr = config.listen_addr.clone(); match transport_mode { "quic" => { let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone()); let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30); tokio::spawn(async move { if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await { error!("QUIC listener error: {}", e); } }); } "both" => { let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone()); let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30); let state2 = state.clone(); let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1); // Store second shutdown sender so both listeners stop let shutdown_tx_orig = self.shutdown_tx.take().unwrap(); let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1); self.shutdown_tx = Some(combined_tx); // Forward combined shutdown to both listeners tokio::spawn(async move { combined_rx.recv().await; let _ = shutdown_tx_orig.send(()).await; let _ = shutdown_tx2.send(()).await; }); tokio::spawn(async move { if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await { error!("WebSocket listener error: {}", e); } }); tokio::spawn(async move { if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await { error!("QUIC listener error: {}", e); } }); } _ => { // "websocket" (default) tokio::spawn(async move { if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await { error!("Server listener error: {}", e); } }); } } info!("VPN server started (transport: {})", transport_mode); Ok(()) } pub async fn stop(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()).await; } self.state = None; info!("VPN server stopped"); Ok(()) } pub fn get_status(&self) -> serde_json::Value { if let Some(ref state) = self.state { serde_json::json!({ "state": "connected", "connectedSince": format!("{:?}", state.started_at.elapsed()), }) } else { serde_json::json!({ "state": "disconnected" }) } } pub async fn get_statistics(&self) -> ServerStatistics { if let Some(ref state) = self.state { let mut stats = state.stats.read().await.clone(); stats.uptime_seconds = state.started_at.elapsed().as_secs(); stats.active_clients = state.clients.read().await.len() as u64; stats } else { ServerStatistics::default() } } pub async fn list_clients(&self) -> Vec { if let Some(ref state) = self.state { state.clients.read().await.values().cloned().collect() } else { Vec::new() } } pub async fn disconnect_client(&self, client_id: &str) -> Result<()> { if let Some(ref state) = self.state { let mut clients = state.clients.write().await; if let Some(client) = clients.remove(client_id) { let ip: Ipv4Addr = client.assigned_ip.parse()?; state.ip_pool.lock().await.release(&ip); state.rate_limiters.lock().await.remove(client_id); info!("Client {} disconnected", client_id); } } Ok(()) } /// Set a rate limit for a specific client. pub async fn set_client_rate_limit( &self, client_id: &str, rate_bytes_per_sec: u64, burst_bytes: u64, ) -> Result<()> { if let Some(ref state) = self.state { let mut limiters = state.rate_limiters.lock().await; if let Some(limiter) = limiters.get_mut(client_id) { limiter.update_limits(rate_bytes_per_sec, burst_bytes); } else { limiters.insert( client_id.to_string(), TokenBucket::new(rate_bytes_per_sec, burst_bytes), ); } // Update client info let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(client_id) { info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec); info.burst_bytes = Some(burst_bytes); } } Ok(()) } /// Remove rate limit for a specific client (unlimited). pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> { if let Some(ref state) = self.state { state.rate_limiters.lock().await.remove(client_id); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(client_id) { info.rate_limit_bytes_per_sec = None; info.burst_bytes = None; } } Ok(()) } } /// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off /// to the transport-agnostic `handle_client_connection`. async fn run_ws_listener( state: Arc, listen_addr: String, shutdown_rx: &mut mpsc::Receiver<()>, ) -> Result<()> { let listener = TcpListener::bind(&listen_addr).await?; info!("WebSocket server listening on {}", listen_addr); loop { tokio::select! { accept = listener.accept() => { match accept { Ok((stream, addr)) => { info!("New connection from {}", addr); let state = state.clone(); tokio::spawn(async move { match transport::accept_connection(stream).await { Ok(ws) => { let (sink, stream) = transport_trait::split_ws(ws); if let Err(e) = handle_client_connection( state, Box::new(sink), Box::new(stream), ).await { warn!("Client connection error: {}", e); } } Err(e) => { warn!("WebSocket upgrade failed: {}", e); } } }); } Err(e) => { error!("Accept error: {}", e); } } } _ = shutdown_rx.recv() => { info!("Shutdown signal received"); break; } } } Ok(()) } /// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic /// `handle_client_connection`. async fn run_quic_listener( state: Arc, listen_addr: String, idle_timeout_secs: u64, shutdown_rx: &mut mpsc::Receiver<()>, ) -> Result<()> { // Generate or use configured TLS certificate for QUIC let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) = (&state.config.tls_cert, &state.config.tls_key) { // Parse PEM certificates let certs: Vec> = rustls_pemfile::certs(&mut cert_pem.as_bytes()) .collect::, _>>()?; let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())? .ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?; (certs, key) } else { // Generate self-signed certificate let (certs, key) = quic_transport::generate_self_signed_cert()?; info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0])); (certs, key) }; let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig { listen_addr, cert_chain, private_key, idle_timeout_secs, })?; loop { tokio::select! { incoming = endpoint.accept() => { match incoming { Some(incoming) => { let state = state.clone(); tokio::spawn(async move { match incoming.await { Ok(conn) => { let remote = conn.remote_address(); info!("New QUIC connection from {}", remote); match quic_transport::accept_quic_connection(conn).await { Ok((sink, stream)) => { if let Err(e) = handle_client_connection( state, Box::new(sink), Box::new(stream), ).await { warn!("QUIC client error: {}", e); } } Err(e) => { warn!("QUIC stream accept failed: {}", e); } } } Err(e) => { warn!("QUIC handshake failed: {}", e); } } }); } None => { info!("QUIC endpoint closed"); break; } } } _ = shutdown_rx.recv() => { info!("QUIC shutdown signal received"); endpoint.close(0u32.into(), b"shutdown"); break; } } } Ok(()) } /// Transport-agnostic client handler. Performs the Noise NK handshake, registers /// the client, and runs the main packet forwarding loop. async fn handle_client_connection( state: Arc, mut sink: Box, mut stream: Box, ) -> Result<()> { let client_id = uuid_v4(); let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?; let server_private_key = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &state.config.private_key, )?; let mut responder = crypto::create_responder(&server_private_key)?; let mut buf = vec![0u8; 65535]; // Receive handshake init let init_msg = match stream.recv_reliable().await? { Some(data) => data, None => anyhow::bail!("Connection closed before handshake"), }; let mut frame_buf = BytesMut::from(&init_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete handshake frame"))?; if frame.packet_type != PacketType::HandshakeInit { anyhow::bail!("Expected HandshakeInit, got {:?}", frame.packet_type); } responder.read_message(&frame.payload, &mut buf)?; let len = responder.write_message(&[], &mut buf)?; let response_payload = buf[..len].to_vec(); let response_frame = Frame { packet_type: PacketType::HandshakeResp, payload: response_payload, }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; let mut noise_transport = responder.into_transport_mode()?; // Register client let default_rate = state.config.default_rate_limit_bytes_per_sec; let default_burst = state.config.default_burst_bytes; let client_info = ClientInfo { client_id: client_id.clone(), assigned_ip: assigned_ip.to_string(), connected_since: timestamp_now(), bytes_sent: 0, bytes_received: 0, packets_dropped: 0, bytes_dropped: 0, last_keepalive_at: None, keepalives_received: 0, rate_limit_bytes_per_sec: default_rate, burst_bytes: default_burst, }; state.clients.write().await.insert(client_id.clone(), client_info); // Set up rate limiter if defaults are configured if let (Some(rate), Some(burst)) = (default_rate, default_burst) { state .rate_limiters .lock() .await .insert(client_id.clone(), TokenBucket::new(rate, burst)); } { let mut stats = state.stats.write().await; stats.total_connections += 1; } // Send assigned IP info (encrypted), include effective MTU let ip_info = serde_json::json!({ "assignedIp": assigned_ip.to_string(), "gateway": state.ip_pool.lock().await.gateway_addr().to_string(), "mtu": state.config.mtu.unwrap_or(1420), "effectiveMtu": state.mtu_config.effective_mtu, }); let ip_info_bytes = serde_json::to_vec(&ip_info)?; let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?; let encrypted_info = Frame { packet_type: PacketType::IpPacket, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; info!("Client {} connected with IP {}", client_id, assigned_ip); // Main packet loop with dead-peer detection let mut last_activity = tokio::time::Instant::now(); loop { tokio::select! { msg = stream.recv_reliable() => { match msg { Ok(Some(data)) => { last_activity = tokio::time::Instant::now(); let mut frame_buf = BytesMut::from(&data[..]); match ::decode(&mut FrameCodec, &mut frame_buf) { Ok(Some(frame)) => match frame.packet_type { PacketType::IpPacket => { match noise_transport.read_message(&frame.payload, &mut buf) { Ok(len) => { // Rate limiting check let allowed = { let mut limiters = state.rate_limiters.lock().await; if let Some(limiter) = limiters.get_mut(&client_id) { limiter.try_consume(len) } else { true } }; if !allowed { let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.packets_dropped += 1; info.bytes_dropped += len as u64; } continue; } let mut stats = state.stats.write().await; stats.bytes_received += len as u64; stats.packets_received += 1; // Update per-client stats drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.bytes_received += len as u64; } } Err(e) => { warn!("Decrypt error from {}: {}", client_id, e); break; } } } PacketType::Keepalive => { // Echo the keepalive payload back in the ACK let ack_frame = Frame { packet_type: PacketType::KeepaliveAck, payload: frame.payload.clone(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; let mut stats = state.stats.write().await; stats.keepalives_received += 1; stats.keepalives_sent += 1; // Update per-client keepalive tracking drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.last_keepalive_at = Some(timestamp_now()); info.keepalives_received += 1; } } PacketType::Disconnect => { info!("Client {} sent disconnect", client_id); break; } _ => { warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type); } }, Ok(None) => { warn!("Incomplete frame from {}", client_id); } Err(e) => { warn!("Frame decode error from {}: {}", client_id, e); break; } } } Ok(None) => { info!("Client {} connection closed", client_id); break; } Err(e) => { warn!("Transport error from {}: {}", client_id, e); break; } } } _ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => { warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs()); break; } } } // Cleanup state.clients.write().await.remove(&client_id); state.ip_pool.lock().await.release(&assigned_ip); state.rate_limiters.lock().await.remove(&client_id); info!("Client {} disconnected, released IP {}", client_id, assigned_ip); Ok(()) } fn uuid_v4() -> String { use rand::Rng; let mut rng = rand::thread_rng(); let bytes: [u8; 16] = rng.gen(); format!( "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}", bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], ) } fn timestamp_now() -> String { use std::time::SystemTime; let duration = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default(); format!("{}", duration.as_secs()) }