feat(vpn transport): add QUIC transport support with auto fallback to WebSocket

This commit is contained in:
2026-03-19 21:53:30 +00:00
parent e14c357ba0
commit e81dd377d8
16 changed files with 2952 additions and 1888 deletions

View File

@@ -1,10 +1,8 @@
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::{mpsc, watch, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn, debug};
use crate::codec::{Frame, FrameCodec, PacketType};
@@ -12,6 +10,8 @@ use crate::crypto;
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
use crate::telemetry::ConnectionQuality;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Client configuration (matches TS IVpnClientConfig).
#[derive(Debug, Clone, Deserialize)]
@@ -22,6 +22,10 @@ pub struct ClientConfig {
pub dns: Option<Vec<String>>,
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
/// Transport type: "websocket" (default) or "quic".
pub transport: Option<String>,
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
pub server_cert_hash: Option<String>,
}
/// Client statistics.
@@ -106,9 +110,66 @@ impl VpnClient {
&config.server_public_key,
)?;
// Connect to WebSocket server
let ws = transport::connect_to_server(&config.server_url).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
// Create transport based on configuration
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
let transport_type = config.transport.as_deref().unwrap_or("auto");
match transport_type {
"quic" => {
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
let cert_hash = config.server_cert_hash.as_deref();
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
info!("Connected via QUIC");
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
"websocket" => {
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
_ => {
// "auto" (default): try QUIC first, fall back to WebSocket
// Extract host:port from the URL for QUIC attempt
let quic_addr = extract_host_port(&config.server_url);
let cert_hash = config.server_cert_hash.as_deref();
if let Some(ref addr) = quic_addr {
match tokio::time::timeout(
std::time::Duration::from_secs(3),
try_quic_connect(addr, cert_hash),
).await {
Ok(Ok((quic_sink, quic_stream))) => {
info!("Auto: connected via QUIC to {}", addr);
(Box::new(quic_sink) as Box<dyn TransportSink>,
Box::new(quic_stream) as Box<dyn TransportStream>)
}
Ok(Err(e)) => {
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC unavailable)");
(Box::new(ws_sink), Box::new(ws_stream))
}
Err(_) => {
debug!("Auto: QUIC timed out, falling back to WebSocket");
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Auto: connected via WebSocket (QUIC timed out)");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
} else {
// Can't extract host:port for QUIC, use WebSocket directly
let ws = transport::connect_to_server(&config.server_url).await?;
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
info!("Connected via WebSocket");
(Box::new(ws_sink), Box::new(ws_stream))
}
}
}
};
// Noise NK handshake (client side = initiator)
*state.write().await = ClientState::Handshaking;
@@ -123,13 +184,11 @@ impl VpnClient {
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
// <- e, ee
let resp_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
let resp_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed during handshake"),
};
@@ -145,9 +204,9 @@ impl VpnClient {
let mut noise_transport = initiator.into_transport_mode()?;
// Receive assigned IP info (encrypted)
let info_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected IP info message"),
let info_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before IP info"),
};
let mut frame_buf = BytesMut::from(&info_msg[..]);
@@ -184,8 +243,8 @@ impl VpnClient {
// Spawn packet forwarding loop
let assigned_ip_clone = assigned_ip.clone();
tokio::spawn(client_loop(
ws_sink,
ws_stream,
sink,
stream,
noise_transport,
state,
stats,
@@ -280,8 +339,8 @@ impl VpnClient {
/// The main client packet forwarding loop (runs in a spawned task).
async fn client_loop(
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
mut noise_transport: snow::TransportState,
state: Arc<RwLock<ClientState>>,
stats: Arc<RwLock<ClientStatistics>>,
@@ -294,10 +353,10 @@ async fn client_loop(
loop {
tokio::select! {
msg = ws_stream.next() => {
msg = stream.recv_reliable() => {
match msg {
Some(Ok(Message::Binary(data))) => {
let mut frame_buf = BytesMut::from(&data[..][..]);
Ok(Some(data)) => {
let mut frame_buf = BytesMut::from(&data[..]);
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
match frame.packet_type {
PacketType::IpPacket => {
@@ -328,17 +387,13 @@ async fn client_loop(
}
}
}
Some(Ok(Message::Close(_))) | None => {
Ok(None) => {
info!("Connection closed");
*state.write().await = ClientState::Disconnected;
break;
}
Some(Ok(Message::Ping(data))) => {
let _ = ws_sink.send(Message::Pong(data)).await;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
error!("WebSocket error: {}", e);
Err(e) => {
error!("Transport error: {}", e);
*state.write().await = ClientState::Error(e.to_string());
break;
}
@@ -354,7 +409,7 @@ async fn client_loop(
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
warn!("Failed to send keepalive");
*state.write().await = ClientState::Disconnected;
break;
@@ -385,12 +440,51 @@ async fn client_loop(
};
let mut frame_bytes = BytesMut::new();
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
}
let _ = ws_sink.close().await;
let _ = sink.close().await;
*state.write().await = ClientState::Disconnected;
break;
}
}
}
}
/// Try to connect via QUIC. Returns transport halves on success.
async fn try_quic_connect(
addr: &str,
cert_hash: Option<&str>,
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
Ok((sink, stream))
}
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
fn extract_host_port(url: &str) -> Option<String> {
if url.starts_with("ws://") || url.starts_with("wss://") {
// Parse as URL
let stripped = if url.starts_with("wss://") {
&url[6..]
} else {
&url[5..]
};
// Remove path
let host_port = stripped.split('/').next()?;
if host_port.contains(':') {
Some(host_port.to_string())
} else {
// Default port
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
Some(format!("{}:{}", host_port, default_port))
}
} else if url.contains(':') {
// Already host:port
Some(url.to_string())
} else {
None
}
}

View File

@@ -5,6 +5,8 @@ pub mod management;
pub mod codec;
pub mod crypto;
pub mod transport;
pub mod transport_trait;
pub mod quic_transport;
pub mod keepalive;
pub mod tunnel;
pub mod network;

546
rust/src/quic_transport.rs Normal file
View File

@@ -0,0 +1,546 @@
use anyhow::Result;
use async_trait::async_trait;
use quinn::crypto::rustls::QuicClientConfig;
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn, debug};
use crate::transport_trait::{TransportSink, TransportStream};
// ============================================================================
// TLS / Certificate helpers
// ============================================================================
/// Generate a self-signed certificate and private key for QUIC.
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
let cert_der = CertificateDer::from(cert.cert);
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
Ok((vec![cert_der], key_der))
}
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
use ring::digest;
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
}
// ============================================================================
// Server-side QUIC endpoint
// ============================================================================
/// Configuration for the QUIC server endpoint.
pub struct QuicServerConfig {
pub listen_addr: String,
pub cert_chain: Vec<CertificateDer<'static>>,
pub private_key: PrivateKeyDer<'static>,
pub idle_timeout_secs: u64,
}
/// Create a QUIC server endpoint bound to the given address.
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
let addr: SocketAddr = config.listen_addr.parse()?;
let provider = Arc::new(rustls::crypto::ring::default_provider());
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.with_no_client_auth()
.with_single_cert(config.cert_chain, config.private_key)?;
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
));
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
));
// Enable datagrams with a generous max size
transport.datagram_receive_buffer_size(Some(65535));
transport.datagram_send_buffer_size(65535);
server_config.transport_config(Arc::new(transport));
let endpoint = quinn::Endpoint::server(server_config, addr)?;
info!("QUIC server listening on {}", addr);
Ok(endpoint)
}
// ============================================================================
// Client-side QUIC connection
// ============================================================================
/// A certificate verifier that accepts any server certificate.
/// Safe when Noise NK provides server authentication at the application layer.
#[derive(Debug)]
struct AcceptAnyCert;
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
#[derive(Debug)]
struct CertHashVerifier {
expected_hash: String,
}
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let actual_hash = cert_hash(end_entity);
if actual_hash == self.expected_hash {
Ok(rustls::client::danger::ServerCertVerified::assertion())
} else {
Err(rustls::Error::General(format!(
"Certificate hash mismatch: expected {}, got {}",
self.expected_hash, actual_hash
)))
}
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
// QUIC always uses TLS 1.3
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
/// Connect to a QUIC server.
///
/// - If `server_cert_hash` is provided, verifies the server certificate matches
/// the given SHA-256 hash (cert pinning).
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
/// safe because the Noise NK handshake (which runs over the QUIC stream)
/// authenticates the server via its pre-shared public key — the same trust
/// model as WireGuard.
pub async fn connect_quic(
addr: &str,
server_cert_hash: Option<&str>,
) -> Result<quinn::Connection> {
let remote: SocketAddr = addr.parse()?;
let provider = Arc::new(rustls::crypto::ring::default_provider());
let tls_config = if let Some(hash) = server_cert_hash {
// Pin to a specific certificate hash
let mut config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.dangerous()
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
expected_hash: hash.to_string(),
}))
.with_no_client_auth();
config.alpn_protocols = vec![b"smartvpn".to_vec()];
config
} else {
// Accept any cert — Noise NK provides server authentication
let mut config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.dangerous()
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
.with_no_client_auth();
config.alpn_protocols = vec![b"smartvpn".to_vec()];
config
};
let client_config = quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(tls_config)?,
));
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
endpoint.set_default_client_config(client_config);
info!("Connecting to QUIC server at {}", addr);
let connection = endpoint.connect(remote, "smartvpn")?.await?;
info!("QUIC connection established");
Ok(connection)
}
// ============================================================================
// QUIC Transport Sink / Stream implementations
// ============================================================================
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
pub struct QuicTransportSink {
send_stream: quinn::SendStream,
connection: quinn::Connection,
}
impl QuicTransportSink {
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
Self {
send_stream,
connection,
}
}
}
#[async_trait]
impl TransportSink for QuicTransportSink {
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
// Length-prefix framing: [4-byte big-endian length][payload]
let len = data.len() as u32;
self.send_stream.write_all(&len.to_be_bytes()).await?;
self.send_stream.write_all(&data).await?;
Ok(())
}
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
let max_size = self.connection.max_datagram_size();
match max_size {
Some(max) if data.len() <= max => {
self.connection.send_datagram(data.into())?;
Ok(())
}
_ => {
// Datagram too large or datagrams disabled — fall back to reliable
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
self.send_reliable(data).await
}
}
}
async fn close(&mut self) -> Result<()> {
self.send_stream.finish()?;
Ok(())
}
}
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
pub struct QuicTransportStream {
recv_stream: quinn::RecvStream,
connection: quinn::Connection,
}
impl QuicTransportStream {
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
Self {
recv_stream,
connection,
}
}
}
#[async_trait]
impl TransportStream for QuicTransportStream {
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
// Read length prefix
let mut len_buf = [0u8; 4];
match self.recv_stream.read_exact(&mut len_buf).await {
Ok(()) => {}
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
warn!("QUIC connection lost: {}", e);
return Ok(None);
}
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > 65536 {
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
}
let mut data = vec![0u8; len];
match self.recv_stream.read_exact(&mut data).await {
Ok(()) => Ok(Some(data)),
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
}
}
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
match self.connection.read_datagram().await {
Ok(data) => Ok(Some(data.to_vec())),
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
}
}
fn supports_datagrams(&self) -> bool {
self.connection.max_datagram_size().is_some()
}
}
/// Accept a QUIC connection and open a bidirectional control stream.
/// Returns the transport sink/stream pair ready for the VPN handshake.
pub async fn accept_quic_connection(
conn: quinn::Connection,
) -> Result<(QuicTransportSink, QuicTransportStream)> {
// The client opens the bidirectional control stream
let (send, recv) = conn.accept_bi().await?;
info!("QUIC bidirectional control stream accepted");
Ok((
QuicTransportSink::new(send, conn.clone()),
QuicTransportStream::new(recv, conn),
))
}
/// Open a QUIC connection's bidirectional control stream (client side).
pub async fn open_quic_streams(
conn: quinn::Connection,
) -> Result<(QuicTransportSink, QuicTransportStream)> {
let (send, recv) = conn.open_bi().await?;
info!("QUIC bidirectional control stream opened");
Ok((
QuicTransportSink::new(send, conn.clone()),
QuicTransportStream::new(recv, conn),
))
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cert_generation_and_hash() {
let (certs, _key) = generate_self_signed_cert().unwrap();
assert_eq!(certs.len(), 1);
let hash = cert_hash(&certs[0]);
// SHA-256 base64 is 44 characters
assert_eq!(hash.len(), 44);
}
#[test]
fn test_cert_hash_deterministic() {
let (certs, _key) = generate_self_signed_cert().unwrap();
let hash1 = cert_hash(&certs[0]);
let hash2 = cert_hash(&certs[0]);
assert_eq!(hash1, hash2);
}
/// Helper: create QUIC server and client endpoints.
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
let (certs, key) = generate_self_signed_cert().unwrap();
let hash = cert_hash(&certs[0]);
let provider = Arc::new(rustls::crypto::ring::default_provider());
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions().unwrap()
.with_no_client_auth()
.with_single_cert(certs, key).unwrap();
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
));
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions().unwrap()
.dangerous()
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
expected_hash: hash,
}))
.with_no_client_auth();
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
let client_config = quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(client_tls).unwrap(),
));
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
client_ep.set_default_client_config(client_config);
let server_addr = server_ep.local_addr().unwrap().to_string();
(server_ep, client_ep, server_addr)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_server_client_roundtrip() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
// Server: accept, accept_bi, read, echo, finish
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
let data = s_recv.read_to_end(1024).await.unwrap();
s_send.write_all(&data).await.unwrap();
s_send.finish().unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
server_ep
});
// Client: connect, open_bi, write, finish, read
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
c_send.write_all(b"hello quinn").await.unwrap();
c_send.finish().unwrap();
let data = c_recv.read_to_end(1024).await.unwrap();
assert_eq!(&data[..], b"hello quinn");
let _ = server_task.await;
drop(client_ep);
}
/// Test transport trait wrappers over QUIC.
/// Key: client must send data first (QUIC streams are opened implicitly by data).
/// The server accept_bi runs concurrently with the client's first send_reliable.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_transport_trait_roundtrip() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
// Server task: accept connection, then accept_bi (blocks until client sends data)
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
(s_sink, s_stream, server_ep)
});
// Client: connect, open_bi via wrapper
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
// Client sends first — this triggers the QUIC stream to become visible to the server
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
// Now server's accept_bi unblocks
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
// Server reads the message
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
assert_eq!(msg, b"hello-from-client");
// Server -> Client
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
assert_eq!(msg, b"hello-from-server");
drop(client_ep);
}
/// Test QUIC datagram support.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_datagram_exchange() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
// Server: accept, accept_bi (opens control stream), then read datagram
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
// Accept bi stream (control channel)
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
// Read datagram
let dgram = conn.read_datagram().await.unwrap();
assert_eq!(&dgram[..], b"dgram-payload");
server_ep
});
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
// Send initial data to open the stream (required for QUIC)
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
// Small yield to let the server process the bi stream
tokio::task::yield_now().await;
// Send datagram
assert!(conn.max_datagram_size().is_some());
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
let _ = server_task.await.unwrap();
drop(client_ep);
}
/// Test that supports_datagrams returns true for QUIC transports.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_quic_supports_datagrams() {
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
let server_task = tokio::spawn(async move {
let conn = server_ep.accept().await.unwrap().await.unwrap();
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
assert!(s_stream.supports_datagrams());
server_ep
});
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
assert!(c_stream.supports_datagrams());
// Send data to trigger server's accept_bi
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
let _ = server_task.await.unwrap();
drop(client_ep);
}
}

View File

@@ -130,10 +130,12 @@ mod tests {
#[test]
fn tokens_do_not_exceed_burst() {
let mut tb = TokenBucket::new(1_000_000, 1_000);
// Use a low rate so refill between consecutive calls is negligible
let mut tb = TokenBucket::new(100, 1_000);
// Wait to accumulate — but should cap at burst
std::thread::sleep(Duration::from_millis(50));
assert!(tb.try_consume(1_000));
// At 100 bytes/sec, the few μs between calls add ~0 tokens
assert!(!tb.try_consume(1));
}
}

View File

@@ -1,6 +1,5 @@
use anyhow::Result;
use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::Ipv4Addr;
@@ -8,7 +7,6 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::{info, error, warn};
use crate::codec::{Frame, FrameCodec, PacketType};
@@ -17,6 +15,8 @@ use crate::mtu::{MtuConfig, TunnelOverhead};
use crate::network::IpPool;
use crate::ratelimit::TokenBucket;
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
@@ -39,6 +39,12 @@ pub struct ServerConfig {
pub default_rate_limit_bytes_per_sec: Option<u64>,
/// Default burst size for new clients (bytes). None = unlimited.
pub default_burst_bytes: Option<u64>,
/// Transport mode: "websocket" (default), "quic", or "both".
pub transport_mode: Option<String>,
/// QUIC listen address (host:port). Defaults to listen_addr.
pub quic_listen_addr: Option<String>,
/// QUIC idle timeout in seconds (default: 30).
pub quic_idle_timeout_secs: Option<u64>,
}
/// Information about a connected client.
@@ -135,14 +141,58 @@ impl VpnServer {
self.state = Some(state.clone());
self.shutdown_tx = Some(shutdown_tx);
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
let listen_addr = config.listen_addr.clone();
tokio::spawn(async move {
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
info!("VPN server started");
match transport_mode {
"quic" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
error!("QUIC listener error: {}", e);
}
});
}
"both" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
let state2 = state.clone();
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
// Store second shutdown sender so both listeners stop
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(combined_tx);
// Forward combined shutdown to both listeners
tokio::spawn(async move {
combined_rx.recv().await;
let _ = shutdown_tx_orig.send(()).await;
let _ = shutdown_tx2.send(()).await;
});
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("WebSocket listener error: {}", e);
}
});
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
error!("QUIC listener error: {}", e);
}
});
}
_ => {
// "websocket" (default)
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
}
}
info!("VPN server started (transport: {})", transport_mode);
Ok(())
}
@@ -239,7 +289,9 @@ impl VpnServer {
}
}
async fn run_listener(
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
/// to the transport-agnostic `handle_client_connection`.
async fn run_ws_listener(
state: Arc<ServerState>,
listen_addr: String,
shutdown_rx: &mut mpsc::Receiver<()>,
@@ -255,8 +307,20 @@ async fn run_listener(
info!("New connection from {}", addr);
let state = state.clone();
tokio::spawn(async move {
if let Err(e) = handle_client_connection(state, stream).await {
warn!("Client connection error: {}", e);
match transport::accept_connection(stream).await {
Ok(ws) => {
let (sink, stream) = transport_trait::split_ws(ws);
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("Client connection error: {}", e);
}
}
Err(e) => {
warn!("WebSocket upgrade failed: {}", e);
}
}
});
}
@@ -275,13 +339,95 @@ async fn run_listener(
Ok(())
}
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
/// `handle_client_connection`.
async fn run_quic_listener(
state: Arc<ServerState>,
listen_addr: String,
idle_timeout_secs: u64,
shutdown_rx: &mut mpsc::Receiver<()>,
) -> Result<()> {
// Generate or use configured TLS certificate for QUIC
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
(&state.config.tls_cert, &state.config.tls_key)
{
// Parse PEM certificates
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
(certs, key)
} else {
// Generate self-signed certificate
let (certs, key) = quic_transport::generate_self_signed_cert()?;
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
(certs, key)
};
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
listen_addr,
cert_chain,
private_key,
idle_timeout_secs,
})?;
loop {
tokio::select! {
incoming = endpoint.accept() => {
match incoming {
Some(incoming) => {
let state = state.clone();
tokio::spawn(async move {
match incoming.await {
Ok(conn) => {
let remote = conn.remote_address();
info!("New QUIC connection from {}", remote);
match quic_transport::accept_quic_connection(conn).await {
Ok((sink, stream)) => {
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("QUIC client error: {}", e);
}
}
Err(e) => {
warn!("QUIC stream accept failed: {}", e);
}
}
}
Err(e) => {
warn!("QUIC handshake failed: {}", e);
}
}
});
}
None => {
info!("QUIC endpoint closed");
break;
}
}
}
_ = shutdown_rx.recv() => {
info!("QUIC shutdown signal received");
endpoint.close(0u32.into(), b"shutdown");
break;
}
}
}
Ok(())
}
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
/// the client, and runs the main packet forwarding loop.
async fn handle_client_connection(
state: Arc<ServerState>,
stream: tokio::net::TcpStream,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
) -> Result<()> {
let ws = transport::accept_connection(stream).await?;
let (mut ws_sink, mut ws_stream) = ws.split();
let client_id = uuid_v4();
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
@@ -295,9 +441,9 @@ async fn handle_client_connection(
let mut buf = vec![0u8; 65535];
// Receive handshake init
let init_msg = match ws_stream.next().await {
Some(Ok(Message::Binary(data))) => data.to_vec(),
_ => anyhow::bail!("Expected handshake init message"),
let init_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before handshake"),
};
let mut frame_buf = BytesMut::from(&init_msg[..]);
@@ -318,7 +464,7 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
let mut noise_transport = responder.into_transport_mode()?;
@@ -369,7 +515,7 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
info!("Client {} connected with IP {}", client_id, assigned_ip);
@@ -378,11 +524,11 @@ async fn handle_client_connection(
loop {
tokio::select! {
msg = ws_stream.next() => {
msg = stream.recv_reliable() => {
match msg {
Some(Ok(Message::Binary(data))) => {
Ok(Some(data)) => {
last_activity = tokio::time::Instant::now();
let mut frame_buf = BytesMut::from(&data[..][..]);
let mut frame_buf = BytesMut::from(&data[..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
@@ -432,7 +578,7 @@ async fn handle_client_connection(
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
sink.send_reliable(frame_bytes.to_vec()).await?;
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
@@ -463,20 +609,12 @@ async fn handle_client_connection(
}
}
}
Some(Ok(Message::Close(_))) | None => {
Ok(None) => {
info!("Client {} connection closed", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
last_activity = tokio::time::Instant::now();
ws_sink.send(Message::Pong(data)).await?;
}
Some(Ok(_)) => {
last_activity = tokio::time::Instant::now();
continue;
}
Some(Err(e)) => {
warn!("WebSocket error from {}: {}", client_id, e);
Err(e) => {
warn!("Transport error from {}: {}", client_id, e);
break;
}
}

116
rust/src/transport_trait.rs Normal file
View File

@@ -0,0 +1,116 @@
use anyhow::Result;
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
use crate::transport::WsStream;
// ============================================================================
// Transport trait abstraction
// ============================================================================
/// Outbound half of a VPN transport connection.
#[async_trait]
pub trait TransportSink: Send + 'static {
/// Send a framed binary message on the reliable channel.
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
/// Send a datagram (unreliable, best-effort).
/// Falls back to reliable if the transport does not support datagrams.
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
/// Gracefully close the transport.
async fn close(&mut self) -> Result<()>;
}
/// Inbound half of a VPN transport connection.
#[async_trait]
pub trait TransportStream: Send + 'static {
/// Receive the next reliable binary message. Returns `None` on close.
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
/// Receive the next datagram. Returns `None` if datagrams are unsupported
/// or the connection is closed.
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
/// Whether this transport supports unreliable datagrams.
fn supports_datagrams(&self) -> bool;
}
// ============================================================================
// WebSocket implementation
// ============================================================================
/// WebSocket transport sink (wraps the write half of a split WsStream).
pub struct WsTransportSink {
inner: futures_util::stream::SplitSink<WsStream, Message>,
}
impl WsTransportSink {
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
Self { inner }
}
}
#[async_trait]
impl TransportSink for WsTransportSink {
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
self.inner.send(Message::Binary(data.into())).await?;
Ok(())
}
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
// WebSocket has no datagram support — fall back to reliable.
self.send_reliable(data).await
}
async fn close(&mut self) -> Result<()> {
self.inner.close().await?;
Ok(())
}
}
/// WebSocket transport stream (wraps the read half of a split WsStream).
pub struct WsTransportStream {
inner: futures_util::stream::SplitStream<WsStream>,
}
impl WsTransportStream {
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
Self { inner }
}
}
#[async_trait]
impl TransportStream for WsTransportStream {
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
loop {
match self.inner.next().await {
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
Some(Ok(Message::Close(_))) | None => return Ok(None),
Some(Ok(Message::Ping(_))) => {
// Ping handling is done at the tungstenite layer automatically
// when the sink side is alive. Just skip here.
continue;
}
Some(Ok(_)) => continue,
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
}
}
}
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
// WebSocket does not support datagrams.
Ok(None)
}
fn supports_datagrams(&self) -> bool {
false
}
}
/// Split a WebSocket stream into transport sink and stream halves.
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
let (sink, stream) = ws.split();
(WsTransportSink::new(sink), WsTransportStream::new(stream))
}