Files
smartmta/rust/crates/mailer-smtp/src/connection.rs

1629 lines
55 KiB
Rust

//! Per-connection SMTP handler.
//!
//! Manages the read/write loop for a single SMTP connection.
//! Dispatches parsed commands, handles DATA mode, and manages
//! authentication flow.
use crate::command::{parse_command, AuthMechanism, ParseError, SmtpCommand};
use crate::config::SmtpServerConfig;
use crate::data::{DataAccumulator, DataAction};
use crate::rate_limiter::RateLimiter;
use crate::response::{build_capabilities, SmtpResponse};
use crate::scram::{ScramCredentials, ScramServer};
use crate::session::{AuthState, SmtpSession};
use crate::validation;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use hickory_resolver::TokioResolver;
use mailer_security::MessageAuthenticator;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tokio::time::{timeout, Duration};
use tokio_rustls::server::TlsStream;
use tracing::{debug, info, warn};
/// Events emitted by a connection handler to the server.
#[derive(Debug, Serialize, Deserialize)]
pub enum ConnectionEvent {
/// A complete email has been received and needs processing.
EmailReceived {
correlation_id: String,
session_id: String,
mail_from: String,
rcpt_to: Vec<String>,
/// Base64-encoded raw message for inline, or file path for large messages.
data: EmailData,
remote_addr: String,
client_hostname: Option<String>,
secure: bool,
authenticated_user: Option<String>,
/// In-process security results (DKIM, SPF, DMARC, content scan).
security_results: Option<serde_json::Value>,
},
/// An authentication request that needs TS validation.
AuthRequest {
correlation_id: String,
session_id: String,
username: String,
password: String,
remote_addr: String,
},
/// A SCRAM credential request — Rust needs stored credentials from TS.
ScramCredentialRequest {
correlation_id: String,
session_id: String,
username: String,
remote_addr: String,
},
}
/// How email data is transported from Rust to TS.
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum EmailData {
/// Inline base64-encoded data (for messages <= 256KB).
#[serde(rename = "inline")]
Inline { base64: String },
/// Path to a temp file containing the raw message (for large messages).
#[serde(rename = "file")]
File { path: String },
}
/// Result of TS processing an email.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailProcessingResult {
pub accepted: bool,
pub smtp_code: Option<u16>,
pub smtp_message: Option<String>,
}
/// Result of TS processing an auth request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthResult {
pub success: bool,
pub message: Option<String>,
}
/// Result of TS returning SCRAM credentials for a user.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScramCredentialResult {
pub found: bool,
pub salt: Option<Vec<u8>>,
pub iterations: Option<u32>,
pub stored_key: Option<Vec<u8>>,
pub server_key: Option<Vec<u8>>,
}
/// Abstraction over plain and TLS streams.
pub enum SmtpStream {
Plain(BufReader<TcpStream>),
Tls(BufReader<TlsStream<TcpStream>>),
}
impl SmtpStream {
/// Read a line from the stream (up to max_line_length bytes).
pub async fn read_line(&mut self, buf: &mut String, max_len: usize) -> std::io::Result<usize> {
match self {
SmtpStream::Plain(reader) => {
let result = reader.read_line(buf).await?;
if buf.len() > max_len {
buf.truncate(max_len);
}
Ok(result)
}
SmtpStream::Tls(reader) => {
let result = reader.read_line(buf).await?;
if buf.len() > max_len {
buf.truncate(max_len);
}
Ok(result)
}
}
}
/// Read bytes from the stream into a buffer.
pub async fn read_chunk(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
use tokio::io::AsyncReadExt;
match self {
SmtpStream::Plain(reader) => reader.read(buf).await,
SmtpStream::Tls(reader) => reader.read(buf).await,
}
}
/// Write bytes to the stream.
pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
SmtpStream::Plain(reader) => reader.get_mut().write_all(buf).await,
SmtpStream::Tls(reader) => reader.get_mut().write_all(buf).await,
}
}
/// Flush the write buffer.
pub async fn flush(&mut self) -> std::io::Result<()> {
match self {
SmtpStream::Plain(reader) => reader.get_mut().flush().await,
SmtpStream::Tls(reader) => reader.get_mut().flush().await,
}
}
/// Check if the internal buffer has unread data (pipelined commands).
pub fn has_buffered_data(&self) -> bool {
match self {
SmtpStream::Plain(reader) => !reader.buffer().is_empty(),
SmtpStream::Tls(reader) => !reader.buffer().is_empty(),
}
}
/// Unwrap to get the raw TcpStream for STARTTLS upgrade.
/// Only works on Plain streams.
pub fn into_tcp_stream(self) -> Option<TcpStream> {
match self {
SmtpStream::Plain(reader) => Some(reader.into_inner()),
SmtpStream::Tls(_) => None,
}
}
}
/// Handle a single SMTP connection.
///
/// This is the main entry point spawned for each incoming connection.
pub async fn handle_connection(
mut stream: SmtpStream,
config: Arc<SmtpServerConfig>,
rate_limiter: Arc<RateLimiter>,
event_tx: mpsc::Sender<ConnectionEvent>,
callback_register: Arc<dyn CallbackRegistry + Send + Sync>,
tls_acceptor: Option<Arc<tokio_rustls::TlsAcceptor>>,
remote_addr: String,
is_secure: bool,
authenticator: Arc<MessageAuthenticator>,
resolver: Arc<TokioResolver>,
) {
let mut session = SmtpSession::new(remote_addr.clone(), is_secure);
// Check IP rate limit
if !rate_limiter.check_connection(&remote_addr) {
let resp = SmtpResponse::service_unavailable(
&config.hostname,
"Too many connections from your IP",
);
let _ = stream.write_all(&resp.to_bytes()).await;
let _ = stream.flush().await;
info!(
session_id = %session.id,
remote_addr = %remote_addr,
"Connection rejected: rate limit exceeded"
);
return;
}
// Send greeting
let greeting = SmtpResponse::greeting(&config.hostname);
if stream.write_all(&greeting.to_bytes()).await.is_err() {
return;
}
if stream.flush().await.is_err() {
return;
}
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
loop {
let mut line = String::new();
let read_result = timeout(socket_timeout, stream.read_line(&mut line, 4096)).await;
match read_result {
Err(_) => {
// Timeout
let resp = SmtpResponse::service_unavailable(
&config.hostname,
"Connection timed out",
);
let _ = stream.write_all(&resp.to_bytes()).await;
let _ = stream.flush().await;
debug!(session_id = %session.id, "Connection timed out");
break;
}
Ok(Err(e)) => {
debug!(session_id = %session.id, error = %e, "Read error");
break;
}
Ok(Ok(0)) => {
debug!(session_id = %session.id, "Client disconnected");
break;
}
Ok(Ok(_)) => {
// Process the first command
let response = process_line(
&line,
&mut session,
&mut stream,
&config,
&rate_limiter,
&event_tx,
callback_register.as_ref(),
&tls_acceptor,
&authenticator,
&resolver,
)
.await;
// Check for pipelined commands in the buffer.
// Collect pipelinable responses into a batch for single write.
let mut response_batch: Vec<u8> = Vec::new();
let mut should_break = false;
let mut starttls_signal = false;
match response {
LineResult::Response(resp) => {
response_batch.extend_from_slice(&resp.to_bytes());
}
LineResult::Quit(resp) => {
let _ = stream.write_all(&resp.to_bytes()).await;
let _ = stream.flush().await;
should_break = true;
}
LineResult::StartTlsSignal => {
starttls_signal = true;
}
LineResult::NoResponse => {}
LineResult::Disconnect => {
should_break = true;
}
}
if should_break {
break;
}
// Process additional pipelined commands from the buffer
if !starttls_signal {
while stream.has_buffered_data() {
let mut next_line = String::new();
match stream.read_line(&mut next_line, 4096).await {
Ok(0) | Err(_) => break,
Ok(_) => {
let next_response = process_line(
&next_line,
&mut session,
&mut stream,
&config,
&rate_limiter,
&event_tx,
callback_register.as_ref(),
&tls_acceptor,
&authenticator,
&resolver,
)
.await;
match next_response {
LineResult::Response(resp) => {
response_batch.extend_from_slice(&resp.to_bytes());
}
LineResult::Quit(resp) => {
response_batch.extend_from_slice(&resp.to_bytes());
should_break = true;
break;
}
LineResult::StartTlsSignal | LineResult::Disconnect => {
// Non-pipelinable: flush batch and handle
starttls_signal = matches!(next_response, LineResult::StartTlsSignal);
should_break = matches!(next_response, LineResult::Disconnect);
break;
}
LineResult::NoResponse => {}
}
}
}
}
}
// Flush the accumulated response batch in one write
if !response_batch.is_empty() {
if stream.write_all(&response_batch).await.is_err() {
break;
}
if stream.flush().await.is_err() {
break;
}
}
if should_break {
break;
}
if starttls_signal {
// Send 220 Ready response
let resp = SmtpResponse::new(220, "Ready to start TLS");
if stream.write_all(&resp.to_bytes()).await.is_err() {
break;
}
if stream.flush().await.is_err() {
break;
}
// Extract TCP stream and upgrade
if let Some(tcp_stream) = stream.into_tcp_stream() {
if let Some(acceptor) = &tls_acceptor {
match acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
stream = SmtpStream::Tls(BufReader::new(tls_stream));
session.secure = true;
session.state = crate::state::SmtpState::Connected;
session.client_hostname = None;
session.esmtp = false;
session.auth_state = AuthState::None;
session.envelope = Default::default();
debug!(session_id = %session.id, "TLS upgrade successful");
}
Err(e) => {
warn!(session_id = %session.id, error = %e, "TLS handshake failed");
break;
}
}
} else {
break;
}
} else {
break;
}
}
}
}
}
info!(
session_id = %session.id,
remote_addr = %remote_addr,
messages = session.message_count,
"Connection closed"
);
}
/// Result of processing a single input line.
enum LineResult {
/// Send this response to the client.
Response(SmtpResponse),
/// Send this response then close the connection.
Quit(SmtpResponse),
/// Signal that STARTTLS should be performed (main loop sends 220 and upgrades).
StartTlsSignal,
/// No response needed (handled internally).
NoResponse,
/// Disconnect immediately.
Disconnect,
}
/// Trait for registering and resolving correlation-ID callbacks.
pub trait CallbackRegistry: Send + Sync {
/// Register a callback for email processing and return a receiver.
fn register_email_callback(
&self,
correlation_id: &str,
) -> oneshot::Receiver<EmailProcessingResult>;
/// Register a callback for auth and return a receiver.
fn register_auth_callback(
&self,
correlation_id: &str,
) -> oneshot::Receiver<AuthResult>;
/// Register a callback for SCRAM credential lookup and return a receiver.
fn register_scram_callback(
&self,
correlation_id: &str,
) -> oneshot::Receiver<ScramCredentialResult>;
}
/// Process a single input line from the client.
async fn process_line(
line: &str,
session: &mut SmtpSession,
stream: &mut SmtpStream,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
tls_acceptor: &Option<Arc<tokio_rustls::TlsAcceptor>>,
authenticator: &Arc<MessageAuthenticator>,
resolver: &Arc<TokioResolver>,
) -> LineResult {
// Handle AUTH intermediate states (waiting for username/password)
match &session.auth_state {
AuthState::WaitingForUsername => {
return handle_auth_username(line.trim(), session);
}
AuthState::WaitingForPassword { .. } => {
return handle_auth_password(
line.trim(),
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await;
}
_ => {}
}
// Parse SMTP command
let cmd = match parse_command(line) {
Ok(cmd) => cmd,
Err(ParseError::Empty) => return LineResult::NoResponse,
Err(_) => {
if session.record_invalid_command() {
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many invalid commands",
));
}
return LineResult::Response(SmtpResponse::syntax_error());
}
};
match cmd {
SmtpCommand::Ehlo(hostname) => handle_ehlo(hostname, true, session, config),
SmtpCommand::Helo(hostname) => handle_ehlo(hostname, false, session, config),
SmtpCommand::MailFrom { address, params } => {
handle_mail_from(address, params, session, config, rate_limiter)
}
SmtpCommand::RcptTo { address, params } => {
handle_rcpt_to(address, params, session, config)
}
SmtpCommand::Data => {
handle_data(session, stream, config, event_tx, callback_registry, authenticator, resolver).await
}
SmtpCommand::Rset => {
session.reset_transaction();
LineResult::Response(SmtpResponse::ok("OK"))
}
SmtpCommand::Noop => LineResult::Response(SmtpResponse::ok("OK")),
SmtpCommand::Quit => {
LineResult::Quit(SmtpResponse::closing(&config.hostname))
}
SmtpCommand::StartTls => {
handle_starttls(session, config, tls_acceptor)
}
SmtpCommand::Auth {
mechanism,
initial_response,
} => {
if matches!(mechanism, AuthMechanism::ScramSha256) {
handle_auth_scram(
initial_response,
session,
stream,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
} else {
handle_auth(
mechanism,
initial_response,
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
}
}
SmtpCommand::Help(_) => {
LineResult::Response(SmtpResponse::new(
214,
"EHLO HELO MAIL RCPT DATA RSET NOOP QUIT STARTTLS AUTH HELP VRFY",
))
}
SmtpCommand::Vrfy(_) => {
LineResult::Response(SmtpResponse::new(252, "Cannot VRFY user"))
}
SmtpCommand::Expn(_) => {
LineResult::Response(SmtpResponse::not_implemented())
}
}
}
/// Handle EHLO/HELO command.
fn handle_ehlo(
hostname: String,
esmtp: bool,
session: &mut SmtpSession,
config: &SmtpServerConfig,
) -> LineResult {
if !validation::is_valid_ehlo_hostname(&hostname) {
return LineResult::Response(SmtpResponse::param_error(
"Invalid hostname",
));
}
session.reset_for_ehlo(hostname, esmtp);
if esmtp {
let caps = build_capabilities(
config.max_message_size,
config.has_tls(),
session.secure,
config.auth_enabled,
);
LineResult::Response(SmtpResponse::ehlo_response(&config.hostname, &caps))
} else {
LineResult::Response(SmtpResponse::ok(format!(
"{} Hello",
config.hostname
)))
}
}
/// Handle MAIL FROM command.
fn handle_mail_from(
address: String,
params: std::collections::HashMap<String, Option<String>>,
session: &mut SmtpSession,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
) -> LineResult {
if !session.state.can_mail_from() {
return LineResult::Response(SmtpResponse::bad_sequence(
"Send EHLO/HELO first",
));
}
if !validation::is_valid_smtp_address(&address) {
return LineResult::Response(SmtpResponse::param_error(
"Invalid sender address",
));
}
// Check SIZE param
if let Some(Some(size_str)) = params.get("SIZE") {
match validation::validate_size_param(size_str, config.max_message_size) {
Ok(_) => {}
Err(msg) => return LineResult::Response(SmtpResponse::new(552, msg)),
}
}
// Rate limit check for sender
if !address.is_empty() && !rate_limiter.check_message(&address) {
return LineResult::Response(SmtpResponse::temp_failure(
"Too many messages from this sender, try again later",
));
}
session.envelope.mail_from = address;
session.envelope.declared_size = params
.get("SIZE")
.and_then(|v| v.as_ref())
.and_then(|s| s.parse().ok());
session.envelope.body_type = params
.get("BODY")
.and_then(|v| v.clone());
match session.state.transition_mail_from() {
Ok(new_state) => {
session.state = new_state;
LineResult::Response(SmtpResponse::ok("OK"))
}
Err(_) => LineResult::Response(SmtpResponse::bad_sequence(
"Bad sequence of commands",
)),
}
}
/// Handle RCPT TO command.
fn handle_rcpt_to(
address: String,
_params: std::collections::HashMap<String, Option<String>>,
session: &mut SmtpSession,
config: &SmtpServerConfig,
) -> LineResult {
if !session.state.can_rcpt_to() {
return LineResult::Response(SmtpResponse::bad_sequence(
"Send MAIL FROM first",
));
}
if !validation::is_valid_smtp_address(&address) || address.is_empty() {
return LineResult::Response(SmtpResponse::param_error(
"Invalid recipient address",
));
}
if session.envelope.rcpt_to.len() >= config.max_recipients as usize {
return LineResult::Response(SmtpResponse::new(
452,
"Too many recipients",
));
}
session.envelope.rcpt_to.push(address);
match session.state.transition_rcpt_to() {
Ok(new_state) => {
session.state = new_state;
LineResult::Response(SmtpResponse::ok("OK"))
}
Err(_) => LineResult::Response(SmtpResponse::bad_sequence(
"Bad sequence of commands",
)),
}
}
/// Handle DATA command: switch to data mode, accumulate, then emit event.
async fn handle_data(
session: &mut SmtpSession,
stream: &mut SmtpStream,
config: &SmtpServerConfig,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
authenticator: &Arc<MessageAuthenticator>,
resolver: &Arc<TokioResolver>,
) -> LineResult {
if !session.state.can_data() {
return LineResult::Response(SmtpResponse::bad_sequence(
"Send RCPT TO first",
));
}
// Transition to DATA state
session.state = match session.state.transition_data() {
Ok(s) => s,
Err(_) => {
return LineResult::Response(SmtpResponse::bad_sequence(
"Bad sequence of commands",
));
}
};
// Send 354
let start_resp = SmtpResponse::start_data();
if stream.write_all(&start_resp.to_bytes()).await.is_err() {
return LineResult::Disconnect;
}
if stream.flush().await.is_err() {
return LineResult::Disconnect;
}
// Accumulate data
let mut accumulator = DataAccumulator::new(config.max_message_size);
let data_timeout = Duration::from_secs(config.data_timeout_secs);
let mut buf = [0u8; 8192];
loop {
let read_result = timeout(data_timeout, stream.read_chunk(&mut buf)).await;
match read_result {
Err(_) => {
// Data timeout
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Data timeout",
));
}
Ok(Err(_)) => return LineResult::Disconnect,
Ok(Ok(0)) => return LineResult::Disconnect,
Ok(Ok(n)) => {
match accumulator.process_chunk(&buf[..n]) {
DataAction::Continue => continue,
DataAction::SizeExceeded => {
// Must still read until end-of-data to stay in sync
session.state = crate::state::SmtpState::Greeted;
session.envelope = Default::default();
return LineResult::Response(SmtpResponse::size_exceeded(
config.max_message_size,
));
}
DataAction::Complete => break,
}
}
}
}
// Data complete — prepare for delivery
let raw_message = accumulator.into_message().unwrap_or_default();
let correlation_id = uuid::Uuid::new_v4().to_string();
// --- In-process security pipeline (30s timeout) ---
let security_results = run_security_pipeline(
&raw_message,
&session.remote_addr,
session.client_hostname.as_deref().unwrap_or("unknown"),
&config.hostname,
&session.envelope.mail_from,
authenticator,
resolver,
)
.await;
// Determine transport: inline base64 or temp file
let email_data = if raw_message.len() <= 256 * 1024 {
EmailData::Inline {
base64: BASE64.encode(&raw_message),
}
} else {
// Write to temp file
let tmp_path = format!("/tmp/mailer-smtp-{}.eml", &correlation_id);
match tokio::fs::write(&tmp_path, &raw_message).await {
Ok(_) => EmailData::File { path: tmp_path },
Err(e) => {
warn!(error = %e, "Failed to write temp file for large email");
// Fall back to inline
EmailData::Inline {
base64: BASE64.encode(&raw_message),
}
}
}
};
// Register callback before sending event
let rx = callback_registry.register_email_callback(&correlation_id);
// Send event to TS
let event = ConnectionEvent::EmailReceived {
correlation_id: correlation_id.clone(),
session_id: session.id.clone(),
mail_from: session.envelope.mail_from.clone(),
rcpt_to: session.envelope.rcpt_to.clone(),
data: email_data,
remote_addr: session.remote_addr.clone(),
client_hostname: session.client_hostname.clone(),
secure: session.secure,
authenticated_user: session.authenticated_user().map(|s| s.to_string()),
security_results
};
if event_tx.send(event).await.is_err() {
warn!("Failed to send emailReceived event");
return LineResult::Response(SmtpResponse::local_error(
"Internal processing error",
));
}
// Wait for TS response with timeout
let processing_timeout = Duration::from_secs(config.processing_timeout_secs);
let result = match timeout(processing_timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => {
warn!(correlation_id = %correlation_id, "Callback channel dropped");
EmailProcessingResult {
accepted: false,
smtp_code: Some(451),
smtp_message: Some("Processing error".into()),
}
}
Err(_) => {
warn!(correlation_id = %correlation_id, "Processing timeout");
EmailProcessingResult {
accepted: false,
smtp_code: Some(451),
smtp_message: Some("Processing timeout".into()),
}
}
};
// Reset transaction state
session.envelope = Default::default();
let _ = session.state.transition_finished();
session.state = crate::state::SmtpState::Finished;
session.record_message();
if result.accepted {
LineResult::Response(SmtpResponse::ok(
result.smtp_message.unwrap_or_else(|| "Message accepted".into()),
))
} else {
let code = result.smtp_code.unwrap_or(550);
let msg = result
.smtp_message
.unwrap_or_else(|| "Message rejected".into());
LineResult::Response(SmtpResponse::new(code, msg))
}
}
/// Handle STARTTLS command.
///
/// Returns `StartTlsSignal` to indicate the main loop should send 220 and
/// perform the TLS upgrade. The main loop handles the stream swap.
fn handle_starttls(
session: &SmtpSession,
config: &SmtpServerConfig,
tls_acceptor: &Option<Arc<tokio_rustls::TlsAcceptor>>,
) -> LineResult {
if session.secure {
return LineResult::Response(SmtpResponse::bad_sequence(
"Already using TLS",
));
}
if !session.state.can_starttls() {
return LineResult::Response(SmtpResponse::bad_sequence(
"STARTTLS not allowed in current state",
));
}
if tls_acceptor.is_none() || !config.has_tls() {
return LineResult::Response(SmtpResponse::new(
454,
"TLS not available",
));
}
// Signal the main loop to perform TLS upgrade.
// The main loop will: send 220, extract TCP stream, do TLS handshake.
LineResult::StartTlsSignal
}
/// Handle AUTH command.
async fn handle_auth(
mechanism: AuthMechanism,
initial_response: Option<String>,
session: &mut SmtpSession,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
) -> LineResult {
if !config.auth_enabled {
return LineResult::Response(SmtpResponse::not_implemented());
}
if session.is_authenticated() {
return LineResult::Response(SmtpResponse::bad_sequence(
"Already authenticated",
));
}
if !session.state.can_auth() {
return LineResult::Response(SmtpResponse::bad_sequence(
"Send EHLO first",
));
}
match mechanism {
AuthMechanism::Plain => {
if let Some(response) = initial_response {
// Decode and validate immediately
return process_auth_plain(
&response,
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await;
}
// No initial response — send challenge for credentials
session.auth_state = AuthState::WaitingForUsername;
// For PLAIN, we use an empty challenge
LineResult::Response(SmtpResponse::auth_challenge(""))
}
AuthMechanism::Login => {
if let Some(response) = initial_response {
// The initial response is the username (base64)
match BASE64.decode(response.as_bytes()) {
Ok(decoded) => {
let username = String::from_utf8_lossy(&decoded).to_string();
session.auth_state = AuthState::WaitingForPassword { username };
// Send password prompt (base64 of "Password:")
LineResult::Response(SmtpResponse::auth_challenge(
&BASE64.encode(b"Password:"),
))
}
Err(_) => LineResult::Response(SmtpResponse::param_error(
"Invalid base64 encoding",
)),
}
} else {
session.auth_state = AuthState::WaitingForUsername;
// Send username prompt (base64 of "Username:")
LineResult::Response(SmtpResponse::auth_challenge(
&BASE64.encode(b"Username:"),
))
}
}
AuthMechanism::ScramSha256 => {
// SCRAM is handled separately in process_line; this should not be reached.
LineResult::Response(SmtpResponse::not_implemented())
}
}
}
/// Handle AUTH SCRAM-SHA-256 — full exchange in a single async function.
///
/// SCRAM is a multi-step challenge-response protocol:
/// 1. Client sends client-first-message (in initial_response or after 334)
/// 2. Server requests SCRAM credentials from TS
/// 3. Server sends server-first-message (334 challenge)
/// 4. Client sends client-final-message (proof)
/// 5. Server verifies proof and responds with 235 or 535
async fn handle_auth_scram(
initial_response: Option<String>,
session: &mut SmtpSession,
stream: &mut SmtpStream,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
) -> LineResult {
if !config.auth_enabled {
return LineResult::Response(SmtpResponse::not_implemented());
}
if session.is_authenticated() {
return LineResult::Response(SmtpResponse::bad_sequence("Already authenticated"));
}
if !session.state.can_auth() {
return LineResult::Response(SmtpResponse::bad_sequence("Send EHLO first"));
}
// Step 1: Get client-first-message
let client_first_b64 = match initial_response {
Some(s) if !s.is_empty() => s,
_ => {
// No initial response — send empty 334 challenge
let resp = SmtpResponse::auth_challenge("");
if stream.write_all(&resp.to_bytes()).await.is_err() {
return LineResult::Disconnect;
}
if stream.flush().await.is_err() {
return LineResult::Disconnect;
}
// Read client-first-message
let mut line = String::new();
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
match timeout(socket_timeout, stream.read_line(&mut line, 4096)).await {
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return LineResult::Disconnect,
Ok(Ok(_)) => {}
}
let trimmed = line.trim().to_string();
if trimmed == "*" {
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
}
trimmed
}
};
// Decode base64 client-first-message
let client_first_bytes = match BASE64.decode(client_first_b64.as_bytes()) {
Ok(b) => b,
Err(_) => {
return LineResult::Response(SmtpResponse::param_error("Invalid base64 encoding"));
}
};
let client_first = match String::from_utf8(client_first_bytes) {
Ok(s) => s,
Err(_) => {
return LineResult::Response(SmtpResponse::param_error("Invalid UTF-8 in SCRAM message"));
}
};
// Parse client-first-message
let mut scram = match ScramServer::from_client_first(&client_first) {
Ok(s) => s,
Err(e) => {
debug!(error = %e, "SCRAM client-first-message parse error");
return LineResult::Response(SmtpResponse::param_error(
"Invalid SCRAM client-first-message",
));
}
};
// Step 2: Request SCRAM credentials from TS
let correlation_id = uuid::Uuid::new_v4().to_string();
let rx = callback_registry.register_scram_callback(&correlation_id);
let event = ConnectionEvent::ScramCredentialRequest {
correlation_id: correlation_id.clone(),
session_id: session.id.clone(),
username: scram.username.clone(),
remote_addr: session.remote_addr.clone(),
};
if event_tx.send(event).await.is_err() {
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
}
// Wait for credentials from TS
let cred_timeout = Duration::from_secs(5);
let cred_result = match timeout(cred_timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => {
warn!(correlation_id = %correlation_id, "SCRAM credential callback dropped");
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
}
Err(_) => {
warn!(correlation_id = %correlation_id, "SCRAM credential request timed out");
return LineResult::Response(SmtpResponse::local_error("Internal processing error"));
}
};
if !cred_result.found {
// User not found — fail auth (don't reveal that user doesn't exist)
session.auth_state = AuthState::None;
let exceeded = session.record_auth_failure(config.max_auth_failures);
if exceeded {
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
return LineResult::Response(SmtpResponse::auth_failed());
}
let creds = ScramCredentials {
salt: cred_result.salt.unwrap_or_default(),
iterations: cred_result.iterations.unwrap_or(4096),
stored_key: cred_result.stored_key.unwrap_or_default(),
server_key: cred_result.server_key.unwrap_or_default(),
};
// Step 3: Generate and send server-first-message
let server_first = scram.server_first_message(creds);
let server_first_b64 = BASE64.encode(server_first.as_bytes());
let challenge = SmtpResponse::auth_challenge(&server_first_b64);
if stream.write_all(&challenge.to_bytes()).await.is_err() {
return LineResult::Disconnect;
}
if stream.flush().await.is_err() {
return LineResult::Disconnect;
}
// Step 4: Read client-final-message
let mut client_final_line = String::new();
let socket_timeout = Duration::from_secs(config.socket_timeout_secs);
match timeout(socket_timeout, stream.read_line(&mut client_final_line, 4096)).await {
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return LineResult::Disconnect,
Ok(Ok(_)) => {}
}
let client_final_b64 = client_final_line.trim();
// Cancel if *
if client_final_b64 == "*" {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
}
// Decode base64 client-final-message
let client_final_bytes = match BASE64.decode(client_final_b64.as_bytes()) {
Ok(b) => b,
Err(_) => {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::param_error("Invalid base64 encoding"));
}
};
let client_final = match String::from_utf8(client_final_bytes) {
Ok(s) => s,
Err(_) => {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::param_error("Invalid UTF-8 in SCRAM message"));
}
};
// Step 5: Verify proof
match scram.process_client_final(&client_final) {
Ok(server_final) => {
let server_final_b64 = BASE64.encode(server_final.as_bytes());
session.auth_state = AuthState::Authenticated {
username: scram.username.clone(),
};
LineResult::Response(SmtpResponse::new(
235,
format!("2.7.0 Authentication successful {}", server_final_b64),
))
}
Err(e) => {
debug!(error = %e, "SCRAM proof verification failed");
session.auth_state = AuthState::None;
let exceeded = session.record_auth_failure(config.max_auth_failures);
if exceeded {
if !rate_limiter.check_auth_failure(&session.remote_addr) {
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
LineResult::Response(SmtpResponse::auth_failed())
}
}
}
/// Handle username input during LOGIN auth flow.
fn handle_auth_username(line: &str, session: &mut SmtpSession) -> LineResult {
// Cancel auth if client sends "*"
if line == "*" {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
}
match BASE64.decode(line.as_bytes()) {
Ok(decoded) => {
let username = String::from_utf8_lossy(&decoded).to_string();
session.auth_state = AuthState::WaitingForPassword { username };
LineResult::Response(SmtpResponse::auth_challenge(
&BASE64.encode(b"Password:"),
))
}
Err(_) => {
session.auth_state = AuthState::None;
LineResult::Response(SmtpResponse::param_error(
"Invalid base64 encoding",
))
}
}
}
/// Handle password input during LOGIN auth flow.
async fn handle_auth_password(
line: &str,
session: &mut SmtpSession,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
) -> LineResult {
// Cancel auth if client sends "*"
if line == "*" {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::new(501, "Authentication cancelled"));
}
let username = match &session.auth_state {
AuthState::WaitingForPassword { username } => username.clone(),
_ => {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::bad_sequence("Unexpected auth state"));
}
};
let password = match BASE64.decode(line.as_bytes()) {
Ok(decoded) => String::from_utf8_lossy(&decoded).to_string(),
Err(_) => {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::param_error(
"Invalid base64 encoding",
));
}
};
validate_credentials(
&username,
&password,
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
}
/// Process AUTH PLAIN credentials (base64-encoded "\0username\0password").
async fn process_auth_plain(
base64_data: &str,
session: &mut SmtpSession,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
) -> LineResult {
let decoded = match BASE64.decode(base64_data.as_bytes()) {
Ok(d) => d,
Err(_) => {
return LineResult::Response(SmtpResponse::param_error(
"Invalid base64 encoding",
));
}
};
// PLAIN format: \0username\0password
let parts: Vec<&[u8]> = decoded.splitn(3, |&b| b == 0).collect();
if parts.len() < 3 {
return LineResult::Response(SmtpResponse::param_error(
"Invalid PLAIN auth format",
));
}
let username = String::from_utf8_lossy(parts[1]).to_string();
let password = String::from_utf8_lossy(parts[2]).to_string();
validate_credentials(
&username,
&password,
session,
config,
rate_limiter,
event_tx,
callback_registry,
)
.await
}
/// Validate credentials by sending authRequest to TS and waiting for response.
async fn validate_credentials(
username: &str,
password: &str,
session: &mut SmtpSession,
config: &SmtpServerConfig,
rate_limiter: &RateLimiter,
event_tx: &mpsc::Sender<ConnectionEvent>,
callback_registry: &dyn CallbackRegistry,
) -> LineResult {
let correlation_id = uuid::Uuid::new_v4().to_string();
// Register callback before sending event
let rx = callback_registry.register_auth_callback(&correlation_id);
let event = ConnectionEvent::AuthRequest {
correlation_id: correlation_id.clone(),
session_id: session.id.clone(),
username: username.to_string(),
password: password.to_string(),
remote_addr: session.remote_addr.clone(),
};
if event_tx.send(event).await.is_err() {
session.auth_state = AuthState::None;
return LineResult::Response(SmtpResponse::local_error(
"Internal processing error",
));
}
// Wait for TS response
let auth_timeout = Duration::from_secs(5);
let result = match timeout(auth_timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => AuthResult {
success: false,
message: Some("Auth processing error".into()),
},
Err(_) => AuthResult {
success: false,
message: Some("Auth timeout".into()),
},
};
if result.success {
session.auth_state = AuthState::Authenticated {
username: username.to_string(),
};
LineResult::Response(SmtpResponse::auth_success())
} else {
session.auth_state = AuthState::None;
let exceeded = session.record_auth_failure(config.max_auth_failures);
if exceeded {
if !rate_limiter.check_auth_failure(&session.remote_addr) {
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
return LineResult::Quit(SmtpResponse::service_unavailable(
&config.hostname,
"Too many authentication failures",
));
}
LineResult::Response(SmtpResponse::auth_failed())
}
}
/// Extract MIME parts from a raw email message for content scanning.
///
/// Returns `(subject, text_body, html_body, attachment_filenames)`.
fn extract_mime_parts(raw_message: &[u8]) -> (Option<String>, Option<String>, Option<String>, Vec<String>) {
let parsed = match mailparse::parse_mail(raw_message) {
Ok(p) => p,
Err(e) => {
debug!(error = %e, "Failed to parse MIME for content scanning");
return (None, None, None, Vec::new());
}
};
// Extract Subject header
let subject = parsed
.headers
.iter()
.find(|h| h.get_key().eq_ignore_ascii_case("subject"))
.map(|h| h.get_value());
let mut text_body: Option<String> = None;
let mut html_body: Option<String> = None;
let mut attachments: Vec<String> = Vec::new();
// Walk the MIME tree
fn walk_parts(
part: &mailparse::ParsedMail<'_>,
text_body: &mut Option<String>,
html_body: &mut Option<String>,
attachments: &mut Vec<String>,
) {
let content_type = part.ctype.mimetype.to_lowercase();
let disposition = part.get_content_disposition();
// Check if this is an attachment
if disposition.disposition == mailparse::DispositionType::Attachment {
if let Some(filename) = disposition.params.get("filename") {
attachments.push(filename.clone());
}
} else if content_type == "text/plain" && text_body.is_none() {
if let Ok(body) = part.get_body() {
*text_body = Some(body);
}
} else if content_type == "text/html" && html_body.is_none() {
if let Ok(body) = part.get_body() {
*html_body = Some(body);
}
}
// Recurse into subparts
for sub in &part.subparts {
walk_parts(sub, text_body, html_body, attachments);
}
}
walk_parts(&parsed, &mut text_body, &mut html_body, &mut attachments);
(subject, text_body, html_body, attachments)
}
/// Run the full security pipeline: DKIM/SPF/DMARC + content scan + IP reputation.
///
/// Returns `Some(json_value)` on success or `None` if the pipeline fails or times out.
async fn run_security_pipeline(
raw_message: &[u8],
remote_addr: &str,
helo_domain: &str,
hostname: &str,
mail_from: &str,
authenticator: &Arc<MessageAuthenticator>,
resolver: &Arc<TokioResolver>,
) -> Option<serde_json::Value> {
let security_timeout = Duration::from_secs(30);
match timeout(security_timeout, run_security_pipeline_inner(
raw_message, remote_addr, helo_domain, hostname, mail_from, authenticator, resolver,
)).await {
Ok(Ok(value)) => {
debug!("In-process security pipeline completed");
Some(value)
}
Ok(Err(e)) => {
warn!(error = %e, "Security pipeline error — emitting event without results");
None
}
Err(_) => {
warn!("Security pipeline timed out (30s) — emitting event without results");
None
}
}
}
/// Inner implementation of the security pipeline (no timeout wrapper).
async fn run_security_pipeline_inner(
raw_message: &[u8],
remote_addr: &str,
helo_domain: &str,
hostname: &str,
mail_from: &str,
authenticator: &Arc<MessageAuthenticator>,
resolver: &Arc<TokioResolver>,
) -> std::result::Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
// Parse the remote IP address
let ip: IpAddr = remote_addr.parse().unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
// Run DKIM/SPF/DMARC and IP reputation concurrently
let (email_security, reputation) = tokio::join!(
mailer_security::verify_email_security(
raw_message, ip, helo_domain, hostname, mail_from, authenticator,
),
mailer_security::check_reputation(
ip, mailer_security::DEFAULT_DNSBL_SERVERS, resolver,
),
);
// Extract MIME parts for content scanning (synchronous)
let (subject, text_body, html_body, attachment_names) = extract_mime_parts(raw_message);
// Run content scan (synchronous)
let content_scan = mailer_security::content_scanner::scan_content(
subject.as_deref(),
text_body.as_deref(),
html_body.as_deref(),
&attachment_names,
);
// Build the combined results JSON
let mut results = serde_json::Map::new();
// DKIM/SPF/DMARC
match email_security {
Ok(sec) => {
results.insert("dkim".into(), serde_json::to_value(&sec.dkim)?);
results.insert("spf".into(), serde_json::to_value(&sec.spf)?);
results.insert("dmarc".into(), serde_json::to_value(&sec.dmarc)?);
}
Err(e) => {
warn!(error = %e, "Email security verification failed");
results.insert("dkim".into(), serde_json::Value::Array(vec![]));
results.insert("spf".into(), serde_json::Value::Null);
results.insert("dmarc".into(), serde_json::Value::Null);
}
}
// Content scan
results.insert("contentScan".into(), serde_json::to_value(&content_scan)?);
// IP reputation
match reputation {
Ok(rep) => {
results.insert("ipReputation".into(), serde_json::to_value(&rep)?);
}
Err(e) => {
warn!(error = %e, "IP reputation check failed");
results.insert("ipReputation".into(), serde_json::Value::Null);
}
}
Ok(serde_json::Value::Object(results))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_email_data_serialization() {
let data = EmailData::Inline {
base64: "dGVzdA==".into(),
};
let json = serde_json::to_string(&data).unwrap();
assert!(json.contains("inline"));
let data = EmailData::File {
path: "/tmp/test.eml".into(),
};
let json = serde_json::to_string(&data).unwrap();
assert!(json.contains("file"));
}
#[test]
fn test_processing_result_serialization() {
let result = EmailProcessingResult {
accepted: true,
smtp_code: Some(250),
smtp_message: Some("OK".into()),
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("accepted"));
}
#[test]
fn test_extract_mime_parts_simple() {
let raw = b"From: sender@example.com\r\n\
To: rcpt@example.com\r\n\
Subject: Test Subject\r\n\
Content-Type: text/plain\r\n\
\r\n\
Hello, this is a test body.\r\n";
let (subject, text, html, attachments) = extract_mime_parts(raw);
assert_eq!(subject.as_deref(), Some("Test Subject"));
assert!(text.is_some());
assert!(text.unwrap().contains("Hello, this is a test body."));
assert!(html.is_none());
assert!(attachments.is_empty());
}
#[test]
fn test_extract_mime_parts_multipart() {
let raw = b"From: sender@example.com\r\n\
To: rcpt@example.com\r\n\
Subject: Multipart Test\r\n\
Content-Type: multipart/mixed; boundary=\"boundary123\"\r\n\
\r\n\
--boundary123\r\n\
Content-Type: text/plain\r\n\
\r\n\
Plain text body\r\n\
--boundary123\r\n\
Content-Type: text/html\r\n\
\r\n\
<html><body>HTML body</body></html>\r\n\
--boundary123\r\n\
Content-Type: application/octet-stream\r\n\
Content-Disposition: attachment; filename=\"report.pdf\"\r\n\
\r\n\
binary data here\r\n\
--boundary123--\r\n";
let (subject, text, html, attachments) = extract_mime_parts(raw);
assert_eq!(subject.as_deref(), Some("Multipart Test"));
assert!(text.is_some());
assert!(text.unwrap().contains("Plain text body"));
assert!(html.is_some());
assert!(html.unwrap().contains("HTML body"));
assert_eq!(attachments.len(), 1);
assert_eq!(attachments[0], "report.pdf");
}
#[test]
fn test_extract_mime_parts_no_subject() {
let raw = b"From: sender@example.com\r\n\
To: rcpt@example.com\r\n\
Content-Type: text/plain\r\n\
\r\n\
Body without subject\r\n";
let (subject, text, _html, _attachments) = extract_mime_parts(raw);
assert!(subject.is_none());
assert!(text.is_some());
}
#[test]
fn test_extract_mime_parts_invalid() {
let raw = b"this is not a valid email";
let (subject, text, html, attachments) = extract_mime_parts(raw);
// Should not panic, may or may not parse partially
// The key property is that it doesn't crash
let _ = (subject, text, html, attachments);
}
#[test]
fn test_extract_mime_parts_multiple_attachments() {
let raw = b"From: sender@example.com\r\n\
To: rcpt@example.com\r\n\
Subject: Attachments\r\n\
Content-Type: multipart/mixed; boundary=\"bound\"\r\n\
\r\n\
--bound\r\n\
Content-Type: text/plain\r\n\
\r\n\
See attached\r\n\
--bound\r\n\
Content-Type: application/pdf\r\n\
Content-Disposition: attachment; filename=\"doc1.pdf\"\r\n\
\r\n\
pdf data\r\n\
--bound\r\n\
Content-Type: application/vnd.ms-excel\r\n\
Content-Disposition: attachment; filename=\"data.xlsx\"\r\n\
\r\n\
excel data\r\n\
--bound--\r\n";
let (subject, text, _html, attachments) = extract_mime_parts(raw);
assert_eq!(subject.as_deref(), Some("Attachments"));
assert!(text.is_some());
assert_eq!(attachments.len(), 2);
assert!(attachments.contains(&"doc1.pdf".to_string()));
assert!(attachments.contains(&"data.xlsx".to_string()));
}
}