399 lines
15 KiB
Rust
399 lines
15 KiB
Rust
//! HTTP/3 proxy service.
|
|
//!
|
|
//! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests,
|
|
//! and forwards them to backends using the same routing and pool infrastructure
|
|
//! as the HTTP/1+2 proxy.
|
|
|
|
use std::net::SocketAddr;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::{Context, Poll};
|
|
use std::time::Duration;
|
|
|
|
use arc_swap::ArcSwap;
|
|
use bytes::{Buf, Bytes};
|
|
use http_body::Frame;
|
|
use tracing::{debug, warn};
|
|
|
|
use rustproxy_config::{RouteConfig, TransportProtocol};
|
|
use rustproxy_metrics::MetricsCollector;
|
|
use rustproxy_routing::{MatchContext, RouteManager};
|
|
|
|
use crate::connection_pool::ConnectionPool;
|
|
use crate::protocol_cache::ProtocolCache;
|
|
use crate::upstream_selector::UpstreamSelector;
|
|
|
|
/// HTTP/3 proxy service.
|
|
///
|
|
/// Handles QUIC connections with the h3 crate, parses HTTP/3 requests,
|
|
/// and forwards them to backends using per-request route matching and
|
|
/// shared connection pooling.
|
|
pub struct H3ProxyService {
|
|
route_manager: Arc<ArcSwap<RouteManager>>,
|
|
metrics: Arc<MetricsCollector>,
|
|
connection_pool: Arc<ConnectionPool>,
|
|
#[allow(dead_code)]
|
|
protocol_cache: Arc<ProtocolCache>,
|
|
#[allow(dead_code)]
|
|
upstream_selector: UpstreamSelector,
|
|
backend_tls_config: Arc<rustls::ClientConfig>,
|
|
connect_timeout: Duration,
|
|
}
|
|
|
|
impl H3ProxyService {
|
|
pub fn new(
|
|
route_manager: Arc<ArcSwap<RouteManager>>,
|
|
metrics: Arc<MetricsCollector>,
|
|
connection_pool: Arc<ConnectionPool>,
|
|
protocol_cache: Arc<ProtocolCache>,
|
|
backend_tls_config: Arc<rustls::ClientConfig>,
|
|
connect_timeout: Duration,
|
|
) -> Self {
|
|
Self {
|
|
route_manager: Arc::clone(&route_manager),
|
|
metrics: Arc::clone(&metrics),
|
|
connection_pool,
|
|
protocol_cache,
|
|
upstream_selector: UpstreamSelector::new(),
|
|
backend_tls_config,
|
|
connect_timeout,
|
|
}
|
|
}
|
|
|
|
/// Handle an accepted QUIC connection as HTTP/3.
|
|
///
|
|
/// If `real_client_addr` is provided (from PROXY protocol), it overrides
|
|
/// `connection.remote_address()` for client IP attribution.
|
|
pub async fn handle_connection(
|
|
&self,
|
|
connection: quinn::Connection,
|
|
_fallback_route: &RouteConfig,
|
|
port: u16,
|
|
real_client_addr: Option<SocketAddr>,
|
|
) -> anyhow::Result<()> {
|
|
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
|
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
|
|
|
|
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
|
|
h3::server::Connection::new(h3_quinn::Connection::new(connection))
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?;
|
|
|
|
let client_ip = remote_addr.ip().to_string();
|
|
|
|
loop {
|
|
match h3_conn.accept().await {
|
|
Ok(Some(resolver)) => {
|
|
let (request, stream) = match resolver.resolve_request().await {
|
|
Ok(pair) => pair,
|
|
Err(e) => {
|
|
debug!("HTTP/3 request resolve error: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
self.metrics.record_http_request();
|
|
|
|
let rm = self.route_manager.load();
|
|
let pool = Arc::clone(&self.connection_pool);
|
|
let metrics = Arc::clone(&self.metrics);
|
|
let backend_tls = Arc::clone(&self.backend_tls_config);
|
|
let connect_timeout = self.connect_timeout;
|
|
let client_ip = client_ip.clone();
|
|
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_h3_request(
|
|
request, stream, port, &client_ip, &rm, &pool, &metrics,
|
|
&backend_tls, connect_timeout,
|
|
).await {
|
|
debug!("HTTP/3 request error from {}: {}", client_ip, e);
|
|
}
|
|
});
|
|
}
|
|
Ok(None) => {
|
|
debug!("HTTP/3 connection from {} closed", remote_addr);
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
debug!("HTTP/3 accept error from {}: {}", remote_addr, e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Handle a single HTTP/3 request with per-request route matching.
|
|
async fn handle_h3_request(
|
|
request: hyper::Request<()>,
|
|
mut stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
|
|
port: u16,
|
|
client_ip: &str,
|
|
route_manager: &RouteManager,
|
|
_connection_pool: &ConnectionPool,
|
|
metrics: &MetricsCollector,
|
|
backend_tls_config: &Arc<rustls::ClientConfig>,
|
|
connect_timeout: Duration,
|
|
) -> anyhow::Result<()> {
|
|
let method = request.method().clone();
|
|
let uri = request.uri().clone();
|
|
let path = uri.path().to_string();
|
|
|
|
// Extract host from :authority or Host header
|
|
let host = request.uri().authority()
|
|
.map(|a| a.as_str().to_string())
|
|
.or_else(|| request.headers().get("host").and_then(|v| v.to_str().ok()).map(|s| s.to_string()))
|
|
.unwrap_or_default();
|
|
|
|
debug!("HTTP/3 {} {} (host: {}, client: {})", method, path, host, client_ip);
|
|
|
|
// Per-request route matching
|
|
let ctx = MatchContext {
|
|
port,
|
|
domain: if host.is_empty() { None } else { Some(&host) },
|
|
path: Some(&path),
|
|
client_ip: Some(client_ip),
|
|
tls_version: Some("TLSv1.3"),
|
|
headers: None,
|
|
is_tls: true,
|
|
protocol: Some("http"),
|
|
transport: Some(TransportProtocol::Udp),
|
|
};
|
|
|
|
let route_match = route_manager.find_route(&ctx)
|
|
.ok_or_else(|| anyhow::anyhow!("No route matched for HTTP/3 request to {}{}", host, path))?;
|
|
let route = route_match.route;
|
|
|
|
// Resolve backend target (use matched target or first target)
|
|
let target = route_match.target
|
|
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
|
|
.ok_or_else(|| anyhow::anyhow!("No target for HTTP/3 route"))?;
|
|
|
|
let backend_host = target.host.first();
|
|
let backend_port = target.port.resolve(port);
|
|
let backend_addr = format!("{}:{}", backend_host, backend_port);
|
|
|
|
// Determine if backend requires TLS (same logic as proxy_service.rs)
|
|
let mut use_tls = target.tls.is_some();
|
|
if let Some(ref tls) = route.action.tls {
|
|
if tls.mode == rustproxy_config::TlsMode::TerminateAndReencrypt {
|
|
use_tls = true;
|
|
}
|
|
}
|
|
|
|
// Connect to backend via TCP with timeout
|
|
let tcp_stream = tokio::time::timeout(
|
|
connect_timeout,
|
|
tokio::net::TcpStream::connect(&backend_addr),
|
|
).await
|
|
.map_err(|_| anyhow::anyhow!("Backend connect timeout to {}", backend_addr))?
|
|
.map_err(|e| anyhow::anyhow!("Backend connect to {} failed: {}", backend_addr, e))?;
|
|
|
|
let _ = tcp_stream.set_nodelay(true);
|
|
|
|
// Branch: wrap in TLS if backend requires it, then HTTP/1.1 handshake.
|
|
// hyper's SendRequest<B> is NOT generic over the IO type, so both branches
|
|
// produce the same type and can be unified.
|
|
let mut sender = if use_tls {
|
|
let connector = tokio_rustls::TlsConnector::from(Arc::clone(backend_tls_config));
|
|
let server_name = rustls::pki_types::ServerName::try_from(backend_host.to_string())
|
|
.map_err(|e| anyhow::anyhow!("Invalid backend SNI '{}': {}", backend_host, e))?;
|
|
let tls_stream = connector.connect(server_name, tcp_stream).await
|
|
.map_err(|e| anyhow::anyhow!("Backend TLS handshake to {} failed: {}", backend_addr, e))?;
|
|
let io = hyper_util::rt::TokioIo::new(tls_stream);
|
|
let (sender, conn) = hyper::client::conn::http1::handshake(io).await
|
|
.map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?;
|
|
tokio::spawn(async move { let _ = conn.await; });
|
|
sender
|
|
} else {
|
|
let io = hyper_util::rt::TokioIo::new(tcp_stream);
|
|
let (sender, conn) = hyper::client::conn::http1::handshake(io).await
|
|
.map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?;
|
|
tokio::spawn(async move { let _ = conn.await; });
|
|
sender
|
|
};
|
|
|
|
// Stream request body from H3 client to backend via an mpsc channel.
|
|
// This avoids buffering the entire request body in memory.
|
|
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(4);
|
|
let total_bytes_in = Arc::new(std::sync::atomic::AtomicU64::new(0));
|
|
let total_bytes_in_writer = Arc::clone(&total_bytes_in);
|
|
|
|
// Spawn the H3 body reader task
|
|
let body_reader = tokio::spawn(async move {
|
|
while let Ok(Some(mut chunk)) = stream.recv_data().await {
|
|
let data = Bytes::copy_from_slice(chunk.chunk());
|
|
total_bytes_in_writer.fetch_add(data.len() as u64, std::sync::atomic::Ordering::Relaxed);
|
|
chunk.advance(chunk.remaining());
|
|
if body_tx.send(data).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
stream
|
|
});
|
|
|
|
// Create a body that polls from the mpsc receiver
|
|
let body = H3RequestBody { receiver: body_rx };
|
|
let backend_req = build_backend_request(&method, &backend_addr, &path, &host, &request, body, use_tls)?;
|
|
|
|
let response = sender.send_request(backend_req).await
|
|
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
|
|
|
// Await the body reader to get the stream back
|
|
let mut stream = body_reader.await
|
|
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
|
|
let total_bytes_in = total_bytes_in.load(std::sync::atomic::Ordering::Relaxed);
|
|
|
|
// Build H3 response
|
|
let status = response.status();
|
|
let mut h3_response = hyper::Response::builder().status(status);
|
|
|
|
// Copy response headers (skip hop-by-hop)
|
|
for (name, value) in response.headers() {
|
|
let n = name.as_str().to_lowercase();
|
|
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" {
|
|
continue;
|
|
}
|
|
h3_response = h3_response.header(name, value);
|
|
}
|
|
|
|
// Extract content-length for body loop termination (must be before into_body())
|
|
let content_length: Option<u64> = response.headers()
|
|
.get(hyper::header::CONTENT_LENGTH)
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|s| s.parse().ok());
|
|
|
|
// Add Alt-Svc for HTTP/3 advertisement
|
|
let alt_svc = route.action.udp.as_ref()
|
|
.and_then(|u| u.quic.as_ref())
|
|
.map(|q| {
|
|
let p = q.alt_svc_port.unwrap_or(port);
|
|
let ma = q.alt_svc_max_age.unwrap_or(86400);
|
|
format!("h3=\":{}\"; ma={}", p, ma)
|
|
})
|
|
.unwrap_or_else(|| format!("h3=\":{}\"; ma=86400", port));
|
|
h3_response = h3_response.header("alt-svc", alt_svc);
|
|
|
|
let h3_response = h3_response.body(())
|
|
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
|
|
|
|
// Send response headers
|
|
stream.send_response(h3_response).await
|
|
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
|
|
|
|
// Stream response body back
|
|
use http_body_util::BodyExt;
|
|
use http_body::Body as _;
|
|
let mut body = response.into_body();
|
|
let mut total_bytes_out: u64 = 0;
|
|
|
|
// Per-frame idle timeout: if no frame arrives within this duration, assume
|
|
// the body is complete (or the backend has stalled). This prevents indefinite
|
|
// hangs on close-delimited bodies or when hyper's internal trailers oneshot
|
|
// never resolves after all data has been received.
|
|
const FRAME_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
|
|
|
loop {
|
|
// Layer 1: If the body already knows it is finished (Content-Length
|
|
// bodies track remaining bytes internally), break immediately to
|
|
// avoid blocking on hyper's internal trailers oneshot.
|
|
if body.is_end_stream() {
|
|
break;
|
|
}
|
|
|
|
// Layer 3: Per-frame idle timeout safety net
|
|
match tokio::time::timeout(FRAME_IDLE_TIMEOUT, body.frame()).await {
|
|
Ok(Some(Ok(frame))) => {
|
|
if let Some(data) = frame.data_ref() {
|
|
total_bytes_out += data.len() as u64;
|
|
stream.send_data(Bytes::copy_from_slice(data)).await
|
|
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
|
|
|
|
// Layer 2: Content-Length byte count check
|
|
if let Some(cl) = content_length {
|
|
if total_bytes_out >= cl {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(Some(Err(e))) => {
|
|
warn!("Backend body read error: {}", e);
|
|
break;
|
|
}
|
|
Ok(None) => break, // Body ended naturally
|
|
Err(_) => {
|
|
debug!(
|
|
"H3 body frame idle timeout ({:?}) after {} bytes; finishing stream",
|
|
FRAME_IDLE_TIMEOUT, total_bytes_out
|
|
);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Record metrics
|
|
let route_id = route.name.as_deref().or(route.id.as_deref());
|
|
metrics.record_bytes(total_bytes_in, total_bytes_out, route_id, Some(client_ip));
|
|
|
|
// Finish the stream
|
|
stream.finish().await
|
|
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Build an HTTP/1.1 backend request from the H3 frontend request.
|
|
fn build_backend_request<B>(
|
|
method: &hyper::Method,
|
|
backend_addr: &str,
|
|
path: &str,
|
|
host: &str,
|
|
original_request: &hyper::Request<()>,
|
|
body: B,
|
|
use_tls: bool,
|
|
) -> anyhow::Result<hyper::Request<B>> {
|
|
let scheme = if use_tls { "https" } else { "http" };
|
|
let mut req = hyper::Request::builder()
|
|
.method(method)
|
|
.uri(format!("{}://{}{}", scheme, backend_addr, path))
|
|
.header("host", host);
|
|
|
|
// Forward non-pseudo headers
|
|
for (name, value) in original_request.headers() {
|
|
let n = name.as_str();
|
|
if !n.starts_with(':') && n != "host" {
|
|
req = req.header(name, value);
|
|
}
|
|
}
|
|
|
|
req.body(body)
|
|
.map_err(|e| anyhow::anyhow!("Failed to build backend request: {}", e))
|
|
}
|
|
|
|
/// A streaming request body backed by an mpsc channel receiver.
|
|
///
|
|
/// Implements `http_body::Body` so hyper can poll chunks as they arrive
|
|
/// from the H3 client, avoiding buffering the entire request body in memory.
|
|
struct H3RequestBody {
|
|
receiver: tokio::sync::mpsc::Receiver<Bytes>,
|
|
}
|
|
|
|
impl http_body::Body for H3RequestBody {
|
|
type Data = Bytes;
|
|
type Error = hyper::Error;
|
|
|
|
fn poll_frame(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
|
|
match self.receiver.poll_recv(cx) {
|
|
Poll::Ready(Some(data)) => Poll::Ready(Some(Ok(Frame::data(data)))),
|
|
Poll::Ready(None) => Poll::Ready(None),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|