use anyhow::Result; use bytes::BytesMut; use serde::Deserialize; use std::sync::Arc; use tokio::sync::{mpsc, watch, RwLock}; use tracing::{info, error, warn, debug}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; use crate::keepalive::{self, KeepaliveSignal, LinkHealth}; use crate::telemetry::ConnectionQuality; use crate::transport; use crate::transport_trait::{self, TransportSink, TransportStream}; use crate::quic_transport; /// Client configuration (matches TS IVpnClientConfig). #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ClientConfig { pub server_url: String, pub server_public_key: String, pub dns: Option>, pub mtu: Option, pub keepalive_interval_secs: Option, /// Transport type: "websocket" (default) or "quic". pub transport: Option, /// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning. pub server_cert_hash: Option, } /// Client statistics. #[derive(Debug, Clone, Default)] pub struct ClientStatistics { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub keepalives_sent: u64, pub keepalives_received: u64, } /// Client connection state. #[derive(Debug, Clone, PartialEq)] pub enum ClientState { Disconnected, Connecting, Handshaking, Connected, Reconnecting, Error(String), } impl std::fmt::Display for ClientState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Disconnected => write!(f, "disconnected"), Self::Connecting => write!(f, "connecting"), Self::Handshaking => write!(f, "handshaking"), Self::Connected => write!(f, "connected"), Self::Reconnecting => write!(f, "reconnecting"), Self::Error(e) => write!(f, "error: {}", e), } } } /// The VPN client. pub struct VpnClient { state: Arc>, stats: Arc>, assigned_ip: Arc>>, shutdown_tx: Option>, connected_since: Arc>>, quality_rx: Option>, link_health: Arc>, } impl VpnClient { pub fn new() -> Self { Self { state: Arc::new(RwLock::new(ClientState::Disconnected)), stats: Arc::new(RwLock::new(ClientStatistics::default())), assigned_ip: Arc::new(RwLock::new(None)), shutdown_tx: None, connected_since: Arc::new(RwLock::new(None)), quality_rx: None, link_health: Arc::new(RwLock::new(LinkHealth::Degraded)), } } /// Connect to the VPN server. pub async fn connect(&mut self, config: ClientConfig) -> Result { if *self.state.read().await != ClientState::Disconnected { anyhow::bail!("Client is not disconnected"); } *self.state.write().await = ClientState::Connecting; let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1); self.shutdown_tx = Some(shutdown_tx); let state = self.state.clone(); let stats = self.stats.clone(); let assigned_ip_ref = self.assigned_ip.clone(); let connected_since = self.connected_since.clone(); let link_health = self.link_health.clone(); // Decode server public key let server_pub_key = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &config.server_public_key, )?; // Create transport based on configuration let (mut sink, mut stream): (Box, Box) = { let transport_type = config.transport.as_deref().unwrap_or("auto"); match transport_type { "quic" => { let server_addr = &config.server_url; // For QUIC, serverUrl is host:port let cert_hash = config.server_cert_hash.as_deref(); let conn = quic_transport::connect_quic(server_addr, cert_hash).await?; let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?; info!("Connected via QUIC"); (Box::new(quic_sink) as Box, Box::new(quic_stream) as Box) } "websocket" => { let ws = transport::connect_to_server(&config.server_url).await?; let (ws_sink, ws_stream) = transport_trait::split_ws(ws); info!("Connected via WebSocket"); (Box::new(ws_sink), Box::new(ws_stream)) } _ => { // "auto" (default): try QUIC first, fall back to WebSocket // Extract host:port from the URL for QUIC attempt let quic_addr = extract_host_port(&config.server_url); let cert_hash = config.server_cert_hash.as_deref(); if let Some(ref addr) = quic_addr { match tokio::time::timeout( std::time::Duration::from_secs(3), try_quic_connect(addr, cert_hash), ).await { Ok(Ok((quic_sink, quic_stream))) => { info!("Auto: connected via QUIC to {}", addr); (Box::new(quic_sink) as Box, Box::new(quic_stream) as Box) } Ok(Err(e)) => { debug!("Auto: QUIC failed ({}), falling back to WebSocket", e); let ws = transport::connect_to_server(&config.server_url).await?; let (ws_sink, ws_stream) = transport_trait::split_ws(ws); info!("Auto: connected via WebSocket (QUIC unavailable)"); (Box::new(ws_sink), Box::new(ws_stream)) } Err(_) => { debug!("Auto: QUIC timed out, falling back to WebSocket"); let ws = transport::connect_to_server(&config.server_url).await?; let (ws_sink, ws_stream) = transport_trait::split_ws(ws); info!("Auto: connected via WebSocket (QUIC timed out)"); (Box::new(ws_sink), Box::new(ws_stream)) } } } else { // Can't extract host:port for QUIC, use WebSocket directly let ws = transport::connect_to_server(&config.server_url).await?; let (ws_sink, ws_stream) = transport_trait::split_ws(ws); info!("Connected via WebSocket"); (Box::new(ws_sink), Box::new(ws_stream)) } } } }; // Noise NK handshake (client side = initiator) *state.write().await = ClientState::Handshaking; let mut initiator = crypto::create_initiator(&server_pub_key)?; let mut buf = vec![0u8; 65535]; // -> e, es let len = initiator.write_message(&[], &mut buf)?; let init_frame = Frame { packet_type: PacketType::HandshakeInit, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?; sink.send_reliable(frame_bytes.to_vec()).await?; // <- e, ee let resp_msg = match stream.recv_reliable().await? { Some(data) => data, None => anyhow::bail!("Connection closed during handshake"), }; let mut frame_buf = BytesMut::from(&resp_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete handshake response frame"))?; if frame.packet_type != PacketType::HandshakeResp { anyhow::bail!("Expected HandshakeResp, got {:?}", frame.packet_type); } initiator.read_message(&frame.payload, &mut buf)?; let mut noise_transport = initiator.into_transport_mode()?; // Receive assigned IP info (encrypted) let info_msg = match stream.recv_reliable().await? { Some(data) => data, None => anyhow::bail!("Connection closed before IP info"), }; let mut frame_buf = BytesMut::from(&info_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete IP info frame"))?; let len = noise_transport.read_message(&frame.payload, &mut buf)?; let ip_info: serde_json::Value = serde_json::from_slice(&buf[..len])?; let assigned_ip = ip_info["assignedIp"] .as_str() .ok_or_else(|| anyhow::anyhow!("Missing assignedIp in server response"))? .to_string(); *assigned_ip_ref.write().await = Some(assigned_ip.clone()); *connected_since.write().await = Some(std::time::Instant::now()); *state.write().await = ClientState::Connected; info!("Connected to VPN, assigned IP: {}", assigned_ip); // Create adaptive keepalive monitor (use custom interval if configured) let ka_config = config.keepalive_interval_secs.map(|secs| { let mut cfg = keepalive::AdaptiveKeepaliveConfig::default(); cfg.degraded_interval = std::time::Duration::from_secs(secs); cfg.healthy_interval = std::time::Duration::from_secs(secs * 2); cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1)); cfg }); let (monitor, handle) = keepalive::create_keepalive(ka_config); self.quality_rx = Some(handle.quality_rx); // Spawn the keepalive monitor tokio::spawn(monitor.run()); // Spawn packet forwarding loop let assigned_ip_clone = assigned_ip.clone(); tokio::spawn(client_loop( sink, stream, noise_transport, state, stats, shutdown_rx, handle.signal_rx, handle.ack_tx, link_health, )); Ok(assigned_ip_clone) } /// Disconnect from the VPN server. pub async fn disconnect(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()).await; } *self.assigned_ip.write().await = None; *self.connected_since.write().await = None; *self.state.write().await = ClientState::Disconnected; self.quality_rx = None; info!("Disconnected from VPN"); Ok(()) } /// Get current status. pub async fn get_status(&self) -> serde_json::Value { let state = self.state.read().await; let ip = self.assigned_ip.read().await; let since = self.connected_since.read().await; let mut status = serde_json::json!({ "state": format!("{}", *state), }); if let Some(ref ip) = *ip { status["assignedIp"] = serde_json::json!(ip); } if let Some(instant) = *since { status["uptimeSeconds"] = serde_json::json!(instant.elapsed().as_secs()); } status } /// Get traffic statistics (includes connection quality). pub async fn get_statistics(&self) -> serde_json::Value { let stats = self.stats.read().await; let since = self.connected_since.read().await; let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0); let health = self.link_health.read().await; let mut result = serde_json::json!({ "bytesSent": stats.bytes_sent, "bytesReceived": stats.bytes_received, "packetsSent": stats.packets_sent, "packetsReceived": stats.packets_received, "keepalivesSent": stats.keepalives_sent, "keepalivesReceived": stats.keepalives_received, "uptimeSeconds": uptime, }); // Include connection quality if available if let Some(ref rx) = self.quality_rx { let quality = rx.borrow().clone(); result["quality"] = serde_json::json!({ "srttMs": quality.srtt_ms, "jitterMs": quality.jitter_ms, "minRttMs": quality.min_rtt_ms, "maxRttMs": quality.max_rtt_ms, "lossRatio": quality.loss_ratio, "consecutiveTimeouts": quality.consecutive_timeouts, "linkHealth": format!("{}", *health), "keepalivesSent": quality.keepalives_sent, "keepalivesAcked": quality.keepalives_acked, }); } result } /// Get connection quality snapshot. pub fn get_connection_quality(&self) -> Option { self.quality_rx.as_ref().map(|rx| rx.borrow().clone()) } /// Get current link health. pub async fn get_link_health(&self) -> LinkHealth { *self.link_health.read().await } } /// The main client packet forwarding loop (runs in a spawned task). async fn client_loop( mut sink: Box, mut stream: Box, mut noise_transport: snow::TransportState, state: Arc>, stats: Arc>, mut shutdown_rx: mpsc::Receiver<()>, mut signal_rx: mpsc::Receiver, ack_tx: mpsc::Sender<()>, link_health: Arc>, ) { let mut buf = vec![0u8; 65535]; loop { tokio::select! { msg = stream.recv_reliable() => { match msg { Ok(Some(data)) => { let mut frame_buf = BytesMut::from(&data[..]); if let Ok(Some(frame)) = ::decode(&mut FrameCodec, &mut frame_buf) { match frame.packet_type { PacketType::IpPacket => { match noise_transport.read_message(&frame.payload, &mut buf) { Ok(len) => { let mut s = stats.write().await; s.bytes_received += len as u64; s.packets_received += 1; } Err(e) => { warn!("Decrypt error: {}", e); *state.write().await = ClientState::Error(e.to_string()); break; } } } PacketType::KeepaliveAck => { stats.write().await.keepalives_received += 1; // Signal the keepalive monitor that ACK was received let _ = ack_tx.send(()).await; } PacketType::Disconnect => { info!("Server sent disconnect"); *state.write().await = ClientState::Disconnected; break; } _ => {} } } } Ok(None) => { info!("Connection closed"); *state.write().await = ClientState::Disconnected; break; } Err(e) => { error!("Transport error: {}", e); *state.write().await = ClientState::Error(e.to_string()); break; } } } signal = signal_rx.recv() => { match signal { Some(KeepaliveSignal::SendPing(timestamp_ms)) => { // Embed the timestamp in the keepalive payload (8 bytes, big-endian) let ka_frame = Frame { packet_type: PacketType::Keepalive, payload: timestamp_ms.to_be_bytes().to_vec(), }; let mut frame_bytes = BytesMut::new(); if >::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() { if sink.send_reliable(frame_bytes.to_vec()).await.is_err() { warn!("Failed to send keepalive"); *state.write().await = ClientState::Disconnected; break; } stats.write().await.keepalives_sent += 1; } } Some(KeepaliveSignal::PeerDead) => { warn!("Peer declared dead by keepalive monitor"); *state.write().await = ClientState::Disconnected; break; } Some(KeepaliveSignal::LinkHealthChanged(health)) => { debug!("Link health changed to: {}", health); *link_health.write().await = health; } None => { // Keepalive monitor channel closed break; } } } _ = shutdown_rx.recv() => { // Send disconnect frame let dc_frame = Frame { packet_type: PacketType::Disconnect, payload: vec![], }; let mut frame_bytes = BytesMut::new(); if >::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() { let _ = sink.send_reliable(frame_bytes.to_vec()).await; } let _ = sink.close().await; *state.write().await = ClientState::Disconnected; break; } } } } /// Try to connect via QUIC. Returns transport halves on success. async fn try_quic_connect( addr: &str, cert_hash: Option<&str>, ) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> { let conn = quic_transport::connect_quic(addr, cert_hash).await?; let (sink, stream) = quic_transport::open_quic_streams(conn).await?; Ok((sink, stream)) } /// Extract host:port from a WebSocket URL for QUIC auto-fallback. /// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080") /// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443") /// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port) fn extract_host_port(url: &str) -> Option { if url.starts_with("ws://") || url.starts_with("wss://") { // Parse as URL let stripped = if url.starts_with("wss://") { &url[6..] } else { &url[5..] }; // Remove path let host_port = stripped.split('/').next()?; if host_port.contains(':') { Some(host_port.to_string()) } else { // Default port let default_port = if url.starts_with("wss://") { 443 } else { 80 }; Some(format!("{}:{}", host_port, default_port)) } } else if url.contains(':') { // Already host:port Some(url.to_string()) } else { None } }