701 lines
26 KiB
Rust
701 lines
26 KiB
Rust
//! QUIC connection handling.
|
|
//!
|
|
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
|
|
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
|
|
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
|
|
//!
|
|
//! When `proxy_ips` is configured, a UDP relay layer intercepts PROXY protocol v2
|
|
//! headers before they reach quinn, extracting real client IPs for attribution.
|
|
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::UdpSocket;
|
|
use tokio::task::JoinHandle;
|
|
|
|
use arc_swap::ArcSwap;
|
|
use dashmap::DashMap;
|
|
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
|
|
use rustls::ServerConfig as RustlsServerConfig;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{debug, info, warn};
|
|
|
|
use rustproxy_config::{RouteConfig, TransportProtocol};
|
|
use rustproxy_metrics::MetricsCollector;
|
|
use rustproxy_routing::{MatchContext, RouteManager};
|
|
|
|
use rustproxy_http::h3_service::H3ProxyService;
|
|
|
|
use crate::connection_tracker::ConnectionTracker;
|
|
|
|
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
|
///
|
|
/// The TLS config must have ALPN protocols set (e.g., `h3` for HTTP/3).
|
|
pub fn create_quic_endpoint(
|
|
port: u16,
|
|
tls_config: Arc<RustlsServerConfig>,
|
|
) -> anyhow::Result<Endpoint> {
|
|
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
|
|
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
|
|
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
|
|
|
|
let socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
|
|
let endpoint = Endpoint::new(
|
|
quinn::EndpointConfig::default(),
|
|
Some(server_config),
|
|
socket,
|
|
quinn::default_runtime()
|
|
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
|
)?;
|
|
|
|
info!("QUIC endpoint listening on port {}", port);
|
|
Ok(endpoint)
|
|
}
|
|
|
|
// ===== PROXY protocol relay for QUIC =====
|
|
|
|
/// Result of creating a QUIC endpoint with a PROXY protocol relay layer.
|
|
pub struct QuicProxyRelay {
|
|
/// The quinn endpoint (bound to 127.0.0.1:ephemeral).
|
|
pub endpoint: Endpoint,
|
|
/// The relay recv loop task handle.
|
|
pub relay_task: JoinHandle<()>,
|
|
/// Maps relay socket local addr → real client SocketAddr (from PROXY v2).
|
|
/// Consulted by `quic_accept_loop` to resolve real client IPs.
|
|
pub real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
|
}
|
|
|
|
/// A single relay session for forwarding datagrams between an external source
|
|
/// and the internal quinn endpoint.
|
|
struct RelaySession {
|
|
socket: Arc<UdpSocket>,
|
|
last_activity: AtomicU64,
|
|
return_task: JoinHandle<()>,
|
|
cancel: CancellationToken,
|
|
}
|
|
|
|
/// Create a QUIC endpoint with a PROXY protocol v2 relay layer.
|
|
///
|
|
/// Instead of giving the external socket to quinn, we:
|
|
/// 1. Bind a raw UDP socket on 0.0.0.0:port (external)
|
|
/// 2. Bind quinn on 127.0.0.1:0 (internal, ephemeral)
|
|
/// 3. Run a relay loop that filters PROXY v2 headers and forwards datagrams
|
|
///
|
|
/// Only used when `proxy_ips` is non-empty.
|
|
pub fn create_quic_endpoint_with_proxy_relay(
|
|
port: u16,
|
|
tls_config: Arc<RustlsServerConfig>,
|
|
proxy_ips: Arc<Vec<IpAddr>>,
|
|
cancel: CancellationToken,
|
|
) -> anyhow::Result<QuicProxyRelay> {
|
|
// Bind external socket on the real port
|
|
let external_socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
|
|
external_socket.set_nonblocking(true)?;
|
|
let external_socket = Arc::new(
|
|
UdpSocket::from_std(external_socket)
|
|
.map_err(|e| anyhow::anyhow!("Failed to wrap external socket: {}", e))?,
|
|
);
|
|
|
|
// Bind quinn on localhost ephemeral port
|
|
let internal_socket = std::net::UdpSocket::bind("127.0.0.1:0")?;
|
|
let quinn_internal_addr = internal_socket.local_addr()?;
|
|
|
|
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
|
|
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
|
|
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
|
|
|
|
let endpoint = Endpoint::new(
|
|
quinn::EndpointConfig::default(),
|
|
Some(server_config),
|
|
internal_socket,
|
|
quinn::default_runtime()
|
|
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
|
)?;
|
|
|
|
let real_client_map = Arc::new(DashMap::new());
|
|
|
|
let relay_task = tokio::spawn(quic_proxy_relay_loop(
|
|
external_socket,
|
|
quinn_internal_addr,
|
|
proxy_ips,
|
|
Arc::clone(&real_client_map),
|
|
cancel,
|
|
));
|
|
|
|
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
|
|
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
|
|
}
|
|
|
|
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
|
/// headers from trusted proxy IPs, and forwards everything else to quinn via
|
|
/// per-session relay sockets.
|
|
async fn quic_proxy_relay_loop(
|
|
external_socket: Arc<UdpSocket>,
|
|
quinn_internal_addr: SocketAddr,
|
|
proxy_ips: Arc<Vec<IpAddr>>,
|
|
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
|
cancel: CancellationToken,
|
|
) {
|
|
// Maps external source addr → real client addr (from PROXY v2 headers)
|
|
let proxy_addr_map: DashMap<SocketAddr, SocketAddr> = DashMap::new();
|
|
// Maps external source addr → relay session
|
|
let relay_sessions: DashMap<SocketAddr, Arc<RelaySession>> = DashMap::new();
|
|
let epoch = Instant::now();
|
|
let mut buf = vec![0u8; 65535];
|
|
|
|
// Inline cleanup: periodically scan relay_sessions for stale entries
|
|
let mut last_cleanup = Instant::now();
|
|
let cleanup_interval = std::time::Duration::from_secs(30);
|
|
let session_timeout_ms: u64 = 120_000;
|
|
|
|
loop {
|
|
let (len, src_addr) = tokio::select! {
|
|
_ = cancel.cancelled() => {
|
|
debug!("QUIC proxy relay loop cancelled");
|
|
break;
|
|
}
|
|
result = external_socket.recv_from(&mut buf) => {
|
|
match result {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
warn!("QUIC proxy relay recv error: {}", e);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let datagram = &buf[..len];
|
|
|
|
// PROXY v2 handling: only on first datagram from a trusted proxy IP
|
|
// (before a relay session exists for this source)
|
|
if proxy_ips.contains(&src_addr.ip()) && relay_sessions.get(&src_addr).is_none() {
|
|
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
|
match crate::proxy_protocol::parse_v2(datagram) {
|
|
Ok((header, _consumed)) => {
|
|
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
|
|
proxy_addr_map.insert(src_addr, header.source_addr);
|
|
continue; // consume the PROXY v2 datagram
|
|
}
|
|
Err(e) => {
|
|
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Determine real client address
|
|
let real_client = proxy_addr_map.get(&src_addr)
|
|
.map(|r| *r)
|
|
.unwrap_or(src_addr);
|
|
|
|
// Get or create relay session for this external source
|
|
let session = match relay_sessions.get(&src_addr) {
|
|
Some(s) => {
|
|
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
|
Arc::clone(s.value())
|
|
}
|
|
None => {
|
|
// Create new relay socket connected to quinn's internal address
|
|
let relay_socket = match UdpSocket::bind("127.0.0.1:0").await {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
warn!("QUIC relay: failed to bind relay socket: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
|
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
|
|
continue;
|
|
}
|
|
let relay_local_addr = match relay_socket.local_addr() {
|
|
Ok(a) => a,
|
|
Err(e) => {
|
|
warn!("QUIC relay: failed to get relay socket local addr: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
let relay_socket = Arc::new(relay_socket);
|
|
|
|
// Store the real client mapping for the QUIC accept loop
|
|
real_client_map.insert(relay_local_addr, real_client);
|
|
|
|
// Spawn return-path relay: quinn -> external socket -> original source
|
|
let session_cancel = cancel.child_token();
|
|
let return_task = tokio::spawn(relay_return_path(
|
|
Arc::clone(&relay_socket),
|
|
Arc::clone(&external_socket),
|
|
src_addr,
|
|
session_cancel.child_token(),
|
|
));
|
|
|
|
let session = Arc::new(RelaySession {
|
|
socket: relay_socket,
|
|
last_activity: AtomicU64::new(epoch.elapsed().as_millis() as u64),
|
|
return_task,
|
|
cancel: session_cancel,
|
|
});
|
|
|
|
relay_sessions.insert(src_addr, Arc::clone(&session));
|
|
debug!("QUIC relay: new session for {} (relay {}), real client {}",
|
|
src_addr, relay_local_addr, real_client);
|
|
|
|
session
|
|
}
|
|
};
|
|
|
|
// Forward datagram to quinn via the relay socket
|
|
if let Err(e) = session.socket.send(datagram).await {
|
|
debug!("QUIC relay: forward error to quinn for {}: {}", src_addr, e);
|
|
}
|
|
|
|
// Periodic cleanup of stale relay sessions
|
|
if last_cleanup.elapsed() >= cleanup_interval {
|
|
last_cleanup = Instant::now();
|
|
let now_ms = epoch.elapsed().as_millis() as u64;
|
|
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
|
|
.filter(|entry| {
|
|
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
|
age > session_timeout_ms
|
|
})
|
|
.map(|entry| *entry.key())
|
|
.collect();
|
|
|
|
for key in stale_keys {
|
|
if let Some((_, session)) = relay_sessions.remove(&key) {
|
|
session.cancel.cancel();
|
|
session.return_task.abort();
|
|
// Clean up real_client_map entry
|
|
if let Ok(addr) = session.socket.local_addr() {
|
|
real_client_map.remove(&addr);
|
|
}
|
|
proxy_addr_map.remove(&key);
|
|
debug!("QUIC relay: cleaned up stale session for {}", key);
|
|
}
|
|
}
|
|
|
|
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
|
// but no relay session was ever created, e.g. client never sent data)
|
|
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
|
|
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
|
.map(|entry| *entry.key())
|
|
.collect();
|
|
for key in orphaned {
|
|
proxy_addr_map.remove(&key);
|
|
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shutdown: cancel all relay sessions
|
|
for entry in relay_sessions.iter() {
|
|
entry.value().cancel.cancel();
|
|
entry.value().return_task.abort();
|
|
}
|
|
}
|
|
|
|
/// Return-path relay: receives datagrams from quinn (via the relay socket)
|
|
/// and forwards them back to the external client through the external socket.
|
|
async fn relay_return_path(
|
|
relay_socket: Arc<UdpSocket>,
|
|
external_socket: Arc<UdpSocket>,
|
|
external_src_addr: SocketAddr,
|
|
cancel: CancellationToken,
|
|
) {
|
|
let mut buf = vec![0u8; 65535];
|
|
loop {
|
|
let len = tokio::select! {
|
|
_ = cancel.cancelled() => break,
|
|
result = relay_socket.recv(&mut buf) => {
|
|
match result {
|
|
Ok(len) => len,
|
|
Err(e) => {
|
|
debug!("QUIC relay return recv error for {}: {}", external_src_addr, e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
|
|
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ===== QUIC accept loop =====
|
|
|
|
/// Run the QUIC accept loop for a single endpoint.
|
|
///
|
|
/// Accepts incoming QUIC connections and spawns a task per connection.
|
|
/// When `real_client_map` is provided, it is consulted to resolve real client
|
|
/// IPs from PROXY protocol v2 headers (relay socket addr → real client addr).
|
|
pub async fn quic_accept_loop(
|
|
endpoint: Endpoint,
|
|
port: u16,
|
|
route_manager: Arc<ArcSwap<RouteManager>>,
|
|
metrics: Arc<MetricsCollector>,
|
|
conn_tracker: Arc<ConnectionTracker>,
|
|
cancel: CancellationToken,
|
|
h3_service: Option<Arc<H3ProxyService>>,
|
|
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
|
) {
|
|
loop {
|
|
let incoming = tokio::select! {
|
|
_ = cancel.cancelled() => {
|
|
debug!("QUIC accept loop on port {} cancelled", port);
|
|
break;
|
|
}
|
|
incoming = endpoint.accept() => {
|
|
match incoming {
|
|
Some(conn) => conn,
|
|
None => {
|
|
debug!("QUIC endpoint on port {} closed", port);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let remote_addr = incoming.remote_address();
|
|
|
|
// Resolve real client IP from PROXY protocol map if available
|
|
let real_addr = real_client_map.as_ref()
|
|
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
|
.unwrap_or(remote_addr);
|
|
let ip = real_addr.ip();
|
|
|
|
// Per-IP rate limiting
|
|
if !conn_tracker.try_accept(&ip) {
|
|
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
|
// Drop `incoming` to refuse the connection
|
|
continue;
|
|
}
|
|
|
|
// Route matching (port + client IP, no domain yet — QUIC Initial is encrypted)
|
|
let rm = route_manager.load();
|
|
let ip_str = ip.to_string();
|
|
let ctx = MatchContext {
|
|
port,
|
|
domain: None,
|
|
path: None,
|
|
client_ip: Some(&ip_str),
|
|
tls_version: None,
|
|
headers: None,
|
|
is_tls: true,
|
|
protocol: Some("quic"),
|
|
transport: Some(TransportProtocol::Udp),
|
|
};
|
|
|
|
let route = match rm.find_route(&ctx) {
|
|
Some(m) => m.route.clone(),
|
|
None => {
|
|
debug!("No QUIC route matched for port {} from {}", port, real_addr);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
conn_tracker.connection_opened(&ip);
|
|
let route_id = route.name.clone().or(route.id.clone());
|
|
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
|
|
|
|
let metrics = Arc::clone(&metrics);
|
|
let conn_tracker = Arc::clone(&conn_tracker);
|
|
let cancel = cancel.child_token();
|
|
let h3_svc = h3_service.clone();
|
|
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
|
|
|
tokio::spawn(async move {
|
|
// RAII guard: ensures metrics/tracker cleanup even on panic
|
|
struct QuicConnGuard {
|
|
tracker: Arc<ConnectionTracker>,
|
|
metrics: Arc<MetricsCollector>,
|
|
ip: std::net::IpAddr,
|
|
ip_str: String,
|
|
route_id: Option<String>,
|
|
}
|
|
impl Drop for QuicConnGuard {
|
|
fn drop(&mut self) {
|
|
self.tracker.connection_closed(&self.ip);
|
|
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
|
}
|
|
}
|
|
let _guard = QuicConnGuard {
|
|
tracker: conn_tracker,
|
|
metrics: Arc::clone(&metrics),
|
|
ip,
|
|
ip_str,
|
|
route_id,
|
|
};
|
|
|
|
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel, h3_svc, real_client_addr).await {
|
|
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
|
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
|
}
|
|
});
|
|
}
|
|
|
|
// Graceful shutdown: close endpoint and wait for in-flight connections
|
|
endpoint.close(quinn::VarInt::from_u32(0), b"server shutting down");
|
|
endpoint.wait_idle().await;
|
|
info!("QUIC endpoint on port {} shut down", port);
|
|
}
|
|
|
|
/// Handle a single accepted QUIC connection.
|
|
async fn handle_quic_connection(
|
|
incoming: quinn::Incoming,
|
|
route: RouteConfig,
|
|
port: u16,
|
|
metrics: Arc<MetricsCollector>,
|
|
cancel: &CancellationToken,
|
|
h3_service: Option<Arc<H3ProxyService>>,
|
|
real_client_addr: Option<SocketAddr>,
|
|
) -> anyhow::Result<()> {
|
|
let connection = incoming.await?;
|
|
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
|
debug!("QUIC connection established from {}", effective_addr);
|
|
|
|
// Check if this route has HTTP/3 enabled
|
|
let enable_http3 = route.action.udp.as_ref()
|
|
.and_then(|u| u.quic.as_ref())
|
|
.and_then(|q| q.enable_http3)
|
|
.unwrap_or(false);
|
|
|
|
if enable_http3 {
|
|
if let Some(ref h3_svc) = h3_service {
|
|
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
|
|
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
|
|
} else {
|
|
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
|
|
// Keep connection alive until cancelled
|
|
tokio::select! {
|
|
_ = cancel.cancelled() => {}
|
|
reason = connection.closed() => {
|
|
debug!("HTTP/3 connection closed (no service): {}", reason);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
} else {
|
|
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
|
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
|
|
}
|
|
}
|
|
|
|
/// Forward QUIC streams bidirectionally to a TCP backend.
|
|
///
|
|
/// For each accepted bidirectional QUIC stream, connects to the backend
|
|
/// via TCP and forwards data in both directions. Quinn's RecvStream/SendStream
|
|
/// implement AsyncRead/AsyncWrite, enabling reuse of existing forwarder patterns.
|
|
async fn handle_quic_stream_forwarding(
|
|
connection: quinn::Connection,
|
|
route: RouteConfig,
|
|
port: u16,
|
|
metrics: Arc<MetricsCollector>,
|
|
cancel: &CancellationToken,
|
|
real_client_addr: Option<SocketAddr>,
|
|
) -> anyhow::Result<()> {
|
|
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
|
let route_id = route.name.as_deref().or(route.id.as_deref());
|
|
let metrics_arc = metrics;
|
|
|
|
// Resolve backend target
|
|
let target = route.action.targets.as_ref()
|
|
.and_then(|t| t.first())
|
|
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
|
let backend_host = target.host.first();
|
|
let backend_port = target.port.resolve(port);
|
|
let backend_addr = format!("{}:{}", backend_host, backend_port);
|
|
|
|
loop {
|
|
let (send_stream, recv_stream) = tokio::select! {
|
|
_ = cancel.cancelled() => break,
|
|
result = connection.accept_bi() => {
|
|
match result {
|
|
Ok(streams) => streams,
|
|
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
|
Err(quinn::ConnectionError::LocallyClosed) => break,
|
|
Err(e) => {
|
|
debug!("QUIC stream accept error from {}: {}", effective_addr, e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let backend_addr = backend_addr.clone();
|
|
let ip_str = effective_addr.ip().to_string();
|
|
let stream_metrics = Arc::clone(&metrics_arc);
|
|
let stream_route_id = route_id.map(|s| s.to_string());
|
|
let stream_cancel = cancel.child_token();
|
|
|
|
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
|
tokio::spawn(async move {
|
|
match forward_quic_stream_to_tcp(
|
|
send_stream,
|
|
recv_stream,
|
|
&backend_addr,
|
|
stream_cancel,
|
|
).await {
|
|
Ok((bytes_in, bytes_out)) => {
|
|
stream_metrics.record_bytes(
|
|
bytes_in, bytes_out,
|
|
stream_route_id.as_deref(),
|
|
Some(&ip_str),
|
|
);
|
|
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
|
}
|
|
Err(e) => {
|
|
debug!("QUIC stream forwarding error: {}", e);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
|
|
///
|
|
/// Includes inactivity timeout (60s), max lifetime (10min), and cancellation
|
|
/// to prevent leaked stream tasks when the parent connection closes.
|
|
async fn forward_quic_stream_to_tcp(
|
|
mut quic_send: quinn::SendStream,
|
|
mut quic_recv: quinn::RecvStream,
|
|
backend_addr: &str,
|
|
cancel: CancellationToken,
|
|
) -> anyhow::Result<(u64, u64)> {
|
|
let inactivity_timeout = std::time::Duration::from_secs(60);
|
|
let max_lifetime = std::time::Duration::from_secs(600);
|
|
|
|
// Connect to backend TCP
|
|
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
|
|
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
|
|
|
|
let last_activity = Arc::new(AtomicU64::new(0));
|
|
let start = std::time::Instant::now();
|
|
let conn_cancel = CancellationToken::new();
|
|
|
|
let la1 = Arc::clone(&last_activity);
|
|
let cc1 = conn_cancel.clone();
|
|
let c2b = tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
let mut total = 0u64;
|
|
loop {
|
|
let n = tokio::select! {
|
|
result = quic_recv.read(&mut buf) => match result {
|
|
Ok(Some(0)) | Ok(None) | Err(_) => break,
|
|
Ok(Some(n)) => n,
|
|
},
|
|
_ = cc1.cancelled() => break,
|
|
};
|
|
if tcp_write.write_all(&buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
total += n as u64;
|
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
|
}
|
|
let _ = tokio::time::timeout(
|
|
std::time::Duration::from_secs(2),
|
|
tcp_write.shutdown(),
|
|
).await;
|
|
total
|
|
});
|
|
|
|
let la2 = Arc::clone(&last_activity);
|
|
let cc2 = conn_cancel.clone();
|
|
let b2c = tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
let mut total = 0u64;
|
|
loop {
|
|
let n = tokio::select! {
|
|
result = tcp_read.read(&mut buf) => match result {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
},
|
|
_ = cc2.cancelled() => break,
|
|
};
|
|
// quinn SendStream implements AsyncWrite
|
|
if quic_send.write_all(&buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
total += n as u64;
|
|
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
|
}
|
|
let _ = quic_send.finish();
|
|
total
|
|
});
|
|
|
|
// Watchdog: inactivity, max lifetime, and cancellation
|
|
let la_watch = Arc::clone(&last_activity);
|
|
let c2b_abort = c2b.abort_handle();
|
|
let b2c_abort = b2c.abort_handle();
|
|
tokio::spawn(async move {
|
|
let check_interval = std::time::Duration::from_secs(5);
|
|
let mut last_seen = 0u64;
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel.cancelled() => break,
|
|
_ = tokio::time::sleep(check_interval) => {
|
|
if start.elapsed() >= max_lifetime {
|
|
debug!("QUIC stream exceeded max lifetime, closing");
|
|
break;
|
|
}
|
|
let current = la_watch.load(Ordering::Relaxed);
|
|
if current == last_seen {
|
|
let elapsed = start.elapsed().as_millis() as u64 - current;
|
|
if elapsed >= inactivity_timeout.as_millis() as u64 {
|
|
debug!("QUIC stream inactive for {}ms, closing", elapsed);
|
|
break;
|
|
}
|
|
}
|
|
last_seen = current;
|
|
}
|
|
}
|
|
}
|
|
conn_cancel.cancel();
|
|
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
|
|
c2b_abort.abort();
|
|
b2c_abort.abort();
|
|
});
|
|
|
|
let bytes_in = c2b.await.unwrap_or(0);
|
|
let bytes_out = b2c.await.unwrap_or(0);
|
|
|
|
Ok((bytes_in, bytes_out))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_quic_endpoint_requires_tls_config() {
|
|
// Install the ring crypto provider for tests
|
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
|
|
|
// Generate a single self-signed cert and use its key pair
|
|
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
|
.unwrap();
|
|
let cert_der = self_signed.cert.der().clone();
|
|
let key_der = self_signed.key_pair.serialize_der();
|
|
|
|
let mut tls_config = RustlsServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(
|
|
vec![cert_der.into()],
|
|
rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap(),
|
|
)
|
|
.unwrap();
|
|
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
|
|
|
// Port 0 = OS assigns a free port
|
|
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
|
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
|
}
|
|
}
|