fix(rustproxy): Cancel connections for routes removed/disabled by adding per-route cancellation tokens and make RouteManager swappable (ArcSwap) for runtime updates

This commit is contained in:
2026-03-03 16:14:16 +00:00
parent bb471a8cc9
commit d51b2c5890
6 changed files with 81 additions and 20 deletions

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
@@ -162,14 +163,18 @@ pub struct TcpListenerManager {
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
/// Global connection semaphore — limits total simultaneous connections.
conn_semaphore: Arc<tokio::sync::Semaphore>,
/// Per-route cancellation tokens (child of cancel_token).
/// When a route is removed, its token is cancelled, terminating all connections on that route.
route_cancels: Arc<DashMap<String, CancellationToken>>,
}
impl TcpListenerManager {
pub fn new(route_manager: Arc<RouteManager>) -> Self {
let metrics = Arc::new(MetricsCollector::new());
let conn_config = ConnectionConfig::default();
let route_manager_swap = Arc::new(ArcSwap::from(route_manager));
let mut http_proxy_svc = HttpProxyService::with_connect_timeout(
Arc::clone(&route_manager),
Arc::clone(&route_manager_swap),
Arc::clone(&metrics),
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
);
@@ -188,7 +193,7 @@ impl TcpListenerManager {
let max_conns = conn_config.max_connections as usize;
Self {
listeners: HashMap::new(),
route_manager: Arc::new(ArcSwap::from(route_manager)),
route_manager: route_manager_swap,
metrics,
tls_configs: Arc::new(ArcSwap::from(Arc::new(HashMap::new()))),
shared_tls_acceptor: Arc::new(ArcSwap::from(Arc::new(None))),
@@ -198,14 +203,16 @@ impl TcpListenerManager {
cancel_token: CancellationToken::new(),
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
conn_semaphore: Arc::new(tokio::sync::Semaphore::new(max_conns)),
route_cancels: Arc::new(DashMap::new()),
}
}
/// Create with a metrics collector.
pub fn with_metrics(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
let conn_config = ConnectionConfig::default();
let route_manager_swap = Arc::new(ArcSwap::from(route_manager));
let mut http_proxy_svc = HttpProxyService::with_connect_timeout(
Arc::clone(&route_manager),
Arc::clone(&route_manager_swap),
Arc::clone(&metrics),
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
);
@@ -224,7 +231,7 @@ impl TcpListenerManager {
let max_conns = conn_config.max_connections as usize;
Self {
listeners: HashMap::new(),
route_manager: Arc::new(ArcSwap::from(route_manager)),
route_manager: route_manager_swap,
metrics,
tls_configs: Arc::new(ArcSwap::from(Arc::new(HashMap::new()))),
shared_tls_acceptor: Arc::new(ArcSwap::from(Arc::new(None))),
@@ -234,6 +241,7 @@ impl TcpListenerManager {
cancel_token: CancellationToken::new(),
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
conn_semaphore: Arc::new(tokio::sync::Semaphore::new(max_conns)),
route_cancels: Arc::new(DashMap::new()),
}
}
@@ -245,10 +253,9 @@ impl TcpListenerManager {
));
self.conn_semaphore = Arc::new(tokio::sync::Semaphore::new(config.max_connections as usize));
// Rebuild http_proxy with updated timeouts
let rm = self.route_manager.load_full();
// Rebuild http_proxy with updated timeouts (shares the same ArcSwap<RouteManager>)
let mut http_proxy_svc = HttpProxyService::with_connect_timeout(
rm,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
std::time::Duration::from_millis(config.connection_timeout_ms),
);
@@ -317,12 +324,13 @@ impl TcpListenerManager {
let cancel = self.cancel_token.clone();
let relay = Arc::clone(&self.socket_handler_relay);
let semaphore = Arc::clone(&self.conn_semaphore);
let route_cancels = Arc::clone(&self.route_cancels);
let handle = tokio::spawn(async move {
Self::accept_loop(
listener, port, route_manager_swap, metrics, tls_configs,
shared_tls_acceptor, http_proxy, conn_config, conn_tracker, cancel, relay,
semaphore,
semaphore, route_cancels,
).await;
});
@@ -401,6 +409,20 @@ impl TcpListenerManager {
self.route_manager.store(route_manager);
}
/// Cancel connections on routes that no longer exist in the active set.
/// Existing connections on removed routes are terminated via their per-route CancellationToken.
pub fn invalidate_removed_routes(&self, active_route_ids: &std::collections::HashSet<String>) {
self.route_cancels.retain(|id, token| {
if active_route_ids.contains(id) {
true
} else {
info!("Cancelling connections for removed route '{}'", id);
token.cancel();
false // remove cancelled token from map
}
});
}
/// Prune HTTP proxy caches for route IDs that are no longer active.
pub fn prune_http_proxy_caches(&self, active_route_ids: &std::collections::HashSet<String>) {
self.http_proxy.prune_stale_routes(active_route_ids);
@@ -430,6 +452,7 @@ impl TcpListenerManager {
cancel: CancellationToken,
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
conn_semaphore: Arc<tokio::sync::Semaphore>,
route_cancels: Arc<DashMap<String, CancellationToken>>,
) {
loop {
tokio::select! {
@@ -484,6 +507,7 @@ impl TcpListenerManager {
let ct = Arc::clone(&conn_tracker);
let cn = cancel.clone();
let sr = Arc::clone(&socket_handler_relay);
let rc = Arc::clone(&route_cancels);
debug!("Accepted connection from {} on port {}", peer_addr, port);
tokio::spawn(async move {
@@ -492,7 +516,7 @@ impl TcpListenerManager {
// RAII guard ensures connection_closed is called on all paths
let _ct_guard = ConnectionTrackerGuard::new(ct, ip);
let result = Self::handle_connection(
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr,
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr, rc,
).await;
if let Err(e) = result {
debug!("Connection error from {}: {}", peer_addr, e);
@@ -522,6 +546,7 @@ impl TcpListenerManager {
conn_config: Arc<ConnectionConfig>,
cancel: CancellationToken,
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
route_cancels: Arc<DashMap<String, CancellationToken>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use tokio::io::AsyncReadExt;
@@ -626,6 +651,14 @@ impl TcpListenerManager {
let target_port = target.port.resolve(port);
let route_id = quick_match.route.id.as_deref();
// Resolve per-route cancel token (child of global cancel)
let conn_cancel = match route_id {
Some(id) => route_cancels.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
None => cancel.clone(),
};
// Check route-level IP security
if let Some(ref security) = quick_match.route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
@@ -680,7 +713,7 @@ impl TcpListenerManager {
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend_w, None,
inactivity_timeout, max_lifetime, cancel,
inactivity_timeout, max_lifetime, conn_cancel,
Some(forwarder::ForwardMetricsCtx {
collector: Arc::clone(&metrics),
route_id: route_id.map(|s| s.to_string()),
@@ -690,7 +723,7 @@ impl TcpListenerManager {
} else {
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend, None,
inactivity_timeout, max_lifetime, cancel,
inactivity_timeout, max_lifetime, conn_cancel,
Some(forwarder::ForwardMetricsCtx {
collector: Arc::clone(&metrics),
route_id: route_id.map(|s| s.to_string()),
@@ -795,6 +828,16 @@ impl TcpListenerManager {
let route_id = route_match.route.id.as_deref();
// Resolve per-route cancel token (child of global cancel).
// When this route is removed via updateRoutes, the token is cancelled,
// terminating all connections on this route.
let cancel = match route_id {
Some(id) => route_cancels.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
None => cancel,
};
// Check route-level IP security for passthrough connections
if let Some(ref security) = route_match.route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(