use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use tracing::debug; use rustproxy_metrics::MetricsCollector; /// Context for forwarding metrics, replacing the growing tuple pattern. #[derive(Clone)] pub struct ForwardMetricsCtx { pub collector: Arc, pub route_id: Option, pub source_ip: Option, } /// Perform bidirectional TCP forwarding between client and backend. /// /// This is the core data path for passthrough connections. /// Returns (bytes_from_client, bytes_from_backend) when the connection closes. pub async fn forward_bidirectional( mut client: TcpStream, mut backend: TcpStream, initial_data: Option<&[u8]>, ) -> std::io::Result<(u64, u64)> { // Send initial data (peeked bytes) to backend if let Some(data) = initial_data { backend.write_all(data).await?; } let (mut client_read, mut client_write) = client.split(); let (mut backend_read, mut backend_write) = backend.split(); let client_to_backend = async { let mut buf = vec![0u8; 65536]; let mut total = initial_data.map_or(0u64, |d| d.len() as u64); loop { let n = client_read.read(&mut buf).await?; if n == 0 { break; } backend_write.write_all(&buf[..n]).await?; total += n as u64; } backend_write.shutdown().await?; Ok::(total) }; let backend_to_client = async { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = backend_read.read(&mut buf).await?; if n == 0 { break; } client_write.write_all(&buf[..n]).await?; total += n as u64; } client_write.shutdown().await?; Ok::(total) }; let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client); Ok((c2b.unwrap_or(0), b2c.unwrap_or(0))) } /// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts. /// /// When `metrics` is provided, bytes are reported to the MetricsCollector /// per-chunk (lock-free) as they flow through the copy loops, enabling /// real-time throughput sampling for long-lived connections. /// /// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out. pub async fn forward_bidirectional_with_timeouts( client: TcpStream, mut backend: TcpStream, initial_data: Option<&[u8]>, inactivity_timeout: std::time::Duration, max_lifetime: std::time::Duration, cancel: CancellationToken, metrics: Option, ) -> std::io::Result<(u64, u64)> { // Send initial data (peeked bytes) to backend if let Some(data) = initial_data { backend.write_all(data).await?; if let Some(ref ctx) = metrics { ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); } } let (mut client_read, mut client_write) = client.into_split(); let (mut backend_read, mut backend_write) = backend.into_split(); let last_activity = Arc::new(AtomicU64::new(0)); let start = std::time::Instant::now(); let la1 = Arc::clone(&last_activity); let initial_len = initial_data.map_or(0u64, |d| d.len() as u64); let metrics_c2b = metrics.clone(); let c2b = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = initial_len; loop { let n = match client_read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if backend_write.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); if let Some(ref ctx) = metrics_c2b { ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); } } let _ = backend_write.shutdown().await; total }); let la2 = Arc::clone(&last_activity); let metrics_b2c = metrics; let b2c = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; let mut total = 0u64; loop { let n = match backend_read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => n, }; if client_write.write_all(&buf[..n]).await.is_err() { break; } total += n as u64; la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); if let Some(ref ctx) = metrics_b2c { ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); } } let _ = client_write.shutdown().await; total }); // Watchdog: inactivity, max lifetime, and cancellation let la_watch = Arc::clone(&last_activity); let c2b_handle = c2b.abort_handle(); let b2c_handle = b2c.abort_handle(); let watchdog = tokio::spawn(async move { let check_interval = std::time::Duration::from_secs(5); let mut last_seen = 0u64; loop { tokio::select! { _ = cancel.cancelled() => { debug!("Connection cancelled by shutdown"); c2b_handle.abort(); b2c_handle.abort(); break; } _ = tokio::time::sleep(check_interval) => { // Check max lifetime if start.elapsed() >= max_lifetime { debug!("Connection exceeded max lifetime, closing"); c2b_handle.abort(); b2c_handle.abort(); break; } // Check inactivity let current = la_watch.load(Ordering::Relaxed); if current == last_seen { let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { debug!("Connection inactive for {}ms, closing", elapsed_since_activity); c2b_handle.abort(); b2c_handle.abort(); break; } } last_seen = current; } } } }); let bytes_in = c2b.await.unwrap_or(0); let bytes_out = b2c.await.unwrap_or(0); watchdog.abort(); Ok((bytes_in, bytes_out)) }