feat(mailer-smtp): add SCRAM-SHA-256 auth, Ed25519 DKIM, opportunistic TLS, SNI cert selection, pipelining and delivery/bridge improvements

This commit is contained in:
2026-02-11 10:11:43 +00:00
parent 7908cbaefa
commit b10597fd5e
28 changed files with 1849 additions and 153 deletions

View File

@@ -23,3 +23,6 @@ rustls = { version = "0.23", default-features = false, features = ["ring", "logg
rustls-pemfile = "2"
mailparse.workspace = true
webpki-roots = "0.26"
sha2.workspace = true
hmac.workspace = true
pbkdf2.workspace = true

View File

@@ -37,6 +37,12 @@ pub struct SmtpClientConfig {
/// Maximum connections per pool. Default: 10.
#[serde(default = "default_max_pool_connections")]
pub max_pool_connections: usize,
/// Accept invalid TLS certificates (expired, self-signed, wrong hostname).
/// Standard for MTA-to-MTA opportunistic TLS per RFC 7435.
/// Default: false.
#[serde(default)]
pub tls_opportunistic: bool,
}
/// Authentication configuration.
@@ -60,8 +66,15 @@ pub struct DkimSignConfig {
pub domain: String,
/// DKIM selector (e.g. "default" or "mta").
pub selector: String,
/// PEM-encoded RSA private key.
/// PEM-encoded private key (RSA or Ed25519 PKCS#8).
pub private_key: String,
/// Key type: "rsa" (default) or "ed25519".
#[serde(default = "default_key_type")]
pub key_type: String,
}
fn default_key_type() -> String {
"rsa".to_string()
}
impl SmtpClientConfig {

View File

@@ -117,6 +117,7 @@ pub async fn connect_tls(
host: &str,
port: u16,
timeout_secs: u64,
tls_opportunistic: bool,
) -> Result<ClientSmtpStream, SmtpClientError> {
debug!("Connecting to {}:{} (implicit TLS)", host, port);
let addr = format!("{host}:{port}");
@@ -130,7 +131,7 @@ pub async fn connect_tls(
message: format!("Failed to connect to {addr}: {e}"),
})?;
let tls_stream = perform_tls_handshake(tcp_stream, host).await?;
let tls_stream = perform_tls_handshake(tcp_stream, host, tls_opportunistic).await?;
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
}
@@ -138,24 +139,77 @@ pub async fn connect_tls(
pub async fn upgrade_to_tls(
stream: ClientSmtpStream,
hostname: &str,
tls_opportunistic: bool,
) -> Result<ClientSmtpStream, SmtpClientError> {
debug!("Upgrading connection to TLS (STARTTLS) for {}", hostname);
let tcp_stream = stream.into_tcp_stream()?;
let tls_stream = perform_tls_handshake(tcp_stream, hostname).await?;
let tls_stream = perform_tls_handshake(tcp_stream, hostname, tls_opportunistic).await?;
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
}
/// A TLS certificate verifier that accepts any certificate.
/// Used for MTA-to-MTA opportunistic TLS per RFC 7435.
#[derive(Debug)]
struct OpportunisticVerifier;
impl rustls::client::danger::ServerCertVerifier for OpportunisticVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer<'_>,
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
_server_name: &rustls_pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
/// Perform the TLS handshake on a TCP stream using webpki-roots.
/// When `tls_opportunistic` is true, certificate verification is skipped
/// (standard for MTA-to-MTA delivery per RFC 7435).
async fn perform_tls_handshake(
tcp_stream: TcpStream,
hostname: &str,
tls_opportunistic: bool,
) -> Result<TlsStream<TcpStream>, SmtpClientError> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_config = if tls_opportunistic {
debug!("Using opportunistic TLS (no cert verification) for {}", hostname);
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(OpportunisticVerifier))
.with_no_client_auth()
} else {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config));
let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string()).map_err(|e| {
@@ -190,7 +244,7 @@ mod tests {
#[tokio::test]
async fn test_connect_tls_refused() {
let result = connect_tls("127.0.0.1", 19998, 2).await;
let result = connect_tls("127.0.0.1", 19998, 2, false).await;
assert!(result.is_err());
}

View File

@@ -116,6 +116,7 @@ impl ConnectionPool {
&self.config.host,
self.config.port,
self.config.connection_timeout_secs,
self.config.tls_opportunistic,
)
.await?
} else {
@@ -139,7 +140,7 @@ impl ConnectionPool {
if !self.config.secure && capabilities.starttls {
protocol::send_starttls(&mut stream, self.config.socket_timeout_secs).await?;
stream =
super::connection::upgrade_to_tls(stream, &self.config.host).await?;
super::connection::upgrade_to_tls(stream, &self.config.host, self.config.tls_opportunistic).await?;
// Re-EHLO after STARTTLS — use updated capabilities for auth
capabilities = protocol::send_ehlo(
@@ -244,9 +245,10 @@ impl SmtpClientManager {
protocol::send_rset(&mut conn.stream, config.socket_timeout_secs).await?;
}
// Perform the SMTP transaction
// Perform the SMTP transaction (use pipelining if server supports it)
let pipelining = conn.capabilities.pipelining;
let result =
Self::perform_send(&mut conn.stream, sender, recipients, message, config).await;
Self::perform_send(&mut conn.stream, sender, recipients, message, config, pipelining).await;
// Re-acquire the pool lock and release the connection
let mut pool = pool_arc.lock().await;
@@ -268,30 +270,39 @@ impl SmtpClientManager {
recipients: &[String],
message: &[u8],
config: &SmtpClientConfig,
pipelining: bool,
) -> Result<SmtpSendResult, SmtpClientError> {
let timeout_secs = config.socket_timeout_secs;
// MAIL FROM
protocol::send_mail_from(stream, sender, timeout_secs).await?;
let (accepted, rejected) = if pipelining {
// Use pipelined envelope: MAIL FROM + all RCPT TO in one batch
let (_mail_ok, acc, rej) = protocol::send_pipelined_envelope(
stream, sender, recipients, timeout_secs,
).await?;
(acc, rej)
} else {
// Sequential: MAIL FROM, then each RCPT TO
protocol::send_mail_from(stream, sender, timeout_secs).await?;
// RCPT TO for each recipient
let mut accepted = Vec::new();
let mut rejected = Vec::new();
let mut accepted = Vec::new();
let mut rejected = Vec::new();
for rcpt in recipients {
match protocol::send_rcpt_to(stream, rcpt, timeout_secs).await {
Ok(resp) => {
if resp.is_success() {
accepted.push(rcpt.clone());
} else {
for rcpt in recipients {
match protocol::send_rcpt_to(stream, rcpt, timeout_secs).await {
Ok(resp) => {
if resp.is_success() {
accepted.push(rcpt.clone());
} else {
rejected.push(rcpt.clone());
}
}
Err(_) => {
rejected.push(rcpt.clone());
}
}
Err(_) => {
rejected.push(rcpt.clone());
}
}
}
(accepted, rejected)
};
// If no recipients were accepted, fail
if accepted.is_empty() {
@@ -339,6 +350,7 @@ impl SmtpClientManager {
&config.host,
config.port,
config.connection_timeout_secs,
config.tls_opportunistic,
)
.await?
} else {

View File

@@ -318,6 +318,54 @@ pub async fn send_rcpt_to(
Ok(resp)
}
/// Send MAIL FROM + RCPT TO commands in a single pipelined batch.
///
/// Writes all envelope commands at once, then reads responses in order.
/// Returns `(mail_from_ok, accepted_recipients, rejected_recipients)`.
pub async fn send_pipelined_envelope(
stream: &mut ClientSmtpStream,
sender: &str,
recipients: &[String],
timeout_secs: u64,
) -> Result<(bool, Vec<String>, Vec<String>), SmtpClientError> {
// Build the full pipelined command batch
let mut batch = format!("MAIL FROM:<{sender}>\r\n");
for rcpt in recipients {
batch.push_str(&format!("RCPT TO:<{rcpt}>\r\n"));
}
// Send all commands at once
debug!("SMTP C (pipelined): MAIL FROM + {} RCPT TO", recipients.len());
stream.write_all(batch.as_bytes()).await?;
stream.flush().await?;
// Read MAIL FROM response
let mail_resp = read_response(stream, timeout_secs).await?;
if !mail_resp.is_success() {
return Err(mail_resp.to_error());
}
// Read RCPT TO responses
let mut accepted = Vec::new();
let mut rejected = Vec::new();
for rcpt in recipients {
match read_response(stream, timeout_secs).await {
Ok(resp) => {
if resp.is_success() {
accepted.push(rcpt.clone());
} else {
rejected.push(rcpt.clone());
}
}
Err(_) => {
rejected.push(rcpt.clone());
}
}
}
Ok((true, accepted, rejected))
}
/// Send DATA command, followed by the message body with dot-stuffing.
pub async fn send_data(
stream: &mut ClientSmtpStream,

View File

@@ -50,6 +50,7 @@ pub enum SmtpCommand {
pub enum AuthMechanism {
Plain,
Login,
ScramSha256,
}
/// Errors that can occur during command parsing.
@@ -218,6 +219,7 @@ fn parse_auth(rest: &str) -> Result<SmtpCommand, ParseError> {
let mechanism = match mech_str.to_ascii_uppercase().as_str() {
"PLAIN" => AuthMechanism::Plain,
"LOGIN" => AuthMechanism::Login,
"SCRAM-SHA-256" => AuthMechanism::ScramSha256,
other => {
return Err(ParseError::SyntaxError(format!(
"unsupported AUTH mechanism: {other}"

View File

@@ -2,6 +2,17 @@
use serde::{Deserialize, Serialize};
/// Per-domain TLS certificate for SNI-based cert selection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsDomainCert {
/// Domain names this certificate covers (matched against SNI hostname).
pub domains: Vec<String>,
/// Certificate chain in PEM format.
pub cert_pem: String,
/// Private key in PEM format.
pub key_pem: String,
}
/// Configuration for an SMTP server instance.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmtpServerConfig {
@@ -11,10 +22,13 @@ pub struct SmtpServerConfig {
pub ports: Vec<u16>,
/// Port for implicit TLS (e.g. 465). None = no implicit TLS port.
pub secure_port: Option<u16>,
/// TLS certificate chain in PEM format.
/// TLS certificate chain in PEM format (default cert).
pub tls_cert_pem: Option<String>,
/// TLS private key in PEM format.
/// TLS private key in PEM format (default key).
pub tls_key_pem: Option<String>,
/// Additional per-domain TLS certificates for SNI-based selection.
#[serde(default)]
pub additional_tls_certs: Vec<TlsDomainCert>,
/// Maximum message size in bytes.
pub max_message_size: u64,
/// Maximum number of concurrent connections.
@@ -43,6 +57,7 @@ impl Default for SmtpServerConfig {
secure_port: None,
tls_cert_pem: None,
tls_key_pem: None,
additional_tls_certs: Vec::new(),
max_message_size: 10 * 1024 * 1024, // 10 MB
max_connections: 100,
max_recipients: 100,

View File

@@ -9,6 +9,7 @@ use crate::config::SmtpServerConfig;
use crate::data::{DataAccumulator, DataAction};
use crate::rate_limiter::RateLimiter;
use crate::response::{build_capabilities, SmtpResponse};
use crate::scram::{ScramCredentials, ScramServer};
use crate::session::{AuthState, SmtpSession};
use crate::validation;
@@ -52,6 +53,13 @@ pub enum ConnectionEvent {
password: String,
remote_addr: String,
},
/// A SCRAM credential request — Rust needs stored credentials from TS.
ScramCredentialRequest {
correlation_id: String,
session_id: String,
username: String,
remote_addr: String,
},
}
/// How email data is transported from Rust to TS.
@@ -81,6 +89,16 @@ pub struct AuthResult {
pub message: Option<String>,
}
/// Result of TS returning SCRAM credentials for a user.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScramCredentialResult {
pub found: bool,
pub salt: Option<Vec<u8>>,
pub iterations: Option<u32>,
pub stored_key: Option<Vec<u8>>,
pub server_key: Option<Vec<u8>>,
}
/// Abstraction over plain and TLS streams.
pub enum SmtpStream {
Plain(BufReader<TcpStream>),
@@ -133,6 +151,14 @@ impl SmtpStream {
}
}
/// Check if the internal buffer has unread data (pipelined commands).
pub fn has_buffered_data(&self) -> bool {
match self {
SmtpStream::Plain(reader) => !reader.buffer().is_empty(),
SmtpStream::Tls(reader) => !reader.buffer().is_empty(),
}
}
/// Unwrap to get the raw TcpStream for STARTTLS upgrade.
/// Only works on Plain streams.
pub fn into_tcp_stream(self) -> Option<TcpStream> {
@@ -212,7 +238,7 @@ pub async fn handle_connection(
break;
}
Ok(Ok(_)) => {
// Process command
// Process the first command
let response = process_line(
&line,
&mut session,
@@ -227,59 +253,123 @@ pub async fn handle_connection(
)
.await;
// Check for pipelined commands in the buffer.
// Collect pipelinable responses into a batch for single write.
let mut response_batch: Vec<u8> = Vec::new();
let mut should_break = false;
let mut starttls_signal = false;
match response {
LineResult::Response(resp) => {
if stream.write_all(&resp.to_bytes()).await.is_err() {
break;
}
if stream.flush().await.is_err() {
break;
}
response_batch.extend_from_slice(&resp.to_bytes());
}
LineResult::Quit(resp) => {
let _ = stream.write_all(&resp.to_bytes()).await;
let _ = stream.flush().await;
break;
should_break = true;
}
LineResult::StartTlsSignal => {
// Send 220 Ready response
let resp = SmtpResponse::new(220, "Ready to start TLS");
if stream.write_all(&resp.to_bytes()).await.is_err() {
break;
}
if stream.flush().await.is_err() {
break;
}
// Extract TCP stream and upgrade
if let Some(tcp_stream) = stream.into_tcp_stream() {
if let Some(acceptor) = &tls_acceptor {
match acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
stream = SmtpStream::Tls(BufReader::new(tls_stream));
session.secure = true;
// Client must re-EHLO after STARTTLS
session.state = crate::state::SmtpState::Connected;
session.client_hostname = None;
session.esmtp = false;
session.auth_state = AuthState::None;
session.envelope = Default::default();
debug!(session_id = %session.id, "TLS upgrade successful");
}
Err(e) => {
warn!(session_id = %session.id, error = %e, "TLS handshake failed");
break;
}
}
} else {
break;
}
} else {
// Already TLS — shouldn't happen
break;
}
starttls_signal = true;
}
LineResult::NoResponse => {}
LineResult::Disconnect => {
should_break = true;
}
}
if should_break {
break;
}
// Process additional pipelined commands from the buffer
if !starttls_signal {
while stream.has_buffered_data() {
let mut next_line = String::new();
match stream.read_line(&mut next_line, 4096).await {
Ok(0) | Err(_) => break,
Ok(_) => {
let next_response = process_line(
&next_line,
&mut session,
&mut stream,
&config,
&rate_limiter,
&event_tx,
callback_register.as_ref(),
&tls_acceptor,
&authenticator,
&resolver,
)
.await;
match next_response {
LineResult::Response(resp) => {
response_batch.extend_from_slice(&resp.to_bytes());
}
LineResult::Quit(resp) => {
response_batch.extend_from_slice(&resp.to_bytes());
should_break = true;
break;
}
LineResult::StartTlsSignal | LineResult::Disconnect => {
// Non-pipelinable: flush batch and handle
starttls_signal = matches!(next_response, LineResult::StartTlsSignal);
should_break = matches!(next_response, LineResult::Disconnect);
break;
}
LineResult::NoResponse => {}
}
}
}
}
}
// Flush the accumulated response batch in one write
if !response_batch.is_empty() {
if stream.write_all(&response_batch).await.is_err() {
break;
}
if stream.flush().await.is_err() {
break;
}
}
if should_break {
break;
}
if starttls_signal {
// Send 220 Ready response
let resp = SmtpResponse::new(220, "Ready to start TLS");
if stream.write_all(&resp.to_bytes()).await.is_err() {
break;
}
if stream.flush().await.is_err() {
break;
}
// Extract TCP stream and upgrade
if let Some(tcp_stream) = stream.into_tcp_stream() {
if let Some(acceptor) = &tls_acceptor {
match acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
stream = SmtpStream::Tls(BufReader::new(tls_stream));
session.secure = true;
session.state = crate::state::SmtpState::Connected;
session.client_hostname = None;
session.esmtp = false;
session.auth_state = AuthState::None;
session.envelope = Default::default();
debug!(session_id = %session.id, "TLS upgrade successful");
}
Err(e) => {
warn!(session_id = %session.id, error = %e, "TLS handshake failed");
break;
}
}
} else {
break;
}
} else {
break;
}
}
@@ -322,6 +412,12 @@ pub trait CallbackRegistry: Send + Sync {
&self,
correlation_id: &str,
) -> oneshot::Receiver<AuthResult>;
/// Register a callback for SCRAM credential lookup and return a receiver.
fn register_scram_callback(
&self,
correlation_id: &str,
) -> oneshot::Receiver<ScramCredentialResult>;
}
/// Process a single input line from the client.
@@ -406,16 +502,29 @@ async fn process_line(
mechanism,
initial_response,
} => {
handle_auth(
mechanism,
initial_response,
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
if matches!(mechanism, AuthMechanism::ScramSha256) {
handle_auth_scram(
initial_response,
session,
stream,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
} else {
handle_auth(
mechanism,
initial_response,
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
}
}
SmtpCommand::Help(_) => {
@@ -832,6 +941,217 @@ async fn handle_auth(
))
}
}
AuthMechanism::ScramSha256 => {
// SCRAM is handled separately in process_line; this should not be reached.
LineResult::Response(SmtpResponse::not_implemented())
}
}
}
/// Handle AUTH SCRAM-SHA-256 — full exchange in a single async function.
///
/// SCRAM is a multi-step challenge-response protocol:
/// 1. Client sends client-first-message (in initial_response or after 334)
/// 2. Server requests SCRAM credentials from TS
/// 3. Server sends server-first-message (334 challenge)
/// 4. Client sends client-final-message (proof)
/// 5. Server verifies proof and responds with 235 or 535
async fn handle_auth_scram(
initial_response: Option<String>,
session: &mut SmtpSession,
stream: &mut SmtpStream,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
) -> LineResult {
if !config.auth_enabled {
return LineResult::Response(SmtpResponse::not_implemented());
}
if session.is_authenticated() {
return LineResult::Response(SmtpResponse::bad_sequence("Already authenticated"));
}
if !session.state.can_auth() {
return LineResult::Response(SmtpResponse::bad_sequence("Send EHLO first"));
}
// Step 1: Get client-first-message
let client_first_b64 = match initial_response {
Some(s) if !s.is_empty() => s,
_ => {
// No initial response — send empty 334 challenge
let resp = SmtpResponse::auth_challenge("");
if stream.write_all(&resp.to_bytes()).await.is_err() {
return LineResult::Disconnect;
}
if stream.flush().await.is_err() {
return LineResult::Disconnect;
}
// Read client-first-message
let mut line = String::new();
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
match timeout(socket_timeout, stream.read_line(&mut line, 4096)).await {
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return LineResult::Disconnect,
Ok(Ok(_)) => {}
}
let trimmed = line.trim().to_string();
if trimmed == "*" {
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
}
trimmed
}
};
// Decode base64 client-first-message
let client_first_bytes = match BASE64.decode(client_first_b64.as_bytes()) {
Ok(b) => b,
Err(_) => {
return LineResult::Response(SmtpResponse::param_error("Invalid base64 encoding"));
}
};
let client_first = match String::from_utf8(client_first_bytes) {
Ok(s) => s,
Err(_) => {
return LineResult::Response(SmtpResponse::param_error("Invalid UTF-8 in SCRAM message"));
}
};
// Parse client-first-message
let mut scram = match ScramServer::from_client_first(&client_first) {
Ok(s) => s,
Err(e) => {
debug!(error = %e, "SCRAM client-first-message parse error");
return LineResult::Response(SmtpResponse::param_error(
"Invalid SCRAM client-first-message",
));
}
};
// Step 2: Request SCRAM credentials from TS
let correlation_id = uuid::Uuid::new_v4().to_string();
let rx = callback_registry.register_scram_callback(&correlation_id);
let event = ConnectionEvent::ScramCredentialRequest {
correlation_id: correlation_id.clone(),
session_id: session.id.clone(),
username: scram.username.clone(),
remote_addr: session.remote_addr.clone(),
};
if event_tx.send(event).await.is_err() {
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
}
// Wait for credentials from TS
let cred_timeout = Duration::from_secs(5);
let cred_result = match timeout(cred_timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => {
warn!(correlation_id = %correlation_id, "SCRAM credential callback dropped");
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
}
Err(_) => {
warn!(correlation_id = %correlation_id, "SCRAM credential request timed out");
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
}
};
if !cred_result.found {
// User not found — fail auth (don't reveal that user doesn't exist)
session.auth_state = AuthState::None;
let exceeded = session.record_auth_failure(config.max_auth_failures);
if exceeded {
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
return LineResult::Response(SmtpResponse::auth_failed());
}
let creds = ScramCredentials {
salt: cred_result.salt.unwrap_or_default(),
iterations: cred_result.iterations.unwrap_or(4096),
stored_key: cred_result.stored_key.unwrap_or_default(),
server_key: cred_result.server_key.unwrap_or_default(),
};
// Step 3: Generate and send server-first-message
let server_first = scram.server_first_message(creds);
let server_first_b64 = BASE64.encode(server_first.as_bytes());
let challenge = SmtpResponse::auth_challenge(&server_first_b64);
if stream.write_all(&challenge.to_bytes()).await.is_err() {
return LineResult::Disconnect;
}
if stream.flush().await.is_err() {
return LineResult::Disconnect;
}
// Step 4: Read client-final-message
let mut client_final_line = String::new();
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
match timeout(socket_timeout, stream.read_line(&mut client_final_line, 4096)).await {
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return LineResult::Disconnect,
Ok(Ok(_)) => {}
}
let client_final_b64 = client_final_line.trim();
// Cancel if *
if client_final_b64 == "*" {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
}
// Decode base64 client-final-message
let client_final_bytes = match BASE64.decode(client_final_b64.as_bytes()) {
Ok(b) => b,
Err(_) => {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::param_error("Invalid base64 encoding"));
}
};
let client_final = match String::from_utf8(client_final_bytes) {
Ok(s) => s,
Err(_) => {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::param_error("Invalid UTF-8 in SCRAM message"));
}
};
// Step 5: Verify proof
match scram.process_client_final(&client_final) {
Ok(server_final) => {
let server_final_b64 = BASE64.encode(server_final.as_bytes());
session.auth_state = AuthState::Authenticated {
username: scram.username.clone(),
};
LineResult::Response(SmtpResponse::new(
235,
format!("2.7.0 Authentication successful {}", server_final_b64),
))
}
Err(e) => {
debug!(error = %e, "SCRAM proof verification failed");
session.auth_state = AuthState::None;
let exceeded = session.record_auth_failure(config.max_auth_failures);
if exceeded {
if !rate_limiter.check_auth_failure(&session.remote_addr) {
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
LineResult::Response(SmtpResponse::auth_failed())
}
}
}

View File

@@ -19,6 +19,7 @@ pub mod connection;
pub mod data;
pub mod rate_limiter;
pub mod response;
pub mod scram;
pub mod server;
pub mod session;
pub mod state;

View File

@@ -196,7 +196,7 @@ pub fn build_capabilities(
caps.push("STARTTLS".to_string());
}
if auth_available {
caps.push("AUTH PLAIN LOGIN".to_string());
caps.push("AUTH PLAIN LOGIN SCRAM-SHA-256".to_string());
}
caps
}
@@ -253,7 +253,7 @@ mod tests {
let caps = build_capabilities(10485760, true, false, true);
assert!(caps.contains(&"SIZE 10485760".to_string()));
assert!(caps.contains(&"STARTTLS".to_string()));
assert!(caps.contains(&"AUTH PLAIN LOGIN".to_string()));
assert!(caps.contains(&"AUTH PLAIN LOGIN SCRAM-SHA-256".to_string()));
assert!(caps.contains(&"PIPELINING".to_string()));
}
@@ -262,7 +262,7 @@ mod tests {
// When already secure, STARTTLS should NOT be advertised
let caps = build_capabilities(10485760, true, true, false);
assert!(!caps.contains(&"STARTTLS".to_string()));
assert!(!caps.contains(&"AUTH PLAIN LOGIN".to_string()));
assert!(!caps.contains(&"AUTH PLAIN LOGIN SCRAM-SHA-256".to_string()));
}
#[test]

View File

@@ -0,0 +1,342 @@
//! SCRAM-SHA-256 server-side implementation (RFC 5802 + RFC 7677).
//!
//! Implements the server side of the SCRAM-SHA-256 SASL mechanism,
//! a challenge-response protocol that avoids transmitting cleartext passwords.
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
type HmacSha256 = Hmac<Sha256>;
/// Pre-computed SCRAM credentials for a user (derived from password).
#[derive(Debug, Clone)]
pub struct ScramCredentials {
pub salt: Vec<u8>,
pub iterations: u32,
pub stored_key: Vec<u8>,
pub server_key: Vec<u8>,
}
/// Server-side SCRAM state machine.
pub struct ScramServer {
/// Username extracted from client-first-message.
pub username: String,
/// Full combined nonce (client + server).
combined_nonce: String,
/// Server nonce portion (used in tests for verification).
#[allow(dead_code)]
server_nonce: String,
/// Stored credentials (set after TS responds).
credentials: Option<ScramCredentials>,
/// The server-first-message (for auth message construction).
server_first: String,
/// The client-first-message-bare (for auth message construction).
client_first_bare: String,
}
impl ScramServer {
/// Process the client-first-message.
///
/// Parses the client nonce and username, generates a server nonce,
/// and returns a partial state that needs credentials to produce the
/// server-first-message.
pub fn from_client_first(client_first: &str) -> Result<Self, String> {
// client-first-message = gs2-header client-first-message-bare
// gs2-header = "n,," (no channel binding)
// client-first-message-bare = "n=username,r=nonce"
let bare = if let Some(rest) = client_first.strip_prefix("n,,") {
rest
} else if let Some(rest) = client_first.strip_prefix("y,,") {
rest
} else {
return Err("Invalid SCRAM gs2-header".into());
};
let mut username = String::new();
let mut client_nonce = String::new();
for part in bare.split(',') {
if let Some(val) = part.strip_prefix("n=") {
username = val.to_string();
} else if let Some(val) = part.strip_prefix("r=") {
client_nonce = val.to_string();
}
}
if username.is_empty() || client_nonce.is_empty() {
return Err("Missing username or nonce in client-first-message".into());
}
// Generate server nonce
let server_nonce: String = (0..24)
.map(|_| {
let idx = (rand_byte() as usize) % 62;
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[idx] as char
})
.collect();
let combined_nonce = format!("{}{}", client_nonce, server_nonce);
Ok(ScramServer {
username,
combined_nonce,
server_nonce,
credentials: None,
server_first: String::new(),
client_first_bare: bare.to_string(),
})
}
/// Set the credentials and produce the server-first-message.
pub fn server_first_message(&mut self, creds: ScramCredentials) -> String {
let salt_b64 = BASE64.encode(&creds.salt);
let server_first = format!(
"r={},s={},i={}",
self.combined_nonce, salt_b64, creds.iterations
);
self.server_first = server_first.clone();
self.credentials = Some(creds);
server_first
}
/// Process the client-final-message and verify the proof.
///
/// Returns the server-final-message (containing ServerSignature) on success,
/// or an error string on failure.
pub fn process_client_final(&mut self, client_final: &str) -> Result<String, String> {
let creds = self.credentials.as_ref().ok_or("No credentials set")?;
// Parse client-final-message
// Format: c=biws,r=<combined_nonce>,p=<client_proof>
let mut channel_binding = String::new();
let mut nonce = String::new();
let mut proof_b64 = String::new();
for part in client_final.split(',') {
if let Some(val) = part.strip_prefix("c=") {
channel_binding = val.to_string();
} else if let Some(val) = part.strip_prefix("r=") {
nonce = val.to_string();
} else if let Some(val) = part.strip_prefix("p=") {
proof_b64 = val.to_string();
}
}
// Verify nonce matches
if nonce != self.combined_nonce {
return Err("Nonce mismatch".into());
}
// Build the client-final-message-without-proof
let client_final_without_proof = format!("c={},r={}", channel_binding, nonce);
// Complete the auth message
let auth_message = format!(
"{},{},{}",
self.client_first_bare, self.server_first, client_final_without_proof
);
// Verify client proof
let client_proof = BASE64.decode(proof_b64.as_bytes())
.map_err(|_| "Invalid base64 in client proof")?;
// ClientSignature = HMAC(StoredKey, AuthMessage)
let client_signature = hmac_sha256(&creds.stored_key, auth_message.as_bytes());
// ClientKey = ClientProof XOR ClientSignature
if client_proof.len() != client_signature.len() {
return Err("Client proof length mismatch".into());
}
let client_key: Vec<u8> = client_proof
.iter()
.zip(client_signature.iter())
.map(|(a, b)| a ^ b)
.collect();
// StoredKey = H(ClientKey)
let computed_stored_key = sha256(&client_key);
// Verify: computed StoredKey must match the stored StoredKey
if computed_stored_key != creds.stored_key {
return Err("Authentication failed".into());
}
// Generate ServerSignature for mutual authentication
let server_signature = hmac_sha256(&creds.server_key, auth_message.as_bytes());
let server_sig_b64 = BASE64.encode(&server_signature);
Ok(format!("v={}", server_sig_b64))
}
}
/// Compute SCRAM credentials from a plaintext password (for TS to pre-compute).
pub fn compute_scram_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
// SaltedPassword = PBKDF2-HMAC-SHA256(password, salt, iterations)
let mut salted_password = [0u8; 32];
pbkdf2::pbkdf2_hmac::<Sha256>(
password.as_bytes(),
salt,
iterations,
&mut salted_password,
);
// ClientKey = HMAC(SaltedPassword, "Client Key")
let client_key = hmac_sha256(&salted_password, b"Client Key");
// StoredKey = H(ClientKey)
let stored_key = sha256(&client_key);
// ServerKey = HMAC(SaltedPassword, "Server Key")
let server_key = hmac_sha256(&salted_password, b"Server Key");
ScramCredentials {
salt: salt.to_vec(),
iterations,
stored_key,
server_key,
}
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
/// Simple random byte using system randomness.
fn rand_byte() -> u8 {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let state = RandomState::new();
let mut hasher = state.build_hasher();
hasher.write_u64(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64);
hasher.finish() as u8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scram_full_exchange() {
let password = "pencil";
let salt = b"test-salt-1234";
let iterations = 4096;
// Pre-compute server-side credentials from password
let creds = compute_scram_credentials(password, salt, iterations);
// 1. Client sends client-first-message
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let mut server = ScramServer::from_client_first(client_first).unwrap();
assert_eq!(server.username, "user");
// 2. Server responds with server-first-message
let server_first = server.server_first_message(creds.clone());
assert!(server_first.starts_with(&format!("r=rOprNGfwEbeRWgbNEkqO{}", server.server_nonce)));
assert!(server_first.contains("s="));
assert!(server_first.contains("i=4096"));
// 3. Client computes proof
// SaltedPassword
let mut salted_password = [0u8; 32];
pbkdf2::pbkdf2_hmac::<Sha256>(
password.as_bytes(),
salt,
iterations,
&mut salted_password,
);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key);
let client_first_bare = "n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final_without_proof = format!("c=biws,r={}", server.combined_nonce);
let auth_message = format!("{},{},{}", client_first_bare, server_first, client_final_without_proof);
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(a, b)| a ^ b)
.collect();
let proof_b64 = BASE64.encode(&client_proof);
let client_final = format!("c=biws,r={},p={}", server.combined_nonce, proof_b64);
// 4. Server verifies proof
let result = server.process_client_final(&client_final);
assert!(result.is_ok(), "SCRAM verification failed: {:?}", result.err());
let server_final = result.unwrap();
assert!(server_final.starts_with("v="));
}
#[test]
fn test_scram_wrong_password() {
let password = "pencil";
let wrong_password = "wrong";
let salt = b"test-salt";
let iterations = 4096;
let creds = compute_scram_credentials(password, salt, iterations);
let client_first = "n,,n=user,r=clientnonce123";
let mut server = ScramServer::from_client_first(client_first).unwrap();
let server_first = server.server_first_message(creds);
// Client computes proof with wrong password
let mut salted_password = [0u8; 32];
pbkdf2::pbkdf2_hmac::<Sha256>(
wrong_password.as_bytes(),
salt,
iterations,
&mut salted_password,
);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key);
let client_first_bare = "n=user,r=clientnonce123";
let client_final_without_proof = format!("c=biws,r={}", server.combined_nonce);
let auth_message = format!("{},{},{}", client_first_bare, server_first, client_final_without_proof);
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(a, b)| a ^ b)
.collect();
let proof_b64 = BASE64.encode(&client_proof);
let client_final = format!("c=biws,r={},p={}", server.combined_nonce, proof_b64);
let result = server.process_client_final(&client_final);
assert!(result.is_err());
}
#[test]
fn test_compute_scram_credentials() {
let creds = compute_scram_credentials("password", b"salt", 4096);
assert_eq!(creds.salt, b"salt");
assert_eq!(creds.iterations, 4096);
assert_eq!(creds.stored_key.len(), 32);
assert_eq!(creds.server_key.len(), 32);
}
#[test]
fn test_invalid_client_first() {
assert!(ScramServer::from_client_first("invalid").is_err());
assert!(ScramServer::from_client_first("n,,").is_err());
}
}

View File

@@ -12,6 +12,7 @@ use crate::rate_limiter::{RateLimitConfig, RateLimiter};
use hickory_resolver::TokioResolver;
use mailer_security::MessageAuthenticator;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::collections::HashMap;
use std::io::BufReader;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
@@ -263,6 +264,69 @@ async fn accept_loop(
}
}
/// SNI-based certificate resolver that selects the appropriate TLS certificate
/// based on the client's requested hostname.
struct SniCertResolver {
/// Domain -> certified key mapping.
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
/// Default certificate for non-matching SNI or missing SNI.
default: Arc<rustls::sign::CertifiedKey>,
}
impl std::fmt::Debug for SniCertResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SniCertResolver")
.field("domains", &self.certs.keys().collect::<Vec<_>>())
.finish()
}
}
impl rustls::server::ResolvesServerCert for SniCertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
if let Some(sni) = client_hello.server_name() {
let sni_lower = sni.to_lowercase();
if let Some(key) = self.certs.get(&sni_lower) {
return Some(key.clone());
}
}
Some(self.default.clone())
}
}
/// Parse a PEM cert+key pair into a `CertifiedKey`.
fn parse_certified_key(
cert_pem: &str,
key_pem: &str,
) -> Result<rustls::sign::CertifiedKey, Box<dyn std::error::Error + Send + Sync>> {
let certs: Vec<CertificateDer<'static>> = {
let mut reader = BufReader::new(cert_pem.as_bytes());
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?
};
if certs.is_empty() {
return Err("No certificates found in PEM".into());
}
let key: PrivateKeyDer<'static> = {
let mut reader = BufReader::new(key_pem.as_bytes());
let mut keys = Vec::new();
for item in rustls_pemfile::read_all(&mut reader) {
match item? {
rustls_pemfile::Item::Pkcs8Key(key) => keys.push(PrivateKeyDer::Pkcs8(key)),
rustls_pemfile::Item::Pkcs1Key(key) => keys.push(PrivateKeyDer::Pkcs1(key)),
rustls_pemfile::Item::Sec1Key(key) => keys.push(PrivateKeyDer::Sec1(key)),
_ => {}
}
}
keys.into_iter().next().ok_or("No private key found in PEM")?
};
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)?;
Ok(rustls::sign::CertifiedKey::new(certs, signing_key))
}
/// Build a TLS acceptor from PEM cert/key strings.
fn build_tls_acceptor(
config: &SmtpServerConfig,
@@ -311,9 +375,42 @@ fn build_tls_acceptor(
.ok_or("No private key found in PEM")?
};
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
// If additional TLS certs are configured, use SNI-based resolution
let tls_config = if config.additional_tls_certs.is_empty() {
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?
} else {
// Build default certified key
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)?;
let default_ck = Arc::new(rustls::sign::CertifiedKey::new(certs, signing_key));
// Build per-domain certs
let mut domain_certs = HashMap::new();
for domain_cert in &config.additional_tls_certs {
match parse_certified_key(&domain_cert.cert_pem, &domain_cert.key_pem) {
Ok(ck) => {
let ck = Arc::new(ck);
for domain in &domain_cert.domains {
domain_certs.insert(domain.to_lowercase(), ck.clone());
}
info!("SNI cert loaded for domains: {:?}", domain_cert.domains);
}
Err(e) => {
warn!("Failed to load SNI cert for domains {:?}: {}", domain_cert.domains, e);
}
}
}
let resolver = SniCertResolver {
certs: domain_certs,
default: default_ck,
};
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver))
};
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(tls_config)))
}