fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup
This commit is contained in:
@@ -3,8 +3,8 @@
|
||||
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
|
||||
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -105,13 +105,19 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to check out an idle HTTP/1.1 sender for the given key.
|
||||
/// Returns `None` if no usable idle connection exists.
|
||||
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
|
||||
pub fn checkout_h1(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
|
||||
let mut entry = self.h1_pool.get_mut(key)?;
|
||||
let idles = entry.value_mut();
|
||||
|
||||
while let Some(idle) = idles.pop() {
|
||||
// Check if the connection is still alive and ready
|
||||
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() {
|
||||
if idle.idle_since.elapsed() < IDLE_TIMEOUT
|
||||
&& idle.sender.is_ready()
|
||||
&& !idle.sender.is_closed()
|
||||
{
|
||||
// H1 pool hit — no logging on hot path
|
||||
return Some(idle.sender);
|
||||
}
|
||||
@@ -128,7 +134,11 @@ impl ConnectionPool {
|
||||
|
||||
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
|
||||
/// The caller should NOT call this if the sender is closed or not ready.
|
||||
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) {
|
||||
pub fn checkin_h1(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||
) {
|
||||
if sender.is_closed() || !sender.is_ready() {
|
||||
return; // Don't pool broken connections
|
||||
}
|
||||
@@ -145,7 +155,10 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to get a cloned HTTP/2 sender for the given key.
|
||||
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
|
||||
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
|
||||
pub fn checkout_h2(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
|
||||
let entry = self.h2_pool.get(key)?;
|
||||
let pooled = entry.value();
|
||||
let age = pooled.created_at.elapsed();
|
||||
@@ -184,16 +197,23 @@ impl ConnectionPool {
|
||||
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
|
||||
/// The caller should pass this generation to the connection driver so it can use
|
||||
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
|
||||
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 {
|
||||
pub fn register_h2(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||
) -> u64 {
|
||||
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
|
||||
if sender.is_closed() {
|
||||
return gen;
|
||||
}
|
||||
self.h2_pool.insert(key, PooledH2 {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
});
|
||||
self.h2_pool.insert(
|
||||
key,
|
||||
PooledH2 {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
},
|
||||
);
|
||||
gen
|
||||
}
|
||||
|
||||
@@ -204,7 +224,11 @@ impl ConnectionPool {
|
||||
pub fn checkout_h3(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> {
|
||||
) -> Option<(
|
||||
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
quinn::Connection,
|
||||
Duration,
|
||||
)> {
|
||||
let entry = self.h3_pool.get(key)?;
|
||||
let pooled = entry.value();
|
||||
let age = pooled.created_at.elapsed();
|
||||
@@ -234,12 +258,15 @@ impl ConnectionPool {
|
||||
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
) -> u64 {
|
||||
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
|
||||
self.h3_pool.insert(key, PooledH3 {
|
||||
send_request,
|
||||
connection,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
});
|
||||
self.h3_pool.insert(
|
||||
key,
|
||||
PooledH3 {
|
||||
send_request,
|
||||
connection,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
},
|
||||
);
|
||||
gen
|
||||
}
|
||||
|
||||
@@ -280,7 +307,9 @@ impl ConnectionPool {
|
||||
// Evict dead or aged-out H2 connections
|
||||
let mut dead_h2 = Vec::new();
|
||||
for entry in h2_pool.iter() {
|
||||
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE {
|
||||
if entry.value().sender.is_closed()
|
||||
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
|
||||
{
|
||||
dead_h2.push(entry.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
|
||||
/// Set the connection-level activity tracker. When set, each data frame
|
||||
/// updates this timestamp to prevent the idle watchdog from killing the
|
||||
/// connection during active body streaming.
|
||||
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self {
|
||||
pub fn with_connection_activity(
|
||||
mut self,
|
||||
activity: Arc<AtomicU64>,
|
||||
start: std::time::Instant,
|
||||
) -> Self {
|
||||
self.connection_activity = Some(activity);
|
||||
self.activity_start = Some(start);
|
||||
self
|
||||
@@ -134,7 +138,9 @@ where
|
||||
}
|
||||
// Keep the connection-level idle watchdog alive on every frame
|
||||
// (this is just one atomic store — cheap enough per-frame)
|
||||
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) {
|
||||
if let (Some(activity), Some(start)) =
|
||||
(&this.connection_activity, &this.activity_start)
|
||||
{
|
||||
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ use std::task::{Context, Poll};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use http_body::Frame;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use rustproxy_config::RouteConfig;
|
||||
@@ -49,7 +49,8 @@ impl H3ProxyService {
|
||||
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
|
||||
|
||||
// Track frontend H3 connection for the QUIC connection's lifetime.
|
||||
let _frontend_h3_guard = ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
|
||||
let _frontend_h3_guard =
|
||||
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
|
||||
|
||||
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
|
||||
h3::server::builder()
|
||||
@@ -92,8 +93,15 @@ impl H3ProxyService {
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_h3_request(
|
||||
request, stream, port, remote_addr, &http_proxy, request_cancel,
|
||||
).await {
|
||||
request,
|
||||
stream,
|
||||
port,
|
||||
remote_addr,
|
||||
&http_proxy,
|
||||
request_cancel,
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
|
||||
}
|
||||
});
|
||||
@@ -153,11 +161,14 @@ async fn handle_h3_request(
|
||||
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
|
||||
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
|
||||
let conn_activity = ConnActivity::new_standalone();
|
||||
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
|
||||
let response = http_proxy
|
||||
.handle_request(req, peer_addr, port, cancel, conn_activity)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
||||
|
||||
// Await the body reader to get the H3 stream back
|
||||
let mut stream = body_reader.await
|
||||
let mut stream = body_reader
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
|
||||
|
||||
// Send response headers over H3 (skip hop-by-hop headers)
|
||||
@@ -170,10 +181,13 @@ async fn handle_h3_request(
|
||||
}
|
||||
h3_response = h3_response.header(name, value);
|
||||
}
|
||||
let h3_response = h3_response.body(())
|
||||
let h3_response = h3_response
|
||||
.body(())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
|
||||
|
||||
stream.send_response(h3_response).await
|
||||
stream
|
||||
.send_response(h3_response)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
|
||||
|
||||
// Stream response body back over H3
|
||||
@@ -182,7 +196,9 @@ async fn handle_h3_request(
|
||||
match frame {
|
||||
Ok(frame) => {
|
||||
if let Ok(data) = frame.into_data() {
|
||||
stream.send_data(data).await
|
||||
stream
|
||||
.send_data(data)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
|
||||
}
|
||||
}
|
||||
@@ -194,7 +210,9 @@ async fn handle_h3_request(
|
||||
}
|
||||
|
||||
// Finish the H3 stream (send QUIC FIN)
|
||||
stream.finish().await
|
||||
stream
|
||||
.finish()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -5,14 +5,15 @@
|
||||
|
||||
pub mod connection_pool;
|
||||
pub mod counting_body;
|
||||
pub mod h3_service;
|
||||
pub mod protocol_cache;
|
||||
pub mod proxy_service;
|
||||
pub mod request_filter;
|
||||
mod request_host;
|
||||
pub mod response_filter;
|
||||
pub mod shutdown_on_drop;
|
||||
pub mod template;
|
||||
pub mod upstream_selector;
|
||||
pub mod h3_service;
|
||||
|
||||
pub use connection_pool::*;
|
||||
pub use counting_body::*;
|
||||
|
||||
@@ -144,10 +144,14 @@ impl FailureState {
|
||||
}
|
||||
|
||||
fn all_expired(&self) -> bool {
|
||||
let h2_expired = self.h2.as_ref()
|
||||
let h2_expired = self
|
||||
.h2
|
||||
.as_ref()
|
||||
.map(|r| r.failed_at.elapsed() >= r.cooldown)
|
||||
.unwrap_or(true);
|
||||
let h3_expired = self.h3.as_ref()
|
||||
let h3_expired = self
|
||||
.h3
|
||||
.as_ref()
|
||||
.map(|r| r.failed_at.elapsed() >= r.cooldown)
|
||||
.unwrap_or(true);
|
||||
h2_expired && h3_expired
|
||||
@@ -355,9 +359,13 @@ impl ProtocolCache {
|
||||
|
||||
let record = entry.get_mut(protocol);
|
||||
let (consecutive, new_cooldown) = match record {
|
||||
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => {
|
||||
Some(existing)
|
||||
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
|
||||
{
|
||||
// Still within the "recent" window — escalate
|
||||
let c = existing.consecutive_failures.saturating_add(1)
|
||||
let c = existing
|
||||
.consecutive_failures
|
||||
.saturating_add(1)
|
||||
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
|
||||
(c, escalate_cooldown(c))
|
||||
}
|
||||
@@ -394,8 +402,13 @@ impl ProtocolCache {
|
||||
if protocol == DetectedProtocol::H1 {
|
||||
return false;
|
||||
}
|
||||
self.failures.get(key)
|
||||
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown))
|
||||
self.failures
|
||||
.get(key)
|
||||
.and_then(|entry| {
|
||||
entry
|
||||
.get(protocol)
|
||||
.map(|r| r.failed_at.elapsed() < r.cooldown)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
@@ -464,19 +477,18 @@ impl ProtocolCache {
|
||||
|
||||
/// Snapshot all non-expired cache entries for metrics/UI display.
|
||||
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
|
||||
self.cache.iter()
|
||||
self.cache
|
||||
.iter()
|
||||
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
|
||||
.map(|entry| {
|
||||
let key = entry.key();
|
||||
let val = entry.value();
|
||||
let failure_info = self.failures.get(key);
|
||||
|
||||
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info(
|
||||
failure_info.as_deref().and_then(|f| f.h2.as_ref()),
|
||||
);
|
||||
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info(
|
||||
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
|
||||
);
|
||||
let (h2_sup, h2_cd, h2_cons) =
|
||||
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
|
||||
let (h3_sup, h3_cd, h3_cons) =
|
||||
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
|
||||
|
||||
ProtocolCacheEntry {
|
||||
host: key.host.clone(),
|
||||
@@ -507,7 +519,13 @@ impl ProtocolCache {
|
||||
/// Insert a protocol detection result with an optional H3 port.
|
||||
/// Logs protocol transitions when overwriting an existing entry.
|
||||
/// No suppression check — callers must check before calling.
|
||||
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) {
|
||||
fn insert_internal(
|
||||
&self,
|
||||
key: ProtocolCacheKey,
|
||||
protocol: DetectedProtocol,
|
||||
h3_port: Option<u16>,
|
||||
reason: &str,
|
||||
) {
|
||||
// Check for existing entry to log protocol transitions
|
||||
if let Some(existing) = self.cache.get(&key) {
|
||||
if existing.protocol != protocol {
|
||||
@@ -522,7 +540,9 @@ impl ProtocolCache {
|
||||
|
||||
// Evict oldest entry if at capacity
|
||||
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
|
||||
let oldest = self.cache.iter()
|
||||
let oldest = self
|
||||
.cache
|
||||
.iter()
|
||||
.min_by_key(|entry| entry.value().last_accessed_at)
|
||||
.map(|entry| entry.key().clone());
|
||||
if let Some(oldest_key) = oldest {
|
||||
@@ -531,13 +551,16 @@ impl ProtocolCache {
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
self.cache.insert(key, CachedEntry {
|
||||
protocol,
|
||||
detected_at: now,
|
||||
last_accessed_at: now,
|
||||
last_probed_at: now,
|
||||
h3_port,
|
||||
});
|
||||
self.cache.insert(
|
||||
key,
|
||||
CachedEntry {
|
||||
protocol,
|
||||
detected_at: now,
|
||||
last_accessed_at: now,
|
||||
last_probed_at: now,
|
||||
h3_port,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Reduce a failure record's remaining cooldown to `target`, if it currently
|
||||
@@ -582,26 +605,34 @@ impl ProtocolCache {
|
||||
interval.tick().await;
|
||||
|
||||
// Clean expired cache entries (sliding TTL based on last_accessed_at)
|
||||
let expired: Vec<ProtocolCacheKey> = cache.iter()
|
||||
let expired: Vec<ProtocolCacheKey> = cache
|
||||
.iter()
|
||||
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
if !expired.is_empty() {
|
||||
debug!("Protocol cache cleanup: removing {} expired entries", expired.len());
|
||||
debug!(
|
||||
"Protocol cache cleanup: removing {} expired entries",
|
||||
expired.len()
|
||||
);
|
||||
for key in expired {
|
||||
cache.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean fully-expired failure entries
|
||||
let expired_failures: Vec<ProtocolCacheKey> = failures.iter()
|
||||
let expired_failures: Vec<ProtocolCacheKey> = failures
|
||||
.iter()
|
||||
.filter(|entry| entry.value().all_expired())
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
if !expired_failures.is_empty() {
|
||||
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len());
|
||||
debug!(
|
||||
"Protocol cache cleanup: removing {} expired failure entries",
|
||||
expired_failures.len()
|
||||
);
|
||||
for key in expired_failures {
|
||||
failures.remove(&key);
|
||||
}
|
||||
@@ -609,7 +640,8 @@ impl ProtocolCache {
|
||||
|
||||
// Safety net: cap failures map at 2× max entries
|
||||
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
|
||||
let oldest: Vec<ProtocolCacheKey> = failures.iter()
|
||||
let oldest: Vec<ProtocolCacheKey> = failures
|
||||
.iter()
|
||||
.filter(|e| e.value().all_expired())
|
||||
.map(|e| e.key().clone())
|
||||
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
|
||||
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
|
||||
|
||||
use crate::request_host::extract_request_host;
|
||||
|
||||
pub struct RequestFilter;
|
||||
|
||||
@@ -35,16 +37,13 @@ impl RequestFilter {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
|
||||
// IP filter (domain-aware: extract Host header for domain-scoped entries)
|
||||
// IP filter (domain-aware: use the same host extraction as route matching)
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
let host = req.headers()
|
||||
.get("host")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|h| h.split(':').next().unwrap_or(h));
|
||||
let host = extract_request_host(req);
|
||||
if !filter.is_allowed_for_domain(&normalized, host) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
@@ -59,16 +58,15 @@ impl RequestFilter {
|
||||
!limiter.check(&key)
|
||||
} else {
|
||||
// Create a per-check limiter (less ideal but works for non-shared case)
|
||||
let limiter = RateLimiter::new(
|
||||
rate_limit_config.max_requests,
|
||||
rate_limit_config.window,
|
||||
);
|
||||
let limiter =
|
||||
RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
};
|
||||
|
||||
if should_block {
|
||||
let message = rate_limit_config.error_message
|
||||
let message = rate_limit_config
|
||||
.error_message
|
||||
.as_deref()
|
||||
.unwrap_or("Rate limit exceeded");
|
||||
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
|
||||
@@ -84,36 +82,48 @@ impl RequestFilter {
|
||||
if let Some(ref basic_auth) = security.basic_auth {
|
||||
if basic_auth.enabled {
|
||||
// Check basic auth exclude paths
|
||||
let skip_basic = basic_auth.exclude_paths.as_ref()
|
||||
let skip_basic = basic_auth
|
||||
.exclude_paths
|
||||
.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_basic {
|
||||
let users: Vec<(String, String)> = basic_auth.users.iter()
|
||||
let users: Vec<(String, String)> = basic_auth
|
||||
.users
|
||||
.iter()
|
||||
.map(|c| (c.username.clone(), c.password.clone()))
|
||||
.collect();
|
||||
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
|
||||
|
||||
let auth_header = req.headers()
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(header) => {
|
||||
if validator.validate(header).is_none() {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap());
|
||||
return Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header(
|
||||
"WWW-Authenticate",
|
||||
validator.www_authenticate(),
|
||||
)
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap());
|
||||
return Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -124,7 +134,9 @@ impl RequestFilter {
|
||||
if let Some(ref jwt_auth) = security.jwt_auth {
|
||||
if jwt_auth.enabled {
|
||||
// Check JWT auth exclude paths
|
||||
let skip_jwt = jwt_auth.exclude_paths.as_ref()
|
||||
let skip_jwt = jwt_auth
|
||||
.exclude_paths
|
||||
.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
@@ -136,18 +148,25 @@ impl RequestFilter {
|
||||
jwt_auth.audience.as_deref(),
|
||||
);
|
||||
|
||||
let auth_header = req.headers()
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header.and_then(JwtValidator::extract_token) {
|
||||
Some(token) => {
|
||||
if validator.validate(token).is_err() {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
|
||||
return Some(error_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid token",
|
||||
));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
|
||||
return Some(error_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Bearer token required",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -209,7 +228,11 @@ impl RequestFilter {
|
||||
/// Check IP-based security (for use in passthrough / TCP-level connections).
|
||||
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
|
||||
/// Returns true if allowed, false if blocked.
|
||||
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr, domain: Option<&str>) -> bool {
|
||||
pub fn check_ip_security(
|
||||
security: &RouteSecurity,
|
||||
client_ip: &std::net::IpAddr,
|
||||
domain: Option<&str>,
|
||||
) -> bool {
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
@@ -238,19 +261,28 @@ impl RequestFilter {
|
||||
return None;
|
||||
}
|
||||
|
||||
let origin = req.headers()
|
||||
let origin = req
|
||||
.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("*");
|
||||
|
||||
Some(Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap())
|
||||
Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
|
||||
)
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Content-Type, Authorization, X-Requested-With",
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
|
||||
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
|
||||
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{Request, StatusCode, Version};
|
||||
use rustproxy_config::{IpAllowEntry, RouteSecurity};
|
||||
|
||||
use super::RequestFilter;
|
||||
|
||||
fn domain_scoped_security() -> RouteSecurity {
|
||||
RouteSecurity {
|
||||
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
|
||||
ip: "10.8.0.2".to_string(),
|
||||
domains: vec!["*.abc.xyz".to_string()],
|
||||
}]),
|
||||
ip_block_list: None,
|
||||
max_connections: None,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn peer_addr() -> std::net::SocketAddr {
|
||||
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
|
||||
}
|
||||
|
||||
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
|
||||
let mut builder = Request::builder().uri(uri).version(version);
|
||||
if let Some(host) = host {
|
||||
builder = builder.header("host", host);
|
||||
}
|
||||
|
||||
builder.body(Empty::<Bytes>::new()).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
|
||||
|
||||
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_allows_host_header_with_port() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request(
|
||||
"https://unrelated.invalid/",
|
||||
Version::HTTP_11,
|
||||
Some("outline.abc.xyz:443"),
|
||||
);
|
||||
|
||||
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_denies_non_matching_uri_authority() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
|
||||
|
||||
let response = RequestFilter::apply(&security, &req, &peer_addr())
|
||||
.expect("non-matching domain should be denied");
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
use hyper::Request;
|
||||
|
||||
/// Extract the effective request host for routing and scoped ACL checks.
|
||||
///
|
||||
/// Prefer the explicit `Host` header when present, otherwise fall back to the
|
||||
/// URI authority used by HTTP/2 and HTTP/3 requests.
|
||||
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
|
||||
req.headers()
|
||||
.get("host")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(|host| host.split(':').next().unwrap_or(host))
|
||||
.or_else(|| req.uri().host())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::Request;
|
||||
|
||||
use super::extract_request_host;
|
||||
|
||||
#[test]
|
||||
fn extracts_host_header_before_uri_authority() {
|
||||
let req = Request::builder()
|
||||
.uri("https://uri.abc.xyz/test")
|
||||
.header("host", "header.abc.xyz:443")
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_uri_authority_when_host_header_missing() {
|
||||
let req = Request::builder()
|
||||
.uri("https://outline.abc.xyz/test")
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use rustproxy_config::RouteConfig;
|
||||
|
||||
use crate::template::{RequestContext, expand_template};
|
||||
use crate::template::{expand_template, RequestContext};
|
||||
|
||||
pub struct ResponseFilter;
|
||||
|
||||
@@ -11,12 +11,17 @@ impl ResponseFilter {
|
||||
/// Apply response headers from route config and CORS settings.
|
||||
/// If a `RequestContext` is provided, template variables in header values will be expanded.
|
||||
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
|
||||
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
|
||||
pub fn apply_headers(
|
||||
route: &RouteConfig,
|
||||
headers: &mut HeaderMap,
|
||||
req_ctx: Option<&RequestContext>,
|
||||
) {
|
||||
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
|
||||
if let Some(ref udp) = route.action.udp {
|
||||
if let Some(ref quic) = udp.quic {
|
||||
if quic.enable_http3.unwrap_or(false) {
|
||||
let port = quic.alt_svc_port
|
||||
let port = quic
|
||||
.alt_svc_port
|
||||
.or_else(|| req_ctx.map(|c| c.port))
|
||||
.unwrap_or(443);
|
||||
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
|
||||
@@ -63,10 +68,7 @@ impl ResponseFilter {
|
||||
headers.insert("access-control-allow-origin", val);
|
||||
}
|
||||
} else {
|
||||
headers.insert(
|
||||
"access-control-allow-origin",
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
|
||||
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
|
||||
self.inner.as_ref().unwrap().is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
|
||||
if result.is_ready() {
|
||||
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
tokio::io::AsyncWriteExt::shutdown(&mut stream),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
// stream is dropped here — all resources freed
|
||||
});
|
||||
}
|
||||
|
||||
@@ -39,7 +39,8 @@ pub fn expand_headers(
|
||||
headers: &HashMap<String, String>,
|
||||
ctx: &RequestContext,
|
||||
) -> HashMap<String, String> {
|
||||
headers.iter()
|
||||
headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
|
||||
.collect()
|
||||
}
|
||||
@@ -150,7 +151,10 @@ mod tests {
|
||||
let ctx = test_context();
|
||||
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
|
||||
let result = expand_template(template, &ctx);
|
||||
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
|
||||
assert_eq!(
|
||||
result,
|
||||
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
|
||||
use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
|
||||
|
||||
/// Upstream selection result.
|
||||
pub struct UpstreamSelection {
|
||||
@@ -51,21 +51,19 @@ impl UpstreamSelector {
|
||||
}
|
||||
|
||||
// Determine load balancing algorithm
|
||||
let algorithm = target.load_balancing.as_ref()
|
||||
let algorithm = target
|
||||
.load_balancing
|
||||
.as_ref()
|
||||
.map(|lb| &lb.algorithm)
|
||||
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
|
||||
|
||||
let idx = match algorithm {
|
||||
LoadBalancingAlgorithm::RoundRobin => {
|
||||
self.round_robin_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
|
||||
LoadBalancingAlgorithm::IpHash => {
|
||||
let hash = Self::ip_hash(client_addr);
|
||||
hash % hosts.len()
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => {
|
||||
self.least_connections_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
|
||||
};
|
||||
|
||||
UpstreamSelection {
|
||||
@@ -78,9 +76,7 @@ impl UpstreamSelector {
|
||||
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let key = format!("{}:{}", hosts[0], port);
|
||||
let mut counters = self.round_robin.lock().unwrap();
|
||||
let counter = counters
|
||||
.entry(key)
|
||||
.or_insert_with(|| AtomicUsize::new(0));
|
||||
let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed);
|
||||
idx % hosts.len()
|
||||
}
|
||||
@@ -91,7 +87,8 @@ impl UpstreamSelector {
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let key = format!("{}:{}", host, port);
|
||||
let conns = self.active_connections
|
||||
let conns = self
|
||||
.active_connections
|
||||
.get(&key)
|
||||
.map(|entry| entry.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
@@ -228,13 +225,21 @@ mod tests {
|
||||
selector.connection_started("backend:8080");
|
||||
selector.connection_started("backend:8080");
|
||||
assert_eq!(
|
||||
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
|
||||
selector
|
||||
.active_connections
|
||||
.get("backend:8080")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
|
||||
selector.connection_ended("backend:8080");
|
||||
assert_eq!(
|
||||
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
|
||||
selector
|
||||
.active_connections
|
||||
.get("backend:8080")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
|
||||
|
||||
@@ -144,6 +144,15 @@ const MAX_BACKENDS_IN_SNAPSHOT: usize = 100;
|
||||
/// Maximum number of distinct domains tracked per IP (prevents subdomain-spray abuse).
|
||||
const MAX_DOMAINS_PER_IP: usize = 256;
|
||||
|
||||
fn canonicalize_domain_key(domain: &str) -> Option<String> {
|
||||
let normalized = domain.trim().trim_end_matches('.').to_ascii_lowercase();
|
||||
if normalized.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(normalized)
|
||||
}
|
||||
}
|
||||
|
||||
/// Metrics collector tracking connections and throughput.
|
||||
///
|
||||
/// Design: The hot path (`record_bytes`) is entirely lock-free — it only touches
|
||||
@@ -334,25 +343,43 @@ impl MetricsCollector {
|
||||
|
||||
/// Record a connection closing.
|
||||
pub fn connection_closed(&self, route_id: Option<&str>, source_ip: Option<&str>) {
|
||||
self.active_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
self.active_connections
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 {
|
||||
Some(v - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
if let Some(counter) = self.route_connections.get(route_id) {
|
||||
let val = counter.load(Ordering::Relaxed);
|
||||
if val > 0 {
|
||||
counter.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
counter
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 {
|
||||
Some(v - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ip) = source_ip {
|
||||
if let Some(counter) = self.ip_connections.get(ip) {
|
||||
let val = counter.load(Ordering::Relaxed);
|
||||
if val > 0 {
|
||||
counter.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
let prev = counter
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 {
|
||||
Some(v - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
// Clean up zero-count entries to prevent memory growth
|
||||
if val <= 1 {
|
||||
if matches!(prev, Some(v) if v <= 1) {
|
||||
drop(counter);
|
||||
self.ip_connections.remove(ip);
|
||||
// Evict all per-IP tracking data for this IP
|
||||
@@ -371,17 +398,25 @@ impl MetricsCollector {
|
||||
///
|
||||
/// Called per-chunk in the TCP copy loop. Only touches AtomicU64 counters —
|
||||
/// no Mutex is taken. The throughput trackers are fed during `sample_all()`.
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>, source_ip: Option<&str>) {
|
||||
pub fn record_bytes(
|
||||
&self,
|
||||
bytes_in: u64,
|
||||
bytes_out: u64,
|
||||
route_id: Option<&str>,
|
||||
source_ip: Option<&str>,
|
||||
) {
|
||||
// Short-circuit: only touch counters for the direction that has data.
|
||||
// CountingBody always calls with one direction zero — skipping the zero
|
||||
// direction avoids ~50% of DashMap shard-locked reads per call.
|
||||
if bytes_in > 0 {
|
||||
self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.global_pending_tp_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.global_pending_tp_in
|
||||
.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
if bytes_out > 0 {
|
||||
self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
self.global_pending_tp_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
self.global_pending_tp_out
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// Per-route tracking: use get() first (zero-alloc fast path for existing entries),
|
||||
@@ -391,7 +426,8 @@ impl MetricsCollector {
|
||||
if let Some(counter) = self.route_bytes_in.get(route_id) {
|
||||
counter.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
} else {
|
||||
self.route_bytes_in.entry(route_id.to_string())
|
||||
self.route_bytes_in
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
@@ -400,7 +436,8 @@ impl MetricsCollector {
|
||||
if let Some(counter) = self.route_bytes_out.get(route_id) {
|
||||
counter.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
} else {
|
||||
self.route_bytes_out.entry(route_id.to_string())
|
||||
self.route_bytes_out
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
@@ -408,13 +445,23 @@ impl MetricsCollector {
|
||||
|
||||
// Accumulate into per-route pending throughput counters (lock-free)
|
||||
if let Some(entry) = self.route_pending_tp.get(route_id) {
|
||||
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); }
|
||||
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); }
|
||||
if bytes_in > 0 {
|
||||
entry.0.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
if bytes_out > 0 {
|
||||
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
} else {
|
||||
let entry = self.route_pending_tp.entry(route_id.to_string())
|
||||
let entry = self
|
||||
.route_pending_tp
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0)));
|
||||
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); }
|
||||
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); }
|
||||
if bytes_in > 0 {
|
||||
entry.0.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
if bytes_out > 0 {
|
||||
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -428,7 +475,8 @@ impl MetricsCollector {
|
||||
if let Some(counter) = self.ip_bytes_in.get(ip) {
|
||||
counter.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
} else {
|
||||
self.ip_bytes_in.entry(ip.to_string())
|
||||
self.ip_bytes_in
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
@@ -437,7 +485,8 @@ impl MetricsCollector {
|
||||
if let Some(counter) = self.ip_bytes_out.get(ip) {
|
||||
counter.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
} else {
|
||||
self.ip_bytes_out.entry(ip.to_string())
|
||||
self.ip_bytes_out
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
@@ -445,13 +494,23 @@ impl MetricsCollector {
|
||||
|
||||
// Accumulate into per-IP pending throughput counters (lock-free)
|
||||
if let Some(entry) = self.ip_pending_tp.get(ip) {
|
||||
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); }
|
||||
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); }
|
||||
if bytes_in > 0 {
|
||||
entry.0.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
if bytes_out > 0 {
|
||||
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
} else {
|
||||
let entry = self.ip_pending_tp.entry(ip.to_string())
|
||||
let entry = self
|
||||
.ip_pending_tp
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0)));
|
||||
if bytes_in > 0 { entry.0.fetch_add(bytes_in, Ordering::Relaxed); }
|
||||
if bytes_out > 0 { entry.1.fetch_add(bytes_out, Ordering::Relaxed); }
|
||||
if bytes_in > 0 {
|
||||
entry.0.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
}
|
||||
if bytes_out > 0 {
|
||||
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -469,9 +528,13 @@ impl MetricsCollector {
|
||||
/// connection (with SNI domain). The common case (IP + domain both already
|
||||
/// tracked) is two DashMap reads + one atomic increment — zero allocation.
|
||||
pub fn record_ip_domain_request(&self, ip: &str, domain: &str) {
|
||||
let Some(domain) = canonicalize_domain_key(domain) else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Fast path: IP already tracked, domain already tracked
|
||||
if let Some(domains) = self.ip_domain_requests.get(ip) {
|
||||
if let Some(counter) = domains.get(domain) {
|
||||
if let Some(counter) = domains.get(domain.as_str()) {
|
||||
counter.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
@@ -480,7 +543,7 @@ impl MetricsCollector {
|
||||
return;
|
||||
}
|
||||
domains
|
||||
.entry(domain.to_string())
|
||||
.entry(domain)
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
@@ -490,7 +553,7 @@ impl MetricsCollector {
|
||||
return;
|
||||
}
|
||||
let inner = DashMap::with_capacity_and_shard_amount(4, 2);
|
||||
inner.insert(domain.to_string(), AtomicU64::new(1));
|
||||
inner.insert(domain, AtomicU64::new(1));
|
||||
self.ip_domain_requests.insert(ip.to_string(), inner);
|
||||
}
|
||||
|
||||
@@ -504,7 +567,15 @@ impl MetricsCollector {
|
||||
|
||||
/// Record a UDP session closed.
|
||||
pub fn udp_session_closed(&self) {
|
||||
self.active_udp_sessions.fetch_sub(1, Ordering::Relaxed);
|
||||
self.active_udp_sessions
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 {
|
||||
Some(v - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
/// Record a UDP datagram (inbound or outbound).
|
||||
@@ -553,9 +624,15 @@ impl MetricsCollector {
|
||||
let (active, _) = self.frontend_proto_counters(proto);
|
||||
// Atomic saturating decrement — avoids TOCTOU race where concurrent
|
||||
// closes could both read val=1, both subtract, wrapping to u64::MAX.
|
||||
active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 { Some(v - 1) } else { None }
|
||||
}).ok();
|
||||
active
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 {
|
||||
Some(v - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
/// Record a backend connection opened with a given protocol.
|
||||
@@ -569,9 +646,15 @@ impl MetricsCollector {
|
||||
pub fn backend_protocol_closed(&self, proto: &str) {
|
||||
let (active, _) = self.backend_proto_counters(proto);
|
||||
// Atomic saturating decrement — see frontend_protocol_closed for rationale.
|
||||
active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 { Some(v - 1) } else { None }
|
||||
}).ok();
|
||||
active
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
|
||||
if v > 0 {
|
||||
Some(v - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
// ── Per-backend recording methods ──
|
||||
@@ -681,17 +764,28 @@ impl MetricsCollector {
|
||||
|
||||
/// Remove per-backend metrics for backends no longer in any route target.
|
||||
pub fn retain_backends(&self, active_backends: &HashSet<String>) {
|
||||
self.backend_active.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_total.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_protocol.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_connect_errors.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_handshake_errors.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_request_errors.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_connect_time_us.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_connect_count.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_pool_hits.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_pool_misses.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_h2_failures.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_active
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_total
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_protocol
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_connect_errors
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_handshake_errors
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_request_errors
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_connect_time_us
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_connect_count
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_pool_hits
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_pool_misses
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
self.backend_h2_failures
|
||||
.retain(|k, _| active_backends.contains(k));
|
||||
}
|
||||
|
||||
/// Take a throughput sample on all trackers (cold path, call at 1Hz or configured interval).
|
||||
@@ -782,41 +876,64 @@ impl MetricsCollector {
|
||||
// Safety-net: prune orphaned per-IP entries that have no corresponding
|
||||
// ip_connections entry. This catches any entries created by a race between
|
||||
// record_bytes and connection_closed.
|
||||
self.ip_bytes_in.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_bytes_out.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_pending_tp.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_throughput.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_total_connections.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_domain_requests.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_bytes_in
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_bytes_out
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_pending_tp
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_throughput
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_total_connections
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
self.ip_domain_requests
|
||||
.retain(|k, _| self.ip_connections.contains_key(k));
|
||||
|
||||
// Safety-net: prune orphaned backend error/stats entries for backends
|
||||
// that have no active or total connections (error-only backends).
|
||||
// These accumulate when backend_connect_error/backend_handshake_error
|
||||
// create entries but backend_connection_opened is never called.
|
||||
let known_backends: HashSet<String> = self.backend_active.iter()
|
||||
let known_backends: HashSet<String> = self
|
||||
.backend_active
|
||||
.iter()
|
||||
.map(|e| e.key().clone())
|
||||
.chain(self.backend_total.iter().map(|e| e.key().clone()))
|
||||
.collect();
|
||||
self.backend_connect_errors.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_handshake_errors.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_request_errors.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_connect_time_us.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_connect_count.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_pool_hits.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_pool_misses.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_h2_failures.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_protocol.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_connect_errors
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_handshake_errors
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_request_errors
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_connect_time_us
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_connect_count
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_pool_hits
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_pool_misses
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_h2_failures
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
self.backend_protocol
|
||||
.retain(|k, _| known_backends.contains(k));
|
||||
}
|
||||
|
||||
/// Remove per-route metrics for route IDs that are no longer active.
|
||||
/// Call this after `update_routes()` to prune stale entries.
|
||||
pub fn retain_routes(&self, active_route_ids: &HashSet<String>) {
|
||||
self.route_connections.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_total_connections.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_bytes_in.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_bytes_out.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_pending_tp.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_throughput.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_connections
|
||||
.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_total_connections
|
||||
.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_bytes_in
|
||||
.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_bytes_out
|
||||
.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_pending_tp
|
||||
.retain(|k, _| active_route_ids.contains(k));
|
||||
self.route_throughput
|
||||
.retain(|k, _| active_route_ids.contains(k));
|
||||
}
|
||||
|
||||
/// Get current active connection count.
|
||||
@@ -859,72 +976,97 @@ impl MetricsCollector {
|
||||
for entry in self.route_total_connections.iter() {
|
||||
let route_id = entry.key().clone();
|
||||
let total = entry.value().load(Ordering::Relaxed);
|
||||
let active = self.route_connections
|
||||
let active = self
|
||||
.route_connections
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_in = self.route_bytes_in
|
||||
let bytes_in = self
|
||||
.route_bytes_in
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_out = self.route_bytes_out
|
||||
let bytes_out = self
|
||||
.route_bytes_out
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
|
||||
let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self.route_throughput
|
||||
let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self
|
||||
.route_throughput
|
||||
.get(&route_id)
|
||||
.and_then(|entry| entry.value().lock().ok().map(|t| {
|
||||
let (i_in, i_out) = t.instant();
|
||||
let (r_in, r_out) = t.recent();
|
||||
(i_in, i_out, r_in, r_out)
|
||||
}))
|
||||
.and_then(|entry| {
|
||||
entry.value().lock().ok().map(|t| {
|
||||
let (i_in, i_out) = t.instant();
|
||||
let (r_in, r_out) = t.recent();
|
||||
(i_in, i_out, r_in, r_out)
|
||||
})
|
||||
})
|
||||
.unwrap_or((0, 0, 0, 0));
|
||||
|
||||
routes.insert(route_id, RouteMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
throughput_in_bytes_per_sec: route_tp_in,
|
||||
throughput_out_bytes_per_sec: route_tp_out,
|
||||
throughput_recent_in_bytes_per_sec: route_recent_in,
|
||||
throughput_recent_out_bytes_per_sec: route_recent_out,
|
||||
});
|
||||
routes.insert(
|
||||
route_id,
|
||||
RouteMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
throughput_in_bytes_per_sec: route_tp_in,
|
||||
throughput_out_bytes_per_sec: route_tp_out,
|
||||
throughput_recent_in_bytes_per_sec: route_recent_in,
|
||||
throughput_recent_out_bytes_per_sec: route_recent_out,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Collect per-IP metrics — only IPs with active connections or total > 0,
|
||||
// capped at top MAX_IPS_IN_SNAPSHOT sorted by active count
|
||||
let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap<String, u64>)> = Vec::new();
|
||||
let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64, HashMap<String, u64>)> =
|
||||
Vec::new();
|
||||
for entry in self.ip_total_connections.iter() {
|
||||
let ip = entry.key().clone();
|
||||
let total = entry.value().load(Ordering::Relaxed);
|
||||
let active = self.ip_connections
|
||||
let active = self
|
||||
.ip_connections
|
||||
.get(&ip)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_in = self.ip_bytes_in
|
||||
let bytes_in = self
|
||||
.ip_bytes_in
|
||||
.get(&ip)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_out = self.ip_bytes_out
|
||||
let bytes_out = self
|
||||
.ip_bytes_out
|
||||
.get(&ip)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let (tp_in, tp_out) = self.ip_throughput
|
||||
let (tp_in, tp_out) = self
|
||||
.ip_throughput
|
||||
.get(&ip)
|
||||
.and_then(|entry| entry.value().lock().ok().map(|t| t.instant()))
|
||||
.unwrap_or((0, 0));
|
||||
// Collect per-domain request counts for this IP
|
||||
let domain_requests = self.ip_domain_requests
|
||||
let domain_requests = self
|
||||
.ip_domain_requests
|
||||
.get(&ip)
|
||||
.map(|domains| {
|
||||
domains.iter()
|
||||
domains
|
||||
.iter()
|
||||
.map(|e| (e.key().clone(), e.value().load(Ordering::Relaxed)))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
ip_entries.push((ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests));
|
||||
ip_entries.push((
|
||||
ip,
|
||||
active,
|
||||
total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
tp_in,
|
||||
tp_out,
|
||||
domain_requests,
|
||||
));
|
||||
}
|
||||
// Sort by active connections descending, then cap
|
||||
ip_entries.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
@@ -932,15 +1074,18 @@ impl MetricsCollector {
|
||||
|
||||
let mut ips = std::collections::HashMap::new();
|
||||
for (ip, active, total, bytes_in, bytes_out, tp_in, tp_out, domain_requests) in ip_entries {
|
||||
ips.insert(ip, IpMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
throughput_in_bytes_per_sec: tp_in,
|
||||
throughput_out_bytes_per_sec: tp_out,
|
||||
domain_requests,
|
||||
});
|
||||
ips.insert(
|
||||
ip,
|
||||
IpMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
throughput_in_bytes_per_sec: tp_in,
|
||||
throughput_out_bytes_per_sec: tp_out,
|
||||
domain_requests,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Collect per-backend metrics, capped at top MAX_BACKENDS_IN_SNAPSHOT by total connections
|
||||
@@ -948,69 +1093,84 @@ impl MetricsCollector {
|
||||
for entry in self.backend_total.iter() {
|
||||
let key = entry.key().clone();
|
||||
let total = entry.value().load(Ordering::Relaxed);
|
||||
let active = self.backend_active
|
||||
let active = self
|
||||
.backend_active
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let protocol = self.backend_protocol
|
||||
let protocol = self
|
||||
.backend_protocol
|
||||
.get(&key)
|
||||
.map(|v| v.value().clone())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let connect_errors = self.backend_connect_errors
|
||||
let connect_errors = self
|
||||
.backend_connect_errors
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let handshake_errors = self.backend_handshake_errors
|
||||
let handshake_errors = self
|
||||
.backend_handshake_errors
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let request_errors = self.backend_request_errors
|
||||
let request_errors = self
|
||||
.backend_request_errors
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let total_connect_time_us = self.backend_connect_time_us
|
||||
let total_connect_time_us = self
|
||||
.backend_connect_time_us
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let connect_count = self.backend_connect_count
|
||||
let connect_count = self
|
||||
.backend_connect_count
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let pool_hits = self.backend_pool_hits
|
||||
let pool_hits = self
|
||||
.backend_pool_hits
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let pool_misses = self.backend_pool_misses
|
||||
let pool_misses = self
|
||||
.backend_pool_misses
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let h2_failures = self.backend_h2_failures
|
||||
let h2_failures = self
|
||||
.backend_h2_failures
|
||||
.get(&key)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
|
||||
backend_entries.push((key, BackendMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
protocol,
|
||||
connect_errors,
|
||||
handshake_errors,
|
||||
request_errors,
|
||||
total_connect_time_us,
|
||||
connect_count,
|
||||
pool_hits,
|
||||
pool_misses,
|
||||
h2_failures,
|
||||
}));
|
||||
backend_entries.push((
|
||||
key,
|
||||
BackendMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
protocol,
|
||||
connect_errors,
|
||||
handshake_errors,
|
||||
request_errors,
|
||||
total_connect_time_us,
|
||||
connect_count,
|
||||
pool_hits,
|
||||
pool_misses,
|
||||
h2_failures,
|
||||
},
|
||||
));
|
||||
}
|
||||
// Sort by total connections descending, then cap
|
||||
backend_entries.sort_by(|a, b| b.1.total_connections.cmp(&a.1.total_connections));
|
||||
backend_entries.truncate(MAX_BACKENDS_IN_SNAPSHOT);
|
||||
|
||||
let backends: std::collections::HashMap<String, BackendMetrics> = backend_entries.into_iter().collect();
|
||||
let backends: std::collections::HashMap<String, BackendMetrics> =
|
||||
backend_entries.into_iter().collect();
|
||||
|
||||
// HTTP request rates
|
||||
let (http_rps, http_rps_recent) = self.http_request_throughput
|
||||
let (http_rps, http_rps_recent) = self
|
||||
.http_request_throughput
|
||||
.lock()
|
||||
.map(|t| {
|
||||
let (instant, _) = t.instant();
|
||||
@@ -1185,11 +1345,19 @@ mod tests {
|
||||
|
||||
// Check IP active connections (drop DashMap refs immediately to avoid deadlock)
|
||||
assert_eq!(
|
||||
collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed),
|
||||
collector
|
||||
.ip_connections
|
||||
.get("1.2.3.4")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
assert_eq!(
|
||||
collector.ip_connections.get("5.6.7.8").unwrap().load(Ordering::Relaxed),
|
||||
collector
|
||||
.ip_connections
|
||||
.get("5.6.7.8")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
|
||||
@@ -1207,7 +1375,11 @@ mod tests {
|
||||
// Close connections
|
||||
collector.connection_closed(Some("route-a"), Some("1.2.3.4"));
|
||||
assert_eq!(
|
||||
collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed),
|
||||
collector
|
||||
.ip_connections
|
||||
.get("1.2.3.4")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
|
||||
@@ -1252,6 +1424,79 @@ mod tests {
|
||||
assert!(collector.ip_total_connections.get("10.0.0.2").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_closed_saturates_active_gauges() {
|
||||
let collector = MetricsCollector::new();
|
||||
|
||||
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
|
||||
assert_eq!(collector.active_connections(), 0);
|
||||
|
||||
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
|
||||
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
|
||||
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
|
||||
|
||||
assert_eq!(collector.active_connections(), 0);
|
||||
assert_eq!(
|
||||
collector
|
||||
.route_connections
|
||||
.get("route-a")
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0),
|
||||
0
|
||||
);
|
||||
assert!(collector.ip_connections.get("10.0.0.1").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_udp_session_closed_saturates() {
|
||||
let collector = MetricsCollector::new();
|
||||
|
||||
collector.udp_session_closed();
|
||||
assert_eq!(collector.snapshot().active_udp_sessions, 0);
|
||||
|
||||
collector.udp_session_opened();
|
||||
collector.udp_session_closed();
|
||||
collector.udp_session_closed();
|
||||
assert_eq!(collector.snapshot().active_udp_sessions, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_domain_requests_are_canonicalized() {
|
||||
let collector = MetricsCollector::new();
|
||||
|
||||
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
|
||||
collector.record_ip_domain_request("10.0.0.1", "Example.COM");
|
||||
collector.record_ip_domain_request("10.0.0.1", "example.com.");
|
||||
collector.record_ip_domain_request("10.0.0.1", " example.com ");
|
||||
|
||||
let snapshot = collector.snapshot();
|
||||
let ip_metrics = snapshot.ips.get("10.0.0.1").unwrap();
|
||||
assert_eq!(ip_metrics.domain_requests.len(), 1);
|
||||
assert_eq!(ip_metrics.domain_requests.get("example.com"), Some(&3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol_metrics_appear_in_snapshot() {
|
||||
let collector = MetricsCollector::new();
|
||||
|
||||
collector.frontend_protocol_opened("h2");
|
||||
collector.frontend_protocol_opened("ws");
|
||||
collector.backend_protocol_opened("h3");
|
||||
collector.backend_protocol_opened("ws");
|
||||
collector.frontend_protocol_closed("h2");
|
||||
collector.backend_protocol_closed("h3");
|
||||
|
||||
let snapshot = collector.snapshot();
|
||||
assert_eq!(snapshot.frontend_protocols.h2_active, 0);
|
||||
assert_eq!(snapshot.frontend_protocols.h2_total, 1);
|
||||
assert_eq!(snapshot.frontend_protocols.ws_active, 1);
|
||||
assert_eq!(snapshot.frontend_protocols.ws_total, 1);
|
||||
assert_eq!(snapshot.backend_protocols.h3_active, 0);
|
||||
assert_eq!(snapshot.backend_protocols.h3_total, 1);
|
||||
assert_eq!(snapshot.backend_protocols.ws_active, 1);
|
||||
assert_eq!(snapshot.backend_protocols.ws_total, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_request_tracking() {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
@@ -1326,9 +1571,16 @@ mod tests {
|
||||
let collector = MetricsCollector::with_retention(60);
|
||||
|
||||
// Manually insert orphaned entries (simulates the race before the guard)
|
||||
collector.ip_bytes_in.insert("orphan-ip".to_string(), AtomicU64::new(100));
|
||||
collector.ip_bytes_out.insert("orphan-ip".to_string(), AtomicU64::new(200));
|
||||
collector.ip_pending_tp.insert("orphan-ip".to_string(), (AtomicU64::new(0), AtomicU64::new(0)));
|
||||
collector
|
||||
.ip_bytes_in
|
||||
.insert("orphan-ip".to_string(), AtomicU64::new(100));
|
||||
collector
|
||||
.ip_bytes_out
|
||||
.insert("orphan-ip".to_string(), AtomicU64::new(200));
|
||||
collector.ip_pending_tp.insert(
|
||||
"orphan-ip".to_string(),
|
||||
(AtomicU64::new(0), AtomicU64::new(0)),
|
||||
);
|
||||
|
||||
// No ip_connections entry for "orphan-ip"
|
||||
assert!(collector.ip_connections.get("orphan-ip").is_none());
|
||||
@@ -1366,17 +1618,59 @@ mod tests {
|
||||
collector.backend_connection_opened(key, Duration::from_millis(15));
|
||||
collector.backend_connection_opened(key, Duration::from_millis(25));
|
||||
|
||||
assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 2);
|
||||
assert_eq!(collector.backend_total.get(key).unwrap().load(Ordering::Relaxed), 2);
|
||||
assert_eq!(collector.backend_connect_count.get(key).unwrap().load(Ordering::Relaxed), 2);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_active
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_total
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_connect_count
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
// 15ms + 25ms = 40ms = 40_000us
|
||||
assert_eq!(collector.backend_connect_time_us.get(key).unwrap().load(Ordering::Relaxed), 40_000);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_connect_time_us
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
40_000
|
||||
);
|
||||
|
||||
// Close one
|
||||
collector.backend_connection_closed(key);
|
||||
assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 1);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_active
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
// total stays
|
||||
assert_eq!(collector.backend_total.get(key).unwrap().load(Ordering::Relaxed), 2);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_total
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
|
||||
// Record errors
|
||||
collector.backend_connect_error(key);
|
||||
@@ -1387,12 +1681,54 @@ mod tests {
|
||||
collector.backend_pool_hit(key);
|
||||
collector.backend_pool_miss(key);
|
||||
|
||||
assert_eq!(collector.backend_connect_errors.get(key).unwrap().load(Ordering::Relaxed), 1);
|
||||
assert_eq!(collector.backend_handshake_errors.get(key).unwrap().load(Ordering::Relaxed), 1);
|
||||
assert_eq!(collector.backend_request_errors.get(key).unwrap().load(Ordering::Relaxed), 1);
|
||||
assert_eq!(collector.backend_h2_failures.get(key).unwrap().load(Ordering::Relaxed), 1);
|
||||
assert_eq!(collector.backend_pool_hits.get(key).unwrap().load(Ordering::Relaxed), 2);
|
||||
assert_eq!(collector.backend_pool_misses.get(key).unwrap().load(Ordering::Relaxed), 1);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_connect_errors
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_handshake_errors
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_request_errors
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_h2_failures
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_pool_hits
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
assert_eq!(
|
||||
collector
|
||||
.backend_pool_misses
|
||||
.get(key)
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
|
||||
// Protocol
|
||||
collector.set_backend_protocol(key, "h1");
|
||||
@@ -1449,7 +1785,10 @@ mod tests {
|
||||
assert!(collector.backend_total.get("stale:8080").is_none());
|
||||
assert!(collector.backend_protocol.get("stale:8080").is_none());
|
||||
assert!(collector.backend_connect_errors.get("stale:8080").is_none());
|
||||
assert!(collector.backend_connect_time_us.get("stale:8080").is_none());
|
||||
assert!(collector
|
||||
.backend_connect_time_us
|
||||
.get("stale:8080")
|
||||
.is_none());
|
||||
assert!(collector.backend_connect_count.get("stale:8080").is_none());
|
||||
assert!(collector.backend_pool_hits.get("stale:8080").is_none());
|
||||
assert!(collector.backend_pool_misses.get("stale:8080").is_none());
|
||||
|
||||
Reference in New Issue
Block a user