Files
smartproxy/rust/crates/rustproxy-http/src/proxy_service.rs

1407 lines
56 KiB
Rust
Raw Normal View History

//! Hyper-based HTTP proxy service.
//!
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
//! parses HTTP requests, matches routes, and forwards to upstream backends.
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use regex::Regex;
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use std::pin::Pin;
use std::task::{Context, Poll};
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use rustproxy_security::RateLimiter;
use crate::counting_body::{CountingBody, Direction};
use crate::request_filter::RequestFilter;
use crate::response_filter::ResponseFilter;
use crate::upstream_selector::UpstreamSelector;
/// Default upstream connect timeout (30 seconds).
const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
/// Default WebSocket inactivity timeout (1 hour).
const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3600);
/// Default WebSocket max lifetime (24 hours).
const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400);
/// Backend stream that can be either plain TCP or TLS-wrapped.
/// Used for `terminate-and-reencrypt` mode where the backend requires TLS.
pub(crate) enum BackendStream {
Plain(TcpStream),
Tls(tokio_rustls::client::TlsStream<TcpStream>),
}
impl tokio::io::AsyncRead for BackendStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
BackendStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl tokio::io::AsyncWrite for BackendStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
BackendStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_flush(cx),
BackendStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
BackendStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
/// Connect to a backend over TLS using the shared backend TLS config
/// (from tls_handler). Session resumption is automatic.
async fn connect_tls_backend(
backend_tls_config: &Arc<rustls::ClientConfig>,
host: &str,
port: u16,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
let connector = tokio_rustls::TlsConnector::from(Arc::clone(backend_tls_config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
// Apply keepalive with 60s default
let _ = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60))
);
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
let tls_stream = connector.connect(server_name, stream).await?;
debug!("Backend TLS connection established to {}:{}", host, port);
Ok(tls_stream)
}
/// HTTP proxy service that processes HTTP traffic.
pub struct HttpProxyService {
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
upstream_selector: UpstreamSelector,
/// Timeout for connecting to upstream backends.
connect_timeout: std::time::Duration,
/// Per-route rate limiters (keyed by route ID).
route_rate_limiters: Arc<DashMap<String, Arc<RateLimiter>>>,
/// Request counter for periodic rate limiter cleanup.
request_counter: AtomicU64,
/// Cache of compiled URL rewrite regexes (keyed by pattern string).
regex_cache: DashMap<String, Regex>,
/// Shared backend TLS config for session resumption across connections.
backend_tls_config: Arc<rustls::ClientConfig>,
/// Backend connection pool for reusing keep-alive connections.
connection_pool: Arc<crate::connection_pool::ConnectionPool>,
}
impl HttpProxyService {
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
Self {
route_manager,
metrics,
upstream_selector: UpstreamSelector::new(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
regex_cache: DashMap::new(),
backend_tls_config: Self::default_backend_tls_config(),
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
}
}
/// Create with a custom connect timeout.
pub fn with_connect_timeout(
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
connect_timeout: std::time::Duration,
) -> Self {
Self {
route_manager,
metrics,
upstream_selector: UpstreamSelector::new(),
connect_timeout,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
regex_cache: DashMap::new(),
backend_tls_config: Self::default_backend_tls_config(),
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
}
}
/// Set the shared backend TLS config (enables session resumption).
/// Call this after construction to inject the shared config from tls_handler.
pub fn set_backend_tls_config(&mut self, config: Arc<rustls::ClientConfig>) {
self.backend_tls_config = config;
}
/// Prune caches for route IDs that are no longer active.
/// Call after route updates to prevent unbounded growth.
pub fn prune_stale_routes(&self, active_route_ids: &std::collections::HashSet<String>) {
self.route_rate_limiters.retain(|k, _| active_route_ids.contains(k));
self.regex_cache.clear();
self.upstream_selector.reset_round_robin();
}
/// Handle an incoming HTTP connection on a plain TCP stream.
pub async fn handle_connection(
self: Arc<Self>,
stream: TcpStream,
peer_addr: std::net::SocketAddr,
port: u16,
cancel: CancellationToken,
) {
self.handle_io(stream, peer_addr, port, cancel).await;
}
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
///
/// Uses `hyper_util::server::conn::auto::Builder` to auto-detect h1 vs h2
/// based on ALPN negotiation (TLS) or connection preface (h2c).
/// Supports HTTP/1.1 upgrades (WebSocket) and HTTP/2 CONNECT.
/// Responds to graceful shutdown via the cancel token.
pub async fn handle_io<I>(
self: Arc<Self>,
stream: I,
peer_addr: std::net::SocketAddr,
port: u16,
cancel: CancellationToken,
)
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let io = TokioIo::new(stream);
let cancel_inner = cancel.clone();
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let svc = Arc::clone(&self);
let peer = peer_addr;
let cn = cancel_inner.clone();
async move {
svc.handle_request(req, peer, port, cn).await
}
});
// Auto-detect h1 vs h2 based on ALPN / connection preface.
// serve_connection_with_upgrades supports h1 Upgrade (WebSocket) and h2 CONNECT.
let builder = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(io, service);
// Pin on the heap — auto::UpgradeableConnection is !Unpin
let mut conn = Box::pin(conn);
// Use select to support graceful shutdown via cancellation token
tokio::select! {
result = conn.as_mut() => {
if let Err(e) = result {
debug!("HTTP connection error from {}: {}", peer_addr, e);
}
}
_ = cancel.cancelled() => {
// Graceful shutdown: let in-flight request finish, stop accepting new ones
conn.as_mut().graceful_shutdown();
if let Err(e) = conn.await {
debug!("HTTP connection error during shutdown from {}: {}", peer_addr, e);
}
}
}
}
/// Handle a single HTTP request.
async fn handle_request(
&self,
req: Request<Incoming>,
peer_addr: std::net::SocketAddr,
port: u16,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let host = req.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.map(|h| {
// Strip port from host header
h.split(':').next().unwrap_or(h).to_string()
})
// HTTP/2 uses :authority pseudo-header instead of Host;
// hyper maps it to the URI authority component
.or_else(|| req.uri().host().map(|h| h.to_string()));
let path = req.uri().path().to_string();
let method = req.method().clone();
// Extract headers for matching
let headers: HashMap<String, String> = req.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
// Check for CORS preflight
if method == hyper::Method::OPTIONS {
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
return Ok(response);
}
}
// Match route
let ctx = rustproxy_routing::MatchContext {
port,
domain: host.as_deref(),
path: Some(&path),
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: Some(&headers),
is_tls: false,
protocol: Some("http"),
};
let route_match = match self.route_manager.find_route(&ctx) {
Some(rm) => rm,
None => {
debug!("No route matched for HTTP request to {:?}{}", host, path);
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
}
};
let route_id = route_match.route.id.as_deref();
let ip_str = peer_addr.ip().to_string();
self.metrics.record_http_request();
// Apply request filters (IP check, rate limiting, auth)
if let Some(ref security) = route_match.route.security {
// Look up or create a shared rate limiter for this route
let rate_limiter = security.rate_limit.as_ref()
.filter(|rl| rl.enabled)
.map(|rl| {
let route_key = route_id.unwrap_or("__default__").to_string();
self.route_rate_limiters
.entry(route_key)
.or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window)))
.clone()
});
if let Some(response) = RequestFilter::apply_with_rate_limiter(
security, &req, &peer_addr, rate_limiter.as_ref(),
) {
return Ok(response);
}
}
// Periodic rate limiter cleanup (every 1000 requests)
let count = self.request_counter.fetch_add(1, Ordering::Relaxed);
if count % 1000 == 0 {
for entry in self.route_rate_limiters.iter() {
entry.value().cleanup();
}
}
// Check for test response (returns immediately, no upstream needed)
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref test_response) = advanced.test_response {
return Ok(Self::build_test_response(test_response));
}
}
// Check for static file serving
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref static_files) = advanced.static_files {
return Ok(Self::serve_static_file(&path, static_files));
}
}
// Select upstream
let target = match route_match.target {
Some(t) => t,
None => {
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
}
};
let mut upstream = self.upstream_selector.select(target, &peer_addr, port);
// If the route uses terminate-and-reencrypt, always re-encrypt to backend
if let Some(ref tls) = route_match.route.action.tls {
if tls.mode == rustproxy_config::TlsMode::TerminateAndReencrypt {
upstream.use_tls = true;
}
}
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_started(&upstream_key);
// Check for WebSocket upgrade
let is_websocket = req.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
if is_websocket {
let result = self.handle_websocket_upgrade(
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key, cancel, &ip_str,
).await;
// Note: for WebSocket, connection_ended is called inside
// the spawned tunnel task when the connection closes.
return result;
}
// Determine backend protocol
let use_h2 = route_match.route.action.options.as_ref()
.and_then(|o| o.backend_protocol.as_ref())
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
.unwrap_or(false);
// Build the upstream path (path + query), applying URL rewriting if configured
let upstream_path = {
let raw_path = match req.uri().query() {
Some(q) => format!("{}?{}", path, q),
None => path.clone(),
};
self.apply_url_rewrite(&raw_path, &route_match.route)
};
// Build upstream request - stream body instead of buffering
let (parts, body) = req.into_parts();
// Apply request headers from route config
let mut upstream_headers = parts.headers.clone();
if let Some(ref route_headers) = route_match.route.headers {
if let Some(ref request_headers) = route_headers.request {
for (key, value) in request_headers {
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
upstream_headers.insert(name, val);
}
}
}
}
}
// Ensure Host header is set (HTTP/2 requests don't have Host; need it for h1 backends)
if !upstream_headers.contains_key("host") {
if let Some(ref h) = host {
if let Ok(val) = hyper::header::HeaderValue::from_str(h) {
upstream_headers.insert(hyper::header::HOST, val);
}
}
}
// Add standard reverse-proxy headers (X-Forwarded-*)
{
let original_host = host.as_deref().unwrap_or("");
let forwarded_proto = if route_match.route.action.tls.as_ref()
.map(|t| matches!(t.mode,
rustproxy_config::TlsMode::Terminate
| rustproxy_config::TlsMode::TerminateAndReencrypt))
.unwrap_or(false)
{
"https"
} else {
"http"
};
// X-Forwarded-For: append client IP to existing chain
let client_ip = peer_addr.ip().to_string();
let xff_value = if let Some(existing) = upstream_headers.get("x-forwarded-for") {
format!("{}, {}", existing.to_str().unwrap_or(""), client_ip)
} else {
client_ip
};
if let Ok(val) = hyper::header::HeaderValue::from_str(&xff_value) {
upstream_headers.insert(
hyper::header::HeaderName::from_static("x-forwarded-for"),
val,
);
}
// X-Forwarded-Host: original Host header
if let Ok(val) = hyper::header::HeaderValue::from_str(original_host) {
upstream_headers.insert(
hyper::header::HeaderName::from_static("x-forwarded-host"),
val,
);
}
// X-Forwarded-Proto: original client protocol
if let Ok(val) = hyper::header::HeaderValue::from_str(forwarded_proto) {
upstream_headers.insert(
hyper::header::HeaderName::from_static("x-forwarded-proto"),
val,
);
}
}
// --- Connection pooling: try reusing an existing connection first ---
let pool_key = crate::connection_pool::PoolKey {
host: upstream.host.clone(),
port: upstream.port,
use_tls: upstream.use_tls,
h2: use_h2,
};
// Try pooled connection first (H2 only — H2 senders are Clone and multiplexed,
// so checkout doesn't consume request parts. For H1, we try pool inside forward_h1.)
if use_h2 {
if let Some(sender) = self.connection_pool.checkout_h2(&pool_key) {
let result = self.forward_h2_pooled(
sender, parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &pool_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
return result;
}
}
// Fresh connection path
let backend = if upstream.use_tls {
match tokio::time::timeout(
self.connect_timeout,
connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port),
).await {
Ok(Ok(tls)) => BackendStream::Tls(tls),
Ok(Err(e)) => {
error!("Failed TLS connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
}
Err(_) => {
error!("Upstream TLS connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(&upstream_key);
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
}
}
} else {
match tokio::time::timeout(
self.connect_timeout,
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
).await {
Ok(Ok(s)) => {
s.set_nodelay(true).ok();
let _ = socket2::SockRef::from(&s).set_tcp_keepalive(
&socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60))
);
BackendStream::Plain(s)
}
Ok(Err(e)) => {
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
Err(_) => {
error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(&upstream_key);
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
}
}
};
let io = TokioIo::new(backend);
let result = if use_h2 {
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &pool_key).await
} else {
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &pool_key).await
};
self.upstream_selector.connection_ended(&upstream_key);
result
}
/// Forward request to backend via HTTP/1.1 with body streaming.
/// Tries a pooled connection first; if unavailable, uses the fresh IO connection.
async fn forward_h1(
&self,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
_upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
source_ip: &str,
pool_key: &crate::connection_pool::PoolKey,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Try pooled H1 connection first — avoids TCP+TLS handshake
if let Some(pooled_sender) = self.connection_pool.checkout_h1(pool_key) {
return self.forward_h1_with_sender(
pooled_sender, parts, body, upstream_headers, upstream_path,
route, route_id, source_ip, pool_key,
).await;
}
// Fresh connection: explicitly type the handshake with BoxBody for uniform pool type
let (sender, conn): (
hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
hyper::client::conn::http1::Connection<TokioIo<BackendStream>, BoxBody<Bytes, hyper::Error>>,
) = match hyper::client::conn::http1::handshake(io).await {
Ok(h) => h,
Err(e) => {
error!("Upstream handshake failed: {}", e);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
}
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("Upstream connection error: {}", e);
}
});
self.forward_h1_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, pool_key).await
}
/// Common H1 forwarding logic used by both fresh and pooled paths.
async fn forward_h1_with_sender(
&self,
mut sender: hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
source_ip: &str,
pool_key: &crate::connection_pool::PoolKey,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Always use HTTP/1.1 for h1 backend connections (h2 incoming requests have version HTTP/2.0)
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_path)
.version(hyper::Version::HTTP_11);
if let Some(headers) = upstream_req.headers_mut() {
*headers = upstream_headers;
}
// Wrap the request body in CountingBody then box it for the uniform pool type
let counting_req_body = CountingBody::new(
body,
Arc::clone(&self.metrics),
route_id.map(|s| s.to_string()),
Some(source_ip.to_string()),
Direction::In,
);
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_req_body);
let upstream_req = upstream_req.body(boxed_body).unwrap();
let upstream_response = match sender.send_request(upstream_req).await {
Ok(resp) => resp,
Err(e) => {
error!("Upstream request failed: {}", e);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
}
};
// Return sender to pool (body streams lazily, sender is reusable once response head is received)
self.connection_pool.checkin_h1(pool_key.clone(), sender);
self.build_streaming_response(upstream_response, route, route_id, source_ip).await
}
/// Forward request to backend via HTTP/2 with body streaming (fresh connection).
/// Registers the h2 sender in the pool for future multiplexed requests.
async fn forward_h2(
&self,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
_upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
source_ip: &str,
pool_key: &crate::connection_pool::PoolKey,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let exec = hyper_util::rt::TokioExecutor::new();
// Explicitly type the handshake with BoxBody for uniform pool type
let (sender, conn): (
hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
hyper::client::conn::http2::Connection<TokioIo<BackendStream>, BoxBody<Bytes, hyper::Error>, hyper_util::rt::TokioExecutor>,
) = match hyper::client::conn::http2::handshake(exec, io).await {
Ok(h) => h,
Err(e) => {
error!("HTTP/2 upstream handshake failed: {}", e);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
}
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("HTTP/2 upstream connection error: {}", e);
}
});
// Register for multiplexed reuse
self.connection_pool.register_h2(pool_key.clone(), sender.clone());
self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await
}
/// Forward request using an existing (pooled) HTTP/2 sender.
async fn forward_h2_pooled(
&self,
sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
source_ip: &str,
_pool_key: &crate::connection_pool::PoolKey,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await
}
/// Common H2 forwarding logic used by both fresh and pooled paths.
async fn forward_h2_with_sender(
&self,
mut sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
source_ip: &str,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_path);
if let Some(headers) = upstream_req.headers_mut() {
*headers = upstream_headers;
}
// Wrap the request body in CountingBody then box it for the uniform pool type
let counting_req_body = CountingBody::new(
body,
Arc::clone(&self.metrics),
route_id.map(|s| s.to_string()),
Some(source_ip.to_string()),
Direction::In,
);
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_req_body);
let upstream_req = upstream_req.body(boxed_body).unwrap();
let upstream_response = match sender.send_request(upstream_req).await {
Ok(resp) => resp,
Err(e) => {
error!("HTTP/2 upstream request failed: {}", e);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
}
};
self.build_streaming_response(upstream_response, route, route_id, source_ip).await
}
/// Build the client-facing response from an upstream response, streaming the body.
///
/// The response body is wrapped in a `CountingBody` that counts bytes as they
/// stream from upstream to client.
async fn build_streaming_response(
&self,
upstream_response: Response<Incoming>,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
source_ip: &str,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let (resp_parts, resp_body) = upstream_response.into_parts();
let mut response = Response::builder()
.status(resp_parts.status);
if let Some(headers) = response.headers_mut() {
*headers = resp_parts.headers;
ResponseFilter::apply_headers(route, headers, None);
}
// Wrap the response body in CountingBody to track bytes_out.
// CountingBody will report bytes and we close the connection metric
// after the body stream completes (not before it even starts).
let counting_body = CountingBody::new(
resp_body,
Arc::clone(&self.metrics),
route_id.map(|s| s.to_string()),
Some(source_ip.to_string()),
Direction::Out,
);
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_body);
Ok(response.body(body).unwrap())
}
/// Handle a WebSocket upgrade request.
async fn handle_websocket_upgrade(
&self,
req: Request<Incoming>,
peer_addr: std::net::SocketAddr,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
upstream_key: &str,
cancel: CancellationToken,
source_ip: &str,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Get WebSocket config from route
let ws_config = route.action.websocket.as_ref();
// Check allowed origins if configured
if let Some(ws) = ws_config {
if let Some(ref allowed_origins) = ws.allowed_origins {
let origin = req.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
}
}
}
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
// Connect to upstream with timeout (TLS if upstream.use_tls is set)
let mut upstream_stream: BackendStream = if upstream.use_tls {
match tokio::time::timeout(
self.connect_timeout,
connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port),
).await {
Ok(Ok(tls)) => BackendStream::Tls(tls),
Ok(Err(e)) => {
error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
}
Err(_) => {
error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
}
}
} else {
match tokio::time::timeout(
self.connect_timeout,
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
).await {
Ok(Ok(s)) => {
s.set_nodelay(true).ok();
let _ = socket2::SockRef::from(&s).set_tcp_keepalive(
&socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60))
);
BackendStream::Plain(s)
}
Ok(Err(e)) => {
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
Err(_) => {
error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
}
}
};
let path = req.uri().path().to_string();
let upstream_path = {
let raw = match req.uri().query() {
Some(q) => format!("{}?{}", path, q),
None => path,
};
// Apply rewrite_path if configured
if let Some(ws) = ws_config {
if let Some(ref rewrite_path) = ws.rewrite_path {
rewrite_path.clone()
} else {
raw
}
} else {
raw
}
};
let (parts, _body) = req.into_parts();
let mut raw_request = format!(
"{} {} HTTP/1.1\r\n",
parts.method, upstream_path
);
// Copy all original headers (preserving the client's Host header).
// Skip X-Forwarded-* since we set them ourselves below.
let mut has_host_header = false;
for (name, value) in parts.headers.iter() {
let name_str = name.as_str();
if name_str == "x-forwarded-for"
|| name_str == "x-forwarded-host"
|| name_str == "x-forwarded-proto"
{
continue;
}
if name_str == "host" {
has_host_header = true;
}
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
}
// HTTP/2 requests don't have Host header; add one from URI authority for h1 backends
let ws_host = parts.uri.host().map(|h| h.to_string());
if !has_host_header {
if let Some(ref h) = ws_host {
raw_request.push_str(&format!("host: {}\r\n", h));
}
}
// Add standard reverse-proxy headers (X-Forwarded-*)
{
let original_host = parts.headers.get("host")
.and_then(|h| h.to_str().ok())
.or(ws_host.as_deref())
.unwrap_or("");
let forwarded_proto = if route.action.tls.as_ref()
.map(|t| matches!(t.mode,
rustproxy_config::TlsMode::Terminate
| rustproxy_config::TlsMode::TerminateAndReencrypt))
.unwrap_or(false)
{
"https"
} else {
"http"
};
let client_ip = peer_addr.ip().to_string();
let xff_value = if let Some(existing) = parts.headers.get("x-forwarded-for") {
format!("{}, {}", existing.to_str().unwrap_or(""), client_ip)
} else {
client_ip
};
raw_request.push_str(&format!("x-forwarded-for: {}\r\n", xff_value));
raw_request.push_str(&format!("x-forwarded-host: {}\r\n", original_host));
raw_request.push_str(&format!("x-forwarded-proto: {}\r\n", forwarded_proto));
}
if let Some(ref route_headers) = route.headers {
if let Some(ref request_headers) = route_headers.request {
for (key, value) in request_headers {
raw_request.push_str(&format!("{}: {}\r\n", key, value));
}
}
}
// Apply WebSocket custom headers
if let Some(ws) = ws_config {
if let Some(ref custom_headers) = ws.custom_headers {
for (key, value) in custom_headers {
raw_request.push_str(&format!("{}: {}\r\n", key, value));
}
}
}
raw_request.push_str("\r\n");
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
}
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
match upstream_stream.read(&mut temp).await {
Ok(0) => {
error!("WebSocket: upstream closed before completing handshake");
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
}
Ok(_) => {
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break;
}
}
if response_buf.len() > 8192 {
error!("WebSocket: upstream response headers too large");
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
}
}
Err(e) => {
error!("WebSocket: failed to read upstream response: {}", e);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
}
}
}
let response_str = String::from_utf8_lossy(&response_buf);
let status_line = response_str.lines().next().unwrap_or("");
let status_code = status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(0);
if status_code != 101 {
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
self.upstream_selector.connection_ended(upstream_key);
return Ok(error_response(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
"WebSocket upgrade rejected by backend",
));
}
let mut client_resp = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS);
if let Some(resp_headers) = client_resp.headers_mut() {
for line in response_str.lines().skip(1) {
let line = line.trim();
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim();
let value = value.trim();
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
resp_headers.insert(header_name, header_value);
}
}
}
}
}
let on_client_upgrade = hyper::upgrade::on(
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
);
let metrics = Arc::clone(&self.metrics);
let route_id_owned = route_id.map(|s| s.to_string());
let source_ip_owned = source_ip.to_string();
let upstream_selector = self.upstream_selector.clone();
let upstream_key_owned = upstream_key.to_string();
tokio::spawn(async move {
let client_upgraded = match on_client_upgrade.await {
Ok(upgraded) => upgraded,
Err(e) => {
debug!("WebSocket: client upgrade failed: {}", e);
upstream_selector.connection_ended(&upstream_key_owned);
return;
}
};
let client_io = TokioIo::new(client_upgraded);
let (mut cr, mut cw) = tokio::io::split(client_io);
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
// Shared activity tracker for the watchdog
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let c2u = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match cr.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if uw.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = uw.shutdown().await;
total
});
let la2 = Arc::clone(&last_activity);
let u2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match ur.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if cw.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = cw.shutdown().await;
total
});
// Watchdog: monitors inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2u_handle = c2u.abort_handle();
let u2c_handle = u2c.abort_handle();
let inactivity_timeout = DEFAULT_WS_INACTIVITY_TIMEOUT;
let max_lifetime = DEFAULT_WS_MAX_LIFETIME;
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = tokio::time::sleep(check_interval) => {}
_ = cancel.cancelled() => {
debug!("WebSocket tunnel cancelled by shutdown");
c2u_handle.abort();
u2c_handle.abort();
break;
}
}
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("WebSocket tunnel exceeded max lifetime, closing");
c2u_handle.abort();
u2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("WebSocket tunnel inactive for {}ms, closing", elapsed_since_activity);
c2u_handle.abort();
u2c_handle.abort();
break;
}
}
last_seen = current;
}
});
let bytes_in = c2u.await.unwrap_or(0);
let bytes_out = u2c.await.unwrap_or(0);
watchdog.abort();
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
upstream_selector.connection_ended(&upstream_key_owned);
if let Some(ref rid) = route_id_owned {
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()), Some(&source_ip_owned));
}
});
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
);
Ok(client_resp.body(body).unwrap())
}
/// Build a test response from config (no upstream connection needed).
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::builder()
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
if let Some(headers) = response.headers_mut() {
for (key, value) in &config.headers {
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
headers.insert(name, val);
}
}
}
}
let body = Full::new(Bytes::from(config.body.clone()))
.map_err(|never| match never {});
response.body(BoxBody::new(body)).unwrap()
}
/// Apply URL rewriting rules from route config, using the compiled regex cache.
fn apply_url_rewrite(&self, path: &str, route: &rustproxy_config::RouteConfig) -> String {
let rewrite = match route.action.advanced.as_ref()
.and_then(|a| a.url_rewrite.as_ref())
{
Some(r) => r,
None => return path.to_string(),
};
// Determine what to rewrite
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
// Only rewrite the path portion (before ?)
match path.split_once('?') {
Some((p, q)) => (p.to_string(), format!("?{}", q)),
None => (path.to_string(), String::new()),
}
} else {
(path.to_string(), String::new())
};
// Look up or compile the regex, caching for future requests
let cached = self.regex_cache.get(&rewrite.pattern);
if let Some(re) = cached {
let result = re.replace_all(&subject, rewrite.target.as_str());
return format!("{}{}", result, suffix);
}
// Not cached — compile and insert
match Regex::new(&rewrite.pattern) {
Ok(re) => {
let result = re.replace_all(&subject, rewrite.target.as_str());
let out = format!("{}{}", result, suffix);
self.regex_cache.insert(rewrite.pattern.clone(), re);
out
}
Err(e) => {
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
path.to_string()
}
}
}
/// Serve a static file from the configured directory.
fn serve_static_file(
path: &str,
config: &rustproxy_config::RouteStaticFiles,
) -> Response<BoxBody<Bytes, hyper::Error>> {
use std::path::Path;
let root = Path::new(&config.root);
// Sanitize path to prevent directory traversal
let clean_path = path.trim_start_matches('/');
let clean_path = clean_path.replace("..", "");
let mut file_path = root.join(&clean_path);
// If path points to a directory, try index files
if file_path.is_dir() || clean_path.is_empty() {
let index_files = config.index_files.as_deref()
.or(config.index.as_deref())
.unwrap_or(&[]);
let default_index = vec!["index.html".to_string()];
let index_files = if index_files.is_empty() { &default_index } else { index_files };
let mut found = false;
for index in index_files {
let candidate = if clean_path.is_empty() {
root.join(index)
} else {
file_path.join(index)
};
if candidate.is_file() {
file_path = candidate;
found = true;
break;
}
}
if !found {
return error_response(StatusCode::NOT_FOUND, "Not found");
}
}
// Ensure the resolved path is within the root (prevent traversal)
let canonical_root = match root.canonicalize() {
Ok(p) => p,
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
};
let canonical_file = match file_path.canonicalize() {
Ok(p) => p,
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
};
if !canonical_file.starts_with(&canonical_root) {
return error_response(StatusCode::FORBIDDEN, "Forbidden");
}
// Check if symlinks are allowed
if config.follow_symlinks == Some(false) && canonical_file != file_path {
return error_response(StatusCode::FORBIDDEN, "Forbidden");
}
// Read the file
match std::fs::read(&file_path) {
Ok(content) => {
let content_type = guess_content_type(&file_path);
let mut response = Response::builder()
.status(StatusCode::OK)
.header("Content-Type", content_type);
// Apply cache-control if configured
if let Some(ref cache_control) = config.cache_control {
response = response.header("Cache-Control", cache_control.as_str());
}
// Apply custom headers
if let Some(ref headers) = config.headers {
for (key, value) in headers {
response = response.header(key.as_str(), value.as_str());
}
}
let body = Full::new(Bytes::from(content))
.map_err(|never| match never {});
response.body(BoxBody::new(body)).unwrap()
}
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
}
}
}
/// Guess MIME content type from file extension.
fn guess_content_type(path: &std::path::Path) -> &'static str {
match path.extension().and_then(|e| e.to_str()) {
Some("html") | Some("htm") => "text/html; charset=utf-8",
Some("css") => "text/css; charset=utf-8",
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
Some("json") => "application/json; charset=utf-8",
Some("xml") => "application/xml; charset=utf-8",
Some("txt") => "text/plain; charset=utf-8",
Some("png") => "image/png",
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("gif") => "image/gif",
Some("svg") => "image/svg+xml",
Some("ico") => "image/x-icon",
Some("woff") => "font/woff",
Some("woff2") => "font/woff2",
Some("ttf") => "font/ttf",
Some("pdf") => "application/pdf",
Some("wasm") => "application/wasm",
_ => "application/octet-stream",
}
}
impl HttpProxyService {
/// Build a default backend TLS config with InsecureVerifier.
/// Used as fallback when no shared config is injected from tls_handler.
fn default_backend_tls_config() -> Arc<rustls::ClientConfig> {
let _ = rustls::crypto::ring::default_provider().install_default();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier))
.with_no_client_auth();
Arc::new(config)
}
}
/// Insecure certificate verifier for backend TLS connections (fallback only).
/// The production path uses the shared config from tls_handler which has the same
/// behavior but with session resumption across all outbound connections.
#[derive(Debug)]
struct InsecureBackendVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
impl Default for HttpProxyService {
fn default() -> Self {
Self {
route_manager: Arc::new(RouteManager::new(vec![])),
metrics: Arc::new(MetricsCollector::new()),
upstream_selector: UpstreamSelector::new(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
regex_cache: DashMap::new(),
backend_tls_config: Self::default_backend_tls_config(),
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
}
}
}
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
let body = Full::new(Bytes::from(message.to_string()))
.map_err(|never| match never {});
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(BoxBody::new(body))
.unwrap()
}