Files
smartproxy/rust/crates/rustproxy-http/src/protocol_cache.rs
T

664 lines
25 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Bounded, sliding-TTL protocol detection cache with periodic re-probing and failure suppression.
//!
//! Caches the detected protocol (H1, H2, or H3) per backend endpoint and requested
//! domain (host:port + requested_host). This prevents cache oscillation when multiple
//! frontend domains share the same backend but differ in protocol support.
//!
//! ## Sliding TTL
//!
//! Each cache hit refreshes the entry's expiry timer (`last_accessed_at`). Entries
//! remain valid for up to 1 day of continuous use. Every 5 minutes, the next request
//! triggers an inline ALPN re-probe to verify the cached protocol is still correct.
//!
//! ## Upgrade signals
//!
//! - ALPN (TLS handshake) → detects H2 vs H1
//! - Alt-Svc (response header) → advertises H3
//!
//! ## Protocol transitions
//!
//! All protocol changes are logged at `info!()` level with the reason:
//! "Protocol transition: H1 → H2 because periodic ALPN re-probe"
//!
//! ## Failure suppression
//!
//! When a protocol fails, `record_failure()` prevents upgrade signals from
//! re-introducing it until an escalating cooldown expires (5s → 10s → ... → 300s).
//! Within-request escalation is allowed via `can_retry()` after a 5s minimum gap.
//!
//! ## Total failure eviction
//!
//! When all protocols (H3, H2, H1) fail for a backend, the cache entry is evicted
//! entirely via `evict()`, forcing a fresh probe on the next request.
//!
//! Cascading: when a lower protocol also fails, higher protocol cooldowns are
//! reduced to 5s remaining (not instant clear), preventing tight retry loops.
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tracing::{debug, info};
/// Sliding TTL for cached protocol detection results.
/// Entries that haven't been accessed for this duration are evicted.
/// Each `get()` call refreshes the timer (sliding window).
const PROTOCOL_CACHE_TTL: Duration = Duration::from_secs(86400); // 1 day
/// Interval between inline ALPN re-probes for H1/H2 entries.
/// When a cached entry's `last_probed_at` exceeds this, the next request
/// triggers an ALPN re-probe to verify the backend still speaks the same protocol.
const PROTOCOL_REPROBE_INTERVAL: Duration = Duration::from_secs(300); // 5 minutes
/// Maximum number of entries in the protocol cache.
const PROTOCOL_CACHE_MAX_ENTRIES: usize = 4096;
/// Background cleanup interval.
const PROTOCOL_CACHE_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
/// Minimum cooldown between retry attempts of a failed protocol.
const PROTOCOL_FAILURE_COOLDOWN: Duration = Duration::from_secs(5);
/// Maximum cooldown (escalation ceiling).
const PROTOCOL_FAILURE_MAX_COOLDOWN: Duration = Duration::from_secs(300);
/// Consecutive failure count at which cooldown reaches maximum.
/// 5s × 2^5 = 160s, 5s × 2^6 = 320s → capped at 300s.
const PROTOCOL_FAILURE_ESCALATION_CAP: u32 = 6;
/// Detected backend protocol.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DetectedProtocol {
H1,
H2,
H3,
}
impl std::fmt::Display for DetectedProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DetectedProtocol::H1 => write!(f, "H1"),
DetectedProtocol::H2 => write!(f, "H2"),
DetectedProtocol::H3 => write!(f, "H3"),
}
}
}
/// Result of a protocol cache lookup.
#[derive(Debug, Clone, Copy)]
pub struct CachedProtocol {
pub protocol: DetectedProtocol,
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
pub h3_port: Option<u16>,
/// True if the entry's `last_probed_at` exceeds `PROTOCOL_REPROBE_INTERVAL`.
/// Caller should perform an inline ALPN re-probe and call `update_probe_result()`.
/// Always `false` for H3 entries (H3 is discovered via Alt-Svc, not ALPN).
pub needs_reprobe: bool,
}
/// Key for the protocol cache: (host, port, requested_host).
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProtocolCacheKey {
pub host: String,
pub port: u16,
/// The incoming request's domain (Host header / :authority).
/// Distinguishes protocol detection when multiple domains share the same backend.
pub requested_host: Option<String>,
}
/// A cached protocol detection result with timestamps.
struct CachedEntry {
protocol: DetectedProtocol,
/// When this protocol was first detected (or last changed).
detected_at: Instant,
/// Last time any request used this entry (sliding-window TTL).
last_accessed_at: Instant,
/// Last time an ALPN re-probe was performed for this entry.
last_probed_at: Instant,
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
h3_port: Option<u16>,
}
/// Failure record for a single protocol level.
#[derive(Debug, Clone)]
struct FailureRecord {
/// When the failure was last recorded.
failed_at: Instant,
/// Current cooldown duration. Escalates on consecutive failures.
cooldown: Duration,
/// Number of consecutive failures (for escalation).
consecutive_failures: u32,
}
/// Per-key failure state. Tracks failures at each upgradeable protocol level.
/// H1 is never tracked (it's the protocol floor — nothing to fall back to).
#[derive(Debug, Clone, Default)]
struct FailureState {
h2: Option<FailureRecord>,
h3: Option<FailureRecord>,
}
impl FailureState {
fn is_empty(&self) -> bool {
self.h2.is_none() && self.h3.is_none()
}
fn all_expired(&self) -> bool {
let h2_expired = self
.h2
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
let h3_expired = self
.h3
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
h2_expired && h3_expired
}
fn get(&self, protocol: DetectedProtocol) -> Option<&FailureRecord> {
match protocol {
DetectedProtocol::H2 => self.h2.as_ref(),
DetectedProtocol::H3 => self.h3.as_ref(),
DetectedProtocol::H1 => None,
}
}
fn get_mut(&mut self, protocol: DetectedProtocol) -> &mut Option<FailureRecord> {
match protocol {
DetectedProtocol::H2 => &mut self.h2,
DetectedProtocol::H3 => &mut self.h3,
DetectedProtocol::H1 => unreachable!("H1 failures are never recorded"),
}
}
}
/// Snapshot of a single protocol cache entry, suitable for metrics/UI display.
#[derive(Debug, Clone)]
pub struct ProtocolCacheEntry {
pub host: String,
pub port: u16,
pub domain: Option<String>,
pub protocol: String,
pub h3_port: Option<u16>,
pub age_secs: u64,
pub last_accessed_secs: u64,
pub last_probed_secs: u64,
pub h2_suppressed: bool,
pub h3_suppressed: bool,
pub h2_cooldown_remaining_secs: Option<u64>,
pub h3_cooldown_remaining_secs: Option<u64>,
pub h2_consecutive_failures: Option<u32>,
pub h3_consecutive_failures: Option<u32>,
}
/// Exponential backoff: PROTOCOL_FAILURE_COOLDOWN × 2^(n-1), capped at MAX.
fn escalate_cooldown(consecutive: u32) -> Duration {
let base = PROTOCOL_FAILURE_COOLDOWN.as_secs();
let exp = consecutive.saturating_sub(1).min(63) as u64;
let secs = base.saturating_mul(1u64.checked_shl(exp as u32).unwrap_or(u64::MAX));
Duration::from_secs(secs.min(PROTOCOL_FAILURE_MAX_COOLDOWN.as_secs()))
}
/// Bounded, sliding-TTL protocol detection cache with failure suppression.
///
/// Memory safety guarantees:
/// - Hard cap at `PROTOCOL_CACHE_MAX_ENTRIES` — cannot grow unboundedly.
/// - Sliding TTL expiry — entries age out after 1 day without access.
/// - Background cleanup task — proactively removes expired entries every 60s.
/// - `clear()` — called on route updates to discard stale detections.
/// - `Drop` — aborts the background task to prevent dangling tokio tasks.
pub struct ProtocolCache {
cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>,
/// Generic protocol failure suppression map. Tracks per-protocol failure
/// records (H2, H3) for each cache key. Used to prevent upgrade signals
/// (ALPN, Alt-Svc) from re-introducing failed protocols.
failures: Arc<DashMap<ProtocolCacheKey, FailureState>>,
cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
impl ProtocolCache {
/// Create a new protocol cache and start the background cleanup task.
pub fn new() -> Self {
let cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>> = Arc::new(DashMap::new());
let failures: Arc<DashMap<ProtocolCacheKey, FailureState>> = Arc::new(DashMap::new());
let cache_clone = Arc::clone(&cache);
let failures_clone = Arc::clone(&failures);
let cleanup_handle = tokio::spawn(async move {
Self::cleanup_loop(cache_clone, failures_clone).await;
});
Self {
cache,
failures,
cleanup_handle: Some(cleanup_handle),
}
}
/// Look up the cached protocol for a backend endpoint.
///
/// Returns `None` if not cached or expired (caller should probe via ALPN).
/// On hit, refreshes `last_accessed_at` (sliding TTL) and sets `needs_reprobe`
/// if the entry hasn't been probed in over 5 minutes (H1/H2 only).
pub fn get(&self, key: &ProtocolCacheKey) -> Option<CachedProtocol> {
let mut entry = self.cache.get_mut(key)?;
if entry.last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL {
// Refresh sliding TTL
entry.last_accessed_at = Instant::now();
// H3 is the ceiling — can't ALPN-probe for H3 (discovered via Alt-Svc).
// Only H1/H2 entries trigger periodic re-probing.
let needs_reprobe = entry.protocol != DetectedProtocol::H3
&& entry.last_probed_at.elapsed() >= PROTOCOL_REPROBE_INTERVAL;
Some(CachedProtocol {
protocol: entry.protocol,
h3_port: entry.h3_port,
needs_reprobe,
})
} else {
// Expired — remove and return None to trigger re-probe
drop(entry); // release DashMap ref before remove
self.cache.remove(key);
None
}
}
/// Insert a detected protocol into the cache.
/// Returns `false` if suppressed due to active failure suppression.
///
/// **Key semantic**: only suppresses if the protocol being inserted matches
/// a suppressed protocol. H1 inserts are NEVER suppressed — downgrades
/// always succeed.
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, reason: &str) -> bool {
if self.is_suppressed(&key, protocol) {
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = ?protocol,
"Protocol cache insert suppressed — recent failure"
);
return false;
}
self.insert_internal(key, protocol, None, reason);
true
}
/// Insert an H3 detection result with the Alt-Svc advertised port.
/// Returns `false` if H3 is suppressed.
pub fn insert_h3(&self, key: ProtocolCacheKey, h3_port: u16, reason: &str) -> bool {
if self.is_suppressed(&key, DetectedProtocol::H3) {
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
"H3 upgrade suppressed — recent failure"
);
return false;
}
self.insert_internal(key, DetectedProtocol::H3, Some(h3_port), reason);
true
}
/// Update the cache after an inline ALPN re-probe completes.
///
/// Always updates `last_probed_at`. If the protocol changed, logs the transition
/// and updates the entry. Returns `Some(new_protocol)` if changed, `None` if unchanged.
pub fn update_probe_result(
&self,
key: &ProtocolCacheKey,
probed_protocol: DetectedProtocol,
reason: &str,
) -> Option<DetectedProtocol> {
if let Some(mut entry) = self.cache.get_mut(key) {
let old_protocol = entry.protocol;
entry.last_probed_at = Instant::now();
entry.last_accessed_at = Instant::now();
if old_protocol != probed_protocol {
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
old = %old_protocol, new = %probed_protocol, reason = %reason,
"Protocol transition"
);
entry.protocol = probed_protocol;
entry.detected_at = Instant::now();
// Clear h3_port if downgrading from H3
if old_protocol == DetectedProtocol::H3 && probed_protocol != DetectedProtocol::H3 {
entry.h3_port = None;
}
return Some(probed_protocol);
}
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = %old_protocol, reason = %reason,
"Re-probe confirmed — no protocol change"
);
None
} else {
// Entry was evicted between the get() and the probe completing.
// Insert as a fresh entry.
self.insert_internal(key.clone(), probed_protocol, None, reason);
Some(probed_protocol)
}
}
/// Record a protocol failure. Future `insert()` calls for this protocol
/// will be suppressed until the escalating cooldown expires.
///
/// Cooldown escalation: 5s → 10s → 20s → 40s → 80s → 160s → 300s.
/// Consecutive counter resets if the previous failure is older than 2× its cooldown.
///
/// Cascading: when H2 fails, H3 cooldown is reduced to 5s remaining.
/// H1 failures are ignored (H1 is the protocol floor).
pub fn record_failure(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return; // H1 is the floor — nothing to suppress
}
let mut entry = self.failures.entry(key.clone()).or_default();
let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record {
Some(existing)
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
{
// Still within the "recent" window — escalate
let c = existing
.consecutive_failures
.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c))
}
_ => {
// First failure or old failure that expired long ago — reset
(1, PROTOCOL_FAILURE_COOLDOWN)
}
};
*record = Some(FailureRecord {
failed_at: Instant::now(),
cooldown: new_cooldown,
consecutive_failures: consecutive,
});
// Cascading: when H2 fails, reduce H3 cooldown to 5s remaining
if protocol == DetectedProtocol::H2 {
Self::reduce_cooldown_to(entry.h3.as_mut(), PROTOCOL_FAILURE_COOLDOWN);
}
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = ?protocol,
consecutive = consecutive,
cooldown_secs = new_cooldown.as_secs(),
"Protocol failure recorded — suppressing for {:?}", new_cooldown
);
}
/// Check whether a protocol is currently suppressed for the given key.
/// Returns `true` if the protocol failed within its cooldown period.
/// H1 is never suppressed.
pub fn is_suppressed(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) -> bool {
if protocol == DetectedProtocol::H1 {
return false;
}
self.failures
.get(key)
.and_then(|entry| {
entry
.get(protocol)
.map(|r| r.failed_at.elapsed() < r.cooldown)
})
.unwrap_or(false)
}
/// Check whether a protocol can be retried (for within-request escalation).
/// Returns `true` if there's no failure record OR if ≥5s have passed since
/// the last attempt. More permissive than `is_suppressed`.
pub fn can_retry(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) -> bool {
if protocol == DetectedProtocol::H1 {
return true;
}
match self.failures.get(key) {
Some(entry) => match entry.get(protocol) {
Some(r) => r.failed_at.elapsed() >= PROTOCOL_FAILURE_COOLDOWN,
None => true, // no failure record
},
None => true,
}
}
/// Record a retry attempt WITHOUT escalating the cooldown.
/// Resets the `failed_at` timestamp to prevent rapid retries (5s gate).
/// Called before an escalation attempt. If the attempt fails,
/// `record_failure` should be called afterward with proper escalation.
pub fn record_retry_attempt(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return;
}
if let Some(mut entry) = self.failures.get_mut(key) {
if let Some(ref mut r) = entry.get_mut(protocol) {
r.failed_at = Instant::now();
}
}
}
/// Clear the failure record for a protocol (it recovered).
/// Called when an escalation retry succeeds.
pub fn clear_failure(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return;
}
if let Some(mut entry) = self.failures.get_mut(key) {
*entry.get_mut(protocol) = None;
if entry.is_empty() {
drop(entry);
self.failures.remove(key);
}
}
}
/// Evict a cache entry entirely. Called when all protocol probes (H3, H2, H1)
/// have failed for a backend.
pub fn evict(&self, key: &ProtocolCacheKey) {
self.cache.remove(key);
self.failures.remove(key);
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
"Cache entry evicted — all protocols failed"
);
}
/// Clear all entries. Called on route updates to discard stale detections.
pub fn clear(&self) {
self.cache.clear();
self.failures.clear();
}
/// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| {
let key = entry.key();
let val = entry.value();
let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
let (h3_sup, h3_cd, h3_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
ProtocolCacheEntry {
host: key.host.clone(),
port: key.port,
domain: key.requested_host.clone(),
protocol: match val.protocol {
DetectedProtocol::H1 => "h1".to_string(),
DetectedProtocol::H2 => "h2".to_string(),
DetectedProtocol::H3 => "h3".to_string(),
},
h3_port: val.h3_port,
age_secs: val.detected_at.elapsed().as_secs(),
last_accessed_secs: val.last_accessed_at.elapsed().as_secs(),
last_probed_secs: val.last_probed_at.elapsed().as_secs(),
h2_suppressed: h2_sup,
h3_suppressed: h3_sup,
h2_cooldown_remaining_secs: h2_cd,
h3_cooldown_remaining_secs: h3_cd,
h2_consecutive_failures: h2_cons,
h3_consecutive_failures: h3_cons,
}
})
.collect()
}
// --- Internal helpers ---
/// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling.
fn insert_internal(
&self,
key: ProtocolCacheKey,
protocol: DetectedProtocol,
h3_port: Option<u16>,
reason: &str,
) {
// Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol {
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
old = %existing.protocol, new = %protocol, reason = %reason,
"Protocol transition"
);
}
drop(existing);
}
// Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self
.cache
.iter()
.min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
self.cache.remove(&oldest_key);
}
}
let now = Instant::now();
self.cache.insert(
key,
CachedEntry {
protocol,
detected_at: now,
last_accessed_at: now,
last_probed_at: now,
h3_port,
},
);
}
/// Reduce a failure record's remaining cooldown to `target`, if it currently
/// has MORE than `target` remaining. Never increases cooldown.
fn reduce_cooldown_to(record: Option<&mut FailureRecord>, target: Duration) {
if let Some(r) = record {
let elapsed = r.failed_at.elapsed();
if elapsed < r.cooldown {
let remaining = r.cooldown - elapsed;
if remaining > target {
// Shrink cooldown so it expires in `target` from now
r.cooldown = elapsed + target;
}
}
}
}
/// Extract suppression info from a failure record for metrics.
fn suppression_info(record: Option<&FailureRecord>) -> (bool, Option<u64>, Option<u32>) {
match record {
Some(r) => {
let elapsed = r.failed_at.elapsed();
let suppressed = elapsed < r.cooldown;
let remaining = if suppressed {
Some((r.cooldown - elapsed).as_secs())
} else {
None
};
(suppressed, remaining, Some(r.consecutive_failures))
}
None => (false, None, None),
}
}
/// Background cleanup loop.
async fn cleanup_loop(
cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>,
failures: Arc<DashMap<ProtocolCacheKey, FailureState>>,
) {
let mut interval = tokio::time::interval(PROTOCOL_CACHE_CLEANUP_INTERVAL);
loop {
interval.tick().await;
// Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone())
.collect();
if !expired.is_empty() {
debug!(
"Protocol cache cleanup: removing {} expired entries",
expired.len()
);
for key in expired {
cache.remove(&key);
}
}
// Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone())
.collect();
if !expired_failures.is_empty() {
debug!(
"Protocol cache cleanup: removing {} expired failure entries",
expired_failures.len()
);
for key in expired_failures {
failures.remove(&key);
}
}
// Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|e| e.value().all_expired())
.map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
.collect();
for key in oldest {
failures.remove(&key);
}
}
}
}
}
impl Drop for ProtocolCache {
fn drop(&mut self) {
if let Some(handle) = self.cleanup_handle.take() {
handle.abort();
}
}
}