fix(rustproxy): Cancel connections for routes removed/disabled by adding per-route cancellation tokens and make RouteManager swappable (ArcSwap) for runtime updates
This commit is contained in:
@@ -8,6 +8,7 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
@@ -133,7 +134,7 @@ async fn connect_tls_backend(
|
||||
|
||||
/// HTTP proxy service that processes HTTP traffic.
|
||||
pub struct HttpProxyService {
|
||||
route_manager: Arc<RouteManager>,
|
||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
upstream_selector: UpstreamSelector,
|
||||
/// Timeout for connecting to upstream backends.
|
||||
@@ -161,7 +162,7 @@ pub struct HttpProxyService {
|
||||
}
|
||||
|
||||
impl HttpProxyService {
|
||||
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
pub fn new(route_manager: Arc<ArcSwap<RouteManager>>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
Self {
|
||||
route_manager,
|
||||
metrics,
|
||||
@@ -182,7 +183,7 @@ impl HttpProxyService {
|
||||
|
||||
/// Create with a custom connect timeout.
|
||||
pub fn with_connect_timeout(
|
||||
route_manager: Arc<RouteManager>,
|
||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
connect_timeout: std::time::Duration,
|
||||
) -> Self {
|
||||
@@ -405,7 +406,8 @@ impl HttpProxyService {
|
||||
protocol: Some("http"),
|
||||
};
|
||||
|
||||
let route_match = match self.route_manager.find_route(&ctx) {
|
||||
let current_rm = self.route_manager.load();
|
||||
let route_match = match current_rm.find_route(&ctx) {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for HTTP request to {:?}{}", host, path);
|
||||
@@ -1759,7 +1761,7 @@ impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier {
|
||||
impl Default for HttpProxyService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
route_manager: Arc::new(RouteManager::new(vec![])),
|
||||
route_manager: Arc::new(ArcSwap::from(Arc::new(RouteManager::new(vec![])))),
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use arc_swap::ArcSwap;
|
||||
use dashmap::DashMap;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -162,14 +163,18 @@ pub struct TcpListenerManager {
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
/// Global connection semaphore — limits total simultaneous connections.
|
||||
conn_semaphore: Arc<tokio::sync::Semaphore>,
|
||||
/// Per-route cancellation tokens (child of cancel_token).
|
||||
/// When a route is removed, its token is cancelled, terminating all connections on that route.
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
}
|
||||
|
||||
impl TcpListenerManager {
|
||||
pub fn new(route_manager: Arc<RouteManager>) -> Self {
|
||||
let metrics = Arc::new(MetricsCollector::new());
|
||||
let conn_config = ConnectionConfig::default();
|
||||
let route_manager_swap = Arc::new(ArcSwap::from(route_manager));
|
||||
let mut http_proxy_svc = HttpProxyService::with_connect_timeout(
|
||||
Arc::clone(&route_manager),
|
||||
Arc::clone(&route_manager_swap),
|
||||
Arc::clone(&metrics),
|
||||
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
|
||||
);
|
||||
@@ -188,7 +193,7 @@ impl TcpListenerManager {
|
||||
let max_conns = conn_config.max_connections as usize;
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager: Arc::new(ArcSwap::from(route_manager)),
|
||||
route_manager: route_manager_swap,
|
||||
metrics,
|
||||
tls_configs: Arc::new(ArcSwap::from(Arc::new(HashMap::new()))),
|
||||
shared_tls_acceptor: Arc::new(ArcSwap::from(Arc::new(None))),
|
||||
@@ -198,14 +203,16 @@ impl TcpListenerManager {
|
||||
cancel_token: CancellationToken::new(),
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
conn_semaphore: Arc::new(tokio::sync::Semaphore::new(max_conns)),
|
||||
route_cancels: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a metrics collector.
|
||||
pub fn with_metrics(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
let conn_config = ConnectionConfig::default();
|
||||
let route_manager_swap = Arc::new(ArcSwap::from(route_manager));
|
||||
let mut http_proxy_svc = HttpProxyService::with_connect_timeout(
|
||||
Arc::clone(&route_manager),
|
||||
Arc::clone(&route_manager_swap),
|
||||
Arc::clone(&metrics),
|
||||
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
|
||||
);
|
||||
@@ -224,7 +231,7 @@ impl TcpListenerManager {
|
||||
let max_conns = conn_config.max_connections as usize;
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager: Arc::new(ArcSwap::from(route_manager)),
|
||||
route_manager: route_manager_swap,
|
||||
metrics,
|
||||
tls_configs: Arc::new(ArcSwap::from(Arc::new(HashMap::new()))),
|
||||
shared_tls_acceptor: Arc::new(ArcSwap::from(Arc::new(None))),
|
||||
@@ -234,6 +241,7 @@ impl TcpListenerManager {
|
||||
cancel_token: CancellationToken::new(),
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
conn_semaphore: Arc::new(tokio::sync::Semaphore::new(max_conns)),
|
||||
route_cancels: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,10 +253,9 @@ impl TcpListenerManager {
|
||||
));
|
||||
self.conn_semaphore = Arc::new(tokio::sync::Semaphore::new(config.max_connections as usize));
|
||||
|
||||
// Rebuild http_proxy with updated timeouts
|
||||
let rm = self.route_manager.load_full();
|
||||
// Rebuild http_proxy with updated timeouts (shares the same ArcSwap<RouteManager>)
|
||||
let mut http_proxy_svc = HttpProxyService::with_connect_timeout(
|
||||
rm,
|
||||
Arc::clone(&self.route_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
std::time::Duration::from_millis(config.connection_timeout_ms),
|
||||
);
|
||||
@@ -317,12 +324,13 @@ impl TcpListenerManager {
|
||||
let cancel = self.cancel_token.clone();
|
||||
let relay = Arc::clone(&self.socket_handler_relay);
|
||||
let semaphore = Arc::clone(&self.conn_semaphore);
|
||||
let route_cancels = Arc::clone(&self.route_cancels);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::accept_loop(
|
||||
listener, port, route_manager_swap, metrics, tls_configs,
|
||||
shared_tls_acceptor, http_proxy, conn_config, conn_tracker, cancel, relay,
|
||||
semaphore,
|
||||
semaphore, route_cancels,
|
||||
).await;
|
||||
});
|
||||
|
||||
@@ -401,6 +409,20 @@ impl TcpListenerManager {
|
||||
self.route_manager.store(route_manager);
|
||||
}
|
||||
|
||||
/// Cancel connections on routes that no longer exist in the active set.
|
||||
/// Existing connections on removed routes are terminated via their per-route CancellationToken.
|
||||
pub fn invalidate_removed_routes(&self, active_route_ids: &std::collections::HashSet<String>) {
|
||||
self.route_cancels.retain(|id, token| {
|
||||
if active_route_ids.contains(id) {
|
||||
true
|
||||
} else {
|
||||
info!("Cancelling connections for removed route '{}'", id);
|
||||
token.cancel();
|
||||
false // remove cancelled token from map
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Prune HTTP proxy caches for route IDs that are no longer active.
|
||||
pub fn prune_http_proxy_caches(&self, active_route_ids: &std::collections::HashSet<String>) {
|
||||
self.http_proxy.prune_stale_routes(active_route_ids);
|
||||
@@ -430,6 +452,7 @@ impl TcpListenerManager {
|
||||
cancel: CancellationToken,
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
conn_semaphore: Arc<tokio::sync::Semaphore>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -484,6 +507,7 @@ impl TcpListenerManager {
|
||||
let ct = Arc::clone(&conn_tracker);
|
||||
let cn = cancel.clone();
|
||||
let sr = Arc::clone(&socket_handler_relay);
|
||||
let rc = Arc::clone(&route_cancels);
|
||||
debug!("Accepted connection from {} on port {}", peer_addr, port);
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -492,7 +516,7 @@ impl TcpListenerManager {
|
||||
// RAII guard ensures connection_closed is called on all paths
|
||||
let _ct_guard = ConnectionTrackerGuard::new(ct, ip);
|
||||
let result = Self::handle_connection(
|
||||
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr,
|
||||
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr, rc,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
debug!("Connection error from {}: {}", peer_addr, e);
|
||||
@@ -522,6 +546,7 @@ impl TcpListenerManager {
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
cancel: CancellationToken,
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
@@ -626,6 +651,14 @@ impl TcpListenerManager {
|
||||
let target_port = target.port.resolve(port);
|
||||
let route_id = quick_match.route.id.as_deref();
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel)
|
||||
let conn_cancel = match route_id {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel.clone(),
|
||||
};
|
||||
|
||||
// Check route-level IP security
|
||||
if let Some(ref security) = quick_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
@@ -680,7 +713,7 @@ impl TcpListenerManager {
|
||||
|
||||
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend_w, None,
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
inactivity_timeout, max_lifetime, conn_cancel,
|
||||
Some(forwarder::ForwardMetricsCtx {
|
||||
collector: Arc::clone(&metrics),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
@@ -690,7 +723,7 @@ impl TcpListenerManager {
|
||||
} else {
|
||||
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, None,
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
inactivity_timeout, max_lifetime, conn_cancel,
|
||||
Some(forwarder::ForwardMetricsCtx {
|
||||
collector: Arc::clone(&metrics),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
@@ -795,6 +828,16 @@ impl TcpListenerManager {
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel).
|
||||
// When this route is removed via updateRoutes, the token is cancelled,
|
||||
// terminating all connections on this route.
|
||||
let cancel = match route_id {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel,
|
||||
};
|
||||
|
||||
// Check route-level IP security for passthrough connections
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
|
||||
@@ -610,6 +610,8 @@ impl RustProxy {
|
||||
// Update listener manager
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.update_route_manager(Arc::clone(&new_manager));
|
||||
// Cancel connections on routes that were removed or disabled
|
||||
listener.invalidate_removed_routes(&active_route_ids);
|
||||
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
|
||||
listener.prune_http_proxy_caches(&active_route_ids);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user