//! SMTP TCP/TLS server. //! //! Listens on configured ports, accepts connections, and dispatches //! them to per-connection handlers. use crate::config::SmtpServerConfig; use crate::connection::{ self, CallbackRegistry, ConnectionEvent, SmtpStream, }; use crate::rate_limiter::{RateLimitConfig, RateLimiter}; use hickory_resolver::TokioResolver; use mailer_security::MessageAuthenticator; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use std::collections::HashMap; use std::io::BufReader; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::Arc; use tokio::io::BufReader as TokioBufReader; use tokio::net::TcpListener; use tokio::sync::mpsc; use tracing::{error, info, warn}; /// Handle for a running SMTP server. pub struct SmtpServerHandle { /// Shutdown signal. shutdown: Arc, /// Join handles for the listener tasks. handles: Vec>, /// Active connection count. pub active_connections: Arc, } impl SmtpServerHandle { /// Signal shutdown and wait for all listeners to stop. pub async fn shutdown(self) { self.shutdown.store(true, Ordering::SeqCst); for handle in self.handles { let _ = handle.await; } info!("SMTP server shut down"); } /// Check if the server is running. pub fn is_running(&self) -> bool { !self.shutdown.load(Ordering::SeqCst) } } /// Start the SMTP server with the given configuration. /// /// Returns a handle that can be used to shut down the server, /// and an event receiver for connection events (emailReceived, authRequest). pub async fn start_server( config: SmtpServerConfig, callback_registry: Arc, rate_limit_config: Option, ) -> Result<(SmtpServerHandle, mpsc::Receiver), Box> { let config = Arc::new(config); let shutdown = Arc::new(AtomicBool::new(false)); let active_connections = Arc::new(AtomicU32::new(0)); let rate_limiter = Arc::new(RateLimiter::new( rate_limit_config.unwrap_or_default(), )); let (event_tx, event_rx) = mpsc::channel::(1024); // Create shared security resources for in-process email verification let authenticator: Arc = Arc::new( mailer_security::default_authenticator() .map_err(|e| format!("Failed to create MessageAuthenticator: {e}"))? ); let resolver: Arc = Arc::new( TokioResolver::builder_tokio() .map(|b| b.build()) .map_err(|e| format!("Failed to create TokioResolver: {e}"))? ); // Build TLS acceptor if configured let tls_acceptor = if config.has_tls() { Some(Arc::new(build_tls_acceptor(&config)?)) } else { None }; let mut handles = Vec::new(); // Start listeners on each port for &port in &config.ports { let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?; info!(port = port, "SMTP server listening (STARTTLS)"); let handle = tokio::spawn(accept_loop( listener, config.clone(), shutdown.clone(), active_connections.clone(), rate_limiter.clone(), event_tx.clone(), callback_registry.clone(), tls_acceptor.clone(), false, // not implicit TLS authenticator.clone(), resolver.clone(), )); handles.push(handle); } // Start implicit TLS listener if configured if let Some(secure_port) = config.secure_port { if tls_acceptor.is_some() { let listener = TcpListener::bind(format!("0.0.0.0:{secure_port}")).await?; info!(port = secure_port, "SMTP server listening (implicit TLS)"); let handle = tokio::spawn(accept_loop( listener, config.clone(), shutdown.clone(), active_connections.clone(), rate_limiter.clone(), event_tx.clone(), callback_registry.clone(), tls_acceptor.clone(), true, // implicit TLS authenticator.clone(), resolver.clone(), )); handles.push(handle); } else { warn!("Secure port configured but TLS certificates not provided"); } } // Spawn periodic rate limiter cleanup { let rate_limiter = rate_limiter.clone(); let shutdown = shutdown.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60)); loop { interval.tick().await; if shutdown.load(Ordering::SeqCst) { break; } rate_limiter.cleanup(); } }); } Ok(( SmtpServerHandle { shutdown, handles, active_connections, }, event_rx, )) } /// Accept loop for a single listener. async fn accept_loop( listener: TcpListener, config: Arc, shutdown: Arc, active_connections: Arc, rate_limiter: Arc, event_tx: mpsc::Sender, callback_registry: Arc, tls_acceptor: Option>, implicit_tls: bool, authenticator: Arc, resolver: Arc, ) { loop { if shutdown.load(Ordering::SeqCst) { break; } // Use a short timeout to check shutdown periodically let accept_result = tokio::time::timeout( tokio::time::Duration::from_secs(1), listener.accept(), ) .await; let (tcp_stream, peer_addr) = match accept_result { Ok(Ok((stream, addr))) => (stream, addr), Ok(Err(e)) => { error!(error = %e, "Accept error"); continue; } Err(_) => continue, // timeout, check shutdown }; // Check max connections let current = active_connections.load(Ordering::SeqCst); if current >= config.max_connections { warn!( current = current, max = config.max_connections, "Max connections reached, rejecting" ); drop(tcp_stream); continue; } let remote_addr = peer_addr.ip().to_string(); let config = config.clone(); let rate_limiter = rate_limiter.clone(); let event_tx = event_tx.clone(); let callback_registry = callback_registry.clone(); let tls_acceptor = tls_acceptor.clone(); let active_connections = active_connections.clone(); let authenticator = authenticator.clone(); let resolver = resolver.clone(); active_connections.fetch_add(1, Ordering::SeqCst); tokio::spawn(async move { let stream = if implicit_tls { // Implicit TLS: wrap immediately if let Some(acceptor) = &tls_acceptor { match acceptor.accept(tcp_stream).await { Ok(tls_stream) => { SmtpStream::Tls(TokioBufReader::new(tls_stream)) } Err(e) => { warn!( remote_addr = %remote_addr, error = %e, "Implicit TLS handshake failed" ); active_connections.fetch_sub(1, Ordering::SeqCst); return; } } } else { active_connections.fetch_sub(1, Ordering::SeqCst); return; } } else { SmtpStream::Plain(TokioBufReader::new(tcp_stream)) }; connection::handle_connection( stream, config, rate_limiter, event_tx, callback_registry, tls_acceptor, remote_addr, implicit_tls, authenticator, resolver, ) .await; active_connections.fetch_sub(1, Ordering::SeqCst); }); } } /// SNI-based certificate resolver that selects the appropriate TLS certificate /// based on the client's requested hostname. struct SniCertResolver { /// Domain -> certified key mapping. certs: HashMap>, /// Default certificate for non-matching SNI or missing SNI. default: Arc, } impl std::fmt::Debug for SniCertResolver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SniCertResolver") .field("domains", &self.certs.keys().collect::>()) .finish() } } impl rustls::server::ResolvesServerCert for SniCertResolver { fn resolve( &self, client_hello: rustls::server::ClientHello<'_>, ) -> Option> { if let Some(sni) = client_hello.server_name() { let sni_lower = sni.to_lowercase(); if let Some(key) = self.certs.get(&sni_lower) { return Some(key.clone()); } } Some(self.default.clone()) } } /// Parse a PEM cert+key pair into a `CertifiedKey`. fn parse_certified_key( cert_pem: &str, key_pem: &str, ) -> Result> { let certs: Vec> = { let mut reader = BufReader::new(cert_pem.as_bytes()); rustls_pemfile::certs(&mut reader).collect::, _>>()? }; if certs.is_empty() { return Err("No certificates found in PEM".into()); } let key: PrivateKeyDer<'static> = { let mut reader = BufReader::new(key_pem.as_bytes()); let mut keys = Vec::new(); for item in rustls_pemfile::read_all(&mut reader) { match item? { rustls_pemfile::Item::Pkcs8Key(key) => keys.push(PrivateKeyDer::Pkcs8(key)), rustls_pemfile::Item::Pkcs1Key(key) => keys.push(PrivateKeyDer::Pkcs1(key)), rustls_pemfile::Item::Sec1Key(key) => keys.push(PrivateKeyDer::Sec1(key)), _ => {} } } keys.into_iter().next().ok_or("No private key found in PEM")? }; let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)?; Ok(rustls::sign::CertifiedKey::new(certs, signing_key)) } /// Build a TLS acceptor from PEM cert/key strings. fn build_tls_acceptor( config: &SmtpServerConfig, ) -> Result> { let cert_pem = config .tls_cert_pem .as_ref() .ok_or("TLS cert not configured")?; let key_pem = config .tls_key_pem .as_ref() .ok_or("TLS key not configured")?; // Parse certificates let certs: Vec> = { let mut reader = BufReader::new(cert_pem.as_bytes()); rustls_pemfile::certs(&mut reader) .collect::, _>>()? }; if certs.is_empty() { return Err("No certificates found in PEM".into()); } // Parse private key let key: PrivateKeyDer<'static> = { let mut reader = BufReader::new(key_pem.as_bytes()); // Try PKCS8 first, then RSA, then EC let mut keys = Vec::new(); for item in rustls_pemfile::read_all(&mut reader) { match item? { rustls_pemfile::Item::Pkcs8Key(key) => { keys.push(PrivateKeyDer::Pkcs8(key)); } rustls_pemfile::Item::Pkcs1Key(key) => { keys.push(PrivateKeyDer::Pkcs1(key)); } rustls_pemfile::Item::Sec1Key(key) => { keys.push(PrivateKeyDer::Sec1(key)); } _ => {} } } keys.into_iter() .next() .ok_or("No private key found in PEM")? }; // If additional TLS certs are configured, use SNI-based resolution let tls_config = if config.additional_tls_certs.is_empty() { rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)? } else { // Build default certified key let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)?; let default_ck = Arc::new(rustls::sign::CertifiedKey::new(certs, signing_key)); // Build per-domain certs let mut domain_certs = HashMap::new(); for domain_cert in &config.additional_tls_certs { match parse_certified_key(&domain_cert.cert_pem, &domain_cert.key_pem) { Ok(ck) => { let ck = Arc::new(ck); for domain in &domain_cert.domains { domain_certs.insert(domain.to_lowercase(), ck.clone()); } info!("SNI cert loaded for domains: {:?}", domain_cert.domains); } Err(e) => { warn!("Failed to load SNI cert for domains {:?}: {}", domain_cert.domains, e); } } } let resolver = SniCertResolver { certs: domain_certs, default: default_ck, }; rustls::ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new(resolver)) }; Ok(tokio_rustls::TlsAcceptor::from(Arc::new(tls_config))) } #[cfg(test)] mod tests { use super::*; #[test] fn test_server_config_defaults() { let config = SmtpServerConfig::default(); assert!(!config.has_tls()); assert_eq!(config.ports, vec![25]); } }