feat(tls): add shared TLS acceptor with SNI resolver and session resumption; prefer shared acceptor and fall back to per-connection when routes specify custom TLS versions

This commit is contained in:
2026-02-16 03:00:39 +00:00
parent fa2a27df6d
commit 455d5bb757
4 changed files with 155 additions and 16 deletions

View File

@@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tracing::{info, error, debug, warn};
use thiserror::Error;
@@ -122,8 +123,10 @@ pub struct TcpListenerManager {
route_manager: Arc<ArcSwap<RouteManager>>,
/// Shared metrics collector
metrics: Arc<MetricsCollector>,
/// TLS acceptors indexed by domain (ArcSwap for hot-reload visibility in accept loops)
/// Raw PEM TLS configs indexed by domain (kept for fallback with custom TLS versions)
tls_configs: Arc<ArcSwap<HashMap<String, TlsCertConfig>>>,
/// Shared TLS acceptor (pre-parsed certs + session cache). None when no certs configured.
shared_tls_acceptor: Arc<ArcSwap<Option<TlsAcceptor>>>,
/// HTTP proxy service for HTTP-level forwarding
http_proxy: Arc<HttpProxyService>,
/// Connection configuration
@@ -154,6 +157,7 @@ impl TcpListenerManager {
route_manager: Arc::new(ArcSwap::from(route_manager)),
metrics,
tls_configs: Arc::new(ArcSwap::from(Arc::new(HashMap::new()))),
shared_tls_acceptor: Arc::new(ArcSwap::from(Arc::new(None))),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
@@ -179,6 +183,7 @@ impl TcpListenerManager {
route_manager: Arc::new(ArcSwap::from(route_manager)),
metrics,
tls_configs: Arc::new(ArcSwap::from(Arc::new(HashMap::new()))),
shared_tls_acceptor: Arc::new(ArcSwap::from(Arc::new(None))),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
@@ -197,8 +202,26 @@ impl TcpListenerManager {
}
/// Set TLS certificate configurations.
/// Builds a shared TLS acceptor with pre-parsed certs and session resumption support.
/// Uses ArcSwap so running accept loops immediately see the new certs.
pub fn set_tls_configs(&self, configs: HashMap<String, TlsCertConfig>) {
if !configs.is_empty() {
match tls_handler::CertResolver::new(&configs)
.and_then(tls_handler::build_shared_tls_acceptor)
{
Ok(acceptor) => {
info!("Built shared TLS acceptor for {} domain(s)", configs.len());
self.shared_tls_acceptor.store(Arc::new(Some(acceptor)));
}
Err(e) => {
warn!("Failed to build shared TLS acceptor: {}, falling back to per-connection", e);
self.shared_tls_acceptor.store(Arc::new(None));
}
}
} else {
self.shared_tls_acceptor.store(Arc::new(None));
}
// Keep raw PEM configs for fallback (routes with custom TLS versions)
self.tls_configs.store(Arc::new(configs));
}
@@ -224,6 +247,7 @@ impl TcpListenerManager {
let route_manager_swap = Arc::clone(&self.route_manager);
let metrics = Arc::clone(&self.metrics);
let tls_configs = Arc::clone(&self.tls_configs);
let shared_tls_acceptor = Arc::clone(&self.shared_tls_acceptor);
let http_proxy = Arc::clone(&self.http_proxy);
let conn_config = Arc::clone(&self.conn_config);
let conn_tracker = Arc::clone(&self.conn_tracker);
@@ -233,7 +257,7 @@ impl TcpListenerManager {
let handle = tokio::spawn(async move {
Self::accept_loop(
listener, port, route_manager_swap, metrics, tls_configs,
http_proxy, conn_config, conn_tracker, cancel, relay,
shared_tls_acceptor, http_proxy, conn_config, conn_tracker, cancel, relay,
).await;
});
@@ -322,6 +346,7 @@ impl TcpListenerManager {
route_manager_swap: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<ArcSwap<HashMap<String, TlsCertConfig>>>,
shared_tls_acceptor: Arc<ArcSwap<Option<TlsAcceptor>>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
conn_tracker: Arc<ConnectionTracker>,
@@ -353,6 +378,8 @@ impl TcpListenerManager {
let m = Arc::clone(&metrics);
// Load the latest TLS configs from ArcSwap on each connection
let tc = tls_configs.load_full();
// Load the latest shared TLS acceptor from ArcSwap
let sa = shared_tls_acceptor.load_full();
let hp = Arc::clone(&http_proxy);
let cc = Arc::clone(&conn_config);
let ct = Arc::clone(&conn_tracker);
@@ -362,7 +389,7 @@ impl TcpListenerManager {
tokio::spawn(async move {
let result = Self::handle_connection(
stream, port, peer_addr, rm, m, tc, hp, cc, cn, sr,
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr,
).await;
if let Err(e) = result {
debug!("Connection error from {}: {}", peer_addr, e);
@@ -388,6 +415,7 @@ impl TcpListenerManager {
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
shared_tls_acceptor: Arc<Option<TlsAcceptor>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
cancel: CancellationToken,
@@ -777,13 +805,9 @@ impl TcpListenerManager {
Ok(())
}
Some(rustproxy_config::TlsMode::Terminate) => {
let tls_config = Self::find_tls_config(&domain, &tls_configs)?;
// TLS accept with timeout, applying route-level TLS settings
// Use shared acceptor (session resumption) or fall back to per-connection
let route_tls = route_match.route.action.tls.as_ref();
let acceptor = tls_handler::build_tls_acceptor_with_config(
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
)?;
let acceptor = Self::get_tls_acceptor(&domain, &tls_configs, &*shared_tls_acceptor, route_tls)?;
let tls_stream = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
tls_handler::accept_tls(stream, &acceptor),
@@ -846,7 +870,8 @@ impl TcpListenerManager {
let route_tls = route_match.route.action.tls.as_ref();
Self::handle_tls_terminate_reencrypt(
stream, n, &domain, &target_host, target_port,
peer_addr, &tls_configs, Arc::clone(&metrics), route_id, &conn_config, route_tls,
peer_addr, &tls_configs, &shared_tls_acceptor,
Arc::clone(&metrics), route_id, &conn_config, route_tls,
).await
}
None => {
@@ -991,15 +1016,14 @@ impl TcpListenerManager {
target_port: u16,
peer_addr: std::net::SocketAddr,
tls_configs: &HashMap<String, TlsCertConfig>,
shared_tls_acceptor: &Option<TlsAcceptor>,
metrics: Arc<MetricsCollector>,
route_id: Option<&str>,
conn_config: &ConnectionConfig,
route_tls: Option<&rustproxy_config::RouteTls>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let tls_config = Self::find_tls_config(domain, tls_configs)?;
let acceptor = tls_handler::build_tls_acceptor_with_config(
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
)?;
// Use shared acceptor (session resumption) or fall back to per-connection
let acceptor = Self::get_tls_acceptor(domain, tls_configs, shared_tls_acceptor, route_tls)?;
// Accept TLS from client with timeout
let client_tls = match tokio::time::timeout(
@@ -1069,6 +1093,30 @@ impl TcpListenerManager {
Ok(())
}
/// Get a TLS acceptor, preferring the shared one (with session resumption)
/// and falling back to per-connection when custom TLS versions are configured.
fn get_tls_acceptor(
domain: &Option<String>,
tls_configs: &HashMap<String, TlsCertConfig>,
shared_tls_acceptor: &Option<TlsAcceptor>,
route_tls: Option<&rustproxy_config::RouteTls>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
let has_custom_versions = route_tls
.and_then(|t| t.versions.as_ref())
.map(|v| !v.is_empty())
.unwrap_or(false);
if !has_custom_versions {
if let Some(shared) = shared_tls_acceptor {
return Ok(shared.clone()); // TlsAcceptor wraps Arc<ServerConfig>, clone is cheap
}
}
// Fallback: per-connection acceptor (custom TLS versions or shared build failed)
let tls_config = Self::find_tls_config(domain, tls_configs)?;
tls_handler::build_tls_acceptor_with_config(&tls_config.cert_pem, &tls_config.key_pem, route_tls)
}
/// Find the TLS config for a given domain.
fn find_tls_config<'a>(
domain: &Option<String>,

View File

@@ -1,17 +1,99 @@
use std::collections::HashMap;
use std::io::BufReader;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tracing::debug;
use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig;
/// Ensure the default crypto provider is installed.
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
/// SNI-based certificate resolver with pre-parsed CertifiedKeys.
/// Enables shared ServerConfig across connections — avoids per-connection PEM parsing
/// and enables TLS session resumption.
#[derive(Debug)]
pub struct CertResolver {
certs: HashMap<String, Arc<CertifiedKey>>,
fallback: Option<Arc<CertifiedKey>>,
}
impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new();
let mut fallback = None;
for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
if domain == "*" {
fallback = Some(Arc::clone(&ck));
}
certs.insert(domain.clone(), ck);
}
// If no explicit "*" fallback, use the first available cert
if fallback.is_none() {
fallback = certs.values().next().map(Arc::clone);
}
Ok(Self { certs, fallback })
}
}
impl ResolvesServerCert for CertResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let domain = match client_hello.server_name() {
Some(name) => name,
None => return self.fallback.clone(),
};
// Exact match
if let Some(ck) = self.certs.get(domain) {
return Some(Arc::clone(ck));
}
// Wildcard: sub.example.com → *.example.com
if let Some(dot) = domain.find('.') {
let wc = format!("*.{}", &domain[dot + 1..]);
if let Some(ck) = self.certs.get(&wc) {
return Some(Arc::clone(ck));
}
}
self.fallback.clone()
}
}
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
// Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new()
.map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096) and ticket support");
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)