use anyhow::Result; use bytes::BytesMut; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::Ipv4Addr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio_tungstenite::tungstenite::Message; use tracing::{info, error, warn}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; use crate::mtu::{MtuConfig, TunnelOverhead}; use crate::network::IpPool; use crate::ratelimit::TokenBucket; use crate::transport; /// Dead-peer timeout: 3x max keepalive interval (Healthy=60s). const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180); /// Server configuration (matches TS IVpnServerConfig). #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ServerConfig { pub listen_addr: String, pub tls_cert: Option, pub tls_key: Option, pub private_key: String, pub public_key: String, pub subnet: String, pub dns: Option>, pub mtu: Option, pub keepalive_interval_secs: Option, pub enable_nat: Option, /// Default rate limit for new clients (bytes/sec). None = unlimited. pub default_rate_limit_bytes_per_sec: Option, /// Default burst size for new clients (bytes). None = unlimited. pub default_burst_bytes: Option, } /// Information about a connected client. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientInfo { pub client_id: String, pub assigned_ip: String, pub connected_since: String, pub bytes_sent: u64, pub bytes_received: u64, pub packets_dropped: u64, pub bytes_dropped: u64, pub last_keepalive_at: Option, pub keepalives_received: u64, pub rate_limit_bytes_per_sec: Option, pub burst_bytes: Option, } /// Server statistics. #[derive(Debug, Clone, Serialize, Default)] #[serde(rename_all = "camelCase")] pub struct ServerStatistics { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub keepalives_sent: u64, pub keepalives_received: u64, pub uptime_seconds: u64, pub active_clients: u64, pub total_connections: u64, } /// Shared server state. pub struct ServerState { pub config: ServerConfig, pub ip_pool: Mutex, pub clients: RwLock>, pub stats: RwLock, pub rate_limiters: Mutex>, pub mtu_config: MtuConfig, pub started_at: std::time::Instant, } /// The VPN server. pub struct VpnServer { state: Option>, shutdown_tx: Option>, } impl VpnServer { pub fn new() -> Self { Self { state: None, shutdown_tx: None, } } pub async fn start(&mut self, config: ServerConfig) -> Result<()> { if self.state.is_some() { anyhow::bail!("Server is already running"); } let ip_pool = IpPool::new(&config.subnet)?; if config.enable_nat.unwrap_or(false) { if let Err(e) = crate::network::enable_ip_forwarding() { warn!("Failed to enable IP forwarding: {}", e); } if let Ok(iface) = crate::network::get_default_interface() { if let Err(e) = crate::network::setup_nat(&config.subnet, &iface).await { warn!("Failed to setup NAT: {}", e); } } } let link_mtu = config.mtu.unwrap_or(1420); // Compute effective MTU from overhead let overhead = TunnelOverhead::default_overhead(); let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu)); let state = Arc::new(ServerState { config: config.clone(), ip_pool: Mutex::new(ip_pool), clients: RwLock::new(HashMap::new()), stats: RwLock::new(ServerStatistics::default()), rate_limiters: Mutex::new(HashMap::new()), mtu_config, started_at: std::time::Instant::now(), }); let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); self.state = Some(state.clone()); self.shutdown_tx = Some(shutdown_tx); let listen_addr = config.listen_addr.clone(); tokio::spawn(async move { if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await { error!("Server listener error: {}", e); } }); info!("VPN server started"); Ok(()) } pub async fn stop(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()).await; } self.state = None; info!("VPN server stopped"); Ok(()) } pub fn get_status(&self) -> serde_json::Value { if let Some(ref state) = self.state { serde_json::json!({ "state": "connected", "connectedSince": format!("{:?}", state.started_at.elapsed()), }) } else { serde_json::json!({ "state": "disconnected" }) } } pub async fn get_statistics(&self) -> ServerStatistics { if let Some(ref state) = self.state { let mut stats = state.stats.read().await.clone(); stats.uptime_seconds = state.started_at.elapsed().as_secs(); stats.active_clients = state.clients.read().await.len() as u64; stats } else { ServerStatistics::default() } } pub async fn list_clients(&self) -> Vec { if let Some(ref state) = self.state { state.clients.read().await.values().cloned().collect() } else { Vec::new() } } pub async fn disconnect_client(&self, client_id: &str) -> Result<()> { if let Some(ref state) = self.state { let mut clients = state.clients.write().await; if let Some(client) = clients.remove(client_id) { let ip: Ipv4Addr = client.assigned_ip.parse()?; state.ip_pool.lock().await.release(&ip); state.rate_limiters.lock().await.remove(client_id); info!("Client {} disconnected", client_id); } } Ok(()) } /// Set a rate limit for a specific client. pub async fn set_client_rate_limit( &self, client_id: &str, rate_bytes_per_sec: u64, burst_bytes: u64, ) -> Result<()> { if let Some(ref state) = self.state { let mut limiters = state.rate_limiters.lock().await; if let Some(limiter) = limiters.get_mut(client_id) { limiter.update_limits(rate_bytes_per_sec, burst_bytes); } else { limiters.insert( client_id.to_string(), TokenBucket::new(rate_bytes_per_sec, burst_bytes), ); } // Update client info let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(client_id) { info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec); info.burst_bytes = Some(burst_bytes); } } Ok(()) } /// Remove rate limit for a specific client (unlimited). pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> { if let Some(ref state) = self.state { state.rate_limiters.lock().await.remove(client_id); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(client_id) { info.rate_limit_bytes_per_sec = None; info.burst_bytes = None; } } Ok(()) } } async fn run_listener( state: Arc, listen_addr: String, shutdown_rx: &mut mpsc::Receiver<()>, ) -> Result<()> { let listener = TcpListener::bind(&listen_addr).await?; info!("WebSocket server listening on {}", listen_addr); loop { tokio::select! { accept = listener.accept() => { match accept { Ok((stream, addr)) => { info!("New connection from {}", addr); let state = state.clone(); tokio::spawn(async move { if let Err(e) = handle_client_connection(state, stream).await { warn!("Client connection error: {}", e); } }); } Err(e) => { error!("Accept error: {}", e); } } } _ = shutdown_rx.recv() => { info!("Shutdown signal received"); break; } } } Ok(()) } async fn handle_client_connection( state: Arc, stream: tokio::net::TcpStream, ) -> Result<()> { let ws = transport::accept_connection(stream).await?; let (mut ws_sink, mut ws_stream) = ws.split(); let client_id = uuid_v4(); let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?; let server_private_key = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &state.config.private_key, )?; let mut responder = crypto::create_responder(&server_private_key)?; let mut buf = vec![0u8; 65535]; // Receive handshake init let init_msg = match ws_stream.next().await { Some(Ok(Message::Binary(data))) => data.to_vec(), _ => anyhow::bail!("Expected handshake init message"), }; let mut frame_buf = BytesMut::from(&init_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete handshake frame"))?; if frame.packet_type != PacketType::HandshakeInit { anyhow::bail!("Expected HandshakeInit, got {:?}", frame.packet_type); } responder.read_message(&frame.payload, &mut buf)?; let len = responder.write_message(&[], &mut buf)?; let response_payload = buf[..len].to_vec(); let response_frame = Frame { packet_type: PacketType::HandshakeResp, payload: response_payload, }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?; ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?; let mut noise_transport = responder.into_transport_mode()?; // Register client let default_rate = state.config.default_rate_limit_bytes_per_sec; let default_burst = state.config.default_burst_bytes; let client_info = ClientInfo { client_id: client_id.clone(), assigned_ip: assigned_ip.to_string(), connected_since: timestamp_now(), bytes_sent: 0, bytes_received: 0, packets_dropped: 0, bytes_dropped: 0, last_keepalive_at: None, keepalives_received: 0, rate_limit_bytes_per_sec: default_rate, burst_bytes: default_burst, }; state.clients.write().await.insert(client_id.clone(), client_info); // Set up rate limiter if defaults are configured if let (Some(rate), Some(burst)) = (default_rate, default_burst) { state .rate_limiters .lock() .await .insert(client_id.clone(), TokenBucket::new(rate, burst)); } { let mut stats = state.stats.write().await; stats.total_connections += 1; } // Send assigned IP info (encrypted), include effective MTU let ip_info = serde_json::json!({ "assignedIp": assigned_ip.to_string(), "gateway": state.ip_pool.lock().await.gateway_addr().to_string(), "mtu": state.config.mtu.unwrap_or(1420), "effectiveMtu": state.mtu_config.effective_mtu, }); let ip_info_bytes = serde_json::to_vec(&ip_info)?; let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?; let encrypted_info = Frame { packet_type: PacketType::IpPacket, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?; ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?; info!("Client {} connected with IP {}", client_id, assigned_ip); // Main packet loop with dead-peer detection let mut last_activity = tokio::time::Instant::now(); loop { tokio::select! { msg = ws_stream.next() => { match msg { Some(Ok(Message::Binary(data))) => { last_activity = tokio::time::Instant::now(); let mut frame_buf = BytesMut::from(&data[..][..]); match ::decode(&mut FrameCodec, &mut frame_buf) { Ok(Some(frame)) => match frame.packet_type { PacketType::IpPacket => { match noise_transport.read_message(&frame.payload, &mut buf) { Ok(len) => { // Rate limiting check let allowed = { let mut limiters = state.rate_limiters.lock().await; if let Some(limiter) = limiters.get_mut(&client_id) { limiter.try_consume(len) } else { true } }; if !allowed { let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.packets_dropped += 1; info.bytes_dropped += len as u64; } continue; } let mut stats = state.stats.write().await; stats.bytes_received += len as u64; stats.packets_received += 1; // Update per-client stats drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.bytes_received += len as u64; } } Err(e) => { warn!("Decrypt error from {}: {}", client_id, e); break; } } } PacketType::Keepalive => { // Echo the keepalive payload back in the ACK let ack_frame = Frame { packet_type: PacketType::KeepaliveAck, payload: frame.payload.clone(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?; ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?; let mut stats = state.stats.write().await; stats.keepalives_received += 1; stats.keepalives_sent += 1; // Update per-client keepalive tracking drop(stats); let mut clients = state.clients.write().await; if let Some(info) = clients.get_mut(&client_id) { info.last_keepalive_at = Some(timestamp_now()); info.keepalives_received += 1; } } PacketType::Disconnect => { info!("Client {} sent disconnect", client_id); break; } _ => { warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type); } }, Ok(None) => { warn!("Incomplete frame from {}", client_id); } Err(e) => { warn!("Frame decode error from {}: {}", client_id, e); break; } } } Some(Ok(Message::Close(_))) | None => { info!("Client {} connection closed", client_id); break; } Some(Ok(Message::Ping(data))) => { last_activity = tokio::time::Instant::now(); ws_sink.send(Message::Pong(data)).await?; } Some(Ok(_)) => { last_activity = tokio::time::Instant::now(); continue; } Some(Err(e)) => { warn!("WebSocket error from {}: {}", client_id, e); break; } } } _ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => { warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs()); break; } } } // Cleanup state.clients.write().await.remove(&client_id); state.ip_pool.lock().await.release(&assigned_ip); state.rate_limiters.lock().await.remove(&client_id); info!("Client {} disconnected, released IP {}", client_id, assigned_ip); Ok(()) } fn uuid_v4() -> String { use rand::Rng; let mut rng = rand::thread_rng(); let bytes: [u8; 16] = rng.gen(); format!( "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}", bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], ) } fn timestamp_now() -> String { use std::time::SystemTime; let duration = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default(); format!("{}", duration.as_secs()) }