Compare commits

..

14 Commits

Author SHA1 Message Date
jkunz 8fa3a51b03 v27.8.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 11:25:24 +00:00
jkunz 088ef6ab09 fix(rustproxy-metrics): retain inactive per-IP metric buckets briefly to capture final throughput before pruning 2026-04-26 11:25:24 +00:00
jkunz fdb5ec59bc v27.8.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 09:17:11 +00:00
jkunz 1ea290a085 fix(rustproxy-metrics): preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated 2026-04-26 09:17:11 +00:00
jkunz cb71f32b90 v27.8.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 12:43:59 +00:00
jkunz 46155ab12c feat(metrics): add per-domain HTTP request rate metrics 2026-04-14 12:43:59 +00:00
jkunz 490a310b54 v27.7.4
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 09:17:55 +00:00
jkunz 6c5180573a fix(rustproxy metrics): use stable route metrics keys across HTTP and passthrough listeners 2026-04-14 09:17:55 +00:00
jkunz 30e5ab308f v27.7.3
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 01:14:33 +00:00
jkunz d2a54b3491 fix(repo): no changes detected 2026-04-14 01:14:33 +00:00
jkunz dc922c97df v27.7.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 00:55:25 +00:00
jkunz 8d1bae7604 fix(docs): clarify metrics documentation for domain normalization and saturating gauges 2026-04-14 00:55:25 +00:00
jkunz 200e86e311 v27.7.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 00:54:12 +00:00
jkunz a53a2c4ca5 fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup 2026-04-14 00:54:12 +00:00
28 changed files with 2674 additions and 655 deletions
+46
View File
@@ -1,5 +1,51 @@
# Changelog
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
- adds a bounded retention window for closed IP buckets so short-lived transfers are still included in per-IP throughput sampling
- prunes expired inactive IP tracking by TTL and hard cap to prevent unbounded metric map growth
- updates Rust and throughput tests to expect zero active connections during the temporary retention period
## 2026-04-26 - 27.8.1 - fix(rustproxy-metrics)
preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated
- Select snapshot IPs using a blend of active-connection and throughput rankings instead of only active connections
- Adds a regression test to ensure a high-bandwidth IP remains included when many other IPs have more active connections
## 2026-04-14 - 27.8.0 - feat(metrics)
add per-domain HTTP request rate metrics
- Record canonicalized HTTP request rates per domain in the Rust metrics collector and expose per-second and last-minute values in snapshots.
- Add TypeScript metrics interfaces and adapter support for requests.byDomain().
- Cover HTTP domain rate tracking and ensure TLS passthrough SNI traffic does not affect HTTP request rate metrics.
## 2026-04-14 - 27.7.4 - fix(rustproxy metrics)
use stable route metrics keys across HTTP and passthrough listeners
- adds a shared RouteConfig::metrics_key helper that prefers route name and falls back to route id
- updates HTTP, TCP, UDP, and QUIC metrics labeling to use the shared route metrics key consistently
- keeps route cancellation and rate limiter indexing bound to route config ids where required
- adds tests covering metrics key selection behavior
## 2026-04-14 - 27.7.3 - fix(repo)
no changes detected
## 2026-04-14 - 27.7.2 - fix(docs)
clarify metrics documentation for domain normalization and saturating gauges
- Document that per-IP domain keys are normalized to lowercase and have trailing dots stripped before counting.
- Clarify that the saturating close pattern also applies to connection and UDP active gauges.
## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics)
fix domain-scoped request host detection and harden connection metrics cleanup
- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header
- add request filter and host extraction tests covering domain-scoped ACL behavior
- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely
- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations
## 2026-04-13 - 27.7.0 - feat(smart-proxy)
add typed Rust config serialization and regex header contract coverage
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartproxy",
"version": "27.7.0",
"version": "27.8.2",
"private": false,
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
"main": "dist_ts/index.js",
+2 -2
View File
@@ -78,7 +78,7 @@ Entries are pruned via `retain_routes()` when routes are removed.
All seven maps for an IP are evicted when its active connection count drops to 0. Safety-net pruning in `sample_all()` catches entries orphaned by races. Snapshots cap at 100 IPs, sorted by active connections descending.
**Domain request tracking:** Records which domains each frontend IP has requested. Populated from HTTP Host headers (for HTTP/1.1, HTTP/2, HTTP/3 requests) and from SNI (for TLS passthrough connections). Capped at 256 domains per IP (`MAX_DOMAINS_PER_IP`) to prevent subdomain-spray abuse. Inner DashMap uses 2 shards to minimise base memory per IP (~200 bytes). Common case (IP + domain both known) is two DashMap reads + one atomic increment with zero allocation.
**Domain request tracking:** Records which domains each frontend IP has requested. Populated from HTTP Host headers (for HTTP/1.1, HTTP/2, HTTP/3 requests) and from SNI (for TLS passthrough connections). Domain keys are normalized to lowercase with any trailing dot stripped so the same hostname does not fragment across multiple counters. Capped at 256 domains per IP (`MAX_DOMAINS_PER_IP`) to prevent subdomain-spray abuse. Inner DashMap uses 2 shards to minimise base memory per IP (~200 bytes). Common case (IP + domain both known) is two DashMap reads + one atomic increment with zero allocation.
### Per-Backend Metrics (keyed by "host:port")
@@ -110,7 +110,7 @@ Tracked via `ProtocolGuard` RAII guards and `FrontendProtocolTracker`. Five prot
| ws | `ProtocolGuard::frontend("ws")` on WebSocket upgrade |
| other | Fallback (TCP passthrough without HTTP) |
Uses `fetch_update` for saturating decrements to prevent underflow races.
Uses `fetch_update` for saturating decrements to prevent underflow races. The same saturating-close pattern is also used for connection and UDP active gauges.
### Backend Protocol Distribution
@@ -656,6 +656,11 @@ impl RouteConfig {
self.route_match.ports.to_ports()
}
/// Stable key used for frontend route-scoped metrics.
pub fn metrics_key(&self) -> Option<&str> {
self.name.as_deref().or(self.id.as_deref())
}
/// Get the TLS mode for this route (from action-level or first target).
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
// Check action-level TLS first
@@ -673,3 +678,63 @@ impl RouteConfig {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_route(name: Option<&str>, id: Option<&str>) -> RouteConfig {
RouteConfig {
id: id.map(str::to_string),
route_match: RouteMatch {
ports: PortRange::Single(443),
transport: None,
domains: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: name.map(str::to_string),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
#[test]
fn metrics_key_prefers_name() {
let route = test_route(Some("named-route"), Some("route-id"));
assert_eq!(route.metrics_key(), Some("named-route"));
}
#[test]
fn metrics_key_falls_back_to_id() {
let route = test_route(None, Some("route-id"));
assert_eq!(route.metrics_key(), Some("route-id"));
}
#[test]
fn metrics_key_is_absent_without_name_or_id() {
let route = test_route(None, None);
assert_eq!(route.metrics_key(), None);
}
}
@@ -3,8 +3,8 @@
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::Bytes;
@@ -105,13 +105,19 @@ impl ConnectionPool {
/// Try to check out an idle HTTP/1.1 sender for the given key.
/// Returns `None` if no usable idle connection exists.
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
pub fn checkout_h1(
&self,
key: &PoolKey,
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
let mut entry = self.h1_pool.get_mut(key)?;
let idles = entry.value_mut();
while let Some(idle) = idles.pop() {
// Check if the connection is still alive and ready
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() {
if idle.idle_since.elapsed() < IDLE_TIMEOUT
&& idle.sender.is_ready()
&& !idle.sender.is_closed()
{
// H1 pool hit — no logging on hot path
return Some(idle.sender);
}
@@ -128,7 +134,11 @@ impl ConnectionPool {
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
/// The caller should NOT call this if the sender is closed or not ready.
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) {
pub fn checkin_h1(
&self,
key: PoolKey,
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
) {
if sender.is_closed() || !sender.is_ready() {
return; // Don't pool broken connections
}
@@ -145,7 +155,10 @@ impl ConnectionPool {
/// Try to get a cloned HTTP/2 sender for the given key.
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
pub fn checkout_h2(
&self,
key: &PoolKey,
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
let entry = self.h2_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
@@ -184,16 +197,23 @@ impl ConnectionPool {
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
/// The caller should pass this generation to the connection driver so it can use
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 {
pub fn register_h2(
&self,
key: PoolKey,
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
if sender.is_closed() {
return gen;
}
self.h2_pool.insert(key, PooledH2 {
self.h2_pool.insert(
key,
PooledH2 {
sender,
created_at: Instant::now(),
generation: gen,
});
},
);
gen
}
@@ -204,7 +224,11 @@ impl ConnectionPool {
pub fn checkout_h3(
&self,
key: &PoolKey,
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> {
) -> Option<(
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
quinn::Connection,
Duration,
)> {
let entry = self.h3_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
@@ -234,12 +258,15 @@ impl ConnectionPool {
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(key, PooledH3 {
self.h3_pool.insert(
key,
PooledH3 {
send_request,
connection,
created_at: Instant::now(),
generation: gen,
});
},
);
gen
}
@@ -280,7 +307,9 @@ impl ConnectionPool {
// Evict dead or aged-out H2 connections
let mut dead_h2 = Vec::new();
for entry in h2_pool.iter() {
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE {
if entry.value().sender.is_closed()
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
{
dead_h2.push(entry.key().clone());
}
}
@@ -1,8 +1,8 @@
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
/// Set the connection-level activity tracker. When set, each data frame
/// updates this timestamp to prevent the idle watchdog from killing the
/// connection during active body streaming.
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self {
pub fn with_connection_activity(
mut self,
activity: Arc<AtomicU64>,
start: std::time::Instant,
) -> Self {
self.connection_activity = Some(activity);
self.activity_start = Some(start);
self
@@ -134,7 +138,9 @@ where
}
// Keep the connection-level idle watchdog alive on every frame
// (this is just one atomic store — cheap enough per-frame)
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) {
if let (Some(activity), Some(start)) =
(&this.connection_activity, &this.activity_start)
{
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
}
+28 -10
View File
@@ -11,8 +11,8 @@ use std::task::{Context, Poll};
use bytes::{Buf, Bytes};
use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use tracing::{debug, warn};
use rustproxy_config::RouteConfig;
@@ -49,7 +49,8 @@ impl H3ProxyService {
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
// Track frontend H3 connection for the QUIC connection's lifetime.
let _frontend_h3_guard = ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let _frontend_h3_guard =
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::builder()
@@ -92,8 +93,15 @@ impl H3ProxyService {
tokio::spawn(async move {
if let Err(e) = handle_h3_request(
request, stream, port, remote_addr, &http_proxy, request_cancel,
).await {
request,
stream,
port,
remote_addr,
&http_proxy,
request_cancel,
)
.await
{
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
}
});
@@ -153,11 +161,14 @@ async fn handle_h3_request(
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
let conn_activity = ConnActivity::new_standalone();
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
let response = http_proxy
.handle_request(req, peer_addr, port, cancel, conn_activity)
.await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Await the body reader to get the H3 stream back
let mut stream = body_reader.await
let mut stream = body_reader
.await
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
// Send response headers over H3 (skip hop-by-hop headers)
@@ -170,10 +181,13 @@ async fn handle_h3_request(
}
h3_response = h3_response.header(name, value);
}
let h3_response = h3_response.body(())
let h3_response = h3_response
.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
stream.send_response(h3_response).await
stream
.send_response(h3_response)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back over H3
@@ -182,7 +196,9 @@ async fn handle_h3_request(
match frame {
Ok(frame) => {
if let Ok(data) = frame.into_data() {
stream.send_data(data).await
stream
.send_data(data)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
}
}
@@ -194,7 +210,9 @@ async fn handle_h3_request(
}
// Finish the H3 stream (send QUIC FIN)
stream.finish().await
stream
.finish()
.await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(())
+2 -1
View File
@@ -5,14 +5,15 @@
pub mod connection_pool;
pub mod counting_body;
pub mod h3_service;
pub mod protocol_cache;
pub mod proxy_service;
pub mod request_filter;
mod request_host;
pub mod response_filter;
pub mod shutdown_on_drop;
pub mod template;
pub mod upstream_selector;
pub mod h3_service;
pub use connection_pool::*;
pub use counting_body::*;
@@ -144,10 +144,14 @@ impl FailureState {
}
fn all_expired(&self) -> bool {
let h2_expired = self.h2.as_ref()
let h2_expired = self
.h2
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
let h3_expired = self.h3.as_ref()
let h3_expired = self
.h3
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
h2_expired && h3_expired
@@ -355,9 +359,13 @@ impl ProtocolCache {
let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record {
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => {
Some(existing)
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
{
// Still within the "recent" window — escalate
let c = existing.consecutive_failures.saturating_add(1)
let c = existing
.consecutive_failures
.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c))
}
@@ -394,8 +402,13 @@ impl ProtocolCache {
if protocol == DetectedProtocol::H1 {
return false;
}
self.failures.get(key)
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown))
self.failures
.get(key)
.and_then(|entry| {
entry
.get(protocol)
.map(|r| r.failed_at.elapsed() < r.cooldown)
})
.unwrap_or(false)
}
@@ -464,19 +477,18 @@ impl ProtocolCache {
/// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache.iter()
self.cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| {
let key = entry.key();
let val = entry.value();
let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info(
failure_info.as_deref().and_then(|f| f.h2.as_ref()),
);
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info(
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
);
let (h2_sup, h2_cd, h2_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
let (h3_sup, h3_cd, h3_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
ProtocolCacheEntry {
host: key.host.clone(),
@@ -507,7 +519,13 @@ impl ProtocolCache {
/// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling.
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) {
fn insert_internal(
&self,
key: ProtocolCacheKey,
protocol: DetectedProtocol,
h3_port: Option<u16>,
reason: &str,
) {
// Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol {
@@ -522,7 +540,9 @@ impl ProtocolCache {
// Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self.cache.iter()
let oldest = self
.cache
.iter()
.min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
@@ -531,13 +551,16 @@ impl ProtocolCache {
}
let now = Instant::now();
self.cache.insert(key, CachedEntry {
self.cache.insert(
key,
CachedEntry {
protocol,
detected_at: now,
last_accessed_at: now,
last_probed_at: now,
h3_port,
});
},
);
}
/// Reduce a failure record's remaining cooldown to `target`, if it currently
@@ -582,26 +605,34 @@ impl ProtocolCache {
interval.tick().await;
// Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache.iter()
let expired: Vec<ProtocolCacheKey> = cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone())
.collect();
if !expired.is_empty() {
debug!("Protocol cache cleanup: removing {} expired entries", expired.len());
debug!(
"Protocol cache cleanup: removing {} expired entries",
expired.len()
);
for key in expired {
cache.remove(&key);
}
}
// Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures.iter()
let expired_failures: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone())
.collect();
if !expired_failures.is_empty() {
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len());
debug!(
"Protocol cache cleanup: removing {} expired failure entries",
expired_failures.len()
);
for key in expired_failures {
failures.remove(&key);
}
@@ -609,7 +640,8 @@ impl ProtocolCache {
// Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures.iter()
let oldest: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|e| e.value().all_expired())
.map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
File diff suppressed because it is too large Load Diff
+132 -32
View File
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
use crate::request_host::extract_request_host;
pub struct RequestFilter;
@@ -35,16 +37,13 @@ impl RequestFilter {
let client_ip = peer_addr.ip();
let request_path = req.uri().path();
// IP filter (domain-aware: extract Host header for domain-scoped entries)
// IP filter (domain-aware: use the same host extraction as route matching)
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip);
let host = req.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.map(|h| h.split(':').next().unwrap_or(h));
let host = extract_request_host(req);
if !filter.is_allowed_for_domain(&normalized, host) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
}
@@ -59,16 +58,15 @@ impl RequestFilter {
!limiter.check(&key)
} else {
// Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new(
rate_limit_config.max_requests,
rate_limit_config.window,
);
let limiter =
RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
};
if should_block {
let message = rate_limit_config.error_message
let message = rate_limit_config
.error_message
.as_deref()
.unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
@@ -84,36 +82,48 @@ impl RequestFilter {
if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled {
// Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref()
let skip_basic = basic_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter()
let users: Vec<(String, String)> = basic_auth
.users
.iter()
.map(|c| (c.username.clone(), c.password.clone()))
.collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers()
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header {
Some(header) => {
if validator.validate(header).is_none() {
return Some(Response::builder()
return Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.header(
"WWW-Authenticate",
validator.www_authenticate(),
)
.body(boxed_body("Invalid credentials"))
.unwrap());
.unwrap(),
);
}
}
None => {
return Some(Response::builder()
return Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap());
.unwrap(),
);
}
}
}
@@ -124,7 +134,9 @@ impl RequestFilter {
if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled {
// Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref()
let skip_jwt = jwt_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
@@ -136,18 +148,25 @@ impl RequestFilter {
jwt_auth.audience.as_deref(),
);
let auth_header = req.headers()
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => {
if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
return Some(error_response(
StatusCode::UNAUTHORIZED,
"Invalid token",
));
}
}
None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
return Some(error_response(
StatusCode::UNAUTHORIZED,
"Bearer token required",
));
}
}
}
@@ -209,7 +228,11 @@ impl RequestFilter {
/// Check IP-based security (for use in passthrough / TCP-level connections).
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
/// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr, domain: Option<&str>) -> bool {
pub fn check_ip_security(
security: &RouteSecurity,
client_ip: &std::net::IpAddr,
domain: Option<&str>,
) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
@@ -238,19 +261,28 @@ impl RequestFilter {
return None;
}
let origin = req.headers()
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");
Some(Response::builder()
Some(
Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
.header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
)
.header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, X-Requested-With",
)
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap())
.unwrap(),
)
}
}
@@ -265,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::{Request, StatusCode, Version};
use rustproxy_config::{IpAllowEntry, RouteSecurity};
use super::RequestFilter;
fn domain_scoped_security() -> RouteSecurity {
RouteSecurity {
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
ip: "10.8.0.2".to_string(),
domains: vec!["*.abc.xyz".to_string()],
}]),
ip_block_list: None,
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
}
}
fn peer_addr() -> std::net::SocketAddr {
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
}
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
let mut builder = Request::builder().uri(uri).version(version);
if let Some(host) = host {
builder = builder.header("host", host);
}
builder.body(Empty::<Bytes>::new()).unwrap()
}
#[test]
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
let security = domain_scoped_security();
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_allows_host_header_with_port() {
let security = domain_scoped_security();
let req = request(
"https://unrelated.invalid/",
Version::HTTP_11,
Some("outline.abc.xyz:443"),
);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_denies_non_matching_uri_authority() {
let security = domain_scoped_security();
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
let response = RequestFilter::apply(&security, &req, &peer_addr())
.expect("non-matching domain should be denied");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
@@ -0,0 +1,43 @@
use hyper::Request;
/// Extract the effective request host for routing and scoped ACL checks.
///
/// Prefer the explicit `Host` header when present, otherwise fall back to the
/// URI authority used by HTTP/2 and HTTP/3 requests.
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
req.headers()
.get("host")
.and_then(|value| value.to_str().ok())
.map(|host| host.split(':').next().unwrap_or(host))
.or_else(|| req.uri().host())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::Request;
use super::extract_request_host;
#[test]
fn extracts_host_header_before_uri_authority() {
let req = Request::builder()
.uri("https://uri.abc.xyz/test")
.header("host", "header.abc.xyz:443")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
}
#[test]
fn falls_back_to_uri_authority_when_host_header_missing() {
let req = Request::builder()
.uri("https://outline.abc.xyz/test")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
}
}
@@ -3,7 +3,7 @@
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template};
use crate::template::{expand_template, RequestContext};
pub struct ResponseFilter;
@@ -11,12 +11,17 @@ impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
pub fn apply_headers(
route: &RouteConfig,
headers: &mut HeaderMap,
req_ctx: Option<&RequestContext>,
) {
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
if let Some(ref udp) = route.action.udp {
if let Some(ref quic) = udp.quic {
if quic.enable_http3.unwrap_or(false) {
let port = quic.alt_svc_port
let port = quic
.alt_svc_port
.or_else(|| req_ctx.map(|c| c.port))
.unwrap_or(443);
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
@@ -63,10 +68,7 @@ impl ResponseFilter {
headers.insert("access-control-allow-origin", val);
}
} else {
headers.insert(
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
}
// Allow-Methods
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
self.inner.as_ref().unwrap().is_write_vectored()
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
if result.is_ready() {
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
tokio::io::AsyncWriteExt::shutdown(&mut stream),
).await;
)
.await;
// stream is dropped here — all resources freed
});
}
+6 -2
View File
@@ -39,7 +39,8 @@ pub fn expand_headers(
headers: &HashMap<String, String>,
ctx: &RequestContext,
) -> HashMap<String, String> {
headers.iter()
headers
.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect()
}
@@ -150,7 +151,10 @@ mod tests {
let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
assert_eq!(
result,
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
);
}
#[test]
@@ -7,7 +7,7 @@ use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
/// Upstream selection result.
pub struct UpstreamSelection {
@@ -51,21 +51,19 @@ impl UpstreamSelector {
}
// Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref()
let algorithm = target
.load_balancing
.as_ref()
.map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr);
hash % hosts.len()
}
LoadBalancingAlgorithm::LeastConnections => {
self.least_connections_select(&hosts, port)
}
LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
};
UpstreamSelection {
@@ -78,9 +76,7 @@ impl UpstreamSelector {
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap();
let counter = counters
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len()
}
@@ -91,7 +87,8 @@ impl UpstreamSelector {
for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port);
let conns = self.active_connections
let conns = self
.active_connections
.get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0);
@@ -228,13 +225,21 @@ mod tests {
selector.connection_started("backend:8080");
selector.connection_started("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
2
);
selector.connection_ended("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
1
);
File diff suppressed because it is too large Load Diff
+146 -1
View File
@@ -29,6 +29,113 @@ pub struct ThroughputTracker {
created_at: Instant,
}
fn unix_timestamp_seconds() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
/// Circular buffer for per-second event counts.
///
/// Unlike `ThroughputTracker`, events are recorded directly into the current
/// second so request counts remain stable even when the collector is sampled
/// more frequently than once per second.
pub(crate) struct RequestRateTracker {
samples: Vec<u64>,
write_index: usize,
count: usize,
capacity: usize,
current_second: Option<u64>,
current_count: u64,
}
impl RequestRateTracker {
pub(crate) fn new(retention_seconds: usize) -> Self {
Self {
samples: Vec::with_capacity(retention_seconds.max(1)),
write_index: 0,
count: 0,
capacity: retention_seconds.max(1),
current_second: None,
current_count: 0,
}
}
fn push_sample(&mut self, count: u64) {
if self.samples.len() < self.capacity {
self.samples.push(count);
} else {
self.samples[self.write_index] = count;
}
self.write_index = (self.write_index + 1) % self.capacity;
self.count = (self.count + 1).min(self.capacity);
}
pub(crate) fn record_event(&mut self) {
self.record_events_at(unix_timestamp_seconds(), 1);
}
pub(crate) fn record_events_at(&mut self, now_sec: u64, count: u64) {
self.advance_to(now_sec);
self.current_count = self.current_count.saturating_add(count);
}
pub(crate) fn advance_to_now(&mut self) {
self.advance_to(unix_timestamp_seconds());
}
pub(crate) fn advance_to(&mut self, now_sec: u64) {
match self.current_second {
Some(current_second) if now_sec > current_second => {
self.push_sample(self.current_count);
for _ in 1..(now_sec - current_second) {
self.push_sample(0);
}
self.current_second = Some(now_sec);
self.current_count = 0;
}
Some(_) => {}
None => {
self.current_second = Some(now_sec);
self.current_count = 0;
}
}
}
fn sum_recent(&self, window_seconds: usize) -> u64 {
let window = window_seconds.min(self.count);
if window == 0 {
return 0;
}
let mut total = 0u64;
for i in 0..window {
let idx = if self.write_index >= i + 1 {
self.write_index - i - 1
} else {
self.capacity - (i + 1 - self.write_index)
};
if idx < self.samples.len() {
total += self.samples[idx];
}
}
total
}
pub(crate) fn last_second(&self) -> u64 {
self.sum_recent(1)
}
pub(crate) fn last_minute(&self) -> u64 {
self.sum_recent(60)
}
pub(crate) fn is_idle(&self) -> bool {
self.current_count == 0 && self.sum_recent(self.capacity) == 0
}
}
impl ThroughputTracker {
/// Create a new tracker with the given capacity (seconds of retention).
pub fn new(retention_seconds: usize) -> Self {
@@ -46,7 +153,8 @@ impl ThroughputTracker {
/// Record bytes (called from data flow callbacks).
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
self.pending_bytes_out
.fetch_add(bytes_out, Ordering::Relaxed);
}
/// Take a sample (called at 1Hz).
@@ -229,4 +337,41 @@ mod tests {
let history = tracker.history(10);
assert!(history.is_empty());
}
#[test]
fn test_request_rate_tracker_counts_last_second_and_last_minute() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 2);
tracker.record_events_at(100, 3);
tracker.advance_to(101);
assert_eq!(tracker.last_second(), 5);
assert_eq!(tracker.last_minute(), 5);
}
#[test]
fn test_request_rate_tracker_adds_zero_samples_for_gaps() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 4);
tracker.record_events_at(102, 1);
tracker.advance_to(103);
assert_eq!(tracker.last_second(), 1);
assert_eq!(tracker.last_minute(), 5);
}
#[test]
fn test_request_rate_tracker_decays_to_zero_over_window() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 7);
tracker.advance_to(101);
tracker.advance_to(161);
assert_eq!(tracker.last_second(), 0);
assert_eq!(tracker.last_minute(), 0);
assert!(tracker.is_idle());
}
}
@@ -420,7 +420,7 @@ pub async fn quic_accept_loop(
}
conn_tracker.connection_opened(&ip);
let route_id = route.name.clone().or(route.id.clone());
let route_id = route.metrics_key().map(str::to_string);
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
// Resolve per-route cancel token (child of global cancel)
@@ -541,7 +541,7 @@ async fn handle_quic_stream_forwarding(
real_client_addr: Option<SocketAddr>,
) -> anyhow::Result<()> {
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
let route_id = route.name.as_deref().or(route.id.as_deref());
let route_id = route.metrics_key();
let metrics_arc = metrics;
// Resolve backend target
@@ -715,10 +715,11 @@ impl TcpListenerManager {
} else if let Some(target) = quick_match.target {
let target_host = target.host.first().to_string();
let target_port = target.port.resolve(port);
let route_id = quick_match.route.id.as_deref();
let route_config_id = quick_match.route.id.as_deref();
let route_id = quick_match.route.metrics_key();
// Resolve per-route cancel token (child of global cancel)
let route_cancel = match route_id {
let route_cancel = match route_config_id {
Some(id) => route_cancels.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
@@ -733,7 +734,7 @@ impl TcpListenerManager {
cancel: conn_cancel.clone(),
source_ip: peer_addr.ip(),
domain: None, // fast path has no domain
route_id: route_id.map(|s| s.to_string()),
route_id: route_config_id.map(|s| s.to_string()),
},
);
@@ -905,12 +906,13 @@ impl TcpListenerManager {
}
};
let route_id = route_match.route.id.as_deref();
let route_config_id = route_match.route.id.as_deref();
let route_id = route_match.route.metrics_key();
// Resolve per-route cancel token (child of global cancel).
// When this route is removed via updateRoutes, the token is cancelled,
// terminating all connections on this route.
let route_cancel = match route_id {
let route_cancel = match route_config_id {
Some(id) => route_cancels.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
@@ -925,7 +927,7 @@ impl TcpListenerManager {
cancel: cancel.clone(),
source_ip: peer_addr.ip(),
domain: domain.clone(),
route_id: route_id.map(|s| s.to_string()),
route_id: route_config_id.map(|s| s.to_string()),
},
);
@@ -1314,9 +1316,7 @@ impl TcpListenerManager {
};
// Build metadata JSON
let route_key = route_match.route.name.as_deref()
.or(route_match.route.id.as_deref())
.unwrap_or("unknown");
let route_key = route_match.route.metrics_key().unwrap_or("unknown");
let metadata = serde_json::json!({
"routeKey": route_key,
@@ -617,7 +617,7 @@ impl UdpListenerManager {
};
let route = route_match.route;
let route_id = route.name.as_deref().or(route.id.as_deref());
let route_id = route.metrics_key();
// Socket handler routes → relay datagram to TS via persistent Unix socket
if route.action.action_type == RouteActionType::SocketHandler {
+191
View File
@@ -0,0 +1,191 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
import * as http from 'http';
import * as net from 'net';
import * as tls from 'tls';
import * as fs from 'fs';
import * as path from 'path';
import { fileURLToPath } from 'url';
import { assertPortsFree, findFreePorts } from './helpers/port-allocator.js';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const CERT_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8');
const KEY_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8');
let httpBackendPort: number;
let tlsBackendPort: number;
let httpProxyPort: number;
let tlsProxyPort: number;
let httpBackend: http.Server;
let tlsBackend: tls.Server;
let proxy: SmartProxy;
async function pollMetrics(proxyToPoll: SmartProxy): Promise<void> {
await (proxyToPoll as any).metricsAdapter.poll();
}
async function waitForCondition(
callback: () => Promise<boolean>,
timeoutMs: number = 5000,
stepMs: number = 100,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await callback()) {
return;
}
await new Promise((resolve) => setTimeout(resolve, stepMs));
}
throw new Error(`Condition not met within ${timeoutMs}ms`);
}
function hasIpDomainRequest(domain: string): boolean {
const byIp = proxy.getMetrics().connections.domainRequestsByIP();
for (const domainMap of byIp.values()) {
if (domainMap.has(domain)) {
return true;
}
}
return false;
}
tap.test('setup - backend servers for HTTP domain rate metrics', async () => {
[httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort] = await findFreePorts(4);
httpBackend = http.createServer((req, res) => {
let body = '';
req.on('data', (chunk) => {
body += chunk;
});
req.on('end', () => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end(`ok:${body}`);
});
});
await new Promise<void>((resolve) => {
httpBackend.listen(httpBackendPort, () => resolve());
});
tlsBackend = tls.createServer({ cert: CERT_PEM, key: KEY_PEM }, (socket) => {
socket.on('data', (data) => {
socket.write(data);
});
socket.on('error', () => {});
});
await new Promise<void>((resolve) => {
tlsBackend.listen(tlsBackendPort, () => resolve());
});
});
tap.test('setup - start proxy with HTTP and TLS passthrough routes', async () => {
proxy = new SmartProxy({
routes: [
{
id: 'http-domain-rates',
name: 'http-domain-rates',
match: { ports: httpProxyPort, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: httpBackendPort }],
},
},
{
id: 'tls-passthrough-domain-rates',
name: 'tls-passthrough-domain-rates',
match: { ports: tlsProxyPort, domains: 'passthrough.example.com' },
action: {
type: 'forward',
tls: { mode: 'passthrough' },
targets: [{ host: 'localhost', port: tlsBackendPort }],
},
},
],
metrics: { enabled: true, sampleIntervalMs: 100, retentionSeconds: 60 },
});
await proxy.start();
await new Promise((resolve) => setTimeout(resolve, 300));
});
tap.test('HTTP requests populate per-domain HTTP request rates', async () => {
for (let i = 0; i < 3; i++) {
await new Promise<void>((resolve, reject) => {
const body = `payload-${i}`;
const req = http.request(
{
hostname: 'localhost',
port: httpProxyPort,
path: '/echo',
method: 'POST',
headers: {
Host: 'Example.COM',
'Content-Type': 'text/plain',
'Content-Length': String(body.length),
},
},
(res) => {
res.resume();
res.on('end', () => resolve());
},
);
req.on('error', reject);
req.end(body);
});
}
await waitForCondition(async () => {
await pollMetrics(proxy);
const domainMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
return (domainMetrics?.lastMinute ?? 0) >= 3 && (domainMetrics?.perSecond ?? 0) > 0;
});
const exampleMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
expect(exampleMetrics).toBeTruthy();
expect(exampleMetrics?.lastMinute).toEqual(3);
expect(exampleMetrics?.perSecond).toBeGreaterThan(0);
});
tap.test('TLS passthrough SNI does not inflate HTTP domain request rates', async () => {
const tlsClient = tls.connect({
host: 'localhost',
port: tlsProxyPort,
servername: 'passthrough.example.com',
rejectUnauthorized: false,
});
await new Promise<void>((resolve, reject) => {
tlsClient.once('secureConnect', () => resolve());
tlsClient.once('error', reject);
});
const echoPromise = new Promise<void>((resolve, reject) => {
tlsClient.once('data', () => resolve());
tlsClient.once('error', reject);
});
tlsClient.write(Buffer.from('hello over tls passthrough'));
await echoPromise;
await waitForCondition(async () => {
await pollMetrics(proxy);
return hasIpDomainRequest('passthrough.example.com');
});
const requestRates = proxy.getMetrics().requests.byDomain();
expect(requestRates.has('passthrough.example.com')).toBeFalse();
expect(requestRates.get('example.com')?.lastMinute).toEqual(3);
expect(hasIpDomainRequest('passthrough.example.com')).toBeTrue();
tlsClient.destroy();
});
tap.test('cleanup - stop proxy and close backend servers', async () => {
await proxy.stop();
await new Promise<void>((resolve) => httpBackend.close(() => resolve()));
await new Promise<void>((resolve) => tlsBackend.close(() => resolve()));
await assertPortsFree([httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort]);
});
export default tap.start()
+3
View File
@@ -83,6 +83,9 @@ tap.test('should verify new metrics API structure', async () => {
expect(metrics.throughput).toHaveProperty('history');
expect(metrics.throughput).toHaveProperty('byRoute');
expect(metrics.throughput).toHaveProperty('byIP');
// Check request methods
expect(metrics.requests).toHaveProperty('byDomain');
});
tap.test('should track active connections', async (tools) => {
+4 -2
View File
@@ -188,10 +188,12 @@ tap.test('TCP forward - real-time byte tracking', async (tools) => {
const byRoute = m.throughput.byRoute();
console.log('TCP forward — throughput byRoute:', Array.from(byRoute.entries()));
// After close, per-IP data should be evicted (memory leak fix)
// After close, per-IP buckets are retained briefly for final throughput sampling,
// but active connection counts must already be zero.
const byIPAfter = m.connections.byIP();
console.log('TCP forward — connections byIP after close:', Array.from(byIPAfter.entries()));
expect(byIPAfter.size).toEqual(0);
expect(byIPAfter.size).toBeGreaterThan(0);
expect(Array.from(byIPAfter.values()).every((count) => count === 0)).toEqual(true);
await proxy.stop();
await tools.delayFor(200);
+1 -1
View File
@@ -3,6 +3,6 @@
*/
export const commitinfo = {
name: '@push.rocks/smartproxy',
version: '27.7.0',
version: '27.8.2',
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
}
@@ -29,6 +29,11 @@ export interface IThroughputHistoryPoint {
out: number;
}
export interface IRequestRateMetrics {
perSecond: number;
lastMinute: number;
}
/**
* Main metrics interface with clean, grouped API
*/
@@ -81,6 +86,7 @@ export interface IMetrics {
perSecond(): number;
perMinute(): number;
total(): number;
byDomain(): Map<string, IRequestRateMetrics>;
};
// Cumulative totals
@@ -134,6 +134,11 @@ export interface IRustBackendMetrics {
h2Failures: number;
}
export interface IRustHttpDomainRequestMetrics {
requestsPerSecond: number;
requestsLastMinute: number;
}
export interface IRustMetricsSnapshot {
activeConnections: number;
totalConnections: number;
@@ -150,6 +155,7 @@ export interface IRustMetricsSnapshot {
totalHttpRequests: number;
httpRequestsPerSec: number;
httpRequestsPerSecRecent: number;
httpDomainRequests: Record<string, IRustHttpDomainRequestMetrics>;
activeUdpSessions: number;
totalUdpSessions: number;
totalDatagramsIn: number;
+14 -2
View File
@@ -1,6 +1,6 @@
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IRequestRateMetrics, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
import type { RustProxyBridge } from './rust-proxy-bridge.js';
import type { IRustBackendMetrics, IRustIpMetrics, IRustMetricsSnapshot, IRustRouteMetrics } from './models/rust-types.js';
import type { IRustBackendMetrics, IRustHttpDomainRequestMetrics, IRustIpMetrics, IRustMetricsSnapshot, IRustRouteMetrics } from './models/rust-types.js';
/**
* Adapts Rust JSON metrics to the IMetrics interface.
@@ -219,6 +219,18 @@ export class RustMetricsAdapter implements IMetrics {
total: (): number => {
return this.cache?.totalHttpRequests ?? this.cache?.totalConnections ?? 0;
},
byDomain: (): Map<string, IRequestRateMetrics> => {
const result = new Map<string, IRequestRateMetrics>();
if (this.cache?.httpDomainRequests) {
for (const [domain, metrics] of Object.entries(this.cache.httpDomainRequests) as Array<[string, IRustHttpDomainRequestMetrics]>) {
result.set(domain, {
perSecond: metrics.requestsPerSecond ?? 0,
lastMinute: metrics.requestsLastMinute ?? 0,
});
}
}
return result;
},
};
public totals = {