feat(vpn transport): add QUIC transport support with auto fallback to WebSocket

This commit is contained in:
2026-03-19 21:53:30 +00:00
parent e14c357ba0
commit e81dd377d8
16 changed files with 2952 additions and 1888 deletions

View File

@@ -1,6 +1,5 @@
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::Ipv4Addr;
@@ -8,7 +7,6 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn};
use crate::codec::{Frame, FrameCodec, PacketType};
@@ -17,6 +15,8 @@ use crate::mtu::{MtuConfig, TunnelOverhead};
use crate::network::IpPool;
use crate::ratelimit::TokenBucket;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
@@ -39,6 +39,12 @@ pub struct ServerConfig {
pub default_rate_limit_bytes_per_sec: Option<u64>,
/// Default burst size for new clients (bytes). None = unlimited.
pub default_burst_bytes: Option<u64>,
/// Transport mode: "websocket" (default), "quic", or "both".
pub transport_mode: Option<String>,
/// QUIC listen address (host:port). Defaults to listen_addr.
pub quic_listen_addr: Option<String>,
/// QUIC idle timeout in seconds (default: 30).
pub quic_idle_timeout_secs: Option<u64>,
}
/// Information about a connected client.
@@ -135,14 +141,58 @@ impl VpnServer {
self.state = Some(state.clone());
self.shutdown_tx = Some(shutdown_tx);
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
let listen_addr = config.listen_addr.clone();
tokio::spawn(async move {
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
info!("VPN server started");
match transport_mode {
"quic" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
error!("QUIC listener error: {}", e);
}
});
}
"both" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
let state2 = state.clone();
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
// Store second shutdown sender so both listeners stop
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(combined_tx);
// Forward combined shutdown to both listeners
tokio::spawn(async move {
combined_rx.recv().await;
let _ = shutdown_tx_orig.send(()).await;
let _ = shutdown_tx2.send(()).await;
});
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("WebSocket listener error: {}", e);
}
});
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
error!("QUIC listener error: {}", e);
}
});
}
_ => {
// "websocket" (default)
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
}
}
info!("VPN server started (transport: {})", transport_mode);
Ok(())
}
@@ -239,7 +289,9 @@ impl VpnServer {
}
}
async fn run_listener(
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
/// to the transport-agnostic `handle_client_connection`.
async fn run_ws_listener(
state: Arc<ServerState>,
listen_addr: String,
shutdown_rx: &mut mpsc::Receiver<()>,
@@ -255,8 +307,20 @@ async fn run_listener(
info!("New connection from {}", addr);
let state = state.clone();
tokio::spawn(async move {
if let Err(e) = handle_client_connection(state, stream).await {
warn!("Client connection error: {}", e);
match transport::accept_connection(stream).await {
Ok(ws) => {
let (sink, stream) = transport_trait::split_ws(ws);
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("Client connection error: {}", e);
}
}
Err(e) => {
warn!("WebSocket upgrade failed: {}", e);
}
}
});
}
@@ -275,13 +339,95 @@ async fn run_listener(
Ok(())
}
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
/// `handle_client_connection`.
async fn run_quic_listener(
state: Arc<ServerState>,
listen_addr: String,
idle_timeout_secs: u64,
shutdown_rx: &mut mpsc::Receiver<()>,
) -> Result<()> {
// Generate or use configured TLS certificate for QUIC
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
(&state.config.tls_cert, &state.config.tls_key)
{
// Parse PEM certificates
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
(certs, key)
} else {
// Generate self-signed certificate
let (certs, key) = quic_transport::generate_self_signed_cert()?;
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
(certs, key)
};
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
listen_addr,
cert_chain,
private_key,
idle_timeout_secs,
})?;
loop {
tokio::select! {
incoming = endpoint.accept() => {
match incoming {
Some(incoming) => {
let state = state.clone();
tokio::spawn(async move {
match incoming.await {
Ok(conn) => {
let remote = conn.remote_address();
info!("New QUIC connection from {}", remote);
match quic_transport::accept_quic_connection(conn).await {
Ok((sink, stream)) => {
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("QUIC client error: {}", e);
}
}
Err(e) => {
warn!("QUIC stream accept failed: {}", e);
}
}
}
Err(e) => {
warn!("QUIC handshake failed: {}", e);
}
}
});
}
None => {
info!("QUIC endpoint closed");
break;
}
}
}
_ = shutdown_rx.recv() => {
info!("QUIC shutdown signal received");
endpoint.close(0u32.into(), b"shutdown");
break;
}
}
}
Ok(())
}
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
/// the client, and runs the main packet forwarding loop.
async fn handle_client_connection(
state: Arc<ServerState>,
stream: tokio::net::TcpStream,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
) -> Result<()> {
let ws = transport::accept_connection(stream).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
let client_id = uuid_v4();
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
@@ -295,9 +441,9 @@ async fn handle_client_connection(
let mut buf = vec![0u8; 65535];
// Receive handshake init
let init_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected handshake init message"),
let init_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before handshake"),
};
let mut frame_buf = BytesMut::from(&init_msg[..]);
@@ -318,7 +464,7 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
let mut noise_transport = responder.into_transport_mode()?;
@@ -369,7 +515,7 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
info!("Client {} connected with IP {}", client_id, assigned_ip);
@@ -378,11 +524,11 @@ async fn handle_client_connection(
loop {
tokio::select! {
msg = ws_stream.next() => {
msg = stream.recv_reliable() => {
match msg {
Some(Ok(Message::Binary(data))) => {
Ok(Some(data)) => {
last_activity = tokio::time::Instant::now();
let mut frame_buf = BytesMut::from(&data[..][..]);
let mut frame_buf = BytesMut::from(&data[..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
@@ -432,7 +578,7 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
@@ -463,20 +609,12 @@ async fn handle_client_connection(
}
}
}
Some(Ok(Message::Close(_))) | None => {
Ok(None) => {
info!("Client {} connection closed", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
last_activity = tokio::time::Instant::now();
ws_sink.send(Message::Pong(data)).await?;
}
Some(Ok(_)) => {
last_activity = tokio::time::Instant::now();
continue;
}
Some(Err(e)) => {
warn!("WebSocket error from {}: {}", client_id, e);
Err(e) => {
warn!("Transport error from {}: {}", client_id, e);
break;
}
}