fix(rustproxy): prune stale per-route metrics, add per-route rate limiter caching and regex cache, and improve connection tracking cleanup to prevent memory growth

This commit is contained in:
2026-02-19 08:48:46 +00:00
parent b4b8bd925d
commit 53d73c7dc6
7 changed files with 219 additions and 12 deletions

View File

@@ -9,6 +9,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
@@ -23,6 +24,7 @@ use std::task::{Context, Poll};
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use rustproxy_security::RateLimiter;
use crate::counting_body::{CountingBody, Direction};
use crate::request_filter::RequestFilter;
@@ -164,6 +166,12 @@ pub struct HttpProxyService {
upstream_selector: UpstreamSelector,
/// Timeout for connecting to upstream backends.
connect_timeout: std::time::Duration,
/// Per-route rate limiters (keyed by route ID).
route_rate_limiters: Arc<DashMap<String, Arc<RateLimiter>>>,
/// Request counter for periodic rate limiter cleanup.
request_counter: AtomicU64,
/// Cache of compiled URL rewrite regexes (keyed by pattern string).
regex_cache: DashMap<String, Regex>,
}
impl HttpProxyService {
@@ -173,6 +181,9 @@ impl HttpProxyService {
metrics,
upstream_selector: UpstreamSelector::new(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
regex_cache: DashMap::new(),
}
}
@@ -187,6 +198,9 @@ impl HttpProxyService {
metrics,
upstream_selector: UpstreamSelector::new(),
connect_timeout,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
regex_cache: DashMap::new(),
}
}
@@ -312,11 +326,31 @@ impl HttpProxyService {
// Apply request filters (IP check, rate limiting, auth)
if let Some(ref security) = route_match.route.security {
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
// Look up or create a shared rate limiter for this route
let rate_limiter = security.rate_limit.as_ref()
.filter(|rl| rl.enabled)
.map(|rl| {
let route_key = route_id.unwrap_or("__default__").to_string();
self.route_rate_limiters
.entry(route_key)
.or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window)))
.clone()
});
if let Some(response) = RequestFilter::apply_with_rate_limiter(
security, &req, &peer_addr, rate_limiter.as_ref(),
) {
return Ok(response);
}
}
// Periodic rate limiter cleanup (every 1000 requests)
let count = self.request_counter.fetch_add(1, Ordering::Relaxed);
if count % 1000 == 0 {
for entry in self.route_rate_limiters.iter() {
entry.value().cleanup();
}
}
// Check for test response (returns immediately, no upstream needed)
if let Some(ref advanced) = route_match.route.action.advanced {
if let Some(ref test_response) = advanced.test_response {
@@ -379,7 +413,7 @@ impl HttpProxyService {
Some(q) => format!("{}?{}", path, q),
None => path.clone(),
};
Self::apply_url_rewrite(&raw_path, &route_match.route)
self.apply_url_rewrite(&raw_path, &route_match.route)
};
// Build upstream request - stream body instead of buffering
@@ -1034,8 +1068,8 @@ impl HttpProxyService {
response.body(BoxBody::new(body)).unwrap()
}
/// Apply URL rewriting rules from route config.
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
/// Apply URL rewriting rules from route config, using the compiled regex cache.
fn apply_url_rewrite(&self, path: &str, route: &rustproxy_config::RouteConfig) -> String {
let rewrite = match route.action.advanced.as_ref()
.and_then(|a| a.url_rewrite.as_ref())
{
@@ -1054,10 +1088,20 @@ impl HttpProxyService {
(path.to_string(), String::new())
};
// Look up or compile the regex, caching for future requests
let cached = self.regex_cache.get(&rewrite.pattern);
if let Some(re) = cached {
let result = re.replace_all(&subject, rewrite.target.as_str());
return format!("{}{}", result, suffix);
}
// Not cached — compile and insert
match Regex::new(&rewrite.pattern) {
Ok(re) => {
let result = re.replace_all(&subject, rewrite.target.as_str());
format!("{}{}", result, suffix)
let out = format!("{}{}", result, suffix);
self.regex_cache.insert(rewrite.pattern.clone(), re);
out
}
Err(e) => {
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
@@ -1184,6 +1228,9 @@ impl Default for HttpProxyService {
metrics: Arc::new(MetricsCollector::new()),
upstream_selector: UpstreamSelector::new(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
regex_cache: DashMap::new(),
}
}
}

View File

@@ -115,10 +115,18 @@ impl UpstreamSelector {
/// Record that a connection to the given host has ended.
pub fn connection_ended(&self, host: &str) {
if let Some(counter) = self.active_connections.get(host) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Guard against underflow (shouldn't happen, but be safe)
let prev = counter.value().load(Ordering::Relaxed);
if prev == 0 {
counter.value().store(0, Ordering::Relaxed);
// Already at zero — just clean up the entry
drop(counter);
self.active_connections.remove(host);
return;
}
counter.value().fetch_sub(1, Ordering::Relaxed);
// Clean up zero-count entries to prevent memory growth
if prev <= 1 {
drop(counter);
self.active_connections.remove(host);
}
}
}
@@ -204,6 +212,31 @@ mod tests {
assert_eq!(r4.host, "a");
}
#[test]
fn test_connection_tracking_cleanup() {
let selector = UpstreamSelector::new();
selector.connection_started("backend:8080");
selector.connection_started("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
2
);
selector.connection_ended("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
1
);
// Last connection ends — entry should be removed entirely
selector.connection_ended("backend:8080");
assert!(selector.active_connections.get("backend:8080").is_none());
// Ending on a non-existent key should not panic
selector.connection_ended("nonexistent:9999");
}
#[test]
fn test_ip_hash_consistent() {
let selector = UpstreamSelector::new();

View File

@@ -1,5 +1,6 @@
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
@@ -196,6 +197,12 @@ impl MetricsCollector {
if val <= 1 {
drop(counter);
self.ip_connections.remove(ip);
// Evict all per-IP tracking data for this IP
self.ip_total_connections.remove(ip);
self.ip_bytes_in.remove(ip);
self.ip_bytes_out.remove(ip);
self.ip_pending_tp.remove(ip);
self.ip_throughput.remove(ip);
}
}
}
@@ -342,6 +349,17 @@ impl MetricsCollector {
}
}
/// Remove per-route metrics for route IDs that are no longer active.
/// Call this after `update_routes()` to prune stale entries.
pub fn retain_routes(&self, active_route_ids: &HashSet<String>) {
self.route_connections.retain(|k, _| active_route_ids.contains(k));
self.route_total_connections.retain(|k, _| active_route_ids.contains(k));
self.route_bytes_in.retain(|k, _| active_route_ids.contains(k));
self.route_bytes_out.retain(|k, _| active_route_ids.contains(k));
self.route_pending_tp.retain(|k, _| active_route_ids.contains(k));
self.route_throughput.retain(|k, _| active_route_ids.contains(k));
}
/// Get current active connection count.
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
@@ -633,6 +651,42 @@ mod tests {
assert!(collector.ip_connections.get("1.2.3.4").is_none());
}
#[test]
fn test_per_ip_full_eviction_on_last_close() {
let collector = MetricsCollector::with_retention(60);
// Open connections from two IPs
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
collector.connection_opened(Some("route-a"), Some("10.0.0.1"));
collector.connection_opened(Some("route-b"), Some("10.0.0.2"));
// Record bytes to populate per-IP DashMaps
collector.record_bytes(100, 200, Some("route-a"), Some("10.0.0.1"));
collector.record_bytes(300, 400, Some("route-b"), Some("10.0.0.2"));
collector.sample_all();
// Verify per-IP data exists
assert!(collector.ip_total_connections.get("10.0.0.1").is_some());
assert!(collector.ip_bytes_in.get("10.0.0.1").is_some());
assert!(collector.ip_throughput.get("10.0.0.1").is_some());
// Close all connections for 10.0.0.1
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
collector.connection_closed(Some("route-a"), Some("10.0.0.1"));
// All per-IP data for 10.0.0.1 should be evicted
assert!(collector.ip_connections.get("10.0.0.1").is_none());
assert!(collector.ip_total_connections.get("10.0.0.1").is_none());
assert!(collector.ip_bytes_in.get("10.0.0.1").is_none());
assert!(collector.ip_bytes_out.get("10.0.0.1").is_none());
assert!(collector.ip_pending_tp.get("10.0.0.1").is_none());
assert!(collector.ip_throughput.get("10.0.0.1").is_none());
// 10.0.0.2 should still have data
assert!(collector.ip_connections.get("10.0.0.2").is_some());
assert!(collector.ip_total_connections.get("10.0.0.2").is_some());
}
#[test]
fn test_http_request_tracking() {
let collector = MetricsCollector::with_retention(60);
@@ -650,6 +704,35 @@ mod tests {
assert_eq!(snapshot.http_requests_per_sec, 3);
}
#[test]
fn test_retain_routes_prunes_stale() {
let collector = MetricsCollector::with_retention(60);
// Create metrics for 3 routes
collector.connection_opened(Some("route-a"), None);
collector.connection_opened(Some("route-b"), None);
collector.connection_opened(Some("route-c"), None);
collector.record_bytes(100, 200, Some("route-a"), None);
collector.record_bytes(100, 200, Some("route-b"), None);
collector.record_bytes(100, 200, Some("route-c"), None);
collector.sample_all();
// Now "route-b" is removed from config
let active = HashSet::from(["route-a".to_string(), "route-c".to_string()]);
collector.retain_routes(&active);
// route-b entries should be gone
assert!(collector.route_connections.get("route-b").is_none());
assert!(collector.route_total_connections.get("route-b").is_none());
assert!(collector.route_bytes_in.get("route-b").is_none());
assert!(collector.route_bytes_out.get("route-b").is_none());
assert!(collector.route_throughput.get("route-b").is_none());
// route-a and route-c should still exist
assert!(collector.route_total_connections.get("route-a").is_some());
assert!(collector.route_total_connections.get("route-c").is_some());
}
#[test]
fn test_throughput_history_in_snapshot() {
let collector = MetricsCollector::with_retention(60);

View File

@@ -95,10 +95,11 @@ impl ConnectionTracker {
pub fn connection_closed(&self, ip: &IpAddr) {
if let Some(counter) = self.active.get(ip) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Clean up zero entries
// Clean up zero entries to prevent memory growth
if prev <= 1 {
drop(counter);
self.active.remove(ip);
self.timestamps.remove(ip);
}
}
}
@@ -205,10 +206,13 @@ impl ConnectionTracker {
let zombies = tracker.scan_zombies();
if !zombies.is_empty() {
warn!(
"Detected {} zombie connection(s): {:?}",
"Cleaning up {} zombie connection(s): {:?}",
zombies.len(),
zombies
);
for id in &zombies {
tracker.unregister_connection(*id);
}
}
}
}
@@ -304,6 +308,30 @@ mod tests {
assert_eq!(tracker.tracked_ips(), 1);
}
#[test]
fn test_timestamps_cleaned_on_last_close() {
let tracker = ConnectionTracker::new(None, Some(100));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
// try_accept populates the timestamps map (when rate limiting is enabled)
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
// Timestamps should exist
assert!(tracker.timestamps.get(&ip).is_some());
// Close one connection — timestamps should still exist
tracker.connection_closed(&ip);
assert!(tracker.timestamps.get(&ip).is_some());
// Close last connection — timestamps should be cleaned up
tracker.connection_closed(&ip);
assert!(tracker.timestamps.get(&ip).is_none());
assert!(tracker.active.get(&ip).is_none());
}
#[test]
fn test_register_unregister_connection() {
let tracker = ConnectionTracker::new(None, None);

View File

@@ -27,7 +27,7 @@
pub mod challenge_server;
pub mod management;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
@@ -565,6 +565,12 @@ impl RustProxy {
vec![]
};
// Prune per-route metrics for route IDs that no longer exist
let active_route_ids: HashSet<String> = routes.iter()
.filter_map(|r| r.id.clone())
.collect();
self.metrics.retain_routes(&active_route_ids);
// Atomically swap the route table
let new_manager = Arc::new(new_manager);
self.route_table.store(Arc::clone(&new_manager));