fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup

This commit is contained in:
2026-04-14 00:54:12 +00:00
parent 6ee7237357
commit a53a2c4ca5
15 changed files with 1813 additions and 590 deletions
@@ -3,8 +3,8 @@
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::Bytes;
@@ -105,13 +105,19 @@ impl ConnectionPool {
/// Try to check out an idle HTTP/1.1 sender for the given key.
/// Returns `None` if no usable idle connection exists.
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
pub fn checkout_h1(
&self,
key: &PoolKey,
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
let mut entry = self.h1_pool.get_mut(key)?;
let idles = entry.value_mut();
while let Some(idle) = idles.pop() {
// Check if the connection is still alive and ready
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() {
if idle.idle_since.elapsed() < IDLE_TIMEOUT
&& idle.sender.is_ready()
&& !idle.sender.is_closed()
{
// H1 pool hit — no logging on hot path
return Some(idle.sender);
}
@@ -128,7 +134,11 @@ impl ConnectionPool {
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
/// The caller should NOT call this if the sender is closed or not ready.
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) {
pub fn checkin_h1(
&self,
key: PoolKey,
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
) {
if sender.is_closed() || !sender.is_ready() {
return; // Don't pool broken connections
}
@@ -145,7 +155,10 @@ impl ConnectionPool {
/// Try to get a cloned HTTP/2 sender for the given key.
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
pub fn checkout_h2(
&self,
key: &PoolKey,
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
let entry = self.h2_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
@@ -184,16 +197,23 @@ impl ConnectionPool {
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
/// The caller should pass this generation to the connection driver so it can use
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 {
pub fn register_h2(
&self,
key: PoolKey,
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
if sender.is_closed() {
return gen;
}
self.h2_pool.insert(key, PooledH2 {
sender,
created_at: Instant::now(),
generation: gen,
});
self.h2_pool.insert(
key,
PooledH2 {
sender,
created_at: Instant::now(),
generation: gen,
},
);
gen
}
@@ -204,7 +224,11 @@ impl ConnectionPool {
pub fn checkout_h3(
&self,
key: &PoolKey,
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> {
) -> Option<(
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
quinn::Connection,
Duration,
)> {
let entry = self.h3_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
@@ -234,12 +258,15 @@ impl ConnectionPool {
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(key, PooledH3 {
send_request,
connection,
created_at: Instant::now(),
generation: gen,
});
self.h3_pool.insert(
key,
PooledH3 {
send_request,
connection,
created_at: Instant::now(),
generation: gen,
},
);
gen
}
@@ -280,7 +307,9 @@ impl ConnectionPool {
// Evict dead or aged-out H2 connections
let mut dead_h2 = Vec::new();
for entry in h2_pool.iter() {
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE {
if entry.value().sender.is_closed()
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
{
dead_h2.push(entry.key().clone());
}
}
@@ -1,8 +1,8 @@
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
/// Set the connection-level activity tracker. When set, each data frame
/// updates this timestamp to prevent the idle watchdog from killing the
/// connection during active body streaming.
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self {
pub fn with_connection_activity(
mut self,
activity: Arc<AtomicU64>,
start: std::time::Instant,
) -> Self {
self.connection_activity = Some(activity);
self.activity_start = Some(start);
self
@@ -134,7 +138,9 @@ where
}
// Keep the connection-level idle watchdog alive on every frame
// (this is just one atomic store — cheap enough per-frame)
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) {
if let (Some(activity), Some(start)) =
(&this.connection_activity, &this.activity_start)
{
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
}
+28 -10
View File
@@ -11,8 +11,8 @@ use std::task::{Context, Poll};
use bytes::{Buf, Bytes};
use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use tracing::{debug, warn};
use rustproxy_config::RouteConfig;
@@ -49,7 +49,8 @@ impl H3ProxyService {
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
// Track frontend H3 connection for the QUIC connection's lifetime.
let _frontend_h3_guard = ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let _frontend_h3_guard =
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::builder()
@@ -92,8 +93,15 @@ impl H3ProxyService {
tokio::spawn(async move {
if let Err(e) = handle_h3_request(
request, stream, port, remote_addr, &http_proxy, request_cancel,
).await {
request,
stream,
port,
remote_addr,
&http_proxy,
request_cancel,
)
.await
{
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
}
});
@@ -153,11 +161,14 @@ async fn handle_h3_request(
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
let conn_activity = ConnActivity::new_standalone();
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
let response = http_proxy
.handle_request(req, peer_addr, port, cancel, conn_activity)
.await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Await the body reader to get the H3 stream back
let mut stream = body_reader.await
let mut stream = body_reader
.await
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
// Send response headers over H3 (skip hop-by-hop headers)
@@ -170,10 +181,13 @@ async fn handle_h3_request(
}
h3_response = h3_response.header(name, value);
}
let h3_response = h3_response.body(())
let h3_response = h3_response
.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
stream.send_response(h3_response).await
stream
.send_response(h3_response)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back over H3
@@ -182,7 +196,9 @@ async fn handle_h3_request(
match frame {
Ok(frame) => {
if let Ok(data) = frame.into_data() {
stream.send_data(data).await
stream
.send_data(data)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
}
}
@@ -194,7 +210,9 @@ async fn handle_h3_request(
}
// Finish the H3 stream (send QUIC FIN)
stream.finish().await
stream
.finish()
.await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(())
+2 -1
View File
@@ -5,14 +5,15 @@
pub mod connection_pool;
pub mod counting_body;
pub mod h3_service;
pub mod protocol_cache;
pub mod proxy_service;
pub mod request_filter;
mod request_host;
pub mod response_filter;
pub mod shutdown_on_drop;
pub mod template;
pub mod upstream_selector;
pub mod h3_service;
pub use connection_pool::*;
pub use counting_body::*;
@@ -144,10 +144,14 @@ impl FailureState {
}
fn all_expired(&self) -> bool {
let h2_expired = self.h2.as_ref()
let h2_expired = self
.h2
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
let h3_expired = self.h3.as_ref()
let h3_expired = self
.h3
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
h2_expired && h3_expired
@@ -355,9 +359,13 @@ impl ProtocolCache {
let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record {
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => {
Some(existing)
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
{
// Still within the "recent" window — escalate
let c = existing.consecutive_failures.saturating_add(1)
let c = existing
.consecutive_failures
.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c))
}
@@ -394,8 +402,13 @@ impl ProtocolCache {
if protocol == DetectedProtocol::H1 {
return false;
}
self.failures.get(key)
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown))
self.failures
.get(key)
.and_then(|entry| {
entry
.get(protocol)
.map(|r| r.failed_at.elapsed() < r.cooldown)
})
.unwrap_or(false)
}
@@ -464,19 +477,18 @@ impl ProtocolCache {
/// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache.iter()
self.cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| {
let key = entry.key();
let val = entry.value();
let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info(
failure_info.as_deref().and_then(|f| f.h2.as_ref()),
);
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info(
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
);
let (h2_sup, h2_cd, h2_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
let (h3_sup, h3_cd, h3_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
ProtocolCacheEntry {
host: key.host.clone(),
@@ -507,7 +519,13 @@ impl ProtocolCache {
/// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling.
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) {
fn insert_internal(
&self,
key: ProtocolCacheKey,
protocol: DetectedProtocol,
h3_port: Option<u16>,
reason: &str,
) {
// Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol {
@@ -522,7 +540,9 @@ impl ProtocolCache {
// Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self.cache.iter()
let oldest = self
.cache
.iter()
.min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
@@ -531,13 +551,16 @@ impl ProtocolCache {
}
let now = Instant::now();
self.cache.insert(key, CachedEntry {
protocol,
detected_at: now,
last_accessed_at: now,
last_probed_at: now,
h3_port,
});
self.cache.insert(
key,
CachedEntry {
protocol,
detected_at: now,
last_accessed_at: now,
last_probed_at: now,
h3_port,
},
);
}
/// Reduce a failure record's remaining cooldown to `target`, if it currently
@@ -582,26 +605,34 @@ impl ProtocolCache {
interval.tick().await;
// Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache.iter()
let expired: Vec<ProtocolCacheKey> = cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone())
.collect();
if !expired.is_empty() {
debug!("Protocol cache cleanup: removing {} expired entries", expired.len());
debug!(
"Protocol cache cleanup: removing {} expired entries",
expired.len()
);
for key in expired {
cache.remove(&key);
}
}
// Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures.iter()
let expired_failures: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone())
.collect();
if !expired_failures.is_empty() {
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len());
debug!(
"Protocol cache cleanup: removing {} expired failure entries",
expired_failures.len()
);
for key in expired_failures {
failures.remove(&key);
}
@@ -609,7 +640,8 @@ impl ProtocolCache {
// Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures.iter()
let oldest: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|e| e.value().all_expired())
.map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
File diff suppressed because it is too large Load Diff
+141 -41
View File
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
use crate::request_host::extract_request_host;
pub struct RequestFilter;
@@ -35,16 +37,13 @@ impl RequestFilter {
let client_ip = peer_addr.ip();
let request_path = req.uri().path();
// IP filter (domain-aware: extract Host header for domain-scoped entries)
// IP filter (domain-aware: use the same host extraction as route matching)
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip);
let host = req.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.map(|h| h.split(':').next().unwrap_or(h));
let host = extract_request_host(req);
if !filter.is_allowed_for_domain(&normalized, host) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
}
@@ -59,16 +58,15 @@ impl RequestFilter {
!limiter.check(&key)
} else {
// Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new(
rate_limit_config.max_requests,
rate_limit_config.window,
);
let limiter =
RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
};
if should_block {
let message = rate_limit_config.error_message
let message = rate_limit_config
.error_message
.as_deref()
.unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
@@ -84,36 +82,48 @@ impl RequestFilter {
if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled {
// Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref()
let skip_basic = basic_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter()
let users: Vec<(String, String)> = basic_auth
.users
.iter()
.map(|c| (c.username.clone(), c.password.clone()))
.collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers()
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header {
Some(header) => {
if validator.validate(header).is_none() {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Invalid credentials"))
.unwrap());
return Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(
"WWW-Authenticate",
validator.www_authenticate(),
)
.body(boxed_body("Invalid credentials"))
.unwrap(),
);
}
}
None => {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap());
return Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap(),
);
}
}
}
@@ -124,7 +134,9 @@ impl RequestFilter {
if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled {
// Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref()
let skip_jwt = jwt_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
@@ -136,18 +148,25 @@ impl RequestFilter {
jwt_auth.audience.as_deref(),
);
let auth_header = req.headers()
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => {
if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
return Some(error_response(
StatusCode::UNAUTHORIZED,
"Invalid token",
));
}
}
None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
return Some(error_response(
StatusCode::UNAUTHORIZED,
"Bearer token required",
));
}
}
}
@@ -209,7 +228,11 @@ impl RequestFilter {
/// Check IP-based security (for use in passthrough / TCP-level connections).
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
/// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr, domain: Option<&str>) -> bool {
pub fn check_ip_security(
security: &RouteSecurity,
client_ip: &std::net::IpAddr,
domain: Option<&str>,
) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
@@ -238,19 +261,28 @@ impl RequestFilter {
return None;
}
let origin = req.headers()
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");
Some(Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap())
Some(
Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
)
.header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, X-Requested-With",
)
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap(),
)
}
}
@@ -265,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::{Request, StatusCode, Version};
use rustproxy_config::{IpAllowEntry, RouteSecurity};
use super::RequestFilter;
fn domain_scoped_security() -> RouteSecurity {
RouteSecurity {
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
ip: "10.8.0.2".to_string(),
domains: vec!["*.abc.xyz".to_string()],
}]),
ip_block_list: None,
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
}
}
fn peer_addr() -> std::net::SocketAddr {
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
}
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
let mut builder = Request::builder().uri(uri).version(version);
if let Some(host) = host {
builder = builder.header("host", host);
}
builder.body(Empty::<Bytes>::new()).unwrap()
}
#[test]
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
let security = domain_scoped_security();
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_allows_host_header_with_port() {
let security = domain_scoped_security();
let req = request(
"https://unrelated.invalid/",
Version::HTTP_11,
Some("outline.abc.xyz:443"),
);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_denies_non_matching_uri_authority() {
let security = domain_scoped_security();
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
let response = RequestFilter::apply(&security, &req, &peer_addr())
.expect("non-matching domain should be denied");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
@@ -0,0 +1,43 @@
use hyper::Request;
/// Extract the effective request host for routing and scoped ACL checks.
///
/// Prefer the explicit `Host` header when present, otherwise fall back to the
/// URI authority used by HTTP/2 and HTTP/3 requests.
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
req.headers()
.get("host")
.and_then(|value| value.to_str().ok())
.map(|host| host.split(':').next().unwrap_or(host))
.or_else(|| req.uri().host())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::Request;
use super::extract_request_host;
#[test]
fn extracts_host_header_before_uri_authority() {
let req = Request::builder()
.uri("https://uri.abc.xyz/test")
.header("host", "header.abc.xyz:443")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
}
#[test]
fn falls_back_to_uri_authority_when_host_header_missing() {
let req = Request::builder()
.uri("https://outline.abc.xyz/test")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
}
}
@@ -3,7 +3,7 @@
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template};
use crate::template::{expand_template, RequestContext};
pub struct ResponseFilter;
@@ -11,12 +11,17 @@ impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
pub fn apply_headers(
route: &RouteConfig,
headers: &mut HeaderMap,
req_ctx: Option<&RequestContext>,
) {
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
if let Some(ref udp) = route.action.udp {
if let Some(ref quic) = udp.quic {
if quic.enable_http3.unwrap_or(false) {
let port = quic.alt_svc_port
let port = quic
.alt_svc_port
.or_else(|| req_ctx.map(|c| c.port))
.unwrap_or(443);
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
@@ -63,10 +68,7 @@ impl ResponseFilter {
headers.insert("access-control-allow-origin", val);
}
} else {
headers.insert(
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
}
// Allow-Methods
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
self.inner.as_ref().unwrap().is_write_vectored()
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
if result.is_ready() {
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
tokio::io::AsyncWriteExt::shutdown(&mut stream),
).await;
)
.await;
// stream is dropped here — all resources freed
});
}
+6 -2
View File
@@ -39,7 +39,8 @@ pub fn expand_headers(
headers: &HashMap<String, String>,
ctx: &RequestContext,
) -> HashMap<String, String> {
headers.iter()
headers
.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect()
}
@@ -150,7 +151,10 @@ mod tests {
let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
assert_eq!(
result,
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
);
}
#[test]
@@ -7,7 +7,7 @@ use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
/// Upstream selection result.
pub struct UpstreamSelection {
@@ -51,21 +51,19 @@ impl UpstreamSelector {
}
// Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref()
let algorithm = target
.load_balancing
.as_ref()
.map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr);
hash % hosts.len()
}
LoadBalancingAlgorithm::LeastConnections => {
self.least_connections_select(&hosts, port)
}
LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
};
UpstreamSelection {
@@ -78,9 +76,7 @@ impl UpstreamSelector {
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap();
let counter = counters
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len()
}
@@ -91,7 +87,8 @@ impl UpstreamSelector {
for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port);
let conns = self.active_connections
let conns = self
.active_connections
.get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0);
@@ -228,13 +225,21 @@ mod tests {
selector.connection_started("backend:8080");
selector.connection_started("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
2
);
selector.connection_ended("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
1
);