use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use tracing::debug; use super::connection_record::ConnectionRecord; /// Statistics for a forwarded connection. #[derive(Debug, Default)] pub struct ForwardStats { pub bytes_in: AtomicU64, pub bytes_out: AtomicU64, } /// Perform bidirectional TCP forwarding between client and backend. /// /// This is the core data path for passthrough connections. /// Returns (bytes_from_client, bytes_from_backend) when the connection closes. pub async fn forward_bidirectional( mut client: TcpStream, mut backend: TcpStream, initial_data: Option<&[u8]>, ) -> std::io::Result<(u64, u64)> { // Send initial data (peeked bytes) to backend if let Some(data) = initial_data { backend.write_all(data).await?; } let (mut client_read, mut client_write) = client.split(); let (mut backend_read, mut backend_write) = backend.split(); let client_to_backend = async { let mut buf = vec![0u8; 65536]; let mut total = initial_data.map_or(0u64, |d| d.len() as u64); loop { let n = client_read.read(&mut buf).await?; if n == 0 { break; } backend_write.write_all(&buf[..n]).await?; total += n as u64; } backend_write.shutdown().await?; Ok::(total) }; let backend_to_client = async { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = backend_read.read(&mut buf).await?; if n == 0 { break; } client_write.write_all(&buf[..n]).await?; total += n as u64; } client_write.shutdown().await?; Ok::(total) }; let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client); Ok((c2b.unwrap_or(0), b2c.unwrap_or(0))) } /// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts. /// /// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out. pub async fn forward_bidirectional_with_timeouts( client: TcpStream, mut backend: TcpStream, initial_data: Option<&[u8]>, inactivity_timeout: std::time::Duration, max_lifetime: std::time::Duration, cancel: CancellationToken, ) -> std::io::Result<(u64, u64)> { // Send initial data (peeked bytes) to backend if let Some(data) = initial_data { backend.write_all(data).await?; } let (mut client_read, mut client_write) = client.into_split(); let (mut backend_read, mut backend_write) = backend.into_split(); let last_activity = Arc::new(AtomicU64::new(0)); let start = std::time::Instant::now(); let la1 = Arc::clone(&last_activity); let initial_len = initial_data.map_or(0u64, |d| d.len() as u64); let c2b = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = initial_len; loop { let n = match client_read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if backend_write.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } let _ = backend_write.shutdown().await; total }); let la2 = Arc::clone(&last_activity); let b2c = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match backend_read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if client_write.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } let _ = client_write.shutdown().await; total }); // Watchdog: inactivity, max lifetime, and cancellation let la_watch = Arc::clone(&last_activity); let c2b_handle = c2b.abort_handle(); let b2c_handle = b2c.abort_handle(); let watchdog = tokio::spawn(async move { let check_interval = std::time::Duration::from_secs(5); let mut last_seen = 0u64; loop { tokio::select! { _ = cancel.cancelled() => { debug!("Connection cancelled by shutdown"); c2b_handle.abort(); b2c_handle.abort(); break; } _ = tokio::time::sleep(check_interval) => { // Check max lifetime if start.elapsed() >= max_lifetime { debug!("Connection exceeded max lifetime, closing"); c2b_handle.abort(); b2c_handle.abort(); break; } // Check inactivity let current = la_watch.load(Ordering::Relaxed); if current == last_seen { let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { debug!("Connection inactive for {}ms, closing", elapsed_since_activity); c2b_handle.abort(); b2c_handle.abort(); break; } } last_seen = current; } } } }); let bytes_in = c2b.await.unwrap_or(0); let bytes_out = b2c.await.unwrap_or(0); watchdog.abort(); Ok((bytes_in, bytes_out)) } /// Forward bidirectional with a callback for byte counting. pub async fn forward_bidirectional_with_stats( client: TcpStream, backend: TcpStream, initial_data: Option<&[u8]>, stats: Arc, ) -> std::io::Result<()> { let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?; stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed); stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed); Ok(()) } /// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts, /// updating a `ConnectionRecord` with byte counts and activity timestamps /// in real time for zombie detection. /// /// When `record` is `None`, this behaves identically to /// `forward_bidirectional_with_timeouts`. /// /// The record's `client_closed` / `backend_closed` flags are set when the /// respective copy loop terminates, giving the zombie scanner visibility /// into half-open connections. pub async fn forward_bidirectional_with_record( client: TcpStream, mut backend: TcpStream, initial_data: Option<&[u8]>, inactivity_timeout: std::time::Duration, max_lifetime: std::time::Duration, cancel: CancellationToken, record: Option>, ) -> std::io::Result<(u64, u64)> { // Send initial data (peeked bytes) to backend if let Some(data) = initial_data { backend.write_all(data).await?; if let Some(ref r) = record { r.record_bytes_in(data.len() as u64); } } let (mut client_read, mut client_write) = client.into_split(); let (mut backend_read, mut backend_write) = backend.into_split(); let last_activity = Arc::new(AtomicU64::new(0)); let start = std::time::Instant::now(); let la1 = Arc::clone(&last_activity); let initial_len = initial_data.map_or(0u64, |d| d.len() as u64); let rec1 = record.clone(); let c2b = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = initial_len; loop { let n = match client_read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if backend_write.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; let now_ms = start.elapsed().as_millis() as u64; la1.store(now_ms, Ordering::Relaxed); if let Some(ref r) = rec1 { r.record_bytes_in(n as u64); } } let _ = backend_write.shutdown().await; // Mark client side as closed if let Some(ref r) = rec1 { r.client_closed.store(true, Ordering::Relaxed); } total }); let la2 = Arc::clone(&last_activity); let rec2 = record.clone(); let b2c = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match backend_read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if client_write.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; let now_ms = start.elapsed().as_millis() as u64; la2.store(now_ms, Ordering::Relaxed); if let Some(ref r) = rec2 { r.record_bytes_out(n as u64); } } let _ = client_write.shutdown().await; // Mark backend side as closed if let Some(ref r) = rec2 { r.backend_closed.store(true, Ordering::Relaxed); } total }); // Watchdog: inactivity, max lifetime, and cancellation let la_watch = Arc::clone(&last_activity); let c2b_handle = c2b.abort_handle(); let b2c_handle = b2c.abort_handle(); let watchdog = tokio::spawn(async move { let check_interval = std::time::Duration::from_secs(5); let mut last_seen = 0u64; loop { tokio::select! { _ = cancel.cancelled() => { debug!("Connection cancelled by shutdown"); c2b_handle.abort(); b2c_handle.abort(); break; } _ = tokio::time::sleep(check_interval) => { // Check max lifetime if start.elapsed() >= max_lifetime { debug!("Connection exceeded max lifetime, closing"); c2b_handle.abort(); b2c_handle.abort(); break; } // Check inactivity let current = la_watch.load(Ordering::Relaxed); if current == last_seen { let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { debug!("Connection inactive for {}ms, closing", elapsed_since_activity); c2b_handle.abort(); b2c_handle.abort(); break; } } last_seen = current; } } } }); let bytes_in = c2b.await.unwrap_or(0); let bytes_out = b2c.await.unwrap_or(0); watchdog.abort(); Ok((bytes_in, bytes_out)) }