feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates

This commit is contained in:
2026-02-09 10:55:46 +00:00
parent a31fee41df
commit 1df3b7af4a
151 changed files with 16927 additions and 19432 deletions

View File

@@ -0,0 +1,24 @@
[package]
name = "rustproxy-http"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Hyper-based HTTP proxy service for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-security = { workspace = true }
rustproxy-metrics = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
regex = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
dashmap = { workspace = true }

View File

@@ -0,0 +1,14 @@
//! # rustproxy-http
//!
//! Hyper-based HTTP proxy service for RustProxy.
//! Handles HTTP request parsing, route-based forwarding, and response filtering.
pub mod proxy_service;
pub mod request_filter;
pub mod response_filter;
pub mod template;
pub mod upstream_selector;
pub use proxy_service::*;
pub use template::*;
pub use upstream_selector::*;

View File

@@ -0,0 +1,827 @@
//! Hyper-based HTTP proxy service.
//!
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
//! parses HTTP requests, matches routes, and forwards to upstream backends.
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
use std::collections::HashMap;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use regex::Regex;
use tokio::net::TcpStream;
use tracing::{debug, error, info, warn};
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use crate::request_filter::RequestFilter;
use crate::response_filter::ResponseFilter;
use crate::upstream_selector::UpstreamSelector;
/// HTTP proxy service that processes HTTP traffic.
pub struct HttpProxyService {
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
upstream_selector: UpstreamSelector,
}
impl HttpProxyService {
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
Self {
route_manager,
metrics,
upstream_selector: UpstreamSelector::new(),
}
}
/// Handle an incoming HTTP connection on a plain TCP stream.
pub async fn handle_connection(
self: Arc<Self>,
stream: TcpStream,
peer_addr: std::net::SocketAddr,
port: u16,
) {
self.handle_io(stream, peer_addr, port).await;
}
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
///
/// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2,
/// use `handle_io_auto` instead.
pub async fn handle_io<I>(
self: Arc<Self>,
stream: I,
peer_addr: std::net::SocketAddr,
port: u16,
)
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let io = TokioIo::new(stream);
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let svc = Arc::clone(&self);
let peer = peer_addr;
async move {
svc.handle_request(req, peer, port).await
}
});
// Use http1::Builder with upgrades for WebSocket support
let conn = hyper::server::conn::http1::Builder::new()
.keep_alive(true)
.serve_connection(io, service)
.with_upgrades();
if let Err(e) = conn.await {
debug!("HTTP connection error from {}: {}", peer_addr, e);
}
}
/// Handle a single HTTP request.
async fn handle_request(
&self,
req: Request<Incoming>,
peer_addr: std::net::SocketAddr,
port: u16,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let host = req.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.map(|h| {
// Strip port from host header
h.split(':').next().unwrap_or(h).to_string()
});
let path = req.uri().path().to_string();
let method = req.method().clone();
// Extract headers for matching
let headers: HashMap<String, String> = req.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
// Check for CORS preflight
if method == hyper::Method::OPTIONS {
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
return Ok(response);
}
}
// Match route
let ctx = rustproxy_routing::MatchContext {
port,
domain: host.as_deref(),
path: Some(&path),
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: Some(&headers),
is_tls: false,
};
let route_match = match self.route_manager.find_route(&ctx) {
Some(rm) => rm,
None => {
debug!("No route matched for HTTP request to {:?}{}", host, path);
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
}
};
let route_id = route_match.route.id.as_deref();
self.metrics.connection_opened(route_id);
// Apply request filters (IP check, rate limiting, auth)
if let Some(ref security) = route_match.route.security {
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
self.metrics.connection_closed(route_id);
return Ok(response);
}
}
// Check for test response (returns immediately, no upstream needed)
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref test_response) = advanced.test_response {
self.metrics.connection_closed(route_id);
return Ok(Self::build_test_response(test_response));
}
}
// Check for static file serving
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref static_files) = advanced.static_files {
self.metrics.connection_closed(route_id);
return Ok(Self::serve_static_file(&path, static_files));
}
}
// Select upstream
let target = match route_match.target {
Some(t) => t,
None => {
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
}
};
let upstream = self.upstream_selector.select(target, &peer_addr, port);
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_started(&upstream_key);
// Check for WebSocket upgrade
let is_websocket = req.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
if is_websocket {
let result = self.handle_websocket_upgrade(
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key,
).await;
// Note: for WebSocket, connection_ended is called inside
// the spawned tunnel task when the connection closes.
return result;
}
// Determine backend protocol
let use_h2 = route_match.route.action.options.as_ref()
.and_then(|o| o.backend_protocol.as_ref())
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
.unwrap_or(false);
// Build the upstream path (path + query), applying URL rewriting if configured
let upstream_path = {
let raw_path = match req.uri().query() {
Some(q) => format!("{}?{}", path, q),
None => path.clone(),
};
Self::apply_url_rewrite(&raw_path, &route_match.route)
};
// Build upstream request - stream body instead of buffering
let (parts, body) = req.into_parts();
// Apply request headers from route config
let mut upstream_headers = parts.headers.clone();
if let Some(ref route_headers) = route_match.route.headers {
if let Some(ref request_headers) = route_headers.request {
for (key, value) in request_headers {
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
upstream_headers.insert(name, val);
}
}
}
}
}
// Connect to upstream
let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await {
Ok(s) => s,
Err(e) => {
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
};
upstream_stream.set_nodelay(true).ok();
let io = TokioIo::new(upstream_stream);
let result = if use_h2 {
// HTTP/2 backend
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
} else {
// HTTP/1.1 backend (default)
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
};
self.upstream_selector.connection_ended(&upstream_key);
result
}
/// Forward request to backend via HTTP/1.1 with body streaming.
async fn forward_h1(
&self,
io: TokioIo<TcpStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
Ok(h) => h,
Err(e) => {
error!("Upstream handshake failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
}
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("Upstream connection error: {}", e);
}
});
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_path)
.version(parts.version);
if let Some(headers) = upstream_req.headers_mut() {
*headers = upstream_headers;
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
&format!("{}:{}", upstream.host, upstream.port)
) {
headers.insert(hyper::header::HOST, host_val);
}
}
// Stream the request body through to upstream
let upstream_req = upstream_req.body(body).unwrap();
let upstream_response = match sender.send_request(upstream_req).await {
Ok(resp) => resp,
Err(e) => {
error!("Upstream request failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
}
};
self.build_streaming_response(upstream_response, route, route_id).await
}
/// Forward request to backend via HTTP/2 with body streaming.
async fn forward_h2(
&self,
io: TokioIo<TcpStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let exec = hyper_util::rt::TokioExecutor::new();
let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await {
Ok(h) => h,
Err(e) => {
error!("HTTP/2 upstream handshake failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
}
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("HTTP/2 upstream connection error: {}", e);
}
});
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_path);
if let Some(headers) = upstream_req.headers_mut() {
*headers = upstream_headers;
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
&format!("{}:{}", upstream.host, upstream.port)
) {
headers.insert(hyper::header::HOST, host_val);
}
}
// Stream the request body through to upstream
let upstream_req = upstream_req.body(body).unwrap();
let upstream_response = match sender.send_request(upstream_req).await {
Ok(resp) => resp,
Err(e) => {
error!("HTTP/2 upstream request failed: {}", e);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
}
};
self.build_streaming_response(upstream_response, route, route_id).await
}
/// Build the client-facing response from an upstream response, streaming the body.
async fn build_streaming_response(
&self,
upstream_response: Response<Incoming>,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let (resp_parts, resp_body) = upstream_response.into_parts();
let mut response = Response::builder()
.status(resp_parts.status);
if let Some(headers) = response.headers_mut() {
*headers = resp_parts.headers;
ResponseFilter::apply_headers(route, headers, None);
}
self.metrics.connection_closed(route_id);
// Stream the response body directly from upstream to client
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(resp_body);
Ok(response.body(body).unwrap())
}
/// Handle a WebSocket upgrade request.
async fn handle_websocket_upgrade(
&self,
req: Request<Incoming>,
peer_addr: std::net::SocketAddr,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
route_id: Option<&str>,
upstream_key: &str,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Get WebSocket config from route
let ws_config = route.action.websocket.as_ref();
// Check allowed origins if configured
if let Some(ws) = ws_config {
if let Some(ref allowed_origins) = ws.allowed_origins {
let origin = req.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
}
}
}
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
let mut upstream_stream = match TcpStream::connect(
format!("{}:{}", upstream.host, upstream.port)
).await {
Ok(s) => s,
Err(e) => {
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
};
upstream_stream.set_nodelay(true).ok();
let path = req.uri().path().to_string();
let upstream_path = {
let raw = match req.uri().query() {
Some(q) => format!("{}?{}", path, q),
None => path,
};
// Apply rewrite_path if configured
if let Some(ws) = ws_config {
if let Some(ref rewrite_path) = ws.rewrite_path {
rewrite_path.clone()
} else {
raw
}
} else {
raw
}
};
let (parts, _body) = req.into_parts();
let mut raw_request = format!(
"{} {} HTTP/1.1\r\n",
parts.method, upstream_path
);
let upstream_host = format!("{}:{}", upstream.host, upstream.port);
for (name, value) in parts.headers.iter() {
if name == hyper::header::HOST {
raw_request.push_str(&format!("host: {}\r\n", upstream_host));
} else {
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
}
}
if let Some(ref route_headers) = route.headers {
if let Some(ref request_headers) = route_headers.request {
for (key, value) in request_headers {
raw_request.push_str(&format!("{}: {}\r\n", key, value));
}
}
}
// Apply WebSocket custom headers
if let Some(ws) = ws_config {
if let Some(ref custom_headers) = ws.custom_headers {
for (key, value) in custom_headers {
raw_request.push_str(&format!("{}: {}\r\n", key, value));
}
}
}
raw_request.push_str("\r\n");
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
}
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
match upstream_stream.read(&mut temp).await {
Ok(0) => {
error!("WebSocket: upstream closed before completing handshake");
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
}
Ok(_) => {
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break;
}
}
if response_buf.len() > 8192 {
error!("WebSocket: upstream response headers too large");
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
}
}
Err(e) => {
error!("WebSocket: failed to read upstream response: {}", e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
}
}
}
let response_str = String::from_utf8_lossy(&response_buf);
let status_line = response_str.lines().next().unwrap_or("");
let status_code = status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(0);
if status_code != 101 {
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id);
return Ok(error_response(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
"WebSocket upgrade rejected by backend",
));
}
let mut client_resp = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS);
if let Some(resp_headers) = client_resp.headers_mut() {
for line in response_str.lines().skip(1) {
let line = line.trim();
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim();
let value = value.trim();
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
resp_headers.insert(header_name, header_value);
}
}
}
}
}
let on_client_upgrade = hyper::upgrade::on(
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
);
let metrics = Arc::clone(&self.metrics);
let route_id_owned = route_id.map(|s| s.to_string());
let upstream_selector = self.upstream_selector.clone();
let upstream_key_owned = upstream_key.to_string();
tokio::spawn(async move {
let client_upgraded = match on_client_upgrade.await {
Ok(upgraded) => upgraded,
Err(e) => {
debug!("WebSocket: client upgrade failed: {}", e);
upstream_selector.connection_ended(&upstream_key_owned);
if let Some(ref rid) = route_id_owned {
metrics.connection_closed(Some(rid.as_str()));
}
return;
}
};
let client_io = TokioIo::new(client_upgraded);
let (mut cr, mut cw) = tokio::io::split(client_io);
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
let c2u = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match cr.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if uw.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
}
let _ = uw.shutdown().await;
total
});
let u2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match ur.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if cw.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
}
let _ = cw.shutdown().await;
total
});
let bytes_in = c2u.await.unwrap_or(0);
let bytes_out = u2c.await.unwrap_or(0);
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
upstream_selector.connection_ended(&upstream_key_owned);
if let Some(ref rid) = route_id_owned {
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()));
metrics.connection_closed(Some(rid.as_str()));
}
});
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
);
Ok(client_resp.body(body).unwrap())
}
/// Build a test response from config (no upstream connection needed).
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::builder()
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
if let Some(headers) = response.headers_mut() {
for (key, value) in &config.headers {
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
headers.insert(name, val);
}
}
}
}
let body = Full::new(Bytes::from(config.body.clone()))
.map_err(|never| match never {});
response.body(BoxBody::new(body)).unwrap()
}
/// Apply URL rewriting rules from route config.
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
let rewrite = match route.action.advanced.as_ref()
.and_then(|a| a.url_rewrite.as_ref())
{
Some(r) => r,
None => return path.to_string(),
};
// Determine what to rewrite
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
// Only rewrite the path portion (before ?)
match path.split_once('?') {
Some((p, q)) => (p.to_string(), format!("?{}", q)),
None => (path.to_string(), String::new()),
}
} else {
(path.to_string(), String::new())
};
match Regex::new(&rewrite.pattern) {
Ok(re) => {
let result = re.replace_all(&subject, rewrite.target.as_str());
format!("{}{}", result, suffix)
}
Err(e) => {
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
path.to_string()
}
}
}
/// Serve a static file from the configured directory.
fn serve_static_file(
path: &str,
config: &rustproxy_config::RouteStaticFiles,
) -> Response<BoxBody<Bytes, hyper::Error>> {
use std::path::Path;
let root = Path::new(&config.root);
// Sanitize path to prevent directory traversal
let clean_path = path.trim_start_matches('/');
let clean_path = clean_path.replace("..", "");
let mut file_path = root.join(&clean_path);
// If path points to a directory, try index files
if file_path.is_dir() || clean_path.is_empty() {
let index_files = config.index_files.as_deref()
.or(config.index.as_deref())
.unwrap_or(&[]);
let default_index = vec!["index.html".to_string()];
let index_files = if index_files.is_empty() { &default_index } else { index_files };
let mut found = false;
for index in index_files {
let candidate = if clean_path.is_empty() {
root.join(index)
} else {
file_path.join(index)
};
if candidate.is_file() {
file_path = candidate;
found = true;
break;
}
}
if !found {
return error_response(StatusCode::NOT_FOUND, "Not found");
}
}
// Ensure the resolved path is within the root (prevent traversal)
let canonical_root = match root.canonicalize() {
Ok(p) => p,
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
};
let canonical_file = match file_path.canonicalize() {
Ok(p) => p,
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
};
if !canonical_file.starts_with(&canonical_root) {
return error_response(StatusCode::FORBIDDEN, "Forbidden");
}
// Check if symlinks are allowed
if config.follow_symlinks == Some(false) && canonical_file != file_path {
return error_response(StatusCode::FORBIDDEN, "Forbidden");
}
// Read the file
match std::fs::read(&file_path) {
Ok(content) => {
let content_type = guess_content_type(&file_path);
let mut response = Response::builder()
.status(StatusCode::OK)
.header("Content-Type", content_type);
// Apply cache-control if configured
if let Some(ref cache_control) = config.cache_control {
response = response.header("Cache-Control", cache_control.as_str());
}
// Apply custom headers
if let Some(ref headers) = config.headers {
for (key, value) in headers {
response = response.header(key.as_str(), value.as_str());
}
}
let body = Full::new(Bytes::from(content))
.map_err(|never| match never {});
response.body(BoxBody::new(body)).unwrap()
}
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
}
}
}
/// Guess MIME content type from file extension.
fn guess_content_type(path: &std::path::Path) -> &'static str {
match path.extension().and_then(|e| e.to_str()) {
Some("html") | Some("htm") => "text/html; charset=utf-8",
Some("css") => "text/css; charset=utf-8",
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
Some("json") => "application/json; charset=utf-8",
Some("xml") => "application/xml; charset=utf-8",
Some("txt") => "text/plain; charset=utf-8",
Some("png") => "image/png",
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("gif") => "image/gif",
Some("svg") => "image/svg+xml",
Some("ico") => "image/x-icon",
Some("woff") => "font/woff",
Some("woff2") => "font/woff2",
Some("ttf") => "font/ttf",
Some("pdf") => "application/pdf",
Some("wasm") => "application/wasm",
_ => "application/octet-stream",
}
}
impl Default for HttpProxyService {
fn default() -> Self {
Self {
route_manager: Arc::new(RouteManager::new(vec![])),
metrics: Arc::new(MetricsCollector::new()),
upstream_selector: UpstreamSelector::new(),
}
}
}
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
let body = Full::new(Bytes::from(message.to_string()))
.map_err(|never| match never {});
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(BoxBody::new(body))
.unwrap()
}

View File

@@ -0,0 +1,263 @@
//! Request filtering: security checks, auth, CORS preflight.
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
pub struct RequestFilter;
impl RequestFilter {
/// Apply security filters. Returns Some(response) if the request should be blocked.
pub fn apply(
security: &RouteSecurity,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
Self::apply_with_rate_limiter(security, req, peer_addr, None)
}
/// Apply security filters with an optional shared rate limiter.
/// Returns Some(response) if the request should be blocked.
pub fn apply_with_rate_limiter(
security: &RouteSecurity,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
rate_limiter: Option<&Arc<RateLimiter>>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
let client_ip = peer_addr.ip();
let request_path = req.uri().path();
// IP filter
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip);
if !filter.is_allowed(&normalized) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
}
}
// Rate limiting
if let Some(ref rate_limit_config) = security.rate_limit {
if rate_limit_config.enabled {
// Use shared rate limiter if provided, otherwise create ephemeral one
let should_block = if let Some(limiter) = rate_limiter {
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
} else {
// Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new(
rate_limit_config.max_requests,
rate_limit_config.window,
);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
};
if should_block {
let message = rate_limit_config.error_message
.as_deref()
.unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
}
}
}
// Check exclude paths before auth
let should_skip_auth = Self::path_matches_exclude_list(request_path, security);
if !should_skip_auth {
// Basic auth
if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled {
// Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter()
.map(|c| (c.username.clone(), c.password.clone()))
.collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header {
Some(header) => {
if validator.validate(header).is_none() {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Invalid credentials"))
.unwrap());
}
}
None => {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap());
}
}
}
}
}
// JWT auth
if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled {
// Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_jwt {
let validator = JwtValidator::new(
&jwt_auth.secret,
jwt_auth.algorithm.as_deref(),
jwt_auth.issuer.as_deref(),
jwt_auth.audience.as_deref(),
);
let auth_header = req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => {
if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
}
}
None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
}
}
}
}
}
}
None
}
/// Check if a request path matches any pattern in the exclude list.
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
// No global exclude paths on RouteSecurity currently,
// but we check per-auth exclude paths above.
// This can be extended if a global exclude_paths is added.
false
}
/// Check if a path matches any pattern in the list.
/// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check`
fn path_matches_any(path: &str, patterns: &[String]) -> bool {
for pattern in patterns {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
if path.starts_with(prefix) {
return true;
}
} else if path == pattern {
return true;
}
}
false
}
/// Determine the rate limit key based on configuration.
fn rate_limit_key(
config: &rustproxy_config::RouteRateLimit,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
) -> String {
use rustproxy_config::RateLimitKeyBy;
match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) {
RateLimitKeyBy::Ip => peer_addr.ip().to_string(),
RateLimitKeyBy::Path => req.uri().path().to_string(),
RateLimitKeyBy::Header => {
if let Some(ref header_name) = config.header_name {
req.headers()
.get(header_name.as_str())
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string()
} else {
peer_addr.ip().to_string()
}
}
}
}
/// Check IP-based security (for use in passthrough / TCP-level connections).
/// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(client_ip);
filter.is_allowed(&normalized)
} else {
true
}
}
/// Handle CORS preflight (OPTIONS) requests.
/// Returns Some(response) if this is a CORS preflight that should be handled.
pub fn handle_cors_preflight(
req: &Request<Incoming>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
if req.method() != hyper::Method::OPTIONS {
return None;
}
// Check for CORS preflight indicators
let has_origin = req.headers().contains_key("origin");
let has_request_method = req.headers().contains_key("access-control-request-method");
if !has_origin || !has_request_method {
return None;
}
let origin = req.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");
Some(Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap())
}
}
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(boxed_body(message))
.unwrap()
}
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
}

View File

@@ -0,0 +1,92 @@
//! Response filtering: CORS headers, custom headers, security headers.
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template};
pub struct ResponseFilter;
impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
// Apply custom response headers from route config
if let Some(ref route_headers) = route.headers {
if let Some(ref response_headers) = route_headers.response {
for (key, value) in response_headers {
if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) {
let expanded = match req_ctx {
Some(ctx) => expand_template(value, ctx),
None => value.clone(),
};
if let Ok(val) = HeaderValue::from_str(&expanded) {
headers.insert(name, val);
}
}
}
}
// Apply CORS headers if configured
if let Some(ref cors) = route_headers.cors {
if cors.enabled {
Self::apply_cors_headers(cors, headers);
}
}
}
}
fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) {
// Allow-Origin
if let Some(ref origin) = cors.allow_origin {
let origin_str = match origin {
rustproxy_config::AllowOrigin::Single(s) => s.clone(),
rustproxy_config::AllowOrigin::List(list) => list.join(", "),
};
if let Ok(val) = HeaderValue::from_str(&origin_str) {
headers.insert("access-control-allow-origin", val);
}
} else {
headers.insert(
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
}
// Allow-Methods
if let Some(ref methods) = cors.allow_methods {
if let Ok(val) = HeaderValue::from_str(methods) {
headers.insert("access-control-allow-methods", val);
}
}
// Allow-Headers
if let Some(ref allow_headers) = cors.allow_headers {
if let Ok(val) = HeaderValue::from_str(allow_headers) {
headers.insert("access-control-allow-headers", val);
}
}
// Allow-Credentials
if cors.allow_credentials == Some(true) {
headers.insert(
"access-control-allow-credentials",
HeaderValue::from_static("true"),
);
}
// Expose-Headers
if let Some(ref expose) = cors.expose_headers {
if let Ok(val) = HeaderValue::from_str(expose) {
headers.insert("access-control-expose-headers", val);
}
}
// Max-Age
if let Some(max_age) = cors.max_age {
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
headers.insert("access-control-max-age", val);
}
}
}
}

View File

@@ -0,0 +1,162 @@
//! Header template variable expansion.
//!
//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc.
//! in header values before they are applied to requests or responses.
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
/// Context for template variable expansion.
pub struct RequestContext {
pub client_ip: String,
pub domain: String,
pub port: u16,
pub path: String,
pub route_name: String,
pub connection_id: u64,
}
/// Expand template variables in a header value.
/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp}
pub fn expand_template(template: &str, ctx: &RequestContext) -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
template
.replace("{clientIp}", &ctx.client_ip)
.replace("{domain}", &ctx.domain)
.replace("{port}", &ctx.port.to_string())
.replace("{path}", &ctx.path)
.replace("{routeName}", &ctx.route_name)
.replace("{connectionId}", &ctx.connection_id.to_string())
.replace("{timestamp}", &timestamp.to_string())
}
/// Expand templates in a map of header key-value pairs.
pub fn expand_headers(
headers: &HashMap<String, String>,
ctx: &RequestContext,
) -> HashMap<String, String> {
headers.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> RequestContext {
RequestContext {
client_ip: "192.168.1.100".to_string(),
domain: "example.com".to_string(),
port: 443,
path: "/api/v1/users".to_string(),
route_name: "api-route".to_string(),
connection_id: 42,
}
}
#[test]
fn test_expand_client_ip() {
let ctx = test_context();
assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100");
}
#[test]
fn test_expand_domain() {
let ctx = test_context();
assert_eq!(expand_template("{domain}", &ctx), "example.com");
}
#[test]
fn test_expand_port() {
let ctx = test_context();
assert_eq!(expand_template("{port}", &ctx), "443");
}
#[test]
fn test_expand_path() {
let ctx = test_context();
assert_eq!(expand_template("{path}", &ctx), "/api/v1/users");
}
#[test]
fn test_expand_route_name() {
let ctx = test_context();
assert_eq!(expand_template("{routeName}", &ctx), "api-route");
}
#[test]
fn test_expand_connection_id() {
let ctx = test_context();
assert_eq!(expand_template("{connectionId}", &ctx), "42");
}
#[test]
fn test_expand_timestamp() {
let ctx = test_context();
let result = expand_template("{timestamp}", &ctx);
// Timestamp should be a valid number
let ts: u64 = result.parse().expect("timestamp should be a number");
// Should be a reasonable Unix timestamp (after 2020)
assert!(ts > 1_577_836_800);
}
#[test]
fn test_expand_mixed_template() {
let ctx = test_context();
let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx);
assert_eq!(result, "client=192.168.1.100, host=example.com:443");
}
#[test]
fn test_expand_no_variables() {
let ctx = test_context();
assert_eq!(expand_template("plain-value", &ctx), "plain-value");
}
#[test]
fn test_expand_empty_string() {
let ctx = test_context();
assert_eq!(expand_template("", &ctx), "");
}
#[test]
fn test_expand_multiple_same_variable() {
let ctx = test_context();
let result = expand_template("{clientIp}-{clientIp}", &ctx);
assert_eq!(result, "192.168.1.100-192.168.1.100");
}
#[test]
fn test_expand_headers_map() {
let ctx = test_context();
let mut headers = HashMap::new();
headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string());
headers.insert("X-Route".to_string(), "{routeName}".to_string());
headers.insert("X-Static".to_string(), "no-template".to_string());
let result = expand_headers(&headers, &ctx);
assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100");
assert_eq!(result.get("X-Route").unwrap(), "api-route");
assert_eq!(result.get("X-Static").unwrap(), "no-template");
}
#[test]
fn test_expand_all_variables_in_one() {
let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
}
#[test]
fn test_expand_unknown_variable_left_as_is() {
let ctx = test_context();
let result = expand_template("{unknownVar}", &ctx);
assert_eq!(result, "{unknownVar}");
}
}

View File

@@ -0,0 +1,222 @@
//! Route-aware upstream selection with load balancing.
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
/// Upstream selection result.
pub struct UpstreamSelection {
pub host: String,
pub port: u16,
pub use_tls: bool,
}
/// Selects upstream backends with load balancing support.
pub struct UpstreamSelector {
/// Round-robin counters per route (keyed by first target host:port)
round_robin: Mutex<HashMap<String, AtomicUsize>>,
/// Active connection counts per host (keyed by "host:port")
active_connections: Arc<DashMap<String, AtomicU64>>,
}
impl UpstreamSelector {
pub fn new() -> Self {
Self {
round_robin: Mutex::new(HashMap::new()),
active_connections: Arc::new(DashMap::new()),
}
}
/// Select an upstream target based on the route target config and load balancing.
pub fn select(
&self,
target: &RouteTarget,
client_addr: &SocketAddr,
incoming_port: u16,
) -> UpstreamSelection {
let hosts = target.host.to_vec();
let port = target.port.resolve(incoming_port);
if hosts.len() <= 1 {
return UpstreamSelection {
host: hosts.first().map(|s| s.to_string()).unwrap_or_default(),
port,
use_tls: target.tls.is_some(),
};
}
// Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref()
.map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr);
hash % hosts.len()
}
LoadBalancingAlgorithm::LeastConnections => {
self.least_connections_select(&hosts, port)
}
};
UpstreamSelection {
host: hosts[idx].to_string(),
port,
use_tls: target.tls.is_some(),
}
}
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap();
let counter = counters
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len()
}
fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize {
let mut min_conns = u64::MAX;
let mut min_idx = 0;
for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port);
let conns = self.active_connections
.get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0);
if conns < min_conns {
min_conns = conns;
min_idx = i;
}
}
min_idx
}
/// Record that a connection to the given host has started.
pub fn connection_started(&self, host: &str) {
self.active_connections
.entry(host.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection to the given host has ended.
pub fn connection_ended(&self, host: &str) {
if let Some(counter) = self.active_connections.get(host) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Guard against underflow (shouldn't happen, but be safe)
if prev == 0 {
counter.value().store(0, Ordering::Relaxed);
}
}
}
fn ip_hash(addr: &SocketAddr) -> usize {
let ip_str = addr.ip().to_string();
let mut hash: usize = 5381;
for byte in ip_str.bytes() {
hash = hash.wrapping_mul(33).wrapping_add(byte as usize);
}
hash
}
}
impl Default for UpstreamSelector {
fn default() -> Self {
Self::new()
}
}
impl Clone for UpstreamSelector {
fn clone(&self) -> Self {
Self {
round_robin: Mutex::new(HashMap::new()),
active_connections: Arc::clone(&self.active_connections),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustproxy_config::*;
fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget {
RouteTarget {
target_match: None,
host: if hosts.len() == 1 {
HostSpec::Single(hosts[0].to_string())
} else {
HostSpec::List(hosts.iter().map(|s| s.to_string()).collect())
},
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}
}
#[test]
fn test_single_host() {
let selector = UpstreamSelector::new();
let target = make_target(vec!["backend"], 8080);
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let result = selector.select(&target, &addr, 80);
assert_eq!(result.host, "backend");
assert_eq!(result.port, 8080);
}
#[test]
fn test_round_robin() {
let selector = UpstreamSelector::new();
let mut target = make_target(vec!["a", "b", "c"], 8080);
target.load_balancing = Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
});
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let r1 = selector.select(&target, &addr, 80);
let r2 = selector.select(&target, &addr, 80);
let r3 = selector.select(&target, &addr, 80);
let r4 = selector.select(&target, &addr, 80);
// Should cycle through a, b, c, a
assert_eq!(r1.host, "a");
assert_eq!(r2.host, "b");
assert_eq!(r3.host, "c");
assert_eq!(r4.host, "a");
}
#[test]
fn test_ip_hash_consistent() {
let selector = UpstreamSelector::new();
let mut target = make_target(vec!["a", "b", "c"], 8080);
target.load_balancing = Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::IpHash,
health_check: None,
});
let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap();
let r1 = selector.select(&target, &addr, 80);
let r2 = selector.select(&target, &addr, 80);
// Same IP should always get same backend
assert_eq!(r1.host, r2.host);
}
}