Files
smartmta/rust/crates/mailer-smtp/src/client/connection.rs

207 lines
7.1 KiB
Rust

//! TCP/TLS connection management for the SMTP client.
use super::error::SmtpClientError;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::time::{timeout, Duration};
use tokio_rustls::client::TlsStream;
use tracing::debug;
/// A client-side SMTP stream that may be plain or TLS.
pub enum ClientSmtpStream {
Plain(BufReader<TcpStream>),
Tls(BufReader<TlsStream<TcpStream>>),
}
impl std::fmt::Debug for ClientSmtpStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientSmtpStream::Plain(_) => write!(f, "ClientSmtpStream::Plain"),
ClientSmtpStream::Tls(_) => write!(f, "ClientSmtpStream::Tls"),
}
}
}
impl ClientSmtpStream {
/// Read a line from the stream (CRLF-terminated).
pub async fn read_line(&mut self, buf: &mut String) -> Result<usize, SmtpClientError> {
match self {
ClientSmtpStream::Plain(reader) => reader.read_line(buf).await.map_err(|e| {
SmtpClientError::ConnectionError {
message: format!("Read error: {e}"),
}
}),
ClientSmtpStream::Tls(reader) => reader.read_line(buf).await.map_err(|e| {
SmtpClientError::ConnectionError {
message: format!("TLS read error: {e}"),
}
}),
}
}
/// Write bytes to the stream.
pub async fn write_all(&mut self, data: &[u8]) -> Result<(), SmtpClientError> {
match self {
ClientSmtpStream::Plain(reader) => {
reader.get_mut().write_all(data).await.map_err(|e| {
SmtpClientError::ConnectionError {
message: format!("Write error: {e}"),
}
})
}
ClientSmtpStream::Tls(reader) => {
reader.get_mut().write_all(data).await.map_err(|e| {
SmtpClientError::ConnectionError {
message: format!("TLS write error: {e}"),
}
})
}
}
}
/// Flush the stream.
pub async fn flush(&mut self) -> Result<(), SmtpClientError> {
match self {
ClientSmtpStream::Plain(reader) => {
reader.get_mut().flush().await.map_err(|e| {
SmtpClientError::ConnectionError {
message: format!("Flush error: {e}"),
}
})
}
ClientSmtpStream::Tls(reader) => {
reader.get_mut().flush().await.map_err(|e| {
SmtpClientError::ConnectionError {
message: format!("TLS flush error: {e}"),
}
})
}
}
}
/// Consume this stream and return the inner TcpStream (for STARTTLS upgrade).
/// Only works on Plain streams; returns an error on TLS streams.
pub fn into_tcp_stream(self) -> Result<TcpStream, SmtpClientError> {
match self {
ClientSmtpStream::Plain(reader) => Ok(reader.into_inner()),
ClientSmtpStream::Tls(_) => Err(SmtpClientError::TlsError {
message: "Cannot extract TcpStream from an already-TLS stream".into(),
}),
}
}
}
/// Connect to an SMTP server via plain TCP.
pub async fn connect_plain(
host: &str,
port: u16,
timeout_secs: u64,
) -> Result<ClientSmtpStream, SmtpClientError> {
debug!("Connecting to {}:{} (plain)", host, port);
let addr = format!("{host}:{port}");
let stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(&addr))
.await
.map_err(|_| SmtpClientError::TimeoutError {
message: format!("Connection to {addr} timed out after {timeout_secs}s"),
})?
.map_err(|e| SmtpClientError::ConnectionError {
message: format!("Failed to connect to {addr}: {e}"),
})?;
Ok(ClientSmtpStream::Plain(BufReader::new(stream)))
}
/// Connect to an SMTP server via implicit TLS (port 465).
pub async fn connect_tls(
host: &str,
port: u16,
timeout_secs: u64,
) -> Result<ClientSmtpStream, SmtpClientError> {
debug!("Connecting to {}:{} (implicit TLS)", host, port);
let addr = format!("{host}:{port}");
let tcp_stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(&addr))
.await
.map_err(|_| SmtpClientError::TimeoutError {
message: format!("Connection to {addr} timed out after {timeout_secs}s"),
})?
.map_err(|e| SmtpClientError::ConnectionError {
message: format!("Failed to connect to {addr}: {e}"),
})?;
let tls_stream = perform_tls_handshake(tcp_stream, host).await?;
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
}
/// Upgrade a plain TCP connection to TLS (STARTTLS).
pub async fn upgrade_to_tls(
stream: ClientSmtpStream,
hostname: &str,
) -> Result<ClientSmtpStream, SmtpClientError> {
debug!("Upgrading connection to TLS (STARTTLS) for {}", hostname);
let tcp_stream = stream.into_tcp_stream()?;
let tls_stream = perform_tls_handshake(tcp_stream, hostname).await?;
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
}
/// Perform the TLS handshake on a TCP stream using webpki-roots.
async fn perform_tls_handshake(
tcp_stream: TcpStream,
hostname: &str,
) -> Result<TlsStream<TcpStream>, SmtpClientError> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config));
let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string()).map_err(|e| {
SmtpClientError::TlsError {
message: format!("Invalid server name '{hostname}': {e}"),
}
})?;
let tls_stream = connector
.connect(server_name, tcp_stream)
.await
.map_err(|e| SmtpClientError::TlsError {
message: format!("TLS handshake with {hostname} failed: {e}"),
})?;
Ok(tls_stream)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_connect_plain_refused() {
// Connecting to a port that's not listening should fail
let result = connect_plain("127.0.0.1", 19999, 2).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, SmtpClientError::ConnectionError { .. }));
assert!(err.is_retryable());
}
#[tokio::test]
async fn test_connect_tls_refused() {
let result = connect_tls("127.0.0.1", 19998, 2).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_connect_timeout() {
// 192.0.2.1 is TEST-NET, should time out
let result = connect_plain("192.0.2.1", 25, 1).await;
assert!(result.is_err());
let err = result.unwrap_err();
// May be timeout or connection error depending on network
assert!(err.is_retryable());
}
}