Files
smartvpn/rust/src/client.rs

390 lines
15 KiB
Rust
Raw Normal View History

2026-02-27 10:18:23 +00:00
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::{mpsc, watch, RwLock};
2026-02-27 10:18:23 +00:00
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn, debug};
2026-02-27 10:18:23 +00:00
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
2026-02-27 10:18:23 +00:00
use crate::transport;
/// Client configuration (matches TS IVpnClientConfig).
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientConfig {
pub server_url: String,
pub server_public_key: String,
pub dns: Option<Vec<String>>,
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
}
/// Client statistics.
#[derive(Debug, Clone, Default)]
pub struct ClientStatistics {
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub keepalives_sent: u64,
pub keepalives_received: u64,
}
/// Client connection state.
#[derive(Debug, Clone, PartialEq)]
pub enum ClientState {
Disconnected,
Connecting,
Handshaking,
Connected,
Reconnecting,
Error(String),
}
impl std::fmt::Display for ClientState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "disconnected"),
Self::Connecting => write!(f, "connecting"),
Self::Handshaking => write!(f, "handshaking"),
Self::Connected => write!(f, "connected"),
Self::Reconnecting => write!(f, "reconnecting"),
Self::Error(e) => write!(f, "error: {}", e),
}
}
}
/// The VPN client.
pub struct VpnClient {
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
assigned_ip: Arc<RwLock<Option<String>>>,
shutdown_tx: Option<mpsc::Sender<()>>,
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
link_health: Arc<RwLock<LinkHealth>>,
2026-02-27 10:18:23 +00:00
}
impl VpnClient {
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(ClientState::Disconnected)),
stats: Arc::new(RwLock::new(ClientStatistics::default())),
assigned_ip: Arc::new(RwLock::new(None)),
shutdown_tx: None,
connected_since: Arc::new(RwLock::new(None)),
quality_rx: None,
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
2026-02-27 10:18:23 +00:00
}
}
/// Connect to the VPN server.
pub async fn connect(&mut self, config: ClientConfig) -> Result<String> {
if *self.state.read().await != ClientState::Disconnected {
anyhow::bail!("Client is not disconnected");
}
*self.state.write().await = ClientState::Connecting;
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(shutdown_tx);
let state = self.state.clone();
let stats = self.stats.clone();
let assigned_ip_ref = self.assigned_ip.clone();
let connected_since = self.connected_since.clone();
let link_health = self.link_health.clone();
2026-02-27 10:18:23 +00:00
// Decode server public key
let server_pub_key = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
&config.server_public_key,
)?;
// Connect to WebSocket server
let ws = transport::connect_to_server(&config.server_url).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
// Noise NK handshake (client side = initiator)
*state.write().await = ClientState::Handshaking;
let mut initiator = crypto::create_initiator(&server_pub_key)?;
let mut buf = vec![0u8; 65535];
// -> e, es
let len = initiator.write_message(&[], &mut buf)?;
let init_frame = Frame {
packet_type: PacketType::HandshakeInit,
payload: buf[..len].to_vec(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
// <- e, ee
let resp_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
None => anyhow::bail!("Connection closed during handshake"),
};
let mut frame_buf = BytesMut::from(&resp_msg[..]);
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
.ok_or_else(|| anyhow::anyhow!("Incomplete handshake response frame"))?;
if frame.packet_type != PacketType::HandshakeResp {
anyhow::bail!("Expected HandshakeResp, got {:?}", frame.packet_type);
}
initiator.read_message(&frame.payload, &mut buf)?;
let mut noise_transport = initiator.into_transport_mode()?;
// Receive assigned IP info (encrypted)
let info_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected IP info message"),
};
let mut frame_buf = BytesMut::from(&info_msg[..]);
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
.ok_or_else(|| anyhow::anyhow!("Incomplete IP info frame"))?;
let len = noise_transport.read_message(&frame.payload, &mut buf)?;
let ip_info: serde_json::Value = serde_json::from_slice(&buf[..len])?;
let assigned_ip = ip_info["assignedIp"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing assignedIp in server response"))?
.to_string();
*assigned_ip_ref.write().await = Some(assigned_ip.clone());
*connected_since.write().await = Some(std::time::Instant::now());
*state.write().await = ClientState::Connected;
info!("Connected to VPN, assigned IP: {}", assigned_ip);
// Create adaptive keepalive monitor
let (monitor, handle) = keepalive::create_keepalive(None);
self.quality_rx = Some(handle.quality_rx);
// Spawn the keepalive monitor
tokio::spawn(monitor.run());
2026-02-27 10:18:23 +00:00
// Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop(
ws_sink,
ws_stream,
noise_transport,
state,
stats,
shutdown_rx,
handle.signal_rx,
handle.ack_tx,
link_health,
2026-02-27 10:18:23 +00:00
));
Ok(assigned_ip_clone)
}
/// Disconnect from the VPN server.
pub async fn disconnect(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
*self.assigned_ip.write().await = None;
*self.connected_since.write().await = None;
*self.state.write().await = ClientState::Disconnected;
self.quality_rx = None;
2026-02-27 10:18:23 +00:00
info!("Disconnected from VPN");
Ok(())
}
/// Get current status.
pub async fn get_status(&self) -> serde_json::Value {
let state = self.state.read().await;
let ip = self.assigned_ip.read().await;
let since = self.connected_since.read().await;
let mut status = serde_json::json!({
"state": format!("{}", *state),
});
if let Some(ref ip) = *ip {
status["assignedIp"] = serde_json::json!(ip);
}
if let Some(instant) = *since {
status["uptimeSeconds"] = serde_json::json!(instant.elapsed().as_secs());
}
status
}
/// Get traffic statistics (includes connection quality).
2026-02-27 10:18:23 +00:00
pub async fn get_statistics(&self) -> serde_json::Value {
let stats = self.stats.read().await;
let since = self.connected_since.read().await;
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
let health = self.link_health.read().await;
2026-02-27 10:18:23 +00:00
let mut result = serde_json::json!({
2026-02-27 10:18:23 +00:00
"bytesSent": stats.bytes_sent,
"bytesReceived": stats.bytes_received,
"packetsSent": stats.packets_sent,
"packetsReceived": stats.packets_received,
"keepalivesSent": stats.keepalives_sent,
"keepalivesReceived": stats.keepalives_received,
"uptimeSeconds": uptime,
});
// Include connection quality if available
if let Some(ref rx) = self.quality_rx {
let quality = rx.borrow().clone();
result["quality"] = serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", *health),
"keepalivesSent": quality.keepalives_sent,
"keepalivesAcked": quality.keepalives_acked,
});
}
result
}
/// Get connection quality snapshot.
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
}
/// Get current link health.
pub async fn get_link_health(&self) -> LinkHealth {
*self.link_health.read().await
2026-02-27 10:18:23 +00:00
}
}
/// The main client packet forwarding loop (runs in a spawned task).
async fn client_loop(
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
mut noise_transport: snow::TransportState,
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
mut shutdown_rx: mpsc::Receiver<()>,
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
ack_tx: mpsc::Sender<()>,
link_health: Arc<RwLock<LinkHealth>>,
2026-02-27 10:18:23 +00:00
) {
let mut buf = vec![0u8; 65535];
loop {
tokio::select! {
msg = ws_stream.next() => {
match msg {
Some(Ok(Message::Binary(data))) => {
let mut frame_buf = BytesMut::from(&data[..][..]);
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
let mut s = stats.write().await;
s.bytes_received += len as u64;
s.packets_received += 1;
}
Err(e) => {
warn!("Decrypt error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
}
}
PacketType::KeepaliveAck => {
stats.write().await.keepalives_received += 1;
// Signal the keepalive monitor that ACK was received
let _ = ack_tx.send(()).await;
2026-02-27 10:18:23 +00:00
}
PacketType::Disconnect => {
info!("Server sent disconnect");
*state.write().await = ClientState::Disconnected;
break;
}
_ => {}
}
}
}
Some(Ok(Message::Close(_))) | None => {
info!("Connection closed");
*state.write().await = ClientState::Disconnected;
break;
}
Some(Ok(Message::Ping(data))) => {
let _ = ws_sink.send(Message::Pong(data)).await;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
error!("WebSocket error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
}
}
signal = signal_rx.recv() => {
match signal {
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
let ka_frame = Frame {
packet_type: PacketType::Keepalive,
payload: timestamp_ms.to_be_bytes().to_vec(),
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
warn!("Failed to send keepalive");
*state.write().await = ClientState::Disconnected;
break;
}
stats.write().await.keepalives_sent += 1;
}
}
Some(KeepaliveSignal::PeerDead) => {
warn!("Peer declared dead by keepalive monitor");
2026-02-27 10:18:23 +00:00
*state.write().await = ClientState::Disconnected;
break;
}
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
debug!("Link health changed to: {}", health);
*link_health.write().await = health;
}
None => {
// Keepalive monitor channel closed
break;
}
2026-02-27 10:18:23 +00:00
}
}
_ = shutdown_rx.recv() => {
// Send disconnect frame
let dc_frame = Frame {
packet_type: PacketType::Disconnect,
payload: vec![],
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
}
let _ = ws_sink.close().await;
*state.write().await = ClientState::Disconnected;
break;
}
}
}
}