Files
smartvpn/rust/src/client.rs

491 lines
20 KiB
Rust

use anyhow::Result;
use bytes::BytesMut;
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::{mpsc, watch, RwLock};
use tracing::{info, error, warn, debug};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Client configuration (matches TS IVpnClientConfig).
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientConfig {
pub server_url: String,
pub server_public_key: String,
pub dns: Option<Vec<String>>,
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
/// Transport type: "websocket" (default) or "quic".
pub transport: Option<String>,
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
pub server_cert_hash: Option<String>,
}
/// Client statistics.
#[derive(Debug, Clone, Default)]
pub struct ClientStatistics {
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub keepalives_sent: u64,
pub keepalives_received: u64,
}
/// Client connection state.
#[derive(Debug, Clone, PartialEq)]
pub enum ClientState {
Disconnected,
Connecting,
Handshaking,
Connected,
Reconnecting,
Error(String),
}
impl std::fmt::Display for ClientState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "disconnected"),
Self::Connecting => write!(f, "connecting"),
Self::Handshaking => write!(f, "handshaking"),
Self::Connected => write!(f, "connected"),
Self::Reconnecting => write!(f, "reconnecting"),
Self::Error(e) => write!(f, "error: {}", e),
}
}
}
/// The VPN client.
pub struct VpnClient {
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
assigned_ip: Arc<RwLock<Option<String>>>,
shutdown_tx: Option<mpsc::Sender<()>>,
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
quality_rx: Option<watch::Receiver<ConnectionQuality>>,
link_health: Arc<RwLock<LinkHealth>>,
}
impl VpnClient {
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(ClientState::Disconnected)),
stats: Arc::new(RwLock::new(ClientStatistics::default())),
assigned_ip: Arc::new(RwLock::new(None)),
shutdown_tx: None,
connected_since: Arc::new(RwLock::new(None)),
quality_rx: None,
link_health: Arc::new(RwLock::new(LinkHealth::Degraded)),
}
}
/// Connect to the VPN server.
pub async fn connect(&mut self, config: ClientConfig) -> Result<String> {
if *self.state.read().await != ClientState::Disconnected {
anyhow::bail!("Client is not disconnected");
}
*self.state.write().await = ClientState::Connecting;
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(shutdown_tx);
let state = self.state.clone();
let stats = self.stats.clone();
let assigned_ip_ref = self.assigned_ip.clone();
let connected_since = self.connected_since.clone();
let link_health = self.link_health.clone();
// Decode server public key
let server_pub_key = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
&config.server_public_key,
)?;
// Create transport based on configuration
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
let transport_type = config.transport.as_deref().unwrap_or("auto");
match transport_type {
"quic" => {
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
let cert_hash = config.server_cert_hash.as_deref();
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
info!("Connected via QUIC");
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
"websocket" => {
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
_ => {
// "auto" (default): try QUIC first, fall back to WebSocket
// Extract host:port from the URL for QUIC attempt
let quic_addr = extract_host_port(&config.server_url);
let cert_hash = config.server_cert_hash.as_deref();
if let Some(ref addr) = quic_addr {
match tokio::time::timeout(
std::time::Duration::from_secs(3),
try_quic_connect(addr, cert_hash),
).await {
Ok(Ok((quic_sink, quic_stream))) => {
info!("Auto: connected via QUIC to {}", addr);
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
Ok(Err(e)) => {
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC unavailable)");
(Box::new(ws_sink), Box::new(ws_stream))
}
Err(_) => {
debug!("Auto: QUIC timed out, falling back to WebSocket");
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC timed out)");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
} else {
// Can't extract host:port for QUIC, use WebSocket directly
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
}
};
// Noise NK handshake (client side = initiator)
*state.write().await = ClientState::Handshaking;
let mut initiator = crypto::create_initiator(&server_pub_key)?;
let mut buf = vec![0u8; 65535];
// -> e, es
let len = initiator.write_message(&[], &mut buf)?;
let init_frame = Frame {
packet_type: PacketType::HandshakeInit,
payload: buf[..len].to_vec(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
// <- e, ee
let resp_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed during handshake"),
};
let mut frame_buf = BytesMut::from(&resp_msg[..]);
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
.ok_or_else(|| anyhow::anyhow!("Incomplete handshake response frame"))?;
if frame.packet_type != PacketType::HandshakeResp {
anyhow::bail!("Expected HandshakeResp, got {:?}", frame.packet_type);
}
initiator.read_message(&frame.payload, &mut buf)?;
let mut noise_transport = initiator.into_transport_mode()?;
// Receive assigned IP info (encrypted)
let info_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before IP info"),
};
let mut frame_buf = BytesMut::from(&info_msg[..]);
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
.ok_or_else(|| anyhow::anyhow!("Incomplete IP info frame"))?;
let len = noise_transport.read_message(&frame.payload, &mut buf)?;
let ip_info: serde_json::Value = serde_json::from_slice(&buf[..len])?;
let assigned_ip = ip_info["assignedIp"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing assignedIp in server response"))?
.to_string();
*assigned_ip_ref.write().await = Some(assigned_ip.clone());
*connected_since.write().await = Some(std::time::Instant::now());
*state.write().await = ClientState::Connected;
info!("Connected to VPN, assigned IP: {}", assigned_ip);
// Create adaptive keepalive monitor (use custom interval if configured)
let ka_config = config.keepalive_interval_secs.map(|secs| {
let mut cfg = keepalive::AdaptiveKeepaliveConfig::default();
cfg.degraded_interval = std::time::Duration::from_secs(secs);
cfg.healthy_interval = std::time::Duration::from_secs(secs * 2);
cfg.critical_interval = std::time::Duration::from_secs((secs / 3).max(1));
cfg
});
let (monitor, handle) = keepalive::create_keepalive(ka_config);
self.quality_rx = Some(handle.quality_rx);
// Spawn the keepalive monitor
tokio::spawn(monitor.run());
// Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop(
sink,
stream,
noise_transport,
state,
stats,
shutdown_rx,
handle.signal_rx,
handle.ack_tx,
link_health,
));
Ok(assigned_ip_clone)
}
/// Disconnect from the VPN server.
pub async fn disconnect(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
*self.assigned_ip.write().await = None;
*self.connected_since.write().await = None;
*self.state.write().await = ClientState::Disconnected;
self.quality_rx = None;
info!("Disconnected from VPN");
Ok(())
}
/// Get current status.
pub async fn get_status(&self) -> serde_json::Value {
let state = self.state.read().await;
let ip = self.assigned_ip.read().await;
let since = self.connected_since.read().await;
let mut status = serde_json::json!({
"state": format!("{}", *state),
});
if let Some(ref ip) = *ip {
status["assignedIp"] = serde_json::json!(ip);
}
if let Some(instant) = *since {
status["uptimeSeconds"] = serde_json::json!(instant.elapsed().as_secs());
}
status
}
/// Get traffic statistics (includes connection quality).
pub async fn get_statistics(&self) -> serde_json::Value {
let stats = self.stats.read().await;
let since = self.connected_since.read().await;
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
let health = self.link_health.read().await;
let mut result = serde_json::json!({
"bytesSent": stats.bytes_sent,
"bytesReceived": stats.bytes_received,
"packetsSent": stats.packets_sent,
"packetsReceived": stats.packets_received,
"keepalivesSent": stats.keepalives_sent,
"keepalivesReceived": stats.keepalives_received,
"uptimeSeconds": uptime,
});
// Include connection quality if available
if let Some(ref rx) = self.quality_rx {
let quality = rx.borrow().clone();
result["quality"] = serde_json::json!({
"srttMs": quality.srtt_ms,
"jitterMs": quality.jitter_ms,
"minRttMs": quality.min_rtt_ms,
"maxRttMs": quality.max_rtt_ms,
"lossRatio": quality.loss_ratio,
"consecutiveTimeouts": quality.consecutive_timeouts,
"linkHealth": format!("{}", *health),
"keepalivesSent": quality.keepalives_sent,
"keepalivesAcked": quality.keepalives_acked,
});
}
result
}
/// Get connection quality snapshot.
pub fn get_connection_quality(&self) -> Option<ConnectionQuality> {
self.quality_rx.as_ref().map(|rx| rx.borrow().clone())
}
/// Get current link health.
pub async fn get_link_health(&self) -> LinkHealth {
*self.link_health.read().await
}
}
/// The main client packet forwarding loop (runs in a spawned task).
async fn client_loop(
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
mut noise_transport: snow::TransportState,
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
mut shutdown_rx: mpsc::Receiver<()>,
mut signal_rx: mpsc::Receiver<KeepaliveSignal>,
ack_tx: mpsc::Sender<()>,
link_health: Arc<RwLock<LinkHealth>>,
) {
let mut buf = vec![0u8; 65535];
loop {
tokio::select! {
msg = stream.recv_reliable() => {
match msg {
Ok(Some(data)) => {
let mut frame_buf = BytesMut::from(&data[..]);
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
let mut s = stats.write().await;
s.bytes_received += len as u64;
s.packets_received += 1;
}
Err(e) => {
warn!("Decrypt error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
}
}
PacketType::KeepaliveAck => {
stats.write().await.keepalives_received += 1;
// Signal the keepalive monitor that ACK was received
let _ = ack_tx.send(()).await;
}
PacketType::Disconnect => {
info!("Server sent disconnect");
*state.write().await = ClientState::Disconnected;
break;
}
_ => {}
}
}
}
Ok(None) => {
info!("Connection closed");
*state.write().await = ClientState::Disconnected;
break;
}
Err(e) => {
error!("Transport error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
}
}
signal = signal_rx.recv() => {
match signal {
Some(KeepaliveSignal::SendPing(timestamp_ms)) => {
// Embed the timestamp in the keepalive payload (8 bytes, big-endian)
let ka_frame = Frame {
packet_type: PacketType::Keepalive,
payload: timestamp_ms.to_be_bytes().to_vec(),
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
warn!("Failed to send keepalive");
*state.write().await = ClientState::Disconnected;
break;
}
stats.write().await.keepalives_sent += 1;
}
}
Some(KeepaliveSignal::PeerDead) => {
warn!("Peer declared dead by keepalive monitor");
*state.write().await = ClientState::Disconnected;
break;
}
Some(KeepaliveSignal::LinkHealthChanged(health)) => {
debug!("Link health changed to: {}", health);
*link_health.write().await = health;
}
None => {
// Keepalive monitor channel closed
break;
}
}
}
_ = shutdown_rx.recv() => {
// Send disconnect frame
let dc_frame = Frame {
packet_type: PacketType::Disconnect,
payload: vec![],
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
}
let _ = sink.close().await;
*state.write().await = ClientState::Disconnected;
break;
}
}
}
}
/// Try to connect via QUIC. Returns transport halves on success.
async fn try_quic_connect(
addr: &str,
cert_hash: Option<&str>,
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
Ok((sink, stream))
}
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
fn extract_host_port(url: &str) -> Option<String> {
if url.starts_with("ws://") || url.starts_with("wss://") {
// Parse as URL
let stripped = if url.starts_with("wss://") {
&url[6..]
} else {
&url[5..]
};
// Remove path
let host_port = stripped.split('/').next()?;
if host_port.contains(':') {
Some(host_port.to_string())
} else {
// Default port
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
Some(format!("{}:{}", host_port, default_port))
}
} else if url.contains(':') {
// Already host:port
Some(url.to_string())
} else {
None
}
}