140 lines
4.2 KiB
Rust
140 lines
4.2 KiB
Rust
|
|
use std::collections::HashMap;
|
||
|
|
use std::sync::{Arc, RwLock};
|
||
|
|
|
||
|
|
use crate::cert_store::CertBundle;
|
||
|
|
|
||
|
|
/// Dynamic SNI-based certificate resolver.
|
||
|
|
/// Used by the TLS stack to select the right certificate based on client SNI.
|
||
|
|
pub struct SniResolver {
|
||
|
|
/// Domain -> certificate bundle mapping
|
||
|
|
certs: RwLock<HashMap<String, Arc<CertBundle>>>,
|
||
|
|
/// Fallback certificate (used when no SNI or no match)
|
||
|
|
fallback: RwLock<Option<Arc<CertBundle>>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl SniResolver {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
certs: RwLock::new(HashMap::new()),
|
||
|
|
fallback: RwLock::new(None),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Register a certificate for a domain.
|
||
|
|
pub fn add_cert(&self, domain: String, bundle: CertBundle) {
|
||
|
|
let mut certs = self.certs.write().unwrap();
|
||
|
|
certs.insert(domain, Arc::new(bundle));
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Set the fallback certificate.
|
||
|
|
pub fn set_fallback(&self, bundle: CertBundle) {
|
||
|
|
let mut fallback = self.fallback.write().unwrap();
|
||
|
|
*fallback = Some(Arc::new(bundle));
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Resolve a certificate for the given SNI domain.
|
||
|
|
pub fn resolve(&self, domain: &str) -> Option<Arc<CertBundle>> {
|
||
|
|
let certs = self.certs.read().unwrap();
|
||
|
|
|
||
|
|
// Try exact match
|
||
|
|
if let Some(bundle) = certs.get(domain) {
|
||
|
|
return Some(Arc::clone(bundle));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Try wildcard match (e.g., *.example.com)
|
||
|
|
if let Some(dot_pos) = domain.find('.') {
|
||
|
|
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
|
||
|
|
if let Some(bundle) = certs.get(&wildcard) {
|
||
|
|
return Some(Arc::clone(bundle));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Fallback
|
||
|
|
let fallback = self.fallback.read().unwrap();
|
||
|
|
fallback.clone()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Remove a certificate for a domain.
|
||
|
|
pub fn remove_cert(&self, domain: &str) {
|
||
|
|
let mut certs = self.certs.write().unwrap();
|
||
|
|
certs.remove(domain);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get the number of registered certificates.
|
||
|
|
pub fn cert_count(&self) -> usize {
|
||
|
|
self.certs.read().unwrap().len()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for SniResolver {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self::new()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use crate::cert_store::{CertBundle, CertMetadata, CertSource};
|
||
|
|
|
||
|
|
fn make_bundle(domain: &str) -> CertBundle {
|
||
|
|
CertBundle {
|
||
|
|
key_pem: format!("KEY-{}", domain),
|
||
|
|
cert_pem: format!("CERT-{}", domain),
|
||
|
|
ca_pem: None,
|
||
|
|
metadata: CertMetadata {
|
||
|
|
domain: domain.to_string(),
|
||
|
|
source: CertSource::Static,
|
||
|
|
issued_at: 0,
|
||
|
|
expires_at: 0,
|
||
|
|
renewed_at: None,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_exact_domain_resolve() {
|
||
|
|
let resolver = SniResolver::new();
|
||
|
|
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||
|
|
let result = resolver.resolve("example.com");
|
||
|
|
assert!(result.is_some());
|
||
|
|
assert_eq!(result.unwrap().cert_pem, "CERT-example.com");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_wildcard_resolve() {
|
||
|
|
let resolver = SniResolver::new();
|
||
|
|
resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com"));
|
||
|
|
let result = resolver.resolve("sub.example.com");
|
||
|
|
assert!(result.is_some());
|
||
|
|
assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_fallback() {
|
||
|
|
let resolver = SniResolver::new();
|
||
|
|
resolver.set_fallback(make_bundle("fallback"));
|
||
|
|
let result = resolver.resolve("unknown.com");
|
||
|
|
assert!(result.is_some());
|
||
|
|
assert_eq!(result.unwrap().cert_pem, "CERT-fallback");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_no_match_no_fallback() {
|
||
|
|
let resolver = SniResolver::new();
|
||
|
|
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||
|
|
let result = resolver.resolve("other.com");
|
||
|
|
assert!(result.is_none());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_remove_cert() {
|
||
|
|
let resolver = SniResolver::new();
|
||
|
|
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||
|
|
assert_eq!(resolver.cert_count(), 1);
|
||
|
|
resolver.remove_cert("example.com");
|
||
|
|
assert_eq!(resolver.cert_count(), 0);
|
||
|
|
assert!(resolver.resolve("example.com").is_none());
|
||
|
|
}
|
||
|
|
}
|