feat(metrics): add real-time throughput sampling and byte-counting metrics
This commit is contained in:
@@ -5,14 +5,7 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use super::connection_record::ConnectionRecord;
|
||||
|
||||
/// Statistics for a forwarded connection.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ForwardStats {
|
||||
pub bytes_in: AtomicU64,
|
||||
pub bytes_out: AtomicU64,
|
||||
}
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
|
||||
/// Perform bidirectional TCP forwarding between client and backend.
|
||||
///
|
||||
@@ -68,6 +61,10 @@ pub async fn forward_bidirectional(
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
|
||||
///
|
||||
/// When `metrics` is provided, bytes are reported to the MetricsCollector
|
||||
/// per-chunk (lock-free) as they flow through the copy loops, enabling
|
||||
/// real-time throughput sampling for long-lived connections.
|
||||
///
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
|
||||
pub async fn forward_bidirectional_with_timeouts(
|
||||
client: TcpStream,
|
||||
@@ -76,10 +73,14 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
metrics: Option<(Arc<MetricsCollector>, Option<String>)>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some((ref m, ref rid)) = metrics {
|
||||
m.record_bytes(data.len() as u64, 0, rid.as_deref());
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
@@ -90,6 +91,7 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let metrics_c2b = metrics.clone();
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
@@ -103,12 +105,16 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
}
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some((ref m, ref rid)) = metrics_c2b {
|
||||
m.record_bytes(n as u64, 0, rid.as_deref());
|
||||
}
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let metrics_b2c = metrics;
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
@@ -122,6 +128,9 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
}
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some((ref m, ref rid)) = metrics_b2c {
|
||||
m.record_bytes(0, n as u64, rid.as_deref());
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
@@ -174,152 +183,3 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
|
||||
/// Forward bidirectional with a callback for byte counting.
|
||||
pub async fn forward_bidirectional_with_stats(
|
||||
client: TcpStream,
|
||||
backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
stats: Arc<ForwardStats>,
|
||||
) -> std::io::Result<()> {
|
||||
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
|
||||
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
|
||||
/// updating a `ConnectionRecord` with byte counts and activity timestamps
|
||||
/// in real time for zombie detection.
|
||||
///
|
||||
/// When `record` is `None`, this behaves identically to
|
||||
/// `forward_bidirectional_with_timeouts`.
|
||||
///
|
||||
/// The record's `client_closed` / `backend_closed` flags are set when the
|
||||
/// respective copy loop terminates, giving the zombie scanner visibility
|
||||
/// into half-open connections.
|
||||
pub async fn forward_bidirectional_with_record(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
record: Option<Arc<ConnectionRecord>>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref r) = record {
|
||||
r.record_bytes_in(data.len() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let rec1 = record.clone();
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la1.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec1 {
|
||||
r.record_bytes_in(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
// Mark client side as closed
|
||||
if let Some(ref r) = rec1 {
|
||||
r.client_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let rec2 = record.clone();
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la2.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec2 {
|
||||
r.record_bytes_out(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
// Mark backend side as closed
|
||||
if let Some(ref r) = rec2 {
|
||||
r.backend_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
|
||||
@@ -496,17 +496,17 @@ impl TcpListenerManager {
|
||||
let mut backend_w = backend;
|
||||
backend_w.write_all(header.as_bytes()).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend_w, None,
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
Some((Arc::clone(&metrics), route_id.map(|s| s.to_string()))),
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
} else {
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, None,
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
Some((Arc::clone(&metrics), route_id.map(|s| s.to_string()))),
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
@@ -661,6 +661,7 @@ impl TcpListenerManager {
|
||||
stream, n, port, peer_addr,
|
||||
&route_match, domain.as_deref(), is_tls,
|
||||
&relay_socket_path,
|
||||
&metrics, route_id,
|
||||
).await;
|
||||
} else {
|
||||
debug!("Socket-handler route matched but no relay path configured");
|
||||
@@ -751,11 +752,11 @@ impl TcpListenerManager {
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, Some(&actual_buf),
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
Some((Arc::clone(&metrics), route_id.map(|s| s.to_string()))),
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
Some(rustproxy_config::TlsMode::Terminate) => {
|
||||
@@ -812,12 +813,11 @@ impl TcpListenerManager {
|
||||
let (tls_read, tls_write) = tokio::io::split(buf_stream);
|
||||
let (backend_read, backend_write) = tokio::io::split(backend);
|
||||
|
||||
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
let (_bytes_in, _bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
tls_read, tls_write, backend_read, backend_write,
|
||||
inactivity_timeout, max_lifetime,
|
||||
Some((Arc::clone(&metrics), route_id.map(|s| s.to_string()))),
|
||||
).await;
|
||||
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -825,7 +825,7 @@ impl TcpListenerManager {
|
||||
let route_tls = route_match.route.action.tls.as_ref();
|
||||
Self::handle_tls_terminate_reencrypt(
|
||||
stream, n, &domain, &target_host, target_port,
|
||||
peer_addr, &tls_configs, &metrics, route_id, &conn_config, route_tls,
|
||||
peer_addr, &tls_configs, Arc::clone(&metrics), route_id, &conn_config, route_tls,
|
||||
).await
|
||||
}
|
||||
None => {
|
||||
@@ -862,11 +862,11 @@ impl TcpListenerManager {
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
let (_bytes_in, _bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, Some(&actual_buf),
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
Some((Arc::clone(&metrics), route_id.map(|s| s.to_string()))),
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -892,6 +892,8 @@ impl TcpListenerManager {
|
||||
domain: Option<&str>,
|
||||
is_tls: bool,
|
||||
relay_path: &str,
|
||||
metrics: &MetricsCollector,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::UnixStream;
|
||||
@@ -932,12 +934,20 @@ impl TcpListenerManager {
|
||||
unix_stream.write_all(&initial_buf).await?;
|
||||
|
||||
// Bidirectional relay between TCP client and Unix socket handler
|
||||
let initial_len = initial_buf.len() as u64;
|
||||
match tokio::io::copy_bidirectional(&mut stream, &mut unix_stream).await {
|
||||
Ok((c2s, s2c)) => {
|
||||
// Include initial data bytes that were forwarded before copy_bidirectional
|
||||
let total_in = c2s + initial_len;
|
||||
debug!("Socket handler relay complete for {}: {} bytes in, {} bytes out",
|
||||
route_key, c2s, s2c);
|
||||
route_key, total_in, s2c);
|
||||
metrics.record_bytes(total_in, s2c, route_id);
|
||||
}
|
||||
Err(e) => {
|
||||
// Still record the initial data even on error
|
||||
if initial_len > 0 {
|
||||
metrics.record_bytes(initial_len, 0, route_id);
|
||||
}
|
||||
debug!("Socket handler relay ended for {}: {}", route_key, e);
|
||||
}
|
||||
}
|
||||
@@ -954,7 +964,7 @@ impl TcpListenerManager {
|
||||
target_port: u16,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
tls_configs: &HashMap<String, TlsCertConfig>,
|
||||
metrics: &MetricsCollector,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
route_id: Option<&str>,
|
||||
conn_config: &ConnectionConfig,
|
||||
route_tls: Option<&rustproxy_config::RouteTls>,
|
||||
@@ -1019,12 +1029,12 @@ impl TcpListenerManager {
|
||||
}
|
||||
};
|
||||
|
||||
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
let (_bytes_in, _bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
client_read, client_write, backend_read, backend_write,
|
||||
inactivity_timeout, max_lifetime,
|
||||
Some((metrics, route_id.map(|s| s.to_string()))),
|
||||
).await;
|
||||
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1058,6 +1068,9 @@ impl TcpListenerManager {
|
||||
}
|
||||
|
||||
/// Forward bidirectional between two split streams with inactivity and lifetime timeouts.
|
||||
///
|
||||
/// When `metrics` is provided, bytes are reported per-chunk (lock-free) for
|
||||
/// real-time throughput measurement.
|
||||
async fn forward_bidirectional_split_with_timeouts<R1, W1, R2, W2>(
|
||||
mut client_read: R1,
|
||||
mut client_write: W1,
|
||||
@@ -1065,6 +1078,7 @@ impl TcpListenerManager {
|
||||
mut backend_write: W2,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
metrics: Option<(Arc<MetricsCollector>, Option<String>)>,
|
||||
) -> (u64, u64)
|
||||
where
|
||||
R1: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
@@ -1080,6 +1094,7 @@ impl TcpListenerManager {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let metrics_c2b = metrics.clone();
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
@@ -1096,12 +1111,16 @@ impl TcpListenerManager {
|
||||
start.elapsed().as_millis() as u64,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
if let Some((ref m, ref rid)) = metrics_c2b {
|
||||
m.record_bytes(n as u64, 0, rid.as_deref());
|
||||
}
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let metrics_b2c = metrics;
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
@@ -1118,6 +1137,9 @@ impl TcpListenerManager {
|
||||
start.elapsed().as_millis() as u64,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
if let Some((ref m, ref rid)) = metrics_b2c {
|
||||
m.record_bytes(0, n as u64, rid.as_deref());
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
|
||||
Reference in New Issue
Block a user