feat(vpn transport): add QUIC transport support with auto fallback to WebSocket

This commit is contained in:
2026-03-19 21:53:30 +00:00
parent e14c357ba0
commit e81dd377d8
16 changed files with 2952 additions and 1888 deletions

View File

@@ -1,10 +1,8 @@
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::{mpsc, watch, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn, debug};
use crate::codec::{Frame, FrameCodec, PacketType};
@@ -12,6 +10,8 @@ use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Client configuration (matches TS IVpnClientConfig).
#[derive(Debug, Clone, Deserialize)]
@@ -22,6 +22,10 @@ pub struct ClientConfig {
pub dns: Option<Vec<String>>,
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
/// Transport type: "websocket" (default) or "quic".
pub transport: Option<String>,
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
pub server_cert_hash: Option<String>,
}
/// Client statistics.
@@ -106,9 +110,66 @@ impl VpnClient {
&config.server_public_key,
)?;
// Connect to WebSocket server
let ws = transport::connect_to_server(&config.server_url).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
// Create transport based on configuration
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
let transport_type = config.transport.as_deref().unwrap_or("auto");
match transport_type {
"quic" => {
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
let cert_hash = config.server_cert_hash.as_deref();
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
info!("Connected via QUIC");
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
"websocket" => {
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
_ => {
// "auto" (default): try QUIC first, fall back to WebSocket
// Extract host:port from the URL for QUIC attempt
let quic_addr = extract_host_port(&config.server_url);
let cert_hash = config.server_cert_hash.as_deref();
if let Some(ref addr) = quic_addr {
match tokio::time::timeout(
std::time::Duration::from_secs(3),
try_quic_connect(addr, cert_hash),
).await {
Ok(Ok((quic_sink, quic_stream))) => {
info!("Auto: connected via QUIC to {}", addr);
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
Ok(Err(e)) => {
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC unavailable)");
(Box::new(ws_sink), Box::new(ws_stream))
}
Err(_) => {
debug!("Auto: QUIC timed out, falling back to WebSocket");
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC timed out)");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
} else {
// Can't extract host:port for QUIC, use WebSocket directly
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
}
};
// Noise NK handshake (client side = initiator)
*state.write().await = ClientState::Handshaking;
@@ -123,13 +184,11 @@ impl VpnClient {
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
// <- e, ee
let resp_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
let resp_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed during handshake"),
};
@@ -145,9 +204,9 @@ impl VpnClient {
let mut noise_transport = initiator.into_transport_mode()?;
// Receive assigned IP info (encrypted)
let info_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected IP info message"),
let info_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before IP info"),
};
let mut frame_buf = BytesMut::from(&info_msg[..]);
@@ -184,8 +243,8 @@ impl VpnClient {
// Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop(
ws_sink,
ws_stream,
sink,
stream,
noise_transport,
state,
stats,
@@ -280,8 +339,8 @@ impl VpnClient {
/// The main client packet forwarding loop (runs in a spawned task).
async fn client_loop(
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
mut noise_transport: snow::TransportState,
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
@@ -294,10 +353,10 @@ async fn client_loop(
loop {
tokio::select! {
msg = ws_stream.next() => {
msg = stream.recv_reliable() => {
match msg {
Some(Ok(Message::Binary(data))) => {
let mut frame_buf = BytesMut::from(&data[..][..]);
Ok(Some(data)) => {
let mut frame_buf = BytesMut::from(&data[..]);
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
match frame.packet_type {
PacketType::IpPacket => {
@@ -328,17 +387,13 @@ async fn client_loop(
}
}
}
Some(Ok(Message::Close(_))) | None => {
Ok(None) => {
info!("Connection closed");
*state.write().await = ClientState::Disconnected;
break;
}
Some(Ok(Message::Ping(data))) => {
let _ = ws_sink.send(Message::Pong(data)).await;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
error!("WebSocket error: {}", e);
Err(e) => {
error!("Transport error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
@@ -354,7 +409,7 @@ async fn client_loop(
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
warn!("Failed to send keepalive");
*state.write().await = ClientState::Disconnected;
break;
@@ -385,12 +440,51 @@ async fn client_loop(
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
}
let _ = ws_sink.close().await;
let _ = sink.close().await;
*state.write().await = ClientState::Disconnected;
break;
}
}
}
}
/// Try to connect via QUIC. Returns transport halves on success.
async fn try_quic_connect(
addr: &str,
cert_hash: Option<&str>,
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
Ok((sink, stream))
}
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
fn extract_host_port(url: &str) -> Option<String> {
if url.starts_with("ws://") || url.starts_with("wss://") {
// Parse as URL
let stripped = if url.starts_with("wss://") {
&url[6..]
} else {
&url[5..]
};
// Remove path
let host_port = stripped.split('/').next()?;
if host_port.contains(':') {
Some(host_port.to_string())
} else {
// Default port
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
Some(format!("{}:{}", host_port, default_port))
}
} else if url.contains(':') {
// Already host:port
Some(url.to_string())
} else {
None
}
}