feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
This commit is contained in:
@@ -0,0 +1,827 @@
|
||||
//! Hyper-based HTTP proxy service.
|
||||
//!
|
||||
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
|
||||
//! parses HTTP requests, matches routes, and forwards to upstream backends.
|
||||
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use regex::Regex;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
|
||||
use crate::request_filter::RequestFilter;
|
||||
use crate::response_filter::ResponseFilter;
|
||||
use crate::upstream_selector::UpstreamSelector;
|
||||
|
||||
/// HTTP proxy service that processes HTTP traffic.
|
||||
pub struct HttpProxyService {
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
upstream_selector: UpstreamSelector,
|
||||
}
|
||||
|
||||
impl HttpProxyService {
|
||||
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
Self {
|
||||
route_manager,
|
||||
metrics,
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on a plain TCP stream.
|
||||
pub async fn handle_connection(
|
||||
self: Arc<Self>,
|
||||
stream: TcpStream,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) {
|
||||
self.handle_io(stream, peer_addr, port).await;
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
|
||||
///
|
||||
/// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2,
|
||||
/// use `handle_io_auto` instead.
|
||||
pub async fn handle_io<I>(
|
||||
self: Arc<Self>,
|
||||
stream: I,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
)
|
||||
where
|
||||
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
let svc = Arc::clone(&self);
|
||||
let peer = peer_addr;
|
||||
async move {
|
||||
svc.handle_request(req, peer, port).await
|
||||
}
|
||||
});
|
||||
|
||||
// Use http1::Builder with upgrades for WebSocket support
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, service)
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP connection error from {}: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single HTTP request.
|
||||
async fn handle_request(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let host = req.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| {
|
||||
// Strip port from host header
|
||||
h.split(':').next().unwrap_or(h).to_string()
|
||||
});
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let method = req.method().clone();
|
||||
|
||||
// Extract headers for matching
|
||||
let headers: HashMap<String, String> = req.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
|
||||
|
||||
// Check for CORS preflight
|
||||
if method == hyper::Method::OPTIONS {
|
||||
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: host.as_deref(),
|
||||
path: Some(&path),
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: Some(&headers),
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let route_match = match self.route_manager.find_route(&ctx) {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for HTTP request to {:?}{}", host, path);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
self.metrics.connection_opened(route_id);
|
||||
|
||||
// Apply request filters (IP check, rate limiting, auth)
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for test response (returns immediately, no upstream needed)
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref test_response) = advanced.test_response {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::build_test_response(test_response));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for static file serving
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref static_files) = advanced.static_files {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::serve_static_file(&path, static_files));
|
||||
}
|
||||
}
|
||||
|
||||
// Select upstream
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
|
||||
}
|
||||
};
|
||||
|
||||
let upstream = self.upstream_selector.select(target, &peer_addr, port);
|
||||
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
|
||||
self.upstream_selector.connection_started(&upstream_key);
|
||||
|
||||
// Check for WebSocket upgrade
|
||||
let is_websocket = req.headers()
|
||||
.get("upgrade")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_websocket {
|
||||
let result = self.handle_websocket_upgrade(
|
||||
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key,
|
||||
).await;
|
||||
// Note: for WebSocket, connection_ended is called inside
|
||||
// the spawned tunnel task when the connection closes.
|
||||
return result;
|
||||
}
|
||||
|
||||
// Determine backend protocol
|
||||
let use_h2 = route_match.route.action.options.as_ref()
|
||||
.and_then(|o| o.backend_protocol.as_ref())
|
||||
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
|
||||
.unwrap_or(false);
|
||||
|
||||
// Build the upstream path (path + query), applying URL rewriting if configured
|
||||
let upstream_path = {
|
||||
let raw_path = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path.clone(),
|
||||
};
|
||||
Self::apply_url_rewrite(&raw_path, &route_match.route)
|
||||
};
|
||||
|
||||
// Build upstream request - stream body instead of buffering
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
// Apply request headers from route config
|
||||
let mut upstream_headers = parts.headers.clone();
|
||||
if let Some(ref route_headers) = route_match.route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
upstream_headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to upstream
|
||||
let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let io = TokioIo::new(upstream_stream);
|
||||
|
||||
let result = if use_h2 {
|
||||
// HTTP/2 backend
|
||||
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
} else {
|
||||
// HTTP/1.1 backend (default)
|
||||
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
};
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
result
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/1.1 with body streaming.
|
||||
async fn forward_h1(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("Upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("Upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path)
|
||||
.version(parts.version);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("Upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/2 with body streaming.
|
||||
async fn forward_h2(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let exec = hyper_util::rt::TokioExecutor::new();
|
||||
let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP/2 upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Build the client-facing response from an upstream response, streaming the body.
|
||||
async fn build_streaming_response(
|
||||
&self,
|
||||
upstream_response: Response<Incoming>,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (resp_parts, resp_body) = upstream_response.into_parts();
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(resp_parts.status);
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
*headers = resp_parts.headers;
|
||||
ResponseFilter::apply_headers(route, headers, None);
|
||||
}
|
||||
|
||||
self.metrics.connection_closed(route_id);
|
||||
|
||||
// Stream the response body directly from upstream to client
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(resp_body);
|
||||
|
||||
Ok(response.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Handle a WebSocket upgrade request.
|
||||
async fn handle_websocket_upgrade(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
upstream_key: &str,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Get WebSocket config from route
|
||||
let ws_config = route.action.websocket.as_ref();
|
||||
|
||||
// Check allowed origins if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref allowed_origins) = ws.allowed_origins {
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
|
||||
|
||||
let mut upstream_stream = match TcpStream::connect(
|
||||
format!("{}:{}", upstream.host, upstream.port)
|
||||
).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let upstream_path = {
|
||||
let raw = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path,
|
||||
};
|
||||
// Apply rewrite_path if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref rewrite_path) = ws.rewrite_path {
|
||||
rewrite_path.clone()
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
};
|
||||
|
||||
let (parts, _body) = req.into_parts();
|
||||
|
||||
let mut raw_request = format!(
|
||||
"{} {} HTTP/1.1\r\n",
|
||||
parts.method, upstream_path
|
||||
);
|
||||
|
||||
let upstream_host = format!("{}:{}", upstream.host, upstream.port);
|
||||
for (name, value) in parts.headers.iter() {
|
||||
if name == hyper::header::HOST {
|
||||
raw_request.push_str(&format!("host: {}\r\n", upstream_host));
|
||||
} else {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply WebSocket custom headers
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref custom_headers) = ws.custom_headers {
|
||||
for (key, value) in custom_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
raw_request.push_str("\r\n");
|
||||
|
||||
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
|
||||
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
|
||||
}
|
||||
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
match upstream_stream.read(&mut temp).await {
|
||||
Ok(0) => {
|
||||
error!("WebSocket: upstream closed before completing handshake");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
|
||||
}
|
||||
Ok(_) => {
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if response_buf.len() > 8192 {
|
||||
error!("WebSocket: upstream response headers too large");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to read upstream response: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf);
|
||||
|
||||
let status_line = response_str.lines().next().unwrap_or("");
|
||||
let status_code = status_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|s| s.parse::<u16>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if status_code != 101 {
|
||||
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(
|
||||
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
|
||||
"WebSocket upgrade rejected by backend",
|
||||
));
|
||||
}
|
||||
|
||||
let mut client_resp = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS);
|
||||
|
||||
if let Some(resp_headers) = client_resp.headers_mut() {
|
||||
for line in response_str.lines().skip(1) {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
break;
|
||||
}
|
||||
if let Some((name, value)) = line.split_once(':') {
|
||||
let name = name.trim();
|
||||
let value = value.trim();
|
||||
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
|
||||
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
|
||||
resp_headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let on_client_upgrade = hyper::upgrade::on(
|
||||
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
|
||||
);
|
||||
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let route_id_owned = route_id.map(|s| s.to_string());
|
||||
let upstream_selector = self.upstream_selector.clone();
|
||||
let upstream_key_owned = upstream_key.to_string();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client_upgraded = match on_client_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(e) => {
|
||||
debug!("WebSocket: client upgrade failed: {}", e);
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let client_io = TokioIo::new(client_upgraded);
|
||||
|
||||
let (mut cr, mut cw) = tokio::io::split(client_io);
|
||||
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
|
||||
|
||||
let c2u = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match cr.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if uw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = uw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let u2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match ur.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if cw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = cw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let bytes_in = c2u.await.unwrap_or(0);
|
||||
let bytes_out = u2c.await.unwrap_or(0);
|
||||
|
||||
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
||||
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()));
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
});
|
||||
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
||||
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
|
||||
);
|
||||
Ok(client_resp.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Build a test response from config (no upstream connection needed).
|
||||
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (key, value) in &config.headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(config.body.clone()))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
|
||||
/// Apply URL rewriting rules from route config.
|
||||
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
|
||||
let rewrite = match route.action.advanced.as_ref()
|
||||
.and_then(|a| a.url_rewrite.as_ref())
|
||||
{
|
||||
Some(r) => r,
|
||||
None => return path.to_string(),
|
||||
};
|
||||
|
||||
// Determine what to rewrite
|
||||
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
|
||||
// Only rewrite the path portion (before ?)
|
||||
match path.split_once('?') {
|
||||
Some((p, q)) => (p.to_string(), format!("?{}", q)),
|
||||
None => (path.to_string(), String::new()),
|
||||
}
|
||||
} else {
|
||||
(path.to_string(), String::new())
|
||||
};
|
||||
|
||||
match Regex::new(&rewrite.pattern) {
|
||||
Ok(re) => {
|
||||
let result = re.replace_all(&subject, rewrite.target.as_str());
|
||||
format!("{}{}", result, suffix)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
|
||||
path.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serve a static file from the configured directory.
|
||||
fn serve_static_file(
|
||||
path: &str,
|
||||
config: &rustproxy_config::RouteStaticFiles,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
use std::path::Path;
|
||||
|
||||
let root = Path::new(&config.root);
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
let clean_path = path.trim_start_matches('/');
|
||||
let clean_path = clean_path.replace("..", "");
|
||||
|
||||
let mut file_path = root.join(&clean_path);
|
||||
|
||||
// If path points to a directory, try index files
|
||||
if file_path.is_dir() || clean_path.is_empty() {
|
||||
let index_files = config.index_files.as_deref()
|
||||
.or(config.index.as_deref())
|
||||
.unwrap_or(&[]);
|
||||
let default_index = vec!["index.html".to_string()];
|
||||
let index_files = if index_files.is_empty() { &default_index } else { index_files };
|
||||
|
||||
let mut found = false;
|
||||
for index in index_files {
|
||||
let candidate = if clean_path.is_empty() {
|
||||
root.join(index)
|
||||
} else {
|
||||
file_path.join(index)
|
||||
};
|
||||
if candidate.is_file() {
|
||||
file_path = candidate;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return error_response(StatusCode::NOT_FOUND, "Not found");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the resolved path is within the root (prevent traversal)
|
||||
let canonical_root = match root.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
let canonical_file = match file_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
if !canonical_file.starts_with(&canonical_root) {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Check if symlinks are allowed
|
||||
if config.follow_symlinks == Some(false) && canonical_file != file_path {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Read the file
|
||||
match std::fs::read(&file_path) {
|
||||
Ok(content) => {
|
||||
let content_type = guess_content_type(&file_path);
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", content_type);
|
||||
|
||||
// Apply cache-control if configured
|
||||
if let Some(ref cache_control) = config.cache_control {
|
||||
response = response.header("Cache-Control", cache_control.as_str());
|
||||
}
|
||||
|
||||
// Apply custom headers
|
||||
if let Some(ref headers) = config.headers {
|
||||
for (key, value) in headers {
|
||||
response = response.header(key.as_str(), value.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(content))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Guess MIME content type from file extension.
|
||||
fn guess_content_type(path: &std::path::Path) -> &'static str {
|
||||
match path.extension().and_then(|e| e.to_str()) {
|
||||
Some("html") | Some("htm") => "text/html; charset=utf-8",
|
||||
Some("css") => "text/css; charset=utf-8",
|
||||
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
|
||||
Some("json") => "application/json; charset=utf-8",
|
||||
Some("xml") => "application/xml; charset=utf-8",
|
||||
Some("txt") => "text/plain; charset=utf-8",
|
||||
Some("png") => "image/png",
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg",
|
||||
Some("gif") => "image/gif",
|
||||
Some("svg") => "image/svg+xml",
|
||||
Some("ico") => "image/x-icon",
|
||||
Some("woff") => "font/woff",
|
||||
Some("woff2") => "font/woff2",
|
||||
Some("ttf") => "font/ttf",
|
||||
Some("pdf") => "application/pdf",
|
||||
Some("wasm") => "application/wasm",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HttpProxyService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
route_manager: Arc::new(RouteManager::new(vec![])),
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let body = Full::new(Bytes::from(message.to_string()))
|
||||
.map_err(|never| match never {});
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(BoxBody::new(body))
|
||||
.unwrap()
|
||||
}
|
||||
Reference in New Issue
Block a user