1407 lines
56 KiB
Rust
1407 lines
56 KiB
Rust
//! Hyper-based HTTP proxy service.
|
|
//!
|
|
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
|
|
//! parses HTTP requests, matches routes, and forwards to upstream backends.
|
|
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
|
|
use bytes::Bytes;
|
|
use dashmap::DashMap;
|
|
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
|
use hyper::body::Incoming;
|
|
use hyper::{Request, Response, StatusCode};
|
|
use hyper_util::rt::TokioIo;
|
|
use regex::Regex;
|
|
use tokio::net::TcpStream;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{debug, error, info, warn};
|
|
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
|
|
use rustproxy_routing::RouteManager;
|
|
use rustproxy_metrics::MetricsCollector;
|
|
use rustproxy_security::RateLimiter;
|
|
|
|
use crate::counting_body::{CountingBody, Direction};
|
|
use crate::request_filter::RequestFilter;
|
|
use crate::response_filter::ResponseFilter;
|
|
use crate::upstream_selector::UpstreamSelector;
|
|
|
|
/// Default upstream connect timeout (30 seconds).
|
|
const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
|
|
|
/// Default WebSocket inactivity timeout (1 hour).
|
|
const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3600);
|
|
|
|
/// Default WebSocket max lifetime (24 hours).
|
|
const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400);
|
|
|
|
/// Backend stream that can be either plain TCP or TLS-wrapped.
|
|
/// Used for `terminate-and-reencrypt` mode where the backend requires TLS.
|
|
pub(crate) enum BackendStream {
|
|
Plain(TcpStream),
|
|
Tls(tokio_rustls::client::TlsStream<TcpStream>),
|
|
}
|
|
|
|
impl tokio::io::AsyncRead for BackendStream {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut tokio::io::ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
match self.get_mut() {
|
|
BackendStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
|
|
BackendStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl tokio::io::AsyncWrite for BackendStream {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<std::io::Result<usize>> {
|
|
match self.get_mut() {
|
|
BackendStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
|
|
BackendStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
|
|
}
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
match self.get_mut() {
|
|
BackendStream::Plain(s) => Pin::new(s).poll_flush(cx),
|
|
BackendStream::Tls(s) => Pin::new(s).poll_flush(cx),
|
|
}
|
|
}
|
|
|
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
match self.get_mut() {
|
|
BackendStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
|
|
BackendStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Connect to a backend over TLS using the shared backend TLS config
|
|
/// (from tls_handler). Session resumption is automatic.
|
|
async fn connect_tls_backend(
|
|
backend_tls_config: &Arc<rustls::ClientConfig>,
|
|
host: &str,
|
|
port: u16,
|
|
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let connector = tokio_rustls::TlsConnector::from(Arc::clone(backend_tls_config));
|
|
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
|
stream.set_nodelay(true)?;
|
|
// Apply keepalive with 60s default
|
|
let _ = socket2::SockRef::from(&stream).set_tcp_keepalive(
|
|
&socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60))
|
|
);
|
|
|
|
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
|
|
let tls_stream = connector.connect(server_name, stream).await?;
|
|
debug!("Backend TLS connection established to {}:{}", host, port);
|
|
Ok(tls_stream)
|
|
}
|
|
|
|
/// HTTP proxy service that processes HTTP traffic.
|
|
pub struct HttpProxyService {
|
|
route_manager: Arc<RouteManager>,
|
|
metrics: Arc<MetricsCollector>,
|
|
upstream_selector: UpstreamSelector,
|
|
/// Timeout for connecting to upstream backends.
|
|
connect_timeout: std::time::Duration,
|
|
/// Per-route rate limiters (keyed by route ID).
|
|
route_rate_limiters: Arc<DashMap<String, Arc<RateLimiter>>>,
|
|
/// Request counter for periodic rate limiter cleanup.
|
|
request_counter: AtomicU64,
|
|
/// Cache of compiled URL rewrite regexes (keyed by pattern string).
|
|
regex_cache: DashMap<String, Regex>,
|
|
/// Shared backend TLS config for session resumption across connections.
|
|
backend_tls_config: Arc<rustls::ClientConfig>,
|
|
/// Backend connection pool for reusing keep-alive connections.
|
|
connection_pool: Arc<crate::connection_pool::ConnectionPool>,
|
|
}
|
|
|
|
impl HttpProxyService {
|
|
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
|
Self {
|
|
route_manager,
|
|
metrics,
|
|
upstream_selector: UpstreamSelector::new(),
|
|
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
|
route_rate_limiters: Arc::new(DashMap::new()),
|
|
request_counter: AtomicU64::new(0),
|
|
regex_cache: DashMap::new(),
|
|
backend_tls_config: Self::default_backend_tls_config(),
|
|
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
|
|
}
|
|
}
|
|
|
|
/// Create with a custom connect timeout.
|
|
pub fn with_connect_timeout(
|
|
route_manager: Arc<RouteManager>,
|
|
metrics: Arc<MetricsCollector>,
|
|
connect_timeout: std::time::Duration,
|
|
) -> Self {
|
|
Self {
|
|
route_manager,
|
|
metrics,
|
|
upstream_selector: UpstreamSelector::new(),
|
|
connect_timeout,
|
|
route_rate_limiters: Arc::new(DashMap::new()),
|
|
request_counter: AtomicU64::new(0),
|
|
regex_cache: DashMap::new(),
|
|
backend_tls_config: Self::default_backend_tls_config(),
|
|
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
|
|
}
|
|
}
|
|
|
|
/// Set the shared backend TLS config (enables session resumption).
|
|
/// Call this after construction to inject the shared config from tls_handler.
|
|
pub fn set_backend_tls_config(&mut self, config: Arc<rustls::ClientConfig>) {
|
|
self.backend_tls_config = config;
|
|
}
|
|
|
|
/// Prune caches for route IDs that are no longer active.
|
|
/// Call after route updates to prevent unbounded growth.
|
|
pub fn prune_stale_routes(&self, active_route_ids: &std::collections::HashSet<String>) {
|
|
self.route_rate_limiters.retain(|k, _| active_route_ids.contains(k));
|
|
self.regex_cache.clear();
|
|
self.upstream_selector.reset_round_robin();
|
|
}
|
|
|
|
/// Handle an incoming HTTP connection on a plain TCP stream.
|
|
pub async fn handle_connection(
|
|
self: Arc<Self>,
|
|
stream: TcpStream,
|
|
peer_addr: std::net::SocketAddr,
|
|
port: u16,
|
|
cancel: CancellationToken,
|
|
) {
|
|
self.handle_io(stream, peer_addr, port, cancel).await;
|
|
}
|
|
|
|
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
|
|
///
|
|
/// Uses `hyper_util::server::conn::auto::Builder` to auto-detect h1 vs h2
|
|
/// based on ALPN negotiation (TLS) or connection preface (h2c).
|
|
/// Supports HTTP/1.1 upgrades (WebSocket) and HTTP/2 CONNECT.
|
|
/// Responds to graceful shutdown via the cancel token.
|
|
pub async fn handle_io<I>(
|
|
self: Arc<Self>,
|
|
stream: I,
|
|
peer_addr: std::net::SocketAddr,
|
|
port: u16,
|
|
cancel: CancellationToken,
|
|
)
|
|
where
|
|
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
|
{
|
|
let io = TokioIo::new(stream);
|
|
|
|
let cancel_inner = cancel.clone();
|
|
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
|
let svc = Arc::clone(&self);
|
|
let peer = peer_addr;
|
|
let cn = cancel_inner.clone();
|
|
async move {
|
|
svc.handle_request(req, peer, port, cn).await
|
|
}
|
|
});
|
|
|
|
// Auto-detect h1 vs h2 based on ALPN / connection preface.
|
|
// serve_connection_with_upgrades supports h1 Upgrade (WebSocket) and h2 CONNECT.
|
|
let builder = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new());
|
|
let conn = builder.serve_connection_with_upgrades(io, service);
|
|
// Pin on the heap — auto::UpgradeableConnection is !Unpin
|
|
let mut conn = Box::pin(conn);
|
|
|
|
// Use select to support graceful shutdown via cancellation token
|
|
tokio::select! {
|
|
result = conn.as_mut() => {
|
|
if let Err(e) = result {
|
|
debug!("HTTP connection error from {}: {}", peer_addr, e);
|
|
}
|
|
}
|
|
_ = cancel.cancelled() => {
|
|
// Graceful shutdown: let in-flight request finish, stop accepting new ones
|
|
conn.as_mut().graceful_shutdown();
|
|
if let Err(e) = conn.await {
|
|
debug!("HTTP connection error during shutdown from {}: {}", peer_addr, e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle a single HTTP request.
|
|
async fn handle_request(
|
|
&self,
|
|
req: Request<Incoming>,
|
|
peer_addr: std::net::SocketAddr,
|
|
port: u16,
|
|
cancel: CancellationToken,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let host = req.headers()
|
|
.get("host")
|
|
.and_then(|h| h.to_str().ok())
|
|
.map(|h| {
|
|
// Strip port from host header
|
|
h.split(':').next().unwrap_or(h).to_string()
|
|
})
|
|
// HTTP/2 uses :authority pseudo-header instead of Host;
|
|
// hyper maps it to the URI authority component
|
|
.or_else(|| req.uri().host().map(|h| h.to_string()));
|
|
|
|
let path = req.uri().path().to_string();
|
|
let method = req.method().clone();
|
|
|
|
// Extract headers for matching
|
|
let headers: HashMap<String, String> = req.headers()
|
|
.iter()
|
|
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
|
.collect();
|
|
|
|
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
|
|
|
|
// Check for CORS preflight
|
|
if method == hyper::Method::OPTIONS {
|
|
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
|
|
return Ok(response);
|
|
}
|
|
}
|
|
|
|
// Match route
|
|
let ctx = rustproxy_routing::MatchContext {
|
|
port,
|
|
domain: host.as_deref(),
|
|
path: Some(&path),
|
|
client_ip: Some(&peer_addr.ip().to_string()),
|
|
tls_version: None,
|
|
headers: Some(&headers),
|
|
is_tls: false,
|
|
protocol: Some("http"),
|
|
};
|
|
|
|
let route_match = match self.route_manager.find_route(&ctx) {
|
|
Some(rm) => rm,
|
|
None => {
|
|
debug!("No route matched for HTTP request to {:?}{}", host, path);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
|
|
}
|
|
};
|
|
|
|
let route_id = route_match.route.id.as_deref();
|
|
let ip_str = peer_addr.ip().to_string();
|
|
self.metrics.record_http_request();
|
|
|
|
// Apply request filters (IP check, rate limiting, auth)
|
|
if let Some(ref security) = route_match.route.security {
|
|
// Look up or create a shared rate limiter for this route
|
|
let rate_limiter = security.rate_limit.as_ref()
|
|
.filter(|rl| rl.enabled)
|
|
.map(|rl| {
|
|
let route_key = route_id.unwrap_or("__default__").to_string();
|
|
self.route_rate_limiters
|
|
.entry(route_key)
|
|
.or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window)))
|
|
.clone()
|
|
});
|
|
if let Some(response) = RequestFilter::apply_with_rate_limiter(
|
|
security, &req, &peer_addr, rate_limiter.as_ref(),
|
|
) {
|
|
return Ok(response);
|
|
}
|
|
}
|
|
|
|
// Periodic rate limiter cleanup (every 1000 requests)
|
|
let count = self.request_counter.fetch_add(1, Ordering::Relaxed);
|
|
if count % 1000 == 0 {
|
|
for entry in self.route_rate_limiters.iter() {
|
|
entry.value().cleanup();
|
|
}
|
|
}
|
|
|
|
// Check for test response (returns immediately, no upstream needed)
|
|
if let Some(ref advanced) = route_match.route.action.advanced {
|
|
if let Some(ref test_response) = advanced.test_response {
|
|
return Ok(Self::build_test_response(test_response));
|
|
}
|
|
}
|
|
|
|
// Check for static file serving
|
|
if let Some(ref advanced) = route_match.route.action.advanced {
|
|
if let Some(ref static_files) = advanced.static_files {
|
|
return Ok(Self::serve_static_file(&path, static_files));
|
|
}
|
|
}
|
|
|
|
// Select upstream
|
|
let target = match route_match.target {
|
|
Some(t) => t,
|
|
None => {
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
|
|
}
|
|
};
|
|
|
|
let mut upstream = self.upstream_selector.select(target, &peer_addr, port);
|
|
|
|
// If the route uses terminate-and-reencrypt, always re-encrypt to backend
|
|
if let Some(ref tls) = route_match.route.action.tls {
|
|
if tls.mode == rustproxy_config::TlsMode::TerminateAndReencrypt {
|
|
upstream.use_tls = true;
|
|
}
|
|
}
|
|
|
|
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
|
|
self.upstream_selector.connection_started(&upstream_key);
|
|
|
|
// Check for WebSocket upgrade
|
|
let is_websocket = req.headers()
|
|
.get("upgrade")
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(|v| v.eq_ignore_ascii_case("websocket"))
|
|
.unwrap_or(false);
|
|
|
|
if is_websocket {
|
|
let result = self.handle_websocket_upgrade(
|
|
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key, cancel, &ip_str,
|
|
).await;
|
|
// Note: for WebSocket, connection_ended is called inside
|
|
// the spawned tunnel task when the connection closes.
|
|
return result;
|
|
}
|
|
|
|
// Determine backend protocol
|
|
let use_h2 = route_match.route.action.options.as_ref()
|
|
.and_then(|o| o.backend_protocol.as_ref())
|
|
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
|
|
.unwrap_or(false);
|
|
|
|
// Build the upstream path (path + query), applying URL rewriting if configured
|
|
let upstream_path = {
|
|
let raw_path = match req.uri().query() {
|
|
Some(q) => format!("{}?{}", path, q),
|
|
None => path.clone(),
|
|
};
|
|
self.apply_url_rewrite(&raw_path, &route_match.route)
|
|
};
|
|
|
|
// Build upstream request - stream body instead of buffering
|
|
let (parts, body) = req.into_parts();
|
|
|
|
// Apply request headers from route config
|
|
let mut upstream_headers = parts.headers.clone();
|
|
if let Some(ref route_headers) = route_match.route.headers {
|
|
if let Some(ref request_headers) = route_headers.request {
|
|
for (key, value) in request_headers {
|
|
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
|
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
|
upstream_headers.insert(name, val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ensure Host header is set (HTTP/2 requests don't have Host; need it for h1 backends)
|
|
if !upstream_headers.contains_key("host") {
|
|
if let Some(ref h) = host {
|
|
if let Ok(val) = hyper::header::HeaderValue::from_str(h) {
|
|
upstream_headers.insert(hyper::header::HOST, val);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add standard reverse-proxy headers (X-Forwarded-*)
|
|
{
|
|
let original_host = host.as_deref().unwrap_or("");
|
|
let forwarded_proto = if route_match.route.action.tls.as_ref()
|
|
.map(|t| matches!(t.mode,
|
|
rustproxy_config::TlsMode::Terminate
|
|
| rustproxy_config::TlsMode::TerminateAndReencrypt))
|
|
.unwrap_or(false)
|
|
{
|
|
"https"
|
|
} else {
|
|
"http"
|
|
};
|
|
|
|
// X-Forwarded-For: append client IP to existing chain
|
|
let client_ip = peer_addr.ip().to_string();
|
|
let xff_value = if let Some(existing) = upstream_headers.get("x-forwarded-for") {
|
|
format!("{}, {}", existing.to_str().unwrap_or(""), client_ip)
|
|
} else {
|
|
client_ip
|
|
};
|
|
if let Ok(val) = hyper::header::HeaderValue::from_str(&xff_value) {
|
|
upstream_headers.insert(
|
|
hyper::header::HeaderName::from_static("x-forwarded-for"),
|
|
val,
|
|
);
|
|
}
|
|
// X-Forwarded-Host: original Host header
|
|
if let Ok(val) = hyper::header::HeaderValue::from_str(original_host) {
|
|
upstream_headers.insert(
|
|
hyper::header::HeaderName::from_static("x-forwarded-host"),
|
|
val,
|
|
);
|
|
}
|
|
// X-Forwarded-Proto: original client protocol
|
|
if let Ok(val) = hyper::header::HeaderValue::from_str(forwarded_proto) {
|
|
upstream_headers.insert(
|
|
hyper::header::HeaderName::from_static("x-forwarded-proto"),
|
|
val,
|
|
);
|
|
}
|
|
}
|
|
|
|
// --- Connection pooling: try reusing an existing connection first ---
|
|
let pool_key = crate::connection_pool::PoolKey {
|
|
host: upstream.host.clone(),
|
|
port: upstream.port,
|
|
use_tls: upstream.use_tls,
|
|
h2: use_h2,
|
|
};
|
|
|
|
// Try pooled connection first (H2 only — H2 senders are Clone and multiplexed,
|
|
// so checkout doesn't consume request parts. For H1, we try pool inside forward_h1.)
|
|
if use_h2 {
|
|
if let Some(sender) = self.connection_pool.checkout_h2(&pool_key) {
|
|
let result = self.forward_h2_pooled(
|
|
sender, parts, body, upstream_headers, &upstream_path,
|
|
route_match.route, route_id, &ip_str, &pool_key,
|
|
).await;
|
|
self.upstream_selector.connection_ended(&upstream_key);
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// Fresh connection path
|
|
let backend = if upstream.use_tls {
|
|
match tokio::time::timeout(
|
|
self.connect_timeout,
|
|
connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port),
|
|
).await {
|
|
Ok(Ok(tls)) => BackendStream::Tls(tls),
|
|
Ok(Err(e)) => {
|
|
error!("Failed TLS connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
|
|
self.upstream_selector.connection_ended(&upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
|
|
}
|
|
Err(_) => {
|
|
error!("Upstream TLS connect timeout for {}:{}", upstream.host, upstream.port);
|
|
self.upstream_selector.connection_ended(&upstream_key);
|
|
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
|
|
}
|
|
}
|
|
} else {
|
|
match tokio::time::timeout(
|
|
self.connect_timeout,
|
|
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
|
|
).await {
|
|
Ok(Ok(s)) => {
|
|
s.set_nodelay(true).ok();
|
|
let _ = socket2::SockRef::from(&s).set_tcp_keepalive(
|
|
&socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60))
|
|
);
|
|
BackendStream::Plain(s)
|
|
}
|
|
Ok(Err(e)) => {
|
|
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
|
|
self.upstream_selector.connection_ended(&upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
|
}
|
|
Err(_) => {
|
|
error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port);
|
|
self.upstream_selector.connection_ended(&upstream_key);
|
|
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
|
|
}
|
|
}
|
|
};
|
|
|
|
let io = TokioIo::new(backend);
|
|
|
|
let result = if use_h2 {
|
|
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &pool_key).await
|
|
} else {
|
|
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id, &ip_str, &pool_key).await
|
|
};
|
|
self.upstream_selector.connection_ended(&upstream_key);
|
|
result
|
|
}
|
|
|
|
/// Forward request to backend via HTTP/1.1 with body streaming.
|
|
/// Tries a pooled connection first; if unavailable, uses the fresh IO connection.
|
|
async fn forward_h1(
|
|
&self,
|
|
io: TokioIo<BackendStream>,
|
|
parts: hyper::http::request::Parts,
|
|
body: Incoming,
|
|
upstream_headers: hyper::HeaderMap,
|
|
upstream_path: &str,
|
|
_upstream: &crate::upstream_selector::UpstreamSelection,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
source_ip: &str,
|
|
pool_key: &crate::connection_pool::PoolKey,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
// Try pooled H1 connection first — avoids TCP+TLS handshake
|
|
if let Some(pooled_sender) = self.connection_pool.checkout_h1(pool_key) {
|
|
return self.forward_h1_with_sender(
|
|
pooled_sender, parts, body, upstream_headers, upstream_path,
|
|
route, route_id, source_ip, pool_key,
|
|
).await;
|
|
}
|
|
|
|
// Fresh connection: explicitly type the handshake with BoxBody for uniform pool type
|
|
let (sender, conn): (
|
|
hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
|
hyper::client::conn::http1::Connection<TokioIo<BackendStream>, BoxBody<Bytes, hyper::Error>>,
|
|
) = match hyper::client::conn::http1::handshake(io).await {
|
|
Ok(h) => h,
|
|
Err(e) => {
|
|
error!("Upstream handshake failed: {}", e);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
|
|
}
|
|
};
|
|
|
|
tokio::spawn(async move {
|
|
if let Err(e) = conn.await {
|
|
debug!("Upstream connection error: {}", e);
|
|
}
|
|
});
|
|
|
|
self.forward_h1_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip, pool_key).await
|
|
}
|
|
|
|
/// Common H1 forwarding logic used by both fresh and pooled paths.
|
|
async fn forward_h1_with_sender(
|
|
&self,
|
|
mut sender: hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
|
parts: hyper::http::request::Parts,
|
|
body: Incoming,
|
|
upstream_headers: hyper::HeaderMap,
|
|
upstream_path: &str,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
source_ip: &str,
|
|
pool_key: &crate::connection_pool::PoolKey,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
// Always use HTTP/1.1 for h1 backend connections (h2 incoming requests have version HTTP/2.0)
|
|
let mut upstream_req = Request::builder()
|
|
.method(parts.method)
|
|
.uri(upstream_path)
|
|
.version(hyper::Version::HTTP_11);
|
|
|
|
if let Some(headers) = upstream_req.headers_mut() {
|
|
*headers = upstream_headers;
|
|
}
|
|
|
|
// Wrap the request body in CountingBody then box it for the uniform pool type
|
|
let counting_req_body = CountingBody::new(
|
|
body,
|
|
Arc::clone(&self.metrics),
|
|
route_id.map(|s| s.to_string()),
|
|
Some(source_ip.to_string()),
|
|
Direction::In,
|
|
);
|
|
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_req_body);
|
|
|
|
let upstream_req = upstream_req.body(boxed_body).unwrap();
|
|
|
|
let upstream_response = match sender.send_request(upstream_req).await {
|
|
Ok(resp) => resp,
|
|
Err(e) => {
|
|
error!("Upstream request failed: {}", e);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
|
|
}
|
|
};
|
|
|
|
// Return sender to pool (body streams lazily, sender is reusable once response head is received)
|
|
self.connection_pool.checkin_h1(pool_key.clone(), sender);
|
|
|
|
self.build_streaming_response(upstream_response, route, route_id, source_ip).await
|
|
}
|
|
|
|
/// Forward request to backend via HTTP/2 with body streaming (fresh connection).
|
|
/// Registers the h2 sender in the pool for future multiplexed requests.
|
|
async fn forward_h2(
|
|
&self,
|
|
io: TokioIo<BackendStream>,
|
|
parts: hyper::http::request::Parts,
|
|
body: Incoming,
|
|
upstream_headers: hyper::HeaderMap,
|
|
upstream_path: &str,
|
|
_upstream: &crate::upstream_selector::UpstreamSelection,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
source_ip: &str,
|
|
pool_key: &crate::connection_pool::PoolKey,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let exec = hyper_util::rt::TokioExecutor::new();
|
|
// Explicitly type the handshake with BoxBody for uniform pool type
|
|
let (sender, conn): (
|
|
hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
|
hyper::client::conn::http2::Connection<TokioIo<BackendStream>, BoxBody<Bytes, hyper::Error>, hyper_util::rt::TokioExecutor>,
|
|
) = match hyper::client::conn::http2::handshake(exec, io).await {
|
|
Ok(h) => h,
|
|
Err(e) => {
|
|
error!("HTTP/2 upstream handshake failed: {}", e);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
|
|
}
|
|
};
|
|
|
|
tokio::spawn(async move {
|
|
if let Err(e) = conn.await {
|
|
debug!("HTTP/2 upstream connection error: {}", e);
|
|
}
|
|
});
|
|
|
|
// Register for multiplexed reuse
|
|
self.connection_pool.register_h2(pool_key.clone(), sender.clone());
|
|
|
|
self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await
|
|
}
|
|
|
|
/// Forward request using an existing (pooled) HTTP/2 sender.
|
|
async fn forward_h2_pooled(
|
|
&self,
|
|
sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
|
parts: hyper::http::request::Parts,
|
|
body: Incoming,
|
|
upstream_headers: hyper::HeaderMap,
|
|
upstream_path: &str,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
source_ip: &str,
|
|
_pool_key: &crate::connection_pool::PoolKey,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
self.forward_h2_with_sender(sender, parts, body, upstream_headers, upstream_path, route, route_id, source_ip).await
|
|
}
|
|
|
|
/// Common H2 forwarding logic used by both fresh and pooled paths.
|
|
async fn forward_h2_with_sender(
|
|
&self,
|
|
mut sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
|
parts: hyper::http::request::Parts,
|
|
body: Incoming,
|
|
upstream_headers: hyper::HeaderMap,
|
|
upstream_path: &str,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
source_ip: &str,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let mut upstream_req = Request::builder()
|
|
.method(parts.method)
|
|
.uri(upstream_path);
|
|
|
|
if let Some(headers) = upstream_req.headers_mut() {
|
|
*headers = upstream_headers;
|
|
}
|
|
|
|
// Wrap the request body in CountingBody then box it for the uniform pool type
|
|
let counting_req_body = CountingBody::new(
|
|
body,
|
|
Arc::clone(&self.metrics),
|
|
route_id.map(|s| s.to_string()),
|
|
Some(source_ip.to_string()),
|
|
Direction::In,
|
|
);
|
|
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_req_body);
|
|
|
|
let upstream_req = upstream_req.body(boxed_body).unwrap();
|
|
|
|
let upstream_response = match sender.send_request(upstream_req).await {
|
|
Ok(resp) => resp,
|
|
Err(e) => {
|
|
error!("HTTP/2 upstream request failed: {}", e);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
|
|
}
|
|
};
|
|
|
|
self.build_streaming_response(upstream_response, route, route_id, source_ip).await
|
|
}
|
|
|
|
/// Build the client-facing response from an upstream response, streaming the body.
|
|
///
|
|
/// The response body is wrapped in a `CountingBody` that counts bytes as they
|
|
/// stream from upstream to client.
|
|
async fn build_streaming_response(
|
|
&self,
|
|
upstream_response: Response<Incoming>,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
source_ip: &str,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let (resp_parts, resp_body) = upstream_response.into_parts();
|
|
|
|
let mut response = Response::builder()
|
|
.status(resp_parts.status);
|
|
|
|
if let Some(headers) = response.headers_mut() {
|
|
*headers = resp_parts.headers;
|
|
ResponseFilter::apply_headers(route, headers, None);
|
|
}
|
|
|
|
// Wrap the response body in CountingBody to track bytes_out.
|
|
// CountingBody will report bytes and we close the connection metric
|
|
// after the body stream completes (not before it even starts).
|
|
let counting_body = CountingBody::new(
|
|
resp_body,
|
|
Arc::clone(&self.metrics),
|
|
route_id.map(|s| s.to_string()),
|
|
Some(source_ip.to_string()),
|
|
Direction::Out,
|
|
);
|
|
|
|
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_body);
|
|
|
|
Ok(response.body(body).unwrap())
|
|
}
|
|
|
|
/// Handle a WebSocket upgrade request.
|
|
async fn handle_websocket_upgrade(
|
|
&self,
|
|
req: Request<Incoming>,
|
|
peer_addr: std::net::SocketAddr,
|
|
upstream: &crate::upstream_selector::UpstreamSelection,
|
|
route: &rustproxy_config::RouteConfig,
|
|
route_id: Option<&str>,
|
|
upstream_key: &str,
|
|
cancel: CancellationToken,
|
|
source_ip: &str,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
// Get WebSocket config from route
|
|
let ws_config = route.action.websocket.as_ref();
|
|
|
|
// Check allowed origins if configured
|
|
if let Some(ws) = ws_config {
|
|
if let Some(ref allowed_origins) = ws.allowed_origins {
|
|
let origin = req.headers()
|
|
.get("origin")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("");
|
|
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
|
|
}
|
|
}
|
|
}
|
|
|
|
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
|
|
|
|
// Connect to upstream with timeout (TLS if upstream.use_tls is set)
|
|
let mut upstream_stream: BackendStream = if upstream.use_tls {
|
|
match tokio::time::timeout(
|
|
self.connect_timeout,
|
|
connect_tls_backend(&self.backend_tls_config, &upstream.host, upstream.port),
|
|
).await {
|
|
Ok(Ok(tls)) => BackendStream::Tls(tls),
|
|
Ok(Err(e)) => {
|
|
error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
|
|
}
|
|
Err(_) => {
|
|
error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
|
|
}
|
|
}
|
|
} else {
|
|
match tokio::time::timeout(
|
|
self.connect_timeout,
|
|
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
|
|
).await {
|
|
Ok(Ok(s)) => {
|
|
s.set_nodelay(true).ok();
|
|
let _ = socket2::SockRef::from(&s).set_tcp_keepalive(
|
|
&socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(60))
|
|
);
|
|
BackendStream::Plain(s)
|
|
}
|
|
Ok(Err(e)) => {
|
|
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
|
}
|
|
Err(_) => {
|
|
error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
|
|
}
|
|
}
|
|
};
|
|
|
|
let path = req.uri().path().to_string();
|
|
let upstream_path = {
|
|
let raw = match req.uri().query() {
|
|
Some(q) => format!("{}?{}", path, q),
|
|
None => path,
|
|
};
|
|
// Apply rewrite_path if configured
|
|
if let Some(ws) = ws_config {
|
|
if let Some(ref rewrite_path) = ws.rewrite_path {
|
|
rewrite_path.clone()
|
|
} else {
|
|
raw
|
|
}
|
|
} else {
|
|
raw
|
|
}
|
|
};
|
|
|
|
let (parts, _body) = req.into_parts();
|
|
|
|
let mut raw_request = format!(
|
|
"{} {} HTTP/1.1\r\n",
|
|
parts.method, upstream_path
|
|
);
|
|
|
|
// Copy all original headers (preserving the client's Host header).
|
|
// Skip X-Forwarded-* since we set them ourselves below.
|
|
let mut has_host_header = false;
|
|
for (name, value) in parts.headers.iter() {
|
|
let name_str = name.as_str();
|
|
if name_str == "x-forwarded-for"
|
|
|| name_str == "x-forwarded-host"
|
|
|| name_str == "x-forwarded-proto"
|
|
{
|
|
continue;
|
|
}
|
|
if name_str == "host" {
|
|
has_host_header = true;
|
|
}
|
|
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
|
|
}
|
|
|
|
// HTTP/2 requests don't have Host header; add one from URI authority for h1 backends
|
|
let ws_host = parts.uri.host().map(|h| h.to_string());
|
|
if !has_host_header {
|
|
if let Some(ref h) = ws_host {
|
|
raw_request.push_str(&format!("host: {}\r\n", h));
|
|
}
|
|
}
|
|
|
|
// Add standard reverse-proxy headers (X-Forwarded-*)
|
|
{
|
|
let original_host = parts.headers.get("host")
|
|
.and_then(|h| h.to_str().ok())
|
|
.or(ws_host.as_deref())
|
|
.unwrap_or("");
|
|
let forwarded_proto = if route.action.tls.as_ref()
|
|
.map(|t| matches!(t.mode,
|
|
rustproxy_config::TlsMode::Terminate
|
|
| rustproxy_config::TlsMode::TerminateAndReencrypt))
|
|
.unwrap_or(false)
|
|
{
|
|
"https"
|
|
} else {
|
|
"http"
|
|
};
|
|
|
|
let client_ip = peer_addr.ip().to_string();
|
|
let xff_value = if let Some(existing) = parts.headers.get("x-forwarded-for") {
|
|
format!("{}, {}", existing.to_str().unwrap_or(""), client_ip)
|
|
} else {
|
|
client_ip
|
|
};
|
|
raw_request.push_str(&format!("x-forwarded-for: {}\r\n", xff_value));
|
|
raw_request.push_str(&format!("x-forwarded-host: {}\r\n", original_host));
|
|
raw_request.push_str(&format!("x-forwarded-proto: {}\r\n", forwarded_proto));
|
|
}
|
|
|
|
if let Some(ref route_headers) = route.headers {
|
|
if let Some(ref request_headers) = route_headers.request {
|
|
for (key, value) in request_headers {
|
|
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply WebSocket custom headers
|
|
if let Some(ws) = ws_config {
|
|
if let Some(ref custom_headers) = ws.custom_headers {
|
|
for (key, value) in custom_headers {
|
|
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
|
}
|
|
}
|
|
}
|
|
|
|
raw_request.push_str("\r\n");
|
|
|
|
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
|
|
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
|
|
}
|
|
|
|
let mut response_buf = Vec::with_capacity(4096);
|
|
let mut temp = [0u8; 1];
|
|
loop {
|
|
match upstream_stream.read(&mut temp).await {
|
|
Ok(0) => {
|
|
error!("WebSocket: upstream closed before completing handshake");
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
|
|
}
|
|
Ok(_) => {
|
|
response_buf.push(temp[0]);
|
|
if response_buf.len() >= 4 {
|
|
let len = response_buf.len();
|
|
if response_buf[len-4..] == *b"\r\n\r\n" {
|
|
break;
|
|
}
|
|
}
|
|
if response_buf.len() > 8192 {
|
|
error!("WebSocket: upstream response headers too large");
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!("WebSocket: failed to read upstream response: {}", e);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
|
|
}
|
|
}
|
|
}
|
|
|
|
let response_str = String::from_utf8_lossy(&response_buf);
|
|
|
|
let status_line = response_str.lines().next().unwrap_or("");
|
|
let status_code = status_line
|
|
.split_whitespace()
|
|
.nth(1)
|
|
.and_then(|s| s.parse::<u16>().ok())
|
|
.unwrap_or(0);
|
|
|
|
if status_code != 101 {
|
|
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
|
|
self.upstream_selector.connection_ended(upstream_key);
|
|
return Ok(error_response(
|
|
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
|
|
"WebSocket upgrade rejected by backend",
|
|
));
|
|
}
|
|
|
|
let mut client_resp = Response::builder()
|
|
.status(StatusCode::SWITCHING_PROTOCOLS);
|
|
|
|
if let Some(resp_headers) = client_resp.headers_mut() {
|
|
for line in response_str.lines().skip(1) {
|
|
let line = line.trim();
|
|
if line.is_empty() {
|
|
break;
|
|
}
|
|
if let Some((name, value)) = line.split_once(':') {
|
|
let name = name.trim();
|
|
let value = value.trim();
|
|
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
|
|
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
|
|
resp_headers.insert(header_name, header_value);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let on_client_upgrade = hyper::upgrade::on(
|
|
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
|
|
);
|
|
|
|
let metrics = Arc::clone(&self.metrics);
|
|
let route_id_owned = route_id.map(|s| s.to_string());
|
|
let source_ip_owned = source_ip.to_string();
|
|
let upstream_selector = self.upstream_selector.clone();
|
|
let upstream_key_owned = upstream_key.to_string();
|
|
|
|
tokio::spawn(async move {
|
|
let client_upgraded = match on_client_upgrade.await {
|
|
Ok(upgraded) => upgraded,
|
|
Err(e) => {
|
|
debug!("WebSocket: client upgrade failed: {}", e);
|
|
upstream_selector.connection_ended(&upstream_key_owned);
|
|
return;
|
|
}
|
|
};
|
|
|
|
let client_io = TokioIo::new(client_upgraded);
|
|
|
|
let (mut cr, mut cw) = tokio::io::split(client_io);
|
|
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
|
|
|
|
// Shared activity tracker for the watchdog
|
|
let last_activity = Arc::new(AtomicU64::new(0));
|
|
let start = std::time::Instant::now();
|
|
|
|
let la1 = Arc::clone(&last_activity);
|
|
let c2u = tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
let mut total = 0u64;
|
|
loop {
|
|
let n = match cr.read(&mut buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
if uw.write_all(&buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
total += n as u64;
|
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
|
}
|
|
let _ = uw.shutdown().await;
|
|
total
|
|
});
|
|
|
|
let la2 = Arc::clone(&last_activity);
|
|
let u2c = tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
let mut total = 0u64;
|
|
loop {
|
|
let n = match ur.read(&mut buf).await {
|
|
Ok(0) | Err(_) => break,
|
|
Ok(n) => n,
|
|
};
|
|
if cw.write_all(&buf[..n]).await.is_err() {
|
|
break;
|
|
}
|
|
total += n as u64;
|
|
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
|
}
|
|
let _ = cw.shutdown().await;
|
|
total
|
|
});
|
|
|
|
// Watchdog: monitors inactivity, max lifetime, and cancellation
|
|
let la_watch = Arc::clone(&last_activity);
|
|
let c2u_handle = c2u.abort_handle();
|
|
let u2c_handle = u2c.abort_handle();
|
|
let inactivity_timeout = DEFAULT_WS_INACTIVITY_TIMEOUT;
|
|
let max_lifetime = DEFAULT_WS_MAX_LIFETIME;
|
|
|
|
let watchdog = tokio::spawn(async move {
|
|
let check_interval = std::time::Duration::from_secs(5);
|
|
let mut last_seen = 0u64;
|
|
loop {
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(check_interval) => {}
|
|
_ = cancel.cancelled() => {
|
|
debug!("WebSocket tunnel cancelled by shutdown");
|
|
c2u_handle.abort();
|
|
u2c_handle.abort();
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Check max lifetime
|
|
if start.elapsed() >= max_lifetime {
|
|
debug!("WebSocket tunnel exceeded max lifetime, closing");
|
|
c2u_handle.abort();
|
|
u2c_handle.abort();
|
|
break;
|
|
}
|
|
|
|
// Check inactivity
|
|
let current = la_watch.load(Ordering::Relaxed);
|
|
if current == last_seen {
|
|
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
|
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
|
debug!("WebSocket tunnel inactive for {}ms, closing", elapsed_since_activity);
|
|
c2u_handle.abort();
|
|
u2c_handle.abort();
|
|
break;
|
|
}
|
|
}
|
|
last_seen = current;
|
|
}
|
|
});
|
|
|
|
let bytes_in = c2u.await.unwrap_or(0);
|
|
let bytes_out = u2c.await.unwrap_or(0);
|
|
watchdog.abort();
|
|
|
|
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
|
|
|
upstream_selector.connection_ended(&upstream_key_owned);
|
|
if let Some(ref rid) = route_id_owned {
|
|
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()), Some(&source_ip_owned));
|
|
}
|
|
});
|
|
|
|
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
|
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
|
|
);
|
|
Ok(client_resp.body(body).unwrap())
|
|
}
|
|
|
|
/// Build a test response from config (no upstream connection needed).
|
|
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
|
|
let mut response = Response::builder()
|
|
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
|
|
|
|
if let Some(headers) = response.headers_mut() {
|
|
for (key, value) in &config.headers {
|
|
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
|
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
|
headers.insert(name, val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let body = Full::new(Bytes::from(config.body.clone()))
|
|
.map_err(|never| match never {});
|
|
response.body(BoxBody::new(body)).unwrap()
|
|
}
|
|
|
|
/// Apply URL rewriting rules from route config, using the compiled regex cache.
|
|
fn apply_url_rewrite(&self, path: &str, route: &rustproxy_config::RouteConfig) -> String {
|
|
let rewrite = match route.action.advanced.as_ref()
|
|
.and_then(|a| a.url_rewrite.as_ref())
|
|
{
|
|
Some(r) => r,
|
|
None => return path.to_string(),
|
|
};
|
|
|
|
// Determine what to rewrite
|
|
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
|
|
// Only rewrite the path portion (before ?)
|
|
match path.split_once('?') {
|
|
Some((p, q)) => (p.to_string(), format!("?{}", q)),
|
|
None => (path.to_string(), String::new()),
|
|
}
|
|
} else {
|
|
(path.to_string(), String::new())
|
|
};
|
|
|
|
// Look up or compile the regex, caching for future requests
|
|
let cached = self.regex_cache.get(&rewrite.pattern);
|
|
if let Some(re) = cached {
|
|
let result = re.replace_all(&subject, rewrite.target.as_str());
|
|
return format!("{}{}", result, suffix);
|
|
}
|
|
|
|
// Not cached — compile and insert
|
|
match Regex::new(&rewrite.pattern) {
|
|
Ok(re) => {
|
|
let result = re.replace_all(&subject, rewrite.target.as_str());
|
|
let out = format!("{}{}", result, suffix);
|
|
self.regex_cache.insert(rewrite.pattern.clone(), re);
|
|
out
|
|
}
|
|
Err(e) => {
|
|
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
|
|
path.to_string()
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Serve a static file from the configured directory.
|
|
fn serve_static_file(
|
|
path: &str,
|
|
config: &rustproxy_config::RouteStaticFiles,
|
|
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
|
use std::path::Path;
|
|
|
|
let root = Path::new(&config.root);
|
|
|
|
// Sanitize path to prevent directory traversal
|
|
let clean_path = path.trim_start_matches('/');
|
|
let clean_path = clean_path.replace("..", "");
|
|
|
|
let mut file_path = root.join(&clean_path);
|
|
|
|
// If path points to a directory, try index files
|
|
if file_path.is_dir() || clean_path.is_empty() {
|
|
let index_files = config.index_files.as_deref()
|
|
.or(config.index.as_deref())
|
|
.unwrap_or(&[]);
|
|
let default_index = vec!["index.html".to_string()];
|
|
let index_files = if index_files.is_empty() { &default_index } else { index_files };
|
|
|
|
let mut found = false;
|
|
for index in index_files {
|
|
let candidate = if clean_path.is_empty() {
|
|
root.join(index)
|
|
} else {
|
|
file_path.join(index)
|
|
};
|
|
if candidate.is_file() {
|
|
file_path = candidate;
|
|
found = true;
|
|
break;
|
|
}
|
|
}
|
|
if !found {
|
|
return error_response(StatusCode::NOT_FOUND, "Not found");
|
|
}
|
|
}
|
|
|
|
// Ensure the resolved path is within the root (prevent traversal)
|
|
let canonical_root = match root.canonicalize() {
|
|
Ok(p) => p,
|
|
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
|
};
|
|
let canonical_file = match file_path.canonicalize() {
|
|
Ok(p) => p,
|
|
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
|
};
|
|
if !canonical_file.starts_with(&canonical_root) {
|
|
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
|
}
|
|
|
|
// Check if symlinks are allowed
|
|
if config.follow_symlinks == Some(false) && canonical_file != file_path {
|
|
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
|
}
|
|
|
|
// Read the file
|
|
match std::fs::read(&file_path) {
|
|
Ok(content) => {
|
|
let content_type = guess_content_type(&file_path);
|
|
let mut response = Response::builder()
|
|
.status(StatusCode::OK)
|
|
.header("Content-Type", content_type);
|
|
|
|
// Apply cache-control if configured
|
|
if let Some(ref cache_control) = config.cache_control {
|
|
response = response.header("Cache-Control", cache_control.as_str());
|
|
}
|
|
|
|
// Apply custom headers
|
|
if let Some(ref headers) = config.headers {
|
|
for (key, value) in headers {
|
|
response = response.header(key.as_str(), value.as_str());
|
|
}
|
|
}
|
|
|
|
let body = Full::new(Bytes::from(content))
|
|
.map_err(|never| match never {});
|
|
response.body(BoxBody::new(body)).unwrap()
|
|
}
|
|
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Guess MIME content type from file extension.
|
|
fn guess_content_type(path: &std::path::Path) -> &'static str {
|
|
match path.extension().and_then(|e| e.to_str()) {
|
|
Some("html") | Some("htm") => "text/html; charset=utf-8",
|
|
Some("css") => "text/css; charset=utf-8",
|
|
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
|
|
Some("json") => "application/json; charset=utf-8",
|
|
Some("xml") => "application/xml; charset=utf-8",
|
|
Some("txt") => "text/plain; charset=utf-8",
|
|
Some("png") => "image/png",
|
|
Some("jpg") | Some("jpeg") => "image/jpeg",
|
|
Some("gif") => "image/gif",
|
|
Some("svg") => "image/svg+xml",
|
|
Some("ico") => "image/x-icon",
|
|
Some("woff") => "font/woff",
|
|
Some("woff2") => "font/woff2",
|
|
Some("ttf") => "font/ttf",
|
|
Some("pdf") => "application/pdf",
|
|
Some("wasm") => "application/wasm",
|
|
_ => "application/octet-stream",
|
|
}
|
|
}
|
|
|
|
impl HttpProxyService {
|
|
/// Build a default backend TLS config with InsecureVerifier.
|
|
/// Used as fallback when no shared config is injected from tls_handler.
|
|
fn default_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
|
let config = rustls::ClientConfig::builder()
|
|
.dangerous()
|
|
.with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier))
|
|
.with_no_client_auth();
|
|
Arc::new(config)
|
|
}
|
|
}
|
|
|
|
/// Insecure certificate verifier for backend TLS connections (fallback only).
|
|
/// The production path uses the shared config from tls_handler which has the same
|
|
/// behavior but with session resumption across all outbound connections.
|
|
#[derive(Debug)]
|
|
struct InsecureBackendVerifier;
|
|
|
|
impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier {
|
|
fn verify_server_cert(
|
|
&self,
|
|
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
|
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
|
_server_name: &rustls::pki_types::ServerName<'_>,
|
|
_ocsp_response: &[u8],
|
|
_now: rustls::pki_types::UnixTime,
|
|
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
|
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
|
}
|
|
|
|
fn verify_tls12_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn verify_tls13_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
|
vec![
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
|
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
|
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
|
rustls::SignatureScheme::ED25519,
|
|
rustls::SignatureScheme::RSA_PSS_SHA256,
|
|
rustls::SignatureScheme::RSA_PSS_SHA384,
|
|
rustls::SignatureScheme::RSA_PSS_SHA512,
|
|
]
|
|
}
|
|
}
|
|
|
|
impl Default for HttpProxyService {
|
|
fn default() -> Self {
|
|
Self {
|
|
route_manager: Arc::new(RouteManager::new(vec![])),
|
|
metrics: Arc::new(MetricsCollector::new()),
|
|
upstream_selector: UpstreamSelector::new(),
|
|
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
|
route_rate_limiters: Arc::new(DashMap::new()),
|
|
request_counter: AtomicU64::new(0),
|
|
regex_cache: DashMap::new(),
|
|
backend_tls_config: Self::default_backend_tls_config(),
|
|
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
|
let body = Full::new(Bytes::from(message.to_string()))
|
|
.map_err(|never| match never {});
|
|
Response::builder()
|
|
.status(status)
|
|
.header("Content-Type", "text/plain")
|
|
.body(BoxBody::new(body))
|
|
.unwrap()
|
|
}
|