//! HTTP/3 proxy service. //! //! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests, //! and forwards them to backends using the same routing and pool infrastructure //! as the HTTP/1+2 proxy. use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use arc_swap::ArcSwap; use bytes::{Buf, Bytes}; use http_body::Frame; use tracing::{debug, warn}; use rustproxy_config::{RouteConfig, TransportProtocol}; use rustproxy_metrics::MetricsCollector; use rustproxy_routing::{MatchContext, RouteManager}; use crate::connection_pool::ConnectionPool; use crate::protocol_cache::ProtocolCache; use crate::upstream_selector::UpstreamSelector; /// HTTP/3 proxy service. /// /// Handles QUIC connections with the h3 crate, parses HTTP/3 requests, /// and forwards them to backends using per-request route matching and /// shared connection pooling. pub struct H3ProxyService { route_manager: Arc>, metrics: Arc, connection_pool: Arc, #[allow(dead_code)] protocol_cache: Arc, #[allow(dead_code)] upstream_selector: UpstreamSelector, backend_tls_config: Arc, connect_timeout: Duration, } impl H3ProxyService { pub fn new( route_manager: Arc>, metrics: Arc, connection_pool: Arc, protocol_cache: Arc, backend_tls_config: Arc, connect_timeout: Duration, ) -> Self { Self { route_manager: Arc::clone(&route_manager), metrics: Arc::clone(&metrics), connection_pool, protocol_cache, upstream_selector: UpstreamSelector::new(), backend_tls_config, connect_timeout, } } /// Handle an accepted QUIC connection as HTTP/3. /// /// If `real_client_addr` is provided (from PROXY protocol), it overrides /// `connection.remote_address()` for client IP attribution. pub async fn handle_connection( &self, connection: quinn::Connection, _fallback_route: &RouteConfig, port: u16, real_client_addr: Option, ) -> anyhow::Result<()> { let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address()); debug!("HTTP/3 connection from {} on port {}", remote_addr, port); let mut h3_conn: h3::server::Connection = h3::server::builder() .send_grease(false) .build(h3_quinn::Connection::new(connection)) .await .map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?; let client_ip = remote_addr.ip().to_string(); loop { match h3_conn.accept().await { Ok(Some(resolver)) => { let (request, stream) = match resolver.resolve_request().await { Ok(pair) => pair, Err(e) => { debug!("HTTP/3 request resolve error: {}", e); continue; } }; self.metrics.record_http_request(); let rm = self.route_manager.load(); let pool = Arc::clone(&self.connection_pool); let metrics = Arc::clone(&self.metrics); let backend_tls = Arc::clone(&self.backend_tls_config); let connect_timeout = self.connect_timeout; let client_ip = client_ip.clone(); tokio::spawn(async move { if let Err(e) = handle_h3_request( request, stream, port, &client_ip, &rm, &pool, &metrics, &backend_tls, connect_timeout, ).await { debug!("HTTP/3 request error from {}: {}", client_ip, e); } }); } Ok(None) => { debug!("HTTP/3 connection from {} closed", remote_addr); break; } Err(e) => { debug!("HTTP/3 accept error from {}: {}", remote_addr, e); break; } } } Ok(()) } } /// Handle a single HTTP/3 request with per-request route matching. async fn handle_h3_request( request: hyper::Request<()>, mut stream: h3::server::RequestStream, Bytes>, port: u16, client_ip: &str, route_manager: &RouteManager, _connection_pool: &ConnectionPool, metrics: &MetricsCollector, backend_tls_config: &Arc, connect_timeout: Duration, ) -> anyhow::Result<()> { let method = request.method().clone(); let uri = request.uri().clone(); let path = uri.path().to_string(); // Extract host from :authority or Host header let host = request.uri().authority() .map(|a| a.as_str().to_string()) .or_else(|| request.headers().get("host").and_then(|v| v.to_str().ok()).map(|s| s.to_string())) .unwrap_or_default(); debug!("HTTP/3 {} {} (host: {}, client: {})", method, path, host, client_ip); // Per-request route matching let ctx = MatchContext { port, domain: if host.is_empty() { None } else { Some(&host) }, path: Some(&path), client_ip: Some(client_ip), tls_version: Some("TLSv1.3"), headers: None, is_tls: true, protocol: Some("http"), transport: Some(TransportProtocol::Udp), }; let route_match = route_manager.find_route(&ctx) .ok_or_else(|| anyhow::anyhow!("No route matched for HTTP/3 request to {}{}", host, path))?; let route = route_match.route; // Resolve backend target (use matched target or first target) let target = route_match.target .or_else(|| route.action.targets.as_ref().and_then(|t| t.first())) .ok_or_else(|| anyhow::anyhow!("No target for HTTP/3 route"))?; let backend_host = target.host.first(); let backend_port = target.port.resolve(port); let backend_addr = format!("{}:{}", backend_host, backend_port); // Determine if backend requires TLS (same logic as proxy_service.rs) let mut use_tls = target.tls.is_some(); if let Some(ref tls) = route.action.tls { if tls.mode == rustproxy_config::TlsMode::TerminateAndReencrypt { use_tls = true; } } // Connect to backend via TCP with timeout let tcp_stream = tokio::time::timeout( connect_timeout, tokio::net::TcpStream::connect(&backend_addr), ).await .map_err(|_| anyhow::anyhow!("Backend connect timeout to {}", backend_addr))? .map_err(|e| anyhow::anyhow!("Backend connect to {} failed: {}", backend_addr, e))?; let _ = tcp_stream.set_nodelay(true); // Branch: wrap in TLS if backend requires it, then HTTP/1.1 handshake. // hyper's SendRequest is NOT generic over the IO type, so both branches // produce the same type and can be unified. let mut sender = if use_tls { let connector = tokio_rustls::TlsConnector::from(Arc::clone(backend_tls_config)); let server_name = rustls::pki_types::ServerName::try_from(backend_host.to_string()) .map_err(|e| anyhow::anyhow!("Invalid backend SNI '{}': {}", backend_host, e))?; let tls_stream = connector.connect(server_name, tcp_stream).await .map_err(|e| anyhow::anyhow!("Backend TLS handshake to {} failed: {}", backend_addr, e))?; let io = hyper_util::rt::TokioIo::new(tls_stream); let (sender, conn) = hyper::client::conn::http1::handshake(io).await .map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?; tokio::spawn(async move { let _ = conn.await; }); sender } else { let io = hyper_util::rt::TokioIo::new(tcp_stream); let (sender, conn) = hyper::client::conn::http1::handshake(io).await .map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?; tokio::spawn(async move { let _ = conn.await; }); sender }; // Stream request body from H3 client to backend via an mpsc channel. // This avoids buffering the entire request body in memory. let (body_tx, body_rx) = tokio::sync::mpsc::channel::(4); let total_bytes_in = Arc::new(std::sync::atomic::AtomicU64::new(0)); let total_bytes_in_writer = Arc::clone(&total_bytes_in); // Spawn the H3 body reader task let body_reader = tokio::spawn(async move { while let Ok(Some(mut chunk)) = stream.recv_data().await { let data = Bytes::copy_from_slice(chunk.chunk()); total_bytes_in_writer.fetch_add(data.len() as u64, std::sync::atomic::Ordering::Relaxed); chunk.advance(chunk.remaining()); if body_tx.send(data).await.is_err() { break; } } stream }); // Create a body that polls from the mpsc receiver let body = H3RequestBody { receiver: body_rx }; let backend_req = build_backend_request(&method, &backend_addr, &path, &host, &request, body, use_tls)?; let response = sender.send_request(backend_req).await .map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?; // Await the body reader to get the stream back let mut stream = body_reader.await .map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?; let total_bytes_in = total_bytes_in.load(std::sync::atomic::Ordering::Relaxed); // Build H3 response let status = response.status(); let mut h3_response = hyper::Response::builder().status(status); // Copy response headers (skip hop-by-hop) for (name, value) in response.headers() { let n = name.as_str().to_lowercase(); if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" { continue; } h3_response = h3_response.header(name, value); } // Extract content-length for body loop termination (must be before into_body()) let content_length: Option = response.headers() .get(hyper::header::CONTENT_LENGTH) .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse().ok()); // Add Alt-Svc for HTTP/3 advertisement let alt_svc = route.action.udp.as_ref() .and_then(|u| u.quic.as_ref()) .map(|q| { let p = q.alt_svc_port.unwrap_or(port); let ma = q.alt_svc_max_age.unwrap_or(86400); format!("h3=\":{}\"; ma={}", p, ma) }) .unwrap_or_else(|| format!("h3=\":{}\"; ma=86400", port)); h3_response = h3_response.header("alt-svc", alt_svc); let h3_response = h3_response.body(()) .map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?; // Send response headers stream.send_response(h3_response).await .map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?; // Stream response body back use http_body_util::BodyExt; use http_body::Body as _; let mut body = response.into_body(); let mut total_bytes_out: u64 = 0; // Per-frame idle timeout: if no frame arrives within this duration, assume // the body is complete (or the backend has stalled). This prevents indefinite // hangs on close-delimited bodies or when hyper's internal trailers oneshot // never resolves after all data has been received. const FRAME_IDLE_TIMEOUT: Duration = Duration::from_secs(30); loop { // Layer 1: If the body already knows it is finished (Content-Length // bodies track remaining bytes internally), break immediately to // avoid blocking on hyper's internal trailers oneshot. if body.is_end_stream() { break; } // Layer 3: Per-frame idle timeout safety net match tokio::time::timeout(FRAME_IDLE_TIMEOUT, body.frame()).await { Ok(Some(Ok(frame))) => { if let Some(data) = frame.data_ref() { total_bytes_out += data.len() as u64; stream.send_data(Bytes::copy_from_slice(data)).await .map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?; // Layer 2: Content-Length byte count check if let Some(cl) = content_length { if total_bytes_out >= cl { break; } } } } Ok(Some(Err(e))) => { warn!("Backend body read error: {}", e); break; } Ok(None) => break, // Body ended naturally Err(_) => { debug!( "H3 body frame idle timeout ({:?}) after {} bytes; finishing stream", FRAME_IDLE_TIMEOUT, total_bytes_out ); break; } } } // Record metrics let route_id = route.name.as_deref().or(route.id.as_deref()); metrics.record_bytes(total_bytes_in, total_bytes_out, route_id, Some(client_ip)); // Finish the stream stream.finish().await .map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?; Ok(()) } /// Build an HTTP/1.1 backend request from the H3 frontend request. fn build_backend_request( method: &hyper::Method, backend_addr: &str, path: &str, host: &str, original_request: &hyper::Request<()>, body: B, use_tls: bool, ) -> anyhow::Result> { let scheme = if use_tls { "https" } else { "http" }; let mut req = hyper::Request::builder() .method(method) .uri(format!("{}://{}{}", scheme, backend_addr, path)) .header("host", host); // Forward non-pseudo headers for (name, value) in original_request.headers() { let n = name.as_str(); if !n.starts_with(':') && n != "host" { req = req.header(name, value); } } req.body(body) .map_err(|e| anyhow::anyhow!("Failed to build backend request: {}", e)) } /// A streaming request body backed by an mpsc channel receiver. /// /// Implements `http_body::Body` so hyper can poll chunks as they arrive /// from the H3 client, avoiding buffering the entire request body in memory. struct H3RequestBody { receiver: tokio::sync::mpsc::Receiver, } impl http_body::Body for H3RequestBody { type Data = Bytes; type Error = hyper::Error; fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { match self.receiver.poll_recv(cx) { Poll::Ready(Some(data)) => Poll::Ready(Some(Ok(Frame::data(data)))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } }