initial
This commit is contained in:
324
rust/src/client.rs
Normal file
324
rust/src/client.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
use crate::crypto;
|
||||
use crate::transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientConfig {
|
||||
pub server_url: String,
|
||||
pub server_public_key: String,
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
}
|
||||
|
||||
/// Client statistics.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ClientStatistics {
|
||||
pub bytes_sent: u64,
|
||||
pub bytes_received: u64,
|
||||
pub packets_sent: u64,
|
||||
pub packets_received: u64,
|
||||
pub keepalives_sent: u64,
|
||||
pub keepalives_received: u64,
|
||||
}
|
||||
|
||||
/// Client connection state.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ClientState {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Handshaking,
|
||||
Connected,
|
||||
Reconnecting,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ClientState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Disconnected => write!(f, "disconnected"),
|
||||
Self::Connecting => write!(f, "connecting"),
|
||||
Self::Handshaking => write!(f, "handshaking"),
|
||||
Self::Connected => write!(f, "connected"),
|
||||
Self::Reconnecting => write!(f, "reconnecting"),
|
||||
Self::Error(e) => write!(f, "error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The VPN client.
|
||||
pub struct VpnClient {
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
assigned_ip: Arc<RwLock<Option<String>>>,
|
||||
shutdown_tx: Option<mpsc::Sender<()>>,
|
||||
connected_since: Arc<RwLock<Option<std::time::Instant>>>,
|
||||
}
|
||||
|
||||
impl VpnClient {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(RwLock::new(ClientState::Disconnected)),
|
||||
stats: Arc::new(RwLock::new(ClientStatistics::default())),
|
||||
assigned_ip: Arc::new(RwLock::new(None)),
|
||||
shutdown_tx: None,
|
||||
connected_since: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to the VPN server.
|
||||
pub async fn connect(&mut self, config: ClientConfig) -> Result<String> {
|
||||
if *self.state.read().await != ClientState::Disconnected {
|
||||
anyhow::bail!("Client is not disconnected");
|
||||
}
|
||||
|
||||
*self.state.write().await = ClientState::Connecting;
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let state = self.state.clone();
|
||||
let stats = self.stats.clone();
|
||||
let assigned_ip_ref = self.assigned_ip.clone();
|
||||
let connected_since = self.connected_since.clone();
|
||||
|
||||
// Decode server public key
|
||||
let server_pub_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
|
||||
// Connect to WebSocket server
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
let mut initiator = crypto::create_initiator(&server_pub_key)?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// -> e, es
|
||||
let len = initiator.write_message(&[], &mut buf)?;
|
||||
let init_frame = Frame {
|
||||
packet_type: PacketType::HandshakeInit,
|
||||
payload: buf[..len].to_vec(),
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
|
||||
// <- e, ee
|
||||
let resp_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
|
||||
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&resp_msg[..]);
|
||||
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
|
||||
.ok_or_else(|| anyhow::anyhow!("Incomplete handshake response frame"))?;
|
||||
|
||||
if frame.packet_type != PacketType::HandshakeResp {
|
||||
anyhow::bail!("Expected HandshakeResp, got {:?}", frame.packet_type);
|
||||
}
|
||||
|
||||
initiator.read_message(&frame.payload, &mut buf)?;
|
||||
let mut noise_transport = initiator.into_transport_mode()?;
|
||||
|
||||
// Receive assigned IP info (encrypted)
|
||||
let info_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected IP info message"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&info_msg[..]);
|
||||
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
|
||||
.ok_or_else(|| anyhow::anyhow!("Incomplete IP info frame"))?;
|
||||
|
||||
let len = noise_transport.read_message(&frame.payload, &mut buf)?;
|
||||
let ip_info: serde_json::Value = serde_json::from_slice(&buf[..len])?;
|
||||
let assigned_ip = ip_info["assignedIp"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing assignedIp in server response"))?
|
||||
.to_string();
|
||||
|
||||
*assigned_ip_ref.write().await = Some(assigned_ip.clone());
|
||||
*connected_since.write().await = Some(std::time::Instant::now());
|
||||
*state.write().await = ClientState::Connected;
|
||||
|
||||
info!("Connected to VPN, assigned IP: {}", assigned_ip);
|
||||
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
ws_sink,
|
||||
ws_stream,
|
||||
noise_transport,
|
||||
state,
|
||||
stats,
|
||||
shutdown_rx,
|
||||
config.keepalive_interval_secs.unwrap_or(30),
|
||||
));
|
||||
|
||||
Ok(assigned_ip_clone)
|
||||
}
|
||||
|
||||
/// Disconnect from the VPN server.
|
||||
pub async fn disconnect(&mut self) -> Result<()> {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(()).await;
|
||||
}
|
||||
*self.assigned_ip.write().await = None;
|
||||
*self.connected_since.write().await = None;
|
||||
*self.state.write().await = ClientState::Disconnected;
|
||||
info!("Disconnected from VPN");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current status.
|
||||
pub async fn get_status(&self) -> serde_json::Value {
|
||||
let state = self.state.read().await;
|
||||
let ip = self.assigned_ip.read().await;
|
||||
let since = self.connected_since.read().await;
|
||||
|
||||
let mut status = serde_json::json!({
|
||||
"state": format!("{}", *state),
|
||||
});
|
||||
|
||||
if let Some(ref ip) = *ip {
|
||||
status["assignedIp"] = serde_json::json!(ip);
|
||||
}
|
||||
if let Some(instant) = *since {
|
||||
status["uptimeSeconds"] = serde_json::json!(instant.elapsed().as_secs());
|
||||
}
|
||||
|
||||
status
|
||||
}
|
||||
|
||||
/// Get traffic statistics.
|
||||
pub async fn get_statistics(&self) -> serde_json::Value {
|
||||
let stats = self.stats.read().await;
|
||||
let since = self.connected_since.read().await;
|
||||
let uptime = since.map(|s| s.elapsed().as_secs()).unwrap_or(0);
|
||||
|
||||
serde_json::json!({
|
||||
"bytesSent": stats.bytes_sent,
|
||||
"bytesReceived": stats.bytes_received,
|
||||
"packetsSent": stats.packets_sent,
|
||||
"packetsReceived": stats.packets_received,
|
||||
"keepalivesSent": stats.keepalives_sent,
|
||||
"keepalivesReceived": stats.keepalives_received,
|
||||
"uptimeSeconds": uptime,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// The main client packet forwarding loop (runs in a spawned task).
|
||||
async fn client_loop(
|
||||
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
|
||||
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
|
||||
mut noise_transport: snow::TransportState,
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
keepalive_secs: u64,
|
||||
) {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(keepalive_secs));
|
||||
keepalive_ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
match noise_transport.read_message(&frame.payload, &mut buf) {
|
||||
Ok(len) => {
|
||||
let mut s = stats.write().await;
|
||||
s.bytes_received += len as u64;
|
||||
s.packets_received += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Decrypt error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
PacketType::KeepaliveAck => {
|
||||
stats.write().await.keepalives_received += 1;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
info!("Server sent disconnect");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
info!("Connection closed");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_sink.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = keepalive_ticker.tick() => {
|
||||
let ka_frame = Frame {
|
||||
packet_type: PacketType::Keepalive,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
stats.write().await.keepalives_sent += 1;
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
// Send disconnect frame
|
||||
let dc_frame = Frame {
|
||||
packet_type: PacketType::Disconnect,
|
||||
payload: vec![],
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
|
||||
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user