use std::collections::HashMap; use std::io::BufReader; use std::sync::Arc; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::server::ResolvesServerCert; use rustls::sign::CertifiedKey; use rustls::ServerConfig; use tokio::net::TcpStream; use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream}; use tracing::{debug, info}; use crate::tcp_listener::TlsCertConfig; /// Ensure the default crypto provider is installed. fn ensure_crypto_provider() { let _ = rustls::crypto::ring::default_provider().install_default(); } /// SNI-based certificate resolver with pre-parsed CertifiedKeys. /// Enables shared ServerConfig across connections — avoids per-connection PEM parsing /// and enables TLS session resumption. #[derive(Debug)] pub struct CertResolver { certs: HashMap>, fallback: Option>, } impl CertResolver { /// Build a resolver from PEM-encoded cert/key configs. /// Parses all PEM data upfront so connections only do a cheap HashMap lookup. pub fn new(configs: &HashMap) -> Result> { ensure_crypto_provider(); let provider = rustls::crypto::ring::default_provider(); let mut certs = HashMap::new(); let mut fallback = None; for (domain, cfg) in configs { let cert_chain = load_certs(&cfg.cert_pem)?; let key = load_private_key(&cfg.key_pem)?; let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider) .map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?); if domain == "*" { fallback = Some(Arc::clone(&ck)); } certs.insert(domain.clone(), ck); } // If no explicit "*" fallback, use the first available cert if fallback.is_none() { fallback = certs.values().next().map(Arc::clone); } Ok(Self { certs, fallback }) } } impl ResolvesServerCert for CertResolver { fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option> { let domain = match client_hello.server_name() { Some(name) => name, None => return self.fallback.clone(), }; // Exact match if let Some(ck) = self.certs.get(domain) { return Some(Arc::clone(ck)); } // Wildcard: sub.example.com → *.example.com if let Some(dot) = domain.find('.') { let wc = format!("*.{}", &domain[dot + 1..]); if let Some(ck) = self.certs.get(&wc) { return Some(Arc::clone(ck)); } } self.fallback.clone() } } /// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets. /// The returned acceptor can be reused across all connections (cheap Arc clone). pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result> { ensure_crypto_provider(); let mut config = ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new(resolver)); // Shared session cache — enables session ID resumption across connections config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096); // Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted) config.ticketer = rustls::crypto::ring::Ticketer::new() .map_err(|e| format!("Ticketer: {}", e))?; info!("Built shared TLS config with session cache (4096) and ticket support"); Ok(TlsAcceptor::from(Arc::new(config))) } /// Build a TLS acceptor from PEM-encoded cert and key data. pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result> { build_tls_acceptor_with_config(cert_pem, key_pem, None) } /// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning. pub fn build_tls_acceptor_with_config( cert_pem: &str, key_pem: &str, tls_config: Option<&rustproxy_config::RouteTls>, ) -> Result> { ensure_crypto_provider(); let certs = load_certs(cert_pem)?; let key = load_private_key(key_pem)?; let mut config = if let Some(route_tls) = tls_config { // Apply TLS version restrictions let versions = resolve_tls_versions(route_tls.versions.as_deref()); let builder = ServerConfig::builder_with_protocol_versions(&versions); builder .with_no_client_auth() .with_single_cert(certs, key)? } else { ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)? }; // Apply session timeout if configured if let Some(route_tls) = tls_config { if let Some(timeout_secs) = route_tls.session_timeout { config.session_storage = rustls::server::ServerSessionMemoryCache::new( 256, // max sessions ); debug!("TLS session timeout configured: {}s", timeout_secs); } } Ok(TlsAcceptor::from(Arc::new(config))) } /// Resolve TLS version strings to rustls SupportedProtocolVersion. fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> { let versions = match versions { Some(v) if !v.is_empty() => v, _ => return vec![&rustls::version::TLS12, &rustls::version::TLS13], }; let mut result = Vec::new(); for v in versions { match v.as_str() { "TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => { if !result.contains(&&rustls::version::TLS12) { result.push(&rustls::version::TLS12); } } "TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => { if !result.contains(&&rustls::version::TLS13) { result.push(&rustls::version::TLS13); } } other => { debug!("Unknown TLS version '{}', ignoring", other); } } } if result.is_empty() { // Fallback to both if no valid versions specified vec![&rustls::version::TLS12, &rustls::version::TLS13] } else { result } } /// Accept a TLS connection from a client stream. pub async fn accept_tls( stream: TcpStream, acceptor: &TlsAcceptor, ) -> Result, Box> { let tls_stream = acceptor.accept(stream).await?; debug!("TLS handshake completed"); Ok(tls_stream) } /// Connect to a backend with TLS (for terminate-and-reencrypt mode). pub async fn connect_tls( host: &str, port: u16, ) -> Result, Box> { ensure_crypto_provider(); let config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(InsecureVerifier)) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; stream.set_nodelay(true)?; let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?; let tls_stream = connector.connect(server_name, stream).await?; debug!("Backend TLS connection established to {}:{}", host, port); Ok(tls_stream) } /// Load certificates from PEM string. fn load_certs(pem: &str) -> Result>, Box> { let mut reader = BufReader::new(pem.as_bytes()); let certs: Vec> = rustls_pemfile::certs(&mut reader) .collect::, _>>()?; if certs.is_empty() { return Err("No certificates found in PEM data".into()); } Ok(certs) } /// Load private key from PEM string. fn load_private_key(pem: &str) -> Result, Box> { let mut reader = BufReader::new(pem.as_bytes()); // Try PKCS8 first, then RSA, then EC let key = rustls_pemfile::private_key(&mut reader)? .ok_or("No private key found in PEM data")?; Ok(key) } /// Insecure certificate verifier for backend connections (terminate-and-reencrypt). /// In internal networks, backends may use self-signed certs. #[derive(Debug)] struct InsecureVerifier; impl rustls::client::danger::ServerCertVerifier for InsecureVerifier { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, ] } }