feat(mailer-smtp): add SCRAM-SHA-256 auth, Ed25519 DKIM, opportunistic TLS, SNI cert selection, pipelining and delivery/bridge improvements
This commit is contained in:
@@ -23,3 +23,6 @@ rustls = { version = "0.23", default-features = false, features = ["ring", "logg
|
||||
rustls-pemfile = "2"
|
||||
mailparse.workspace = true
|
||||
webpki-roots = "0.26"
|
||||
sha2.workspace = true
|
||||
hmac.workspace = true
|
||||
pbkdf2.workspace = true
|
||||
|
||||
@@ -37,6 +37,12 @@ pub struct SmtpClientConfig {
|
||||
/// Maximum connections per pool. Default: 10.
|
||||
#[serde(default = "default_max_pool_connections")]
|
||||
pub max_pool_connections: usize,
|
||||
|
||||
/// Accept invalid TLS certificates (expired, self-signed, wrong hostname).
|
||||
/// Standard for MTA-to-MTA opportunistic TLS per RFC 7435.
|
||||
/// Default: false.
|
||||
#[serde(default)]
|
||||
pub tls_opportunistic: bool,
|
||||
}
|
||||
|
||||
/// Authentication configuration.
|
||||
@@ -60,8 +66,15 @@ pub struct DkimSignConfig {
|
||||
pub domain: String,
|
||||
/// DKIM selector (e.g. "default" or "mta").
|
||||
pub selector: String,
|
||||
/// PEM-encoded RSA private key.
|
||||
/// PEM-encoded private key (RSA or Ed25519 PKCS#8).
|
||||
pub private_key: String,
|
||||
/// Key type: "rsa" (default) or "ed25519".
|
||||
#[serde(default = "default_key_type")]
|
||||
pub key_type: String,
|
||||
}
|
||||
|
||||
fn default_key_type() -> String {
|
||||
"rsa".to_string()
|
||||
}
|
||||
|
||||
impl SmtpClientConfig {
|
||||
|
||||
@@ -117,6 +117,7 @@ pub async fn connect_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout_secs: u64,
|
||||
tls_opportunistic: bool,
|
||||
) -> Result<ClientSmtpStream, SmtpClientError> {
|
||||
debug!("Connecting to {}:{} (implicit TLS)", host, port);
|
||||
let addr = format!("{host}:{port}");
|
||||
@@ -130,7 +131,7 @@ pub async fn connect_tls(
|
||||
message: format!("Failed to connect to {addr}: {e}"),
|
||||
})?;
|
||||
|
||||
let tls_stream = perform_tls_handshake(tcp_stream, host).await?;
|
||||
let tls_stream = perform_tls_handshake(tcp_stream, host, tls_opportunistic).await?;
|
||||
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
|
||||
}
|
||||
|
||||
@@ -138,24 +139,77 @@ pub async fn connect_tls(
|
||||
pub async fn upgrade_to_tls(
|
||||
stream: ClientSmtpStream,
|
||||
hostname: &str,
|
||||
tls_opportunistic: bool,
|
||||
) -> Result<ClientSmtpStream, SmtpClientError> {
|
||||
debug!("Upgrading connection to TLS (STARTTLS) for {}", hostname);
|
||||
let tcp_stream = stream.into_tcp_stream()?;
|
||||
let tls_stream = perform_tls_handshake(tcp_stream, hostname).await?;
|
||||
let tls_stream = perform_tls_handshake(tcp_stream, hostname, tls_opportunistic).await?;
|
||||
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
|
||||
}
|
||||
|
||||
/// A TLS certificate verifier that accepts any certificate.
|
||||
/// Used for MTA-to-MTA opportunistic TLS per RFC 7435.
|
||||
#[derive(Debug)]
|
||||
struct OpportunisticVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for OpportunisticVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls_pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.signature_verification_algorithms
|
||||
.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the TLS handshake on a TCP stream using webpki-roots.
|
||||
/// When `tls_opportunistic` is true, certificate verification is skipped
|
||||
/// (standard for MTA-to-MTA delivery per RFC 7435).
|
||||
async fn perform_tls_handshake(
|
||||
tcp_stream: TcpStream,
|
||||
hostname: &str,
|
||||
tls_opportunistic: bool,
|
||||
) -> Result<TlsStream<TcpStream>, SmtpClientError> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let tls_config = if tls_opportunistic {
|
||||
debug!("Using opportunistic TLS (no cert verification) for {}", hostname);
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(OpportunisticVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config));
|
||||
let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string()).map_err(|e| {
|
||||
@@ -190,7 +244,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_tls_refused() {
|
||||
let result = connect_tls("127.0.0.1", 19998, 2).await;
|
||||
let result = connect_tls("127.0.0.1", 19998, 2, false).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ impl ConnectionPool {
|
||||
&self.config.host,
|
||||
self.config.port,
|
||||
self.config.connection_timeout_secs,
|
||||
self.config.tls_opportunistic,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
@@ -139,7 +140,7 @@ impl ConnectionPool {
|
||||
if !self.config.secure && capabilities.starttls {
|
||||
protocol::send_starttls(&mut stream, self.config.socket_timeout_secs).await?;
|
||||
stream =
|
||||
super::connection::upgrade_to_tls(stream, &self.config.host).await?;
|
||||
super::connection::upgrade_to_tls(stream, &self.config.host, self.config.tls_opportunistic).await?;
|
||||
|
||||
// Re-EHLO after STARTTLS — use updated capabilities for auth
|
||||
capabilities = protocol::send_ehlo(
|
||||
@@ -244,9 +245,10 @@ impl SmtpClientManager {
|
||||
protocol::send_rset(&mut conn.stream, config.socket_timeout_secs).await?;
|
||||
}
|
||||
|
||||
// Perform the SMTP transaction
|
||||
// Perform the SMTP transaction (use pipelining if server supports it)
|
||||
let pipelining = conn.capabilities.pipelining;
|
||||
let result =
|
||||
Self::perform_send(&mut conn.stream, sender, recipients, message, config).await;
|
||||
Self::perform_send(&mut conn.stream, sender, recipients, message, config, pipelining).await;
|
||||
|
||||
// Re-acquire the pool lock and release the connection
|
||||
let mut pool = pool_arc.lock().await;
|
||||
@@ -268,30 +270,39 @@ impl SmtpClientManager {
|
||||
recipients: &[String],
|
||||
message: &[u8],
|
||||
config: &SmtpClientConfig,
|
||||
pipelining: bool,
|
||||
) -> Result<SmtpSendResult, SmtpClientError> {
|
||||
let timeout_secs = config.socket_timeout_secs;
|
||||
|
||||
// MAIL FROM
|
||||
protocol::send_mail_from(stream, sender, timeout_secs).await?;
|
||||
let (accepted, rejected) = if pipelining {
|
||||
// Use pipelined envelope: MAIL FROM + all RCPT TO in one batch
|
||||
let (_mail_ok, acc, rej) = protocol::send_pipelined_envelope(
|
||||
stream, sender, recipients, timeout_secs,
|
||||
).await?;
|
||||
(acc, rej)
|
||||
} else {
|
||||
// Sequential: MAIL FROM, then each RCPT TO
|
||||
protocol::send_mail_from(stream, sender, timeout_secs).await?;
|
||||
|
||||
// RCPT TO for each recipient
|
||||
let mut accepted = Vec::new();
|
||||
let mut rejected = Vec::new();
|
||||
let mut accepted = Vec::new();
|
||||
let mut rejected = Vec::new();
|
||||
|
||||
for rcpt in recipients {
|
||||
match protocol::send_rcpt_to(stream, rcpt, timeout_secs).await {
|
||||
Ok(resp) => {
|
||||
if resp.is_success() {
|
||||
accepted.push(rcpt.clone());
|
||||
} else {
|
||||
for rcpt in recipients {
|
||||
match protocol::send_rcpt_to(stream, rcpt, timeout_secs).await {
|
||||
Ok(resp) => {
|
||||
if resp.is_success() {
|
||||
accepted.push(rcpt.clone());
|
||||
} else {
|
||||
rejected.push(rcpt.clone());
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
rejected.push(rcpt.clone());
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
rejected.push(rcpt.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
(accepted, rejected)
|
||||
};
|
||||
|
||||
// If no recipients were accepted, fail
|
||||
if accepted.is_empty() {
|
||||
@@ -339,6 +350,7 @@ impl SmtpClientManager {
|
||||
&config.host,
|
||||
config.port,
|
||||
config.connection_timeout_secs,
|
||||
config.tls_opportunistic,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
|
||||
@@ -318,6 +318,54 @@ pub async fn send_rcpt_to(
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
/// Send MAIL FROM + RCPT TO commands in a single pipelined batch.
|
||||
///
|
||||
/// Writes all envelope commands at once, then reads responses in order.
|
||||
/// Returns `(mail_from_ok, accepted_recipients, rejected_recipients)`.
|
||||
pub async fn send_pipelined_envelope(
|
||||
stream: &mut ClientSmtpStream,
|
||||
sender: &str,
|
||||
recipients: &[String],
|
||||
timeout_secs: u64,
|
||||
) -> Result<(bool, Vec<String>, Vec<String>), SmtpClientError> {
|
||||
// Build the full pipelined command batch
|
||||
let mut batch = format!("MAIL FROM:<{sender}>\r\n");
|
||||
for rcpt in recipients {
|
||||
batch.push_str(&format!("RCPT TO:<{rcpt}>\r\n"));
|
||||
}
|
||||
|
||||
// Send all commands at once
|
||||
debug!("SMTP C (pipelined): MAIL FROM + {} RCPT TO", recipients.len());
|
||||
stream.write_all(batch.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// Read MAIL FROM response
|
||||
let mail_resp = read_response(stream, timeout_secs).await?;
|
||||
if !mail_resp.is_success() {
|
||||
return Err(mail_resp.to_error());
|
||||
}
|
||||
|
||||
// Read RCPT TO responses
|
||||
let mut accepted = Vec::new();
|
||||
let mut rejected = Vec::new();
|
||||
for rcpt in recipients {
|
||||
match read_response(stream, timeout_secs).await {
|
||||
Ok(resp) => {
|
||||
if resp.is_success() {
|
||||
accepted.push(rcpt.clone());
|
||||
} else {
|
||||
rejected.push(rcpt.clone());
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
rejected.push(rcpt.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((true, accepted, rejected))
|
||||
}
|
||||
|
||||
/// Send DATA command, followed by the message body with dot-stuffing.
|
||||
pub async fn send_data(
|
||||
stream: &mut ClientSmtpStream,
|
||||
|
||||
@@ -50,6 +50,7 @@ pub enum SmtpCommand {
|
||||
pub enum AuthMechanism {
|
||||
Plain,
|
||||
Login,
|
||||
ScramSha256,
|
||||
}
|
||||
|
||||
/// Errors that can occur during command parsing.
|
||||
@@ -218,6 +219,7 @@ fn parse_auth(rest: &str) -> Result<SmtpCommand, ParseError> {
|
||||
let mechanism = match mech_str.to_ascii_uppercase().as_str() {
|
||||
"PLAIN" => AuthMechanism::Plain,
|
||||
"LOGIN" => AuthMechanism::Login,
|
||||
"SCRAM-SHA-256" => AuthMechanism::ScramSha256,
|
||||
other => {
|
||||
return Err(ParseError::SyntaxError(format!(
|
||||
"unsupported AUTH mechanism: {other}"
|
||||
|
||||
@@ -2,6 +2,17 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Per-domain TLS certificate for SNI-based cert selection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsDomainCert {
|
||||
/// Domain names this certificate covers (matched against SNI hostname).
|
||||
pub domains: Vec<String>,
|
||||
/// Certificate chain in PEM format.
|
||||
pub cert_pem: String,
|
||||
/// Private key in PEM format.
|
||||
pub key_pem: String,
|
||||
}
|
||||
|
||||
/// Configuration for an SMTP server instance.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SmtpServerConfig {
|
||||
@@ -11,10 +22,13 @@ pub struct SmtpServerConfig {
|
||||
pub ports: Vec<u16>,
|
||||
/// Port for implicit TLS (e.g. 465). None = no implicit TLS port.
|
||||
pub secure_port: Option<u16>,
|
||||
/// TLS certificate chain in PEM format.
|
||||
/// TLS certificate chain in PEM format (default cert).
|
||||
pub tls_cert_pem: Option<String>,
|
||||
/// TLS private key in PEM format.
|
||||
/// TLS private key in PEM format (default key).
|
||||
pub tls_key_pem: Option<String>,
|
||||
/// Additional per-domain TLS certificates for SNI-based selection.
|
||||
#[serde(default)]
|
||||
pub additional_tls_certs: Vec<TlsDomainCert>,
|
||||
/// Maximum message size in bytes.
|
||||
pub max_message_size: u64,
|
||||
/// Maximum number of concurrent connections.
|
||||
@@ -43,6 +57,7 @@ impl Default for SmtpServerConfig {
|
||||
secure_port: None,
|
||||
tls_cert_pem: None,
|
||||
tls_key_pem: None,
|
||||
additional_tls_certs: Vec::new(),
|
||||
max_message_size: 10 * 1024 * 1024, // 10 MB
|
||||
max_connections: 100,
|
||||
max_recipients: 100,
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::config::SmtpServerConfig;
|
||||
use crate::data::{DataAccumulator, DataAction};
|
||||
use crate::rate_limiter::RateLimiter;
|
||||
use crate::response::{build_capabilities, SmtpResponse};
|
||||
use crate::scram::{ScramCredentials, ScramServer};
|
||||
use crate::session::{AuthState, SmtpSession};
|
||||
use crate::validation;
|
||||
|
||||
@@ -52,6 +53,13 @@ pub enum ConnectionEvent {
|
||||
password: String,
|
||||
remote_addr: String,
|
||||
},
|
||||
/// A SCRAM credential request — Rust needs stored credentials from TS.
|
||||
ScramCredentialRequest {
|
||||
correlation_id: String,
|
||||
session_id: String,
|
||||
username: String,
|
||||
remote_addr: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// How email data is transported from Rust to TS.
|
||||
@@ -81,6 +89,16 @@ pub struct AuthResult {
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of TS returning SCRAM credentials for a user.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScramCredentialResult {
|
||||
pub found: bool,
|
||||
pub salt: Option<Vec<u8>>,
|
||||
pub iterations: Option<u32>,
|
||||
pub stored_key: Option<Vec<u8>>,
|
||||
pub server_key: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// Abstraction over plain and TLS streams.
|
||||
pub enum SmtpStream {
|
||||
Plain(BufReader<TcpStream>),
|
||||
@@ -133,6 +151,14 @@ impl SmtpStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the internal buffer has unread data (pipelined commands).
|
||||
pub fn has_buffered_data(&self) -> bool {
|
||||
match self {
|
||||
SmtpStream::Plain(reader) => !reader.buffer().is_empty(),
|
||||
SmtpStream::Tls(reader) => !reader.buffer().is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Unwrap to get the raw TcpStream for STARTTLS upgrade.
|
||||
/// Only works on Plain streams.
|
||||
pub fn into_tcp_stream(self) -> Option<TcpStream> {
|
||||
@@ -212,7 +238,7 @@ pub async fn handle_connection(
|
||||
break;
|
||||
}
|
||||
Ok(Ok(_)) => {
|
||||
// Process command
|
||||
// Process the first command
|
||||
let response = process_line(
|
||||
&line,
|
||||
&mut session,
|
||||
@@ -227,59 +253,123 @@ pub async fn handle_connection(
|
||||
)
|
||||
.await;
|
||||
|
||||
// Check for pipelined commands in the buffer.
|
||||
// Collect pipelinable responses into a batch for single write.
|
||||
let mut response_batch: Vec<u8> = Vec::new();
|
||||
let mut should_break = false;
|
||||
let mut starttls_signal = false;
|
||||
|
||||
match response {
|
||||
LineResult::Response(resp) => {
|
||||
if stream.write_all(&resp.to_bytes()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
break;
|
||||
}
|
||||
response_batch.extend_from_slice(&resp.to_bytes());
|
||||
}
|
||||
LineResult::Quit(resp) => {
|
||||
let _ = stream.write_all(&resp.to_bytes()).await;
|
||||
let _ = stream.flush().await;
|
||||
break;
|
||||
should_break = true;
|
||||
}
|
||||
LineResult::StartTlsSignal => {
|
||||
// Send 220 Ready response
|
||||
let resp = SmtpResponse::new(220, "Ready to start TLS");
|
||||
if stream.write_all(&resp.to_bytes()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
break;
|
||||
}
|
||||
// Extract TCP stream and upgrade
|
||||
if let Some(tcp_stream) = stream.into_tcp_stream() {
|
||||
if let Some(acceptor) = &tls_acceptor {
|
||||
match acceptor.accept(tcp_stream).await {
|
||||
Ok(tls_stream) => {
|
||||
stream = SmtpStream::Tls(BufReader::new(tls_stream));
|
||||
session.secure = true;
|
||||
// Client must re-EHLO after STARTTLS
|
||||
session.state = crate::state::SmtpState::Connected;
|
||||
session.client_hostname = None;
|
||||
session.esmtp = false;
|
||||
session.auth_state = AuthState::None;
|
||||
session.envelope = Default::default();
|
||||
debug!(session_id = %session.id, "TLS upgrade successful");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(session_id = %session.id, error = %e, "TLS handshake failed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Already TLS — shouldn't happen
|
||||
break;
|
||||
}
|
||||
starttls_signal = true;
|
||||
}
|
||||
LineResult::NoResponse => {}
|
||||
LineResult::Disconnect => {
|
||||
should_break = true;
|
||||
}
|
||||
}
|
||||
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
|
||||
// Process additional pipelined commands from the buffer
|
||||
if !starttls_signal {
|
||||
while stream.has_buffered_data() {
|
||||
let mut next_line = String::new();
|
||||
match stream.read_line(&mut next_line, 4096).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(_) => {
|
||||
let next_response = process_line(
|
||||
&next_line,
|
||||
&mut session,
|
||||
&mut stream,
|
||||
&config,
|
||||
&rate_limiter,
|
||||
&event_tx,
|
||||
callback_register.as_ref(),
|
||||
&tls_acceptor,
|
||||
&authenticator,
|
||||
&resolver,
|
||||
)
|
||||
.await;
|
||||
|
||||
match next_response {
|
||||
LineResult::Response(resp) => {
|
||||
response_batch.extend_from_slice(&resp.to_bytes());
|
||||
}
|
||||
LineResult::Quit(resp) => {
|
||||
response_batch.extend_from_slice(&resp.to_bytes());
|
||||
should_break = true;
|
||||
break;
|
||||
}
|
||||
LineResult::StartTlsSignal | LineResult::Disconnect => {
|
||||
// Non-pipelinable: flush batch and handle
|
||||
starttls_signal = matches!(next_response, LineResult::StartTlsSignal);
|
||||
should_break = matches!(next_response, LineResult::Disconnect);
|
||||
break;
|
||||
}
|
||||
LineResult::NoResponse => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the accumulated response batch in one write
|
||||
if !response_batch.is_empty() {
|
||||
if stream.write_all(&response_batch).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
|
||||
if starttls_signal {
|
||||
// Send 220 Ready response
|
||||
let resp = SmtpResponse::new(220, "Ready to start TLS");
|
||||
if stream.write_all(&resp.to_bytes()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
break;
|
||||
}
|
||||
// Extract TCP stream and upgrade
|
||||
if let Some(tcp_stream) = stream.into_tcp_stream() {
|
||||
if let Some(acceptor) = &tls_acceptor {
|
||||
match acceptor.accept(tcp_stream).await {
|
||||
Ok(tls_stream) => {
|
||||
stream = SmtpStream::Tls(BufReader::new(tls_stream));
|
||||
session.secure = true;
|
||||
session.state = crate::state::SmtpState::Connected;
|
||||
session.client_hostname = None;
|
||||
session.esmtp = false;
|
||||
session.auth_state = AuthState::None;
|
||||
session.envelope = Default::default();
|
||||
debug!(session_id = %session.id, "TLS upgrade successful");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(session_id = %session.id, error = %e, "TLS handshake failed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -322,6 +412,12 @@ pub trait CallbackRegistry: Send + Sync {
|
||||
&self,
|
||||
correlation_id: &str,
|
||||
) -> oneshot::Receiver<AuthResult>;
|
||||
|
||||
/// Register a callback for SCRAM credential lookup and return a receiver.
|
||||
fn register_scram_callback(
|
||||
&self,
|
||||
correlation_id: &str,
|
||||
) -> oneshot::Receiver<ScramCredentialResult>;
|
||||
}
|
||||
|
||||
/// Process a single input line from the client.
|
||||
@@ -406,16 +502,29 @@ async fn process_line(
|
||||
mechanism,
|
||||
initial_response,
|
||||
} => {
|
||||
handle_auth(
|
||||
mechanism,
|
||||
initial_response,
|
||||
session,
|
||||
config,
|
||||
rate_limiter,
|
||||
event_tx,
|
||||
callback_registry,
|
||||
)
|
||||
.await
|
||||
if matches!(mechanism, AuthMechanism::ScramSha256) {
|
||||
handle_auth_scram(
|
||||
initial_response,
|
||||
session,
|
||||
stream,
|
||||
config,
|
||||
rate_limiter,
|
||||
event_tx,
|
||||
callback_registry,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
handle_auth(
|
||||
mechanism,
|
||||
initial_response,
|
||||
session,
|
||||
config,
|
||||
rate_limiter,
|
||||
event_tx,
|
||||
callback_registry,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
SmtpCommand::Help(_) => {
|
||||
@@ -832,6 +941,217 @@ async fn handle_auth(
|
||||
))
|
||||
}
|
||||
}
|
||||
AuthMechanism::ScramSha256 => {
|
||||
// SCRAM is handled separately in process_line; this should not be reached.
|
||||
LineResult::Response(SmtpResponse::not_implemented())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle AUTH SCRAM-SHA-256 — full exchange in a single async function.
|
||||
///
|
||||
/// SCRAM is a multi-step challenge-response protocol:
|
||||
/// 1. Client sends client-first-message (in initial_response or after 334)
|
||||
/// 2. Server requests SCRAM credentials from TS
|
||||
/// 3. Server sends server-first-message (334 challenge)
|
||||
/// 4. Client sends client-final-message (proof)
|
||||
/// 5. Server verifies proof and responds with 235 or 535
|
||||
async fn handle_auth_scram(
|
||||
initial_response: Option<String>,
|
||||
session: &mut SmtpSession,
|
||||
stream: &mut SmtpStream,
|
||||
config: &SmtpServerConfig,
|
||||
rate_limiter: &RateLimiter,
|
||||
event_tx: &mpsc::Sender<ConnectionEvent>,
|
||||
callback_registry: &dyn CallbackRegistry,
|
||||
) -> LineResult {
|
||||
if !config.auth_enabled {
|
||||
return LineResult::Response(SmtpResponse::not_implemented());
|
||||
}
|
||||
|
||||
if session.is_authenticated() {
|
||||
return LineResult::Response(SmtpResponse::bad_sequence("Already authenticated"));
|
||||
}
|
||||
|
||||
if !session.state.can_auth() {
|
||||
return LineResult::Response(SmtpResponse::bad_sequence("Send EHLO first"));
|
||||
}
|
||||
|
||||
// Step 1: Get client-first-message
|
||||
let client_first_b64 = match initial_response {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => {
|
||||
// No initial response — send empty 334 challenge
|
||||
let resp = SmtpResponse::auth_challenge("");
|
||||
if stream.write_all(&resp.to_bytes()).await.is_err() {
|
||||
return LineResult::Disconnect;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
return LineResult::Disconnect;
|
||||
}
|
||||
// Read client-first-message
|
||||
let mut line = String::new();
|
||||
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
|
||||
match timeout(socket_timeout, stream.read_line(&mut line, 4096)).await {
|
||||
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return LineResult::Disconnect,
|
||||
Ok(Ok(_)) => {}
|
||||
}
|
||||
let trimmed = line.trim().to_string();
|
||||
if trimmed == "*" {
|
||||
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
|
||||
}
|
||||
trimmed
|
||||
}
|
||||
};
|
||||
|
||||
// Decode base64 client-first-message
|
||||
let client_first_bytes = match BASE64.decode(client_first_b64.as_bytes()) {
|
||||
Ok(b) => b,
|
||||
Err(_) => {
|
||||
return LineResult::Response(SmtpResponse::param_error("Invalid base64 encoding"));
|
||||
}
|
||||
};
|
||||
let client_first = match String::from_utf8(client_first_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return LineResult::Response(SmtpResponse::param_error("Invalid UTF-8 in SCRAM message"));
|
||||
}
|
||||
};
|
||||
|
||||
// Parse client-first-message
|
||||
let mut scram = match ScramServer::from_client_first(&client_first) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
debug!(error = %e, "SCRAM client-first-message parse error");
|
||||
return LineResult::Response(SmtpResponse::param_error(
|
||||
"Invalid SCRAM client-first-message",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Step 2: Request SCRAM credentials from TS
|
||||
let correlation_id = uuid::Uuid::new_v4().to_string();
|
||||
let rx = callback_registry.register_scram_callback(&correlation_id);
|
||||
|
||||
let event = ConnectionEvent::ScramCredentialRequest {
|
||||
correlation_id: correlation_id.clone(),
|
||||
session_id: session.id.clone(),
|
||||
username: scram.username.clone(),
|
||||
remote_addr: session.remote_addr.clone(),
|
||||
};
|
||||
|
||||
if event_tx.send(event).await.is_err() {
|
||||
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
|
||||
}
|
||||
|
||||
// Wait for credentials from TS
|
||||
let cred_timeout = Duration::from_secs(5);
|
||||
let cred_result = match timeout(cred_timeout, rx).await {
|
||||
Ok(Ok(result)) => result,
|
||||
Ok(Err(_)) => {
|
||||
warn!(correlation_id = %correlation_id, "SCRAM credential callback dropped");
|
||||
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(correlation_id = %correlation_id, "SCRAM credential request timed out");
|
||||
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
|
||||
}
|
||||
};
|
||||
|
||||
if !cred_result.found {
|
||||
// User not found — fail auth (don't reveal that user doesn't exist)
|
||||
session.auth_state = AuthState::None;
|
||||
let exceeded = session.record_auth_failure(config.max_auth_failures);
|
||||
if exceeded {
|
||||
return LineResult::Quit(SmtpResponse::service_unavailable(
|
||||
&config.hostname,
|
||||
"Too many authentication failures",
|
||||
));
|
||||
}
|
||||
return LineResult::Response(SmtpResponse::auth_failed());
|
||||
}
|
||||
|
||||
let creds = ScramCredentials {
|
||||
salt: cred_result.salt.unwrap_or_default(),
|
||||
iterations: cred_result.iterations.unwrap_or(4096),
|
||||
stored_key: cred_result.stored_key.unwrap_or_default(),
|
||||
server_key: cred_result.server_key.unwrap_or_default(),
|
||||
};
|
||||
|
||||
// Step 3: Generate and send server-first-message
|
||||
let server_first = scram.server_first_message(creds);
|
||||
let server_first_b64 = BASE64.encode(server_first.as_bytes());
|
||||
|
||||
let challenge = SmtpResponse::auth_challenge(&server_first_b64);
|
||||
if stream.write_all(&challenge.to_bytes()).await.is_err() {
|
||||
return LineResult::Disconnect;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
return LineResult::Disconnect;
|
||||
}
|
||||
|
||||
// Step 4: Read client-final-message
|
||||
let mut client_final_line = String::new();
|
||||
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
|
||||
match timeout(socket_timeout, stream.read_line(&mut client_final_line, 4096)).await {
|
||||
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return LineResult::Disconnect,
|
||||
Ok(Ok(_)) => {}
|
||||
}
|
||||
|
||||
let client_final_b64 = client_final_line.trim();
|
||||
|
||||
// Cancel if *
|
||||
if client_final_b64 == "*" {
|
||||
session.auth_state = AuthState::None;
|
||||
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
|
||||
}
|
||||
|
||||
// Decode base64 client-final-message
|
||||
let client_final_bytes = match BASE64.decode(client_final_b64.as_bytes()) {
|
||||
Ok(b) => b,
|
||||
Err(_) => {
|
||||
session.auth_state = AuthState::None;
|
||||
return LineResult::Response(SmtpResponse::param_error("Invalid base64 encoding"));
|
||||
}
|
||||
};
|
||||
let client_final = match String::from_utf8(client_final_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
session.auth_state = AuthState::None;
|
||||
return LineResult::Response(SmtpResponse::param_error("Invalid UTF-8 in SCRAM message"));
|
||||
}
|
||||
};
|
||||
|
||||
// Step 5: Verify proof
|
||||
match scram.process_client_final(&client_final) {
|
||||
Ok(server_final) => {
|
||||
let server_final_b64 = BASE64.encode(server_final.as_bytes());
|
||||
session.auth_state = AuthState::Authenticated {
|
||||
username: scram.username.clone(),
|
||||
};
|
||||
LineResult::Response(SmtpResponse::new(
|
||||
235,
|
||||
format!("2.7.0 Authentication successful {}", server_final_b64),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(error = %e, "SCRAM proof verification failed");
|
||||
session.auth_state = AuthState::None;
|
||||
let exceeded = session.record_auth_failure(config.max_auth_failures);
|
||||
if exceeded {
|
||||
if !rate_limiter.check_auth_failure(&session.remote_addr) {
|
||||
return LineResult::Quit(SmtpResponse::service_unavailable(
|
||||
&config.hostname,
|
||||
"Too many authentication failures",
|
||||
));
|
||||
}
|
||||
return LineResult::Quit(SmtpResponse::service_unavailable(
|
||||
&config.hostname,
|
||||
"Too many authentication failures",
|
||||
));
|
||||
}
|
||||
LineResult::Response(SmtpResponse::auth_failed())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ pub mod connection;
|
||||
pub mod data;
|
||||
pub mod rate_limiter;
|
||||
pub mod response;
|
||||
pub mod scram;
|
||||
pub mod server;
|
||||
pub mod session;
|
||||
pub mod state;
|
||||
|
||||
@@ -196,7 +196,7 @@ pub fn build_capabilities(
|
||||
caps.push("STARTTLS".to_string());
|
||||
}
|
||||
if auth_available {
|
||||
caps.push("AUTH PLAIN LOGIN".to_string());
|
||||
caps.push("AUTH PLAIN LOGIN SCRAM-SHA-256".to_string());
|
||||
}
|
||||
caps
|
||||
}
|
||||
@@ -253,7 +253,7 @@ mod tests {
|
||||
let caps = build_capabilities(10485760, true, false, true);
|
||||
assert!(caps.contains(&"SIZE 10485760".to_string()));
|
||||
assert!(caps.contains(&"STARTTLS".to_string()));
|
||||
assert!(caps.contains(&"AUTH PLAIN LOGIN".to_string()));
|
||||
assert!(caps.contains(&"AUTH PLAIN LOGIN SCRAM-SHA-256".to_string()));
|
||||
assert!(caps.contains(&"PIPELINING".to_string()));
|
||||
}
|
||||
|
||||
@@ -262,7 +262,7 @@ mod tests {
|
||||
// When already secure, STARTTLS should NOT be advertised
|
||||
let caps = build_capabilities(10485760, true, true, false);
|
||||
assert!(!caps.contains(&"STARTTLS".to_string()));
|
||||
assert!(!caps.contains(&"AUTH PLAIN LOGIN".to_string()));
|
||||
assert!(!caps.contains(&"AUTH PLAIN LOGIN SCRAM-SHA-256".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
342
rust/crates/mailer-smtp/src/scram.rs
Normal file
342
rust/crates/mailer-smtp/src/scram.rs
Normal file
@@ -0,0 +1,342 @@
|
||||
//! SCRAM-SHA-256 server-side implementation (RFC 5802 + RFC 7677).
|
||||
//!
|
||||
//! Implements the server side of the SCRAM-SHA-256 SASL mechanism,
|
||||
//! a challenge-response protocol that avoids transmitting cleartext passwords.
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Pre-computed SCRAM credentials for a user (derived from password).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScramCredentials {
|
||||
pub salt: Vec<u8>,
|
||||
pub iterations: u32,
|
||||
pub stored_key: Vec<u8>,
|
||||
pub server_key: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Server-side SCRAM state machine.
|
||||
pub struct ScramServer {
|
||||
/// Username extracted from client-first-message.
|
||||
pub username: String,
|
||||
/// Full combined nonce (client + server).
|
||||
combined_nonce: String,
|
||||
/// Server nonce portion (used in tests for verification).
|
||||
#[allow(dead_code)]
|
||||
server_nonce: String,
|
||||
/// Stored credentials (set after TS responds).
|
||||
credentials: Option<ScramCredentials>,
|
||||
/// The server-first-message (for auth message construction).
|
||||
server_first: String,
|
||||
/// The client-first-message-bare (for auth message construction).
|
||||
client_first_bare: String,
|
||||
}
|
||||
|
||||
impl ScramServer {
|
||||
/// Process the client-first-message.
|
||||
///
|
||||
/// Parses the client nonce and username, generates a server nonce,
|
||||
/// and returns a partial state that needs credentials to produce the
|
||||
/// server-first-message.
|
||||
pub fn from_client_first(client_first: &str) -> Result<Self, String> {
|
||||
// client-first-message = gs2-header client-first-message-bare
|
||||
// gs2-header = "n,," (no channel binding)
|
||||
// client-first-message-bare = "n=username,r=nonce"
|
||||
let bare = if let Some(rest) = client_first.strip_prefix("n,,") {
|
||||
rest
|
||||
} else if let Some(rest) = client_first.strip_prefix("y,,") {
|
||||
rest
|
||||
} else {
|
||||
return Err("Invalid SCRAM gs2-header".into());
|
||||
};
|
||||
|
||||
let mut username = String::new();
|
||||
let mut client_nonce = String::new();
|
||||
|
||||
for part in bare.split(',') {
|
||||
if let Some(val) = part.strip_prefix("n=") {
|
||||
username = val.to_string();
|
||||
} else if let Some(val) = part.strip_prefix("r=") {
|
||||
client_nonce = val.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
if username.is_empty() || client_nonce.is_empty() {
|
||||
return Err("Missing username or nonce in client-first-message".into());
|
||||
}
|
||||
|
||||
// Generate server nonce
|
||||
let server_nonce: String = (0..24)
|
||||
.map(|_| {
|
||||
let idx = (rand_byte() as usize) % 62;
|
||||
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[idx] as char
|
||||
})
|
||||
.collect();
|
||||
|
||||
let combined_nonce = format!("{}{}", client_nonce, server_nonce);
|
||||
|
||||
Ok(ScramServer {
|
||||
username,
|
||||
combined_nonce,
|
||||
server_nonce,
|
||||
credentials: None,
|
||||
server_first: String::new(),
|
||||
client_first_bare: bare.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the credentials and produce the server-first-message.
|
||||
pub fn server_first_message(&mut self, creds: ScramCredentials) -> String {
|
||||
let salt_b64 = BASE64.encode(&creds.salt);
|
||||
let server_first = format!(
|
||||
"r={},s={},i={}",
|
||||
self.combined_nonce, salt_b64, creds.iterations
|
||||
);
|
||||
|
||||
self.server_first = server_first.clone();
|
||||
self.credentials = Some(creds);
|
||||
server_first
|
||||
}
|
||||
|
||||
/// Process the client-final-message and verify the proof.
|
||||
///
|
||||
/// Returns the server-final-message (containing ServerSignature) on success,
|
||||
/// or an error string on failure.
|
||||
pub fn process_client_final(&mut self, client_final: &str) -> Result<String, String> {
|
||||
let creds = self.credentials.as_ref().ok_or("No credentials set")?;
|
||||
|
||||
// Parse client-final-message
|
||||
// Format: c=biws,r=<combined_nonce>,p=<client_proof>
|
||||
let mut channel_binding = String::new();
|
||||
let mut nonce = String::new();
|
||||
let mut proof_b64 = String::new();
|
||||
|
||||
for part in client_final.split(',') {
|
||||
if let Some(val) = part.strip_prefix("c=") {
|
||||
channel_binding = val.to_string();
|
||||
} else if let Some(val) = part.strip_prefix("r=") {
|
||||
nonce = val.to_string();
|
||||
} else if let Some(val) = part.strip_prefix("p=") {
|
||||
proof_b64 = val.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Verify nonce matches
|
||||
if nonce != self.combined_nonce {
|
||||
return Err("Nonce mismatch".into());
|
||||
}
|
||||
|
||||
// Build the client-final-message-without-proof
|
||||
let client_final_without_proof = format!("c={},r={}", channel_binding, nonce);
|
||||
|
||||
// Complete the auth message
|
||||
let auth_message = format!(
|
||||
"{},{},{}",
|
||||
self.client_first_bare, self.server_first, client_final_without_proof
|
||||
);
|
||||
|
||||
// Verify client proof
|
||||
let client_proof = BASE64.decode(proof_b64.as_bytes())
|
||||
.map_err(|_| "Invalid base64 in client proof")?;
|
||||
|
||||
// ClientSignature = HMAC(StoredKey, AuthMessage)
|
||||
let client_signature = hmac_sha256(&creds.stored_key, auth_message.as_bytes());
|
||||
|
||||
// ClientKey = ClientProof XOR ClientSignature
|
||||
if client_proof.len() != client_signature.len() {
|
||||
return Err("Client proof length mismatch".into());
|
||||
}
|
||||
let client_key: Vec<u8> = client_proof
|
||||
.iter()
|
||||
.zip(client_signature.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect();
|
||||
|
||||
// StoredKey = H(ClientKey)
|
||||
let computed_stored_key = sha256(&client_key);
|
||||
|
||||
// Verify: computed StoredKey must match the stored StoredKey
|
||||
if computed_stored_key != creds.stored_key {
|
||||
return Err("Authentication failed".into());
|
||||
}
|
||||
|
||||
// Generate ServerSignature for mutual authentication
|
||||
let server_signature = hmac_sha256(&creds.server_key, auth_message.as_bytes());
|
||||
let server_sig_b64 = BASE64.encode(&server_signature);
|
||||
|
||||
Ok(format!("v={}", server_sig_b64))
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute SCRAM credentials from a plaintext password (for TS to pre-compute).
|
||||
pub fn compute_scram_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
|
||||
// SaltedPassword = PBKDF2-HMAC-SHA256(password, salt, iterations)
|
||||
let mut salted_password = [0u8; 32];
|
||||
pbkdf2::pbkdf2_hmac::<Sha256>(
|
||||
password.as_bytes(),
|
||||
salt,
|
||||
iterations,
|
||||
&mut salted_password,
|
||||
);
|
||||
|
||||
// ClientKey = HMAC(SaltedPassword, "Client Key")
|
||||
let client_key = hmac_sha256(&salted_password, b"Client Key");
|
||||
|
||||
// StoredKey = H(ClientKey)
|
||||
let stored_key = sha256(&client_key);
|
||||
|
||||
// ServerKey = HMAC(SaltedPassword, "Server Key")
|
||||
let server_key = hmac_sha256(&salted_password, b"Server Key");
|
||||
|
||||
ScramCredentials {
|
||||
salt: salt.to_vec(),
|
||||
iterations,
|
||||
stored_key,
|
||||
server_key,
|
||||
}
|
||||
}
|
||||
|
||||
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
|
||||
mac.update(data);
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
fn sha256(data: &[u8]) -> Vec<u8> {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
/// Simple random byte using system randomness.
|
||||
fn rand_byte() -> u8 {
|
||||
use std::collections::hash_map::RandomState;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
let state = RandomState::new();
|
||||
let mut hasher = state.build_hasher();
|
||||
hasher.write_u64(std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos() as u64);
|
||||
hasher.finish() as u8
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scram_full_exchange() {
|
||||
let password = "pencil";
|
||||
let salt = b"test-salt-1234";
|
||||
let iterations = 4096;
|
||||
|
||||
// Pre-compute server-side credentials from password
|
||||
let creds = compute_scram_credentials(password, salt, iterations);
|
||||
|
||||
// 1. Client sends client-first-message
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let mut server = ScramServer::from_client_first(client_first).unwrap();
|
||||
assert_eq!(server.username, "user");
|
||||
|
||||
// 2. Server responds with server-first-message
|
||||
let server_first = server.server_first_message(creds.clone());
|
||||
assert!(server_first.starts_with(&format!("r=rOprNGfwEbeRWgbNEkqO{}", server.server_nonce)));
|
||||
assert!(server_first.contains("s="));
|
||||
assert!(server_first.contains("i=4096"));
|
||||
|
||||
// 3. Client computes proof
|
||||
// SaltedPassword
|
||||
let mut salted_password = [0u8; 32];
|
||||
pbkdf2::pbkdf2_hmac::<Sha256>(
|
||||
password.as_bytes(),
|
||||
salt,
|
||||
iterations,
|
||||
&mut salted_password,
|
||||
);
|
||||
|
||||
let client_key = hmac_sha256(&salted_password, b"Client Key");
|
||||
let stored_key = sha256(&client_key);
|
||||
|
||||
let client_first_bare = "n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final_without_proof = format!("c=biws,r={}", server.combined_nonce);
|
||||
let auth_message = format!("{},{},{}", client_first_bare, server_first, client_final_without_proof);
|
||||
|
||||
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
|
||||
let client_proof: Vec<u8> = client_key
|
||||
.iter()
|
||||
.zip(client_signature.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect();
|
||||
let proof_b64 = BASE64.encode(&client_proof);
|
||||
|
||||
let client_final = format!("c=biws,r={},p={}", server.combined_nonce, proof_b64);
|
||||
|
||||
// 4. Server verifies proof
|
||||
let result = server.process_client_final(&client_final);
|
||||
assert!(result.is_ok(), "SCRAM verification failed: {:?}", result.err());
|
||||
let server_final = result.unwrap();
|
||||
assert!(server_final.starts_with("v="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scram_wrong_password() {
|
||||
let password = "pencil";
|
||||
let wrong_password = "wrong";
|
||||
let salt = b"test-salt";
|
||||
let iterations = 4096;
|
||||
|
||||
let creds = compute_scram_credentials(password, salt, iterations);
|
||||
|
||||
let client_first = "n,,n=user,r=clientnonce123";
|
||||
let mut server = ScramServer::from_client_first(client_first).unwrap();
|
||||
let server_first = server.server_first_message(creds);
|
||||
|
||||
// Client computes proof with wrong password
|
||||
let mut salted_password = [0u8; 32];
|
||||
pbkdf2::pbkdf2_hmac::<Sha256>(
|
||||
wrong_password.as_bytes(),
|
||||
salt,
|
||||
iterations,
|
||||
&mut salted_password,
|
||||
);
|
||||
|
||||
let client_key = hmac_sha256(&salted_password, b"Client Key");
|
||||
let stored_key = sha256(&client_key);
|
||||
|
||||
let client_first_bare = "n=user,r=clientnonce123";
|
||||
let client_final_without_proof = format!("c=biws,r={}", server.combined_nonce);
|
||||
let auth_message = format!("{},{},{}", client_first_bare, server_first, client_final_without_proof);
|
||||
|
||||
let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
|
||||
let client_proof: Vec<u8> = client_key
|
||||
.iter()
|
||||
.zip(client_signature.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect();
|
||||
let proof_b64 = BASE64.encode(&client_proof);
|
||||
|
||||
let client_final = format!("c=biws,r={},p={}", server.combined_nonce, proof_b64);
|
||||
let result = server.process_client_final(&client_final);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_scram_credentials() {
|
||||
let creds = compute_scram_credentials("password", b"salt", 4096);
|
||||
assert_eq!(creds.salt, b"salt");
|
||||
assert_eq!(creds.iterations, 4096);
|
||||
assert_eq!(creds.stored_key.len(), 32);
|
||||
assert_eq!(creds.server_key.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_client_first() {
|
||||
assert!(ScramServer::from_client_first("invalid").is_err());
|
||||
assert!(ScramServer::from_client_first("n,,").is_err());
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ use crate::rate_limiter::{RateLimitConfig, RateLimiter};
|
||||
use hickory_resolver::TokioResolver;
|
||||
use mailer_security::MessageAuthenticator;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufReader;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
@@ -263,6 +264,69 @@ async fn accept_loop(
|
||||
}
|
||||
}
|
||||
|
||||
/// SNI-based certificate resolver that selects the appropriate TLS certificate
|
||||
/// based on the client's requested hostname.
|
||||
struct SniCertResolver {
|
||||
/// Domain -> certified key mapping.
|
||||
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
|
||||
/// Default certificate for non-matching SNI or missing SNI.
|
||||
default: Arc<rustls::sign::CertifiedKey>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SniCertResolver {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SniCertResolver")
|
||||
.field("domains", &self.certs.keys().collect::<Vec<_>>())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for SniCertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
if let Some(sni) = client_hello.server_name() {
|
||||
let sni_lower = sni.to_lowercase();
|
||||
if let Some(key) = self.certs.get(&sni_lower) {
|
||||
return Some(key.clone());
|
||||
}
|
||||
}
|
||||
Some(self.default.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a PEM cert+key pair into a `CertifiedKey`.
|
||||
fn parse_certified_key(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<rustls::sign::CertifiedKey, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let certs: Vec<CertificateDer<'static>> = {
|
||||
let mut reader = BufReader::new(cert_pem.as_bytes());
|
||||
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?
|
||||
};
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in PEM".into());
|
||||
}
|
||||
|
||||
let key: PrivateKeyDer<'static> = {
|
||||
let mut reader = BufReader::new(key_pem.as_bytes());
|
||||
let mut keys = Vec::new();
|
||||
for item in rustls_pemfile::read_all(&mut reader) {
|
||||
match item? {
|
||||
rustls_pemfile::Item::Pkcs8Key(key) => keys.push(PrivateKeyDer::Pkcs8(key)),
|
||||
rustls_pemfile::Item::Pkcs1Key(key) => keys.push(PrivateKeyDer::Pkcs1(key)),
|
||||
rustls_pemfile::Item::Sec1Key(key) => keys.push(PrivateKeyDer::Sec1(key)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
keys.into_iter().next().ok_or("No private key found in PEM")?
|
||||
};
|
||||
|
||||
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)?;
|
||||
Ok(rustls::sign::CertifiedKey::new(certs, signing_key))
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor from PEM cert/key strings.
|
||||
fn build_tls_acceptor(
|
||||
config: &SmtpServerConfig,
|
||||
@@ -311,9 +375,42 @@ fn build_tls_acceptor(
|
||||
.ok_or("No private key found in PEM")?
|
||||
};
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?;
|
||||
// If additional TLS certs are configured, use SNI-based resolution
|
||||
let tls_config = if config.additional_tls_certs.is_empty() {
|
||||
rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
} else {
|
||||
// Build default certified key
|
||||
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)?;
|
||||
let default_ck = Arc::new(rustls::sign::CertifiedKey::new(certs, signing_key));
|
||||
|
||||
// Build per-domain certs
|
||||
let mut domain_certs = HashMap::new();
|
||||
for domain_cert in &config.additional_tls_certs {
|
||||
match parse_certified_key(&domain_cert.cert_pem, &domain_cert.key_pem) {
|
||||
Ok(ck) => {
|
||||
let ck = Arc::new(ck);
|
||||
for domain in &domain_cert.domains {
|
||||
domain_certs.insert(domain.to_lowercase(), ck.clone());
|
||||
}
|
||||
info!("SNI cert loaded for domains: {:?}", domain_cert.domains);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to load SNI cert for domains {:?}: {}", domain_cert.domains, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let resolver = SniCertResolver {
|
||||
certs: domain_certs,
|
||||
default: default_ck,
|
||||
};
|
||||
|
||||
rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(resolver))
|
||||
};
|
||||
|
||||
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(tls_config)))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user