Files
smartproxy/rust/crates/rustproxy-passthrough/src/tls_handler.rs

273 lines
10 KiB
Rust

use std::collections::HashMap;
use std::io::BufReader;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig;
/// Ensure the default crypto provider is installed.
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
/// SNI-based certificate resolver with pre-parsed CertifiedKeys.
/// Enables shared ServerConfig across connections — avoids per-connection PEM parsing
/// and enables TLS session resumption.
#[derive(Debug)]
pub struct CertResolver {
certs: HashMap<String, Arc<CertifiedKey>>,
fallback: Option<Arc<CertifiedKey>>,
}
impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new();
let mut fallback = None;
for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
if domain == "*" {
fallback = Some(Arc::clone(&ck));
}
certs.insert(domain.clone(), ck);
}
// If no explicit "*" fallback, use the first available cert
if fallback.is_none() {
fallback = certs.values().next().map(Arc::clone);
}
Ok(Self { certs, fallback })
}
}
impl ResolvesServerCert for CertResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let domain = match client_hello.server_name() {
Some(name) => name,
None => return self.fallback.clone(),
};
// Exact match
if let Some(ck) = self.certs.get(domain) {
return Some(Arc::clone(ck));
}
// Wildcard: sub.example.com → *.example.com
if let Some(dot) = domain.find('.') {
let wc = format!("*.{}", &domain[dot + 1..]);
if let Some(ck) = self.certs.get(&wc) {
return Some(Arc::clone(ck));
}
}
self.fallback.clone()
}
}
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
// Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new()
.map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096) and ticket support");
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
pub fn build_tls_acceptor_with_config(
cert_pem: &str,
key_pem: &str,
tls_config: Option<&rustproxy_config::RouteTls>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
let mut config = if let Some(route_tls) = tls_config {
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?
};
// Apply session timeout if configured
if let Some(route_tls) = tls_config {
if let Some(timeout_secs) = route_tls.session_timeout {
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
256, // max sessions
);
debug!("TLS session timeout configured: {}s", timeout_secs);
}
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
};
let mut result = Vec::new();
for v in versions {
match v.as_str() {
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
if !result.contains(&&rustls::version::TLS12) {
result.push(&rustls::version::TLS12);
}
}
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
if !result.contains(&&rustls::version::TLS13) {
result.push(&rustls::version::TLS13);
}
}
other => {
debug!("Unknown TLS version '{}', ignoring", other);
}
}
}
if result.is_empty() {
// Fallback to both if no valid versions specified
vec![&rustls::version::TLS12, &rustls::version::TLS13]
} else {
result
}
}
/// Accept a TLS connection from a client stream.
pub async fn accept_tls(
stream: TcpStream,
acceptor: &TlsAcceptor,
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
let tls_stream = acceptor.accept(stream).await?;
debug!("TLS handshake completed");
Ok(tls_stream)
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
pub async fn connect_tls(
host: &str,
port: u16,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
let tls_stream = connector.connect(server_name, stream).await?;
debug!("Backend TLS connection established to {}:{}", host, port);
Ok(tls_stream)
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
Ok(certs)
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
Ok(key)
}
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
/// In internal networks, backends may use self-signed certs.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}