//! Hyper-based HTTP proxy service. //! //! Accepts decrypted TCP streams (from TLS termination or plain TCP), //! parses HTTP requests, matches routes, and forwards to upstream backends. //! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade. use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use bytes::Bytes; use dashmap::DashMap; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use regex::Regex; use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; use std::pin::Pin; use std::task::{Context, Poll}; use rustproxy_routing::RouteManager; use rustproxy_metrics::MetricsCollector; use rustproxy_security::RateLimiter; use crate::counting_body::{CountingBody, Direction}; use crate::request_filter::RequestFilter; use crate::response_filter::ResponseFilter; use crate::upstream_selector::UpstreamSelector; /// Default upstream connect timeout (30 seconds). const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); /// Default HTTP keep-alive idle timeout (60 seconds). /// If no new request arrives within this duration, the connection is closed. const DEFAULT_HTTP_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); /// Default WebSocket inactivity timeout (1 hour). const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3600); /// Default WebSocket max lifetime (24 hours). const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400); /// RAII guard that decrements the active request counter on drop. /// Ensures the counter is correct even if the request handler panics. struct ActiveRequestGuard { counter: Arc, } impl ActiveRequestGuard { fn new(counter: Arc) -> Self { counter.fetch_add(1, Ordering::Relaxed); Self { counter } } } impl Drop for ActiveRequestGuard { fn drop(&mut self) { self.counter.fetch_sub(1, Ordering::Relaxed); } } /// Backend stream that can be either plain TCP or TLS-wrapped. /// Used for `terminate-and-reencrypt` mode where the backend requires TLS. pub(crate) enum BackendStream { Plain(TcpStream), Tls(tokio_rustls::client::TlsStream), } impl tokio::io::AsyncRead for BackendStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_read(cx, buf), BackendStream::Tls(s) => Pin::new(s).poll_read(cx, buf), } } } impl tokio::io::AsyncWrite for BackendStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_write(cx, buf), BackendStream::Tls(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_flush(cx), BackendStream::Tls(s) => Pin::new(s).poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_shutdown(cx), BackendStream::Tls(s) => Pin::new(s).poll_shutdown(cx), } } } /// Connect to a backend over TLS using the shared backend TLS config /// (from tls_handler). Session resumption is automatic. async fn connect_tls_backend( backend_tls_config: &Arc, host: &str, port: u16, ) -> Result, Box> { let connector = tokio_rustls::TlsConnector::from(Arc::clone(backend_tls_config)); let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; stream.set_nodelay(true)?; // Apply keepalive with 60s default let _ = socket2::SockRef::from(&stream).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?; let tls_stream = connector.connect(server_name, stream).await?; debug!("Backend TLS connection established to {}:{}", host, port); Ok(tls_stream) } /// HTTP proxy service that processes HTTP traffic. pub struct HttpProxyService { route_manager: Arc, metrics: Arc, upstream_selector: UpstreamSelector, /// Timeout for connecting to upstream backends. connect_timeout: std::time::Duration, /// Per-route rate limiters (keyed by route ID). route_rate_limiters: Arc>>, /// Request counter for periodic rate limiter cleanup. request_counter: AtomicU64, /// Cache of compiled URL rewrite regexes (keyed by pattern string). regex_cache: DashMap, /// Shared backend TLS config for session resumption across connections. backend_tls_config: Arc, /// Backend TLS config with ALPN h2+http/1.1 for auto-detection mode. backend_tls_config_alpn: Arc, /// Backend connection pool for reusing keep-alive connections. connection_pool: Arc, /// Protocol detection cache for auto mode (caches ALPN-detected protocol per backend). protocol_cache: Arc, /// HTTP keep-alive idle timeout: close connection if no new request arrives within this duration. http_idle_timeout: std::time::Duration, /// WebSocket inactivity timeout (no data in either direction). ws_inactivity_timeout: std::time::Duration, /// WebSocket maximum connection lifetime. ws_max_lifetime: std::time::Duration, } impl HttpProxyService { pub fn new(route_manager: Arc, metrics: Arc) -> Self { Self { route_manager, metrics, upstream_selector: UpstreamSelector::new(), connect_timeout: DEFAULT_CONNECT_TIMEOUT, route_rate_limiters: Arc::new(DashMap::new()), request_counter: AtomicU64::new(0), regex_cache: DashMap::new(), backend_tls_config: Self::default_backend_tls_config(), backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(), connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()), protocol_cache: Arc::new(crate::protocol_cache::ProtocolCache::new()), http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT, ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT, ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME, } } /// Create with a custom connect timeout. pub fn with_connect_timeout( route_manager: Arc, metrics: Arc, connect_timeout: std::time::Duration, ) -> Self { Self { route_manager, metrics, upstream_selector: UpstreamSelector::new(), connect_timeout, route_rate_limiters: Arc::new(DashMap::new()), request_counter: AtomicU64::new(0), regex_cache: DashMap::new(), backend_tls_config: Self::default_backend_tls_config(), backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(), connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()), protocol_cache: Arc::new(crate::protocol_cache::ProtocolCache::new()), http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT, ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT, ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME, } } /// Set the HTTP keep-alive idle timeout, WebSocket inactivity timeout, and /// WebSocket max lifetime from connection config values. pub fn set_connection_timeouts( &mut self, http_idle_timeout: std::time::Duration, ws_inactivity_timeout: std::time::Duration, ws_max_lifetime: std::time::Duration, ) { self.http_idle_timeout = http_idle_timeout; self.ws_inactivity_timeout = ws_inactivity_timeout; self.ws_max_lifetime = ws_max_lifetime; } /// Set the shared backend TLS config (enables session resumption). /// Call this after construction to inject the shared config from tls_handler. pub fn set_backend_tls_config(&mut self, config: Arc) { self.backend_tls_config = config; } /// Set the shared backend TLS config with ALPN h2+http/1.1 (for auto-detection mode). pub fn set_backend_tls_config_alpn(&mut self, config: Arc) { self.backend_tls_config_alpn = config; } /// Prune caches for route IDs that are no longer active. /// Call after route updates to prevent unbounded growth. pub fn prune_stale_routes(&self, active_route_ids: &std::collections::HashSet) { self.route_rate_limiters.retain(|k, _| active_route_ids.contains(k)); self.regex_cache.clear(); self.upstream_selector.reset_round_robin(); self.protocol_cache.clear(); } /// Handle an incoming HTTP connection on a plain TCP stream. pub async fn handle_connection( self: Arc, stream: TcpStream, peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, ) { self.handle_io(stream, peer_addr, port, cancel).await; } /// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated). /// /// Uses `hyper_util::server::conn::auto::Builder` to auto-detect h1 vs h2 /// based on ALPN negotiation (TLS) or connection preface (h2c). /// Supports HTTP/1.1 upgrades (WebSocket) and HTTP/2 CONNECT. /// Responds to graceful shutdown via the cancel token. /// /// An idle watchdog closes the connection if no new HTTP request arrives /// within `http_idle_timeout` (default 60s). This prevents keep-alive /// connections from accumulating indefinitely. pub async fn handle_io( self: Arc, stream: I, peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, ) where I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { let io = TokioIo::new(stream); // Capture timeouts before `self` is moved into the service closure. let idle_timeout = self.http_idle_timeout; // Activity tracker: updated at the START and END of each request. // The idle watchdog checks this to determine if the connection is idle // (no request in progress and none started recently). let last_activity = Arc::new(AtomicU64::new(0)); let active_requests = Arc::new(AtomicU64::new(0)); let start = std::time::Instant::now(); let la_inner = Arc::clone(&last_activity); let ar_inner = Arc::clone(&active_requests); let cancel_inner = cancel.clone(); let service = hyper::service::service_fn(move |req: Request| { // Mark request start — RAII guard decrements on drop (panic-safe) la_inner.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); let req_guard = ActiveRequestGuard::new(Arc::clone(&ar_inner)); let svc = Arc::clone(&self); let peer = peer_addr; let cn = cancel_inner.clone(); let la = Arc::clone(&la_inner); let st = start; async move { let result = svc.handle_request(req, peer, port, cn).await; // Mark request end — update activity timestamp before guard drops la.store(st.elapsed().as_millis() as u64, Ordering::Relaxed); drop(req_guard); // Explicitly drop to decrement active_requests result } }); // Auto-detect h1 vs h2 based on ALPN / connection preface. // serve_connection_with_upgrades supports h1 Upgrade (WebSocket) and h2 CONNECT. let builder = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); let conn = builder.serve_connection_with_upgrades(io, service); // Pin on the heap — auto::UpgradeableConnection is !Unpin let mut conn = Box::pin(conn); // Use select to support graceful shutdown, cancellation, and idle timeout tokio::select! { result = conn.as_mut() => { if let Err(e) = result { debug!("HTTP connection error from {}: {}", peer_addr, e); } } _ = cancel.cancelled() => { // Graceful shutdown: let in-flight request finish, stop accepting new ones conn.as_mut().graceful_shutdown(); if let Err(e) = conn.await { debug!("HTTP connection error during shutdown from {}: {}", peer_addr, e); } } _ = async { // Idle watchdog: check every 5s whether the connection has been idle // (no active requests AND no activity for idle_timeout). // This avoids killing long-running requests or upgraded connections. let check_interval = std::time::Duration::from_secs(5); let mut last_seen = 0u64; loop { tokio::time::sleep(check_interval).await; // Never close while a request is in progress if active_requests.load(Ordering::Relaxed) > 0 { last_seen = last_activity.load(Ordering::Relaxed); continue; } let current = last_activity.load(Ordering::Relaxed); if current == last_seen { // No new activity since last check let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= idle_timeout.as_millis() as u64 { return; } } last_seen = current; } } => { debug!("HTTP connection idle timeout ({}s) from {}", idle_timeout.as_secs(), peer_addr); conn.as_mut().graceful_shutdown(); // Give any in-flight work 5s to drain after graceful shutdown let _ = tokio::time::timeout(std::time::Duration::from_secs(5), conn).await; } } } /// Handle a single HTTP request. async fn handle_request( &self, req: Request, peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, ) -> Result>, hyper::Error> { let host = req.headers() .get("host") .and_then(|h| h.to_str().ok()) .map(|h| { // Strip port from host header h.split(':').next().unwrap_or(h).to_string() }) // HTTP/2 uses :authority pseudo-header instead of Host; // hyper maps it to the URI authority component .or_else(|| req.uri().host().map(|h| h.to_string())); let path = req.uri().path().to_string(); let method = req.method().clone(); // Extract headers for matching let headers: HashMap = req.headers() .iter() .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) .collect(); debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr); // Check for CORS preflight if method == hyper::Method::OPTIONS { if let Some(response) = RequestFilter::handle_cors_preflight(&req) { return Ok(response); } } // Match route let ctx = rustproxy_routing::MatchContext { port, domain: host.as_deref(), path: Some(&path), client_ip: Some(&peer_addr.ip().to_string()), tls_version: None, headers: Some(&headers), is_tls: false, protocol: Some("http"), }; let route_match = match self.route_manager.find_route(&ctx) { Some(rm) => rm, None => { debug!("No route matched for HTTP request to {:?}{}", host, path); return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched")); } }; let route_id = route_match.route.id.as_deref(); let ip_str = peer_addr.ip().to_string(); self.metrics.record_http_request(); // Apply request filters (IP check, rate limiting, auth) if let Some(ref security) = route_match.route.security { // Look up or create a shared rate limiter for this route let rate_limiter = security.rate_limit.as_ref() .filter(|rl| rl.enabled) .map(|rl| { let route_key = route_id.unwrap_or("__default__").to_string(); self.route_rate_limiters .entry(route_key) .or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window))) .clone() }); if let Some(response) = RequestFilter::apply_with_rate_limiter( security, &req, &peer_addr, rate_limiter.as_ref(), ) { return Ok(response); } } // Periodic rate limiter cleanup (every 1000 requests) let count = self.request_counter.fetch_add(1, Ordering::Relaxed); if count % 1000 == 0 { for entry in self.route_rate_limiters.iter() { entry.value().cleanup(); } } // Check for test response (returns immediately, no upstream needed) if let Some(ref advanced) = route_match.route.action.advanced { if let Some(ref test_response) = advanced.test_response { return Ok(Self::build_test_response(test_response)); } } // Check for static file serving if let Some(ref advanced) = route_match.route.action.advanced { if let Some(ref static_files) = advanced.static_files { return Ok(Self::serve_static_file(&path, static_files)); } } // Select upstream let target = match route_match.target { Some(t) => t, None => { return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available")); } }; let mut upstream = self.upstream_selector.select(target, &peer_addr, port); // If the route uses terminate-and-reencrypt, always re-encrypt to backend if let Some(ref tls) = route_match.route.action.tls { if tls.mode == rustproxy_config::TlsMode::TerminateAndReencrypt { upstream.use_tls = true; } } let upstream_key = format!("{}:{}", upstream.host, upstream.port); self.upstream_selector.connection_started(&upstream_key); // Check for WebSocket upgrade let is_websocket = req.headers() .get("upgrade") .and_then(|v| v.to_str().ok()) .map(|v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false); if is_websocket { let result = self.handle_websocket_upgrade( req, peer_addr, &upstream, route_match.route, route_id, &upstream_key, cancel, &ip_str, ).await; // Note: for WebSocket, connection_ended is called inside // the spawned tunnel task when the connection closes. return result; } // Determine backend protocol mode let backend_protocol_mode = route_match.route.action.options.as_ref() .and_then(|o| o.backend_protocol.as_ref()) .cloned() .unwrap_or(rustproxy_config::BackendProtocol::Auto); // Build the upstream path (path + query), applying URL rewriting if configured let upstream_path = { let raw_path = match req.uri().query() { Some(q) => format!("{}?{}", path, q), None => path.clone(), }; self.apply_url_rewrite(&raw_path, &route_match.route) }; // Build upstream request - stream body instead of buffering let (parts, body) = req.into_parts(); // Apply request headers from route config let mut upstream_headers = parts.headers.clone(); if let Some(ref route_headers) = route_match.route.headers { if let Some(ref request_headers) = route_headers.request { for (key, value) in request_headers { if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) { if let Ok(val) = hyper::header::HeaderValue::from_str(value) { upstream_headers.insert(name, val); } } } } } // Ensure Host header is set (HTTP/2 requests don't have Host; need it for h1 backends) if !upstream_headers.contains_key("host") { if let Some(ref h) = host { if let Ok(val) = hyper::header::HeaderValue::from_str(h) { upstream_headers.insert(hyper::header::HOST, val); } } } // Add standard reverse-proxy headers (X-Forwarded-*) { let original_host = host.as_deref().unwrap_or(""); let forwarded_proto = if route_match.route.action.tls.as_ref() .map(|t| matches!(t.mode, rustproxy_config::TlsMode::Terminate | rustproxy_config::TlsMode::TerminateAndReencrypt)) .unwrap_or(false) { "https" } else { "http" }; // X-Forwarded-For: append client IP to existing chain let client_ip = peer_addr.ip().to_string(); let xff_value = if let Some(existing) = upstream_headers.get("x-forwarded-for") { format!("{}, {}", existing.to_str().unwrap_or(""), client_ip) } else { client_ip }; if let Ok(val) = hyper::header::HeaderValue::from_str(&xff_value) { upstream_headers.insert( hyper::header::HeaderName::from_static("x-forwarded-for"), val, ); } // X-Forwarded-Host: original Host header if let Ok(val) = hyper::header::HeaderValue::from_str(original_host) { upstream_headers.insert( hyper::header::HeaderName::from_static("x-forwarded-host"), val, ); } // X-Forwarded-Proto: original client protocol if let Ok(val) = hyper::header::HeaderValue::from_str(forwarded_proto) { upstream_headers.insert( hyper::header::HeaderName::from_static("x-forwarded-proto"), val, ); } } // --- Resolve protocol decision based on backend protocol mode --- let is_auto_detect_mode = matches!(backend_protocol_mode, rustproxy_config::BackendProtocol::Auto); let (use_h2, needs_alpn_probe) = match backend_protocol_mode { rustproxy_config::BackendProtocol::Http1 => (false, false), rustproxy_config::BackendProtocol::Http2 => (true, false), rustproxy_config::BackendProtocol::Auto => { if !upstream.use_tls { // No ALPN without TLS — default to H1 (false, false) } else { let cache_key = crate::protocol_cache::ProtocolCacheKey { host: upstream.host.clone(), port: upstream.port, }; match self.protocol_cache.get(&cache_key) { Some(crate::protocol_cache::DetectedProtocol::H2) => (true, false), Some(crate::protocol_cache::DetectedProtocol::H1) => (false, false), None => (false, true), // needs ALPN probe } } } }; // --- Connection pooling: try reusing an existing connection first --- // For ALPN probe mode, skip pool checkout (we don't know the protocol yet) if !needs_alpn_probe { let pool_key = crate::connection_pool::PoolKey { host: upstream.host.clone(), port: upstream.port, use_tls: upstream.use_tls, h2: use_h2, }; // H2 pool checkout (H2 senders are Clone and multiplexed) if use_h2 { if let Some(sender) = self.connection_pool.checkout_h2(&pool_key) { let result = self.forward_h2_pooled( sender, parts, body, upstream_headers, &upstream_path, route_match.route, route_id, &ip_str, &pool_key, ).await; self.upstream_selector.connection_ended(&upstream_key); return result; } } } // --- Fresh connection path --- // Choose TLS config: use ALPN config for auto-detect probe, plain config otherwise let tls_config = if needs_alpn_probe { &self.backend_tls_config_alpn } else { &self.backend_tls_config }; // Establish backend connection let (backend, detected_h2) = if upstream.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(tls_config, &upstream.host, upstream.port), ).await { Ok(Ok(tls)) => { let final_h2 = if needs_alpn_probe { // Read the ALPN-negotiated protocol from the TLS connection let alpn = tls.get_ref().1.alpn_protocol(); let is_h2 = alpn.map(|p| p == b"h2").unwrap_or(false); // Cache the result let cache_key = crate::protocol_cache::ProtocolCacheKey { host: upstream.host.clone(), port: upstream.port, }; let detected = if is_h2 { crate::protocol_cache::DetectedProtocol::H2 } else { crate::protocol_cache::DetectedProtocol::H1 }; self.protocol_cache.insert(cache_key, detected); debug!( "Auto-detected {} for backend {}:{}", if is_h2 { "HTTP/2" } else { "HTTP/1.1" }, upstream.host, upstream.port ); is_h2 } else { use_h2 }; (BackendStream::Tls(tls), final_h2) } Ok(Err(e)) => { error!("Failed TLS connect to upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); } Err(_) => { error!("Upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), ).await { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); (BackendStream::Plain(s), use_h2) } Ok(Err(e)) => { error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); } Err(_) => { error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); } } }; let final_pool_key = crate::connection_pool::PoolKey { host: upstream.host.clone(), port: upstream.port, use_tls: upstream.use_tls, h2: detected_h2, }; let io = TokioIo::new(backend); let result = if detected_h2 { if is_auto_detect_mode { // Auto-detect mode: use fallback-capable H2 forwarding self.forward_h2_with_fallback( io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &final_pool_key, ).await } else { // Explicit H2 mode: hard-fail on handshake error (preserved behavior) self.forward_h2( io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &final_pool_key, ).await } } else { self.forward_h1( io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &final_pool_key, ).await }; self.upstream_selector.connection_ended(&upstream_key); result } /// Forward request to backend via HTTP/1.1 with body streaming. /// Tries a pooled connection first; if unavailable, uses the fresh IO connection. async fn forward_h1( &self, io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, _upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { // Try pooled H1 connection first — avoids TCP+TLS handshake if let Some(pooled_sender) = self.connection_pool.checkout_h1(pool_key) { return self.forward_h1_with_sender( pooled_sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, pool_key, ).await; } // Fresh connection: explicitly type the handshake with BoxBody for uniform pool type let (sender, conn): ( hyper::client::conn::http1::SendRequest>, hyper::client::conn::http1::Connection, BoxBody>, ) = match hyper::client::conn::http1::handshake(io).await { Ok(h) => h, Err(e) => { error!("Upstream handshake failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed")); } }; tokio::spawn(async move { if let Err(e) = conn.await { debug!("Upstream connection error: {}", e); } }); self.forward_h1_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, pool_key).await } /// Common H1 forwarding logic used by both fresh and pooled paths. async fn forward_h1_with_sender( &self, mut sender: hyper::client::conn::http1::SendRequest>, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { // Always use HTTP/1.1 for h1 backend connections (h2 incoming requests have version HTTP/2.0) let mut upstream_req = Request::builder() .method(parts.method) .uri(upstream_path) .version(hyper::Version::HTTP_11); if let Some(headers) = upstream_req.headers_mut() { *headers = upstream_headers; } // Wrap the request body in CountingBody then box it for the uniform pool type let counting_req_body = CountingBody::new( body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::In, ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); let upstream_response = match sender.send_request(upstream_req).await { Ok(resp) => resp, Err(e) => { error!("Upstream request failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed")); } }; // Return sender to pool (body streams lazily, sender is reusable once response head is received) self.connection_pool.checkin_h1(pool_key.clone(), sender); self.build_streaming_response(upstream_response, route, route_id, source_ip).await } /// Forward request to backend via HTTP/2 with body streaming (fresh connection). /// Registers the h2 sender in the pool for future multiplexed requests. async fn forward_h2( &self, io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, _upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { let exec = hyper_util::rt::TokioExecutor::new(); // Explicitly type the handshake with BoxBody for uniform pool type let (sender, conn): ( hyper::client::conn::http2::SendRequest>, hyper::client::conn::http2::Connection, BoxBody, hyper_util::rt::TokioExecutor>, ) = match hyper::client::conn::http2::handshake(exec, io).await { Ok(h) => h, Err(e) => { error!("HTTP/2 upstream handshake failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed")); } }; tokio::spawn(async move { if let Err(e) = conn.await { debug!("HTTP/2 upstream connection error: {}", e); } }); // Register for multiplexed reuse self.connection_pool.register_h2(pool_key.clone(), sender.clone()); self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await } /// Forward request using an existing (pooled) HTTP/2 sender. async fn forward_h2_pooled( &self, sender: hyper::client::conn::http2::SendRequest>, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, _pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await } /// Forward via HTTP/2 with fallback to HTTP/1.1 (auto-detect mode). /// /// Handles two failure scenarios: /// 1. H2 handshake fails → reconnects and falls back to H1 (body not consumed yet). /// 2. H2 handshake "succeeds" but request fails (backend advertises h2 via ALPN but /// doesn't actually speak h2) → updates cache to H1. The request body is consumed /// so this request fails, but all subsequent requests will correctly use H1. async fn forward_h2_with_fallback( &self, io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { let exec = hyper_util::rt::TokioExecutor::new(); let handshake_result: Result<( hyper::client::conn::http2::SendRequest>, hyper::client::conn::http2::Connection, BoxBody, hyper_util::rt::TokioExecutor>, ), hyper::Error> = hyper::client::conn::http2::handshake(exec, io).await; match handshake_result { Ok((mut sender, conn)) => { tokio::spawn(async move { if let Err(e) = conn.await { debug!("HTTP/2 upstream connection error: {}", e); } }); // Build and send the h2 request inline (don't register in pool yet — // we need to verify the request actually succeeds first, because some // backends advertise h2 via ALPN but don't speak the h2 binary protocol). let mut upstream_req = Request::builder() .method(parts.method) .uri(upstream_path); if let Some(headers) = upstream_req.headers_mut() { *headers = upstream_headers; } let counting_req_body = CountingBody::new( body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::In, ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); match sender.send_request(upstream_req).await { Ok(upstream_response) => { // H2 works! Register sender in pool for multiplexed reuse self.connection_pool.register_h2(pool_key.clone(), sender); self.build_streaming_response(upstream_response, route, route_id, source_ip).await } Err(e) => { // H2 request failed — backend advertises h2 via ALPN but doesn't // actually speak it. Update cache so future requests use H1. // The request body is consumed so this request can't be retried, // but all subsequent requests will correctly use H1. warn!( "Auto-detect: H2 request failed for {}:{}, updating cache to H1: {}", upstream.host, upstream.port, e ); let cache_key = crate::protocol_cache::ProtocolCacheKey { host: upstream.host.clone(), port: upstream.port, }; self.protocol_cache.insert(cache_key, crate::protocol_cache::DetectedProtocol::H1); Ok(error_response(StatusCode::BAD_GATEWAY, "Backend protocol mismatch, retrying with H1")) } } } Err(e) => { // H2 handshake truly failed — fall back to H1 // Body is NOT consumed yet, so we can retry the full request. warn!( "H2 handshake failed for {}:{}, falling back to H1: {}", upstream.host, upstream.port, e ); // Update cache to H1 so subsequent requests skip H2 let cache_key = crate::protocol_cache::ProtocolCacheKey { host: upstream.host.clone(), port: upstream.port, }; self.protocol_cache.insert(cache_key, crate::protocol_cache::DetectedProtocol::H1); // Reconnect for H1 (the original io was consumed by the failed h2 handshake) match self.reconnect_backend(upstream).await { Some(fallback_backend) => { let h1_pool_key = crate::connection_pool::PoolKey { host: upstream.host.clone(), port: upstream.port, use_tls: upstream.use_tls, h2: false, }; let fallback_io = TokioIo::new(fallback_backend); self.forward_h1( fallback_io, parts, body, upstream_headers, upstream_path, upstream, route, route_id, source_ip, &h1_pool_key, ).await } None => { Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable after H2 fallback")) } } } } } /// Reconnect to a backend (used for H2→H1 fallback). async fn reconnect_backend( &self, upstream: &crate::upstream_selector::UpstreamSelection, ) -> Option { if upstream.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port), ).await { Ok(Ok(tls)) => Some(BackendStream::Tls(tls)), Ok(Err(e)) => { error!("H1 fallback: TLS reconnect failed for {}:{}: {}", upstream.host, upstream.port, e); None } Err(_) => { error!("H1 fallback: TLS reconnect timeout for {}:{}", upstream.host, upstream.port); None } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), ).await { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); Some(BackendStream::Plain(s)) } Ok(Err(e)) => { error!("H1 fallback: reconnect failed for {}:{}: {}", upstream.host, upstream.port, e); None } Err(_) => { error!("H1 fallback: reconnect timeout for {}:{}", upstream.host, upstream.port); None } } } } /// Common H2 forwarding logic used by both fresh and pooled paths. async fn forward_h2_with_sender( &self, mut sender: hyper::client::conn::http2::SendRequest>, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, ) -> Result>, hyper::Error> { let mut upstream_req = Request::builder() .method(parts.method) .uri(upstream_path); if let Some(headers) = upstream_req.headers_mut() { *headers = upstream_headers; } // Wrap the request body in CountingBody then box it for the uniform pool type let counting_req_body = CountingBody::new( body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::In, ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); let upstream_response = match sender.send_request(upstream_req).await { Ok(resp) => resp, Err(e) => { error!("HTTP/2 upstream request failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed")); } }; self.build_streaming_response(upstream_response, route, route_id, source_ip).await } /// Build the client-facing response from an upstream response, streaming the body. /// /// The response body is wrapped in a `CountingBody` that counts bytes as they /// stream from upstream to client. async fn build_streaming_response( &self, upstream_response: Response, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, ) -> Result>, hyper::Error> { let (resp_parts, resp_body) = upstream_response.into_parts(); let mut response = Response::builder() .status(resp_parts.status); if let Some(headers) = response.headers_mut() { *headers = resp_parts.headers; ResponseFilter::apply_headers(route, headers, None); } // Wrap the response body in CountingBody to track bytes_out. // CountingBody will report bytes and we close the connection metric // after the body stream completes (not before it even starts). let counting_body = CountingBody::new( resp_body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::Out, ); let body: BoxBody = BoxBody::new(counting_body); Ok(response.body(body).unwrap()) } /// Handle a WebSocket upgrade request. async fn handle_websocket_upgrade( &self, req: Request, peer_addr: std::net::SocketAddr, upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, upstream_key: &str, cancel: CancellationToken, source_ip: &str, ) -> Result>, hyper::Error> { use tokio::io::{AsyncReadExt, AsyncWriteExt}; // Get WebSocket config from route let ws_config = route.action.websocket.as_ref(); // Check allowed origins if configured if let Some(ws) = ws_config { if let Some(ref allowed_origins) = ws.allowed_origins { let origin = req.headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or(""); if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) { self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed")); } } } info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port); // Connect to upstream with timeout (TLS if upstream.use_tls is set) let mut upstream_stream: BackendStream = if upstream.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port), ).await { Ok(Ok(tls)) => BackendStream::Tls(tls), Ok(Err(e)) => { error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); } Err(_) => { error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), ).await { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); BackendStream::Plain(s) } Ok(Err(e)) => { error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); } Err(_) => { error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); } } }; let path = req.uri().path().to_string(); let upstream_path = { let raw = match req.uri().query() { Some(q) => format!("{}?{}", path, q), None => path, }; // Apply rewrite_path if configured if let Some(ws) = ws_config { if let Some(ref rewrite_path) = ws.rewrite_path { rewrite_path.clone() } else { raw } } else { raw } }; let (parts, _body) = req.into_parts(); let mut raw_request = format!( "{} {} HTTP/1.1\r\n", parts.method, upstream_path ); // Copy all original headers (preserving the client's Host header). // Skip X-Forwarded-* since we set them ourselves below. let mut has_host_header = false; for (name, value) in parts.headers.iter() { let name_str = name.as_str(); if name_str == "x-forwarded-for" || name_str == "x-forwarded-host" || name_str == "x-forwarded-proto" { continue; } if name_str == "host" { has_host_header = true; } raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or(""))); } // HTTP/2 requests don't have Host header; add one from URI authority for h1 backends let ws_host = parts.uri.host().map(|h| h.to_string()); if !has_host_header { if let Some(ref h) = ws_host { raw_request.push_str(&format!("host: {}\r\n", h)); } } // Add standard reverse-proxy headers (X-Forwarded-*) { let original_host = parts.headers.get("host") .and_then(|h| h.to_str().ok()) .or(ws_host.as_deref()) .unwrap_or(""); let forwarded_proto = if route.action.tls.as_ref() .map(|t| matches!(t.mode, rustproxy_config::TlsMode::Terminate | rustproxy_config::TlsMode::TerminateAndReencrypt)) .unwrap_or(false) { "https" } else { "http" }; let client_ip = peer_addr.ip().to_string(); let xff_value = if let Some(existing) = parts.headers.get("x-forwarded-for") { format!("{}, {}", existing.to_str().unwrap_or(""), client_ip) } else { client_ip }; raw_request.push_str(&format!("x-forwarded-for: {}\r\n", xff_value)); raw_request.push_str(&format!("x-forwarded-host: {}\r\n", original_host)); raw_request.push_str(&format!("x-forwarded-proto: {}\r\n", forwarded_proto)); } if let Some(ref route_headers) = route.headers { if let Some(ref request_headers) = route_headers.request { for (key, value) in request_headers { raw_request.push_str(&format!("{}: {}\r\n", key, value)); } } } // Apply WebSocket custom headers if let Some(ws) = ws_config { if let Some(ref custom_headers) = ws.custom_headers { for (key, value) in custom_headers { raw_request.push_str(&format!("{}: {}\r\n", key, value)); } } } raw_request.push_str("\r\n"); if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await { error!("WebSocket: failed to send upgrade request to upstream: {}", e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed")); } let mut response_buf = Vec::with_capacity(4096); let mut temp = [0u8; 1]; loop { match upstream_stream.read(&mut temp).await { Ok(0) => { error!("WebSocket: upstream closed before completing handshake"); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed")); } Ok(_) => { response_buf.push(temp[0]); if response_buf.len() >= 4 { let len = response_buf.len(); if response_buf[len-4..] == *b"\r\n\r\n" { break; } } if response_buf.len() > 8192 { error!("WebSocket: upstream response headers too large"); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large")); } } Err(e) => { error!("WebSocket: failed to read upstream response: {}", e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed")); } } } let response_str = String::from_utf8_lossy(&response_buf); let status_line = response_str.lines().next().unwrap_or(""); let status_code = status_line .split_whitespace() .nth(1) .and_then(|s| s.parse::().ok()) .unwrap_or(0); if status_code != 101 { debug!("WebSocket: upstream rejected upgrade with status {}", status_code); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response( StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY), "WebSocket upgrade rejected by backend", )); } let mut client_resp = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS); if let Some(resp_headers) = client_resp.headers_mut() { for line in response_str.lines().skip(1) { let line = line.trim(); if line.is_empty() { break; } if let Some((name, value)) = line.split_once(':') { let name = name.trim(); let value = value.trim(); if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) { if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) { resp_headers.insert(header_name, header_value); } } } } } let on_client_upgrade = hyper::upgrade::on( Request::from_parts(parts, http_body_util::Empty::::new()) ); let metrics = Arc::clone(&self.metrics); let route_id_owned = route_id.map(|s| s.to_string()); let source_ip_owned = source_ip.to_string(); let upstream_selector = self.upstream_selector.clone(); let upstream_key_owned = upstream_key.to_string(); let ws_inactivity_timeout = self.ws_inactivity_timeout; let ws_max_lifetime = self.ws_max_lifetime; tokio::spawn(async move { let client_upgraded = match on_client_upgrade.await { Ok(upgraded) => upgraded, Err(e) => { debug!("WebSocket: client upgrade failed: {}", e); upstream_selector.connection_ended(&upstream_key_owned); return; } }; let client_io = TokioIo::new(client_upgraded); let (mut cr, mut cw) = tokio::io::split(client_io); let (mut ur, mut uw) = tokio::io::split(upstream_stream); // Shared activity tracker for the watchdog let last_activity = Arc::new(AtomicU64::new(0)); let start = std::time::Instant::now(); let la1 = Arc::clone(&last_activity); let c2u = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match cr.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if uw.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } let _ = uw.shutdown().await; total }); let la2 = Arc::clone(&last_activity); let u2c = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match ur.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if cw.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } let _ = cw.shutdown().await; total }); // Watchdog: monitors inactivity, max lifetime, and cancellation let la_watch = Arc::clone(&last_activity); let c2u_handle = c2u.abort_handle(); let u2c_handle = u2c.abort_handle(); let inactivity_timeout = ws_inactivity_timeout; let max_lifetime = ws_max_lifetime; let watchdog = tokio::spawn(async move { let check_interval = std::time::Duration::from_secs(5); let mut last_seen = 0u64; loop { tokio::select! { _ = tokio::time::sleep(check_interval) => {} _ = cancel.cancelled() => { debug!("WebSocket tunnel cancelled by shutdown"); c2u_handle.abort(); u2c_handle.abort(); break; } } // Check max lifetime if start.elapsed() >= max_lifetime { debug!("WebSocket tunnel exceeded max lifetime, closing"); c2u_handle.abort(); u2c_handle.abort(); break; } // Check inactivity let current = la_watch.load(Ordering::Relaxed); if current == last_seen { let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { debug!("WebSocket tunnel inactive for {}ms, closing", elapsed_since_activity); c2u_handle.abort(); u2c_handle.abort(); break; } } last_seen = current; } }); let bytes_in = c2u.await.unwrap_or(0); let bytes_out = u2c.await.unwrap_or(0); watchdog.abort(); debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out); upstream_selector.connection_ended(&upstream_key_owned); if let Some(ref rid) = route_id_owned { metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()), Some(&source_ip_owned)); } }); let body: BoxBody = BoxBody::new( http_body_util::Empty::::new().map_err(|never| match never {}) ); Ok(client_resp.body(body).unwrap()) } /// Build a test response from config (no upstream connection needed). fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response> { let mut response = Response::builder() .status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK)); if let Some(headers) = response.headers_mut() { for (key, value) in &config.headers { if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) { if let Ok(val) = hyper::header::HeaderValue::from_str(value) { headers.insert(name, val); } } } } let body = Full::new(Bytes::from(config.body.clone())) .map_err(|never| match never {}); response.body(BoxBody::new(body)).unwrap() } /// Apply URL rewriting rules from route config, using the compiled regex cache. fn apply_url_rewrite(&self, path: &str, route: &rustproxy_config::RouteConfig) -> String { let rewrite = match route.action.advanced.as_ref() .and_then(|a| a.url_rewrite.as_ref()) { Some(r) => r, None => return path.to_string(), }; // Determine what to rewrite let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) { // Only rewrite the path portion (before ?) match path.split_once('?') { Some((p, q)) => (p.to_string(), format!("?{}", q)), None => (path.to_string(), String::new()), } } else { (path.to_string(), String::new()) }; // Look up or compile the regex, caching for future requests let cached = self.regex_cache.get(&rewrite.pattern); if let Some(re) = cached { let result = re.replace_all(&subject, rewrite.target.as_str()); return format!("{}{}", result, suffix); } // Not cached — compile and insert match Regex::new(&rewrite.pattern) { Ok(re) => { let result = re.replace_all(&subject, rewrite.target.as_str()); let out = format!("{}{}", result, suffix); self.regex_cache.insert(rewrite.pattern.clone(), re); out } Err(e) => { warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e); path.to_string() } } } /// Serve a static file from the configured directory. fn serve_static_file( path: &str, config: &rustproxy_config::RouteStaticFiles, ) -> Response> { use std::path::Path; let root = Path::new(&config.root); // Sanitize path to prevent directory traversal let clean_path = path.trim_start_matches('/'); let clean_path = clean_path.replace("..", ""); let mut file_path = root.join(&clean_path); // If path points to a directory, try index files if file_path.is_dir() || clean_path.is_empty() { let index_files = config.index_files.as_deref() .or(config.index.as_deref()) .unwrap_or(&[]); let default_index = vec!["index.html".to_string()]; let index_files = if index_files.is_empty() { &default_index } else { index_files }; let mut found = false; for index in index_files { let candidate = if clean_path.is_empty() { root.join(index) } else { file_path.join(index) }; if candidate.is_file() { file_path = candidate; found = true; break; } } if !found { return error_response(StatusCode::NOT_FOUND, "Not found"); } } // Ensure the resolved path is within the root (prevent traversal) let canonical_root = match root.canonicalize() { Ok(p) => p, Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"), }; let canonical_file = match file_path.canonicalize() { Ok(p) => p, Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"), }; if !canonical_file.starts_with(&canonical_root) { return error_response(StatusCode::FORBIDDEN, "Forbidden"); } // Check if symlinks are allowed if config.follow_symlinks == Some(false) && canonical_file != file_path { return error_response(StatusCode::FORBIDDEN, "Forbidden"); } // Read the file match std::fs::read(&file_path) { Ok(content) => { let content_type = guess_content_type(&file_path); let mut response = Response::builder() .status(StatusCode::OK) .header("Content-Type", content_type); // Apply cache-control if configured if let Some(ref cache_control) = config.cache_control { response = response.header("Cache-Control", cache_control.as_str()); } // Apply custom headers if let Some(ref headers) = config.headers { for (key, value) in headers { response = response.header(key.as_str(), value.as_str()); } } let body = Full::new(Bytes::from(content)) .map_err(|never| match never {}); response.body(BoxBody::new(body)).unwrap() } Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"), } } } /// Guess MIME content type from file extension. fn guess_content_type(path: &std::path::Path) -> &'static str { match path.extension().and_then(|e| e.to_str()) { Some("html") | Some("htm") => "text/html; charset=utf-8", Some("css") => "text/css; charset=utf-8", Some("js") | Some("mjs") => "application/javascript; charset=utf-8", Some("json") => "application/json; charset=utf-8", Some("xml") => "application/xml; charset=utf-8", Some("txt") => "text/plain; charset=utf-8", Some("png") => "image/png", Some("jpg") | Some("jpeg") => "image/jpeg", Some("gif") => "image/gif", Some("svg") => "image/svg+xml", Some("ico") => "image/x-icon", Some("woff") => "font/woff", Some("woff2") => "font/woff2", Some("ttf") => "font/ttf", Some("pdf") => "application/pdf", Some("wasm") => "application/wasm", _ => "application/octet-stream", } } impl HttpProxyService { /// Build a default backend TLS config with InsecureVerifier. /// Used as fallback when no shared config is injected from tls_handler. fn default_backend_tls_config() -> Arc { let _ = rustls::crypto::ring::default_provider().install_default(); let config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier)) .with_no_client_auth(); Arc::new(config) } /// Build a default backend TLS config with ALPN h2+http/1.1 for auto-detection. /// Used as fallback when no shared ALPN config is injected from tls_handler. fn default_backend_tls_config_with_alpn() -> Arc { let _ = rustls::crypto::ring::default_provider().install_default(); let mut config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier)) .with_no_client_auth(); config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Arc::new(config) } } /// Insecure certificate verifier for backend TLS connections (fallback only). /// The production path uses the shared config from tls_handler which has the same /// behavior but with session resumption across all outbound connections. #[derive(Debug)] struct InsecureBackendVerifier; impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier { fn verify_server_cert( &self, _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, ] } } impl Default for HttpProxyService { fn default() -> Self { Self { route_manager: Arc::new(RouteManager::new(vec![])), metrics: Arc::new(MetricsCollector::new()), upstream_selector: UpstreamSelector::new(), connect_timeout: DEFAULT_CONNECT_TIMEOUT, route_rate_limiters: Arc::new(DashMap::new()), request_counter: AtomicU64::new(0), regex_cache: DashMap::new(), backend_tls_config: Self::default_backend_tls_config(), backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(), connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()), protocol_cache: Arc::new(crate::protocol_cache::ProtocolCache::new()), http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT, ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT, ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME, } } } fn error_response(status: StatusCode, message: &str) -> Response> { let body = Full::new(Bytes::from(message.to_string())) .map_err(|never| match never {}); Response::builder() .status(status) .header("Content-Type", "text/plain") .body(BoxBody::new(body)) .unwrap() }