use std::collections::HashMap; use std::sync::{Arc, RwLock}; use crate::cert_store::CertBundle; /// Dynamic SNI-based certificate resolver. /// Used by the TLS stack to select the right certificate based on client SNI. pub struct SniResolver { /// Domain -> certificate bundle mapping certs: RwLock>>, /// Fallback certificate (used when no SNI or no match) fallback: RwLock>>, } impl SniResolver { pub fn new() -> Self { Self { certs: RwLock::new(HashMap::new()), fallback: RwLock::new(None), } } /// Register a certificate for a domain. pub fn add_cert(&self, domain: String, bundle: CertBundle) { let mut certs = self.certs.write().unwrap(); certs.insert(domain, Arc::new(bundle)); } /// Set the fallback certificate. pub fn set_fallback(&self, bundle: CertBundle) { let mut fallback = self.fallback.write().unwrap(); *fallback = Some(Arc::new(bundle)); } /// Resolve a certificate for the given SNI domain. pub fn resolve(&self, domain: &str) -> Option> { let certs = self.certs.read().unwrap(); // Try exact match if let Some(bundle) = certs.get(domain) { return Some(Arc::clone(bundle)); } // Try wildcard match (e.g., *.example.com) if let Some(dot_pos) = domain.find('.') { let wildcard = format!("*.{}", &domain[dot_pos + 1..]); if let Some(bundle) = certs.get(&wildcard) { return Some(Arc::clone(bundle)); } } // Fallback let fallback = self.fallback.read().unwrap(); fallback.clone() } /// Remove a certificate for a domain. pub fn remove_cert(&self, domain: &str) { let mut certs = self.certs.write().unwrap(); certs.remove(domain); } /// Get the number of registered certificates. pub fn cert_count(&self) -> usize { self.certs.read().unwrap().len() } } impl Default for SniResolver { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; use crate::cert_store::{CertBundle, CertMetadata, CertSource}; fn make_bundle(domain: &str) -> CertBundle { CertBundle { key_pem: format!("KEY-{}", domain), cert_pem: format!("CERT-{}", domain), ca_pem: None, metadata: CertMetadata { domain: domain.to_string(), source: CertSource::Static, issued_at: 0, expires_at: 0, renewed_at: None, }, } } #[test] fn test_exact_domain_resolve() { let resolver = SniResolver::new(); resolver.add_cert("example.com".to_string(), make_bundle("example.com")); let result = resolver.resolve("example.com"); assert!(result.is_some()); assert_eq!(result.unwrap().cert_pem, "CERT-example.com"); } #[test] fn test_wildcard_resolve() { let resolver = SniResolver::new(); resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com")); let result = resolver.resolve("sub.example.com"); assert!(result.is_some()); assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com"); } #[test] fn test_fallback() { let resolver = SniResolver::new(); resolver.set_fallback(make_bundle("fallback")); let result = resolver.resolve("unknown.com"); assert!(result.is_some()); assert_eq!(result.unwrap().cert_pem, "CERT-fallback"); } #[test] fn test_no_match_no_fallback() { let resolver = SniResolver::new(); resolver.add_cert("example.com".to_string(), make_bundle("example.com")); let result = resolver.resolve("other.com"); assert!(result.is_none()); } #[test] fn test_remove_cert() { let resolver = SniResolver::new(); resolver.add_cert("example.com".to_string(), make_bundle("example.com")); assert_eq!(resolver.cert_count(), 1); resolver.remove_cert("example.com"); assert_eq!(resolver.cert_count(), 0); assert!(resolver.resolve("example.com").is_none()); } }