fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup

This commit is contained in:
2026-04-14 00:54:12 +00:00
parent 6ee7237357
commit a53a2c4ca5
15 changed files with 1813 additions and 590 deletions
+8
View File
@@ -1,5 +1,13 @@
# Changelog # Changelog
## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics)
fix domain-scoped request host detection and harden connection metrics cleanup
- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header
- add request filter and host extraction tests covering domain-scoped ACL behavior
- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely
- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations
## 2026-04-13 - 27.7.0 - feat(smart-proxy) ## 2026-04-13 - 27.7.0 - feat(smart-proxy)
add typed Rust config serialization and regex header contract coverage add typed Rust config serialization and regex header contract coverage
@@ -3,8 +3,8 @@
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes. //! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection). //! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::Bytes; use bytes::Bytes;
@@ -105,13 +105,19 @@ impl ConnectionPool {
/// Try to check out an idle HTTP/1.1 sender for the given key. /// Try to check out an idle HTTP/1.1 sender for the given key.
/// Returns `None` if no usable idle connection exists. /// Returns `None` if no usable idle connection exists.
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> { pub fn checkout_h1(
&self,
key: &PoolKey,
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
let mut entry = self.h1_pool.get_mut(key)?; let mut entry = self.h1_pool.get_mut(key)?;
let idles = entry.value_mut(); let idles = entry.value_mut();
while let Some(idle) = idles.pop() { while let Some(idle) = idles.pop() {
// Check if the connection is still alive and ready // Check if the connection is still alive and ready
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() { if idle.idle_since.elapsed() < IDLE_TIMEOUT
&& idle.sender.is_ready()
&& !idle.sender.is_closed()
{
// H1 pool hit — no logging on hot path // H1 pool hit — no logging on hot path
return Some(idle.sender); return Some(idle.sender);
} }
@@ -128,7 +134,11 @@ impl ConnectionPool {
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared. /// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
/// The caller should NOT call this if the sender is closed or not ready. /// The caller should NOT call this if the sender is closed or not ready.
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) { pub fn checkin_h1(
&self,
key: PoolKey,
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
) {
if sender.is_closed() || !sender.is_ready() { if sender.is_closed() || !sender.is_ready() {
return; // Don't pool broken connections return; // Don't pool broken connections
} }
@@ -145,7 +155,10 @@ impl ConnectionPool {
/// Try to get a cloned HTTP/2 sender for the given key. /// Try to get a cloned HTTP/2 sender for the given key.
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove. /// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> { pub fn checkout_h2(
&self,
key: &PoolKey,
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
let entry = self.h2_pool.get(key)?; let entry = self.h2_pool.get(key)?;
let pooled = entry.value(); let pooled = entry.value();
let age = pooled.created_at.elapsed(); let age = pooled.created_at.elapsed();
@@ -184,16 +197,23 @@ impl ConnectionPool {
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry. /// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
/// The caller should pass this generation to the connection driver so it can use /// The caller should pass this generation to the connection driver so it can use
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction. /// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 { pub fn register_h2(
&self,
key: PoolKey,
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed); let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
if sender.is_closed() { if sender.is_closed() {
return gen; return gen;
} }
self.h2_pool.insert(key, PooledH2 { self.h2_pool.insert(
sender, key,
created_at: Instant::now(), PooledH2 {
generation: gen, sender,
}); created_at: Instant::now(),
generation: gen,
},
);
gen gen
} }
@@ -204,7 +224,11 @@ impl ConnectionPool {
pub fn checkout_h3( pub fn checkout_h3(
&self, &self,
key: &PoolKey, key: &PoolKey,
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> { ) -> Option<(
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
quinn::Connection,
Duration,
)> {
let entry = self.h3_pool.get(key)?; let entry = self.h3_pool.get(key)?;
let pooled = entry.value(); let pooled = entry.value();
let age = pooled.created_at.elapsed(); let age = pooled.created_at.elapsed();
@@ -234,12 +258,15 @@ impl ConnectionPool {
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
) -> u64 { ) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed); let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(key, PooledH3 { self.h3_pool.insert(
send_request, key,
connection, PooledH3 {
created_at: Instant::now(), send_request,
generation: gen, connection,
}); created_at: Instant::now(),
generation: gen,
},
);
gen gen
} }
@@ -280,7 +307,9 @@ impl ConnectionPool {
// Evict dead or aged-out H2 connections // Evict dead or aged-out H2 connections
let mut dead_h2 = Vec::new(); let mut dead_h2 = Vec::new();
for entry in h2_pool.iter() { for entry in h2_pool.iter() {
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE { if entry.value().sender.is_closed()
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
{
dead_h2.push(entry.key().clone()); dead_h2.push(entry.key().clone());
} }
} }
@@ -1,8 +1,8 @@
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector. //! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
/// Set the connection-level activity tracker. When set, each data frame /// Set the connection-level activity tracker. When set, each data frame
/// updates this timestamp to prevent the idle watchdog from killing the /// updates this timestamp to prevent the idle watchdog from killing the
/// connection during active body streaming. /// connection during active body streaming.
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self { pub fn with_connection_activity(
mut self,
activity: Arc<AtomicU64>,
start: std::time::Instant,
) -> Self {
self.connection_activity = Some(activity); self.connection_activity = Some(activity);
self.activity_start = Some(start); self.activity_start = Some(start);
self self
@@ -134,7 +138,9 @@ where
} }
// Keep the connection-level idle watchdog alive on every frame // Keep the connection-level idle watchdog alive on every frame
// (this is just one atomic store — cheap enough per-frame) // (this is just one atomic store — cheap enough per-frame)
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) { if let (Some(activity), Some(start)) =
(&this.connection_activity, &this.activity_start)
{
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
} }
} }
+28 -10
View File
@@ -11,8 +11,8 @@ use std::task::{Context, Poll};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use http_body::Frame; use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use tracing::{debug, warn}; use tracing::{debug, warn};
use rustproxy_config::RouteConfig; use rustproxy_config::RouteConfig;
@@ -49,7 +49,8 @@ impl H3ProxyService {
debug!("HTTP/3 connection from {} on port {}", remote_addr, port); debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
// Track frontend H3 connection for the QUIC connection's lifetime. // Track frontend H3 connection for the QUIC connection's lifetime.
let _frontend_h3_guard = ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3"); let _frontend_h3_guard =
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> = let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::builder() h3::server::builder()
@@ -92,8 +93,15 @@ impl H3ProxyService {
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_h3_request( if let Err(e) = handle_h3_request(
request, stream, port, remote_addr, &http_proxy, request_cancel, request,
).await { stream,
port,
remote_addr,
&http_proxy,
request_cancel,
)
.await
{
debug!("HTTP/3 request error from {}: {}", remote_addr, e); debug!("HTTP/3 request error from {}: {}", remote_addr, e);
} }
}); });
@@ -153,11 +161,14 @@ async fn handle_h3_request(
// Delegate to HttpProxyService — same backend path as TCP/HTTP: // Delegate to HttpProxyService — same backend path as TCP/HTTP:
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto. // route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
let conn_activity = ConnActivity::new_standalone(); let conn_activity = ConnActivity::new_standalone();
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await let response = http_proxy
.handle_request(req, peer_addr, port, cancel, conn_activity)
.await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?; .map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Await the body reader to get the H3 stream back // Await the body reader to get the H3 stream back
let mut stream = body_reader.await let mut stream = body_reader
.await
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?; .map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
// Send response headers over H3 (skip hop-by-hop headers) // Send response headers over H3 (skip hop-by-hop headers)
@@ -170,10 +181,13 @@ async fn handle_h3_request(
} }
h3_response = h3_response.header(name, value); h3_response = h3_response.header(name, value);
} }
let h3_response = h3_response.body(()) let h3_response = h3_response
.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
stream.send_response(h3_response).await stream
.send_response(h3_response)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back over H3 // Stream response body back over H3
@@ -182,7 +196,9 @@ async fn handle_h3_request(
match frame { match frame {
Ok(frame) => { Ok(frame) => {
if let Ok(data) = frame.into_data() { if let Ok(data) = frame.into_data() {
stream.send_data(data).await stream
.send_data(data)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
} }
} }
@@ -194,7 +210,9 @@ async fn handle_h3_request(
} }
// Finish the H3 stream (send QUIC FIN) // Finish the H3 stream (send QUIC FIN)
stream.finish().await stream
.finish()
.await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(()) Ok(())
+2 -1
View File
@@ -5,14 +5,15 @@
pub mod connection_pool; pub mod connection_pool;
pub mod counting_body; pub mod counting_body;
pub mod h3_service;
pub mod protocol_cache; pub mod protocol_cache;
pub mod proxy_service; pub mod proxy_service;
pub mod request_filter; pub mod request_filter;
mod request_host;
pub mod response_filter; pub mod response_filter;
pub mod shutdown_on_drop; pub mod shutdown_on_drop;
pub mod template; pub mod template;
pub mod upstream_selector; pub mod upstream_selector;
pub mod h3_service;
pub use connection_pool::*; pub use connection_pool::*;
pub use counting_body::*; pub use counting_body::*;
@@ -144,10 +144,14 @@ impl FailureState {
} }
fn all_expired(&self) -> bool { fn all_expired(&self) -> bool {
let h2_expired = self.h2.as_ref() let h2_expired = self
.h2
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown) .map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true); .unwrap_or(true);
let h3_expired = self.h3.as_ref() let h3_expired = self
.h3
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown) .map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true); .unwrap_or(true);
h2_expired && h3_expired h2_expired && h3_expired
@@ -355,9 +359,13 @@ impl ProtocolCache {
let record = entry.get_mut(protocol); let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record { let (consecutive, new_cooldown) = match record {
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => { Some(existing)
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
{
// Still within the "recent" window — escalate // Still within the "recent" window — escalate
let c = existing.consecutive_failures.saturating_add(1) let c = existing
.consecutive_failures
.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP); .min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c)) (c, escalate_cooldown(c))
} }
@@ -394,8 +402,13 @@ impl ProtocolCache {
if protocol == DetectedProtocol::H1 { if protocol == DetectedProtocol::H1 {
return false; return false;
} }
self.failures.get(key) self.failures
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown)) .get(key)
.and_then(|entry| {
entry
.get(protocol)
.map(|r| r.failed_at.elapsed() < r.cooldown)
})
.unwrap_or(false) .unwrap_or(false)
} }
@@ -464,19 +477,18 @@ impl ProtocolCache {
/// Snapshot all non-expired cache entries for metrics/UI display. /// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> { pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache.iter() self.cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL) .filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| { .map(|entry| {
let key = entry.key(); let key = entry.key();
let val = entry.value(); let val = entry.value();
let failure_info = self.failures.get(key); let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info( let (h2_sup, h2_cd, h2_cons) =
failure_info.as_deref().and_then(|f| f.h2.as_ref()), Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
); let (h3_sup, h3_cd, h3_cons) =
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info( Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
);
ProtocolCacheEntry { ProtocolCacheEntry {
host: key.host.clone(), host: key.host.clone(),
@@ -507,7 +519,13 @@ impl ProtocolCache {
/// Insert a protocol detection result with an optional H3 port. /// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry. /// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling. /// No suppression check — callers must check before calling.
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) { fn insert_internal(
&self,
key: ProtocolCacheKey,
protocol: DetectedProtocol,
h3_port: Option<u16>,
reason: &str,
) {
// Check for existing entry to log protocol transitions // Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) { if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol { if existing.protocol != protocol {
@@ -522,7 +540,9 @@ impl ProtocolCache {
// Evict oldest entry if at capacity // Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) { if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self.cache.iter() let oldest = self
.cache
.iter()
.min_by_key(|entry| entry.value().last_accessed_at) .min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone()); .map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest { if let Some(oldest_key) = oldest {
@@ -531,13 +551,16 @@ impl ProtocolCache {
} }
let now = Instant::now(); let now = Instant::now();
self.cache.insert(key, CachedEntry { self.cache.insert(
protocol, key,
detected_at: now, CachedEntry {
last_accessed_at: now, protocol,
last_probed_at: now, detected_at: now,
h3_port, last_accessed_at: now,
}); last_probed_at: now,
h3_port,
},
);
} }
/// Reduce a failure record's remaining cooldown to `target`, if it currently /// Reduce a failure record's remaining cooldown to `target`, if it currently
@@ -582,26 +605,34 @@ impl ProtocolCache {
interval.tick().await; interval.tick().await;
// Clean expired cache entries (sliding TTL based on last_accessed_at) // Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache.iter() let expired: Vec<ProtocolCacheKey> = cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL) .filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone()) .map(|entry| entry.key().clone())
.collect(); .collect();
if !expired.is_empty() { if !expired.is_empty() {
debug!("Protocol cache cleanup: removing {} expired entries", expired.len()); debug!(
"Protocol cache cleanup: removing {} expired entries",
expired.len()
);
for key in expired { for key in expired {
cache.remove(&key); cache.remove(&key);
} }
} }
// Clean fully-expired failure entries // Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures.iter() let expired_failures: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|entry| entry.value().all_expired()) .filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone()) .map(|entry| entry.key().clone())
.collect(); .collect();
if !expired_failures.is_empty() { if !expired_failures.is_empty() {
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len()); debug!(
"Protocol cache cleanup: removing {} expired failure entries",
expired_failures.len()
);
for key in expired_failures { for key in expired_failures {
failures.remove(&key); failures.remove(&key);
} }
@@ -609,7 +640,8 @@ impl ProtocolCache {
// Safety net: cap failures map at 2× max entries // Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 { if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures.iter() let oldest: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|e| e.value().all_expired()) .filter(|e| e.value().all_expired())
.map(|e| e.key().clone()) .map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES) .take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
File diff suppressed because it is too large Load Diff
+141 -41
View File
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use bytes::Bytes; use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use rustproxy_config::RouteSecurity; use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter}; use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
use crate::request_host::extract_request_host;
pub struct RequestFilter; pub struct RequestFilter;
@@ -35,16 +37,13 @@ impl RequestFilter {
let client_ip = peer_addr.ip(); let client_ip = peer_addr.ip();
let request_path = req.uri().path(); let request_path = req.uri().path();
// IP filter (domain-aware: extract Host header for domain-scoped entries) // IP filter (domain-aware: use the same host extraction as route matching)
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]); let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block); let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip); let normalized = IpFilter::normalize_ip(&client_ip);
let host = req.headers() let host = extract_request_host(req);
.get("host")
.and_then(|v| v.to_str().ok())
.map(|h| h.split(':').next().unwrap_or(h));
if !filter.is_allowed_for_domain(&normalized, host) { if !filter.is_allowed_for_domain(&normalized, host) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied")); return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
} }
@@ -59,16 +58,15 @@ impl RequestFilter {
!limiter.check(&key) !limiter.check(&key)
} else { } else {
// Create a per-check limiter (less ideal but works for non-shared case) // Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new( let limiter =
rate_limit_config.max_requests, RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
rate_limit_config.window,
);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr); let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key) !limiter.check(&key)
}; };
if should_block { if should_block {
let message = rate_limit_config.error_message let message = rate_limit_config
.error_message
.as_deref() .as_deref()
.unwrap_or("Rate limit exceeded"); .unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message)); return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
@@ -84,36 +82,48 @@ impl RequestFilter {
if let Some(ref basic_auth) = security.basic_auth { if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled { if basic_auth.enabled {
// Check basic auth exclude paths // Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref() let skip_basic = basic_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths)) .map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false); .unwrap_or(false);
if !skip_basic { if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter() let users: Vec<(String, String)> = basic_auth
.users
.iter()
.map(|c| (c.username.clone(), c.password.clone())) .map(|c| (c.username.clone(), c.password.clone()))
.collect(); .collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone()); let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers() let auth_header = req
.headers()
.get("authorization") .get("authorization")
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok());
match auth_header { match auth_header {
Some(header) => { Some(header) => {
if validator.validate(header).is_none() { if validator.validate(header).is_none() {
return Some(Response::builder() return Some(
.status(StatusCode::UNAUTHORIZED) Response::builder()
.header("WWW-Authenticate", validator.www_authenticate()) .status(StatusCode::UNAUTHORIZED)
.body(boxed_body("Invalid credentials")) .header(
.unwrap()); "WWW-Authenticate",
validator.www_authenticate(),
)
.body(boxed_body("Invalid credentials"))
.unwrap(),
);
} }
} }
None => { None => {
return Some(Response::builder() return Some(
.status(StatusCode::UNAUTHORIZED) Response::builder()
.header("WWW-Authenticate", validator.www_authenticate()) .status(StatusCode::UNAUTHORIZED)
.body(boxed_body("Authentication required")) .header("WWW-Authenticate", validator.www_authenticate())
.unwrap()); .body(boxed_body("Authentication required"))
.unwrap(),
);
} }
} }
} }
@@ -124,7 +134,9 @@ impl RequestFilter {
if let Some(ref jwt_auth) = security.jwt_auth { if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled { if jwt_auth.enabled {
// Check JWT auth exclude paths // Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref() let skip_jwt = jwt_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths)) .map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false); .unwrap_or(false);
@@ -136,18 +148,25 @@ impl RequestFilter {
jwt_auth.audience.as_deref(), jwt_auth.audience.as_deref(),
); );
let auth_header = req.headers() let auth_header = req
.headers()
.get("authorization") .get("authorization")
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) { match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => { Some(token) => {
if validator.validate(token).is_err() { if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token")); return Some(error_response(
StatusCode::UNAUTHORIZED,
"Invalid token",
));
} }
} }
None => { None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required")); return Some(error_response(
StatusCode::UNAUTHORIZED,
"Bearer token required",
));
} }
} }
} }
@@ -209,7 +228,11 @@ impl RequestFilter {
/// Check IP-based security (for use in passthrough / TCP-level connections). /// Check IP-based security (for use in passthrough / TCP-level connections).
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering. /// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
/// Returns true if allowed, false if blocked. /// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr, domain: Option<&str>) -> bool { pub fn check_ip_security(
security: &RouteSecurity,
client_ip: &std::net::IpAddr,
domain: Option<&str>,
) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]); let block = security.ip_block_list.as_deref().unwrap_or(&[]);
@@ -238,19 +261,28 @@ impl RequestFilter {
return None; return None;
} }
let origin = req.headers() let origin = req
.headers()
.get("origin") .get("origin")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("*"); .unwrap_or("*");
Some(Response::builder() Some(
.status(StatusCode::NO_CONTENT) Response::builder()
.header("Access-Control-Allow-Origin", origin) .status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") .header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") .header(
.header("Access-Control-Max-Age", "86400") "Access-Control-Allow-Methods",
.body(boxed_body("")) "GET, POST, PUT, DELETE, PATCH, OPTIONS",
.unwrap()) )
.header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, X-Requested-With",
)
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap(),
)
} }
} }
@@ -265,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> { fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {})) BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
} }
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::{Request, StatusCode, Version};
use rustproxy_config::{IpAllowEntry, RouteSecurity};
use super::RequestFilter;
fn domain_scoped_security() -> RouteSecurity {
RouteSecurity {
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
ip: "10.8.0.2".to_string(),
domains: vec!["*.abc.xyz".to_string()],
}]),
ip_block_list: None,
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
}
}
fn peer_addr() -> std::net::SocketAddr {
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
}
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
let mut builder = Request::builder().uri(uri).version(version);
if let Some(host) = host {
builder = builder.header("host", host);
}
builder.body(Empty::<Bytes>::new()).unwrap()
}
#[test]
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
let security = domain_scoped_security();
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_allows_host_header_with_port() {
let security = domain_scoped_security();
let req = request(
"https://unrelated.invalid/",
Version::HTTP_11,
Some("outline.abc.xyz:443"),
);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_denies_non_matching_uri_authority() {
let security = domain_scoped_security();
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
let response = RequestFilter::apply(&security, &req, &peer_addr())
.expect("non-matching domain should be denied");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
@@ -0,0 +1,43 @@
use hyper::Request;
/// Extract the effective request host for routing and scoped ACL checks.
///
/// Prefer the explicit `Host` header when present, otherwise fall back to the
/// URI authority used by HTTP/2 and HTTP/3 requests.
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
req.headers()
.get("host")
.and_then(|value| value.to_str().ok())
.map(|host| host.split(':').next().unwrap_or(host))
.or_else(|| req.uri().host())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::Request;
use super::extract_request_host;
#[test]
fn extracts_host_header_before_uri_authority() {
let req = Request::builder()
.uri("https://uri.abc.xyz/test")
.header("host", "header.abc.xyz:443")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
}
#[test]
fn falls_back_to_uri_authority_when_host_header_missing() {
let req = Request::builder()
.uri("https://outline.abc.xyz/test")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
}
}
@@ -3,7 +3,7 @@
use hyper::header::{HeaderMap, HeaderName, HeaderValue}; use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig; use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template}; use crate::template::{expand_template, RequestContext};
pub struct ResponseFilter; pub struct ResponseFilter;
@@ -11,12 +11,17 @@ impl ResponseFilter {
/// Apply response headers from route config and CORS settings. /// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded. /// If a `RequestContext` is provided, template variables in header values will be expanded.
/// Also injects Alt-Svc header for routes with HTTP/3 enabled. /// Also injects Alt-Svc header for routes with HTTP/3 enabled.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) { pub fn apply_headers(
route: &RouteConfig,
headers: &mut HeaderMap,
req_ctx: Option<&RequestContext>,
) {
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route // Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
if let Some(ref udp) = route.action.udp { if let Some(ref udp) = route.action.udp {
if let Some(ref quic) = udp.quic { if let Some(ref quic) = udp.quic {
if quic.enable_http3.unwrap_or(false) { if quic.enable_http3.unwrap_or(false) {
let port = quic.alt_svc_port let port = quic
.alt_svc_port
.or_else(|| req_ctx.map(|c| c.port)) .or_else(|| req_ctx.map(|c| c.port))
.unwrap_or(443); .unwrap_or(443);
let max_age = quic.alt_svc_max_age.unwrap_or(86400); let max_age = quic.alt_svc_max_age.unwrap_or(86400);
@@ -63,10 +68,7 @@ impl ResponseFilter {
headers.insert("access-control-allow-origin", val); headers.insert("access-control-allow-origin", val);
} }
} else { } else {
headers.insert( headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
} }
// Allow-Methods // Allow-Methods
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
self.inner.as_ref().unwrap().is_write_vectored() self.inner.as_ref().unwrap().is_write_vectored()
} }
fn poll_flush( fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx) Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
} }
fn poll_shutdown( fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx); let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
if result.is_ready() { if result.is_ready() {
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
let _ = tokio::time::timeout( let _ = tokio::time::timeout(
std::time::Duration::from_secs(2), std::time::Duration::from_secs(2),
tokio::io::AsyncWriteExt::shutdown(&mut stream), tokio::io::AsyncWriteExt::shutdown(&mut stream),
).await; )
.await;
// stream is dropped here — all resources freed // stream is dropped here — all resources freed
}); });
} }
+6 -2
View File
@@ -39,7 +39,8 @@ pub fn expand_headers(
headers: &HashMap<String, String>, headers: &HashMap<String, String>,
ctx: &RequestContext, ctx: &RequestContext,
) -> HashMap<String, String> { ) -> HashMap<String, String> {
headers.iter() headers
.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx))) .map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect() .collect()
} }
@@ -150,7 +151,10 @@ mod tests {
let ctx = test_context(); let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}"; let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx); let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42"); assert_eq!(
result,
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
);
} }
#[test] #[test]
@@ -7,7 +7,7 @@ use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use dashmap::DashMap; use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm}; use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
/// Upstream selection result. /// Upstream selection result.
pub struct UpstreamSelection { pub struct UpstreamSelection {
@@ -51,21 +51,19 @@ impl UpstreamSelector {
} }
// Determine load balancing algorithm // Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref() let algorithm = target
.load_balancing
.as_ref()
.map(|lb| &lb.algorithm) .map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin); .unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm { let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => { LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::IpHash => { LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr); let hash = Self::ip_hash(client_addr);
hash % hosts.len() hash % hosts.len()
} }
LoadBalancingAlgorithm::LeastConnections => { LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
self.least_connections_select(&hosts, port)
}
}; };
UpstreamSelection { UpstreamSelection {
@@ -78,9 +76,7 @@ impl UpstreamSelector {
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize { fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port); let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap(); let mut counters = self.round_robin.lock().unwrap();
let counter = counters let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed); let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len() idx % hosts.len()
} }
@@ -91,7 +87,8 @@ impl UpstreamSelector {
for (i, host) in hosts.iter().enumerate() { for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port); let key = format!("{}:{}", host, port);
let conns = self.active_connections let conns = self
.active_connections
.get(&key) .get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed)) .map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
@@ -228,13 +225,21 @@ mod tests {
selector.connection_started("backend:8080"); selector.connection_started("backend:8080");
selector.connection_started("backend:8080"); selector.connection_started("backend:8080");
assert_eq!( assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed), selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
2 2
); );
selector.connection_ended("backend:8080"); selector.connection_ended("backend:8080");
assert_eq!( assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed), selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
1 1
); );
+489 -150
View File
@@ -144,6 +144,15 @@ const MAX_BACKENDS_IN_SNAPSHOT: usize = 100;
/// Maximum number of distinct domains tracked per IP (prevents subdomain-spray abuse). /// Maximum number of distinct domains tracked per IP (prevents subdomain-spray abuse).
const MAX_DOMAINS_PER_IP: usize = 256; const MAX_DOMAINS_PER_IP: usize = 256;
fn canonicalize_domain_key(domain: &str) -> Option<String> {
let normalized = domain.trim().trim_end_matches('.').to_ascii_lowercase();
if normalized.is_empty() {
None
} else {
Some(normalized)
}
}
/// Metrics collector tracking connections and throughput. /// Metrics collector tracking connections and throughput.
/// ///
/// Design: The hot path (`record_bytes`) is entirely lock-free — it only touches /// Design: The hot path (`record_bytes`) is entirely lock-free — it only touches
@@ -334,25 +343,43 @@ impl MetricsCollector {
/// Record a connection closing. /// Record a connection closing.
pub fn connection_closed(&self, route_id: Option<&str>, source_ip: Option<&str>) { pub fn connection_closed(&self, route_id: Option<&str>, source_ip: Option<&str>) {
self.active_connections.fetch_sub(1, Ordering::Relaxed); self.active_connections
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
if v > 0 {
Some(v - 1)
} else {
None
}
})
.ok();
if let Some(route_id) = route_id { if let Some(route_id) = route_id {
if let Some(counter) = self.route_connections.get(route_id) { if let Some(counter) = self.route_connections.get(route_id) {
let val = counter.load(Ordering::Relaxed); counter
if val > 0 { .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
counter.fetch_sub(1, Ordering::Relaxed); if v > 0 {
} Some(v - 1)
} else {
None
}
})
.ok();
} }
} }
if let Some(ip) = source_ip { if let Some(ip) = source_ip {
if let Some(counter) = self.ip_connections.get(ip) { if let Some(counter) = self.ip_connections.get(ip) {
let val = counter.load(Ordering::Relaxed); let prev = counter
if val > 0 { .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
counter.fetch_sub(1, Ordering::Relaxed); if v > 0 {
} Some(v - 1)
} else {
None
}
})
.ok();
// Clean up zero-count entries to prevent memory growth // Clean up zero-count entries to prevent memory growth
if val <= 1 { if matches!(prev, Some(v) if v <= 1) {
drop(counter); drop(counter);
self.ip_connections.remove(ip); self.ip_connections.remove(ip);
// Evict all per-IP tracking data for this IP // Evict all per-IP tracking data for this IP
@@ -371,17 +398,25 @@ impl MetricsCollector {
/// ///
/// Called per-chunk in the TCP copy loop. Only touches AtomicU64 counters — /// Called per-chunk in the TCP copy loop. Only touches AtomicU64 counters —
/// no Mutex is taken. The throughput trackers are fed during `sample_all()`. /// no Mutex is taken. The throughput trackers are fed during `sample_all()`.
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>, source_ip: Option<&str>) { pub fn record_bytes(
&self,
bytes_in: u64,
bytes_out: u64,
route_id: Option<&str>,
source_ip: Option<&str>,
) {
// Short-circuit: only touch counters for the direction that has data. // Short-circuit: only touch counters for the direction that has data.
// CountingBody always calls with one direction zero — skipping the zero // CountingBody always calls with one direction zero — skipping the zero
// direction avoids ~50% of DashMap shard-locked reads per call. // direction avoids ~50% of DashMap shard-locked reads per call.
if bytes_in > 0 { if bytes_in > 0 {
self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed); self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.global_pending_tp_in.fetch_add(bytes_in, Ordering::Relaxed); self.global_pending_tp_in
.fetch_add(bytes_in, Ordering::Relaxed);
} }
if bytes_out > 0 { if bytes_out > 0 {
self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed); self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
self.global_pending_tp_out.fetch_add(bytes_out, Ordering::Relaxed); self.global_pending_tp_out
.fetch_add(bytes_out, Ordering::Relaxed);
} }
// Per-route tracking: use get() first (zero-alloc fast path for existing entries), // Per-route tracking: use get() first (zero-alloc fast path for existing entries),
@@ -391,7 +426,8 @@ impl MetricsCollector {
if let Some(counter) = self.route_bytes_in.get(route_id) { if let Some(counter) = self.route_bytes_in.get(route_id) {
counter.fetch_add(bytes_in, Ordering::Relaxed); counter.fetch_add(bytes_in, Ordering::Relaxed);
} else { } else {
self.route_bytes_in.entry(route_id.to_string()) self.route_bytes_in
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0)) .or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_in, Ordering::Relaxed); .fetch_add(bytes_in, Ordering::Relaxed);
} }
@@ -400,7 +436,8 @@ impl MetricsCollector {
if let Some(counter) = self.route_bytes_out.get(route_id) { if let Some(counter) = self.route_bytes_out.get(route_id) {
counter.fetch_add(bytes_out, Ordering::Relaxed); counter.fetch_add(bytes_out, Ordering::Relaxed);
} else { } else {
self.route_bytes_out.entry(route_id.to_string()) self.route_bytes_out
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0)) .or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_out, Ordering::Relaxed); .fetch_add(bytes_out, Ordering::Relaxed);
} }
@@ -408,13 +445,23 @@ impl MetricsCollector {
// Accumulate into per-route pending throughput counters (lock-free) // Accumulate into per-route pending throughput counters (lock-free)
if let Some(entry) = self.route_pending_tp.get(route_id) { if let Some(entry) = self.route_pending_tp.get(route_id) {
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } if bytes_in > 0 {
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } entry.0.fetch_add(bytes_in, Ordering::Relaxed);
}
if bytes_out > 0 {
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
}
} else { } else {
let entry = self.route_pending_tp.entry(route_id.to_string()) let entry = self
.route_pending_tp
.entry(route_id.to_string())
.or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0))); .or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0)));
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } if bytes_in > 0 {
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } entry.0.fetch_add(bytes_in, Ordering::Relaxed);
}
if bytes_out > 0 {
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
}
} }
} }
@@ -428,7 +475,8 @@ impl MetricsCollector {
if let Some(counter) = self.ip_bytes_in.get(ip) { if let Some(counter) = self.ip_bytes_in.get(ip) {
counter.fetch_add(bytes_in, Ordering::Relaxed); counter.fetch_add(bytes_in, Ordering::Relaxed);
} else { } else {
self.ip_bytes_in.entry(ip.to_string()) self.ip_bytes_in
.entry(ip.to_string())
.or_insert_with(|| AtomicU64::new(0)) .or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_in, Ordering::Relaxed); .fetch_add(bytes_in, Ordering::Relaxed);
} }
@@ -437,7 +485,8 @@ impl MetricsCollector {
if let Some(counter) = self.ip_bytes_out.get(ip) { if let Some(counter) = self.ip_bytes_out.get(ip) {
counter.fetch_add(bytes_out, Ordering::Relaxed); counter.fetch_add(bytes_out, Ordering::Relaxed);
} else { } else {
self.ip_bytes_out.entry(ip.to_string()) self.ip_bytes_out
.entry(ip.to_string())
.or_insert_with(|| AtomicU64::new(0)) .or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_out, Ordering::Relaxed); .fetch_add(bytes_out, Ordering::Relaxed);
} }
@@ -445,13 +494,23 @@ impl MetricsCollector {
// Accumulate into per-IP pending throughput counters (lock-free) // Accumulate into per-IP pending throughput counters (lock-free)
if let Some(entry) = self.ip_pending_tp.get(ip) { if let Some(entry) = self.ip_pending_tp.get(ip) {
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } if bytes_in > 0 {
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } entry.0.fetch_add(bytes_in, Ordering::Relaxed);
}
if bytes_out > 0 {
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
}
} else { } else {
let entry = self.ip_pending_tp.entry(ip.to_string()) let entry = self
.ip_pending_tp
.entry(ip.to_string())
.or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0))); .or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0)));
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); } if bytes_in > 0 {
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); } entry.0.fetch_add(bytes_in, Ordering::Relaxed);
}
if bytes_out > 0 {
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
}
} }
} }
} }
@@ -469,9 +528,13 @@ impl MetricsCollector {
/// connection (with SNI domain). The common case (IP + domain both already /// connection (with SNI domain). The common case (IP + domain both already
/// tracked) is two DashMap reads + one atomic increment — zero allocation. /// tracked) is two DashMap reads + one atomic increment — zero allocation.
pub fn record_ip_domain_request(&self, ip: &str, domain: &str) { pub fn record_ip_domain_request(&self, ip: &str, domain: &str) {
let Some(domain) = canonicalize_domain_key(domain) else {
return;
};
// Fast path: IP already tracked, domain already tracked // Fast path: IP already tracked, domain already tracked
if let Some(domains) = self.ip_domain_requests.get(ip) { if let Some(domains) = self.ip_domain_requests.get(ip) {
if let Some(counter) = domains.get(domain) { if let Some(counter) = domains.get(domain.as_str()) {
counter.fetch_add(1, Ordering::Relaxed); counter.fetch_add(1, Ordering::Relaxed);
return; return;
} }
@@ -480,7 +543,7 @@ impl MetricsCollector {
return; return;
} }
domains domains
.entry(domain.to_string()) .entry(domain)
.or_insert_with(|| AtomicU64::new(0)) .or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed); .fetch_add(1, Ordering::Relaxed);
return; return;
@@ -490,7 +553,7 @@ impl MetricsCollector {
return; return;
} }
let inner = DashMap::with_capacity_and_shard_amount(4, 2); let inner = DashMap::with_capacity_and_shard_amount(4, 2);
inner.insert(domain.to_string(), AtomicU64::new(1)); inner.insert(domain, AtomicU64::new(1));
self.ip_domain_requests.insert(ip.to_string(), inner); self.ip_domain_requests.insert(ip.to_string(), inner);
} }
@@ -504,7 +567,15 @@ impl MetricsCollector {
/// Record a UDP session closed. /// Record a UDP session closed.
pub fn udp_session_closed(&self) { pub fn udp_session_closed(&self) {
self.active_udp_sessions.fetch_sub(1, Ordering::Relaxed); self.active_udp_sessions
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
if v > 0 {
Some(v - 1)
} else {
None
}
})
.ok();
} }
/// Record a UDP datagram (inbound or outbound). /// Record a UDP datagram (inbound or outbound).
@@ -553,9 +624,15 @@ impl MetricsCollector {
let (active, _) = self.frontend_proto_counters(proto); let (active, _) = self.frontend_proto_counters(proto);
// Atomic saturating decrement — avoids TOCTOU race where concurrent // Atomic saturating decrement — avoids TOCTOU race where concurrent
// closes could both read val=1, both subtract, wrapping to u64::MAX. // closes could both read val=1, both subtract, wrapping to u64::MAX.
active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { active
if v > 0 { Some(v - 1) } else { None } .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
}).ok(); if v > 0 {
Some(v - 1)
} else {
None
}
})
.ok();
} }
/// Record a backend connection opened with a given protocol. /// Record a backend connection opened with a given protocol.
@@ -569,9 +646,15 @@ impl MetricsCollector {
pub fn backend_protocol_closed(&self, proto: &str) { pub fn backend_protocol_closed(&self, proto: &str) {
let (active, _) = self.backend_proto_counters(proto); let (active, _) = self.backend_proto_counters(proto);
// Atomic saturating decrement — see frontend_protocol_closed for rationale. // Atomic saturating decrement — see frontend_protocol_closed for rationale.
active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| { active
if v > 0 { Some(v - 1) } else { None } .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
}).ok(); if v > 0 {
Some(v - 1)
} else {
None
}
})
.ok();
} }
// ── Per-backend recording methods ── // ── Per-backend recording methods ──
@@ -681,17 +764,28 @@ impl MetricsCollector {
/// Remove per-backend metrics for backends no longer in any route target. /// Remove per-backend metrics for backends no longer in any route target.
pub fn retain_backends(&self, active_backends: &HashSet<String>) { pub fn retain_backends(&self, active_backends: &HashSet<String>) {
self.backend_active.retain(|k, _| active_backends.contains(k)); self.backend_active
self.backend_total.retain(|k, _| active_backends.contains(k)); .retain(|k, _| active_backends.contains(k));
self.backend_protocol.retain(|k, _| active_backends.contains(k)); self.backend_total
self.backend_connect_errors.retain(|k, _| active_backends.contains(k)); .retain(|k, _| active_backends.contains(k));
self.backend_handshake_errors.retain(|k, _| active_backends.contains(k)); self.backend_protocol
self.backend_request_errors.retain(|k, _| active_backends.contains(k)); .retain(|k, _| active_backends.contains(k));
self.backend_connect_time_us.retain(|k, _| active_backends.contains(k)); self.backend_connect_errors
self.backend_connect_count.retain(|k, _| active_backends.contains(k)); .retain(|k, _| active_backends.contains(k));
self.backend_pool_hits.retain(|k, _| active_backends.contains(k)); self.backend_handshake_errors
self.backend_pool_misses.retain(|k, _| active_backends.contains(k)); .retain(|k, _| active_backends.contains(k));
self.backend_h2_failures.retain(|k, _| active_backends.contains(k)); self.backend_request_errors
.retain(|k, _| active_backends.contains(k));
self.backend_connect_time_us
.retain(|k, _| active_backends.contains(k));
self.backend_connect_count
.retain(|k, _| active_backends.contains(k));
self.backend_pool_hits
.retain(|k, _| active_backends.contains(k));
self.backend_pool_misses
.retain(|k, _| active_backends.contains(k));
self.backend_h2_failures
.retain(|k, _| active_backends.contains(k));
} }
/// Take a throughput sample on all trackers (cold path, call at 1Hz or configured interval). /// Take a throughput sample on all trackers (cold path, call at 1Hz or configured interval).
@@ -782,41 +876,64 @@ impl MetricsCollector {
// Safety-net: prune orphaned per-IP entries that have no corresponding // Safety-net: prune orphaned per-IP entries that have no corresponding
// ip_connections entry. This catches any entries created by a race between // ip_connections entry. This catches any entries created by a race between
// record_bytes and connection_closed. // record_bytes and connection_closed.
self.ip_bytes_in.retain(|k, _| self.ip_connections.contains_key(k)); self.ip_bytes_in
self.ip_bytes_out.retain(|k, _| self.ip_connections.contains_key(k)); .retain(|k, _| self.ip_connections.contains_key(k));
self.ip_pending_tp.retain(|k, _| self.ip_connections.contains_key(k)); self.ip_bytes_out
self.ip_throughput.retain(|k, _| self.ip_connections.contains_key(k)); .retain(|k, _| self.ip_connections.contains_key(k));
self.ip_total_connections.retain(|k, _| self.ip_connections.contains_key(k)); self.ip_pending_tp
self.ip_domain_requests.retain(|k, _| self.ip_connections.contains_key(k)); .retain(|k, _| self.ip_connections.contains_key(k));
self.ip_throughput
.retain(|k, _| self.ip_connections.contains_key(k));
self.ip_total_connections
.retain(|k, _| self.ip_connections.contains_key(k));
self.ip_domain_requests
.retain(|k, _| self.ip_connections.contains_key(k));
// Safety-net: prune orphaned backend error/stats entries for backends // Safety-net: prune orphaned backend error/stats entries for backends
// that have no active or total connections (error-only backends). // that have no active or total connections (error-only backends).
// These accumulate when backend_connect_error/backend_handshake_error // These accumulate when backend_connect_error/backend_handshake_error
// create entries but backend_connection_opened is never called. // create entries but backend_connection_opened is never called.
let known_backends: HashSet<String> = self.backend_active.iter() let known_backends: HashSet<String> = self
.backend_active
.iter()
.map(|e| e.key().clone()) .map(|e| e.key().clone())
.chain(self.backend_total.iter().map(|e| e.key().clone())) .chain(self.backend_total.iter().map(|e| e.key().clone()))
.collect(); .collect();
self.backend_connect_errors.retain(|k, _| known_backends.contains(k)); self.backend_connect_errors
self.backend_handshake_errors.retain(|k, _| known_backends.contains(k)); .retain(|k, _| known_backends.contains(k));
self.backend_request_errors.retain(|k, _| known_backends.contains(k)); self.backend_handshake_errors
self.backend_connect_time_us.retain(|k, _| known_backends.contains(k)); .retain(|k, _| known_backends.contains(k));
self.backend_connect_count.retain(|k, _| known_backends.contains(k)); self.backend_request_errors
self.backend_pool_hits.retain(|k, _| known_backends.contains(k)); .retain(|k, _| known_backends.contains(k));
self.backend_pool_misses.retain(|k, _| known_backends.contains(k)); self.backend_connect_time_us
self.backend_h2_failures.retain(|k, _| known_backends.contains(k)); .retain(|k, _| known_backends.contains(k));
self.backend_protocol.retain(|k, _| known_backends.contains(k)); self.backend_connect_count
.retain(|k, _| known_backends.contains(k));
self.backend_pool_hits
.retain(|k, _| known_backends.contains(k));
self.backend_pool_misses
.retain(|k, _| known_backends.contains(k));
self.backend_h2_failures
.retain(|k, _| known_backends.contains(k));
self.backend_protocol
.retain(|k, _| known_backends.contains(k));
} }
/// Remove per-route metrics for route IDs that are no longer active. /// Remove per-route metrics for route IDs that are no longer active.
/// Call this after `update_routes()` to prune stale entries. /// Call this after `update_routes()` to prune stale entries.
pub fn retain_routes(&self, active_route_ids: &HashSet<String>) { pub fn retain_routes(&self, active_route_ids: &HashSet<String>) {
self.route_connections.retain(|k, _| active_route_ids.contains(k)); self.route_connections
self.route_total_connections.retain(|k, _| active_route_ids.contains(k)); .retain(|k, _| active_route_ids.contains(k));
self.route_bytes_in.retain(|k, _| active_route_ids.contains(k)); self.route_total_connections
self.route_bytes_out.retain(|k, _| active_route_ids.contains(k)); .retain(|k, _| active_route_ids.contains(k));
self.route_pending_tp.retain(|k, _| active_route_ids.contains(k)); self.route_bytes_in
self.route_throughput.retain(|k, _| active_route_ids.contains(k)); .retain(|k, _| active_route_ids.contains(k));
self.route_bytes_out
.retain(|k, _| active_route_ids.contains(k));
self.route_pending_tp
.retain(|k, _| active_route_ids.contains(k));
self.route_throughput
.retain(|k, _| active_route_ids.contains(k));
} }
/// Get current active connection count. /// Get current active connection count.
@@ -859,72 +976,97 @@ impl MetricsCollector {
for entry in self.route_total_connections.iter() { for entry in self.route_total_connections.iter() {
let route_id = entry.key().clone(); let route_id = entry.key().clone();
let total = entry.value().load(Ordering::Relaxed); let total = entry.value().load(Ordering::Relaxed);
let active = self.route_connections let active = self
.route_connections
.get(&route_id) .get(&route_id)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let bytes_in = self.route_bytes_in let bytes_in = self
.route_bytes_in
.get(&route_id) .get(&route_id)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let bytes_out = self.route_bytes_out let bytes_out = self
.route_bytes_out
.get(&route_id) .get(&route_id)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self.route_throughput let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self
.route_throughput
.get(&route_id) .get(&route_id)
.and_then(|entry| entry.value().lock().ok().map(|t| { .and_then(|entry| {
let (i_in, i_out) = t.instant(); entry.value().lock().ok().map(|t| {
let (r_in, r_out) = t.recent(); let (i_in, i_out) = t.instant();
(i_in, i_out, r_in, r_out) let (r_in, r_out) = t.recent();
})) (i_in, i_out, r_in, r_out)
})
})
.unwrap_or((0, 0, 0, 0)); .unwrap_or((0, 0, 0, 0));
routes.insert(route_id, RouteMetrics { routes.insert(
active_connections: active, route_id,
total_connections: total, RouteMetrics {
bytes_in, active_connections: active,
bytes_out, total_connections: total,
throughput_in_bytes_per_sec: route_tp_in, bytes_in,
throughput_out_bytes_per_sec: route_tp_out, bytes_out,
throughput_recent_in_bytes_per_sec: route_recent_in, throughput_in_bytes_per_sec: route_tp_in,
throughput_recent_out_bytes_per_sec: route_recent_out, throughput_out_bytes_per_sec: route_tp_out,
}); throughput_recent_in_bytes_per_sec: route_recent_in,
throughput_recent_out_bytes_per_sec: route_recent_out,
},
);
} }
// Collect per-IP metrics — only IPs with active connections or total > 0, // Collect per-IP metrics — only IPs with active connections or total > 0,
// capped at top MAX_IPS_IN_SNAPSHOT sorted by active count // capped at top MAX_IPS_IN_SNAPSHOT sorted by active count
let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap<String, u64>)> = Vec::new(); let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap<String, u64>)> =
Vec::new();
for entry in self.ip_total_connections.iter() { for entry in self.ip_total_connections.iter() {
let ip = entry.key().clone(); let ip = entry.key().clone();
let total = entry.value().load(Ordering::Relaxed); let total = entry.value().load(Ordering::Relaxed);
let active = self.ip_connections let active = self
.ip_connections
.get(&ip) .get(&ip)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let bytes_in = self.ip_bytes_in let bytes_in = self
.ip_bytes_in
.get(&ip) .get(&ip)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let bytes_out = self.ip_bytes_out let bytes_out = self
.ip_bytes_out
.get(&ip) .get(&ip)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let (tp_in, tp_out) = self.ip_throughput let (tp_in, tp_out) = self
.ip_throughput
.get(&ip) .get(&ip)
.and_then(|entry| entry.value().lock().ok().map(|t| t.instant())) .and_then(|entry| entry.value().lock().ok().map(|t| t.instant()))
.unwrap_or((0, 0)); .unwrap_or((0, 0));
// Collect per-domain request counts for this IP // Collect per-domain request counts for this IP
let domain_requests = self.ip_domain_requests let domain_requests = self
.ip_domain_requests
.get(&ip) .get(&ip)
.map(|domains| { .map(|domains| {
domains.iter() domains
.iter()
.map(|e| (e.key().clone(), e.value().load(Ordering::Relaxed))) .map(|e| (e.key().clone(), e.value().load(Ordering::Relaxed)))
.collect() .collect()
}) })
.unwrap_or_default(); .unwrap_or_default();
ip_entries.push((ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests)); ip_entries.push((
ip,
active,
total,
bytes_in,
bytes_out,
tp_in,
tp_out,
domain_requests,
));
} }
// Sort by active connections descending, then cap // Sort by active connections descending, then cap
ip_entries.sort_by(|a, b| b.1.cmp(&a.1)); ip_entries.sort_by(|a, b| b.1.cmp(&a.1));
@@ -932,15 +1074,18 @@ impl MetricsCollector {
let mut ips = std::collections::HashMap::new(); let mut ips = std::collections::HashMap::new();
for (ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests) in ip_entries { for (ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests) in ip_entries {
ips.insert(ip, IpMetrics { ips.insert(
active_connections: active, ip,
total_connections: total, IpMetrics {
bytes_in, active_connections: active,
bytes_out, total_connections: total,
throughput_in_bytes_per_sec: tp_in, bytes_in,
throughput_out_bytes_per_sec: tp_out, bytes_out,
domain_requests, throughput_in_bytes_per_sec: tp_in,
}); throughput_out_bytes_per_sec: tp_out,
domain_requests,
},
);
} }
// Collect per-backend metrics, capped at top MAX_BACKENDS_IN_SNAPSHOT by total connections // Collect per-backend metrics, capped at top MAX_BACKENDS_IN_SNAPSHOT by total connections
@@ -948,69 +1093,84 @@ impl MetricsCollector {
for entry in self.backend_total.iter() { for entry in self.backend_total.iter() {
let key = entry.key().clone(); let key = entry.key().clone();
let total = entry.value().load(Ordering::Relaxed); let total = entry.value().load(Ordering::Relaxed);
let active = self.backend_active let active = self
.backend_active
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let protocol = self.backend_protocol let protocol = self
.backend_protocol
.get(&key) .get(&key)
.map(|v| v.value().clone()) .map(|v| v.value().clone())
.unwrap_or_else(|| "unknown".to_string()); .unwrap_or_else(|| "unknown".to_string());
let connect_errors = self.backend_connect_errors let connect_errors = self
.backend_connect_errors
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let handshake_errors = self.backend_handshake_errors let handshake_errors = self
.backend_handshake_errors
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let request_errors = self.backend_request_errors let request_errors = self
.backend_request_errors
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let total_connect_time_us = self.backend_connect_time_us let total_connect_time_us = self
.backend_connect_time_us
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let connect_count = self.backend_connect_count let connect_count = self
.backend_connect_count
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let pool_hits = self.backend_pool_hits let pool_hits = self
.backend_pool_hits
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let pool_misses = self.backend_pool_misses let pool_misses = self
.backend_pool_misses
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
let h2_failures = self.backend_h2_failures let h2_failures = self
.backend_h2_failures
.get(&key) .get(&key)
.map(|c| c.load(Ordering::Relaxed)) .map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
backend_entries.push((key, BackendMetrics { backend_entries.push((
active_connections: active, key,
total_connections: total, BackendMetrics {
protocol, active_connections: active,
connect_errors, total_connections: total,
handshake_errors, protocol,
request_errors, connect_errors,
total_connect_time_us, handshake_errors,
connect_count, request_errors,
pool_hits, total_connect_time_us,
pool_misses, connect_count,
h2_failures, pool_hits,
})); pool_misses,
h2_failures,
},
));
} }
// Sort by total connections descending, then cap // Sort by total connections descending, then cap
backend_entries.sort_by(|a, b| b.1.total_connections.cmp(&a.1.total_connections)); backend_entries.sort_by(|a, b| b.1.total_connections.cmp(&a.1.total_connections));
backend_entries.truncate(MAX_BACKENDS_IN_SNAPSHOT); backend_entries.truncate(MAX_BACKENDS_IN_SNAPSHOT);
let backends: std::collections::HashMap<String, BackendMetrics> = backend_entries.into_iter().collect(); let backends: std::collections::HashMap<String, BackendMetrics> =
backend_entries.into_iter().collect();
// HTTP request rates // HTTP request rates
let (http_rps, http_rps_recent) = self.http_request_throughput let (http_rps, http_rps_recent) = self
.http_request_throughput
.lock() .lock()
.map(|t| { .map(|t| {
let (instant, _) = t.instant(); let (instant, _) = t.instant();
@@ -1185,11 +1345,19 @@ mod tests {
// Check IP active connections (drop DashMap refs immediately to avoid deadlock) // Check IP active connections (drop DashMap refs immediately to avoid deadlock)
assert_eq!( assert_eq!(
collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed), collector
.ip_connections
.get("1.2.3.4")
.unwrap()
.load(Ordering::Relaxed),
2 2
); );
assert_eq!( assert_eq!(
collector.ip_connections.get("5.6.7.8").unwrap().load(Ordering::Relaxed), collector
.ip_connections
.get("5.6.7.8")
.unwrap()
.load(Ordering::Relaxed),
1 1
); );
@@ -1207,7 +1375,11 @@ mod tests {
// Close connections // Close connections
collector.connection_closed(Some("route-a"), Some("1.2.3.4")); collector.connection_closed(Some("route-a"), Some("1.2.3.4"));
assert_eq!( assert_eq!(
collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed), collector
.ip_connections
.get("1.2.3.4")
.unwrap()
.load(Ordering::Relaxed),
1 1
); );
@@ -1252,6 +1424,79 @@ mod tests {
assert!(collector.ip_total_connections.get("10.0.0.2").is_some()); assert!(collector.ip_total_connections.get("10.0.0.2").is_some());
} }
#[test]
fn test_connection_closed_saturates_active_gauges() {
let collector = MetricsCollector::new();
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
assert_eq!(collector.active_connections(), 0);
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
assert_eq!(collector.active_connections(), 0);
assert_eq!(
collector
.route_connections
.get("route-a")
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0),
0
);
assert!(collector.ip_connections.get("10.0.0.1").is_none());
}
#[test]
fn test_udp_session_closed_saturates() {
let collector = MetricsCollector::new();
collector.udp_session_closed();
assert_eq!(collector.snapshot().active_udp_sessions, 0);
collector.udp_session_opened();
collector.udp_session_closed();
collector.udp_session_closed();
assert_eq!(collector.snapshot().active_udp_sessions, 0);
}
#[test]
fn test_ip_domain_requests_are_canonicalized() {
let collector = MetricsCollector::new();
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
collector.record_ip_domain_request("10.0.0.1", "Example.COM");
collector.record_ip_domain_request("10.0.0.1", "example.com.");
collector.record_ip_domain_request("10.0.0.1", " example.com ");
let snapshot = collector.snapshot();
let ip_metrics = snapshot.ips.get("10.0.0.1").unwrap();
assert_eq!(ip_metrics.domain_requests.len(), 1);
assert_eq!(ip_metrics.domain_requests.get("example.com"), Some(&3));
}
#[test]
fn test_protocol_metrics_appear_in_snapshot() {
let collector = MetricsCollector::new();
collector.frontend_protocol_opened("h2");
collector.frontend_protocol_opened("ws");
collector.backend_protocol_opened("h3");
collector.backend_protocol_opened("ws");
collector.frontend_protocol_closed("h2");
collector.backend_protocol_closed("h3");
let snapshot = collector.snapshot();
assert_eq!(snapshot.frontend_protocols.h2_active, 0);
assert_eq!(snapshot.frontend_protocols.h2_total, 1);
assert_eq!(snapshot.frontend_protocols.ws_active, 1);
assert_eq!(snapshot.frontend_protocols.ws_total, 1);
assert_eq!(snapshot.backend_protocols.h3_active, 0);
assert_eq!(snapshot.backend_protocols.h3_total, 1);
assert_eq!(snapshot.backend_protocols.ws_active, 1);
assert_eq!(snapshot.backend_protocols.ws_total, 1);
}
#[test] #[test]
fn test_http_request_tracking() { fn test_http_request_tracking() {
let collector = MetricsCollector::with_retention(60); let collector = MetricsCollector::with_retention(60);
@@ -1326,9 +1571,16 @@ mod tests {
let collector = MetricsCollector::with_retention(60); let collector = MetricsCollector::with_retention(60);
// Manually insert orphaned entries (simulates the race before the guard) // Manually insert orphaned entries (simulates the race before the guard)
collector.ip_bytes_in.insert("orphan-ip".to_string(), AtomicU64::new(100)); collector
collector.ip_bytes_out.insert("orphan-ip".to_string(), AtomicU64::new(200)); .ip_bytes_in
collector.ip_pending_tp.insert("orphan-ip".to_string(), (AtomicU64::new(0), AtomicU64::new(0))); .insert("orphan-ip".to_string(), AtomicU64::new(100));
collector
.ip_bytes_out
.insert("orphan-ip".to_string(), AtomicU64::new(200));
collector.ip_pending_tp.insert(
"orphan-ip".to_string(),
(AtomicU64::new(0), AtomicU64::new(0)),
);
// No ip_connections entry for "orphan-ip" // No ip_connections entry for "orphan-ip"
assert!(collector.ip_connections.get("orphan-ip").is_none()); assert!(collector.ip_connections.get("orphan-ip").is_none());
@@ -1366,17 +1618,59 @@ mod tests {
collector.backend_connection_opened(key, Duration::from_millis(15)); collector.backend_connection_opened(key, Duration::from_millis(15));
collector.backend_connection_opened(key, Duration::from_millis(25)); collector.backend_connection_opened(key, Duration::from_millis(25));
assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 2); assert_eq!(
assert_eq!(collector.backend_total.get(key).unwrap().load(Ordering::Relaxed), 2); collector
assert_eq!(collector.backend_connect_count.get(key).unwrap().load(Ordering::Relaxed), 2); .backend_active
.get(key)
.unwrap()
.load(Ordering::Relaxed),
2
);
assert_eq!(
collector
.backend_total
.get(key)
.unwrap()
.load(Ordering::Relaxed),
2
);
assert_eq!(
collector
.backend_connect_count
.get(key)
.unwrap()
.load(Ordering::Relaxed),
2
);
// 15ms + 25ms = 40ms = 40_000us // 15ms + 25ms = 40ms = 40_000us
assert_eq!(collector.backend_connect_time_us.get(key).unwrap().load(Ordering::Relaxed), 40_000); assert_eq!(
collector
.backend_connect_time_us
.get(key)
.unwrap()
.load(Ordering::Relaxed),
40_000
);
// Close one // Close one
collector.backend_connection_closed(key); collector.backend_connection_closed(key);
assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 1); assert_eq!(
collector
.backend_active
.get(key)
.unwrap()
.load(Ordering::Relaxed),
1
);
// total stays // total stays
assert_eq!(collector.backend_total.get(key).unwrap().load(Ordering::Relaxed), 2); assert_eq!(
collector
.backend_total
.get(key)
.unwrap()
.load(Ordering::Relaxed),
2
);
// Record errors // Record errors
collector.backend_connect_error(key); collector.backend_connect_error(key);
@@ -1387,12 +1681,54 @@ mod tests {
collector.backend_pool_hit(key); collector.backend_pool_hit(key);
collector.backend_pool_miss(key); collector.backend_pool_miss(key);
assert_eq!(collector.backend_connect_errors.get(key).unwrap().load(Ordering::Relaxed), 1); assert_eq!(
assert_eq!(collector.backend_handshake_errors.get(key).unwrap().load(Ordering::Relaxed), 1); collector
assert_eq!(collector.backend_request_errors.get(key).unwrap().load(Ordering::Relaxed), 1); .backend_connect_errors
assert_eq!(collector.backend_h2_failures.get(key).unwrap().load(Ordering::Relaxed), 1); .get(key)
assert_eq!(collector.backend_pool_hits.get(key).unwrap().load(Ordering::Relaxed), 2); .unwrap()
assert_eq!(collector.backend_pool_misses.get(key).unwrap().load(Ordering::Relaxed), 1); .load(Ordering::Relaxed),
1
);
assert_eq!(
collector
.backend_handshake_errors
.get(key)
.unwrap()
.load(Ordering::Relaxed),
1
);
assert_eq!(
collector
.backend_request_errors
.get(key)
.unwrap()
.load(Ordering::Relaxed),
1
);
assert_eq!(
collector
.backend_h2_failures
.get(key)
.unwrap()
.load(Ordering::Relaxed),
1
);
assert_eq!(
collector
.backend_pool_hits
.get(key)
.unwrap()
.load(Ordering::Relaxed),
2
);
assert_eq!(
collector
.backend_pool_misses
.get(key)
.unwrap()
.load(Ordering::Relaxed),
1
);
// Protocol // Protocol
collector.set_backend_protocol(key, "h1"); collector.set_backend_protocol(key, "h1");
@@ -1449,7 +1785,10 @@ mod tests {
assert!(collector.backend_total.get("stale:8080").is_none()); assert!(collector.backend_total.get("stale:8080").is_none());
assert!(collector.backend_protocol.get("stale:8080").is_none()); assert!(collector.backend_protocol.get("stale:8080").is_none());
assert!(collector.backend_connect_errors.get("stale:8080").is_none()); assert!(collector.backend_connect_errors.get("stale:8080").is_none());
assert!(collector.backend_connect_time_us.get("stale:8080").is_none()); assert!(collector
.backend_connect_time_us
.get("stale:8080")
.is_none());
assert!(collector.backend_connect_count.get("stale:8080").is_none()); assert!(collector.backend_connect_count.get("stale:8080").is_none());
assert!(collector.backend_pool_hits.get("stale:8080").is_none()); assert!(collector.backend_pool_hits.get("stale:8080").is_none());
assert!(collector.backend_pool_misses.get("stale:8080").is_none()); assert!(collector.backend_pool_misses.get("stale:8080").is_none());
+1 -1
View File
@@ -3,6 +3,6 @@
*/ */
export const commitinfo = { export const commitinfo = {
name: '@push.rocks/smartproxy', name: '@push.rocks/smartproxy',
version: '27.7.0', version: '27.7.1',
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.' description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
} }