//! TCP/TLS connection management for the SMTP client. use super::error::SmtpClientError; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::time::{timeout, Duration}; use tokio_rustls::client::TlsStream; use tracing::debug; /// A client-side SMTP stream that may be plain or TLS. pub enum ClientSmtpStream { Plain(BufReader), Tls(BufReader>), } impl std::fmt::Debug for ClientSmtpStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ClientSmtpStream::Plain(_) => write!(f, "ClientSmtpStream::Plain"), ClientSmtpStream::Tls(_) => write!(f, "ClientSmtpStream::Tls"), } } } impl ClientSmtpStream { /// Read a line from the stream (CRLF-terminated). pub async fn read_line(&mut self, buf: &mut String) -> Result { match self { ClientSmtpStream::Plain(reader) => reader.read_line(buf).await.map_err(|e| { SmtpClientError::ConnectionError { message: format!("Read error: {e}"), } }), ClientSmtpStream::Tls(reader) => reader.read_line(buf).await.map_err(|e| { SmtpClientError::ConnectionError { message: format!("TLS read error: {e}"), } }), } } /// Write bytes to the stream. pub async fn write_all(&mut self, data: &[u8]) -> Result<(), SmtpClientError> { match self { ClientSmtpStream::Plain(reader) => { reader.get_mut().write_all(data).await.map_err(|e| { SmtpClientError::ConnectionError { message: format!("Write error: {e}"), } }) } ClientSmtpStream::Tls(reader) => { reader.get_mut().write_all(data).await.map_err(|e| { SmtpClientError::ConnectionError { message: format!("TLS write error: {e}"), } }) } } } /// Flush the stream. pub async fn flush(&mut self) -> Result<(), SmtpClientError> { match self { ClientSmtpStream::Plain(reader) => { reader.get_mut().flush().await.map_err(|e| { SmtpClientError::ConnectionError { message: format!("Flush error: {e}"), } }) } ClientSmtpStream::Tls(reader) => { reader.get_mut().flush().await.map_err(|e| { SmtpClientError::ConnectionError { message: format!("TLS flush error: {e}"), } }) } } } /// Consume this stream and return the inner TcpStream (for STARTTLS upgrade). /// Only works on Plain streams; returns an error on TLS streams. pub fn into_tcp_stream(self) -> Result { match self { ClientSmtpStream::Plain(reader) => Ok(reader.into_inner()), ClientSmtpStream::Tls(_) => Err(SmtpClientError::TlsError { message: "Cannot extract TcpStream from an already-TLS stream".into(), }), } } } /// Connect to an SMTP server via plain TCP. pub async fn connect_plain( host: &str, port: u16, timeout_secs: u64, ) -> Result { debug!("Connecting to {}:{} (plain)", host, port); let addr = format!("{host}:{port}"); let stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(&addr)) .await .map_err(|_| SmtpClientError::TimeoutError { message: format!("Connection to {addr} timed out after {timeout_secs}s"), })? .map_err(|e| SmtpClientError::ConnectionError { message: format!("Failed to connect to {addr}: {e}"), })?; Ok(ClientSmtpStream::Plain(BufReader::new(stream))) } /// Connect to an SMTP server via implicit TLS (port 465). pub async fn connect_tls( host: &str, port: u16, timeout_secs: u64, ) -> Result { debug!("Connecting to {}:{} (implicit TLS)", host, port); let addr = format!("{host}:{port}"); let tcp_stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(&addr)) .await .map_err(|_| SmtpClientError::TimeoutError { message: format!("Connection to {addr} timed out after {timeout_secs}s"), })? .map_err(|e| SmtpClientError::ConnectionError { message: format!("Failed to connect to {addr}: {e}"), })?; let tls_stream = perform_tls_handshake(tcp_stream, host).await?; Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream))) } /// Upgrade a plain TCP connection to TLS (STARTTLS). pub async fn upgrade_to_tls( stream: ClientSmtpStream, hostname: &str, ) -> Result { debug!("Upgrading connection to TLS (STARTTLS) for {}", hostname); let tcp_stream = stream.into_tcp_stream()?; let tls_stream = perform_tls_handshake(tcp_stream, hostname).await?; Ok(ClientSmtpStream::Tls(BufReader::new(tls_stream))) } /// Perform the TLS handshake on a TCP stream using webpki-roots. async fn perform_tls_handshake( tcp_stream: TcpStream, hostname: &str, ) -> Result, SmtpClientError> { let mut root_store = rustls::RootCertStore::empty(); root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let tls_config = rustls::ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config)); let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string()).map_err(|e| { SmtpClientError::TlsError { message: format!("Invalid server name '{hostname}': {e}"), } })?; let tls_stream = connector .connect(server_name, tcp_stream) .await .map_err(|e| SmtpClientError::TlsError { message: format!("TLS handshake with {hostname} failed: {e}"), })?; Ok(tls_stream) } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_connect_plain_refused() { // Connecting to a port that's not listening should fail let result = connect_plain("127.0.0.1", 19999, 2).await; assert!(result.is_err()); let err = result.unwrap_err(); assert!(matches!(err, SmtpClientError::ConnectionError { .. })); assert!(err.is_retryable()); } #[tokio::test] async fn test_connect_tls_refused() { let result = connect_tls("127.0.0.1", 19998, 2).await; assert!(result.is_err()); } #[tokio::test] async fn test_connect_timeout() { // 192.0.2.1 is TEST-NET, should time out let result = connect_plain("192.0.2.1", 25, 1).await; assert!(result.is_err()); let err = result.unwrap_err(); // May be timeout or connection error depending on network assert!(err.is_retryable()); } }