Files
smartproxy/rust/crates/rustproxy-passthrough/src/connection_registry.rs
T

330 lines
11 KiB
Rust

//! Shared connection registry for selective connection recycling.
//!
//! Tracks active connections across both TCP and QUIC with metadata
//! (source IP, SNI domain, route ID, cancel token) so that connections
//! can be selectively recycled when certificates, security rules, or
//! route targets change.
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use tracing::info;
use rustproxy_config::RouteSecurity;
use rustproxy_http::request_filter::RequestFilter;
use rustproxy_routing::matchers::domain_matches;
/// Metadata about an active connection.
pub struct ConnectionEntry {
/// Per-connection cancel token (child of per-route token).
pub cancel: CancellationToken,
/// Client source IP.
pub source_ip: IpAddr,
/// SNI domain from TLS handshake (None for non-TLS connections).
pub domain: Option<String>,
/// Route ID this connection was matched to (None if route has no ID).
pub route_id: Option<String>,
}
/// Transport-agnostic registry of active connections.
///
/// Used by both `TcpListenerManager` and `UdpListenerManager` to track
/// connections and enable selective recycling on config changes.
pub struct ConnectionRegistry {
connections: DashMap<u64, ConnectionEntry>,
next_id: AtomicU64,
}
impl ConnectionRegistry {
pub fn new() -> Self {
Self {
connections: DashMap::new(),
next_id: AtomicU64::new(1),
}
}
/// Register a connection and return its ID + RAII guard.
///
/// The guard automatically removes the connection from the registry on drop.
pub fn register(self: &Arc<Self>, entry: ConnectionEntry) -> (u64, ConnectionRegistryGuard) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
self.connections.insert(id, entry);
let guard = ConnectionRegistryGuard {
registry: Arc::clone(self),
conn_id: id,
};
(id, guard)
}
/// Number of tracked connections (for metrics/debugging).
pub fn len(&self) -> usize {
self.connections.len()
}
/// Recycle connections whose SNI domain matches a renewed certificate domain.
///
/// Uses bidirectional domain matching so that:
/// - Cert `*.example.com` recycles connections for `sub.example.com`
/// - Cert `sub.example.com` recycles connections on routes with `*.example.com`
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
let matches = entry.domain.as_deref()
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
.unwrap_or(false);
if matches {
entry.cancel.cancel();
recycled += 1;
false
} else {
true
}
});
if recycled > 0 {
info!(
"Recycled {} connection(s) for cert change on domain '{}'",
recycled, cert_domain
);
}
}
/// Recycle connections on a route whose security config changed.
///
/// Re-evaluates each connection's source IP against the new security rules.
/// Only connections from now-blocked IPs are terminated; allowed IPs are undisturbed.
pub fn recycle_for_security_change(&self, route_id: &str, new_security: &RouteSecurity) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) {
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
info!(
"Terminating connection from {} — IP now blocked on route '{}'",
entry.source_ip, route_id
);
entry.cancel.cancel();
recycled += 1;
return false;
}
}
true
});
if recycled > 0 {
info!(
"Recycled {} connection(s) for security change on route '{}'",
recycled, route_id
);
}
}
/// Recycle all connections on a route (e.g., when targets changed).
pub fn recycle_for_route_change(&self, route_id: &str) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) {
entry.cancel.cancel();
recycled += 1;
false
} else {
true
}
});
if recycled > 0 {
info!(
"Recycled {} connection(s) for config change on route '{}'",
recycled, route_id
);
}
}
/// Remove connections on routes that no longer exist.
///
/// This complements per-route CancellationToken cancellation —
/// the token cascade handles graceful shutdown, this cleans up the registry.
pub fn cleanup_removed_routes(&self, active_route_ids: &HashSet<String>) {
self.connections.retain(|_, entry| {
match &entry.route_id {
Some(id) => active_route_ids.contains(id),
None => true, // keep connections without a route ID
}
});
}
}
/// RAII guard that removes a connection from the registry on drop.
pub struct ConnectionRegistryGuard {
registry: Arc<ConnectionRegistry>,
conn_id: u64,
}
impl Drop for ConnectionRegistryGuard {
fn drop(&mut self) {
self.registry.connections.remove(&self.conn_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_registry() -> Arc<ConnectionRegistry> {
Arc::new(ConnectionRegistry::new())
}
#[test]
fn test_register_and_guard_cleanup() {
let reg = make_registry();
let token = CancellationToken::new();
let entry = ConnectionEntry {
cancel: token.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: Some("example.com".to_string()),
route_id: Some("route-1".to_string()),
};
let (id, guard) = reg.register(entry);
assert_eq!(reg.len(), 1);
assert!(reg.connections.contains_key(&id));
drop(guard);
assert_eq!(reg.len(), 0);
assert!(!token.is_cancelled());
}
#[test]
fn test_recycle_for_cert_change_exact() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: Some("api.example.com".to_string()),
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: Some("other.com".to_string()),
route_id: Some("r2".to_string()),
});
reg.recycle_for_cert_change("api.example.com");
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
// Registry retains unmatched entry (guard still alive keeps it too,
// but the retain removed the matched one before guard could)
}
#[test]
fn test_recycle_for_cert_change_wildcard() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: Some("sub.example.com".to_string()),
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: Some("other.com".to_string()),
route_id: Some("r2".to_string()),
});
// Wildcard cert should match subdomain
reg.recycle_for_cert_change("*.example.com");
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
}
#[test]
fn test_recycle_for_security_change() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: None,
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: None,
route_id: Some("r1".to_string()),
});
// Block 10.0.0.1, allow 10.0.0.2
let security = RouteSecurity {
ip_allow_list: None,
ip_block_list: Some(vec!["10.0.0.1".to_string()]),
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
};
reg.recycle_for_security_change("r1", &security);
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
}
#[test]
fn test_recycle_for_route_change() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: None,
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: None,
route_id: Some("r2".to_string()),
});
reg.recycle_for_route_change("r1");
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
}
#[test]
fn test_cleanup_removed_routes() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: None,
route_id: Some("active".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: None,
route_id: Some("removed".to_string()),
});
let mut active = HashSet::new();
active.insert("active".to_string());
reg.cleanup_removed_routes(&active);
// "removed" route entry was cleaned from registry
// (but guard is still alive so len may differ — the retain already removed it)
assert!(!t1.is_cancelled()); // not cancelled by cleanup, only by token cascade
assert!(!t2.is_cancelled()); // cleanup doesn't cancel, just removes from registry
}
}