feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
This commit is contained in:
325
rust/crates/rustproxy-passthrough/src/forwarder.rs
Normal file
325
rust/crates/rustproxy-passthrough/src/forwarder.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use super::connection_record::ConnectionRecord;
|
||||
|
||||
/// Statistics for a forwarded connection.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ForwardStats {
|
||||
pub bytes_in: AtomicU64,
|
||||
pub bytes_out: AtomicU64,
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding between client and backend.
|
||||
///
|
||||
/// This is the core data path for passthrough connections.
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
|
||||
pub async fn forward_bidirectional(
|
||||
mut client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.split();
|
||||
let (mut backend_read, mut backend_write) = backend.split();
|
||||
|
||||
let client_to_backend = async {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
loop {
|
||||
let n = client_read.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
backend_write.write_all(&buf[..n]).await?;
|
||||
total += n as u64;
|
||||
}
|
||||
backend_write.shutdown().await?;
|
||||
Ok::<u64, std::io::Error>(total)
|
||||
};
|
||||
|
||||
let backend_to_client = async {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = backend_read.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
client_write.write_all(&buf[..n]).await?;
|
||||
total += n as u64;
|
||||
}
|
||||
client_write.shutdown().await?;
|
||||
Ok::<u64, std::io::Error>(total)
|
||||
};
|
||||
|
||||
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
||||
|
||||
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
|
||||
///
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
|
||||
pub async fn forward_bidirectional_with_timeouts(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
|
||||
/// Forward bidirectional with a callback for byte counting.
|
||||
pub async fn forward_bidirectional_with_stats(
|
||||
client: TcpStream,
|
||||
backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
stats: Arc<ForwardStats>,
|
||||
) -> std::io::Result<()> {
|
||||
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
|
||||
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
|
||||
/// updating a `ConnectionRecord` with byte counts and activity timestamps
|
||||
/// in real time for zombie detection.
|
||||
///
|
||||
/// When `record` is `None`, this behaves identically to
|
||||
/// `forward_bidirectional_with_timeouts`.
|
||||
///
|
||||
/// The record's `client_closed` / `backend_closed` flags are set when the
|
||||
/// respective copy loop terminates, giving the zombie scanner visibility
|
||||
/// into half-open connections.
|
||||
pub async fn forward_bidirectional_with_record(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
record: Option<Arc<ConnectionRecord>>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref r) = record {
|
||||
r.record_bytes_in(data.len() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let rec1 = record.clone();
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la1.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec1 {
|
||||
r.record_bytes_in(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
// Mark client side as closed
|
||||
if let Some(ref r) = rec1 {
|
||||
r.client_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let rec2 = record.clone();
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la2.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec2 {
|
||||
r.record_bytes_out(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
// Mark backend side as closed
|
||||
if let Some(ref r) = rec2 {
|
||||
r.backend_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
Reference in New Issue
Block a user