326 lines
11 KiB
Rust
326 lines
11 KiB
Rust
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||
|
|
use tokio::net::TcpStream;
|
||
|
|
use tokio_util::sync::CancellationToken;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||
|
|
use tracing::debug;
|
||
|
|
|
||
|
|
use super::connection_record::ConnectionRecord;
|
||
|
|
|
||
|
|
/// Statistics for a forwarded connection.
|
||
|
|
#[derive(Debug, Default)]
|
||
|
|
pub struct ForwardStats {
|
||
|
|
pub bytes_in: AtomicU64,
|
||
|
|
pub bytes_out: AtomicU64,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Perform bidirectional TCP forwarding between client and backend.
|
||
|
|
///
|
||
|
|
/// This is the core data path for passthrough connections.
|
||
|
|
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
|
||
|
|
pub async fn forward_bidirectional(
|
||
|
|
mut client: TcpStream,
|
||
|
|
mut backend: TcpStream,
|
||
|
|
initial_data: Option<&[u8]>,
|
||
|
|
) -> std::io::Result<(u64, u64)> {
|
||
|
|
// Send initial data (peeked bytes) to backend
|
||
|
|
if let Some(data) = initial_data {
|
||
|
|
backend.write_all(data).await?;
|
||
|
|
}
|
||
|
|
|
||
|
|
let (mut client_read, mut client_write) = client.split();
|
||
|
|
let (mut backend_read, mut backend_write) = backend.split();
|
||
|
|
|
||
|
|
let client_to_backend = async {
|
||
|
|
let mut buf = vec![0u8; 65536];
|
||
|
|
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
|
||
|
|
loop {
|
||
|
|
let n = client_read.read(&mut buf).await?;
|
||
|
|
if n == 0 {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
backend_write.write_all(&buf[..n]).await?;
|
||
|
|
total += n as u64;
|
||
|
|
}
|
||
|
|
backend_write.shutdown().await?;
|
||
|
|
Ok::<u64, std::io::Error>(total)
|
||
|
|
};
|
||
|
|
|
||
|
|
let backend_to_client = async {
|
||
|
|
let mut buf = vec![0u8; 65536];
|
||
|
|
let mut total = 0u64;
|
||
|
|
loop {
|
||
|
|
let n = backend_read.read(&mut buf).await?;
|
||
|
|
if n == 0 {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
client_write.write_all(&buf[..n]).await?;
|
||
|
|
total += n as u64;
|
||
|
|
}
|
||
|
|
client_write.shutdown().await?;
|
||
|
|
Ok::<u64, std::io::Error>(total)
|
||
|
|
};
|
||
|
|
|
||
|
|
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
||
|
|
|
||
|
|
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
|
||
|
|
///
|
||
|
|
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
|
||
|
|
pub async fn forward_bidirectional_with_timeouts(
|
||
|
|
client: TcpStream,
|
||
|
|
mut backend: TcpStream,
|
||
|
|
initial_data: Option<&[u8]>,
|
||
|
|
inactivity_timeout: std::time::Duration,
|
||
|
|
max_lifetime: std::time::Duration,
|
||
|
|
cancel: CancellationToken,
|
||
|
|
) -> std::io::Result<(u64, u64)> {
|
||
|
|
// Send initial data (peeked bytes) to backend
|
||
|
|
if let Some(data) = initial_data {
|
||
|
|
backend.write_all(data).await?;
|
||
|
|
}
|
||
|
|
|
||
|
|
let (mut client_read, mut client_write) = client.into_split();
|
||
|
|
let (mut backend_read, mut backend_write) = backend.into_split();
|
||
|
|
|
||
|
|
let last_activity = Arc::new(AtomicU64::new(0));
|
||
|
|
let start = std::time::Instant::now();
|
||
|
|
|
||
|
|
let la1 = Arc::clone(&last_activity);
|
||
|
|
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||
|
|
let c2b = tokio::spawn(async move {
|
||
|
|
let mut buf = vec![0u8; 65536];
|
||
|
|
let mut total = initial_len;
|
||
|
|
loop {
|
||
|
|
let n = match client_read.read(&mut buf).await {
|
||
|
|
Ok(0) | Err(_) => break,
|
||
|
|
Ok(n) => n,
|
||
|
|
};
|
||
|
|
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
total += n as u64;
|
||
|
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||
|
|
}
|
||
|
|
let _ = backend_write.shutdown().await;
|
||
|
|
total
|
||
|
|
});
|
||
|
|
|
||
|
|
let la2 = Arc::clone(&last_activity);
|
||
|
|
let b2c = tokio::spawn(async move {
|
||
|
|
let mut buf = vec![0u8; 65536];
|
||
|
|
let mut total = 0u64;
|
||
|
|
loop {
|
||
|
|
let n = match backend_read.read(&mut buf).await {
|
||
|
|
Ok(0) | Err(_) => break,
|
||
|
|
Ok(n) => n,
|
||
|
|
};
|
||
|
|
if client_write.write_all(&buf[..n]).await.is_err() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
total += n as u64;
|
||
|
|
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||
|
|
}
|
||
|
|
let _ = client_write.shutdown().await;
|
||
|
|
total
|
||
|
|
});
|
||
|
|
|
||
|
|
// Watchdog: inactivity, max lifetime, and cancellation
|
||
|
|
let la_watch = Arc::clone(&last_activity);
|
||
|
|
let c2b_handle = c2b.abort_handle();
|
||
|
|
let b2c_handle = b2c.abort_handle();
|
||
|
|
let watchdog = tokio::spawn(async move {
|
||
|
|
let check_interval = std::time::Duration::from_secs(5);
|
||
|
|
let mut last_seen = 0u64;
|
||
|
|
loop {
|
||
|
|
tokio::select! {
|
||
|
|
_ = cancel.cancelled() => {
|
||
|
|
debug!("Connection cancelled by shutdown");
|
||
|
|
c2b_handle.abort();
|
||
|
|
b2c_handle.abort();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
_ = tokio::time::sleep(check_interval) => {
|
||
|
|
// Check max lifetime
|
||
|
|
if start.elapsed() >= max_lifetime {
|
||
|
|
debug!("Connection exceeded max lifetime, closing");
|
||
|
|
c2b_handle.abort();
|
||
|
|
b2c_handle.abort();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check inactivity
|
||
|
|
let current = la_watch.load(Ordering::Relaxed);
|
||
|
|
if current == last_seen {
|
||
|
|
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||
|
|
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||
|
|
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||
|
|
c2b_handle.abort();
|
||
|
|
b2c_handle.abort();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
last_seen = current;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
let bytes_in = c2b.await.unwrap_or(0);
|
||
|
|
let bytes_out = b2c.await.unwrap_or(0);
|
||
|
|
watchdog.abort();
|
||
|
|
Ok((bytes_in, bytes_out))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Forward bidirectional with a callback for byte counting.
|
||
|
|
pub async fn forward_bidirectional_with_stats(
|
||
|
|
client: TcpStream,
|
||
|
|
backend: TcpStream,
|
||
|
|
initial_data: Option<&[u8]>,
|
||
|
|
stats: Arc<ForwardStats>,
|
||
|
|
) -> std::io::Result<()> {
|
||
|
|
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
|
||
|
|
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||
|
|
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
|
||
|
|
/// updating a `ConnectionRecord` with byte counts and activity timestamps
|
||
|
|
/// in real time for zombie detection.
|
||
|
|
///
|
||
|
|
/// When `record` is `None`, this behaves identically to
|
||
|
|
/// `forward_bidirectional_with_timeouts`.
|
||
|
|
///
|
||
|
|
/// The record's `client_closed` / `backend_closed` flags are set when the
|
||
|
|
/// respective copy loop terminates, giving the zombie scanner visibility
|
||
|
|
/// into half-open connections.
|
||
|
|
pub async fn forward_bidirectional_with_record(
|
||
|
|
client: TcpStream,
|
||
|
|
mut backend: TcpStream,
|
||
|
|
initial_data: Option<&[u8]>,
|
||
|
|
inactivity_timeout: std::time::Duration,
|
||
|
|
max_lifetime: std::time::Duration,
|
||
|
|
cancel: CancellationToken,
|
||
|
|
record: Option<Arc<ConnectionRecord>>,
|
||
|
|
) -> std::io::Result<(u64, u64)> {
|
||
|
|
// Send initial data (peeked bytes) to backend
|
||
|
|
if let Some(data) = initial_data {
|
||
|
|
backend.write_all(data).await?;
|
||
|
|
if let Some(ref r) = record {
|
||
|
|
r.record_bytes_in(data.len() as u64);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
let (mut client_read, mut client_write) = client.into_split();
|
||
|
|
let (mut backend_read, mut backend_write) = backend.into_split();
|
||
|
|
|
||
|
|
let last_activity = Arc::new(AtomicU64::new(0));
|
||
|
|
let start = std::time::Instant::now();
|
||
|
|
|
||
|
|
let la1 = Arc::clone(&last_activity);
|
||
|
|
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||
|
|
let rec1 = record.clone();
|
||
|
|
let c2b = tokio::spawn(async move {
|
||
|
|
let mut buf = vec![0u8; 65536];
|
||
|
|
let mut total = initial_len;
|
||
|
|
loop {
|
||
|
|
let n = match client_read.read(&mut buf).await {
|
||
|
|
Ok(0) | Err(_) => break,
|
||
|
|
Ok(n) => n,
|
||
|
|
};
|
||
|
|
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
total += n as u64;
|
||
|
|
let now_ms = start.elapsed().as_millis() as u64;
|
||
|
|
la1.store(now_ms, Ordering::Relaxed);
|
||
|
|
if let Some(ref r) = rec1 {
|
||
|
|
r.record_bytes_in(n as u64);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
let _ = backend_write.shutdown().await;
|
||
|
|
// Mark client side as closed
|
||
|
|
if let Some(ref r) = rec1 {
|
||
|
|
r.client_closed.store(true, Ordering::Relaxed);
|
||
|
|
}
|
||
|
|
total
|
||
|
|
});
|
||
|
|
|
||
|
|
let la2 = Arc::clone(&last_activity);
|
||
|
|
let rec2 = record.clone();
|
||
|
|
let b2c = tokio::spawn(async move {
|
||
|
|
let mut buf = vec![0u8; 65536];
|
||
|
|
let mut total = 0u64;
|
||
|
|
loop {
|
||
|
|
let n = match backend_read.read(&mut buf).await {
|
||
|
|
Ok(0) | Err(_) => break,
|
||
|
|
Ok(n) => n,
|
||
|
|
};
|
||
|
|
if client_write.write_all(&buf[..n]).await.is_err() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
total += n as u64;
|
||
|
|
let now_ms = start.elapsed().as_millis() as u64;
|
||
|
|
la2.store(now_ms, Ordering::Relaxed);
|
||
|
|
if let Some(ref r) = rec2 {
|
||
|
|
r.record_bytes_out(n as u64);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
let _ = client_write.shutdown().await;
|
||
|
|
// Mark backend side as closed
|
||
|
|
if let Some(ref r) = rec2 {
|
||
|
|
r.backend_closed.store(true, Ordering::Relaxed);
|
||
|
|
}
|
||
|
|
total
|
||
|
|
});
|
||
|
|
|
||
|
|
// Watchdog: inactivity, max lifetime, and cancellation
|
||
|
|
let la_watch = Arc::clone(&last_activity);
|
||
|
|
let c2b_handle = c2b.abort_handle();
|
||
|
|
let b2c_handle = b2c.abort_handle();
|
||
|
|
let watchdog = tokio::spawn(async move {
|
||
|
|
let check_interval = std::time::Duration::from_secs(5);
|
||
|
|
let mut last_seen = 0u64;
|
||
|
|
loop {
|
||
|
|
tokio::select! {
|
||
|
|
_ = cancel.cancelled() => {
|
||
|
|
debug!("Connection cancelled by shutdown");
|
||
|
|
c2b_handle.abort();
|
||
|
|
b2c_handle.abort();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
_ = tokio::time::sleep(check_interval) => {
|
||
|
|
// Check max lifetime
|
||
|
|
if start.elapsed() >= max_lifetime {
|
||
|
|
debug!("Connection exceeded max lifetime, closing");
|
||
|
|
c2b_handle.abort();
|
||
|
|
b2c_handle.abort();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check inactivity
|
||
|
|
let current = la_watch.load(Ordering::Relaxed);
|
||
|
|
if current == last_seen {
|
||
|
|
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||
|
|
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||
|
|
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||
|
|
c2b_handle.abort();
|
||
|
|
b2c_handle.abort();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
last_seen = current;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
let bytes_in = c2b.await.unwrap_or(0);
|
||
|
|
let bytes_out = b2c.await.unwrap_or(0);
|
||
|
|
watchdog.abort();
|
||
|
|
Ok((bytes_in, bytes_out))
|
||
|
|
}
|