207 lines
7.1 KiB
Rust
207 lines
7.1 KiB
Rust
//! TCP/TLS connection management for the SMTP client.
|
|
|
|
use super::error::SmtpClientError;
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::TcpStream;
|
|
use tokio::time::{timeout, Duration};
|
|
use tokio_rustls::client::TlsStream;
|
|
use tracing::debug;
|
|
|
|
/// A client-side SMTP stream that may be plain or TLS.
|
|
pub enum ClientSmtpStream {
|
|
Plain(BufReader<TcpStream>),
|
|
Tls(BufReader<TlsStream<TcpStream>>),
|
|
}
|
|
|
|
impl std::fmt::Debug for ClientSmtpStream {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
ClientSmtpStream::Plain(_) => write!(f, "ClientSmtpStream::Plain"),
|
|
ClientSmtpStream::Tls(_) => write!(f, "ClientSmtpStream::Tls"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ClientSmtpStream {
|
|
/// Read a line from the stream (CRLF-terminated).
|
|
pub async fn read_line(&mut self, buf: &mut String) -> Result<usize, SmtpClientError> {
|
|
match self {
|
|
ClientSmtpStream::Plain(reader) => reader.read_line(buf).await.map_err(|e| {
|
|
SmtpClientError::ConnectionError {
|
|
message: format!("Read error: {e}"),
|
|
}
|
|
}),
|
|
ClientSmtpStream::Tls(reader) => reader.read_line(buf).await.map_err(|e| {
|
|
SmtpClientError::ConnectionError {
|
|
message: format!("TLS read error: {e}"),
|
|
}
|
|
}),
|
|
}
|
|
}
|
|
|
|
/// Write bytes to the stream.
|
|
pub async fn write_all(&mut self, data: &[u8]) -> Result<(), SmtpClientError> {
|
|
match self {
|
|
ClientSmtpStream::Plain(reader) => {
|
|
reader.get_mut().write_all(data).await.map_err(|e| {
|
|
SmtpClientError::ConnectionError {
|
|
message: format!("Write error: {e}"),
|
|
}
|
|
})
|
|
}
|
|
ClientSmtpStream::Tls(reader) => {
|
|
reader.get_mut().write_all(data).await.map_err(|e| {
|
|
SmtpClientError::ConnectionError {
|
|
message: format!("TLS write error: {e}"),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Flush the stream.
|
|
pub async fn flush(&mut self) -> Result<(), SmtpClientError> {
|
|
match self {
|
|
ClientSmtpStream::Plain(reader) => {
|
|
reader.get_mut().flush().await.map_err(|e| {
|
|
SmtpClientError::ConnectionError {
|
|
message: format!("Flush error: {e}"),
|
|
}
|
|
})
|
|
}
|
|
ClientSmtpStream::Tls(reader) => {
|
|
reader.get_mut().flush().await.map_err(|e| {
|
|
SmtpClientError::ConnectionError {
|
|
message: format!("TLS flush error: {e}"),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Consume this stream and return the inner TcpStream (for STARTTLS upgrade).
|
|
/// Only works on Plain streams; returns an error on TLS streams.
|
|
pub fn into_tcp_stream(self) -> Result<TcpStream, SmtpClientError> {
|
|
match self {
|
|
ClientSmtpStream::Plain(reader) => Ok(reader.into_inner()),
|
|
ClientSmtpStream::Tls(_) => Err(SmtpClientError::TlsError {
|
|
message: "Cannot extract TcpStream from an already-TLS stream".into(),
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Connect to an SMTP server via plain TCP.
|
|
pub async fn connect_plain(
|
|
host: &str,
|
|
port: u16,
|
|
timeout_secs: u64,
|
|
) -> Result<ClientSmtpStream, SmtpClientError> {
|
|
debug!("Connecting to {}:{} (plain)", host, port);
|
|
let addr = format!("{host}:{port}");
|
|
let stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(&addr))
|
|
.await
|
|
.map_err(|_| SmtpClientError::TimeoutError {
|
|
message: format!("Connection to {addr} timed out after {timeout_secs}s"),
|
|
})?
|
|
.map_err(|e| SmtpClientError::ConnectionError {
|
|
message: format!("Failed to connect to {addr}: {e}"),
|
|
})?;
|
|
|
|
Ok(ClientSmtpStream::Plain(BufReader::new(stream)))
|
|
}
|
|
|
|
/// Connect to an SMTP server via implicit TLS (port 465).
|
|
pub async fn connect_tls(
|
|
host: &str,
|
|
port: u16,
|
|
timeout_secs: u64,
|
|
) -> Result<ClientSmtpStream, SmtpClientError> {
|
|
debug!("Connecting to {}:{} (implicit TLS)", host, port);
|
|
let addr = format!("{host}:{port}");
|
|
|
|
let tcp_stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(&addr))
|
|
.await
|
|
.map_err(|_| SmtpClientError::TimeoutError {
|
|
message: format!("Connection to {addr} timed out after {timeout_secs}s"),
|
|
})?
|
|
.map_err(|e| SmtpClientError::ConnectionError {
|
|
message: format!("Failed to connect to {addr}: {e}"),
|
|
})?;
|
|
|
|
let tls_stream = perform_tls_handshake(tcp_stream, host).await?;
|
|
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
|
|
}
|
|
|
|
/// Upgrade a plain TCP connection to TLS (STARTTLS).
|
|
pub async fn upgrade_to_tls(
|
|
stream: ClientSmtpStream,
|
|
hostname: &str,
|
|
) -> Result<ClientSmtpStream, SmtpClientError> {
|
|
debug!("Upgrading connection to TLS (STARTTLS) for {}", hostname);
|
|
let tcp_stream = stream.into_tcp_stream()?;
|
|
let tls_stream = perform_tls_handshake(tcp_stream, hostname).await?;
|
|
Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream)))
|
|
}
|
|
|
|
/// Perform the TLS handshake on a TCP stream using webpki-roots.
|
|
async fn perform_tls_handshake(
|
|
tcp_stream: TcpStream,
|
|
hostname: &str,
|
|
) -> Result<TlsStream<TcpStream>, SmtpClientError> {
|
|
let mut root_store = rustls::RootCertStore::empty();
|
|
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
|
|
|
let tls_config = rustls::ClientConfig::builder()
|
|
.with_root_certificates(root_store)
|
|
.with_no_client_auth();
|
|
|
|
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config));
|
|
let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string()).map_err(|e| {
|
|
SmtpClientError::TlsError {
|
|
message: format!("Invalid server name '{hostname}': {e}"),
|
|
}
|
|
})?;
|
|
|
|
let tls_stream = connector
|
|
.connect(server_name, tcp_stream)
|
|
.await
|
|
.map_err(|e| SmtpClientError::TlsError {
|
|
message: format!("TLS handshake with {hostname} failed: {e}"),
|
|
})?;
|
|
|
|
Ok(tls_stream)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_connect_plain_refused() {
|
|
// Connecting to a port that's not listening should fail
|
|
let result = connect_plain("127.0.0.1", 19999, 2).await;
|
|
assert!(result.is_err());
|
|
let err = result.unwrap_err();
|
|
assert!(matches!(err, SmtpClientError::ConnectionError { .. }));
|
|
assert!(err.is_retryable());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_connect_tls_refused() {
|
|
let result = connect_tls("127.0.0.1", 19998, 2).await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_connect_timeout() {
|
|
// 192.0.2.1 is TEST-NET, should time out
|
|
let result = connect_plain("192.0.2.1", 25, 1).await;
|
|
assert!(result.is_err());
|
|
let err = result.unwrap_err();
|
|
// May be timeout or connection error depending on network
|
|
assert!(err.is_retryable());
|
|
}
|
|
}
|