//! Hyper-based HTTP proxy service. //! //! Accepts decrypted TCP streams (from TLS termination or plain TCP), //! parses HTTP requests, matches routes, and forwards to upstream backends. //! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade. use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use bytes::Bytes; use dashmap::DashMap; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use regex::Regex; use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; use std::pin::Pin; use std::task::{Context, Poll}; use rustproxy_routing::RouteManager; use rustproxy_metrics::MetricsCollector; use rustproxy_security::RateLimiter; use crate::counting_body::{CountingBody, Direction}; use crate::request_filter::RequestFilter; use crate::response_filter::ResponseFilter; use crate::upstream_selector::UpstreamSelector; /// Default upstream connect timeout (30 seconds). const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); /// Default WebSocket inactivity timeout (1 hour). const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3600); /// Default WebSocket max lifetime (24 hours). const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400); /// Backend stream that can be either plain TCP or TLS-wrapped. /// Used for `terminate-and-reencrypt` mode where the backend requires TLS. pub(crate) enum BackendStream { Plain(TcpStream), Tls(tokio_rustls::client::TlsStream), } impl tokio::io::AsyncRead for BackendStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_read(cx, buf), BackendStream::Tls(s) => Pin::new(s).poll_read(cx, buf), } } } impl tokio::io::AsyncWrite for BackendStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_write(cx, buf), BackendStream::Tls(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_flush(cx), BackendStream::Tls(s) => Pin::new(s).poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { BackendStream::Plain(s) => Pin::new(s).poll_shutdown(cx), BackendStream::Tls(s) => Pin::new(s).poll_shutdown(cx), } } } /// Connect to a backend over TLS using the shared backend TLS config /// (from tls_handler). Session resumption is automatic. async fn connect_tls_backend( backend_tls_config: &Arc, host: &str, port: u16, ) -> Result, Box> { let connector = tokio_rustls::TlsConnector::from(Arc::clone(backend_tls_config)); let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; stream.set_nodelay(true)?; // Apply keepalive with 60s default let _ = socket2::SockRef::from(&stream).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?; let tls_stream = connector.connect(server_name, stream).await?; debug!("Backend TLS connection established to {}:{}", host, port); Ok(tls_stream) } /// HTTP proxy service that processes HTTP traffic. pub struct HttpProxyService { route_manager: Arc, metrics: Arc, upstream_selector: UpstreamSelector, /// Timeout for connecting to upstream backends. connect_timeout: std::time::Duration, /// Per-route rate limiters (keyed by route ID). route_rate_limiters: Arc>>, /// Request counter for periodic rate limiter cleanup. request_counter: AtomicU64, /// Cache of compiled URL rewrite regexes (keyed by pattern string). regex_cache: DashMap, /// Shared backend TLS config for session resumption across connections. backend_tls_config: Arc, /// Backend connection pool for reusing keep-alive connections. connection_pool: Arc, } impl HttpProxyService { pub fn new(route_manager: Arc, metrics: Arc) -> Self { Self { route_manager, metrics, upstream_selector: UpstreamSelector::new(), connect_timeout: DEFAULT_CONNECT_TIMEOUT, route_rate_limiters: Arc::new(DashMap::new()), request_counter: AtomicU64::new(0), regex_cache: DashMap::new(), backend_tls_config: Self::default_backend_tls_config(), connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()), } } /// Create with a custom connect timeout. pub fn with_connect_timeout( route_manager: Arc, metrics: Arc, connect_timeout: std::time::Duration, ) -> Self { Self { route_manager, metrics, upstream_selector: UpstreamSelector::new(), connect_timeout, route_rate_limiters: Arc::new(DashMap::new()), request_counter: AtomicU64::new(0), regex_cache: DashMap::new(), backend_tls_config: Self::default_backend_tls_config(), connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()), } } /// Set the shared backend TLS config (enables session resumption). /// Call this after construction to inject the shared config from tls_handler. pub fn set_backend_tls_config(&mut self, config: Arc) { self.backend_tls_config = config; } /// Prune caches for route IDs that are no longer active. /// Call after route updates to prevent unbounded growth. pub fn prune_stale_routes(&self, active_route_ids: &std::collections::HashSet) { self.route_rate_limiters.retain(|k, _| active_route_ids.contains(k)); self.regex_cache.clear(); self.upstream_selector.reset_round_robin(); } /// Handle an incoming HTTP connection on a plain TCP stream. pub async fn handle_connection( self: Arc, stream: TcpStream, peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, ) { self.handle_io(stream, peer_addr, port, cancel).await; } /// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated). /// /// Uses `hyper_util::server::conn::auto::Builder` to auto-detect h1 vs h2 /// based on ALPN negotiation (TLS) or connection preface (h2c). /// Supports HTTP/1.1 upgrades (WebSocket) and HTTP/2 CONNECT. /// Responds to graceful shutdown via the cancel token. pub async fn handle_io( self: Arc, stream: I, peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, ) where I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { let io = TokioIo::new(stream); let cancel_inner = cancel.clone(); let service = hyper::service::service_fn(move |req: Request| { let svc = Arc::clone(&self); let peer = peer_addr; let cn = cancel_inner.clone(); async move { svc.handle_request(req, peer, port, cn).await } }); // Auto-detect h1 vs h2 based on ALPN / connection preface. // serve_connection_with_upgrades supports h1 Upgrade (WebSocket) and h2 CONNECT. let builder = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); let conn = builder.serve_connection_with_upgrades(io, service); // Pin on the heap — auto::UpgradeableConnection is !Unpin let mut conn = Box::pin(conn); // Use select to support graceful shutdown via cancellation token tokio::select! { result = conn.as_mut() => { if let Err(e) = result { debug!("HTTP connection error from {}: {}", peer_addr, e); } } _ = cancel.cancelled() => { // Graceful shutdown: let in-flight request finish, stop accepting new ones conn.as_mut().graceful_shutdown(); if let Err(e) = conn.await { debug!("HTTP connection error during shutdown from {}: {}", peer_addr, e); } } } } /// Handle a single HTTP request. async fn handle_request( &self, req: Request, peer_addr: std::net::SocketAddr, port: u16, cancel: CancellationToken, ) -> Result>, hyper::Error> { let host = req.headers() .get("host") .and_then(|h| h.to_str().ok()) .map(|h| { // Strip port from host header h.split(':').next().unwrap_or(h).to_string() }) // HTTP/2 uses :authority pseudo-header instead of Host; // hyper maps it to the URI authority component .or_else(|| req.uri().host().map(|h| h.to_string())); let path = req.uri().path().to_string(); let method = req.method().clone(); // Extract headers for matching let headers: HashMap = req.headers() .iter() .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) .collect(); debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr); // Check for CORS preflight if method == hyper::Method::OPTIONS { if let Some(response) = RequestFilter::handle_cors_preflight(&req) { return Ok(response); } } // Match route let ctx = rustproxy_routing::MatchContext { port, domain: host.as_deref(), path: Some(&path), client_ip: Some(&peer_addr.ip().to_string()), tls_version: None, headers: Some(&headers), is_tls: false, protocol: Some("http"), }; let route_match = match self.route_manager.find_route(&ctx) { Some(rm) => rm, None => { debug!("No route matched for HTTP request to {:?}{}", host, path); return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched")); } }; let route_id = route_match.route.id.as_deref(); let ip_str = peer_addr.ip().to_string(); self.metrics.record_http_request(); // Apply request filters (IP check, rate limiting, auth) if let Some(ref security) = route_match.route.security { // Look up or create a shared rate limiter for this route let rate_limiter = security.rate_limit.as_ref() .filter(|rl| rl.enabled) .map(|rl| { let route_key = route_id.unwrap_or("__default__").to_string(); self.route_rate_limiters .entry(route_key) .or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window))) .clone() }); if let Some(response) = RequestFilter::apply_with_rate_limiter( security, &req, &peer_addr, rate_limiter.as_ref(), ) { return Ok(response); } } // Periodic rate limiter cleanup (every 1000 requests) let count = self.request_counter.fetch_add(1, Ordering::Relaxed); if count % 1000 == 0 { for entry in self.route_rate_limiters.iter() { entry.value().cleanup(); } } // Check for test response (returns immediately, no upstream needed) if let Some(ref advanced) = route_match.route.action.advanced { if let Some(ref test_response) = advanced.test_response { return Ok(Self::build_test_response(test_response)); } } // Check for static file serving if let Some(ref advanced) = route_match.route.action.advanced { if let Some(ref static_files) = advanced.static_files { return Ok(Self::serve_static_file(&path, static_files)); } } // Select upstream let target = match route_match.target { Some(t) => t, None => { return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available")); } }; let mut upstream = self.upstream_selector.select(target, &peer_addr, port); // If the route uses terminate-and-reencrypt, always re-encrypt to backend if let Some(ref tls) = route_match.route.action.tls { if tls.mode == rustproxy_config::TlsMode::TerminateAndReencrypt { upstream.use_tls = true; } } let upstream_key = format!("{}:{}", upstream.host, upstream.port); self.upstream_selector.connection_started(&upstream_key); // Check for WebSocket upgrade let is_websocket = req.headers() .get("upgrade") .and_then(|v| v.to_str().ok()) .map(|v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false); if is_websocket { let result = self.handle_websocket_upgrade( req, peer_addr, &upstream, route_match.route, route_id, &upstream_key, cancel, &ip_str, ).await; // Note: for WebSocket, connection_ended is called inside // the spawned tunnel task when the connection closes. return result; } // Determine backend protocol let use_h2 = route_match.route.action.options.as_ref() .and_then(|o| o.backend_protocol.as_ref()) .map(|p| *p == rustproxy_config::BackendProtocol::Http2) .unwrap_or(false); // Build the upstream path (path + query), applying URL rewriting if configured let upstream_path = { let raw_path = match req.uri().query() { Some(q) => format!("{}?{}", path, q), None => path.clone(), }; self.apply_url_rewrite(&raw_path, &route_match.route) }; // Build upstream request - stream body instead of buffering let (parts, body) = req.into_parts(); // Apply request headers from route config let mut upstream_headers = parts.headers.clone(); if let Some(ref route_headers) = route_match.route.headers { if let Some(ref request_headers) = route_headers.request { for (key, value) in request_headers { if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) { if let Ok(val) = hyper::header::HeaderValue::from_str(value) { upstream_headers.insert(name, val); } } } } } // Ensure Host header is set (HTTP/2 requests don't have Host; need it for h1 backends) if !upstream_headers.contains_key("host") { if let Some(ref h) = host { if let Ok(val) = hyper::header::HeaderValue::from_str(h) { upstream_headers.insert(hyper::header::HOST, val); } } } // Add standard reverse-proxy headers (X-Forwarded-*) { let original_host = host.as_deref().unwrap_or(""); let forwarded_proto = if route_match.route.action.tls.as_ref() .map(|t| matches!(t.mode, rustproxy_config::TlsMode::Terminate | rustproxy_config::TlsMode::TerminateAndReencrypt)) .unwrap_or(false) { "https" } else { "http" }; // X-Forwarded-For: append client IP to existing chain let client_ip = peer_addr.ip().to_string(); let xff_value = if let Some(existing) = upstream_headers.get("x-forwarded-for") { format!("{}, {}", existing.to_str().unwrap_or(""), client_ip) } else { client_ip }; if let Ok(val) = hyper::header::HeaderValue::from_str(&xff_value) { upstream_headers.insert( hyper::header::HeaderName::from_static("x-forwarded-for"), val, ); } // X-Forwarded-Host: original Host header if let Ok(val) = hyper::header::HeaderValue::from_str(original_host) { upstream_headers.insert( hyper::header::HeaderName::from_static("x-forwarded-host"), val, ); } // X-Forwarded-Proto: original client protocol if let Ok(val) = hyper::header::HeaderValue::from_str(forwarded_proto) { upstream_headers.insert( hyper::header::HeaderName::from_static("x-forwarded-proto"), val, ); } } // --- Connection pooling: try reusing an existing connection first --- let pool_key = crate::connection_pool::PoolKey { host: upstream.host.clone(), port: upstream.port, use_tls: upstream.use_tls, h2: use_h2, }; // Try pooled connection first (H2 only — H2 senders are Clone and multiplexed, // so checkout doesn't consume request parts. For H1, we try pool inside forward_h1.) if use_h2 { if let Some(sender) = self.connection_pool.checkout_h2(&pool_key) { let result = self.forward_h2_pooled( sender, parts, body, upstream_headers, &upstream_path, route_match.route, route_id, &ip_str, &pool_key, ).await; self.upstream_selector.connection_ended(&upstream_key); return result; } } // Fresh connection path let backend = if upstream.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port), ).await { Ok(Ok(tls)) => BackendStream::Tls(tls), Ok(Err(e)) => { error!("Failed TLS connect to upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); } Err(_) => { error!("Upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), ).await { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); BackendStream::Plain(s) } Ok(Err(e)) => { error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); } Err(_) => { error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(&upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); } } }; let io = TokioIo::new(backend); let result = if use_h2 { self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &pool_key).await } else { self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &pool_key).await }; self.upstream_selector.connection_ended(&upstream_key); result } /// Forward request to backend via HTTP/1.1 with body streaming. /// Tries a pooled connection first; if unavailable, uses the fresh IO connection. async fn forward_h1( &self, io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, _upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { // Try pooled H1 connection first — avoids TCP+TLS handshake if let Some(pooled_sender) = self.connection_pool.checkout_h1(pool_key) { return self.forward_h1_with_sender( pooled_sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, pool_key, ).await; } // Fresh connection: explicitly type the handshake with BoxBody for uniform pool type let (sender, conn): ( hyper::client::conn::http1::SendRequest>, hyper::client::conn::http1::Connection, BoxBody>, ) = match hyper::client::conn::http1::handshake(io).await { Ok(h) => h, Err(e) => { error!("Upstream handshake failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed")); } }; tokio::spawn(async move { if let Err(e) = conn.await { debug!("Upstream connection error: {}", e); } }); self.forward_h1_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, pool_key).await } /// Common H1 forwarding logic used by both fresh and pooled paths. async fn forward_h1_with_sender( &self, mut sender: hyper::client::conn::http1::SendRequest>, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { // Always use HTTP/1.1 for h1 backend connections (h2 incoming requests have version HTTP/2.0) let mut upstream_req = Request::builder() .method(parts.method) .uri(upstream_path) .version(hyper::Version::HTTP_11); if let Some(headers) = upstream_req.headers_mut() { *headers = upstream_headers; } // Wrap the request body in CountingBody then box it for the uniform pool type let counting_req_body = CountingBody::new( body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::In, ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); let upstream_response = match sender.send_request(upstream_req).await { Ok(resp) => resp, Err(e) => { error!("Upstream request failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed")); } }; // Return sender to pool (body streams lazily, sender is reusable once response head is received) self.connection_pool.checkin_h1(pool_key.clone(), sender); self.build_streaming_response(upstream_response, route, route_id, source_ip).await } /// Forward request to backend via HTTP/2 with body streaming (fresh connection). /// Registers the h2 sender in the pool for future multiplexed requests. async fn forward_h2( &self, io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, _upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { let exec = hyper_util::rt::TokioExecutor::new(); // Explicitly type the handshake with BoxBody for uniform pool type let (sender, conn): ( hyper::client::conn::http2::SendRequest>, hyper::client::conn::http2::Connection, BoxBody, hyper_util::rt::TokioExecutor>, ) = match hyper::client::conn::http2::handshake(exec, io).await { Ok(h) => h, Err(e) => { error!("HTTP/2 upstream handshake failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed")); } }; tokio::spawn(async move { if let Err(e) = conn.await { debug!("HTTP/2 upstream connection error: {}", e); } }); // Register for multiplexed reuse self.connection_pool.register_h2(pool_key.clone(), sender.clone()); self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await } /// Forward request using an existing (pooled) HTTP/2 sender. async fn forward_h2_pooled( &self, sender: hyper::client::conn::http2::SendRequest>, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, _pool_key: &crate::connection_pool::PoolKey, ) -> Result>, hyper::Error> { self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await } /// Common H2 forwarding logic used by both fresh and pooled paths. async fn forward_h2_with_sender( &self, mut sender: hyper::client::conn::http2::SendRequest>, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, upstream_path: &str, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, ) -> Result>, hyper::Error> { let mut upstream_req = Request::builder() .method(parts.method) .uri(upstream_path); if let Some(headers) = upstream_req.headers_mut() { *headers = upstream_headers; } // Wrap the request body in CountingBody then box it for the uniform pool type let counting_req_body = CountingBody::new( body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::In, ); let boxed_body: BoxBody = BoxBody::new(counting_req_body); let upstream_req = upstream_req.body(boxed_body).unwrap(); let upstream_response = match sender.send_request(upstream_req).await { Ok(resp) => resp, Err(e) => { error!("HTTP/2 upstream request failed: {}", e); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed")); } }; self.build_streaming_response(upstream_response, route, route_id, source_ip).await } /// Build the client-facing response from an upstream response, streaming the body. /// /// The response body is wrapped in a `CountingBody` that counts bytes as they /// stream from upstream to client. async fn build_streaming_response( &self, upstream_response: Response, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, source_ip: &str, ) -> Result>, hyper::Error> { let (resp_parts, resp_body) = upstream_response.into_parts(); let mut response = Response::builder() .status(resp_parts.status); if let Some(headers) = response.headers_mut() { *headers = resp_parts.headers; ResponseFilter::apply_headers(route, headers, None); } // Wrap the response body in CountingBody to track bytes_out. // CountingBody will report bytes and we close the connection metric // after the body stream completes (not before it even starts). let counting_body = CountingBody::new( resp_body, Arc::clone(&self.metrics), route_id.map(|s| s.to_string()), Some(source_ip.to_string()), Direction::Out, ); let body: BoxBody = BoxBody::new(counting_body); Ok(response.body(body).unwrap()) } /// Handle a WebSocket upgrade request. async fn handle_websocket_upgrade( &self, req: Request, peer_addr: std::net::SocketAddr, upstream: &crate::upstream_selector::UpstreamSelection, route: &rustproxy_config::RouteConfig, route_id: Option<&str>, upstream_key: &str, cancel: CancellationToken, source_ip: &str, ) -> Result>, hyper::Error> { use tokio::io::{AsyncReadExt, AsyncWriteExt}; // Get WebSocket config from route let ws_config = route.action.websocket.as_ref(); // Check allowed origins if configured if let Some(ws) = ws_config { if let Some(ref allowed_origins) = ws.allowed_origins { let origin = req.headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or(""); if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) { self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed")); } } } info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port); // Connect to upstream with timeout (TLS if upstream.use_tls is set) let mut upstream_stream: BackendStream = if upstream.use_tls { match tokio::time::timeout( self.connect_timeout, connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port), ).await { Ok(Ok(tls)) => BackendStream::Tls(tls), Ok(Err(e)) => { error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); } Err(_) => { error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); } } } else { match tokio::time::timeout( self.connect_timeout, TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), ).await { Ok(Ok(s)) => { s.set_nodelay(true).ok(); let _ = socket2::SockRef::from(&s).set_tcp_keepalive( &socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60)) ); BackendStream::Plain(s) } Ok(Err(e)) => { error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); } Err(_) => { error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); } } }; let path = req.uri().path().to_string(); let upstream_path = { let raw = match req.uri().query() { Some(q) => format!("{}?{}", path, q), None => path, }; // Apply rewrite_path if configured if let Some(ws) = ws_config { if let Some(ref rewrite_path) = ws.rewrite_path { rewrite_path.clone() } else { raw } } else { raw } }; let (parts, _body) = req.into_parts(); let mut raw_request = format!( "{} {} HTTP/1.1\r\n", parts.method, upstream_path ); // Copy all original headers (preserving the client's Host header). // Skip X-Forwarded-* since we set them ourselves below. let mut has_host_header = false; for (name, value) in parts.headers.iter() { let name_str = name.as_str(); if name_str == "x-forwarded-for" || name_str == "x-forwarded-host" || name_str == "x-forwarded-proto" { continue; } if name_str == "host" { has_host_header = true; } raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or(""))); } // HTTP/2 requests don't have Host header; add one from URI authority for h1 backends let ws_host = parts.uri.host().map(|h| h.to_string()); if !has_host_header { if let Some(ref h) = ws_host { raw_request.push_str(&format!("host: {}\r\n", h)); } } // Add standard reverse-proxy headers (X-Forwarded-*) { let original_host = parts.headers.get("host") .and_then(|h| h.to_str().ok()) .or(ws_host.as_deref()) .unwrap_or(""); let forwarded_proto = if route.action.tls.as_ref() .map(|t| matches!(t.mode, rustproxy_config::TlsMode::Terminate | rustproxy_config::TlsMode::TerminateAndReencrypt)) .unwrap_or(false) { "https" } else { "http" }; let client_ip = peer_addr.ip().to_string(); let xff_value = if let Some(existing) = parts.headers.get("x-forwarded-for") { format!("{}, {}", existing.to_str().unwrap_or(""), client_ip) } else { client_ip }; raw_request.push_str(&format!("x-forwarded-for: {}\r\n", xff_value)); raw_request.push_str(&format!("x-forwarded-host: {}\r\n", original_host)); raw_request.push_str(&format!("x-forwarded-proto: {}\r\n", forwarded_proto)); } if let Some(ref route_headers) = route.headers { if let Some(ref request_headers) = route_headers.request { for (key, value) in request_headers { raw_request.push_str(&format!("{}: {}\r\n", key, value)); } } } // Apply WebSocket custom headers if let Some(ws) = ws_config { if let Some(ref custom_headers) = ws.custom_headers { for (key, value) in custom_headers { raw_request.push_str(&format!("{}: {}\r\n", key, value)); } } } raw_request.push_str("\r\n"); if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await { error!("WebSocket: failed to send upgrade request to upstream: {}", e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed")); } let mut response_buf = Vec::with_capacity(4096); let mut temp = [0u8; 1]; loop { match upstream_stream.read(&mut temp).await { Ok(0) => { error!("WebSocket: upstream closed before completing handshake"); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed")); } Ok(_) => { response_buf.push(temp[0]); if response_buf.len() >= 4 { let len = response_buf.len(); if response_buf[len-4..] == *b"\r\n\r\n" { break; } } if response_buf.len() > 8192 { error!("WebSocket: upstream response headers too large"); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large")); } } Err(e) => { error!("WebSocket: failed to read upstream response: {}", e); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed")); } } } let response_str = String::from_utf8_lossy(&response_buf); let status_line = response_str.lines().next().unwrap_or(""); let status_code = status_line .split_whitespace() .nth(1) .and_then(|s| s.parse::().ok()) .unwrap_or(0); if status_code != 101 { debug!("WebSocket: upstream rejected upgrade with status {}", status_code); self.upstream_selector.connection_ended(upstream_key); return Ok(error_response( StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY), "WebSocket upgrade rejected by backend", )); } let mut client_resp = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS); if let Some(resp_headers) = client_resp.headers_mut() { for line in response_str.lines().skip(1) { let line = line.trim(); if line.is_empty() { break; } if let Some((name, value)) = line.split_once(':') { let name = name.trim(); let value = value.trim(); if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) { if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) { resp_headers.insert(header_name, header_value); } } } } } let on_client_upgrade = hyper::upgrade::on( Request::from_parts(parts, http_body_util::Empty::::new()) ); let metrics = Arc::clone(&self.metrics); let route_id_owned = route_id.map(|s| s.to_string()); let source_ip_owned = source_ip.to_string(); let upstream_selector = self.upstream_selector.clone(); let upstream_key_owned = upstream_key.to_string(); tokio::spawn(async move { let client_upgraded = match on_client_upgrade.await { Ok(upgraded) => upgraded, Err(e) => { debug!("WebSocket: client upgrade failed: {}", e); upstream_selector.connection_ended(&upstream_key_owned); return; } }; let client_io = TokioIo::new(client_upgraded); let (mut cr, mut cw) = tokio::io::split(client_io); let (mut ur, mut uw) = tokio::io::split(upstream_stream); // Shared activity tracker for the watchdog let last_activity = Arc::new(AtomicU64::new(0)); let start = std::time::Instant::now(); let la1 = Arc::clone(&last_activity); let c2u = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match cr.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if uw.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } let _ = uw.shutdown().await; total }); let la2 = Arc::clone(&last_activity); let u2c = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match ur.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if cw.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } let _ = cw.shutdown().await; total }); // Watchdog: monitors inactivity, max lifetime, and cancellation let la_watch = Arc::clone(&last_activity); let c2u_handle = c2u.abort_handle(); let u2c_handle = u2c.abort_handle(); let inactivity_timeout = DEFAULT_WS_INACTIVITY_TIMEOUT; let max_lifetime = DEFAULT_WS_MAX_LIFETIME; let watchdog = tokio::spawn(async move { let check_interval = std::time::Duration::from_secs(5); let mut last_seen = 0u64; loop { tokio::select! { _ = tokio::time::sleep(check_interval) => {} _ = cancel.cancelled() => { debug!("WebSocket tunnel cancelled by shutdown"); c2u_handle.abort(); u2c_handle.abort(); break; } } // Check max lifetime if start.elapsed() >= max_lifetime { debug!("WebSocket tunnel exceeded max lifetime, closing"); c2u_handle.abort(); u2c_handle.abort(); break; } // Check inactivity let current = la_watch.load(Ordering::Relaxed); if current == last_seen { let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { debug!("WebSocket tunnel inactive for {}ms, closing", elapsed_since_activity); c2u_handle.abort(); u2c_handle.abort(); break; } } last_seen = current; } }); let bytes_in = c2u.await.unwrap_or(0); let bytes_out = u2c.await.unwrap_or(0); watchdog.abort(); debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out); upstream_selector.connection_ended(&upstream_key_owned); if let Some(ref rid) = route_id_owned { metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()), Some(&source_ip_owned)); } }); let body: BoxBody = BoxBody::new( http_body_util::Empty::::new().map_err(|never| match never {}) ); Ok(client_resp.body(body).unwrap()) } /// Build a test response from config (no upstream connection needed). fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response> { let mut response = Response::builder() .status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK)); if let Some(headers) = response.headers_mut() { for (key, value) in &config.headers { if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) { if let Ok(val) = hyper::header::HeaderValue::from_str(value) { headers.insert(name, val); } } } } let body = Full::new(Bytes::from(config.body.clone())) .map_err(|never| match never {}); response.body(BoxBody::new(body)).unwrap() } /// Apply URL rewriting rules from route config, using the compiled regex cache. fn apply_url_rewrite(&self, path: &str, route: &rustproxy_config::RouteConfig) -> String { let rewrite = match route.action.advanced.as_ref() .and_then(|a| a.url_rewrite.as_ref()) { Some(r) => r, None => return path.to_string(), }; // Determine what to rewrite let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) { // Only rewrite the path portion (before ?) match path.split_once('?') { Some((p, q)) => (p.to_string(), format!("?{}", q)), None => (path.to_string(), String::new()), } } else { (path.to_string(), String::new()) }; // Look up or compile the regex, caching for future requests let cached = self.regex_cache.get(&rewrite.pattern); if let Some(re) = cached { let result = re.replace_all(&subject, rewrite.target.as_str()); return format!("{}{}", result, suffix); } // Not cached — compile and insert match Regex::new(&rewrite.pattern) { Ok(re) => { let result = re.replace_all(&subject, rewrite.target.as_str()); let out = format!("{}{}", result, suffix); self.regex_cache.insert(rewrite.pattern.clone(), re); out } Err(e) => { warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e); path.to_string() } } } /// Serve a static file from the configured directory. fn serve_static_file( path: &str, config: &rustproxy_config::RouteStaticFiles, ) -> Response> { use std::path::Path; let root = Path::new(&config.root); // Sanitize path to prevent directory traversal let clean_path = path.trim_start_matches('/'); let clean_path = clean_path.replace("..", ""); let mut file_path = root.join(&clean_path); // If path points to a directory, try index files if file_path.is_dir() || clean_path.is_empty() { let index_files = config.index_files.as_deref() .or(config.index.as_deref()) .unwrap_or(&[]); let default_index = vec!["index.html".to_string()]; let index_files = if index_files.is_empty() { &default_index } else { index_files }; let mut found = false; for index in index_files { let candidate = if clean_path.is_empty() { root.join(index) } else { file_path.join(index) }; if candidate.is_file() { file_path = candidate; found = true; break; } } if !found { return error_response(StatusCode::NOT_FOUND, "Not found"); } } // Ensure the resolved path is within the root (prevent traversal) let canonical_root = match root.canonicalize() { Ok(p) => p, Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"), }; let canonical_file = match file_path.canonicalize() { Ok(p) => p, Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"), }; if !canonical_file.starts_with(&canonical_root) { return error_response(StatusCode::FORBIDDEN, "Forbidden"); } // Check if symlinks are allowed if config.follow_symlinks == Some(false) && canonical_file != file_path { return error_response(StatusCode::FORBIDDEN, "Forbidden"); } // Read the file match std::fs::read(&file_path) { Ok(content) => { let content_type = guess_content_type(&file_path); let mut response = Response::builder() .status(StatusCode::OK) .header("Content-Type", content_type); // Apply cache-control if configured if let Some(ref cache_control) = config.cache_control { response = response.header("Cache-Control", cache_control.as_str()); } // Apply custom headers if let Some(ref headers) = config.headers { for (key, value) in headers { response = response.header(key.as_str(), value.as_str()); } } let body = Full::new(Bytes::from(content)) .map_err(|never| match never {}); response.body(BoxBody::new(body)).unwrap() } Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"), } } } /// Guess MIME content type from file extension. fn guess_content_type(path: &std::path::Path) -> &'static str { match path.extension().and_then(|e| e.to_str()) { Some("html") | Some("htm") => "text/html; charset=utf-8", Some("css") => "text/css; charset=utf-8", Some("js") | Some("mjs") => "application/javascript; charset=utf-8", Some("json") => "application/json; charset=utf-8", Some("xml") => "application/xml; charset=utf-8", Some("txt") => "text/plain; charset=utf-8", Some("png") => "image/png", Some("jpg") | Some("jpeg") => "image/jpeg", Some("gif") => "image/gif", Some("svg") => "image/svg+xml", Some("ico") => "image/x-icon", Some("woff") => "font/woff", Some("woff2") => "font/woff2", Some("ttf") => "font/ttf", Some("pdf") => "application/pdf", Some("wasm") => "application/wasm", _ => "application/octet-stream", } } impl HttpProxyService { /// Build a default backend TLS config with InsecureVerifier. /// Used as fallback when no shared config is injected from tls_handler. fn default_backend_tls_config() -> Arc { let _ = rustls::crypto::ring::default_provider().install_default(); let config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier)) .with_no_client_auth(); Arc::new(config) } } /// Insecure certificate verifier for backend TLS connections (fallback only). /// The production path uses the shared config from tls_handler which has the same /// behavior but with session resumption across all outbound connections. #[derive(Debug)] struct InsecureBackendVerifier; impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier { fn verify_server_cert( &self, _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, ] } } impl Default for HttpProxyService { fn default() -> Self { Self { route_manager: Arc::new(RouteManager::new(vec![])), metrics: Arc::new(MetricsCollector::new()), upstream_selector: UpstreamSelector::new(), connect_timeout: DEFAULT_CONNECT_TIMEOUT, route_rate_limiters: Arc::new(DashMap::new()), request_counter: AtomicU64::new(0), regex_cache: DashMap::new(), backend_tls_config: Self::default_backend_tls_config(), connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()), } } } fn error_response(status: StatusCode, message: &str) -> Response> { let body = Full::new(Bytes::from(message.to_string())) .map_err(|never| match never {}); Response::builder() .status(status) .header("Content-Type", "text/plain") .body(BoxBody::new(body)) .unwrap() }