From a53a2c4ca54edfea7686b12886fb04dee4c2d004 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Tue, 14 Apr 2026 00:54:12 +0000 Subject: [PATCH] fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup --- changelog.md | 8 + .../rustproxy-http/src/connection_pool.rs | 67 +- .../rustproxy-http/src/counting_body.rs | 12 +- rust/crates/rustproxy-http/src/h3_service.rs | 38 +- rust/crates/rustproxy-http/src/lib.rs | 3 +- .../rustproxy-http/src/protocol_cache.rs | 86 +- .../rustproxy-http/src/proxy_service.rs | 1253 +++++++++++++---- .../rustproxy-http/src/request_filter.rs | 182 ++- .../crates/rustproxy-http/src/request_host.rs | 43 + .../rustproxy-http/src/response_filter.rs | 16 +- .../rustproxy-http/src/shutdown_on_drop.rs | 13 +- rust/crates/rustproxy-http/src/template.rs | 8 +- .../rustproxy-http/src/upstream_selector.rs | 33 +- .../crates/rustproxy-metrics/src/collector.rs | 639 +++++++-- ts/00_commitinfo_data.ts | 2 +- 15 files changed, 1813 insertions(+), 590 deletions(-) create mode 100644 rust/crates/rustproxy-http/src/request_host.rs diff --git a/changelog.md b/changelog.md index 8d1d618..ac8d522 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,13 @@ # Changelog +## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics) +fix domain-scoped request host detection and harden connection metrics cleanup + +- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header +- add request filter and host extraction tests covering domain-scoped ACL behavior +- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely +- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations + ## 2026-04-13 - 27.7.0 - feat(smart-proxy) add typed Rust config serialization and regex header contract coverage diff --git a/rust/crates/rustproxy-http/src/connection_pool.rs b/rust/crates/rustproxy-http/src/connection_pool.rs index afa4033..84a3f4d 100644 --- a/rust/crates/rustproxy-http/src/connection_pool.rs +++ b/rust/crates/rustproxy-http/src/connection_pool.rs @@ -3,8 +3,8 @@ //! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes. //! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection). -use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant}; use bytes::Bytes; @@ -105,13 +105,19 @@ impl ConnectionPool { /// Try to check out an idle HTTP/1.1 sender for the given key. /// Returns `None` if no usable idle connection exists. - pub fn checkout_h1(&self, key: &PoolKey) -> Option>> { + pub fn checkout_h1( + &self, + key: &PoolKey, + ) -> Option>> { let mut entry = self.h1_pool.get_mut(key)?; let idles = entry.value_mut(); while let Some(idle) = idles.pop() { // Check if the connection is still alive and ready - if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() { + if idle.idle_since.elapsed() < IDLE_TIMEOUT + && idle.sender.is_ready() + && !idle.sender.is_closed() + { // H1 pool hit — no logging on hot path return Some(idle.sender); } @@ -128,7 +134,11 @@ impl ConnectionPool { /// Return an HTTP/1.1 sender to the pool after the response body has been prepared. /// The caller should NOT call this if the sender is closed or not ready. - pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest>) { + pub fn checkin_h1( + &self, + key: PoolKey, + sender: http1::SendRequest>, + ) { if sender.is_closed() || !sender.is_ready() { return; // Don't pool broken connections } @@ -145,7 +155,10 @@ impl ConnectionPool { /// Try to get a cloned HTTP/2 sender for the given key. /// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove. - pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest>, Duration)> { + pub fn checkout_h2( + &self, + key: &PoolKey, + ) -> Option<(http2::SendRequest>, Duration)> { let entry = self.h2_pool.get(key)?; let pooled = entry.value(); let age = pooled.created_at.elapsed(); @@ -184,16 +197,23 @@ impl ConnectionPool { /// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry. /// The caller should pass this generation to the connection driver so it can use /// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction. - pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest>) -> u64 { + pub fn register_h2( + &self, + key: PoolKey, + sender: http2::SendRequest>, + ) -> u64 { let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed); if sender.is_closed() { return gen; } - self.h2_pool.insert(key, PooledH2 { - sender, - created_at: Instant::now(), - generation: gen, - }); + self.h2_pool.insert( + key, + PooledH2 { + sender, + created_at: Instant::now(), + generation: gen, + }, + ); gen } @@ -204,7 +224,11 @@ impl ConnectionPool { pub fn checkout_h3( &self, key: &PoolKey, - ) -> Option<(h3::client::SendRequest, quinn::Connection, Duration)> { + ) -> Option<( + h3::client::SendRequest, + quinn::Connection, + Duration, + )> { let entry = self.h3_pool.get(key)?; let pooled = entry.value(); let age = pooled.created_at.elapsed(); @@ -234,12 +258,15 @@ impl ConnectionPool { send_request: h3::client::SendRequest, ) -> u64 { let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed); - self.h3_pool.insert(key, PooledH3 { - send_request, - connection, - created_at: Instant::now(), - generation: gen, - }); + self.h3_pool.insert( + key, + PooledH3 { + send_request, + connection, + created_at: Instant::now(), + generation: gen, + }, + ); gen } @@ -280,7 +307,9 @@ impl ConnectionPool { // Evict dead or aged-out H2 connections let mut dead_h2 = Vec::new(); for entry in h2_pool.iter() { - if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE { + if entry.value().sender.is_closed() + || entry.value().created_at.elapsed() >= MAX_H2_AGE + { dead_h2.push(entry.key().clone()); } } diff --git a/rust/crates/rustproxy-http/src/counting_body.rs b/rust/crates/rustproxy-http/src/counting_body.rs index 23b82b8..2de70c4 100644 --- a/rust/crates/rustproxy-http/src/counting_body.rs +++ b/rust/crates/rustproxy-http/src/counting_body.rs @@ -1,8 +1,8 @@ //! A body wrapper that counts bytes flowing through and reports them to MetricsCollector. use std::pin::Pin; -use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use std::task::{Context, Poll}; use bytes::Bytes; @@ -76,7 +76,11 @@ impl CountingBody { /// Set the connection-level activity tracker. When set, each data frame /// updates this timestamp to prevent the idle watchdog from killing the /// connection during active body streaming. - pub fn with_connection_activity(mut self, activity: Arc, start: std::time::Instant) -> Self { + pub fn with_connection_activity( + mut self, + activity: Arc, + start: std::time::Instant, + ) -> Self { self.connection_activity = Some(activity); self.activity_start = Some(start); self @@ -134,7 +138,9 @@ where } // Keep the connection-level idle watchdog alive on every frame // (this is just one atomic store — cheap enough per-frame) - if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) { + if let (Some(activity), Some(start)) = + (&this.connection_activity, &this.activity_start) + { activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } } diff --git a/rust/crates/rustproxy-http/src/h3_service.rs b/rust/crates/rustproxy-http/src/h3_service.rs index 92371af..579e53e 100644 --- a/rust/crates/rustproxy-http/src/h3_service.rs +++ b/rust/crates/rustproxy-http/src/h3_service.rs @@ -11,8 +11,8 @@ use std::task::{Context, Poll}; use bytes::{Buf, Bytes}; use http_body::Frame; -use http_body_util::BodyExt; use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; use tracing::{debug, warn}; use rustproxy_config::RouteConfig; @@ -49,7 +49,8 @@ impl H3ProxyService { debug!("HTTP/3 connection from {} on port {}", remote_addr, port); // Track frontend H3 connection for the QUIC connection's lifetime. - let _frontend_h3_guard = ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3"); + let _frontend_h3_guard = + ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3"); let mut h3_conn: h3::server::Connection = h3::server::builder() @@ -92,8 +93,15 @@ impl H3ProxyService { tokio::spawn(async move { if let Err(e) = handle_h3_request( - request, stream, port, remote_addr, &http_proxy, request_cancel, - ).await { + request, + stream, + port, + remote_addr, + &http_proxy, + request_cancel, + ) + .await + { debug!("HTTP/3 request error from {}: {}", remote_addr, e); } }); @@ -153,11 +161,14 @@ async fn handle_h3_request( // Delegate to HttpProxyService — same backend path as TCP/HTTP: // route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto. let conn_activity = ConnActivity::new_standalone(); - let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await + let response = http_proxy + .handle_request(req, peer_addr, port, cancel, conn_activity) + .await .map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?; // Await the body reader to get the H3 stream back - let mut stream = body_reader.await + let mut stream = body_reader + .await .map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?; // Send response headers over H3 (skip hop-by-hop headers) @@ -170,10 +181,13 @@ async fn handle_h3_request( } h3_response = h3_response.header(name, value); } - let h3_response = h3_response.body(()) + let h3_response = h3_response + .body(()) .map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?; - stream.send_response(h3_response).await + stream + .send_response(h3_response) + .await .map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?; // Stream response body back over H3 @@ -182,7 +196,9 @@ async fn handle_h3_request( match frame { Ok(frame) => { if let Ok(data) = frame.into_data() { - stream.send_data(data).await + stream + .send_data(data) + .await .map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?; } } @@ -194,7 +210,9 @@ async fn handle_h3_request( } // Finish the H3 stream (send QUIC FIN) - stream.finish().await + stream + .finish() + .await .map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?; Ok(()) diff --git a/rust/crates/rustproxy-http/src/lib.rs b/rust/crates/rustproxy-http/src/lib.rs index 0398a70..a30ed70 100644 --- a/rust/crates/rustproxy-http/src/lib.rs +++ b/rust/crates/rustproxy-http/src/lib.rs @@ -5,14 +5,15 @@ pub mod connection_pool; pub mod counting_body; +pub mod h3_service; pub mod protocol_cache; pub mod proxy_service; pub mod request_filter; +mod request_host; pub mod response_filter; pub mod shutdown_on_drop; pub mod template; pub mod upstream_selector; -pub mod h3_service; pub use connection_pool::*; pub use counting_body::*; diff --git a/rust/crates/rustproxy-http/src/protocol_cache.rs b/rust/crates/rustproxy-http/src/protocol_cache.rs index 9838090..db74c49 100644 --- a/rust/crates/rustproxy-http/src/protocol_cache.rs +++ b/rust/crates/rustproxy-http/src/protocol_cache.rs @@ -144,10 +144,14 @@ impl FailureState { } fn all_expired(&self) -> bool { - let h2_expired = self.h2.as_ref() + let h2_expired = self + .h2 + .as_ref() .map(|r| r.failed_at.elapsed() >= r.cooldown) .unwrap_or(true); - let h3_expired = self.h3.as_ref() + let h3_expired = self + .h3 + .as_ref() .map(|r| r.failed_at.elapsed() >= r.cooldown) .unwrap_or(true); h2_expired && h3_expired @@ -355,9 +359,13 @@ impl ProtocolCache { let record = entry.get_mut(protocol); let (consecutive, new_cooldown) = match record { - Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => { + Some(existing) + if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => + { // Still within the "recent" window — escalate - let c = existing.consecutive_failures.saturating_add(1) + let c = existing + .consecutive_failures + .saturating_add(1) .min(PROTOCOL_FAILURE_ESCALATION_CAP); (c, escalate_cooldown(c)) } @@ -394,8 +402,13 @@ impl ProtocolCache { if protocol == DetectedProtocol::H1 { return false; } - self.failures.get(key) - .and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown)) + self.failures + .get(key) + .and_then(|entry| { + entry + .get(protocol) + .map(|r| r.failed_at.elapsed() < r.cooldown) + }) .unwrap_or(false) } @@ -464,19 +477,18 @@ impl ProtocolCache { /// Snapshot all non-expired cache entries for metrics/UI display. pub fn snapshot(&self) -> Vec { - self.cache.iter() + self.cache + .iter() .filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL) .map(|entry| { let key = entry.key(); let val = entry.value(); let failure_info = self.failures.get(key); - let (h2_sup, h2_cd, h2_cons) = Self::suppression_info( - failure_info.as_deref().and_then(|f| f.h2.as_ref()), - ); - let (h3_sup, h3_cd, h3_cons) = Self::suppression_info( - failure_info.as_deref().and_then(|f| f.h3.as_ref()), - ); + let (h2_sup, h2_cd, h2_cons) = + Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref())); + let (h3_sup, h3_cd, h3_cons) = + Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref())); ProtocolCacheEntry { host: key.host.clone(), @@ -507,7 +519,13 @@ impl ProtocolCache { /// Insert a protocol detection result with an optional H3 port. /// Logs protocol transitions when overwriting an existing entry. /// No suppression check — callers must check before calling. - fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option, reason: &str) { + fn insert_internal( + &self, + key: ProtocolCacheKey, + protocol: DetectedProtocol, + h3_port: Option, + reason: &str, + ) { // Check for existing entry to log protocol transitions if let Some(existing) = self.cache.get(&key) { if existing.protocol != protocol { @@ -522,7 +540,9 @@ impl ProtocolCache { // Evict oldest entry if at capacity if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) { - let oldest = self.cache.iter() + let oldest = self + .cache + .iter() .min_by_key(|entry| entry.value().last_accessed_at) .map(|entry| entry.key().clone()); if let Some(oldest_key) = oldest { @@ -531,13 +551,16 @@ impl ProtocolCache { } let now = Instant::now(); - self.cache.insert(key, CachedEntry { - protocol, - detected_at: now, - last_accessed_at: now, - last_probed_at: now, - h3_port, - }); + self.cache.insert( + key, + CachedEntry { + protocol, + detected_at: now, + last_accessed_at: now, + last_probed_at: now, + h3_port, + }, + ); } /// Reduce a failure record's remaining cooldown to `target`, if it currently @@ -582,26 +605,34 @@ impl ProtocolCache { interval.tick().await; // Clean expired cache entries (sliding TTL based on last_accessed_at) - let expired: Vec = cache.iter() + let expired: Vec = cache + .iter() .filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL) .map(|entry| entry.key().clone()) .collect(); if !expired.is_empty() { - debug!("Protocol cache cleanup: removing {} expired entries", expired.len()); + debug!( + "Protocol cache cleanup: removing {} expired entries", + expired.len() + ); for key in expired { cache.remove(&key); } } // Clean fully-expired failure entries - let expired_failures: Vec = failures.iter() + let expired_failures: Vec = failures + .iter() .filter(|entry| entry.value().all_expired()) .map(|entry| entry.key().clone()) .collect(); if !expired_failures.is_empty() { - debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len()); + debug!( + "Protocol cache cleanup: removing {} expired failure entries", + expired_failures.len() + ); for key in expired_failures { failures.remove(&key); } @@ -609,7 +640,8 @@ impl ProtocolCache { // Safety net: cap failures map at 2× max entries if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 { - let oldest: Vec = failures.iter() + let oldest: Vec = failures + .iter() .filter(|e| e.value().all_expired()) .map(|e| e.key().clone()) .take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES) diff --git a/rust/crates/rustproxy-http/src/proxy_service.rs b/rust/crates/rustproxy-http/src/proxy_service.rs index be305ef..3bf79e1 100644 --- a/rust/crates/rustproxy-http/src/proxy_service.rs +++ b/rust/crates/rustproxy-http/src/proxy_service.rs @@ -5,14 +5,14 @@ //! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade. use std::collections::HashMap; -use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use arc_swap::ArcSwap; use bytes::Bytes; use dashmap::DashMap; use http_body::Body as HttpBody; -use http_body_util::{BodyExt, Full, combinators::BoxBody}; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::TokioIo; @@ -24,12 +24,13 @@ use tracing::{debug, error, info, warn}; use std::pin::Pin; use std::task::{Context, Poll}; -use rustproxy_routing::RouteManager; use rustproxy_metrics::MetricsCollector; +use rustproxy_routing::RouteManager; use rustproxy_security::RateLimiter; use crate::counting_body::{CountingBody, Direction}; use crate::request_filter::RequestFilter; +use crate::request_host::extract_request_host; use crate::response_filter::ResponseFilter; use crate::upstream_selector::UpstreamSelector; @@ -121,12 +122,20 @@ pub(crate) struct ProtocolGuard { impl ProtocolGuard { pub fn frontend(metrics: Arc, version: &'static str) -> Self { metrics.frontend_protocol_opened(version); - Self { metrics, version, is_frontend: true } + Self { + metrics, + version, + is_frontend: true, + } } pub fn backend(metrics: Arc, version: &'static str) -> Self { metrics.backend_protocol_opened(version); - Self { metrics, version, is_frontend: false } + Self { + metrics, + version, + is_frontend: false, + } } } @@ -153,7 +162,10 @@ pub(crate) struct FrontendProtocolTracker { impl FrontendProtocolTracker { fn new(metrics: Arc) -> Self { - Self { metrics, proto: std::sync::OnceLock::new() } + Self { + metrics, + proto: std::sync::OnceLock::new(), + } } /// Set the frontend protocol. Only the first call opens the counter. @@ -249,7 +261,7 @@ async fn connect_tls_backend( stream.set_nodelay(true)?; // Apply keepalive with 60s default let _ = socket2::SockRef::from(&stream).set_tcp_keepalive( - &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) + &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)), ); let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?; @@ -377,7 +389,8 @@ impl HttpProxyService { /// Prune caches for route IDs that are no longer active. /// Call after route updates to prevent unbounded growth. pub fn prune_stale_routes(&self, active_route_ids: &std::collections::HashSet) { - self.route_rate_limiters.retain(|k, _| active_route_ids.contains(k)); + self.route_rate_limiters + .retain(|k, _| active_route_ids.contains(k)); self.regex_cache.clear(); self.upstream_selector.reset_round_robin(); self.protocol_cache.clear(); @@ -429,8 +442,7 @@ impl HttpProxyService { peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, - ) - where + ) where I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { let io = TokioIo::new(stream); @@ -472,7 +484,13 @@ impl HttpProxyService { let cn = cancel_inner.clone(); let la = Arc::clone(&la_inner); let st = start; - let ca = ConnActivity { last_activity: Arc::clone(&la_inner), start, active_requests: Some(Arc::clone(&ar_inner)), alt_svc_cache_key: None, alt_svc_request_url: None }; + let ca = ConnActivity { + last_activity: Arc::clone(&la_inner), + start, + active_requests: Some(Arc::clone(&ar_inner)), + alt_svc_cache_key: None, + alt_svc_request_url: None, + }; async move { let req = req.map(|body| BoxBody::new(body)); let result = svc.handle_request(req, peer, port, cn, ca).await; @@ -485,12 +503,14 @@ impl HttpProxyService { // Auto-detect h1 vs h2 based on ALPN / connection preface. // serve_connection_with_upgrades supports h1 Upgrade (WebSocket) and h2 Extended CONNECT (RFC 8441). - let mut builder = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + let mut builder = + hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); // Configure H2 server settings: Extended CONNECT for WebSocket + flow control tuning - builder.http2() + builder + .http2() .enable_connect_protocol() - .initial_stream_window_size(2 * 1024 * 1024) // 2MB per stream (vs default 64KB) - .initial_connection_window_size(8 * 1024 * 1024); // 8MB per client connection + .initial_stream_window_size(2 * 1024 * 1024) // 2MB per stream (vs default 64KB) + .initial_connection_window_size(8 * 1024 * 1024); // 8MB per client connection let conn = builder.serve_connection_with_upgrades(io, service); // Pin on the heap — auto::UpgradeableConnection is !Unpin let mut conn = Box::pin(conn); @@ -564,16 +584,7 @@ impl HttpProxyService { cancel: CancellationToken, mut conn_activity: ConnActivity, ) -> Result>, hyper::Error> { - let host = req.headers() - .get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| { - // Strip port from host header - h.split(':').next().unwrap_or(h).to_string() - }) - // HTTP/2 uses :authority pseudo-header instead of Host; - // hyper maps it to the URI authority component - .or_else(|| req.uri().host().map(|h| h.to_string())); + let host = extract_request_host(&req).map(str::to_string); let path = req.uri().path().to_string(); let method = req.method().clone(); @@ -584,15 +595,20 @@ impl HttpProxyService { let current_rm = self.route_manager.load(); let needs_headers = current_rm.any_route_has_headers(port); let headers: Option> = if needs_headers { - Some(req.headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect()) + Some( + req.headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(), + ) } else { None }; - debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr); + debug!( + "HTTP {} {} (host: {:?}) from {}", + method, path, host, peer_addr + ); // Check for CORS preflight if method == hyper::Method::OPTIONS { @@ -633,7 +649,9 @@ impl HttpProxyService { // Apply request filters (IP check, rate limiting, auth) if let Some(ref security) = route_match.route.security { // Look up or create a shared rate limiter for this route - let rate_limiter = security.rate_limit.as_ref() + let rate_limiter = security + .rate_limit + .as_ref() .filter(|rl| rl.enabled) .map(|rl| { let route_key = route_id.unwrap_or("__default__").to_string(); @@ -643,7 +661,10 @@ impl HttpProxyService { .clone() }); if let Some(response) = RequestFilter::apply_with_rate_limiter( - security, &req, &peer_addr, rate_limiter.as_ref(), + security, + &req, + &peer_addr, + rate_limiter.as_ref(), ) { return Ok(response); } @@ -655,7 +676,8 @@ impl HttpProxyService { let last_cleanup = self.last_rate_limiter_cleanup_ms.load(Ordering::Relaxed); let time_triggered = now_ms.saturating_sub(last_cleanup) >= 60_000; if count % 1000 == 0 || time_triggered { - self.last_rate_limiter_cleanup_ms.store(now_ms, Ordering::Relaxed); + self.last_rate_limiter_cleanup_ms + .store(now_ms, Ordering::Relaxed); for entry in self.route_rate_limiters.iter() { entry.value().cleanup(); } @@ -679,7 +701,10 @@ impl HttpProxyService { let target = match route_match.target { Some(t) => t, None => { - return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "No target available", + )); } }; @@ -697,14 +722,16 @@ impl HttpProxyService { self.upstream_selector.connection_started(&upstream_key); // Check for WebSocket upgrade: H1 (Upgrade header) or H2 Extended CONNECT (RFC 8441) - let is_h1_websocket = req.headers() + let is_h1_websocket = req + .headers() .get("upgrade") .and_then(|v| v.to_str().ok()) .map(|v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false); let is_h2_websocket = req.method() == hyper::Method::CONNECT - && req.extensions() + && req + .extensions() .get::() .map(|p| p.as_str().eq_ignore_ascii_case("websocket")) .unwrap_or(false); @@ -713,17 +740,35 @@ impl HttpProxyService { // WebSocket tunnels additionally get their own "ws" guards in the spawned task. if is_h1_websocket || is_h2_websocket { - let result = self.handle_websocket_upgrade( - req, peer_addr, &upstream, route_match.route, route_id, &upstream_key, cancel, &ip_str, is_h2_websocket, - if is_h2_websocket { Some(conn_activity.clone()) } else { None }, - ).await; + let result = self + .handle_websocket_upgrade( + req, + peer_addr, + &upstream, + route_match.route, + route_id, + &upstream_key, + cancel, + &ip_str, + is_h2_websocket, + if is_h2_websocket { + Some(conn_activity.clone()) + } else { + None + }, + ) + .await; // Note: for WebSocket, connection_ended is called inside // the spawned tunnel task when the connection closes. return result; } // Determine backend protocol mode - let backend_protocol_mode = route_match.route.action.options.as_ref() + let backend_protocol_mode = route_match + .route + .action + .options + .as_ref() .and_then(|o| o.backend_protocol.as_ref()) .cloned() .unwrap_or(rustproxy_config::BackendProtocol::Auto); @@ -775,10 +820,18 @@ impl HttpProxyService { // Add standard reverse-proxy headers (X-Forwarded-*) { let original_host = host.as_deref().unwrap_or(""); - let forwarded_proto = if route_match.route.action.tls.as_ref() - .map(|t| matches!(t.mode, - rustproxy_config::TlsMode::Terminate - | rustproxy_config::TlsMode::TerminateAndReencrypt)) + let forwarded_proto = if route_match + .route + .action + .tls + .as_ref() + .map(|t| { + matches!( + t.mode, + rustproxy_config::TlsMode::Terminate + | rustproxy_config::TlsMode::TerminateAndReencrypt + ) + }) .unwrap_or(false) { "https" @@ -815,7 +868,10 @@ impl HttpProxyService { } // --- Resolve protocol decision based on backend protocol mode --- - let is_auto_detect_mode = matches!(backend_protocol_mode, rustproxy_config::BackendProtocol::Auto); + let is_auto_detect_mode = matches!( + backend_protocol_mode, + rustproxy_config::BackendProtocol::Auto + ); let protocol_cache_key = crate::protocol_cache::ProtocolCacheKey { host: upstream.host.clone(), port: upstream.port, @@ -823,7 +879,9 @@ impl HttpProxyService { }; // Save cached H3 port for within-request escalation (may be needed later // if TCP connect fails and we escalate to H3 as a last resort) - let cached_h3_port = self.protocol_cache.get(&protocol_cache_key) + let cached_h3_port = self + .protocol_cache + .get(&protocol_cache_key) .and_then(|c| c.h3_port); // Track whether this ALPN probe is a periodic re-probe (vs first-time detection) @@ -832,7 +890,9 @@ impl HttpProxyService { let protocol_decision = match backend_protocol_mode { rustproxy_config::BackendProtocol::Http1 => ProtocolDecision::H1, rustproxy_config::BackendProtocol::Http2 => ProtocolDecision::H2, - rustproxy_config::BackendProtocol::Http3 => ProtocolDecision::H3 { port: upstream.port }, + rustproxy_config::BackendProtocol::Http3 => ProtocolDecision::H3 { + port: upstream.port, + }, rustproxy_config::BackendProtocol::Auto => { if !upstream.use_tls { // No ALPN without TLS, no QUIC without TLS — default to H1 @@ -847,7 +907,10 @@ impl HttpProxyService { } Some(cached) => match cached.protocol { crate::protocol_cache::DetectedProtocol::H3 => { - if self.protocol_cache.is_suppressed(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) { + if self.protocol_cache.is_suppressed( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ) { // H3 cached but suppressed — fall back to ALPN probe ProtocolDecision::AlpnProbe } else if let Some(h3_port) = cached.h3_port { @@ -857,7 +920,10 @@ impl HttpProxyService { } } crate::protocol_cache::DetectedProtocol::H2 => { - if self.protocol_cache.is_suppressed(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H2) { + if self.protocol_cache.is_suppressed( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H2, + ) { ProtocolDecision::H1 } else { ProtocolDecision::H2 @@ -867,7 +933,10 @@ impl HttpProxyService { }, None => { // Cache miss — skip ALPN probe if H2 is suppressed - if self.protocol_cache.is_suppressed(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H2) { + if self.protocol_cache.is_suppressed( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H2, + ) { ProtocolDecision::H1 } else { ProtocolDecision::AlpnProbe @@ -903,12 +972,27 @@ impl HttpProxyService { }; // Try H3 pool checkout first - if let Some((pooled_sr, quic_conn, _age)) = self.connection_pool.checkout_h3(&h3_pool_key) { + if let Some((pooled_sr, quic_conn, _age)) = + self.connection_pool.checkout_h3(&h3_pool_key) + { self.metrics.backend_pool_hit(&upstream_key); - let result = self.forward_h3( - quic_conn, Some(pooled_sr), parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + let result = self + .forward_h3( + quic_conn, + Some(pooled_sr), + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &h3_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } @@ -917,11 +1001,27 @@ impl HttpProxyService { match self.connect_quic_backend(&upstream.host, h3_port).await { Ok(quic_conn) => { self.metrics.backend_pool_miss(&upstream_key); - self.metrics.backend_connection_opened(&upstream_key, std::time::Instant::now().elapsed()); - let result = self.forward_h3( - quic_conn, None, parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + self.metrics.backend_connection_opened( + &upstream_key, + std::time::Instant::now().elapsed(), + ); + let result = self + .forward_h3( + quic_conn, + None, + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &h3_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } @@ -956,7 +1056,11 @@ impl HttpProxyService { host: upstream.host.clone(), port: upstream.port, use_tls: upstream.use_tls, - protocol: if use_h2 { crate::connection_pool::PoolProtocol::H2 } else { crate::connection_pool::PoolProtocol::H1 }, + protocol: if use_h2 { + crate::connection_pool::PoolProtocol::H2 + } else { + crate::connection_pool::PoolProtocol::H1 + }, }; // H2 pool checkout — reuse pooled connections for all requests. @@ -967,14 +1071,28 @@ impl HttpProxyService { match tokio::time::timeout( std::time::Duration::from_millis(500), sender.ready(), - ).await { + ) + .await + { Ok(Ok(())) => { self.metrics.backend_pool_hit(&upstream_key); self.metrics.set_backend_protocol(&upstream_key, "h2"); - let result = self.forward_h2_pooled( - sender, parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + let result = self + .forward_h2_pooled( + sender, + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } @@ -1011,7 +1129,9 @@ impl HttpProxyService { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(tls_config, &upstream.host, upstream.port), - ).await { + ) + .await + { Ok(Ok(tls)) => { let final_h2 = if needs_alpn_probe { // Read the ALPN-negotiated protocol from the TLS connection @@ -1030,9 +1150,17 @@ impl HttpProxyService { crate::protocol_cache::DetectedProtocol::H1 }; if is_reprobe { - self.protocol_cache.update_probe_result(&cache_key, detected, "periodic ALPN re-probe"); + self.protocol_cache.update_probe_result( + &cache_key, + detected, + "periodic ALPN re-probe", + ); } else { - self.protocol_cache.insert(cache_key, detected, "initial ALPN detection"); + self.protocol_cache.insert( + cache_key, + detected, + "initial ALPN detection", + ); } info!( @@ -1048,8 +1176,10 @@ impl HttpProxyService { } else { use_h2 }; - self.metrics.backend_connection_opened(&upstream_key, connect_start.elapsed()); - self.metrics.set_backend_protocol(&upstream_key, if final_h2 { "h2" } else { "h1" }); + self.metrics + .backend_connection_opened(&upstream_key, connect_start.elapsed()); + self.metrics + .set_backend_protocol(&upstream_key, if final_h2 { "h2" } else { "h1" }); (BackendStream::Tls(tls), final_h2) } Ok(Err(e)) => { @@ -1066,27 +1196,58 @@ impl HttpProxyService { // --- Within-request escalation: try H3 via QUIC if retryable --- if is_auto_detect_mode { if let Some(h3_port) = cached_h3_port { - if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) { - self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); + if self.protocol_cache.can_retry( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ) { + self.protocol_cache.record_retry_attempt( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); debug!(backend = %upstream_key, domain = %domain_str, "TLS connect failed — escalating to H3"); match self.connect_quic_backend(&upstream.host, h3_port).await { Ok(quic_conn) => { - self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); - self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TLS failed, H3 succeeded"); + self.protocol_cache.clear_failure( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); + self.protocol_cache.insert_h3( + protocol_cache_key.clone(), + h3_port, + "recovery — TLS failed, H3 succeeded", + ); let h3_pool_key = crate::connection_pool::PoolKey { - host: upstream.host.clone(), port: h3_port, use_tls: true, + host: upstream.host.clone(), + port: h3_port, + use_tls: true, protocol: crate::connection_pool::PoolProtocol::H3, }; - let result = self.forward_h3( - quic_conn, None, parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + let result = self + .forward_h3( + quic_conn, + None, + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &h3_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } Err(e3) => { debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed"); - self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3); + self.protocol_cache.record_failure( + protocol_cache_key.clone(), + crate::protocol_cache::DetectedProtocol::H3, + ); } } } @@ -1094,7 +1255,10 @@ impl HttpProxyService { // All protocols failed — evict cache entry self.protocol_cache.evict(&protocol_cache_key); } - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend TLS unavailable", + )); } Err(_) => { error!( @@ -1109,27 +1273,58 @@ impl HttpProxyService { // --- Within-request escalation: try H3 via QUIC if retryable --- if is_auto_detect_mode { if let Some(h3_port) = cached_h3_port { - if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) { - self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); + if self.protocol_cache.can_retry( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ) { + self.protocol_cache.record_retry_attempt( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); debug!(backend = %upstream_key, domain = %domain_str, "TLS connect timeout — escalating to H3"); match self.connect_quic_backend(&upstream.host, h3_port).await { Ok(quic_conn) => { - self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); - self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TLS timeout, H3 succeeded"); + self.protocol_cache.clear_failure( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); + self.protocol_cache.insert_h3( + protocol_cache_key.clone(), + h3_port, + "recovery — TLS timeout, H3 succeeded", + ); let h3_pool_key = crate::connection_pool::PoolKey { - host: upstream.host.clone(), port: h3_port, use_tls: true, + host: upstream.host.clone(), + port: h3_port, + use_tls: true, protocol: crate::connection_pool::PoolProtocol::H3, }; - let result = self.forward_h3( - quic_conn, None, parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + let result = self + .forward_h3( + quic_conn, + None, + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &h3_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } Err(e3) => { debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed"); - self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3); + self.protocol_cache.record_failure( + protocol_cache_key.clone(), + crate::protocol_cache::DetectedProtocol::H3, + ); } } } @@ -1137,21 +1332,28 @@ impl HttpProxyService { // All protocols failed — evict cache entry self.protocol_cache.evict(&protocol_cache_key); } - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend TLS connect timeout", + )); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), - ).await { + ) + .await + { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( - &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) + &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)), ); - self.metrics.backend_connection_opened(&upstream_key, connect_start.elapsed()); - self.metrics.set_backend_protocol(&upstream_key, if use_h2 { "h2" } else { "h1" }); + self.metrics + .backend_connection_opened(&upstream_key, connect_start.elapsed()); + self.metrics + .set_backend_protocol(&upstream_key, if use_h2 { "h2" } else { "h1" }); (BackendStream::Plain(s), use_h2) } Ok(Err(e)) => { @@ -1168,27 +1370,58 @@ impl HttpProxyService { // --- Within-request escalation: try H3 via QUIC if retryable --- if is_auto_detect_mode { if let Some(h3_port) = cached_h3_port { - if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) { - self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); + if self.protocol_cache.can_retry( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ) { + self.protocol_cache.record_retry_attempt( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); debug!(backend = %upstream_key, domain = %domain_str, "TCP connect failed — escalating to H3"); match self.connect_quic_backend(&upstream.host, h3_port).await { Ok(quic_conn) => { - self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); - self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TCP failed, H3 succeeded"); + self.protocol_cache.clear_failure( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); + self.protocol_cache.insert_h3( + protocol_cache_key.clone(), + h3_port, + "recovery — TCP failed, H3 succeeded", + ); let h3_pool_key = crate::connection_pool::PoolKey { - host: upstream.host.clone(), port: h3_port, use_tls: true, + host: upstream.host.clone(), + port: h3_port, + use_tls: true, protocol: crate::connection_pool::PoolProtocol::H3, }; - let result = self.forward_h3( - quic_conn, None, parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + let result = self + .forward_h3( + quic_conn, + None, + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &h3_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } Err(e3) => { debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed"); - self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3); + self.protocol_cache.record_failure( + protocol_cache_key.clone(), + crate::protocol_cache::DetectedProtocol::H3, + ); } } } @@ -1196,7 +1429,10 @@ impl HttpProxyService { // All protocols failed — evict cache entry self.protocol_cache.evict(&protocol_cache_key); } - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable", + )); } Err(_) => { error!( @@ -1211,27 +1447,58 @@ impl HttpProxyService { // --- Within-request escalation: try H3 via QUIC if retryable --- if is_auto_detect_mode { if let Some(h3_port) = cached_h3_port { - if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) { - self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); + if self.protocol_cache.can_retry( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ) { + self.protocol_cache.record_retry_attempt( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); debug!(backend = %upstream_key, domain = %domain_str, "TCP connect timeout — escalating to H3"); match self.connect_quic_backend(&upstream.host, h3_port).await { Ok(quic_conn) => { - self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3); - self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TCP timeout, H3 succeeded"); + self.protocol_cache.clear_failure( + &protocol_cache_key, + crate::protocol_cache::DetectedProtocol::H3, + ); + self.protocol_cache.insert_h3( + protocol_cache_key.clone(), + h3_port, + "recovery — TCP timeout, H3 succeeded", + ); let h3_pool_key = crate::connection_pool::PoolKey { - host: upstream.host.clone(), port: h3_port, use_tls: true, + host: upstream.host.clone(), + port: h3_port, + use_tls: true, protocol: crate::connection_pool::PoolProtocol::H3, }; - let result = self.forward_h3( - quic_conn, None, parts, body, upstream_headers, &upstream_path, - route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key, - ).await; + let result = self + .forward_h3( + quic_conn, + None, + parts, + body, + upstream_headers, + &upstream_path, + route_match.route, + route_id, + &ip_str, + &h3_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await; self.upstream_selector.connection_ended(&upstream_key); return result; } Err(e3) => { debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed"); - self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3); + self.protocol_cache.record_failure( + protocol_cache_key.clone(), + crate::protocol_cache::DetectedProtocol::H3, + ); } } } @@ -1239,7 +1506,10 @@ impl HttpProxyService { // All protocols failed — evict cache entry self.protocol_cache.evict(&protocol_cache_key); } - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend connect timeout", + )); } } }; @@ -1248,7 +1518,11 @@ impl HttpProxyService { host: upstream.host.clone(), port: upstream.port, use_tls: upstream.use_tls, - protocol: if detected_h2 { crate::connection_pool::PoolProtocol::H2 } else { crate::connection_pool::PoolProtocol::H1 }, + protocol: if detected_h2 { + crate::connection_pool::PoolProtocol::H2 + } else { + crate::connection_pool::PoolProtocol::H1 + }, }; let io = TokioIo::new(backend); @@ -1257,22 +1531,58 @@ impl HttpProxyService { if is_auto_detect_mode { // Auto-detect mode: use fallback-capable H2 forwarding self.forward_h2_with_fallback( - io, parts, body, upstream_headers, &upstream_path, - &upstream, route_match.route, route_id, &ip_str, &final_pool_key, - host.clone(), domain_str, &conn_activity, &upstream_key, - ).await + io, + parts, + body, + upstream_headers, + &upstream_path, + &upstream, + route_match.route, + route_id, + &ip_str, + &final_pool_key, + host.clone(), + domain_str, + &conn_activity, + &upstream_key, + ) + .await } else { // Explicit H2 mode: hard-fail on handshake error (preserved behavior) self.forward_h2( - io, parts, body, upstream_headers, &upstream_path, - &upstream, route_match.route, route_id, &ip_str, &final_pool_key, domain_str, &conn_activity, &upstream_key, - ).await + io, + parts, + body, + upstream_headers, + &upstream_path, + &upstream, + route_match.route, + route_id, + &ip_str, + &final_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await } } else { self.forward_h1( - io, parts, body, upstream_headers, &upstream_path, - &upstream, route_match.route, route_id, &ip_str, &final_pool_key, domain_str, &conn_activity, &upstream_key, - ).await + io, + parts, + body, + upstream_headers, + &upstream_path, + &upstream, + route_match.route, + route_id, + &ip_str, + &final_pool_key, + domain_str, + &conn_activity, + &upstream_key, + ) + .await }; self.upstream_selector.connection_ended(&upstream_key); self.metrics.backend_connection_closed(&upstream_key); @@ -1301,22 +1611,39 @@ impl HttpProxyService { // Try pooled H1 connection first — avoids TCP+TLS handshake if let Some(pooled_sender) = self.connection_pool.checkout_h1(pool_key) { self.metrics.backend_pool_hit(backend_key); - return self.forward_h1_with_sender( - pooled_sender, parts, body, upstream_headers, upstream_path, - route, route_id, source_ip, domain, conn_activity, backend_key, - ).await; + return self + .forward_h1_with_sender( + pooled_sender, + parts, + body, + upstream_headers, + upstream_path, + route, + route_id, + source_ip, + domain, + conn_activity, + backend_key, + ) + .await; } // Fresh connection: explicitly type the handshake with BoxBody for uniform pool type let (sender, conn): ( hyper::client::conn::http1::SendRequest>, - hyper::client::conn::http1::Connection, BoxBody>, + hyper::client::conn::http1::Connection< + TokioIo, + BoxBody, + >, ) = match hyper::client::conn::http1::handshake(io).await { Ok(h) => h, Err(e) => { error!(backend = %backend_key, domain = %domain, error = %e, "Backend H1 handshake failed"); self.metrics.backend_handshake_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend handshake failed", + )); } }; @@ -1333,7 +1660,20 @@ impl HttpProxyService { }); } - self.forward_h1_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, domain, conn_activity, backend_key).await + self.forward_h1_with_sender( + sender, + parts, + body, + upstream_headers, + upstream_path, + route, + route_id, + source_ip, + domain, + conn_activity, + backend_key, + ) + .await } /// Common H1 forwarding logic used by both fresh and pooled paths. @@ -1372,7 +1712,11 @@ impl HttpProxyService { rid.clone(), Some(Arc::clone(&sip)), Direction::In, - ).with_connection_activity(Arc::clone(&conn_activity.last_activity), conn_activity.start); + ) + .with_connection_activity( + Arc::clone(&conn_activity.last_activity), + conn_activity.start, + ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); @@ -1382,7 +1726,10 @@ impl HttpProxyService { Err(e) => { error!(backend = %backend_key, domain = %domain, error = %e, "Backend H1 request failed"); self.metrics.backend_request_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend request failed", + )); } }; @@ -1396,7 +1743,8 @@ impl HttpProxyService { // of large streaming responses (e.g. 352MB Docker layers) takes priority. drop(sender); - self.build_streaming_response(upstream_response, route, rid, sip, conn_activity).await + self.build_streaming_response(upstream_response, route, rid, sip, conn_activity) + .await } /// Forward request to backend via HTTP/2 with body streaming (fresh connection). @@ -1427,18 +1775,28 @@ impl HttpProxyService { .initial_connection_window_size(16 * 1024 * 1024); let (sender, conn): ( hyper::client::conn::http2::SendRequest>, - hyper::client::conn::http2::Connection, BoxBody, hyper_util::rt::TokioExecutor>, + hyper::client::conn::http2::Connection< + TokioIo, + BoxBody, + hyper_util::rt::TokioExecutor, + >, ) = match tokio::time::timeout(self.connect_timeout, h2_builder.handshake(io)).await { Ok(Ok(h)) => h, Ok(Err(e)) => { error!(backend = %backend_key, domain = %domain, error = %e, error_debug = ?e, "Backend H2 handshake failed"); self.metrics.backend_handshake_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend H2 handshake failed", + )); } Err(_) => { error!(backend = %backend_key, domain = %domain, "Backend H2 handshake timeout"); self.metrics.backend_handshake_error(backend_key); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend H2 handshake timeout")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend H2 handshake timeout", + )); } }; @@ -1467,9 +1825,26 @@ impl HttpProxyService { } let sender_for_pool = sender.clone(); - let result = self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, Some(pool_key), domain, conn_activity, backend_key).await; + let result = self + .forward_h2_with_sender( + sender, + parts, + body, + upstream_headers, + upstream_path, + route, + route_id, + source_ip, + Some(pool_key), + domain, + conn_activity, + backend_key, + ) + .await; if matches!(&result, Ok(ref resp) if resp.status() != StatusCode::BAD_GATEWAY) { - let g = self.connection_pool.register_h2(pool_key.clone(), sender_for_pool); + let g = self + .connection_pool + .register_h2(pool_key.clone(), sender_for_pool); gen_holder.store(g, std::sync::atomic::Ordering::Relaxed); } result @@ -1500,10 +1875,22 @@ impl HttpProxyService { None }; - let result = self.forward_h2_with_sender( - sender, parts, body, upstream_headers, upstream_path, - route, route_id, source_ip, Some(pool_key), domain, conn_activity, backend_key, - ).await; + let result = self + .forward_h2_with_sender( + sender, + parts, + body, + upstream_headers, + upstream_path, + route, + route_id, + source_ip, + Some(pool_key), + domain, + conn_activity, + backend_key, + ) + .await; // If the request failed (502) and we can retry with an empty body, do so let is_502 = matches!(&result, Ok(resp) if resp.status() == StatusCode::BAD_GATEWAY); @@ -1511,10 +1898,20 @@ impl HttpProxyService { if let Some((method, headers)) = retry_state { warn!(backend = %backend_key, domain = %domain, "Stale pooled H2 sender, retrying with fresh connection"); - return self.retry_h2_with_fresh_connection( - method, headers, upstream_path, - pool_key, route, route_id, source_ip, domain, conn_activity, backend_key, - ).await; + return self + .retry_h2_with_fresh_connection( + method, + headers, + upstream_path, + pool_key, + route, + route_id, + source_ip, + domain, + conn_activity, + backend_key, + ) + .await; } } result @@ -1535,31 +1932,40 @@ impl HttpProxyService { conn_activity: &ConnActivity, backend_key: &str, ) -> Result>, hyper::Error> { - // Establish fresh backend connection let retry_connect_start = std::time::Instant::now(); let backend = if pool_key.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config_alpn, &pool_key.host, pool_key.port), - ).await { + ) + .await + { Ok(Ok(tls)) => BackendStream::Tls(tls), Ok(Err(e)) => { error!(backend = %backend_key, domain = %domain, error = %e, "H2 retry: TLS connect failed"); self.metrics.backend_connect_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable on H2 retry")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable on H2 retry", + )); } Err(_) => { error!(backend = %backend_key, domain = %domain, "H2 retry: TLS connect timeout"); self.metrics.backend_connect_error(backend_key); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend timeout on H2 retry")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend timeout on H2 retry", + )); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", pool_key.host, pool_key.port)), - ).await { + ) + .await + { Ok(Ok(s)) => { s.set_nodelay(true).ok(); BackendStream::Plain(s) @@ -1567,16 +1973,23 @@ impl HttpProxyService { Ok(Err(e)) => { error!(backend = %backend_key, domain = %domain, error = %e, "H2 retry: TCP connect failed"); self.metrics.backend_connect_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable on H2 retry")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable on H2 retry", + )); } Err(_) => { error!(backend = %backend_key, domain = %domain, "H2 retry: TCP connect timeout"); self.metrics.backend_connect_error(backend_key); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend timeout on H2 retry")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend timeout on H2 retry", + )); } } }; - self.metrics.backend_connection_opened(backend_key, retry_connect_start.elapsed()); + self.metrics + .backend_connection_opened(backend_key, retry_connect_start.elapsed()); let io = TokioIo::new(backend); let exec = hyper_util::rt::TokioExecutor::new(); @@ -1589,20 +2002,30 @@ impl HttpProxyService { .initial_connection_window_size(16 * 1024 * 1024); let (mut sender, conn): ( hyper::client::conn::http2::SendRequest>, - hyper::client::conn::http2::Connection, BoxBody, hyper_util::rt::TokioExecutor>, + hyper::client::conn::http2::Connection< + TokioIo, + BoxBody, + hyper_util::rt::TokioExecutor, + >, ) = match tokio::time::timeout(self.connect_timeout, h2_builder.handshake(io)).await { Ok(Ok(h)) => h, Ok(Err(e)) => { error!(backend = %backend_key, domain = %domain, error = %e, error_debug = ?e, "H2 retry: handshake failed"); self.metrics.backend_handshake_error(backend_key); self.metrics.backend_connection_closed(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 retry handshake failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend H2 retry handshake failed", + )); } Err(_) => { error!(backend = %backend_key, domain = %domain, "H2 retry: handshake timeout"); self.metrics.backend_handshake_error(backend_key); self.metrics.backend_connection_closed(backend_key); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend H2 retry handshake timeout")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend H2 retry handshake timeout", + )); } }; @@ -1625,11 +2048,13 @@ impl HttpProxyService { // Build request with empty body using absolute URI for H2 pseudo-headers let scheme = if pool_key.use_tls { "https" } else { "http" }; - let authority = if domain != "-" { domain } else { pool_key.host.as_str() }; + let authority = if domain != "-" { + domain + } else { + pool_key.host.as_str() + }; let h2_uri = format!("{}://{}{}", scheme, authority, upstream_path); - let mut upstream_req = Request::builder() - .method(method) - .uri(&h2_uri); + let mut upstream_req = Request::builder().method(method).uri(&h2_uri); // Remove Host header for H2 — :authority pseudo-header (from URI) is sufficient let mut upstream_headers = upstream_headers; @@ -1639,9 +2064,8 @@ impl HttpProxyService { *headers = upstream_headers; } - let empty_body: BoxBody = BoxBody::new( - http_body_util::Empty::new().map_err(|never| match never {}) - ); + let empty_body: BoxBody = + BoxBody::new(http_body_util::Empty::new().map_err(|never| match never {})); let upstream_req = upstream_req.body(empty_body).unwrap(); match sender.send_request(upstream_req).await { @@ -1649,7 +2073,15 @@ impl HttpProxyService { // Register in pool only after request succeeds let g = self.connection_pool.register_h2(pool_key.clone(), sender); gen_holder.store(g, std::sync::atomic::Ordering::Relaxed); - let result = self.build_streaming_response(resp, route, route_id.map(Arc::from), Arc::from(source_ip), conn_activity).await; + let result = self + .build_streaming_response( + resp, + route, + route_id.map(Arc::from), + Arc::from(source_ip), + conn_activity, + ) + .await; // Close the fresh backend connection (opened above) self.metrics.backend_connection_closed(backend_key); result @@ -1659,7 +2091,10 @@ impl HttpProxyService { self.metrics.backend_request_error(backend_key); // Close the fresh backend connection (opened above) self.metrics.backend_connection_closed(backend_key); - Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed on retry")) + Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend H2 request failed on retry", + )) } } } @@ -1696,10 +2131,8 @@ impl HttpProxyService { .keep_alive_timeout(std::time::Duration::from_secs(30)) .initial_stream_window_size(2 * 1024 * 1024) .initial_connection_window_size(16 * 1024 * 1024); - let handshake_result = tokio::time::timeout( - self.connect_timeout, - h2_builder.handshake(io), - ).await; + let handshake_result = + tokio::time::timeout(self.connect_timeout, h2_builder.handshake(io)).await; match handshake_result { Err(_) => { @@ -1722,7 +2155,11 @@ impl HttpProxyService { cache_key.clone(), crate::protocol_cache::DetectedProtocol::H2, ); - self.protocol_cache.insert(cache_key.clone(), crate::protocol_cache::DetectedProtocol::H1, "H2 handshake timeout — downgrade"); + self.protocol_cache.insert( + cache_key.clone(), + crate::protocol_cache::DetectedProtocol::H1, + "H2 handshake timeout — downgrade", + ); match self.reconnect_backend(upstream, domain, backend_key).await { Some(fallback_backend) => { @@ -1733,17 +2170,33 @@ impl HttpProxyService { protocol: crate::connection_pool::PoolProtocol::H1, }; let fallback_io = TokioIo::new(fallback_backend); - let result = self.forward_h1( - fallback_io, parts, body, upstream_headers, upstream_path, - upstream, route, route_id, source_ip, &h1_pool_key, domain, conn_activity, backend_key, - ).await; + let result = self + .forward_h1( + fallback_io, + parts, + body, + upstream_headers, + upstream_path, + upstream, + route, + route_id, + source_ip, + &h1_pool_key, + domain, + conn_activity, + backend_key, + ) + .await; self.metrics.backend_connection_closed(backend_key); result } None => { // H2 failed and H1 reconnect also failed — evict cache self.protocol_cache.evict(&cache_key); - Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable after H2 timeout fallback")) + Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable after H2 timeout fallback", + )) } } } @@ -1783,11 +2236,13 @@ impl HttpProxyService { // we need to verify the request actually succeeds first, because some // backends advertise h2 via ALPN but don't speak the h2 binary protocol). let scheme = if upstream.use_tls { "https" } else { "http" }; - let authority = if domain != "-" { domain } else { upstream.host.as_str() }; + let authority = if domain != "-" { + domain + } else { + upstream.host.as_str() + }; let h2_uri = format!("{}://{}{}", scheme, authority, upstream_path); - let mut upstream_req = Request::builder() - .method(parts.method) - .uri(&h2_uri); + let mut upstream_req = Request::builder().method(parts.method).uri(&h2_uri); if let Some(headers) = upstream_req.headers_mut() { *headers = upstream_headers; @@ -1801,7 +2256,11 @@ impl HttpProxyService { rid.clone(), Some(Arc::clone(&sip)), Direction::In, - ).with_connection_activity(Arc::clone(&conn_activity.last_activity), conn_activity.start); + ) + .with_connection_activity( + Arc::clone(&conn_activity.last_activity), + conn_activity.start, + ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); @@ -1809,7 +2268,14 @@ impl HttpProxyService { Ok(upstream_response) => { let g = self.connection_pool.register_h2(pool_key.clone(), sender); gen_holder.store(g, std::sync::atomic::Ordering::Relaxed); - self.build_streaming_response(upstream_response, route, rid, sip, conn_activity).await + self.build_streaming_response( + upstream_response, + route, + rid, + sip, + conn_activity, + ) + .await } Err(e) => { // H2 request failed on a stream level (e.g. RST_STREAM PROTOCOL_ERROR). @@ -1830,20 +2296,34 @@ impl HttpProxyService { match self.reconnect_backend(upstream, domain, backend_key).await { Some(fallback_backend) => { let fallback_io = TokioIo::new(fallback_backend); - let result = self.forward_h1_empty_body( - fallback_io, method, headers, upstream_path, - route, route_id, source_ip, domain, conn_activity, backend_key, - ).await; + let result = self + .forward_h1_empty_body( + fallback_io, + method, + headers, + upstream_path, + route, + route_id, + source_ip, + domain, + conn_activity, + backend_key, + ) + .await; // Close the reconnected backend connection (opened in reconnect_backend) self.metrics.backend_connection_closed(backend_key); result } - None => { - Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable after H2 fallback")) - } + None => Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable after H2 fallback", + )), } } else { - Ok(error_response(StatusCode::BAD_GATEWAY, "Backend protocol mismatch")) + Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend protocol mismatch", + )) } } } @@ -1870,7 +2350,11 @@ impl HttpProxyService { cache_key.clone(), crate::protocol_cache::DetectedProtocol::H2, ); - self.protocol_cache.insert(cache_key.clone(), crate::protocol_cache::DetectedProtocol::H1, "H2 handshake error — downgrade"); + self.protocol_cache.insert( + cache_key.clone(), + crate::protocol_cache::DetectedProtocol::H1, + "H2 handshake error — downgrade", + ); // Reconnect for H1 (the original io was consumed by the failed h2 handshake) match self.reconnect_backend(upstream, domain, backend_key).await { @@ -1882,10 +2366,23 @@ impl HttpProxyService { protocol: crate::connection_pool::PoolProtocol::H1, }; let fallback_io = TokioIo::new(fallback_backend); - let result = self.forward_h1( - fallback_io, parts, body, upstream_headers, upstream_path, - upstream, route, route_id, source_ip, &h1_pool_key, domain, conn_activity, backend_key, - ).await; + let result = self + .forward_h1( + fallback_io, + parts, + body, + upstream_headers, + upstream_path, + upstream, + route, + route_id, + source_ip, + &h1_pool_key, + domain, + conn_activity, + backend_key, + ) + .await; // Close the reconnected backend connection (opened in reconnect_backend) self.metrics.backend_connection_closed(backend_key); result @@ -1893,7 +2390,10 @@ impl HttpProxyService { None => { // H2 failed and H1 reconnect also failed — evict cache self.protocol_cache.evict(&cache_key); - Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable after H2 fallback")) + Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable after H2 fallback", + )) } } } @@ -1917,13 +2417,19 @@ impl HttpProxyService { ) -> Result>, hyper::Error> { let (mut sender, conn): ( hyper::client::conn::http1::SendRequest>, - hyper::client::conn::http1::Connection, BoxBody>, + hyper::client::conn::http1::Connection< + TokioIo, + BoxBody, + >, ) = match hyper::client::conn::http1::handshake(io).await { Ok(h) => h, Err(e) => { error!(backend = %backend_key, domain = %domain, error = %e, "H1 fallback: handshake failed"); self.metrics.backend_handshake_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H1 fallback handshake failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend H1 fallback handshake failed", + )); } }; @@ -1949,9 +2455,8 @@ impl HttpProxyService { *headers = upstream_headers; } - let empty_body: BoxBody = BoxBody::new( - http_body_util::Empty::new().map_err(|never| match never {}) - ); + let empty_body: BoxBody = + BoxBody::new(http_body_util::Empty::new().map_err(|never| match never {})); let upstream_req = upstream_req.body(empty_body).unwrap(); let upstream_response = match sender.send_request(upstream_req).await { @@ -1959,14 +2464,24 @@ impl HttpProxyService { Err(e) => { error!(backend = %backend_key, domain = %domain, error = %e, "H1 fallback: request failed"); self.metrics.backend_request_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H1 fallback request failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend H1 fallback request failed", + )); } }; // Don't pool the sender while response body is still streaming (same safety as forward_h1_with_sender) drop(sender); - self.build_streaming_response(upstream_response, route, route_id.map(Arc::from), Arc::from(source_ip), conn_activity).await + self.build_streaming_response( + upstream_response, + route, + route_id.map(Arc::from), + Arc::from(source_ip), + conn_activity, + ) + .await } /// Reconnect to a backend (used for H2→H1 fallback). @@ -1981,9 +2496,12 @@ impl HttpProxyService { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port), - ).await { + ) + .await + { Ok(Ok(tls)) => { - self.metrics.backend_connection_opened(backend_key, reconnect_start.elapsed()); + self.metrics + .backend_connection_opened(backend_key, reconnect_start.elapsed()); Some(BackendStream::Tls(tls)) } Ok(Err(e)) => { @@ -2001,13 +2519,16 @@ impl HttpProxyService { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), - ).await { + ) + .await + { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( - &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) + &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)), ); - self.metrics.backend_connection_opened(backend_key, reconnect_start.elapsed()); + self.metrics + .backend_connection_opened(backend_key, reconnect_start.elapsed()); Some(BackendStream::Plain(s)) } Ok(Err(e)) => { @@ -2042,14 +2563,18 @@ impl HttpProxyService { ) -> Result>, hyper::Error> { // Build absolute URI for H2 pseudo-headers (:scheme, :authority) // Use the requested domain as authority (not backend address) so :authority matches Host header - let scheme = if pool_key.map(|pk| pk.use_tls).unwrap_or(false) { "https" } else { "http" }; - let authority = if domain != "-" { domain } else { + let scheme = if pool_key.map(|pk| pk.use_tls).unwrap_or(false) { + "https" + } else { + "http" + }; + let authority = if domain != "-" { + domain + } else { pool_key.map(|pk| pk.host.as_str()).unwrap_or("localhost") }; let h2_uri = format!("{}://{}{}", scheme, authority, upstream_path); - let mut upstream_req = Request::builder() - .method(parts.method) - .uri(&h2_uri); + let mut upstream_req = Request::builder().method(parts.method).uri(&h2_uri); // Remove Host header for H2 — :authority pseudo-header (from URI) is sufficient // Having both Host and :authority causes nginx to return 400 @@ -2071,7 +2596,11 @@ impl HttpProxyService { rid.clone(), Some(Arc::clone(&sip)), Direction::In, - ).with_connection_activity(Arc::clone(&conn_activity.last_activity), conn_activity.start); + ) + .with_connection_activity( + Arc::clone(&conn_activity.last_activity), + conn_activity.start, + ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); @@ -2087,11 +2616,15 @@ impl HttpProxyService { } else { error!(domain = %domain, error = %e, error_debug = ?e, "Backend H2 request failed"); } - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend H2 request failed", + )); } }; - self.build_streaming_response(upstream_response, route, rid, sip, conn_activity).await + self.build_streaming_response(upstream_response, route, rid, sip, conn_activity) + .await } /// Build the client-facing response from an upstream response, streaming the body. @@ -2113,18 +2646,22 @@ impl HttpProxyService { // for client-facing HTTP/3 advertisement, which must not be confused with // backend-originated Alt-Svc. if let Some(ref cache_key) = conn_activity.alt_svc_cache_key { - if let Some(alt_svc) = resp_parts.headers.get("alt-svc").and_then(|v| v.to_str().ok()) { + if let Some(alt_svc) = resp_parts + .headers + .get("alt-svc") + .and_then(|v| v.to_str().ok()) + { if let Some(h3_port) = parse_alt_svc_h3_port(alt_svc) { let url = conn_activity.alt_svc_request_url.as_deref().unwrap_or("-"); debug!(h3_port, url, "Backend advertises H3 via Alt-Svc"); let reason = format!("Alt-Svc response header ({})", url); - self.protocol_cache.insert_h3(cache_key.clone(), h3_port, &reason); + self.protocol_cache + .insert_h3(cache_key.clone(), h3_port, &reason); } } } - let mut response = Response::builder() - .status(resp_parts.status); + let mut response = Response::builder().status(resp_parts.status); if let Some(headers) = response.headers_mut() { *headers = resp_parts.headers; @@ -2153,7 +2690,11 @@ impl HttpProxyService { route_id, Some(source_ip), Direction::Out, - ).with_connection_activity(Arc::clone(&conn_activity.last_activity), conn_activity.start); + ) + .with_connection_activity( + Arc::clone(&conn_activity.last_activity), + conn_activity.start, + ); // Keep active_requests > 0 while the response body streams, so the idle // watchdog doesn't kill the connection mid-transfer (e.g. during git fetch). @@ -2191,58 +2732,92 @@ impl HttpProxyService { // Check allowed origins if configured if let Some(ws) = ws_config { if let Some(ref allowed_origins) = ws.allowed_origins { - let origin = req.headers() + let origin = req + .headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or(""); - if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) { + if !allowed_origins.is_empty() + && !allowed_origins.iter().any(|o| o == "*" || o == origin) + { self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed")); } } } - info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port); + info!( + "WebSocket upgrade from {} -> {}:{}", + peer_addr, upstream.host, upstream.port + ); // Connect to upstream with timeout (TLS if upstream.use_tls is set) let mut upstream_stream: BackendStream = if upstream.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port), - ).await { + ) + .await + { Ok(Ok(tls)) => BackendStream::Tls(tls), Ok(Err(e)) => { - error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e); + error!( + "WebSocket: failed TLS connect upstream {}:{}: {}", + upstream.host, upstream.port, e + ); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend TLS unavailable", + )); } Err(_) => { - error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); + error!( + "WebSocket: upstream TLS connect timeout for {}:{}", + upstream.host, upstream.port + ); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend TLS connect timeout", + )); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), - ).await { + ) + .await + { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( - &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) + &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)), ); BackendStream::Plain(s) } Ok(Err(e)) => { - error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e); + error!( + "WebSocket: failed to connect upstream {}:{}: {}", + upstream.host, upstream.port, e + ); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend unavailable", + )); } Err(_) => { - error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port); + error!( + "WebSocket: upstream connect timeout for {}:{}", + upstream.host, upstream.port + ); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); + return Ok(error_response( + StatusCode::GATEWAY_TIMEOUT, + "Backend connect timeout", + )); } } }; @@ -2269,10 +2844,7 @@ impl HttpProxyService { // H2 Extended CONNECT uses method=CONNECT, but the H1.1 backend expects GET let backend_method = if is_h2 { "GET" } else { parts.method.as_str() }; - let mut raw_request = format!( - "{} {} HTTP/1.1\r\n", - backend_method, upstream_path - ); + let mut raw_request = format!("{} {} HTTP/1.1\r\n", backend_method, upstream_path); // Copy all original headers (preserving the client's Host header). // Skip X-Forwarded-* since we set them ourselves below. @@ -2318,14 +2890,23 @@ impl HttpProxyService { // Add standard reverse-proxy headers (X-Forwarded-*) { - let original_host = parts.headers.get("host") + let original_host = parts + .headers + .get("host") .and_then(|h| h.to_str().ok()) .or(ws_host.as_deref()) .unwrap_or(""); - let forwarded_proto = if route.action.tls.as_ref() - .map(|t| matches!(t.mode, - rustproxy_config::TlsMode::Terminate - | rustproxy_config::TlsMode::TerminateAndReencrypt)) + let forwarded_proto = if route + .action + .tls + .as_ref() + .map(|t| { + matches!( + t.mode, + rustproxy_config::TlsMode::Terminate + | rustproxy_config::TlsMode::TerminateAndReencrypt + ) + }) .unwrap_or(false) { "https" @@ -2364,9 +2945,15 @@ impl HttpProxyService { raw_request.push_str("\r\n"); if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await { - error!("WebSocket: failed to send upgrade request to upstream: {}", e); + error!( + "WebSocket: failed to send upgrade request to upstream: {}", + e + ); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend write failed", + )); } let mut response_buf = Vec::with_capacity(4096); @@ -2384,7 +2971,8 @@ impl HttpProxyService { response_buf.extend_from_slice(&read_buf[..n]); // Scan for \r\n\r\n, backing up 3 bytes to handle split across reads let search_start = prev_len.saturating_sub(3); - if let Some(pos) = response_buf[search_start..].windows(4) + if let Some(pos) = response_buf[search_start..] + .windows(4) .position(|w| w == b"\r\n\r\n") { let header_end = search_start + pos + 4; @@ -2394,13 +2982,19 @@ impl HttpProxyService { if response_buf.len() > 8192 { error!("WebSocket: upstream response headers too large"); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend response too large", + )); } } Err(e) => { error!("WebSocket: failed to read upstream response: {}", e); self.upstream_selector.connection_ended(upstream_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "Backend read failed", + )); } } } @@ -2415,7 +3009,10 @@ impl HttpProxyService { .unwrap_or(0); if status_code != 101 { - debug!("WebSocket: upstream rejected upgrade with status {}", status_code); + debug!( + "WebSocket: upstream rejected upgrade with status {}", + status_code + ); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response( StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY), @@ -2442,7 +3039,8 @@ impl HttpProxyService { // Skip hop-by-hop headers for H2 (forbidden by RFC 9113 §8.2.2) if is_h2 { let name_lower = name.to_lowercase(); - if name_lower == "upgrade" || name_lower == "connection" + if name_lower == "upgrade" + || name_lower == "connection" || name_lower == "sec-websocket-accept" || name_lower == "transfer-encoding" || name_lower == "keep-alive" @@ -2450,7 +3048,8 @@ impl HttpProxyService { continue; } } - if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) { + if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) + { if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) { resp_headers.insert(header_name, header_value); } @@ -2459,9 +3058,10 @@ impl HttpProxyService { } } - let on_client_upgrade = hyper::upgrade::on( - Request::from_parts(parts, http_body_util::Empty::::new()) - ); + let on_client_upgrade = hyper::upgrade::on(Request::from_parts( + parts, + http_body_util::Empty::::new(), + )); let metrics = Arc::clone(&self.metrics); let route_id_owned: Option> = route_id.map(Arc::from); @@ -2515,8 +3115,12 @@ impl HttpProxyService { // For H2 WebSocket: also update the connection-level activity tracker // to prevent the idle watchdog from killing the H2 connection - let conn_act_c2u = conn_activity.as_ref().map(|ca| (Arc::clone(&ca.last_activity), ca.start)); - let conn_act_u2c = conn_activity.as_ref().map(|ca| (Arc::clone(&ca.last_activity), ca.start)); + let conn_act_c2u = conn_activity + .as_ref() + .map(|ca| (Arc::clone(&ca.last_activity), ca.start)); + let conn_act_u2c = conn_activity + .as_ref() + .map(|ca| (Arc::clone(&ca.last_activity), ca.start)); let la1 = Arc::clone(&last_activity); let metrics_c2u = Arc::clone(&metrics); @@ -2545,10 +3149,8 @@ impl HttpProxyService { } } // Graceful shutdown with timeout (sends TLS close_notify / TCP FIN) - let _ = tokio::time::timeout( - std::time::Duration::from_secs(2), - uw.shutdown(), - ).await; + let _ = + tokio::time::timeout(std::time::Duration::from_secs(2), uw.shutdown()).await; total }); @@ -2564,10 +3166,9 @@ impl HttpProxyService { if !extra_bytes.is_empty() { let n = extra_bytes.len(); if cw.write_all(&extra_bytes).await.is_err() { - let _ = tokio::time::timeout( - std::time::Duration::from_secs(2), - cw.shutdown(), - ).await; + let _ = + tokio::time::timeout(std::time::Duration::from_secs(2), cw.shutdown()) + .await; return 0u64; } total += n as u64; @@ -2596,10 +3197,8 @@ impl HttpProxyService { } } // Graceful shutdown with timeout (sends TLS close_notify / TCP FIN) - let _ = tokio::time::timeout( - std::time::Duration::from_secs(2), - cw.shutdown(), - ).await; + let _ = + tokio::time::timeout(std::time::Duration::from_secs(2), cw.shutdown()).await; total }); @@ -2635,7 +3234,10 @@ impl HttpProxyService { if current == last_seen { let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { - debug!("WebSocket tunnel inactive for {}ms, closing", elapsed_since_activity); + debug!( + "WebSocket tunnel inactive for {}ms, closing", + elapsed_since_activity + ); break; } } @@ -2654,18 +3256,22 @@ impl HttpProxyService { let bytes_out = u2c.await.unwrap_or(0); watchdog.abort(); - debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out); + debug!( + "WebSocket tunnel closed: {} bytes in, {} bytes out", + bytes_in, bytes_out + ); // _upstream_guard Drop handles connection_ended on all paths including panic }); - let body: BoxBody = BoxBody::new( - http_body_util::Empty::::new().map_err(|never| match never {}) - ); + let body: BoxBody = + BoxBody::new(http_body_util::Empty::::new().map_err(|never| match never {})); Ok(client_resp.body(body).unwrap()) } /// Build a test response from config (no upstream connection needed). - fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response> { + fn build_test_response( + config: &rustproxy_config::RouteTestResponse, + ) -> Response> { let mut response = Response::builder() .status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK)); @@ -2679,14 +3285,16 @@ impl HttpProxyService { } } - let body = Full::new(Bytes::from(config.body.clone())) - .map_err(|never| match never {}); + let body = Full::new(Bytes::from(config.body.clone())).map_err(|never| match never {}); response.body(BoxBody::new(body)).unwrap() } /// Apply URL rewriting rules from route config, using the compiled regex cache. fn apply_url_rewrite(&self, path: &str, route: &rustproxy_config::RouteConfig) -> String { - let rewrite = match route.action.advanced.as_ref() + let rewrite = match route + .action + .advanced + .as_ref() .and_then(|a| a.url_rewrite.as_ref()) { Some(r) => r, @@ -2745,14 +3353,23 @@ impl HttpProxyService { let is_dir = if clean_path.is_empty() { true } else { - tokio::fs::metadata(&file_path).await.map(|m| m.is_dir()).unwrap_or(false) + tokio::fs::metadata(&file_path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) }; if is_dir { - let index_files = config.index_files.as_deref() + let index_files = config + .index_files + .as_deref() .or(config.index.as_deref()) .unwrap_or(&[]); let default_index = vec!["index.html".to_string()]; - let index_files = if index_files.is_empty() { &default_index } else { index_files }; + let index_files = if index_files.is_empty() { + &default_index + } else { + index_files + }; let mut found = false; for index in index_files { @@ -2761,7 +3378,11 @@ impl HttpProxyService { } else { file_path.join(index) }; - if tokio::fs::metadata(&candidate).await.map(|m| m.is_file()).unwrap_or(false) { + if tokio::fs::metadata(&candidate) + .await + .map(|m| m.is_file()) + .unwrap_or(false) + { file_path = candidate; found = true; break; @@ -2810,8 +3431,7 @@ impl HttpProxyService { } } - let body = Full::new(Bytes::from(content)) - .map_err(|never| match never {}); + let body = Full::new(Bytes::from(content)).map_err(|never| match never {}); response.body(BoxBody::new(body)).unwrap() } Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"), @@ -2904,8 +3524,14 @@ impl HttpProxyService { let server_name = host.to_string(); let connecting = self.quinn_client_endpoint.connect(addr, &server_name)?; - let connection = tokio::time::timeout(self.connect_timeout, connecting).await - .map_err(|_| format!("QUIC connect timeout ({:?}) for {}", self.connect_timeout, host))??; + let connection = tokio::time::timeout(self.connect_timeout, connecting) + .await + .map_err(|_| { + format!( + "QUIC connect timeout ({:?}) for {}", + self.connect_timeout, host + ) + })??; debug!("QUIC backend connection established to {}:{}", host, port); Ok(connection) @@ -2944,7 +3570,10 @@ impl HttpProxyService { Err(e) => { error!(backend = %backend_key, domain = %domain, error = %e, "H3 client handshake failed"); self.metrics.backend_handshake_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 handshake failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "H3 handshake failed", + )); } }; @@ -3007,10 +3636,18 @@ impl HttpProxyService { match frame { Ok(frame) => { if let Ok(data) = frame.into_data() { - self.metrics.record_bytes(data.len() as u64, 0, rid.as_deref(), Some(&sip)); + self.metrics.record_bytes( + data.len() as u64, + 0, + rid.as_deref(), + Some(&sip), + ); if let Err(e) = stream.send_data(data).await { error!(backend = %backend_key, error = %e, "H3 send_data failed"); - return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 body send failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "H3 body send failed", + )); } } } @@ -3030,7 +3667,10 @@ impl HttpProxyService { Err(e) => { error!(backend = %backend_key, domain = %domain, error = %e, "H3 recv_response failed"); self.metrics.backend_request_error(backend_key); - return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 response failed")); + return Ok(error_response( + StatusCode::BAD_GATEWAY, + "H3 response failed", + )); } }; @@ -3073,7 +3713,11 @@ impl HttpProxyService { rid, Some(sip), Direction::Out, - ).with_connection_activity(Arc::clone(&conn_activity.last_activity), conn_activity.start); + ) + .with_connection_activity( + Arc::clone(&conn_activity.last_activity), + conn_activity.start, + ); let counting_body = if let Some(ref ar) = conn_activity.active_requests { counting_body.with_active_requests(Arc::clone(ar)) @@ -3086,11 +3730,9 @@ impl HttpProxyService { // Register connection in pool on success (fresh connections only) if status != StatusCode::BAD_GATEWAY { if let Some(gh) = gen_holder { - let g = self.connection_pool.register_h3( - pool_key.clone(), - quic_conn, - send_request, - ); + let g = self + .connection_pool + .register_h3(pool_key.clone(), quic_conn, send_request); gh.store(g, std::sync::atomic::Ordering::Relaxed); } } @@ -3198,8 +3840,7 @@ impl Default for HttpProxyService { } fn error_response(status: StatusCode, message: &str) -> Response> { - let body = Full::new(Bytes::from(message.to_string())) - .map_err(|never| match never {}); + let body = Full::new(Bytes::from(message.to_string())).map_err(|never| match never {}); Response::builder() .status(status) .header("Content-Type", "text/plain") diff --git a/rust/crates/rustproxy-http/src/request_filter.rs b/rust/crates/rustproxy-http/src/request_filter.rs index 163f2b7..3cd204c 100644 --- a/rust/crates/rustproxy-http/src/request_filter.rs +++ b/rust/crates/rustproxy-http/src/request_filter.rs @@ -4,13 +4,15 @@ use std::net::SocketAddr; use std::sync::Arc; use bytes::Bytes; -use http_body_util::Full; -use http_body_util::BodyExt; -use hyper::{Request, Response, StatusCode}; use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; +use http_body_util::Full; +use hyper::{Request, Response, StatusCode}; use rustproxy_config::RouteSecurity; -use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter}; +use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter}; + +use crate::request_host::extract_request_host; pub struct RequestFilter; @@ -35,16 +37,13 @@ impl RequestFilter { let client_ip = peer_addr.ip(); let request_path = req.uri().path(); - // IP filter (domain-aware: extract Host header for domain-scoped entries) + // IP filter (domain-aware: use the same host extraction as route matching) if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); let block = security.ip_block_list.as_deref().unwrap_or(&[]); let filter = IpFilter::new(allow, block); let normalized = IpFilter::normalize_ip(&client_ip); - let host = req.headers() - .get("host") - .and_then(|v| v.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h)); + let host = extract_request_host(req); if !filter.is_allowed_for_domain(&normalized, host) { return Some(error_response(StatusCode::FORBIDDEN, "Access denied")); } @@ -59,16 +58,15 @@ impl RequestFilter { !limiter.check(&key) } else { // Create a per-check limiter (less ideal but works for non-shared case) - let limiter = RateLimiter::new( - rate_limit_config.max_requests, - rate_limit_config.window, - ); + let limiter = + RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window); let key = Self::rate_limit_key(rate_limit_config, req, peer_addr); !limiter.check(&key) }; if should_block { - let message = rate_limit_config.error_message + let message = rate_limit_config + .error_message .as_deref() .unwrap_or("Rate limit exceeded"); return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message)); @@ -84,36 +82,48 @@ impl RequestFilter { if let Some(ref basic_auth) = security.basic_auth { if basic_auth.enabled { // Check basic auth exclude paths - let skip_basic = basic_auth.exclude_paths.as_ref() + let skip_basic = basic_auth + .exclude_paths + .as_ref() .map(|paths| Self::path_matches_any(request_path, paths)) .unwrap_or(false); if !skip_basic { - let users: Vec<(String, String)> = basic_auth.users.iter() + let users: Vec<(String, String)> = basic_auth + .users + .iter() .map(|c| (c.username.clone(), c.password.clone())) .collect(); let validator = BasicAuthValidator::new(users, basic_auth.realm.clone()); - let auth_header = req.headers() + let auth_header = req + .headers() .get("authorization") .and_then(|v| v.to_str().ok()); match auth_header { Some(header) => { if validator.validate(header).is_none() { - return Some(Response::builder() - .status(StatusCode::UNAUTHORIZED) - .header("WWW-Authenticate", validator.www_authenticate()) - .body(boxed_body("Invalid credentials")) - .unwrap()); + return Some( + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header( + "WWW-Authenticate", + validator.www_authenticate(), + ) + .body(boxed_body("Invalid credentials")) + .unwrap(), + ); } } None => { - return Some(Response::builder() - .status(StatusCode::UNAUTHORIZED) - .header("WWW-Authenticate", validator.www_authenticate()) - .body(boxed_body("Authentication required")) - .unwrap()); + return Some( + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", validator.www_authenticate()) + .body(boxed_body("Authentication required")) + .unwrap(), + ); } } } @@ -124,7 +134,9 @@ impl RequestFilter { if let Some(ref jwt_auth) = security.jwt_auth { if jwt_auth.enabled { // Check JWT auth exclude paths - let skip_jwt = jwt_auth.exclude_paths.as_ref() + let skip_jwt = jwt_auth + .exclude_paths + .as_ref() .map(|paths| Self::path_matches_any(request_path, paths)) .unwrap_or(false); @@ -136,18 +148,25 @@ impl RequestFilter { jwt_auth.audience.as_deref(), ); - let auth_header = req.headers() + let auth_header = req + .headers() .get("authorization") .and_then(|v| v.to_str().ok()); match auth_header.and_then(JwtValidator::extract_token) { Some(token) => { if validator.validate(token).is_err() { - return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token")); + return Some(error_response( + StatusCode::UNAUTHORIZED, + "Invalid token", + )); } } None => { - return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required")); + return Some(error_response( + StatusCode::UNAUTHORIZED, + "Bearer token required", + )); } } } @@ -209,7 +228,11 @@ impl RequestFilter { /// Check IP-based security (for use in passthrough / TCP-level connections). /// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering. /// Returns true if allowed, false if blocked. - pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr, domain: Option<&str>) -> bool { + pub fn check_ip_security( + security: &RouteSecurity, + client_ip: &std::net::IpAddr, + domain: Option<&str>, + ) -> bool { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); let block = security.ip_block_list.as_deref().unwrap_or(&[]); @@ -238,19 +261,28 @@ impl RequestFilter { return None; } - let origin = req.headers() + let origin = req + .headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or("*"); - Some(Response::builder() - .status(StatusCode::NO_CONTENT) - .header("Access-Control-Allow-Origin", origin) - .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") - .header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") - .header("Access-Control-Max-Age", "86400") - .body(boxed_body("")) - .unwrap()) + Some( + Response::builder() + .status(StatusCode::NO_CONTENT) + .header("Access-Control-Allow-Origin", origin) + .header( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, PATCH, OPTIONS", + ) + .header( + "Access-Control-Allow-Headers", + "Content-Type, Authorization, X-Requested-With", + ) + .header("Access-Control-Max-Age", "86400") + .body(boxed_body("")) + .unwrap(), + ) } } @@ -265,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response BoxBody { BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {})) } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use http_body_util::Empty; + use hyper::{Request, StatusCode, Version}; + use rustproxy_config::{IpAllowEntry, RouteSecurity}; + + use super::RequestFilter; + + fn domain_scoped_security() -> RouteSecurity { + RouteSecurity { + ip_allow_list: Some(vec![IpAllowEntry::DomainScoped { + ip: "10.8.0.2".to_string(), + domains: vec!["*.abc.xyz".to_string()], + }]), + ip_block_list: None, + max_connections: None, + authentication: None, + rate_limit: None, + basic_auth: None, + jwt_auth: None, + } + } + + fn peer_addr() -> std::net::SocketAddr { + std::net::SocketAddr::from(([10, 8, 0, 2], 4242)) + } + + fn request(uri: &str, version: Version, host: Option<&str>) -> Request> { + let mut builder = Request::builder().uri(uri).version(version); + if let Some(host) = host { + builder = builder.header("host", host); + } + + builder.body(Empty::::new()).unwrap() + } + + #[test] + fn domain_scoped_acl_allows_uri_authority_without_host_header() { + let security = domain_scoped_security(); + let req = request("https://outline.abc.xyz/", Version::HTTP_2, None); + + assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none()); + } + + #[test] + fn domain_scoped_acl_allows_host_header_with_port() { + let security = domain_scoped_security(); + let req = request( + "https://unrelated.invalid/", + Version::HTTP_11, + Some("outline.abc.xyz:443"), + ); + + assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none()); + } + + #[test] + fn domain_scoped_acl_denies_non_matching_uri_authority() { + let security = domain_scoped_security(); + let req = request("https://outline.other.xyz/", Version::HTTP_2, None); + + let response = RequestFilter::apply(&security, &req, &peer_addr()) + .expect("non-matching domain should be denied"); + assert_eq!(response.status(), StatusCode::FORBIDDEN); + } +} diff --git a/rust/crates/rustproxy-http/src/request_host.rs b/rust/crates/rustproxy-http/src/request_host.rs new file mode 100644 index 0000000..b0eb97f --- /dev/null +++ b/rust/crates/rustproxy-http/src/request_host.rs @@ -0,0 +1,43 @@ +use hyper::Request; + +/// Extract the effective request host for routing and scoped ACL checks. +/// +/// Prefer the explicit `Host` header when present, otherwise fall back to the +/// URI authority used by HTTP/2 and HTTP/3 requests. +pub(crate) fn extract_request_host(req: &Request) -> Option<&str> { + req.headers() + .get("host") + .and_then(|value| value.to_str().ok()) + .map(|host| host.split(':').next().unwrap_or(host)) + .or_else(|| req.uri().host()) +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use http_body_util::Empty; + use hyper::Request; + + use super::extract_request_host; + + #[test] + fn extracts_host_header_before_uri_authority() { + let req = Request::builder() + .uri("https://uri.abc.xyz/test") + .header("host", "header.abc.xyz:443") + .body(Empty::::new()) + .unwrap(); + + assert_eq!(extract_request_host(&req), Some("header.abc.xyz")); + } + + #[test] + fn falls_back_to_uri_authority_when_host_header_missing() { + let req = Request::builder() + .uri("https://outline.abc.xyz/test") + .body(Empty::::new()) + .unwrap(); + + assert_eq!(extract_request_host(&req), Some("outline.abc.xyz")); + } +} diff --git a/rust/crates/rustproxy-http/src/response_filter.rs b/rust/crates/rustproxy-http/src/response_filter.rs index c76a44a..ab8756d 100644 --- a/rust/crates/rustproxy-http/src/response_filter.rs +++ b/rust/crates/rustproxy-http/src/response_filter.rs @@ -3,7 +3,7 @@ use hyper::header::{HeaderMap, HeaderName, HeaderValue}; use rustproxy_config::RouteConfig; -use crate::template::{RequestContext, expand_template}; +use crate::template::{expand_template, RequestContext}; pub struct ResponseFilter; @@ -11,12 +11,17 @@ impl ResponseFilter { /// Apply response headers from route config and CORS settings. /// If a `RequestContext` is provided, template variables in header values will be expanded. /// Also injects Alt-Svc header for routes with HTTP/3 enabled. - pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) { + pub fn apply_headers( + route: &RouteConfig, + headers: &mut HeaderMap, + req_ctx: Option<&RequestContext>, + ) { // Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route if let Some(ref udp) = route.action.udp { if let Some(ref quic) = udp.quic { if quic.enable_http3.unwrap_or(false) { - let port = quic.alt_svc_port + let port = quic + .alt_svc_port .or_else(|| req_ctx.map(|c| c.port)) .unwrap_or(443); let max_age = quic.alt_svc_max_age.unwrap_or(86400); @@ -63,10 +68,7 @@ impl ResponseFilter { headers.insert("access-control-allow-origin", val); } } else { - headers.insert( - "access-control-allow-origin", - HeaderValue::from_static("*"), - ); + headers.insert("access-control-allow-origin", HeaderValue::from_static("*")); } // Allow-Methods diff --git a/rust/crates/rustproxy-http/src/shutdown_on_drop.rs b/rust/crates/rustproxy-http/src/shutdown_on_drop.rs index 1f02ba4..071d2a3 100644 --- a/rust/crates/rustproxy-http/src/shutdown_on_drop.rs +++ b/rust/crates/rustproxy-http/src/shutdown_on_drop.rs @@ -62,17 +62,11 @@ impl AsyncWrite for Shutdown self.inner.as_ref().unwrap().is_write_vectored() } - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx) } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx); if result.is_ready() { @@ -93,7 +87,8 @@ impl Drop for ShutdownOnDrop let _ = tokio::time::timeout( std::time::Duration::from_secs(2), tokio::io::AsyncWriteExt::shutdown(&mut stream), - ).await; + ) + .await; // stream is dropped here — all resources freed }); } diff --git a/rust/crates/rustproxy-http/src/template.rs b/rust/crates/rustproxy-http/src/template.rs index a6333bc..f690742 100644 --- a/rust/crates/rustproxy-http/src/template.rs +++ b/rust/crates/rustproxy-http/src/template.rs @@ -39,7 +39,8 @@ pub fn expand_headers( headers: &HashMap, ctx: &RequestContext, ) -> HashMap { - headers.iter() + headers + .iter() .map(|(k, v)| (k.clone(), expand_template(v, ctx))) .collect() } @@ -150,7 +151,10 @@ mod tests { let ctx = test_context(); let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}"; let result = expand_template(template, &ctx); - assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42"); + assert_eq!( + result, + "192.168.1.100|example.com|443|/api/v1/users|api-route|42" + ); } #[test] diff --git a/rust/crates/rustproxy-http/src/upstream_selector.rs b/rust/crates/rustproxy-http/src/upstream_selector.rs index f5c99d5..3e9520f 100644 --- a/rust/crates/rustproxy-http/src/upstream_selector.rs +++ b/rust/crates/rustproxy-http/src/upstream_selector.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use std::sync::Mutex; use dashmap::DashMap; -use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm}; +use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget}; /// Upstream selection result. pub struct UpstreamSelection { @@ -51,21 +51,19 @@ impl UpstreamSelector { } // Determine load balancing algorithm - let algorithm = target.load_balancing.as_ref() + let algorithm = target + .load_balancing + .as_ref() .map(|lb| &lb.algorithm) .unwrap_or(&LoadBalancingAlgorithm::RoundRobin); let idx = match algorithm { - LoadBalancingAlgorithm::RoundRobin => { - self.round_robin_select(&hosts, port) - } + LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port), LoadBalancingAlgorithm::IpHash => { let hash = Self::ip_hash(client_addr); hash % hosts.len() } - LoadBalancingAlgorithm::LeastConnections => { - self.least_connections_select(&hosts, port) - } + LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port), }; UpstreamSelection { @@ -78,9 +76,7 @@ impl UpstreamSelector { fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize { let key = format!("{}:{}", hosts[0], port); let mut counters = self.round_robin.lock().unwrap(); - let counter = counters - .entry(key) - .or_insert_with(|| AtomicUsize::new(0)); + let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0)); let idx = counter.fetch_add(1, Ordering::Relaxed); idx % hosts.len() } @@ -91,7 +87,8 @@ impl UpstreamSelector { for (i, host) in hosts.iter().enumerate() { let key = format!("{}:{}", host, port); - let conns = self.active_connections + let conns = self + .active_connections .get(&key) .map(|entry| entry.value().load(Ordering::Relaxed)) .unwrap_or(0); @@ -228,13 +225,21 @@ mod tests { selector.connection_started("backend:8080"); selector.connection_started("backend:8080"); assert_eq!( - selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed), + selector + .active_connections + .get("backend:8080") + .unwrap() + .load(Ordering::Relaxed), 2 ); selector.connection_ended("backend:8080"); assert_eq!( - selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed), + selector + .active_connections + .get("backend:8080") + .unwrap() + .load(Ordering::Relaxed), 1 ); diff --git a/rust/crates/rustproxy-metrics/src/collector.rs b/rust/crates/rustproxy-metrics/src/collector.rs index 8392b96..e026f3e 100644 --- a/rust/crates/rustproxy-metrics/src/collector.rs +++ b/rust/crates/rustproxy-metrics/src/collector.rs @@ -144,6 +144,15 @@ const MAX_BACKENDS_IN_SNAPSHOT: usize = 100; /// Maximum number of distinct domains tracked per IP (prevents subdomain-spray abuse). const MAX_DOMAINS_PER_IP: usize = 256; +fn canonicalize_domain_key(domain: &str) -> Option { + let normalized = domain.trim().trim_end_matches('.').to_ascii_lowercase(); + if normalized.is_empty() { + None + } else { + Some(normalized) + } +} + /// Metrics collector tracking connections and throughput. /// /// Design: The hot path (`record_bytes`) is entirely lock-free — it only touches @@ -334,25 +343,43 @@ impl MetricsCollector { /// Record a connection closing. pub fn connection_closed(&self, route_id: Option<&str>, source_ip: Option<&str>) { - self.active_connections.fetch_sub(1, Ordering::Relaxed); + self.active_connections + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { + if v > 0 { + Some(v - 1) + } else { + None + } + }) + .ok(); if let Some(route_id) = route_id { if let Some(counter) = self.route_connections.get(route_id) { - let val = counter.load(Ordering::Relaxed); - if val > 0 { - counter.fetch_sub(1, Ordering::Relaxed); - } + counter + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { + if v > 0 { + Some(v - 1) + } else { + None + } + }) + .ok(); } } if let Some(ip) = source_ip { if let Some(counter) = self.ip_connections.get(ip) { - let val = counter.load(Ordering::Relaxed); - if val > 0 { - counter.fetch_sub(1, Ordering::Relaxed); - } + let prev = counter + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { + if v > 0 { + Some(v - 1) + } else { + None + } + }) + .ok(); // Clean up zero-count entries to prevent memory growth - if val <= 1 { + if matches!(prev, Some(v) if v <= 1) { drop(counter); self.ip_connections.remove(ip); // Evict all per-IP tracking data for this IP @@ -371,17 +398,25 @@ impl MetricsCollector { /// /// Called per-chunk in the TCP copy loop. Only touches AtomicU64 counters — /// no Mutex is taken. The throughput trackers are fed during `sample_all()`. - pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>, source_ip: Option<&str>) { + pub fn record_bytes( + &self, + bytes_in: u64, + bytes_out: u64, + route_id: Option<&str>, + source_ip: Option<&str>, + ) { // Short-circuit: only touch counters for the direction that has data. // CountingBody always calls with one direction zero — skipping the zero // direction avoids ~50% of DashMap shard-locked reads per call. if bytes_in > 0 { self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed); - self.global_pending_tp_in.fetch_add(bytes_in, Ordering::Relaxed); + self.global_pending_tp_in + .fetch_add(bytes_in, Ordering::Relaxed); } if bytes_out > 0 { self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed); - self.global_pending_tp_out.fetch_add(bytes_out, Ordering::Relaxed); + self.global_pending_tp_out + .fetch_add(bytes_out, Ordering::Relaxed); } // Per-route tracking: use get() first (zero-alloc fast path for existing entries), @@ -391,7 +426,8 @@ impl MetricsCollector { if let Some(counter) = self.route_bytes_in.get(route_id) { counter.fetch_add(bytes_in, Ordering::Relaxed); } else { - self.route_bytes_in.entry(route_id.to_string()) + self.route_bytes_in + .entry(route_id.to_string()) .or_insert_with(|| AtomicU64::new(0)) .fetch_add(bytes_in, Ordering::Relaxed); } @@ -400,7 +436,8 @@ impl MetricsCollector { if let Some(counter) = self.route_bytes_out.get(route_id) { counter.fetch_add(bytes_out, Ordering::Relaxed); } else { - self.route_bytes_out.entry(route_id.to_string()) + self.route_bytes_out + .entry(route_id.to_string()) .or_insert_with(|| AtomicU64::new(0)) .fetch_add(bytes_out, Ordering::Relaxed); } @@ -408,13 +445,23 @@ impl MetricsCollector { // Accumulate into per-route pending throughput counters (lock-free) if let Some(entry) = self.route_pending_tp.get(route_id) { - if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } - if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } + if bytes_in > 0 { + entry.0.fetch_add(bytes_in, Ordering::Relaxed); + } + if bytes_out > 0 { + entry.1.fetch_add(bytes_out, Ordering::Relaxed); + } } else { - let entry = self.route_pending_tp.entry(route_id.to_string()) + let entry = self + .route_pending_tp + .entry(route_id.to_string()) .or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0))); - if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } - if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } + if bytes_in > 0 { + entry.0.fetch_add(bytes_in, Ordering::Relaxed); + } + if bytes_out > 0 { + entry.1.fetch_add(bytes_out, Ordering::Relaxed); + } } } @@ -428,7 +475,8 @@ impl MetricsCollector { if let Some(counter) = self.ip_bytes_in.get(ip) { counter.fetch_add(bytes_in, Ordering::Relaxed); } else { - self.ip_bytes_in.entry(ip.to_string()) + self.ip_bytes_in + .entry(ip.to_string()) .or_insert_with(|| AtomicU64::new(0)) .fetch_add(bytes_in, Ordering::Relaxed); } @@ -437,7 +485,8 @@ impl MetricsCollector { if let Some(counter) = self.ip_bytes_out.get(ip) { counter.fetch_add(bytes_out, Ordering::Relaxed); } else { - self.ip_bytes_out.entry(ip.to_string()) + self.ip_bytes_out + .entry(ip.to_string()) .or_insert_with(|| AtomicU64::new(0)) .fetch_add(bytes_out, Ordering::Relaxed); } @@ -445,13 +494,23 @@ impl MetricsCollector { // Accumulate into per-IP pending throughput counters (lock-free) if let Some(entry) = self.ip_pending_tp.get(ip) { - if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } - if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } + if bytes_in > 0 { + entry.0.fetch_add(bytes_in, Ordering::Relaxed); + } + if bytes_out > 0 { + entry.1.fetch_add(bytes_out, Ordering::Relaxed); + } } else { - let entry = self.ip_pending_tp.entry(ip.to_string()) + let entry = self + .ip_pending_tp + .entry(ip.to_string()) .or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0))); - if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } - if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } + if bytes_in > 0 { + entry.0.fetch_add(bytes_in, Ordering::Relaxed); + } + if bytes_out > 0 { + entry.1.fetch_add(bytes_out, Ordering::Relaxed); + } } } } @@ -469,9 +528,13 @@ impl MetricsCollector { /// connection (with SNI domain). The common case (IP + domain both already /// tracked) is two DashMap reads + one atomic increment — zero allocation. pub fn record_ip_domain_request(&self, ip: &str, domain: &str) { + let Some(domain) = canonicalize_domain_key(domain) else { + return; + }; + // Fast path: IP already tracked, domain already tracked if let Some(domains) = self.ip_domain_requests.get(ip) { - if let Some(counter) = domains.get(domain) { + if let Some(counter) = domains.get(domain.as_str()) { counter.fetch_add(1, Ordering::Relaxed); return; } @@ -480,7 +543,7 @@ impl MetricsCollector { return; } domains - .entry(domain.to_string()) + .entry(domain) .or_insert_with(|| AtomicU64::new(0)) .fetch_add(1, Ordering::Relaxed); return; @@ -490,7 +553,7 @@ impl MetricsCollector { return; } let inner = DashMap::with_capacity_and_shard_amount(4, 2); - inner.insert(domain.to_string(), AtomicU64::new(1)); + inner.insert(domain, AtomicU64::new(1)); self.ip_domain_requests.insert(ip.to_string(), inner); } @@ -504,7 +567,15 @@ impl MetricsCollector { /// Record a UDP session closed. pub fn udp_session_closed(&self) { - self.active_udp_sessions.fetch_sub(1, Ordering::Relaxed); + self.active_udp_sessions + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { + if v > 0 { + Some(v - 1) + } else { + None + } + }) + .ok(); } /// Record a UDP datagram (inbound or outbound). @@ -553,9 +624,15 @@ impl MetricsCollector { let (active, _) = self.frontend_proto_counters(proto); // Atomic saturating decrement — avoids TOCTOU race where concurrent // closes could both read val=1, both subtract, wrapping to u64::MAX. - active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { - if v > 0 { Some(v - 1) } else { None } - }).ok(); + active + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { + if v > 0 { + Some(v - 1) + } else { + None + } + }) + .ok(); } /// Record a backend connection opened with a given protocol. @@ -569,9 +646,15 @@ impl MetricsCollector { pub fn backend_protocol_closed(&self, proto: &str) { let (active, _) = self.backend_proto_counters(proto); // Atomic saturating decrement — see frontend_protocol_closed for rationale. - active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { - if v > 0 { Some(v - 1) } else { None } - }).ok(); + active + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { + if v > 0 { + Some(v - 1) + } else { + None + } + }) + .ok(); } // ── Per-backend recording methods ── @@ -681,17 +764,28 @@ impl MetricsCollector { /// Remove per-backend metrics for backends no longer in any route target. pub fn retain_backends(&self, active_backends: &HashSet) { - self.backend_active.retain(|k, _| active_backends.contains(k)); - self.backend_total.retain(|k, _| active_backends.contains(k)); - self.backend_protocol.retain(|k, _| active_backends.contains(k)); - self.backend_connect_errors.retain(|k, _| active_backends.contains(k)); - self.backend_handshake_errors.retain(|k, _| active_backends.contains(k)); - self.backend_request_errors.retain(|k, _| active_backends.contains(k)); - self.backend_connect_time_us.retain(|k, _| active_backends.contains(k)); - self.backend_connect_count.retain(|k, _| active_backends.contains(k)); - self.backend_pool_hits.retain(|k, _| active_backends.contains(k)); - self.backend_pool_misses.retain(|k, _| active_backends.contains(k)); - self.backend_h2_failures.retain(|k, _| active_backends.contains(k)); + self.backend_active + .retain(|k, _| active_backends.contains(k)); + self.backend_total + .retain(|k, _| active_backends.contains(k)); + self.backend_protocol + .retain(|k, _| active_backends.contains(k)); + self.backend_connect_errors + .retain(|k, _| active_backends.contains(k)); + self.backend_handshake_errors + .retain(|k, _| active_backends.contains(k)); + self.backend_request_errors + .retain(|k, _| active_backends.contains(k)); + self.backend_connect_time_us + .retain(|k, _| active_backends.contains(k)); + self.backend_connect_count + .retain(|k, _| active_backends.contains(k)); + self.backend_pool_hits + .retain(|k, _| active_backends.contains(k)); + self.backend_pool_misses + .retain(|k, _| active_backends.contains(k)); + self.backend_h2_failures + .retain(|k, _| active_backends.contains(k)); } /// Take a throughput sample on all trackers (cold path, call at 1Hz or configured interval). @@ -782,41 +876,64 @@ impl MetricsCollector { // Safety-net: prune orphaned per-IP entries that have no corresponding // ip_connections entry. This catches any entries created by a race between // record_bytes and connection_closed. - self.ip_bytes_in.retain(|k, _| self.ip_connections.contains_key(k)); - self.ip_bytes_out.retain(|k, _| self.ip_connections.contains_key(k)); - self.ip_pending_tp.retain(|k, _| self.ip_connections.contains_key(k)); - self.ip_throughput.retain(|k, _| self.ip_connections.contains_key(k)); - self.ip_total_connections.retain(|k, _| self.ip_connections.contains_key(k)); - self.ip_domain_requests.retain(|k, _| self.ip_connections.contains_key(k)); + self.ip_bytes_in + .retain(|k, _| self.ip_connections.contains_key(k)); + self.ip_bytes_out + .retain(|k, _| self.ip_connections.contains_key(k)); + self.ip_pending_tp + .retain(|k, _| self.ip_connections.contains_key(k)); + self.ip_throughput + .retain(|k, _| self.ip_connections.contains_key(k)); + self.ip_total_connections + .retain(|k, _| self.ip_connections.contains_key(k)); + self.ip_domain_requests + .retain(|k, _| self.ip_connections.contains_key(k)); // Safety-net: prune orphaned backend error/stats entries for backends // that have no active or total connections (error-only backends). // These accumulate when backend_connect_error/backend_handshake_error // create entries but backend_connection_opened is never called. - let known_backends: HashSet = self.backend_active.iter() + let known_backends: HashSet = self + .backend_active + .iter() .map(|e| e.key().clone()) .chain(self.backend_total.iter().map(|e| e.key().clone())) .collect(); - self.backend_connect_errors.retain(|k, _| known_backends.contains(k)); - self.backend_handshake_errors.retain(|k, _| known_backends.contains(k)); - self.backend_request_errors.retain(|k, _| known_backends.contains(k)); - self.backend_connect_time_us.retain(|k, _| known_backends.contains(k)); - self.backend_connect_count.retain(|k, _| known_backends.contains(k)); - self.backend_pool_hits.retain(|k, _| known_backends.contains(k)); - self.backend_pool_misses.retain(|k, _| known_backends.contains(k)); - self.backend_h2_failures.retain(|k, _| known_backends.contains(k)); - self.backend_protocol.retain(|k, _| known_backends.contains(k)); + self.backend_connect_errors + .retain(|k, _| known_backends.contains(k)); + self.backend_handshake_errors + .retain(|k, _| known_backends.contains(k)); + self.backend_request_errors + .retain(|k, _| known_backends.contains(k)); + self.backend_connect_time_us + .retain(|k, _| known_backends.contains(k)); + self.backend_connect_count + .retain(|k, _| known_backends.contains(k)); + self.backend_pool_hits + .retain(|k, _| known_backends.contains(k)); + self.backend_pool_misses + .retain(|k, _| known_backends.contains(k)); + self.backend_h2_failures + .retain(|k, _| known_backends.contains(k)); + self.backend_protocol + .retain(|k, _| known_backends.contains(k)); } /// Remove per-route metrics for route IDs that are no longer active. /// Call this after `update_routes()` to prune stale entries. pub fn retain_routes(&self, active_route_ids: &HashSet) { - self.route_connections.retain(|k, _| active_route_ids.contains(k)); - self.route_total_connections.retain(|k, _| active_route_ids.contains(k)); - self.route_bytes_in.retain(|k, _| active_route_ids.contains(k)); - self.route_bytes_out.retain(|k, _| active_route_ids.contains(k)); - self.route_pending_tp.retain(|k, _| active_route_ids.contains(k)); - self.route_throughput.retain(|k, _| active_route_ids.contains(k)); + self.route_connections + .retain(|k, _| active_route_ids.contains(k)); + self.route_total_connections + .retain(|k, _| active_route_ids.contains(k)); + self.route_bytes_in + .retain(|k, _| active_route_ids.contains(k)); + self.route_bytes_out + .retain(|k, _| active_route_ids.contains(k)); + self.route_pending_tp + .retain(|k, _| active_route_ids.contains(k)); + self.route_throughput + .retain(|k, _| active_route_ids.contains(k)); } /// Get current active connection count. @@ -859,72 +976,97 @@ impl MetricsCollector { for entry in self.route_total_connections.iter() { let route_id = entry.key().clone(); let total = entry.value().load(Ordering::Relaxed); - let active = self.route_connections + let active = self + .route_connections .get(&route_id) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let bytes_in = self.route_bytes_in + let bytes_in = self + .route_bytes_in .get(&route_id) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let bytes_out = self.route_bytes_out + let bytes_out = self + .route_bytes_out .get(&route_id) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self.route_throughput + let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self + .route_throughput .get(&route_id) - .and_then(|entry| entry.value().lock().ok().map(|t| { - let (i_in, i_out) = t.instant(); - let (r_in, r_out) = t.recent(); - (i_in, i_out, r_in, r_out) - })) + .and_then(|entry| { + entry.value().lock().ok().map(|t| { + let (i_in, i_out) = t.instant(); + let (r_in, r_out) = t.recent(); + (i_in, i_out, r_in, r_out) + }) + }) .unwrap_or((0, 0, 0, 0)); - routes.insert(route_id, RouteMetrics { - active_connections: active, - total_connections: total, - bytes_in, - bytes_out, - throughput_in_bytes_per_sec: route_tp_in, - throughput_out_bytes_per_sec: route_tp_out, - throughput_recent_in_bytes_per_sec: route_recent_in, - throughput_recent_out_bytes_per_sec: route_recent_out, - }); + routes.insert( + route_id, + RouteMetrics { + active_connections: active, + total_connections: total, + bytes_in, + bytes_out, + throughput_in_bytes_per_sec: route_tp_in, + throughput_out_bytes_per_sec: route_tp_out, + throughput_recent_in_bytes_per_sec: route_recent_in, + throughput_recent_out_bytes_per_sec: route_recent_out, + }, + ); } // Collect per-IP metrics — only IPs with active connections or total > 0, // capped at top MAX_IPS_IN_SNAPSHOT sorted by active count - let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap)> = Vec::new(); + let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap)> = + Vec::new(); for entry in self.ip_total_connections.iter() { let ip = entry.key().clone(); let total = entry.value().load(Ordering::Relaxed); - let active = self.ip_connections + let active = self + .ip_connections .get(&ip) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let bytes_in = self.ip_bytes_in + let bytes_in = self + .ip_bytes_in .get(&ip) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let bytes_out = self.ip_bytes_out + let bytes_out = self + .ip_bytes_out .get(&ip) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let (tp_in, tp_out) = self.ip_throughput + let (tp_in, tp_out) = self + .ip_throughput .get(&ip) .and_then(|entry| entry.value().lock().ok().map(|t| t.instant())) .unwrap_or((0, 0)); // Collect per-domain request counts for this IP - let domain_requests = self.ip_domain_requests + let domain_requests = self + .ip_domain_requests .get(&ip) .map(|domains| { - domains.iter() + domains + .iter() .map(|e| (e.key().clone(), e.value().load(Ordering::Relaxed))) .collect() }) .unwrap_or_default(); - ip_entries.push((ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests)); + ip_entries.push(( + ip, + active, + total, + bytes_in, + bytes_out, + tp_in, + tp_out, + domain_requests, + )); } // Sort by active connections descending, then cap ip_entries.sort_by(|a, b| b.1.cmp(&a.1)); @@ -932,15 +1074,18 @@ impl MetricsCollector { let mut ips = std::collections::HashMap::new(); for (ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests) in ip_entries { - ips.insert(ip, IpMetrics { - active_connections: active, - total_connections: total, - bytes_in, - bytes_out, - throughput_in_bytes_per_sec: tp_in, - throughput_out_bytes_per_sec: tp_out, - domain_requests, - }); + ips.insert( + ip, + IpMetrics { + active_connections: active, + total_connections: total, + bytes_in, + bytes_out, + throughput_in_bytes_per_sec: tp_in, + throughput_out_bytes_per_sec: tp_out, + domain_requests, + }, + ); } // Collect per-backend metrics, capped at top MAX_BACKENDS_IN_SNAPSHOT by total connections @@ -948,69 +1093,84 @@ impl MetricsCollector { for entry in self.backend_total.iter() { let key = entry.key().clone(); let total = entry.value().load(Ordering::Relaxed); - let active = self.backend_active + let active = self + .backend_active .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let protocol = self.backend_protocol + let protocol = self + .backend_protocol .get(&key) .map(|v| v.value().clone()) .unwrap_or_else(|| "unknown".to_string()); - let connect_errors = self.backend_connect_errors + let connect_errors = self + .backend_connect_errors .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let handshake_errors = self.backend_handshake_errors + let handshake_errors = self + .backend_handshake_errors .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let request_errors = self.backend_request_errors + let request_errors = self + .backend_request_errors .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let total_connect_time_us = self.backend_connect_time_us + let total_connect_time_us = self + .backend_connect_time_us .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let connect_count = self.backend_connect_count + let connect_count = self + .backend_connect_count .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let pool_hits = self.backend_pool_hits + let pool_hits = self + .backend_pool_hits .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let pool_misses = self.backend_pool_misses + let pool_misses = self + .backend_pool_misses .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - let h2_failures = self.backend_h2_failures + let h2_failures = self + .backend_h2_failures .get(&key) .map(|c| c.load(Ordering::Relaxed)) .unwrap_or(0); - backend_entries.push((key, BackendMetrics { - active_connections: active, - total_connections: total, - protocol, - connect_errors, - handshake_errors, - request_errors, - total_connect_time_us, - connect_count, - pool_hits, - pool_misses, - h2_failures, - })); + backend_entries.push(( + key, + BackendMetrics { + active_connections: active, + total_connections: total, + protocol, + connect_errors, + handshake_errors, + request_errors, + total_connect_time_us, + connect_count, + pool_hits, + pool_misses, + h2_failures, + }, + )); } // Sort by total connections descending, then cap backend_entries.sort_by(|a, b| b.1.total_connections.cmp(&a.1.total_connections)); backend_entries.truncate(MAX_BACKENDS_IN_SNAPSHOT); - let backends: std::collections::HashMap = backend_entries.into_iter().collect(); + let backends: std::collections::HashMap = + backend_entries.into_iter().collect(); // HTTP request rates - let (http_rps, http_rps_recent) = self.http_request_throughput + let (http_rps, http_rps_recent) = self + .http_request_throughput .lock() .map(|t| { let (instant, _) = t.instant(); @@ -1185,11 +1345,19 @@ mod tests { // Check IP active connections (drop DashMap refs immediately to avoid deadlock) assert_eq!( - collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed), + collector + .ip_connections + .get("1.2.3.4") + .unwrap() + .load(Ordering::Relaxed), 2 ); assert_eq!( - collector.ip_connections.get("5.6.7.8").unwrap().load(Ordering::Relaxed), + collector + .ip_connections + .get("5.6.7.8") + .unwrap() + .load(Ordering::Relaxed), 1 ); @@ -1207,7 +1375,11 @@ mod tests { // Close connections collector.connection_closed(Some("route-a"), Some("1.2.3.4")); assert_eq!( - collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed), + collector + .ip_connections + .get("1.2.3.4") + .unwrap() + .load(Ordering::Relaxed), 1 ); @@ -1252,6 +1424,79 @@ mod tests { assert!(collector.ip_total_connections.get("10.0.0.2").is_some()); } + #[test] + fn test_connection_closed_saturates_active_gauges() { + let collector = MetricsCollector::new(); + + collector.connection_closed(Some("route-a"), Some("10.0.0.1")); + assert_eq!(collector.active_connections(), 0); + + collector.connection_opened(Some("route-a"), Some("10.0.0.1")); + collector.connection_closed(Some("route-a"), Some("10.0.0.1")); + collector.connection_closed(Some("route-a"), Some("10.0.0.1")); + + assert_eq!(collector.active_connections(), 0); + assert_eq!( + collector + .route_connections + .get("route-a") + .map(|c| c.load(Ordering::Relaxed)) + .unwrap_or(0), + 0 + ); + assert!(collector.ip_connections.get("10.0.0.1").is_none()); + } + + #[test] + fn test_udp_session_closed_saturates() { + let collector = MetricsCollector::new(); + + collector.udp_session_closed(); + assert_eq!(collector.snapshot().active_udp_sessions, 0); + + collector.udp_session_opened(); + collector.udp_session_closed(); + collector.udp_session_closed(); + assert_eq!(collector.snapshot().active_udp_sessions, 0); + } + + #[test] + fn test_ip_domain_requests_are_canonicalized() { + let collector = MetricsCollector::new(); + + collector.connection_opened(Some("route-a"), Some("10.0.0.1")); + collector.record_ip_domain_request("10.0.0.1", "Example.COM"); + collector.record_ip_domain_request("10.0.0.1", "example.com."); + collector.record_ip_domain_request("10.0.0.1", " example.com "); + + let snapshot = collector.snapshot(); + let ip_metrics = snapshot.ips.get("10.0.0.1").unwrap(); + assert_eq!(ip_metrics.domain_requests.len(), 1); + assert_eq!(ip_metrics.domain_requests.get("example.com"), Some(&3)); + } + + #[test] + fn test_protocol_metrics_appear_in_snapshot() { + let collector = MetricsCollector::new(); + + collector.frontend_protocol_opened("h2"); + collector.frontend_protocol_opened("ws"); + collector.backend_protocol_opened("h3"); + collector.backend_protocol_opened("ws"); + collector.frontend_protocol_closed("h2"); + collector.backend_protocol_closed("h3"); + + let snapshot = collector.snapshot(); + assert_eq!(snapshot.frontend_protocols.h2_active, 0); + assert_eq!(snapshot.frontend_protocols.h2_total, 1); + assert_eq!(snapshot.frontend_protocols.ws_active, 1); + assert_eq!(snapshot.frontend_protocols.ws_total, 1); + assert_eq!(snapshot.backend_protocols.h3_active, 0); + assert_eq!(snapshot.backend_protocols.h3_total, 1); + assert_eq!(snapshot.backend_protocols.ws_active, 1); + assert_eq!(snapshot.backend_protocols.ws_total, 1); + } + #[test] fn test_http_request_tracking() { let collector = MetricsCollector::with_retention(60); @@ -1326,9 +1571,16 @@ mod tests { let collector = MetricsCollector::with_retention(60); // Manually insert orphaned entries (simulates the race before the guard) - collector.ip_bytes_in.insert("orphan-ip".to_string(), AtomicU64::new(100)); - collector.ip_bytes_out.insert("orphan-ip".to_string(), AtomicU64::new(200)); - collector.ip_pending_tp.insert("orphan-ip".to_string(), (AtomicU64::new(0), AtomicU64::new(0))); + collector + .ip_bytes_in + .insert("orphan-ip".to_string(), AtomicU64::new(100)); + collector + .ip_bytes_out + .insert("orphan-ip".to_string(), AtomicU64::new(200)); + collector.ip_pending_tp.insert( + "orphan-ip".to_string(), + (AtomicU64::new(0), AtomicU64::new(0)), + ); // No ip_connections entry for "orphan-ip" assert!(collector.ip_connections.get("orphan-ip").is_none()); @@ -1366,17 +1618,59 @@ mod tests { collector.backend_connection_opened(key, Duration::from_millis(15)); collector.backend_connection_opened(key, Duration::from_millis(25)); - assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 2); - assert_eq!(collector.backend_total.get(key).unwrap().load(Ordering::Relaxed), 2); - assert_eq!(collector.backend_connect_count.get(key).unwrap().load(Ordering::Relaxed), 2); + assert_eq!( + collector + .backend_active + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 2 + ); + assert_eq!( + collector + .backend_total + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 2 + ); + assert_eq!( + collector + .backend_connect_count + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 2 + ); // 15ms + 25ms = 40ms = 40_000us - assert_eq!(collector.backend_connect_time_us.get(key).unwrap().load(Ordering::Relaxed), 40_000); + assert_eq!( + collector + .backend_connect_time_us + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 40_000 + ); // Close one collector.backend_connection_closed(key); - assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 1); + assert_eq!( + collector + .backend_active + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 1 + ); // total stays - assert_eq!(collector.backend_total.get(key).unwrap().load(Ordering::Relaxed), 2); + assert_eq!( + collector + .backend_total + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 2 + ); // Record errors collector.backend_connect_error(key); @@ -1387,12 +1681,54 @@ mod tests { collector.backend_pool_hit(key); collector.backend_pool_miss(key); - assert_eq!(collector.backend_connect_errors.get(key).unwrap().load(Ordering::Relaxed), 1); - assert_eq!(collector.backend_handshake_errors.get(key).unwrap().load(Ordering::Relaxed), 1); - assert_eq!(collector.backend_request_errors.get(key).unwrap().load(Ordering::Relaxed), 1); - assert_eq!(collector.backend_h2_failures.get(key).unwrap().load(Ordering::Relaxed), 1); - assert_eq!(collector.backend_pool_hits.get(key).unwrap().load(Ordering::Relaxed), 2); - assert_eq!(collector.backend_pool_misses.get(key).unwrap().load(Ordering::Relaxed), 1); + assert_eq!( + collector + .backend_connect_errors + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 1 + ); + assert_eq!( + collector + .backend_handshake_errors + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 1 + ); + assert_eq!( + collector + .backend_request_errors + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 1 + ); + assert_eq!( + collector + .backend_h2_failures + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 1 + ); + assert_eq!( + collector + .backend_pool_hits + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 2 + ); + assert_eq!( + collector + .backend_pool_misses + .get(key) + .unwrap() + .load(Ordering::Relaxed), + 1 + ); // Protocol collector.set_backend_protocol(key, "h1"); @@ -1449,7 +1785,10 @@ mod tests { assert!(collector.backend_total.get("stale:8080").is_none()); assert!(collector.backend_protocol.get("stale:8080").is_none()); assert!(collector.backend_connect_errors.get("stale:8080").is_none()); - assert!(collector.backend_connect_time_us.get("stale:8080").is_none()); + assert!(collector + .backend_connect_time_us + .get("stale:8080") + .is_none()); assert!(collector.backend_connect_count.get("stale:8080").is_none()); assert!(collector.backend_pool_hits.get("stale:8080").is_none()); assert!(collector.backend_pool_misses.get("stale:8080").is_none()); diff --git a/ts/00_commitinfo_data.ts b/ts/00_commitinfo_data.ts index d6b362a..e1e4933 100644 --- a/ts/00_commitinfo_data.ts +++ b/ts/00_commitinfo_data.ts @@ -3,6 +3,6 @@ */ export const commitinfo = { name: '@push.rocks/smartproxy', - version: '27.7.0', + version: '27.7.1', description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.' }