feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
This commit is contained in:
14
rust/crates/rustproxy-http/src/lib.rs
Normal file
14
rust/crates/rustproxy-http/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! # rustproxy-http
|
||||
//!
|
||||
//! Hyper-based HTTP proxy service for RustProxy.
|
||||
//! Handles HTTP request parsing, route-based forwarding, and response filtering.
|
||||
|
||||
pub mod proxy_service;
|
||||
pub mod request_filter;
|
||||
pub mod response_filter;
|
||||
pub mod template;
|
||||
pub mod upstream_selector;
|
||||
|
||||
pub use proxy_service::*;
|
||||
pub use template::*;
|
||||
pub use upstream_selector::*;
|
||||
827
rust/crates/rustproxy-http/src/proxy_service.rs
Normal file
827
rust/crates/rustproxy-http/src/proxy_service.rs
Normal file
@@ -0,0 +1,827 @@
|
||||
//! Hyper-based HTTP proxy service.
|
||||
//!
|
||||
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
|
||||
//! parses HTTP requests, matches routes, and forwards to upstream backends.
|
||||
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use regex::Regex;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
|
||||
use crate::request_filter::RequestFilter;
|
||||
use crate::response_filter::ResponseFilter;
|
||||
use crate::upstream_selector::UpstreamSelector;
|
||||
|
||||
/// HTTP proxy service that processes HTTP traffic.
|
||||
pub struct HttpProxyService {
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
upstream_selector: UpstreamSelector,
|
||||
}
|
||||
|
||||
impl HttpProxyService {
|
||||
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
Self {
|
||||
route_manager,
|
||||
metrics,
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on a plain TCP stream.
|
||||
pub async fn handle_connection(
|
||||
self: Arc<Self>,
|
||||
stream: TcpStream,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) {
|
||||
self.handle_io(stream, peer_addr, port).await;
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
|
||||
///
|
||||
/// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2,
|
||||
/// use `handle_io_auto` instead.
|
||||
pub async fn handle_io<I>(
|
||||
self: Arc<Self>,
|
||||
stream: I,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
)
|
||||
where
|
||||
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
let svc = Arc::clone(&self);
|
||||
let peer = peer_addr;
|
||||
async move {
|
||||
svc.handle_request(req, peer, port).await
|
||||
}
|
||||
});
|
||||
|
||||
// Use http1::Builder with upgrades for WebSocket support
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, service)
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP connection error from {}: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single HTTP request.
|
||||
async fn handle_request(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let host = req.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| {
|
||||
// Strip port from host header
|
||||
h.split(':').next().unwrap_or(h).to_string()
|
||||
});
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let method = req.method().clone();
|
||||
|
||||
// Extract headers for matching
|
||||
let headers: HashMap<String, String> = req.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
|
||||
|
||||
// Check for CORS preflight
|
||||
if method == hyper::Method::OPTIONS {
|
||||
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: host.as_deref(),
|
||||
path: Some(&path),
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: Some(&headers),
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let route_match = match self.route_manager.find_route(&ctx) {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for HTTP request to {:?}{}", host, path);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
self.metrics.connection_opened(route_id);
|
||||
|
||||
// Apply request filters (IP check, rate limiting, auth)
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for test response (returns immediately, no upstream needed)
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref test_response) = advanced.test_response {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::build_test_response(test_response));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for static file serving
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref static_files) = advanced.static_files {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::serve_static_file(&path, static_files));
|
||||
}
|
||||
}
|
||||
|
||||
// Select upstream
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
|
||||
}
|
||||
};
|
||||
|
||||
let upstream = self.upstream_selector.select(target, &peer_addr, port);
|
||||
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
|
||||
self.upstream_selector.connection_started(&upstream_key);
|
||||
|
||||
// Check for WebSocket upgrade
|
||||
let is_websocket = req.headers()
|
||||
.get("upgrade")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_websocket {
|
||||
let result = self.handle_websocket_upgrade(
|
||||
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key,
|
||||
).await;
|
||||
// Note: for WebSocket, connection_ended is called inside
|
||||
// the spawned tunnel task when the connection closes.
|
||||
return result;
|
||||
}
|
||||
|
||||
// Determine backend protocol
|
||||
let use_h2 = route_match.route.action.options.as_ref()
|
||||
.and_then(|o| o.backend_protocol.as_ref())
|
||||
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
|
||||
.unwrap_or(false);
|
||||
|
||||
// Build the upstream path (path + query), applying URL rewriting if configured
|
||||
let upstream_path = {
|
||||
let raw_path = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path.clone(),
|
||||
};
|
||||
Self::apply_url_rewrite(&raw_path, &route_match.route)
|
||||
};
|
||||
|
||||
// Build upstream request - stream body instead of buffering
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
// Apply request headers from route config
|
||||
let mut upstream_headers = parts.headers.clone();
|
||||
if let Some(ref route_headers) = route_match.route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
upstream_headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to upstream
|
||||
let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let io = TokioIo::new(upstream_stream);
|
||||
|
||||
let result = if use_h2 {
|
||||
// HTTP/2 backend
|
||||
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
} else {
|
||||
// HTTP/1.1 backend (default)
|
||||
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
};
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
result
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/1.1 with body streaming.
|
||||
async fn forward_h1(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("Upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("Upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path)
|
||||
.version(parts.version);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("Upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/2 with body streaming.
|
||||
async fn forward_h2(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let exec = hyper_util::rt::TokioExecutor::new();
|
||||
let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP/2 upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Build the client-facing response from an upstream response, streaming the body.
|
||||
async fn build_streaming_response(
|
||||
&self,
|
||||
upstream_response: Response<Incoming>,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (resp_parts, resp_body) = upstream_response.into_parts();
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(resp_parts.status);
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
*headers = resp_parts.headers;
|
||||
ResponseFilter::apply_headers(route, headers, None);
|
||||
}
|
||||
|
||||
self.metrics.connection_closed(route_id);
|
||||
|
||||
// Stream the response body directly from upstream to client
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(resp_body);
|
||||
|
||||
Ok(response.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Handle a WebSocket upgrade request.
|
||||
async fn handle_websocket_upgrade(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
upstream_key: &str,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Get WebSocket config from route
|
||||
let ws_config = route.action.websocket.as_ref();
|
||||
|
||||
// Check allowed origins if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref allowed_origins) = ws.allowed_origins {
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
|
||||
|
||||
let mut upstream_stream = match TcpStream::connect(
|
||||
format!("{}:{}", upstream.host, upstream.port)
|
||||
).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let upstream_path = {
|
||||
let raw = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path,
|
||||
};
|
||||
// Apply rewrite_path if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref rewrite_path) = ws.rewrite_path {
|
||||
rewrite_path.clone()
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
};
|
||||
|
||||
let (parts, _body) = req.into_parts();
|
||||
|
||||
let mut raw_request = format!(
|
||||
"{} {} HTTP/1.1\r\n",
|
||||
parts.method, upstream_path
|
||||
);
|
||||
|
||||
let upstream_host = format!("{}:{}", upstream.host, upstream.port);
|
||||
for (name, value) in parts.headers.iter() {
|
||||
if name == hyper::header::HOST {
|
||||
raw_request.push_str(&format!("host: {}\r\n", upstream_host));
|
||||
} else {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply WebSocket custom headers
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref custom_headers) = ws.custom_headers {
|
||||
for (key, value) in custom_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
raw_request.push_str("\r\n");
|
||||
|
||||
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
|
||||
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
|
||||
}
|
||||
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
match upstream_stream.read(&mut temp).await {
|
||||
Ok(0) => {
|
||||
error!("WebSocket: upstream closed before completing handshake");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
|
||||
}
|
||||
Ok(_) => {
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if response_buf.len() > 8192 {
|
||||
error!("WebSocket: upstream response headers too large");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to read upstream response: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf);
|
||||
|
||||
let status_line = response_str.lines().next().unwrap_or("");
|
||||
let status_code = status_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|s| s.parse::<u16>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if status_code != 101 {
|
||||
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(
|
||||
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
|
||||
"WebSocket upgrade rejected by backend",
|
||||
));
|
||||
}
|
||||
|
||||
let mut client_resp = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS);
|
||||
|
||||
if let Some(resp_headers) = client_resp.headers_mut() {
|
||||
for line in response_str.lines().skip(1) {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
break;
|
||||
}
|
||||
if let Some((name, value)) = line.split_once(':') {
|
||||
let name = name.trim();
|
||||
let value = value.trim();
|
||||
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
|
||||
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
|
||||
resp_headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let on_client_upgrade = hyper::upgrade::on(
|
||||
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
|
||||
);
|
||||
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let route_id_owned = route_id.map(|s| s.to_string());
|
||||
let upstream_selector = self.upstream_selector.clone();
|
||||
let upstream_key_owned = upstream_key.to_string();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client_upgraded = match on_client_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(e) => {
|
||||
debug!("WebSocket: client upgrade failed: {}", e);
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let client_io = TokioIo::new(client_upgraded);
|
||||
|
||||
let (mut cr, mut cw) = tokio::io::split(client_io);
|
||||
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
|
||||
|
||||
let c2u = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match cr.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if uw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = uw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let u2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match ur.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if cw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = cw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let bytes_in = c2u.await.unwrap_or(0);
|
||||
let bytes_out = u2c.await.unwrap_or(0);
|
||||
|
||||
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
||||
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()));
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
});
|
||||
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
||||
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
|
||||
);
|
||||
Ok(client_resp.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Build a test response from config (no upstream connection needed).
|
||||
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (key, value) in &config.headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(config.body.clone()))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
|
||||
/// Apply URL rewriting rules from route config.
|
||||
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
|
||||
let rewrite = match route.action.advanced.as_ref()
|
||||
.and_then(|a| a.url_rewrite.as_ref())
|
||||
{
|
||||
Some(r) => r,
|
||||
None => return path.to_string(),
|
||||
};
|
||||
|
||||
// Determine what to rewrite
|
||||
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
|
||||
// Only rewrite the path portion (before ?)
|
||||
match path.split_once('?') {
|
||||
Some((p, q)) => (p.to_string(), format!("?{}", q)),
|
||||
None => (path.to_string(), String::new()),
|
||||
}
|
||||
} else {
|
||||
(path.to_string(), String::new())
|
||||
};
|
||||
|
||||
match Regex::new(&rewrite.pattern) {
|
||||
Ok(re) => {
|
||||
let result = re.replace_all(&subject, rewrite.target.as_str());
|
||||
format!("{}{}", result, suffix)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
|
||||
path.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serve a static file from the configured directory.
|
||||
fn serve_static_file(
|
||||
path: &str,
|
||||
config: &rustproxy_config::RouteStaticFiles,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
use std::path::Path;
|
||||
|
||||
let root = Path::new(&config.root);
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
let clean_path = path.trim_start_matches('/');
|
||||
let clean_path = clean_path.replace("..", "");
|
||||
|
||||
let mut file_path = root.join(&clean_path);
|
||||
|
||||
// If path points to a directory, try index files
|
||||
if file_path.is_dir() || clean_path.is_empty() {
|
||||
let index_files = config.index_files.as_deref()
|
||||
.or(config.index.as_deref())
|
||||
.unwrap_or(&[]);
|
||||
let default_index = vec!["index.html".to_string()];
|
||||
let index_files = if index_files.is_empty() { &default_index } else { index_files };
|
||||
|
||||
let mut found = false;
|
||||
for index in index_files {
|
||||
let candidate = if clean_path.is_empty() {
|
||||
root.join(index)
|
||||
} else {
|
||||
file_path.join(index)
|
||||
};
|
||||
if candidate.is_file() {
|
||||
file_path = candidate;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return error_response(StatusCode::NOT_FOUND, "Not found");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the resolved path is within the root (prevent traversal)
|
||||
let canonical_root = match root.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
let canonical_file = match file_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
if !canonical_file.starts_with(&canonical_root) {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Check if symlinks are allowed
|
||||
if config.follow_symlinks == Some(false) && canonical_file != file_path {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Read the file
|
||||
match std::fs::read(&file_path) {
|
||||
Ok(content) => {
|
||||
let content_type = guess_content_type(&file_path);
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", content_type);
|
||||
|
||||
// Apply cache-control if configured
|
||||
if let Some(ref cache_control) = config.cache_control {
|
||||
response = response.header("Cache-Control", cache_control.as_str());
|
||||
}
|
||||
|
||||
// Apply custom headers
|
||||
if let Some(ref headers) = config.headers {
|
||||
for (key, value) in headers {
|
||||
response = response.header(key.as_str(), value.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(content))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Guess MIME content type from file extension.
|
||||
fn guess_content_type(path: &std::path::Path) -> &'static str {
|
||||
match path.extension().and_then(|e| e.to_str()) {
|
||||
Some("html") | Some("htm") => "text/html; charset=utf-8",
|
||||
Some("css") => "text/css; charset=utf-8",
|
||||
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
|
||||
Some("json") => "application/json; charset=utf-8",
|
||||
Some("xml") => "application/xml; charset=utf-8",
|
||||
Some("txt") => "text/plain; charset=utf-8",
|
||||
Some("png") => "image/png",
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg",
|
||||
Some("gif") => "image/gif",
|
||||
Some("svg") => "image/svg+xml",
|
||||
Some("ico") => "image/x-icon",
|
||||
Some("woff") => "font/woff",
|
||||
Some("woff2") => "font/woff2",
|
||||
Some("ttf") => "font/ttf",
|
||||
Some("pdf") => "application/pdf",
|
||||
Some("wasm") => "application/wasm",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HttpProxyService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
route_manager: Arc::new(RouteManager::new(vec![])),
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let body = Full::new(Bytes::from(message.to_string()))
|
||||
.map_err(|never| match never {});
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(BoxBody::new(body))
|
||||
.unwrap()
|
||||
}
|
||||
263
rust/crates/rustproxy-http/src/request_filter.rs
Normal file
263
rust/crates/rustproxy-http/src/request_filter.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
//! Request filtering: security checks, auth, CORS preflight.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
|
||||
|
||||
pub struct RequestFilter;
|
||||
|
||||
impl RequestFilter {
|
||||
/// Apply security filters. Returns Some(response) if the request should be blocked.
|
||||
pub fn apply(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
Self::apply_with_rate_limiter(security, req, peer_addr, None)
|
||||
}
|
||||
|
||||
/// Apply security filters with an optional shared rate limiter.
|
||||
/// Returns Some(response) if the request should be blocked.
|
||||
pub fn apply_with_rate_limiter(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
rate_limiter: Option<&Arc<RateLimiter>>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
|
||||
// IP filter
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
if !filter.is_allowed(&normalized) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(ref rate_limit_config) = security.rate_limit {
|
||||
if rate_limit_config.enabled {
|
||||
// Use shared rate limiter if provided, otherwise create ephemeral one
|
||||
let should_block = if let Some(limiter) = rate_limiter {
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
} else {
|
||||
// Create a per-check limiter (less ideal but works for non-shared case)
|
||||
let limiter = RateLimiter::new(
|
||||
rate_limit_config.max_requests,
|
||||
rate_limit_config.window,
|
||||
);
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
};
|
||||
|
||||
if should_block {
|
||||
let message = rate_limit_config.error_message
|
||||
.as_deref()
|
||||
.unwrap_or("Rate limit exceeded");
|
||||
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check exclude paths before auth
|
||||
let should_skip_auth = Self::path_matches_exclude_list(request_path, security);
|
||||
|
||||
if !should_skip_auth {
|
||||
// Basic auth
|
||||
if let Some(ref basic_auth) = security.basic_auth {
|
||||
if basic_auth.enabled {
|
||||
// Check basic auth exclude paths
|
||||
let skip_basic = basic_auth.exclude_paths.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_basic {
|
||||
let users: Vec<(String, String)> = basic_auth.users.iter()
|
||||
.map(|c| (c.username.clone(), c.password.clone()))
|
||||
.collect();
|
||||
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
|
||||
|
||||
let auth_header = req.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(header) => {
|
||||
if validator.validate(header).is_none() {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWT auth
|
||||
if let Some(ref jwt_auth) = security.jwt_auth {
|
||||
if jwt_auth.enabled {
|
||||
// Check JWT auth exclude paths
|
||||
let skip_jwt = jwt_auth.exclude_paths.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_jwt {
|
||||
let validator = JwtValidator::new(
|
||||
&jwt_auth.secret,
|
||||
jwt_auth.algorithm.as_deref(),
|
||||
jwt_auth.issuer.as_deref(),
|
||||
jwt_auth.audience.as_deref(),
|
||||
);
|
||||
|
||||
let auth_header = req.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header.and_then(JwtValidator::extract_token) {
|
||||
Some(token) => {
|
||||
if validator.validate(token).is_err() {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a request path matches any pattern in the exclude list.
|
||||
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
|
||||
// No global exclude paths on RouteSecurity currently,
|
||||
// but we check per-auth exclude paths above.
|
||||
// This can be extended if a global exclude_paths is added.
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if a path matches any pattern in the list.
|
||||
/// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check`
|
||||
fn path_matches_any(path: &str, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
if path.starts_with(prefix) {
|
||||
return true;
|
||||
}
|
||||
} else if path == pattern {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Determine the rate limit key based on configuration.
|
||||
fn rate_limit_key(
|
||||
config: &rustproxy_config::RouteRateLimit,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
) -> String {
|
||||
use rustproxy_config::RateLimitKeyBy;
|
||||
match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) {
|
||||
RateLimitKeyBy::Ip => peer_addr.ip().to_string(),
|
||||
RateLimitKeyBy::Path => req.uri().path().to_string(),
|
||||
RateLimitKeyBy::Header => {
|
||||
if let Some(ref header_name) = config.header_name {
|
||||
req.headers()
|
||||
.get(header_name.as_str())
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string()
|
||||
} else {
|
||||
peer_addr.ip().to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check IP-based security (for use in passthrough / TCP-level connections).
|
||||
/// Returns true if allowed, false if blocked.
|
||||
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(client_ip);
|
||||
filter.is_allowed(&normalized)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle CORS preflight (OPTIONS) requests.
|
||||
/// Returns Some(response) if this is a CORS preflight that should be handled.
|
||||
pub fn handle_cors_preflight(
|
||||
req: &Request<Incoming>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
if req.method() != hyper::Method::OPTIONS {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check for CORS preflight indicators
|
||||
let has_origin = req.headers().contains_key("origin");
|
||||
let has_request_method = req.headers().contains_key("access-control-request-method");
|
||||
|
||||
if !has_origin || !has_request_method {
|
||||
return None;
|
||||
}
|
||||
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("*");
|
||||
|
||||
Some(Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(boxed_body(message))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
|
||||
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
|
||||
}
|
||||
92
rust/crates/rustproxy-http/src/response_filter.rs
Normal file
92
rust/crates/rustproxy-http/src/response_filter.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Response filtering: CORS headers, custom headers, security headers.
|
||||
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use rustproxy_config::RouteConfig;
|
||||
|
||||
use crate::template::{RequestContext, expand_template};
|
||||
|
||||
pub struct ResponseFilter;
|
||||
|
||||
impl ResponseFilter {
|
||||
/// Apply response headers from route config and CORS settings.
|
||||
/// If a `RequestContext` is provided, template variables in header values will be expanded.
|
||||
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
|
||||
// Apply custom response headers from route config
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref response_headers) = route_headers.response {
|
||||
for (key, value) in response_headers {
|
||||
if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) {
|
||||
let expanded = match req_ctx {
|
||||
Some(ctx) => expand_template(value, ctx),
|
||||
None => value.clone(),
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&expanded) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply CORS headers if configured
|
||||
if let Some(ref cors) = route_headers.cors {
|
||||
if cors.enabled {
|
||||
Self::apply_cors_headers(cors, headers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) {
|
||||
// Allow-Origin
|
||||
if let Some(ref origin) = cors.allow_origin {
|
||||
let origin_str = match origin {
|
||||
rustproxy_config::AllowOrigin::Single(s) => s.clone(),
|
||||
rustproxy_config::AllowOrigin::List(list) => list.join(", "),
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&origin_str) {
|
||||
headers.insert("access-control-allow-origin", val);
|
||||
}
|
||||
} else {
|
||||
headers.insert(
|
||||
"access-control-allow-origin",
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
if let Some(ref methods) = cors.allow_methods {
|
||||
if let Ok(val) = HeaderValue::from_str(methods) {
|
||||
headers.insert("access-control-allow-methods", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Allow-Headers
|
||||
if let Some(ref allow_headers) = cors.allow_headers {
|
||||
if let Ok(val) = HeaderValue::from_str(allow_headers) {
|
||||
headers.insert("access-control-allow-headers", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Allow-Credentials
|
||||
if cors.allow_credentials == Some(true) {
|
||||
headers.insert(
|
||||
"access-control-allow-credentials",
|
||||
HeaderValue::from_static("true"),
|
||||
);
|
||||
}
|
||||
|
||||
// Expose-Headers
|
||||
if let Some(ref expose) = cors.expose_headers {
|
||||
if let Ok(val) = HeaderValue::from_str(expose) {
|
||||
headers.insert("access-control-expose-headers", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Max-Age
|
||||
if let Some(max_age) = cors.max_age {
|
||||
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
|
||||
headers.insert("access-control-max-age", val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
162
rust/crates/rustproxy-http/src/template.rs
Normal file
162
rust/crates/rustproxy-http/src/template.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Header template variable expansion.
|
||||
//!
|
||||
//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc.
|
||||
//! in header values before they are applied to requests or responses.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// Context for template variable expansion.
|
||||
pub struct RequestContext {
|
||||
pub client_ip: String,
|
||||
pub domain: String,
|
||||
pub port: u16,
|
||||
pub path: String,
|
||||
pub route_name: String,
|
||||
pub connection_id: u64,
|
||||
}
|
||||
|
||||
/// Expand template variables in a header value.
|
||||
/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp}
|
||||
pub fn expand_template(template: &str, ctx: &RequestContext) -> String {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
template
|
||||
.replace("{clientIp}", &ctx.client_ip)
|
||||
.replace("{domain}", &ctx.domain)
|
||||
.replace("{port}", &ctx.port.to_string())
|
||||
.replace("{path}", &ctx.path)
|
||||
.replace("{routeName}", &ctx.route_name)
|
||||
.replace("{connectionId}", &ctx.connection_id.to_string())
|
||||
.replace("{timestamp}", ×tamp.to_string())
|
||||
}
|
||||
|
||||
/// Expand templates in a map of header key-value pairs.
|
||||
pub fn expand_headers(
|
||||
headers: &HashMap<String, String>,
|
||||
ctx: &RequestContext,
|
||||
) -> HashMap<String, String> {
|
||||
headers.iter()
|
||||
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> RequestContext {
|
||||
RequestContext {
|
||||
client_ip: "192.168.1.100".to_string(),
|
||||
domain: "example.com".to_string(),
|
||||
port: 443,
|
||||
path: "/api/v1/users".to_string(),
|
||||
route_name: "api-route".to_string(),
|
||||
connection_id: 42,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_client_ip() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_domain() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{domain}", &ctx), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_port() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{port}", &ctx), "443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_path() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{path}", &ctx), "/api/v1/users");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_route_name() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{routeName}", &ctx), "api-route");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_connection_id() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{connectionId}", &ctx), "42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_timestamp() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{timestamp}", &ctx);
|
||||
// Timestamp should be a valid number
|
||||
let ts: u64 = result.parse().expect("timestamp should be a number");
|
||||
// Should be a reasonable Unix timestamp (after 2020)
|
||||
assert!(ts > 1_577_836_800);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_mixed_template() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx);
|
||||
assert_eq!(result, "client=192.168.1.100, host=example.com:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_no_variables() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("plain-value", &ctx), "plain-value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_empty_string() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("", &ctx), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_multiple_same_variable() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{clientIp}-{clientIp}", &ctx);
|
||||
assert_eq!(result, "192.168.1.100-192.168.1.100");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_headers_map() {
|
||||
let ctx = test_context();
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string());
|
||||
headers.insert("X-Route".to_string(), "{routeName}".to_string());
|
||||
headers.insert("X-Static".to_string(), "no-template".to_string());
|
||||
|
||||
let result = expand_headers(&headers, &ctx);
|
||||
assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100");
|
||||
assert_eq!(result.get("X-Route").unwrap(), "api-route");
|
||||
assert_eq!(result.get("X-Static").unwrap(), "no-template");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_all_variables_in_one() {
|
||||
let ctx = test_context();
|
||||
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
|
||||
let result = expand_template(template, &ctx);
|
||||
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_unknown_variable_left_as_is() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{unknownVar}", &ctx);
|
||||
assert_eq!(result, "{unknownVar}");
|
||||
}
|
||||
}
|
||||
222
rust/crates/rustproxy-http/src/upstream_selector.rs
Normal file
222
rust/crates/rustproxy-http/src/upstream_selector.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
//! Route-aware upstream selection with load balancing.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
|
||||
|
||||
/// Upstream selection result.
|
||||
pub struct UpstreamSelection {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub use_tls: bool,
|
||||
}
|
||||
|
||||
/// Selects upstream backends with load balancing support.
|
||||
pub struct UpstreamSelector {
|
||||
/// Round-robin counters per route (keyed by first target host:port)
|
||||
round_robin: Mutex<HashMap<String, AtomicUsize>>,
|
||||
/// Active connection counts per host (keyed by "host:port")
|
||||
active_connections: Arc<DashMap<String, AtomicU64>>,
|
||||
}
|
||||
|
||||
impl UpstreamSelector {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
round_robin: Mutex::new(HashMap::new()),
|
||||
active_connections: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select an upstream target based on the route target config and load balancing.
|
||||
pub fn select(
|
||||
&self,
|
||||
target: &RouteTarget,
|
||||
client_addr: &SocketAddr,
|
||||
incoming_port: u16,
|
||||
) -> UpstreamSelection {
|
||||
let hosts = target.host.to_vec();
|
||||
let port = target.port.resolve(incoming_port);
|
||||
|
||||
if hosts.len() <= 1 {
|
||||
return UpstreamSelection {
|
||||
host: hosts.first().map(|s| s.to_string()).unwrap_or_default(),
|
||||
port,
|
||||
use_tls: target.tls.is_some(),
|
||||
};
|
||||
}
|
||||
|
||||
// Determine load balancing algorithm
|
||||
let algorithm = target.load_balancing.as_ref()
|
||||
.map(|lb| &lb.algorithm)
|
||||
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
|
||||
|
||||
let idx = match algorithm {
|
||||
LoadBalancingAlgorithm::RoundRobin => {
|
||||
self.round_robin_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::IpHash => {
|
||||
let hash = Self::ip_hash(client_addr);
|
||||
hash % hosts.len()
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => {
|
||||
self.least_connections_select(&hosts, port)
|
||||
}
|
||||
};
|
||||
|
||||
UpstreamSelection {
|
||||
host: hosts[idx].to_string(),
|
||||
port,
|
||||
use_tls: target.tls.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let key = format!("{}:{}", hosts[0], port);
|
||||
let mut counters = self.round_robin.lock().unwrap();
|
||||
let counter = counters
|
||||
.entry(key)
|
||||
.or_insert_with(|| AtomicUsize::new(0));
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed);
|
||||
idx % hosts.len()
|
||||
}
|
||||
|
||||
fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let mut min_conns = u64::MAX;
|
||||
let mut min_idx = 0;
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let key = format!("{}:{}", host, port);
|
||||
let conns = self.active_connections
|
||||
.get(&key)
|
||||
.map(|entry| entry.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
if conns < min_conns {
|
||||
min_conns = conns;
|
||||
min_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
min_idx
|
||||
}
|
||||
|
||||
/// Record that a connection to the given host has started.
|
||||
pub fn connection_started(&self, host: &str) {
|
||||
self.active_connections
|
||||
.entry(host.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record that a connection to the given host has ended.
|
||||
pub fn connection_ended(&self, host: &str) {
|
||||
if let Some(counter) = self.active_connections.get(host) {
|
||||
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
||||
// Guard against underflow (shouldn't happen, but be safe)
|
||||
if prev == 0 {
|
||||
counter.value().store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ip_hash(addr: &SocketAddr) -> usize {
|
||||
let ip_str = addr.ip().to_string();
|
||||
let mut hash: usize = 5381;
|
||||
for byte in ip_str.bytes() {
|
||||
hash = hash.wrapping_mul(33).wrapping_add(byte as usize);
|
||||
}
|
||||
hash
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for UpstreamSelector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for UpstreamSelector {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
round_robin: Mutex::new(HashMap::new()),
|
||||
active_connections: Arc::clone(&self.active_connections),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustproxy_config::*;
|
||||
|
||||
fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget {
|
||||
RouteTarget {
|
||||
target_match: None,
|
||||
host: if hosts.len() == 1 {
|
||||
HostSpec::Single(hosts[0].to_string())
|
||||
} else {
|
||||
HostSpec::List(hosts.iter().map(|s| s.to_string()).collect())
|
||||
},
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_host() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let target = make_target(vec!["backend"], 8080);
|
||||
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
|
||||
let result = selector.select(&target, &addr, 80);
|
||||
assert_eq!(result.host, "backend");
|
||||
assert_eq!(result.port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let mut target = make_target(vec!["a", "b", "c"], 8080);
|
||||
target.load_balancing = Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::RoundRobin,
|
||||
health_check: None,
|
||||
});
|
||||
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
|
||||
|
||||
let r1 = selector.select(&target, &addr, 80);
|
||||
let r2 = selector.select(&target, &addr, 80);
|
||||
let r3 = selector.select(&target, &addr, 80);
|
||||
let r4 = selector.select(&target, &addr, 80);
|
||||
|
||||
// Should cycle through a, b, c, a
|
||||
assert_eq!(r1.host, "a");
|
||||
assert_eq!(r2.host, "b");
|
||||
assert_eq!(r3.host, "c");
|
||||
assert_eq!(r4.host, "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_hash_consistent() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let mut target = make_target(vec!["a", "b", "c"], 8080);
|
||||
target.load_balancing = Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::IpHash,
|
||||
health_check: None,
|
||||
});
|
||||
let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap();
|
||||
|
||||
let r1 = selector.select(&target, &addr, 80);
|
||||
let r2 = selector.select(&target, &addr, 80);
|
||||
// Same IP should always get same backend
|
||||
assert_eq!(r1.host, r2.host);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user