310 lines
11 KiB
Rust
310 lines
11 KiB
Rust
|
|
//! QUIC connection handling.
|
||
|
|
//!
|
||
|
|
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
|
||
|
|
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
|
||
|
|
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
|
||
|
|
|
||
|
|
use std::net::SocketAddr;
|
||
|
|
use std::sync::Arc;
|
||
|
|
|
||
|
|
use tokio::io::AsyncWriteExt;
|
||
|
|
|
||
|
|
use arc_swap::ArcSwap;
|
||
|
|
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
|
||
|
|
use rustls::ServerConfig as RustlsServerConfig;
|
||
|
|
use tokio_util::sync::CancellationToken;
|
||
|
|
use tracing::{debug, info, warn};
|
||
|
|
|
||
|
|
use rustproxy_config::{RouteConfig, TransportProtocol};
|
||
|
|
use rustproxy_metrics::MetricsCollector;
|
||
|
|
use rustproxy_routing::{MatchContext, RouteManager};
|
||
|
|
|
||
|
|
use crate::connection_tracker::ConnectionTracker;
|
||
|
|
use crate::forwarder::ForwardMetricsCtx;
|
||
|
|
|
||
|
|
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||
|
|
///
|
||
|
|
/// The TLS config must have ALPN protocols set (e.g., `h3` for HTTP/3).
|
||
|
|
pub fn create_quic_endpoint(
|
||
|
|
port: u16,
|
||
|
|
tls_config: Arc<RustlsServerConfig>,
|
||
|
|
) -> anyhow::Result<Endpoint> {
|
||
|
|
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
|
||
|
|
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
|
||
|
|
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
|
||
|
|
|
||
|
|
let socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
|
||
|
|
let endpoint = Endpoint::new(
|
||
|
|
quinn::EndpointConfig::default(),
|
||
|
|
Some(server_config),
|
||
|
|
socket,
|
||
|
|
quinn::default_runtime()
|
||
|
|
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||
|
|
)?;
|
||
|
|
|
||
|
|
info!("QUIC endpoint listening on port {}", port);
|
||
|
|
Ok(endpoint)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Run the QUIC accept loop for a single endpoint.
|
||
|
|
///
|
||
|
|
/// Accepts incoming QUIC connections and spawns a task per connection.
|
||
|
|
pub async fn quic_accept_loop(
|
||
|
|
endpoint: Endpoint,
|
||
|
|
port: u16,
|
||
|
|
route_manager: Arc<ArcSwap<RouteManager>>,
|
||
|
|
metrics: Arc<MetricsCollector>,
|
||
|
|
conn_tracker: Arc<ConnectionTracker>,
|
||
|
|
cancel: CancellationToken,
|
||
|
|
) {
|
||
|
|
loop {
|
||
|
|
let incoming = tokio::select! {
|
||
|
|
_ = cancel.cancelled() => {
|
||
|
|
debug!("QUIC accept loop on port {} cancelled", port);
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
incoming = endpoint.accept() => {
|
||
|
|
match incoming {
|
||
|
|
Some(conn) => conn,
|
||
|
|
None => {
|
||
|
|
debug!("QUIC endpoint on port {} closed", port);
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
let remote_addr = incoming.remote_address();
|
||
|
|
let ip = remote_addr.ip();
|
||
|
|
|
||
|
|
// Per-IP rate limiting
|
||
|
|
if !conn_tracker.try_accept(&ip) {
|
||
|
|
debug!("QUIC connection rejected from {} (rate limit)", remote_addr);
|
||
|
|
// Drop `incoming` to refuse the connection
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Route matching (port + client IP, no domain yet — QUIC Initial is encrypted)
|
||
|
|
let rm = route_manager.load();
|
||
|
|
let ip_str = ip.to_string();
|
||
|
|
let ctx = MatchContext {
|
||
|
|
port,
|
||
|
|
domain: None,
|
||
|
|
path: None,
|
||
|
|
client_ip: Some(&ip_str),
|
||
|
|
tls_version: None,
|
||
|
|
headers: None,
|
||
|
|
is_tls: true,
|
||
|
|
protocol: Some("quic"),
|
||
|
|
transport: Some(TransportProtocol::Udp),
|
||
|
|
};
|
||
|
|
|
||
|
|
let route = match rm.find_route(&ctx) {
|
||
|
|
Some(m) => m.route.clone(),
|
||
|
|
None => {
|
||
|
|
debug!("No QUIC route matched for port {} from {}", port, remote_addr);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
conn_tracker.connection_opened(&ip);
|
||
|
|
let route_id = route.name.clone().or(route.id.clone());
|
||
|
|
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
|
||
|
|
|
||
|
|
let metrics = Arc::clone(&metrics);
|
||
|
|
let conn_tracker = Arc::clone(&conn_tracker);
|
||
|
|
let cancel = cancel.child_token();
|
||
|
|
|
||
|
|
tokio::spawn(async move {
|
||
|
|
match handle_quic_connection(incoming, route, port, &metrics, &cancel).await {
|
||
|
|
Ok(()) => debug!("QUIC connection from {} completed", remote_addr),
|
||
|
|
Err(e) => debug!("QUIC connection from {} error: {}", remote_addr, e),
|
||
|
|
}
|
||
|
|
|
||
|
|
// Cleanup
|
||
|
|
conn_tracker.connection_closed(&ip);
|
||
|
|
metrics.connection_closed(route_id.as_deref(), Some(&ip_str));
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Graceful shutdown: close endpoint and wait for in-flight connections
|
||
|
|
endpoint.close(quinn::VarInt::from_u32(0), b"server shutting down");
|
||
|
|
endpoint.wait_idle().await;
|
||
|
|
info!("QUIC endpoint on port {} shut down", port);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Handle a single accepted QUIC connection.
|
||
|
|
async fn handle_quic_connection(
|
||
|
|
incoming: quinn::Incoming,
|
||
|
|
route: RouteConfig,
|
||
|
|
port: u16,
|
||
|
|
metrics: &MetricsCollector,
|
||
|
|
cancel: &CancellationToken,
|
||
|
|
) -> anyhow::Result<()> {
|
||
|
|
let connection = incoming.await?;
|
||
|
|
let remote_addr = connection.remote_address();
|
||
|
|
debug!("QUIC connection established from {}", remote_addr);
|
||
|
|
|
||
|
|
// Check if this route has HTTP/3 enabled
|
||
|
|
let enable_http3 = route.action.udp.as_ref()
|
||
|
|
.and_then(|u| u.quic.as_ref())
|
||
|
|
.and_then(|q| q.enable_http3)
|
||
|
|
.unwrap_or(false);
|
||
|
|
|
||
|
|
if enable_http3 {
|
||
|
|
// Phase 5: dispatch to H3ProxyService
|
||
|
|
// For now, log and accept streams for basic handling
|
||
|
|
debug!("HTTP/3 enabled for route {:?}, dispatching to H3 handler", route.name);
|
||
|
|
handle_h3_connection(connection, route, port, metrics, cancel).await
|
||
|
|
} else {
|
||
|
|
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||
|
|
handle_quic_stream_forwarding(connection, route, port, metrics, cancel).await
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Forward QUIC streams bidirectionally to a TCP backend.
|
||
|
|
///
|
||
|
|
/// For each accepted bidirectional QUIC stream, connects to the backend
|
||
|
|
/// via TCP and forwards data in both directions. Quinn's RecvStream/SendStream
|
||
|
|
/// implement AsyncRead/AsyncWrite, enabling reuse of existing forwarder patterns.
|
||
|
|
async fn handle_quic_stream_forwarding(
|
||
|
|
connection: quinn::Connection,
|
||
|
|
route: RouteConfig,
|
||
|
|
port: u16,
|
||
|
|
_metrics: &MetricsCollector,
|
||
|
|
cancel: &CancellationToken,
|
||
|
|
) -> anyhow::Result<()> {
|
||
|
|
let remote_addr = connection.remote_address();
|
||
|
|
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||
|
|
|
||
|
|
// Resolve backend target
|
||
|
|
let target = route.action.targets.as_ref()
|
||
|
|
.and_then(|t| t.first())
|
||
|
|
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
||
|
|
let backend_host = target.host.first();
|
||
|
|
let backend_port = target.port.resolve(port);
|
||
|
|
let backend_addr = format!("{}:{}", backend_host, backend_port);
|
||
|
|
|
||
|
|
loop {
|
||
|
|
let (send_stream, recv_stream) = tokio::select! {
|
||
|
|
_ = cancel.cancelled() => break,
|
||
|
|
result = connection.accept_bi() => {
|
||
|
|
match result {
|
||
|
|
Ok(streams) => streams,
|
||
|
|
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
||
|
|
Err(quinn::ConnectionError::LocallyClosed) => break,
|
||
|
|
Err(e) => {
|
||
|
|
debug!("QUIC stream accept error from {}: {}", remote_addr, e);
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
let backend_addr = backend_addr.clone();
|
||
|
|
let ip_str = remote_addr.ip().to_string();
|
||
|
|
let _fwd_ctx = ForwardMetricsCtx {
|
||
|
|
collector: Arc::new(MetricsCollector::new()), // TODO: share real metrics
|
||
|
|
route_id: route_id.map(|s| s.to_string()),
|
||
|
|
source_ip: Some(ip_str),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||
|
|
tokio::spawn(async move {
|
||
|
|
match forward_quic_stream_to_tcp(
|
||
|
|
send_stream,
|
||
|
|
recv_stream,
|
||
|
|
&backend_addr,
|
||
|
|
).await {
|
||
|
|
Ok((bytes_in, bytes_out)) => {
|
||
|
|
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
debug!("QUIC stream forwarding error: {}", e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
|
||
|
|
async fn forward_quic_stream_to_tcp(
|
||
|
|
mut quic_send: quinn::SendStream,
|
||
|
|
mut quic_recv: quinn::RecvStream,
|
||
|
|
backend_addr: &str,
|
||
|
|
) -> anyhow::Result<(u64, u64)> {
|
||
|
|
// Connect to backend TCP
|
||
|
|
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
|
||
|
|
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
|
||
|
|
|
||
|
|
// Bidirectional copy
|
||
|
|
let client_to_backend = tokio::io::copy(&mut quic_recv, &mut tcp_write);
|
||
|
|
let backend_to_client = tokio::io::copy(&mut tcp_read, &mut quic_send);
|
||
|
|
|
||
|
|
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
||
|
|
|
||
|
|
let bytes_in = c2b.unwrap_or(0);
|
||
|
|
let bytes_out = b2c.unwrap_or(0);
|
||
|
|
|
||
|
|
// Graceful shutdown
|
||
|
|
let _ = quic_send.finish();
|
||
|
|
let _ = tcp_write.shutdown().await;
|
||
|
|
|
||
|
|
Ok((bytes_in, bytes_out))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Placeholder for HTTP/3 connection handling (Phase 5).
|
||
|
|
///
|
||
|
|
/// Once h3_service is implemented, this will delegate to it.
|
||
|
|
async fn handle_h3_connection(
|
||
|
|
connection: quinn::Connection,
|
||
|
|
_route: RouteConfig,
|
||
|
|
_port: u16,
|
||
|
|
_metrics: &MetricsCollector,
|
||
|
|
cancel: &CancellationToken,
|
||
|
|
) -> anyhow::Result<()> {
|
||
|
|
warn!("HTTP/3 handling not yet fully implemented — accepting connection but no request processing");
|
||
|
|
|
||
|
|
// Keep the connection alive until cancelled or closed
|
||
|
|
tokio::select! {
|
||
|
|
_ = cancel.cancelled() => {}
|
||
|
|
reason = connection.closed() => {
|
||
|
|
debug!("HTTP/3 connection closed: {}", reason);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_quic_endpoint_requires_tls_config() {
|
||
|
|
// Install the ring crypto provider for tests
|
||
|
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||
|
|
|
||
|
|
// Generate a single self-signed cert and use its key pair
|
||
|
|
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
||
|
|
.unwrap();
|
||
|
|
let cert_der = self_signed.cert.der().clone();
|
||
|
|
let key_der = self_signed.key_pair.serialize_der();
|
||
|
|
|
||
|
|
let mut tls_config = RustlsServerConfig::builder()
|
||
|
|
.with_no_client_auth()
|
||
|
|
.with_single_cert(
|
||
|
|
vec![cert_der.into()],
|
||
|
|
rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap(),
|
||
|
|
)
|
||
|
|
.unwrap();
|
||
|
|
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
||
|
|
|
||
|
|
// Port 0 = OS assigns a free port
|
||
|
|
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
||
|
|
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
||
|
|
}
|
||
|
|
}
|