use anyhow::Result; use bytes::BytesMut; use futures_util::{SinkExt, StreamExt}; use serde::Deserialize; use std::sync::Arc; use std::time::Duration; use tokio::sync::{mpsc, RwLock}; use tokio_tungstenite::tungstenite::Message; use tracing::{info, error, warn}; use crate::codec::{Frame, FrameCodec, PacketType}; use crate::crypto; use crate::transport; /// Client configuration (matches TS IVpnClientConfig). #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ClientConfig { pub server_url: String, pub server_public_key: String, pub dns: Option>, pub mtu: Option, pub keepalive_interval_secs: Option, } /// Client statistics. #[derive(Debug, Clone, Default)] pub struct ClientStatistics { pub bytes_sent: u64, pub bytes_received: u64, pub packets_sent: u64, pub packets_received: u64, pub keepalives_sent: u64, pub keepalives_received: u64, } /// Client connection state. #[derive(Debug, Clone, PartialEq)] pub enum ClientState { Disconnected, Connecting, Handshaking, Connected, Reconnecting, Error(String), } impl std::fmt::Display for ClientState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Disconnected => write!(f, "disconnected"), Self::Connecting => write!(f, "connecting"), Self::Handshaking => write!(f, "handshaking"), Self::Connected => write!(f, "connected"), Self::Reconnecting => write!(f, "reconnecting"), Self::Error(e) => write!(f, "error: {}", e), } } } /// The VPN client. pub struct VpnClient { state: Arc>, stats: Arc>, assigned_ip: Arc>>, shutdown_tx: Option>, connected_since: Arc>>, } impl VpnClient { pub fn new() -> Self { Self { state: Arc::new(RwLock::new(ClientState::Disconnected)), stats: Arc::new(RwLock::new(ClientStatistics::default())), assigned_ip: Arc::new(RwLock::new(None)), shutdown_tx: None, connected_since: Arc::new(RwLock::new(None)), } } /// Connect to the VPN server. pub async fn connect(&mut self, config: ClientConfig) -> Result { if *self.state.read().await != ClientState::Disconnected { anyhow::bail!("Client is not disconnected"); } *self.state.write().await = ClientState::Connecting; let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1); self.shutdown_tx = Some(shutdown_tx); let state = self.state.clone(); let stats = self.stats.clone(); let assigned_ip_ref = self.assigned_ip.clone(); let connected_since = self.connected_since.clone(); // Decode server public key let server_pub_key = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &config.server_public_key, )?; // Connect to WebSocket server let ws = transport::connect_to_server(&config.server_url).await?; let (mut ws_sink, mut ws_stream) = ws.split(); // Noise NK handshake (client side = initiator) *state.write().await = ClientState::Handshaking; let mut initiator = crypto::create_initiator(&server_pub_key)?; let mut buf = vec![0u8; 65535]; // -> e, es let len = initiator.write_message(&[], &mut buf)?; let init_frame = Frame { packet_type: PacketType::HandshakeInit, payload: buf[..len].to_vec(), }; let mut frame_bytes = BytesMut::new(); >::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?; ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?; // <- e, ee let resp_msg = match ws_stream.next().await { Some(Ok(Message::Binary(data))) => data.to_vec(), Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"), Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e), None => anyhow::bail!("Connection closed during handshake"), }; let mut frame_buf = BytesMut::from(&resp_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete handshake response frame"))?; if frame.packet_type != PacketType::HandshakeResp { anyhow::bail!("Expected HandshakeResp, got {:?}", frame.packet_type); } initiator.read_message(&frame.payload, &mut buf)?; let mut noise_transport = initiator.into_transport_mode()?; // Receive assigned IP info (encrypted) let info_msg = match ws_stream.next().await { Some(Ok(Message::Binary(data))) => data.to_vec(), _ => anyhow::bail!("Expected IP info message"), }; let mut frame_buf = BytesMut::from(&info_msg[..]); let frame = ::decode(&mut FrameCodec, &mut frame_buf)? .ok_or_else(|| anyhow::anyhow!("Incomplete IP info frame"))?; let len = noise_transport.read_message(&frame.payload, &mut buf)?; let ip_info: serde_json::Value = serde_json::from_slice(&buf[..len])?; let assigned_ip = ip_info["assignedIp"] .as_str() .ok_or_else(|| anyhow::anyhow!("Missing assignedIp in server response"))? .to_string(); *assigned_ip_ref.write().await = Some(assigned_ip.clone()); *connected_since.write().await = Some(std::time::Instant::now()); *state.write().await = ClientState::Connected; info!("Connected to VPN, assigned IP: {}", assigned_ip); // Spawn packet forwarding loop let assigned_ip_clone = assigned_ip.clone(); tokio::spawn(client_loop( ws_sink, ws_stream, noise_transport, state, stats, shutdown_rx, config.keepalive_interval_secs.unwrap_or(30), )); Ok(assigned_ip_clone) } /// Disconnect from the VPN server. pub async fn disconnect(&mut self) -> Result<()> { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()).await; } *self.assigned_ip.write().await = None; *self.connected_since.write().await = None; *self.state.write().await = ClientState::Disconnected; info!("Disconnected from VPN"); Ok(()) } /// Get current status. pub async fn get_status(&self) -> serde_json::Value { let state = self.state.read().await; let ip = self.assigned_ip.read().await; let since = self.connected_since.read().await; let mut status = serde_json::json!({ "state": format!("{}", *state), }); if let Some(ref ip) = *ip { status["assignedIp"] = serde_json::json!(ip); } if let Some(instant) = *since { status["uptimeSeconds"] = serde_json::json!(instant.elapsed().as_secs()); } status } /// Get traffic statistics. pub async fn get_statistics(&self) -> serde_json::Value { let stats = self.stats.read().await; let since = self.connected_since.read().await; let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0); serde_json::json!({ "bytesSent": stats.bytes_sent, "bytesReceived": stats.bytes_received, "packetsSent": stats.packets_sent, "packetsReceived": stats.packets_received, "keepalivesSent": stats.keepalives_sent, "keepalivesReceived": stats.keepalives_received, "uptimeSeconds": uptime, }) } } /// The main client packet forwarding loop (runs in a spawned task). async fn client_loop( mut ws_sink: futures_util::stream::SplitSink, mut ws_stream: futures_util::stream::SplitStream, mut noise_transport: snow::TransportState, state: Arc>, stats: Arc>, mut shutdown_rx: mpsc::Receiver<()>, keepalive_secs: u64, ) { let mut buf = vec![0u8; 65535]; let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs)); keepalive_ticker.tick().await; // skip first immediate tick loop { tokio::select! { msg = ws_stream.next() => { match msg { Some(Ok(Message::Binary(data))) => { let mut frame_buf = BytesMut::from(&data[..][..]); if let Ok(Some(frame)) = ::decode(&mut FrameCodec, &mut frame_buf) { match frame.packet_type { PacketType::IpPacket => { match noise_transport.read_message(&frame.payload, &mut buf) { Ok(len) => { let mut s = stats.write().await; s.bytes_received += len as u64; s.packets_received += 1; } Err(e) => { warn!("Decrypt error: {}", e); *state.write().await = ClientState::Error(e.to_string()); break; } } } PacketType::KeepaliveAck => { stats.write().await.keepalives_received += 1; } PacketType::Disconnect => { info!("Server sent disconnect"); *state.write().await = ClientState::Disconnected; break; } _ => {} } } } Some(Ok(Message::Close(_))) | None => { info!("Connection closed"); *state.write().await = ClientState::Disconnected; break; } Some(Ok(Message::Ping(data))) => { let _ = ws_sink.send(Message::Pong(data)).await; } Some(Ok(_)) => continue, Some(Err(e)) => { error!("WebSocket error: {}", e); *state.write().await = ClientState::Error(e.to_string()); break; } } } _ = keepalive_ticker.tick() => { let ka_frame = Frame { packet_type: PacketType::Keepalive, payload: vec![], }; let mut frame_bytes = BytesMut::new(); if >::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() { if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() { warn!("Failed to send keepalive"); *state.write().await = ClientState::Disconnected; break; } stats.write().await.keepalives_sent += 1; } } _ = shutdown_rx.recv() => { // Send disconnect frame let dc_frame = Frame { packet_type: PacketType::Disconnect, payload: vec![], }; let mut frame_bytes = BytesMut::new(); if >::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() { let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await; } let _ = ws_sink.close().await; *state.write().await = ClientState::Disconnected; break; } } } }