feat(vpn transport): add QUIC transport support with auto fallback to WebSocket
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn, debug};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
@@ -12,6 +10,8 @@ use crate::crypto;
|
||||
use crate::keepalive::{self, KeepaliveSignal, LinkHealth};
|
||||
use crate::telemetry::ConnectionQuality;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Client configuration (matches TS IVpnClientConfig).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -22,6 +22,10 @@ pub struct ClientConfig {
|
||||
pub dns: Option<Vec<String>>,
|
||||
pub mtu: Option<u16>,
|
||||
pub keepalive_interval_secs: Option<u64>,
|
||||
/// Transport type: "websocket" (default) or "quic".
|
||||
pub transport: Option<String>,
|
||||
/// For QUIC: SHA-256 hash of server certificate (base64) for cert pinning.
|
||||
pub server_cert_hash: Option<String>,
|
||||
}
|
||||
|
||||
/// Client statistics.
|
||||
@@ -106,9 +110,66 @@ impl VpnClient {
|
||||
&config.server_public_key,
|
||||
)?;
|
||||
|
||||
// Connect to WebSocket server
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
// Create transport based on configuration
|
||||
let (mut sink, mut stream): (Box<dyn TransportSink>, Box<dyn TransportStream>) = {
|
||||
let transport_type = config.transport.as_deref().unwrap_or("auto");
|
||||
match transport_type {
|
||||
"quic" => {
|
||||
let server_addr = &config.server_url; // For QUIC, serverUrl is host:port
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
let conn = quic_transport::connect_quic(server_addr, cert_hash).await?;
|
||||
let (quic_sink, quic_stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
info!("Connected via QUIC");
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
"websocket" => {
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
_ => {
|
||||
// "auto" (default): try QUIC first, fall back to WebSocket
|
||||
// Extract host:port from the URL for QUIC attempt
|
||||
let quic_addr = extract_host_port(&config.server_url);
|
||||
let cert_hash = config.server_cert_hash.as_deref();
|
||||
|
||||
if let Some(ref addr) = quic_addr {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(3),
|
||||
try_quic_connect(addr, cert_hash),
|
||||
).await {
|
||||
Ok(Ok((quic_sink, quic_stream))) => {
|
||||
info!("Auto: connected via QUIC to {}", addr);
|
||||
(Box::new(quic_sink) as Box<dyn TransportSink>,
|
||||
Box::new(quic_stream) as Box<dyn TransportStream>)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Auto: QUIC failed ({}), falling back to WebSocket", e);
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC unavailable)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Auto: QUIC timed out, falling back to WebSocket");
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Auto: connected via WebSocket (QUIC timed out)");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Can't extract host:port for QUIC, use WebSocket directly
|
||||
let ws = transport::connect_to_server(&config.server_url).await?;
|
||||
let (ws_sink, ws_stream) = transport_trait::split_ws(ws);
|
||||
info!("Connected via WebSocket");
|
||||
(Box::new(ws_sink), Box::new(ws_stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Noise NK handshake (client side = initiator)
|
||||
*state.write().await = ClientState::Handshaking;
|
||||
@@ -123,13 +184,11 @@ impl VpnClient {
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, init_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
// <- e, ee
|
||||
let resp_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
Some(Ok(_)) => anyhow::bail!("Expected binary handshake response"),
|
||||
Some(Err(e)) => anyhow::bail!("WebSocket error during handshake: {}", e),
|
||||
let resp_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed during handshake"),
|
||||
};
|
||||
|
||||
@@ -145,9 +204,9 @@ impl VpnClient {
|
||||
let mut noise_transport = initiator.into_transport_mode()?;
|
||||
|
||||
// Receive assigned IP info (encrypted)
|
||||
let info_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected IP info message"),
|
||||
let info_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before IP info"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&info_msg[..]);
|
||||
@@ -184,8 +243,8 @@ impl VpnClient {
|
||||
// Spawn packet forwarding loop
|
||||
let assigned_ip_clone = assigned_ip.clone();
|
||||
tokio::spawn(client_loop(
|
||||
ws_sink,
|
||||
ws_stream,
|
||||
sink,
|
||||
stream,
|
||||
noise_transport,
|
||||
state,
|
||||
stats,
|
||||
@@ -280,8 +339,8 @@ impl VpnClient {
|
||||
|
||||
/// The main client packet forwarding loop (runs in a spawned task).
|
||||
async fn client_loop(
|
||||
mut ws_sink: futures_util::stream::SplitSink<transport::WsStream, Message>,
|
||||
mut ws_stream: futures_util::stream::SplitStream<transport::WsStream>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
mut noise_transport: snow::TransportState,
|
||||
state: Arc<RwLock<ClientState>>,
|
||||
stats: Arc<RwLock<ClientStatistics>>,
|
||||
@@ -294,10 +353,10 @@ async fn client_loop(
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
Ok(Some(data)) => {
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
if let Ok(Some(frame)) = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
@@ -328,17 +387,13 @@ async fn client_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Connection closed");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_sink.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
Err(e) => {
|
||||
error!("Transport error: {}", e);
|
||||
*state.write().await = ClientState::Error(e.to_string());
|
||||
break;
|
||||
}
|
||||
@@ -354,7 +409,7 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ka_frame, &mut frame_bytes).is_ok() {
|
||||
if ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await.is_err() {
|
||||
if sink.send_reliable(frame_bytes.to_vec()).await.is_err() {
|
||||
warn!("Failed to send keepalive");
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
@@ -385,12 +440,51 @@ async fn client_loop(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
if <FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, dc_frame, &mut frame_bytes).is_ok() {
|
||||
let _ = ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
let _ = sink.close().await;
|
||||
*state.write().await = ClientState::Disconnected;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect via QUIC. Returns transport halves on success.
|
||||
async fn try_quic_connect(
|
||||
addr: &str,
|
||||
cert_hash: Option<&str>,
|
||||
) -> Result<(quic_transport::QuicTransportSink, quic_transport::QuicTransportStream)> {
|
||||
let conn = quic_transport::connect_quic(addr, cert_hash).await?;
|
||||
let (sink, stream) = quic_transport::open_quic_streams(conn).await?;
|
||||
Ok((sink, stream))
|
||||
}
|
||||
|
||||
/// Extract host:port from a WebSocket URL for QUIC auto-fallback.
|
||||
/// e.g. "ws://127.0.0.1:8080" -> Some("127.0.0.1:8080")
|
||||
/// "wss://vpn.example.com/tunnel" -> Some("vpn.example.com:443")
|
||||
/// "127.0.0.1:8080" -> Some("127.0.0.1:8080") (already host:port)
|
||||
fn extract_host_port(url: &str) -> Option<String> {
|
||||
if url.starts_with("ws://") || url.starts_with("wss://") {
|
||||
// Parse as URL
|
||||
let stripped = if url.starts_with("wss://") {
|
||||
&url[6..]
|
||||
} else {
|
||||
&url[5..]
|
||||
};
|
||||
// Remove path
|
||||
let host_port = stripped.split('/').next()?;
|
||||
if host_port.contains(':') {
|
||||
Some(host_port.to_string())
|
||||
} else {
|
||||
// Default port
|
||||
let default_port = if url.starts_with("wss://") { 443 } else { 80 };
|
||||
Some(format!("{}:{}", host_port, default_port))
|
||||
}
|
||||
} else if url.contains(':') {
|
||||
// Already host:port
|
||||
Some(url.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ pub mod management;
|
||||
pub mod codec;
|
||||
pub mod crypto;
|
||||
pub mod transport;
|
||||
pub mod transport_trait;
|
||||
pub mod quic_transport;
|
||||
pub mod keepalive;
|
||||
pub mod tunnel;
|
||||
pub mod network;
|
||||
|
||||
546
rust/src/quic_transport.rs
Normal file
546
rust/src/quic_transport.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use quinn::crypto::rustls::QuicClientConfig;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn, debug};
|
||||
|
||||
use crate::transport_trait::{TransportSink, TransportStream};
|
||||
|
||||
// ============================================================================
|
||||
// TLS / Certificate helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a self-signed certificate and private key for QUIC.
|
||||
pub fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["smartvpn".to_string()])?;
|
||||
let cert_der = CertificateDer::from(cert.cert);
|
||||
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
|
||||
Ok((vec![cert_der], key_der))
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 hash of a DER-encoded certificate and return it as base64.
|
||||
pub fn cert_hash(cert_der: &CertificateDer<'_>) -> String {
|
||||
use ring::digest;
|
||||
let hash = digest::digest(&digest::SHA256, cert_der.as_ref());
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_ref())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server-side QUIC endpoint
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for the QUIC server endpoint.
|
||||
pub struct QuicServerConfig {
|
||||
pub listen_addr: String,
|
||||
pub cert_chain: Vec<CertificateDer<'static>>,
|
||||
pub private_key: PrivateKeyDer<'static>,
|
||||
pub idle_timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// Create a QUIC server endpoint bound to the given address.
|
||||
pub fn create_quic_server(config: QuicServerConfig) -> Result<quinn::Endpoint> {
|
||||
let addr: SocketAddr = config.listen_addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let mut tls_config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(config.cert_chain, config.private_key)?;
|
||||
tls_config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
quinn::IdleTimeout::try_from(Duration::from_secs(config.idle_timeout_secs))?,
|
||||
));
|
||||
// Enable datagrams with a generous max size
|
||||
transport.datagram_receive_buffer_size(Some(65535));
|
||||
transport.datagram_send_buffer_size(65535);
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
let endpoint = quinn::Endpoint::server(server_config, addr)?;
|
||||
info!("QUIC server listening on {}", addr);
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Client-side QUIC connection
|
||||
// ============================================================================
|
||||
|
||||
/// A certificate verifier that accepts any server certificate.
|
||||
/// Safe when Noise NK provides server authentication at the application layer.
|
||||
#[derive(Debug)]
|
||||
struct AcceptAnyCert;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// A certificate verifier that accepts any certificate matching a given SHA-256 hash.
|
||||
#[derive(Debug)]
|
||||
struct CertHashVerifier {
|
||||
expected_hash: String,
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for CertHashVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
let actual_hash = cert_hash(end_entity);
|
||||
if actual_hash == self.expected_hash {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
} else {
|
||||
Err(rustls::Error::General(format!(
|
||||
"Certificate hash mismatch: expected {}, got {}",
|
||||
self.expected_hash, actual_hash
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
// QUIC always uses TLS 1.3
|
||||
Err(rustls::Error::General("TLS 1.2 not supported".to_string()))
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a QUIC server.
|
||||
///
|
||||
/// - If `server_cert_hash` is provided, verifies the server certificate matches
|
||||
/// the given SHA-256 hash (cert pinning).
|
||||
/// - If `server_cert_hash` is `None`, accepts any server certificate. This is
|
||||
/// safe because the Noise NK handshake (which runs over the QUIC stream)
|
||||
/// authenticates the server via its pre-shared public key — the same trust
|
||||
/// model as WireGuard.
|
||||
pub async fn connect_quic(
|
||||
addr: &str,
|
||||
server_cert_hash: Option<&str>,
|
||||
) -> Result<quinn::Connection> {
|
||||
let remote: SocketAddr = addr.parse()?;
|
||||
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
let tls_config = if let Some(hash) = server_cert_hash {
|
||||
// Pin to a specific certificate hash
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash.to_string(),
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
} else {
|
||||
// Accept any cert — Noise NK provides server authentication
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
config
|
||||
};
|
||||
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(tls_config)?,
|
||||
));
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse()?)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
info!("Connecting to QUIC server at {}", addr);
|
||||
let connection = endpoint.connect(remote, "smartvpn")?.await?;
|
||||
info!("QUIC connection established");
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// QUIC Transport Sink / Stream implementations
|
||||
// ============================================================================
|
||||
|
||||
/// QUIC transport sink — wraps a SendStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportSink {
|
||||
send_stream: quinn::SendStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportSink {
|
||||
pub fn new(send_stream: quinn::SendStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
send_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for QuicTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// Length-prefix framing: [4-byte big-endian length][payload]
|
||||
let len = data.len() as u32;
|
||||
self.send_stream.write_all(&len.to_be_bytes()).await?;
|
||||
self.send_stream.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
let max_size = self.connection.max_datagram_size();
|
||||
match max_size {
|
||||
Some(max) if data.len() <= max => {
|
||||
self.connection.send_datagram(data.into())?;
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Datagram too large or datagrams disabled — fall back to reliable
|
||||
debug!("Datagram too large ({}B), falling back to reliable stream", data.len());
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.send_stream.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// QUIC transport stream — wraps a RecvStream (reliable) and Connection (datagrams).
|
||||
pub struct QuicTransportStream {
|
||||
recv_stream: quinn::RecvStream,
|
||||
connection: quinn::Connection,
|
||||
}
|
||||
|
||||
impl QuicTransportStream {
|
||||
pub fn new(recv_stream: quinn::RecvStream, connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
recv_stream,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for QuicTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// Read length prefix
|
||||
let mut len_buf = [0u8; 4];
|
||||
match self.recv_stream.read_exact(&mut len_buf).await {
|
||||
Ok(()) => {}
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None),
|
||||
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(e))) => {
|
||||
warn!("QUIC connection lost: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(len_buf) as usize;
|
||||
if len > 65536 {
|
||||
return Err(anyhow::anyhow!("Frame too large: {} bytes", len));
|
||||
}
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
match self.recv_stream.read_exact(&mut data).await {
|
||||
Ok(()) => Ok(Some(data)),
|
||||
Err(quinn::ReadExactError::FinishedEarly(_)) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC read error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
match self.connection.read_datagram().await {
|
||||
Ok(data) => Ok(Some(data.to_vec())),
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
|
||||
Err(quinn::ConnectionError::LocallyClosed) => Ok(None),
|
||||
Err(e) => Err(anyhow::anyhow!("QUIC datagram error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
self.connection.max_datagram_size().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a QUIC connection and open a bidirectional control stream.
|
||||
/// Returns the transport sink/stream pair ready for the VPN handshake.
|
||||
pub async fn accept_quic_connection(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
// The client opens the bidirectional control stream
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
info!("QUIC bidirectional control stream accepted");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
/// Open a QUIC connection's bidirectional control stream (client side).
|
||||
pub async fn open_quic_streams(
|
||||
conn: quinn::Connection,
|
||||
) -> Result<(QuicTransportSink, QuicTransportStream)> {
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
info!("QUIC bidirectional control stream opened");
|
||||
Ok((
|
||||
QuicTransportSink::new(send, conn.clone()),
|
||||
QuicTransportStream::new(recv, conn),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cert_generation_and_hash() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
assert_eq!(certs.len(), 1);
|
||||
let hash = cert_hash(&certs[0]);
|
||||
// SHA-256 base64 is 44 characters
|
||||
assert_eq!(hash.len(), 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cert_hash_deterministic() {
|
||||
let (certs, _key) = generate_self_signed_cert().unwrap();
|
||||
let hash1 = cert_hash(&certs[0]);
|
||||
let hash2 = cert_hash(&certs[0]);
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
/// Helper: create QUIC server and client endpoints.
|
||||
fn create_quic_endpoints() -> (quinn::Endpoint, quinn::Endpoint, String) {
|
||||
let (certs, key) = generate_self_signed_cert().unwrap();
|
||||
let hash = cert_hash(&certs[0]);
|
||||
let provider = Arc::new(rustls::crypto::ring::default_provider());
|
||||
|
||||
let mut server_tls = rustls::ServerConfig::builder_with_provider(provider.clone())
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key).unwrap();
|
||||
server_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let server_qcfg = quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(server_tls).unwrap(),
|
||||
));
|
||||
let server_ep = quinn::Endpoint::server(server_qcfg, "127.0.0.1:0".parse().unwrap()).unwrap();
|
||||
|
||||
let mut client_tls = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions().unwrap()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(CertHashVerifier {
|
||||
expected_hash: hash,
|
||||
}))
|
||||
.with_no_client_auth();
|
||||
client_tls.alpn_protocols = vec![b"smartvpn".to_vec()];
|
||||
let client_config = quinn::ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(client_tls).unwrap(),
|
||||
));
|
||||
let mut client_ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
client_ep.set_default_client_config(client_config);
|
||||
|
||||
let server_addr = server_ep.local_addr().unwrap().to_string();
|
||||
(server_ep, client_ep, server_addr)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_server_client_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi, read, echo, finish
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (mut s_send, mut s_recv) = conn.accept_bi().await.unwrap();
|
||||
let data = s_recv.read_to_end(1024).await.unwrap();
|
||||
s_send.write_all(&data).await.unwrap();
|
||||
s_send.finish().unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open_bi, write, finish, read
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_send, mut c_recv) = conn.open_bi().await.unwrap();
|
||||
c_send.write_all(b"hello quinn").await.unwrap();
|
||||
c_send.finish().unwrap();
|
||||
let data = c_recv.read_to_end(1024).await.unwrap();
|
||||
assert_eq!(&data[..], b"hello quinn");
|
||||
|
||||
let _ = server_task.await;
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test transport trait wrappers over QUIC.
|
||||
/// Key: client must send data first (QUIC streams are opened implicitly by data).
|
||||
/// The server accept_bi runs concurrently with the client's first send_reliable.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_transport_trait_roundtrip() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server task: accept connection, then accept_bi (blocks until client sends data)
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
(s_sink, s_stream, server_ep)
|
||||
});
|
||||
|
||||
// Client: connect, open_bi via wrapper
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, mut c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
|
||||
// Client sends first — this triggers the QUIC stream to become visible to the server
|
||||
c_sink.send_reliable(b"hello-from-client".to_vec()).await.unwrap();
|
||||
|
||||
// Now server's accept_bi unblocks
|
||||
let (mut s_sink, mut s_stream, _sep) = server_task.await.unwrap();
|
||||
|
||||
// Server reads the message
|
||||
let msg = s_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-client");
|
||||
|
||||
// Server -> Client
|
||||
s_sink.send_reliable(b"hello-from-server".to_vec()).await.unwrap();
|
||||
let msg = c_stream.recv_reliable().await.unwrap().unwrap();
|
||||
assert_eq!(msg, b"hello-from-server");
|
||||
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test QUIC datagram support.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_datagram_exchange() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
// Server: accept, accept_bi (opens control stream), then read datagram
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
// Accept bi stream (control channel)
|
||||
let (_s_sink, _s_stream) = accept_quic_connection(conn.clone()).await.unwrap();
|
||||
// Read datagram
|
||||
let dgram = conn.read_datagram().await.unwrap();
|
||||
assert_eq!(&dgram[..], b"dgram-payload");
|
||||
server_ep
|
||||
});
|
||||
|
||||
// Client: connect, open bi stream (triggers server accept_bi), then send datagram
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, _c_stream) = open_quic_streams(conn.clone()).await.unwrap();
|
||||
|
||||
// Send initial data to open the stream (required for QUIC)
|
||||
c_sink.send_reliable(b"init".to_vec()).await.unwrap();
|
||||
|
||||
// Small yield to let the server process the bi stream
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// Send datagram
|
||||
assert!(conn.max_datagram_size().is_some());
|
||||
conn.send_datagram(bytes::Bytes::from_static(b"dgram-payload")).unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
|
||||
/// Test that supports_datagrams returns true for QUIC transports.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_quic_supports_datagrams() {
|
||||
let (server_ep, client_ep, server_addr) = create_quic_endpoints();
|
||||
let addr: std::net::SocketAddr = server_addr.parse().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let conn = server_ep.accept().await.unwrap().await.unwrap();
|
||||
let (_s_sink, s_stream) = accept_quic_connection(conn).await.unwrap();
|
||||
assert!(s_stream.supports_datagrams());
|
||||
server_ep
|
||||
});
|
||||
|
||||
let conn = client_ep.connect(addr, "smartvpn").unwrap().await.unwrap();
|
||||
let (mut c_sink, c_stream) = open_quic_streams(conn).await.unwrap();
|
||||
assert!(c_stream.supports_datagrams());
|
||||
|
||||
// Send data to trigger server's accept_bi
|
||||
c_sink.send_reliable(b"ping".to_vec()).await.unwrap();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
drop(client_ep);
|
||||
}
|
||||
}
|
||||
@@ -130,10 +130,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tokens_do_not_exceed_burst() {
|
||||
let mut tb = TokenBucket::new(1_000_000, 1_000);
|
||||
// Use a low rate so refill between consecutive calls is negligible
|
||||
let mut tb = TokenBucket::new(100, 1_000);
|
||||
// Wait to accumulate — but should cap at burst
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(tb.try_consume(1_000));
|
||||
// At 100 bytes/sec, the few μs between calls add ~0 tokens
|
||||
assert!(!tb.try_consume(1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
@@ -8,7 +7,6 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
use crate::codec::{Frame, FrameCodec, PacketType};
|
||||
@@ -17,6 +15,8 @@ use crate::mtu::{MtuConfig, TunnelOverhead};
|
||||
use crate::network::IpPool;
|
||||
use crate::ratelimit::TokenBucket;
|
||||
use crate::transport;
|
||||
use crate::transport_trait::{self, TransportSink, TransportStream};
|
||||
use crate::quic_transport;
|
||||
|
||||
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
|
||||
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
|
||||
@@ -39,6 +39,12 @@ pub struct ServerConfig {
|
||||
pub default_rate_limit_bytes_per_sec: Option<u64>,
|
||||
/// Default burst size for new clients (bytes). None = unlimited.
|
||||
pub default_burst_bytes: Option<u64>,
|
||||
/// Transport mode: "websocket" (default), "quic", or "both".
|
||||
pub transport_mode: Option<String>,
|
||||
/// QUIC listen address (host:port). Defaults to listen_addr.
|
||||
pub quic_listen_addr: Option<String>,
|
||||
/// QUIC idle timeout in seconds (default: 30).
|
||||
pub quic_idle_timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -135,14 +141,58 @@ impl VpnServer {
|
||||
self.state = Some(state.clone());
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
|
||||
let listen_addr = config.listen_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
info!("VPN server started");
|
||||
match transport_mode {
|
||||
"quic" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
"both" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
let state2 = state.clone();
|
||||
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
|
||||
// Store second shutdown sender so both listeners stop
|
||||
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
|
||||
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
|
||||
self.shutdown_tx = Some(combined_tx);
|
||||
|
||||
// Forward combined shutdown to both listeners
|
||||
tokio::spawn(async move {
|
||||
combined_rx.recv().await;
|
||||
let _ = shutdown_tx_orig.send(()).await;
|
||||
let _ = shutdown_tx2.send(()).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("WebSocket listener error: {}", e);
|
||||
}
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// "websocket" (default)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!("VPN server started (transport: {})", transport_mode);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -239,7 +289,9 @@ impl VpnServer {
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_listener(
|
||||
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
|
||||
/// to the transport-agnostic `handle_client_connection`.
|
||||
async fn run_ws_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
@@ -255,8 +307,20 @@ async fn run_listener(
|
||||
info!("New connection from {}", addr);
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_client_connection(state, stream).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
match transport::accept_connection(stream).await {
|
||||
Ok(ws) => {
|
||||
let (sink, stream) = transport_trait::split_ws(ws);
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("WebSocket upgrade failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -275,13 +339,95 @@ async fn run_listener(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
|
||||
/// `handle_client_connection`.
|
||||
async fn run_quic_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
idle_timeout_secs: u64,
|
||||
shutdown_rx: &mut mpsc::Receiver<()>,
|
||||
) -> Result<()> {
|
||||
// Generate or use configured TLS certificate for QUIC
|
||||
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
|
||||
(&state.config.tls_cert, &state.config.tls_key)
|
||||
{
|
||||
// Parse PEM certificates
|
||||
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut cert_pem.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
|
||||
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
|
||||
(certs, key)
|
||||
} else {
|
||||
// Generate self-signed certificate
|
||||
let (certs, key) = quic_transport::generate_self_signed_cert()?;
|
||||
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
|
||||
(certs, key)
|
||||
};
|
||||
|
||||
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
|
||||
listen_addr,
|
||||
cert_chain,
|
||||
private_key,
|
||||
idle_timeout_secs,
|
||||
})?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
incoming = endpoint.accept() => {
|
||||
match incoming {
|
||||
Some(incoming) => {
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
match incoming.await {
|
||||
Ok(conn) => {
|
||||
let remote = conn.remote_address();
|
||||
info!("New QUIC connection from {}", remote);
|
||||
match quic_transport::accept_quic_connection(conn).await {
|
||||
Ok((sink, stream)) => {
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC stream accept failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("QUIC handshake failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
None => {
|
||||
info!("QUIC endpoint closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("QUIC shutdown signal received");
|
||||
endpoint.close(0u32.into(), b"shutdown");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
|
||||
/// the client, and runs the main packet forwarding loop.
|
||||
async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
stream: tokio::net::TcpStream,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
) -> Result<()> {
|
||||
let ws = transport::accept_connection(stream).await?;
|
||||
let (mut ws_sink, mut ws_stream) = ws.split();
|
||||
|
||||
let client_id = uuid_v4();
|
||||
|
||||
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
|
||||
@@ -295,9 +441,9 @@ async fn handle_client_connection(
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
// Receive handshake init
|
||||
let init_msg = match ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(data))) => data.to_vec(),
|
||||
_ => anyhow::bail!("Expected handshake init message"),
|
||||
let init_msg = match stream.recv_reliable().await? {
|
||||
Some(data) => data,
|
||||
None => anyhow::bail!("Connection closed before handshake"),
|
||||
};
|
||||
|
||||
let mut frame_buf = BytesMut::from(&init_msg[..]);
|
||||
@@ -318,7 +464,7 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
@@ -369,7 +515,7 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
info!("Client {} connected with IP {}", client_id, assigned_ip);
|
||||
|
||||
@@ -378,11 +524,11 @@ async fn handle_client_connection(
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = ws_stream.next() => {
|
||||
msg = stream.recv_reliable() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
Ok(Some(data)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
let mut frame_buf = BytesMut::from(&data[..][..]);
|
||||
let mut frame_buf = BytesMut::from(&data[..]);
|
||||
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
|
||||
Ok(Some(frame)) => match frame.packet_type {
|
||||
PacketType::IpPacket => {
|
||||
@@ -432,7 +578,7 @@ async fn handle_client_connection(
|
||||
};
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
|
||||
ws_sink.send(Message::Binary(frame_bytes.to_vec().into())).await?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
let mut stats = state.stats.write().await;
|
||||
stats.keepalives_received += 1;
|
||||
@@ -463,20 +609,12 @@ async fn handle_client_connection(
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
Ok(None) => {
|
||||
info!("Client {} connection closed", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
ws_sink.send(Message::Pong(data)).await?;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
last_activity = tokio::time::Instant::now();
|
||||
continue;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("WebSocket error from {}: {}", client_id, e);
|
||||
Err(e) => {
|
||||
warn!("Transport error from {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
116
rust/src/transport_trait.rs
Normal file
116
rust/src/transport_trait.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::transport::WsStream;
|
||||
|
||||
// ============================================================================
|
||||
// Transport trait abstraction
|
||||
// ============================================================================
|
||||
|
||||
/// Outbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportSink: Send + 'static {
|
||||
/// Send a framed binary message on the reliable channel.
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Send a datagram (unreliable, best-effort).
|
||||
/// Falls back to reliable if the transport does not support datagrams.
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()>;
|
||||
|
||||
/// Gracefully close the transport.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Inbound half of a VPN transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportStream: Send + 'static {
|
||||
/// Receive the next reliable binary message. Returns `None` on close.
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Receive the next datagram. Returns `None` if datagrams are unsupported
|
||||
/// or the connection is closed.
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
/// Whether this transport supports unreliable datagrams.
|
||||
fn supports_datagrams(&self) -> bool;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket implementation
|
||||
// ============================================================================
|
||||
|
||||
/// WebSocket transport sink (wraps the write half of a split WsStream).
|
||||
pub struct WsTransportSink {
|
||||
inner: futures_util::stream::SplitSink<WsStream, Message>,
|
||||
}
|
||||
|
||||
impl WsTransportSink {
|
||||
pub fn new(inner: futures_util::stream::SplitSink<WsStream, Message>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSink for WsTransportSink {
|
||||
async fn send_reliable(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
self.inner.send(Message::Binary(data.into())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_datagram(&mut self, data: Vec<u8>) -> Result<()> {
|
||||
// WebSocket has no datagram support — fall back to reliable.
|
||||
self.send_reliable(data).await
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
self.inner.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket transport stream (wraps the read half of a split WsStream).
|
||||
pub struct WsTransportStream {
|
||||
inner: futures_util::stream::SplitStream<WsStream>,
|
||||
}
|
||||
|
||||
impl WsTransportStream {
|
||||
pub fn new(inner: futures_util::stream::SplitStream<WsStream>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportStream for WsTransportStream {
|
||||
async fn recv_reliable(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
Some(Ok(Message::Binary(data))) => return Ok(Some(data.to_vec())),
|
||||
Some(Ok(Message::Close(_))) | None => return Ok(None),
|
||||
Some(Ok(Message::Ping(_))) => {
|
||||
// Ping handling is done at the tungstenite layer automatically
|
||||
// when the sink side is alive. Just skip here.
|
||||
continue;
|
||||
}
|
||||
Some(Ok(_)) => continue,
|
||||
Some(Err(e)) => return Err(anyhow::anyhow!("WebSocket error: {}", e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_datagram(&mut self) -> Result<Option<Vec<u8>>> {
|
||||
// WebSocket does not support datagrams.
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn supports_datagrams(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a WebSocket stream into transport sink and stream halves.
|
||||
pub fn split_ws(ws: WsStream) -> (WsTransportSink, WsTransportStream) {
|
||||
let (sink, stream) = ws.split();
|
||||
(WsTransportSink::new(sink), WsTransportStream::new(stream))
|
||||
}
|
||||
Reference in New Issue
Block a user