273 lines
10 KiB
Rust
273 lines
10 KiB
Rust
use std::collections::HashMap;
|
|
use std::io::BufReader;
|
|
use std::sync::Arc;
|
|
|
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
|
use rustls::server::ResolvesServerCert;
|
|
use rustls::sign::CertifiedKey;
|
|
use rustls::ServerConfig;
|
|
use tokio::net::TcpStream;
|
|
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
|
use tracing::{debug, info};
|
|
|
|
use crate::tcp_listener::TlsCertConfig;
|
|
|
|
/// Ensure the default crypto provider is installed.
|
|
fn ensure_crypto_provider() {
|
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
|
}
|
|
|
|
/// SNI-based certificate resolver with pre-parsed CertifiedKeys.
|
|
/// Enables shared ServerConfig across connections — avoids per-connection PEM parsing
|
|
/// and enables TLS session resumption.
|
|
#[derive(Debug)]
|
|
pub struct CertResolver {
|
|
certs: HashMap<String, Arc<CertifiedKey>>,
|
|
fallback: Option<Arc<CertifiedKey>>,
|
|
}
|
|
|
|
impl CertResolver {
|
|
/// Build a resolver from PEM-encoded cert/key configs.
|
|
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
|
|
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
|
ensure_crypto_provider();
|
|
let provider = rustls::crypto::ring::default_provider();
|
|
let mut certs = HashMap::new();
|
|
let mut fallback = None;
|
|
|
|
for (domain, cfg) in configs {
|
|
let cert_chain = load_certs(&cfg.cert_pem)?;
|
|
let key = load_private_key(&cfg.key_pem)?;
|
|
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
|
|
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
|
|
if domain == "*" {
|
|
fallback = Some(Arc::clone(&ck));
|
|
}
|
|
certs.insert(domain.clone(), ck);
|
|
}
|
|
|
|
// If no explicit "*" fallback, use the first available cert
|
|
if fallback.is_none() {
|
|
fallback = certs.values().next().map(Arc::clone);
|
|
}
|
|
|
|
Ok(Self { certs, fallback })
|
|
}
|
|
}
|
|
|
|
impl ResolvesServerCert for CertResolver {
|
|
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
|
let domain = match client_hello.server_name() {
|
|
Some(name) => name,
|
|
None => return self.fallback.clone(),
|
|
};
|
|
// Exact match
|
|
if let Some(ck) = self.certs.get(domain) {
|
|
return Some(Arc::clone(ck));
|
|
}
|
|
// Wildcard: sub.example.com → *.example.com
|
|
if let Some(dot) = domain.find('.') {
|
|
let wc = format!("*.{}", &domain[dot + 1..]);
|
|
if let Some(ck) = self.certs.get(&wc) {
|
|
return Some(Arc::clone(ck));
|
|
}
|
|
}
|
|
self.fallback.clone()
|
|
}
|
|
}
|
|
|
|
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
|
|
/// The returned acceptor can be reused across all connections (cheap Arc clone).
|
|
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
|
ensure_crypto_provider();
|
|
let mut config = ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_cert_resolver(Arc::new(resolver));
|
|
|
|
// Shared session cache — enables session ID resumption across connections
|
|
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
|
|
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
|
|
config.ticketer = rustls::crypto::ring::Ticketer::new()
|
|
.map_err(|e| format!("Ticketer: {}", e))?;
|
|
|
|
info!("Built shared TLS config with session cache (4096) and ticket support");
|
|
Ok(TlsAcceptor::from(Arc::new(config)))
|
|
}
|
|
|
|
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
|
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
|
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
|
}
|
|
|
|
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
|
|
pub fn build_tls_acceptor_with_config(
|
|
cert_pem: &str,
|
|
key_pem: &str,
|
|
tls_config: Option<&rustproxy_config::RouteTls>,
|
|
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
|
ensure_crypto_provider();
|
|
let certs = load_certs(cert_pem)?;
|
|
let key = load_private_key(key_pem)?;
|
|
|
|
let mut config = if let Some(route_tls) = tls_config {
|
|
// Apply TLS version restrictions
|
|
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
|
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
|
builder
|
|
.with_no_client_auth()
|
|
.with_single_cert(certs, key)?
|
|
} else {
|
|
ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(certs, key)?
|
|
};
|
|
|
|
// Apply session timeout if configured
|
|
if let Some(route_tls) = tls_config {
|
|
if let Some(timeout_secs) = route_tls.session_timeout {
|
|
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
|
|
256, // max sessions
|
|
);
|
|
debug!("TLS session timeout configured: {}s", timeout_secs);
|
|
}
|
|
}
|
|
|
|
Ok(TlsAcceptor::from(Arc::new(config)))
|
|
}
|
|
|
|
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
|
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
|
let versions = match versions {
|
|
Some(v) if !v.is_empty() => v,
|
|
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
|
};
|
|
|
|
let mut result = Vec::new();
|
|
for v in versions {
|
|
match v.as_str() {
|
|
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
|
|
if !result.contains(&&rustls::version::TLS12) {
|
|
result.push(&rustls::version::TLS12);
|
|
}
|
|
}
|
|
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
|
|
if !result.contains(&&rustls::version::TLS13) {
|
|
result.push(&rustls::version::TLS13);
|
|
}
|
|
}
|
|
other => {
|
|
debug!("Unknown TLS version '{}', ignoring", other);
|
|
}
|
|
}
|
|
}
|
|
|
|
if result.is_empty() {
|
|
// Fallback to both if no valid versions specified
|
|
vec![&rustls::version::TLS12, &rustls::version::TLS13]
|
|
} else {
|
|
result
|
|
}
|
|
}
|
|
|
|
/// Accept a TLS connection from a client stream.
|
|
pub async fn accept_tls(
|
|
stream: TcpStream,
|
|
acceptor: &TlsAcceptor,
|
|
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let tls_stream = acceptor.accept(stream).await?;
|
|
debug!("TLS handshake completed");
|
|
Ok(tls_stream)
|
|
}
|
|
|
|
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
|
pub async fn connect_tls(
|
|
host: &str,
|
|
port: u16,
|
|
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
|
ensure_crypto_provider();
|
|
let config = rustls::ClientConfig::builder()
|
|
.dangerous()
|
|
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
|
.with_no_client_auth();
|
|
|
|
let connector = TlsConnector::from(Arc::new(config));
|
|
|
|
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
|
stream.set_nodelay(true)?;
|
|
|
|
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
|
|
let tls_stream = connector.connect(server_name, stream).await?;
|
|
debug!("Backend TLS connection established to {}:{}", host, port);
|
|
Ok(tls_stream)
|
|
}
|
|
|
|
/// Load certificates from PEM string.
|
|
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let mut reader = BufReader::new(pem.as_bytes());
|
|
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
if certs.is_empty() {
|
|
return Err("No certificates found in PEM data".into());
|
|
}
|
|
Ok(certs)
|
|
}
|
|
|
|
/// Load private key from PEM string.
|
|
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let mut reader = BufReader::new(pem.as_bytes());
|
|
// Try PKCS8 first, then RSA, then EC
|
|
let key = rustls_pemfile::private_key(&mut reader)?
|
|
.ok_or("No private key found in PEM data")?;
|
|
Ok(key)
|
|
}
|
|
|
|
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
|
|
/// In internal networks, backends may use self-signed certs.
|
|
#[derive(Debug)]
|
|
struct InsecureVerifier;
|
|
|
|
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
|
fn verify_server_cert(
|
|
&self,
|
|
_end_entity: &CertificateDer<'_>,
|
|
_intermediates: &[CertificateDer<'_>],
|
|
_server_name: &rustls::pki_types::ServerName<'_>,
|
|
_ocsp_response: &[u8],
|
|
_now: rustls::pki_types::UnixTime,
|
|
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
|
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
|
}
|
|
|
|
fn verify_tls12_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn verify_tls13_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
|
vec![
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
|
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
|
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
|
rustls::SignatureScheme::ED25519,
|
|
rustls::SignatureScheme::RSA_PSS_SHA256,
|
|
rustls::SignatureScheme::RSA_PSS_SHA384,
|
|
rustls::SignatureScheme::RSA_PSS_SHA512,
|
|
]
|
|
}
|
|
}
|