330 lines
11 KiB
Rust
330 lines
11 KiB
Rust
//! Shared connection registry for selective connection recycling.
|
|
//!
|
|
//! Tracks active connections across both TCP and QUIC with metadata
|
|
//! (source IP, SNI domain, route ID, cancel token) so that connections
|
|
//! can be selectively recycled when certificates, security rules, or
|
|
//! route targets change.
|
|
|
|
use std::collections::HashSet;
|
|
use std::net::IpAddr;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
|
|
use dashmap::DashMap;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::info;
|
|
|
|
use rustproxy_config::RouteSecurity;
|
|
use rustproxy_http::request_filter::RequestFilter;
|
|
use rustproxy_routing::matchers::domain_matches;
|
|
|
|
/// Metadata about an active connection.
|
|
pub struct ConnectionEntry {
|
|
/// Per-connection cancel token (child of per-route token).
|
|
pub cancel: CancellationToken,
|
|
/// Client source IP.
|
|
pub source_ip: IpAddr,
|
|
/// SNI domain from TLS handshake (None for non-TLS connections).
|
|
pub domain: Option<String>,
|
|
/// Route ID this connection was matched to (None if route has no ID).
|
|
pub route_id: Option<String>,
|
|
}
|
|
|
|
/// Transport-agnostic registry of active connections.
|
|
///
|
|
/// Used by both `TcpListenerManager` and `UdpListenerManager` to track
|
|
/// connections and enable selective recycling on config changes.
|
|
pub struct ConnectionRegistry {
|
|
connections: DashMap<u64, ConnectionEntry>,
|
|
next_id: AtomicU64,
|
|
}
|
|
|
|
impl ConnectionRegistry {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
connections: DashMap::new(),
|
|
next_id: AtomicU64::new(1),
|
|
}
|
|
}
|
|
|
|
/// Register a connection and return its ID + RAII guard.
|
|
///
|
|
/// The guard automatically removes the connection from the registry on drop.
|
|
pub fn register(self: &Arc<Self>, entry: ConnectionEntry) -> (u64, ConnectionRegistryGuard) {
|
|
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
|
self.connections.insert(id, entry);
|
|
let guard = ConnectionRegistryGuard {
|
|
registry: Arc::clone(self),
|
|
conn_id: id,
|
|
};
|
|
(id, guard)
|
|
}
|
|
|
|
/// Number of tracked connections (for metrics/debugging).
|
|
pub fn len(&self) -> usize {
|
|
self.connections.len()
|
|
}
|
|
|
|
/// Recycle connections whose SNI domain matches a renewed certificate domain.
|
|
///
|
|
/// Uses bidirectional domain matching so that:
|
|
/// - Cert `*.example.com` recycles connections for `sub.example.com`
|
|
/// - Cert `sub.example.com` recycles connections on routes with `*.example.com`
|
|
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
|
let mut recycled = 0u64;
|
|
self.connections.retain(|_, entry| {
|
|
let matches = entry.domain.as_deref()
|
|
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
|
.unwrap_or(false);
|
|
if matches {
|
|
entry.cancel.cancel();
|
|
recycled += 1;
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
});
|
|
if recycled > 0 {
|
|
info!(
|
|
"Recycled {} connection(s) for cert change on domain '{}'",
|
|
recycled, cert_domain
|
|
);
|
|
}
|
|
}
|
|
|
|
/// Recycle connections on a route whose security config changed.
|
|
///
|
|
/// Re-evaluates each connection's source IP against the new security rules.
|
|
/// Only connections from now-blocked IPs are terminated; allowed IPs are undisturbed.
|
|
pub fn recycle_for_security_change(&self, route_id: &str, new_security: &RouteSecurity) {
|
|
let mut recycled = 0u64;
|
|
self.connections.retain(|_, entry| {
|
|
if entry.route_id.as_deref() == Some(route_id) {
|
|
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
|
|
info!(
|
|
"Terminating connection from {} — IP now blocked on route '{}'",
|
|
entry.source_ip, route_id
|
|
);
|
|
entry.cancel.cancel();
|
|
recycled += 1;
|
|
return false;
|
|
}
|
|
}
|
|
true
|
|
});
|
|
if recycled > 0 {
|
|
info!(
|
|
"Recycled {} connection(s) for security change on route '{}'",
|
|
recycled, route_id
|
|
);
|
|
}
|
|
}
|
|
|
|
/// Recycle all connections on a route (e.g., when targets changed).
|
|
pub fn recycle_for_route_change(&self, route_id: &str) {
|
|
let mut recycled = 0u64;
|
|
self.connections.retain(|_, entry| {
|
|
if entry.route_id.as_deref() == Some(route_id) {
|
|
entry.cancel.cancel();
|
|
recycled += 1;
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
});
|
|
if recycled > 0 {
|
|
info!(
|
|
"Recycled {} connection(s) for config change on route '{}'",
|
|
recycled, route_id
|
|
);
|
|
}
|
|
}
|
|
|
|
/// Remove connections on routes that no longer exist.
|
|
///
|
|
/// This complements per-route CancellationToken cancellation —
|
|
/// the token cascade handles graceful shutdown, this cleans up the registry.
|
|
pub fn cleanup_removed_routes(&self, active_route_ids: &HashSet<String>) {
|
|
self.connections.retain(|_, entry| {
|
|
match &entry.route_id {
|
|
Some(id) => active_route_ids.contains(id),
|
|
None => true, // keep connections without a route ID
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
/// RAII guard that removes a connection from the registry on drop.
|
|
pub struct ConnectionRegistryGuard {
|
|
registry: Arc<ConnectionRegistry>,
|
|
conn_id: u64,
|
|
}
|
|
|
|
impl Drop for ConnectionRegistryGuard {
|
|
fn drop(&mut self) {
|
|
self.registry.connections.remove(&self.conn_id);
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn make_registry() -> Arc<ConnectionRegistry> {
|
|
Arc::new(ConnectionRegistry::new())
|
|
}
|
|
|
|
#[test]
|
|
fn test_register_and_guard_cleanup() {
|
|
let reg = make_registry();
|
|
let token = CancellationToken::new();
|
|
let entry = ConnectionEntry {
|
|
cancel: token.clone(),
|
|
source_ip: "10.0.0.1".parse().unwrap(),
|
|
domain: Some("example.com".to_string()),
|
|
route_id: Some("route-1".to_string()),
|
|
};
|
|
let (id, guard) = reg.register(entry);
|
|
assert_eq!(reg.len(), 1);
|
|
assert!(reg.connections.contains_key(&id));
|
|
|
|
drop(guard);
|
|
assert_eq!(reg.len(), 0);
|
|
assert!(!token.is_cancelled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_recycle_for_cert_change_exact() {
|
|
let reg = make_registry();
|
|
let t1 = CancellationToken::new();
|
|
let t2 = CancellationToken::new();
|
|
let (_, _g1) = reg.register(ConnectionEntry {
|
|
cancel: t1.clone(),
|
|
source_ip: "10.0.0.1".parse().unwrap(),
|
|
domain: Some("api.example.com".to_string()),
|
|
route_id: Some("r1".to_string()),
|
|
});
|
|
let (_, _g2) = reg.register(ConnectionEntry {
|
|
cancel: t2.clone(),
|
|
source_ip: "10.0.0.2".parse().unwrap(),
|
|
domain: Some("other.com".to_string()),
|
|
route_id: Some("r2".to_string()),
|
|
});
|
|
|
|
reg.recycle_for_cert_change("api.example.com");
|
|
assert!(t1.is_cancelled());
|
|
assert!(!t2.is_cancelled());
|
|
// Registry retains unmatched entry (guard still alive keeps it too,
|
|
// but the retain removed the matched one before guard could)
|
|
}
|
|
|
|
#[test]
|
|
fn test_recycle_for_cert_change_wildcard() {
|
|
let reg = make_registry();
|
|
let t1 = CancellationToken::new();
|
|
let t2 = CancellationToken::new();
|
|
let (_, _g1) = reg.register(ConnectionEntry {
|
|
cancel: t1.clone(),
|
|
source_ip: "10.0.0.1".parse().unwrap(),
|
|
domain: Some("sub.example.com".to_string()),
|
|
route_id: Some("r1".to_string()),
|
|
});
|
|
let (_, _g2) = reg.register(ConnectionEntry {
|
|
cancel: t2.clone(),
|
|
source_ip: "10.0.0.2".parse().unwrap(),
|
|
domain: Some("other.com".to_string()),
|
|
route_id: Some("r2".to_string()),
|
|
});
|
|
|
|
// Wildcard cert should match subdomain
|
|
reg.recycle_for_cert_change("*.example.com");
|
|
assert!(t1.is_cancelled());
|
|
assert!(!t2.is_cancelled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_recycle_for_security_change() {
|
|
let reg = make_registry();
|
|
let t1 = CancellationToken::new();
|
|
let t2 = CancellationToken::new();
|
|
let (_, _g1) = reg.register(ConnectionEntry {
|
|
cancel: t1.clone(),
|
|
source_ip: "10.0.0.1".parse().unwrap(),
|
|
domain: None,
|
|
route_id: Some("r1".to_string()),
|
|
});
|
|
let (_, _g2) = reg.register(ConnectionEntry {
|
|
cancel: t2.clone(),
|
|
source_ip: "10.0.0.2".parse().unwrap(),
|
|
domain: None,
|
|
route_id: Some("r1".to_string()),
|
|
});
|
|
|
|
// Block 10.0.0.1, allow 10.0.0.2
|
|
let security = RouteSecurity {
|
|
ip_allow_list: None,
|
|
ip_block_list: Some(vec!["10.0.0.1".to_string()]),
|
|
max_connections: None,
|
|
authentication: None,
|
|
rate_limit: None,
|
|
basic_auth: None,
|
|
jwt_auth: None,
|
|
};
|
|
|
|
reg.recycle_for_security_change("r1", &security);
|
|
assert!(t1.is_cancelled());
|
|
assert!(!t2.is_cancelled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_recycle_for_route_change() {
|
|
let reg = make_registry();
|
|
let t1 = CancellationToken::new();
|
|
let t2 = CancellationToken::new();
|
|
let (_, _g1) = reg.register(ConnectionEntry {
|
|
cancel: t1.clone(),
|
|
source_ip: "10.0.0.1".parse().unwrap(),
|
|
domain: None,
|
|
route_id: Some("r1".to_string()),
|
|
});
|
|
let (_, _g2) = reg.register(ConnectionEntry {
|
|
cancel: t2.clone(),
|
|
source_ip: "10.0.0.2".parse().unwrap(),
|
|
domain: None,
|
|
route_id: Some("r2".to_string()),
|
|
});
|
|
|
|
reg.recycle_for_route_change("r1");
|
|
assert!(t1.is_cancelled());
|
|
assert!(!t2.is_cancelled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_cleanup_removed_routes() {
|
|
let reg = make_registry();
|
|
let t1 = CancellationToken::new();
|
|
let t2 = CancellationToken::new();
|
|
let (_, _g1) = reg.register(ConnectionEntry {
|
|
cancel: t1.clone(),
|
|
source_ip: "10.0.0.1".parse().unwrap(),
|
|
domain: None,
|
|
route_id: Some("active".to_string()),
|
|
});
|
|
let (_, _g2) = reg.register(ConnectionEntry {
|
|
cancel: t2.clone(),
|
|
source_ip: "10.0.0.2".parse().unwrap(),
|
|
domain: None,
|
|
route_id: Some("removed".to_string()),
|
|
});
|
|
|
|
let mut active = HashSet::new();
|
|
active.insert("active".to_string());
|
|
reg.cleanup_removed_routes(&active);
|
|
|
|
// "removed" route entry was cleaned from registry
|
|
// (but guard is still alive so len may differ — the retain already removed it)
|
|
assert!(!t1.is_cancelled()); // not cancelled by cleanup, only by token cascade
|
|
assert!(!t2.is_cancelled()); // cleanup doesn't cancel, just removes from registry
|
|
}
|
|
}
|