2026-02-09 10:55:46 +00:00
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
use tokio_util::sync::CancellationToken;
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
|
|
|
use tracing::debug;
|
|
|
|
|
|
2026-02-13 23:18:22 +00:00
|
|
|
use rustproxy_metrics::MetricsCollector;
|
2026-02-09 10:55:46 +00:00
|
|
|
|
2026-02-14 11:15:17 +00:00
|
|
|
/// Context for forwarding metrics, replacing the growing tuple pattern.
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
|
pub struct ForwardMetricsCtx {
|
|
|
|
|
pub collector: Arc<MetricsCollector>,
|
|
|
|
|
pub route_id: Option<String>,
|
|
|
|
|
pub source_ip: Option<String>,
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-09 10:55:46 +00:00
|
|
|
/// Perform bidirectional TCP forwarding between client and backend.
|
|
|
|
|
///
|
|
|
|
|
/// This is the core data path for passthrough connections.
|
|
|
|
|
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
|
|
|
|
|
pub async fn forward_bidirectional(
|
|
|
|
|
mut client: TcpStream,
|
|
|
|
|
mut backend: TcpStream,
|
|
|
|
|
initial_data: Option<&[u8]>,
|
|
|
|
|
) -> std::io::Result<(u64, u64)> {
|
|
|
|
|
// Send initial data (peeked bytes) to backend
|
|
|
|
|
if let Some(data) = initial_data {
|
|
|
|
|
backend.write_all(data).await?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let (mut client_read, mut client_write) = client.split();
|
|
|
|
|
let (mut backend_read, mut backend_write) = backend.split();
|
|
|
|
|
|
|
|
|
|
let client_to_backend = async {
|
|
|
|
|
let mut buf = vec![0u8; 65536];
|
|
|
|
|
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
|
|
|
|
|
loop {
|
|
|
|
|
let n = client_read.read(&mut buf).await?;
|
|
|
|
|
if n == 0 {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
backend_write.write_all(&buf[..n]).await?;
|
|
|
|
|
total += n as u64;
|
|
|
|
|
}
|
|
|
|
|
backend_write.shutdown().await?;
|
|
|
|
|
Ok::<u64, std::io::Error>(total)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let backend_to_client = async {
|
|
|
|
|
let mut buf = vec![0u8; 65536];
|
|
|
|
|
let mut total = 0u64;
|
|
|
|
|
loop {
|
|
|
|
|
let n = backend_read.read(&mut buf).await?;
|
|
|
|
|
if n == 0 {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
client_write.write_all(&buf[..n]).await?;
|
|
|
|
|
total += n as u64;
|
|
|
|
|
}
|
|
|
|
|
client_write.shutdown().await?;
|
|
|
|
|
Ok::<u64, std::io::Error>(total)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
|
|
|
|
|
|
|
|
|
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
|
|
|
|
|
///
|
2026-02-13 23:18:22 +00:00
|
|
|
/// When `metrics` is provided, bytes are reported to the MetricsCollector
|
|
|
|
|
/// per-chunk (lock-free) as they flow through the copy loops, enabling
|
|
|
|
|
/// real-time throughput sampling for long-lived connections.
|
|
|
|
|
///
|
2026-02-09 10:55:46 +00:00
|
|
|
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
|
|
|
|
|
pub async fn forward_bidirectional_with_timeouts(
|
|
|
|
|
client: TcpStream,
|
|
|
|
|
mut backend: TcpStream,
|
|
|
|
|
initial_data: Option<&[u8]>,
|
|
|
|
|
inactivity_timeout: std::time::Duration,
|
|
|
|
|
max_lifetime: std::time::Duration,
|
|
|
|
|
cancel: CancellationToken,
|
2026-02-14 11:15:17 +00:00
|
|
|
metrics: Option<ForwardMetricsCtx>,
|
2026-02-09 10:55:46 +00:00
|
|
|
) -> std::io::Result<(u64, u64)> {
|
|
|
|
|
// Send initial data (peeked bytes) to backend
|
|
|
|
|
if let Some(data) = initial_data {
|
|
|
|
|
backend.write_all(data).await?;
|
2026-02-14 11:15:17 +00:00
|
|
|
if let Some(ref ctx) = metrics {
|
|
|
|
|
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
2026-02-09 10:55:46 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let (mut client_read, mut client_write) = client.into_split();
|
|
|
|
|
let (mut backend_read, mut backend_write) = backend.into_split();
|
|
|
|
|
|
|
|
|
|
let last_activity = Arc::new(AtomicU64::new(0));
|
|
|
|
|
let start = std::time::Instant::now();
|
|
|
|
|
|
|
|
|
|
let la1 = Arc::clone(&last_activity);
|
|
|
|
|
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
2026-02-13 23:18:22 +00:00
|
|
|
let metrics_c2b = metrics.clone();
|
2026-02-09 10:55:46 +00:00
|
|
|
let c2b = tokio::spawn(async move {
|
|
|
|
|
let mut buf = vec![0u8; 65536];
|
|
|
|
|
let mut total = initial_len;
|
|
|
|
|
loop {
|
|
|
|
|
let n = match client_read.read(&mut buf).await {
|
|
|
|
|
Ok(0) | Err(_) => break,
|
|
|
|
|
Ok(n) => n,
|
|
|
|
|
};
|
|
|
|
|
if backend_write.write_all(&buf[..n]).await.is_err() {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
total += n as u64;
|
2026-02-13 23:18:22 +00:00
|
|
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
2026-02-14 11:15:17 +00:00
|
|
|
if let Some(ref ctx) = metrics_c2b {
|
|
|
|
|
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
2026-02-09 10:55:46 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let _ = backend_write.shutdown().await;
|
|
|
|
|
total
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
let la2 = Arc::clone(&last_activity);
|
2026-02-13 23:18:22 +00:00
|
|
|
let metrics_b2c = metrics;
|
2026-02-09 10:55:46 +00:00
|
|
|
let b2c = tokio::spawn(async move {
|
|
|
|
|
let mut buf = vec![0u8; 65536];
|
|
|
|
|
let mut total = 0u64;
|
|
|
|
|
loop {
|
|
|
|
|
let n = match backend_read.read(&mut buf).await {
|
|
|
|
|
Ok(0) | Err(_) => break,
|
|
|
|
|
Ok(n) => n,
|
|
|
|
|
};
|
|
|
|
|
if client_write.write_all(&buf[..n]).await.is_err() {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
total += n as u64;
|
2026-02-13 23:18:22 +00:00
|
|
|
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
2026-02-14 11:15:17 +00:00
|
|
|
if let Some(ref ctx) = metrics_b2c {
|
|
|
|
|
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
2026-02-09 10:55:46 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let _ = client_write.shutdown().await;
|
|
|
|
|
total
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Watchdog: inactivity, max lifetime, and cancellation
|
|
|
|
|
let la_watch = Arc::clone(&last_activity);
|
|
|
|
|
let c2b_handle = c2b.abort_handle();
|
|
|
|
|
let b2c_handle = b2c.abort_handle();
|
|
|
|
|
let watchdog = tokio::spawn(async move {
|
|
|
|
|
let check_interval = std::time::Duration::from_secs(5);
|
|
|
|
|
let mut last_seen = 0u64;
|
|
|
|
|
loop {
|
|
|
|
|
tokio::select! {
|
|
|
|
|
_ = cancel.cancelled() => {
|
|
|
|
|
debug!("Connection cancelled by shutdown");
|
|
|
|
|
c2b_handle.abort();
|
|
|
|
|
b2c_handle.abort();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
_ = tokio::time::sleep(check_interval) => {
|
|
|
|
|
// Check max lifetime
|
|
|
|
|
if start.elapsed() >= max_lifetime {
|
|
|
|
|
debug!("Connection exceeded max lifetime, closing");
|
|
|
|
|
c2b_handle.abort();
|
|
|
|
|
b2c_handle.abort();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check inactivity
|
|
|
|
|
let current = la_watch.load(Ordering::Relaxed);
|
|
|
|
|
if current == last_seen {
|
|
|
|
|
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
|
|
|
|
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
|
|
|
|
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
|
|
|
|
c2b_handle.abort();
|
|
|
|
|
b2c_handle.abort();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
last_seen = current;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
let bytes_in = c2b.await.unwrap_or(0);
|
|
|
|
|
let bytes_out = b2c.await.unwrap_or(0);
|
|
|
|
|
watchdog.abort();
|
|
|
|
|
Ok((bytes_in, bytes_out))
|
|
|
|
|
}
|