//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector. use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::task::{Context, Poll}; use bytes::Bytes; use http_body::Frame; use rustproxy_metrics::MetricsCollector; /// Wraps any `http_body::Body` and counts data bytes passing through. /// /// When the body is fully consumed or dropped, accumulated byte counts /// are reported to the `MetricsCollector`. /// /// The inner body is pinned on the heap to support `!Unpin` types like `hyper::body::Incoming`. pub struct CountingBody { inner: Pin>, counted_bytes: AtomicU64, metrics: Arc, route_id: Option, source_ip: Option, /// Whether we count bytes as "in" (request body) or "out" (response body). direction: Direction, /// Whether we've already reported the bytes (to avoid double-reporting on drop). reported: bool, /// Optional connection-level activity tracker. When set, poll_frame updates this /// to keep the idle watchdog alive during active body streaming (uploads/downloads). connection_activity: Option>, /// Start instant for computing elapsed ms for connection_activity. activity_start: Option, } /// Which direction the bytes flow. #[derive(Clone, Copy)] pub enum Direction { /// Request body: bytes flowing from client → upstream (counted as bytes_in) In, /// Response body: bytes flowing from upstream → client (counted as bytes_out) Out, } impl CountingBody { /// Create a new CountingBody wrapping an inner body. pub fn new( inner: B, metrics: Arc, route_id: Option, source_ip: Option, direction: Direction, ) -> Self { Self { inner: Box::pin(inner), counted_bytes: AtomicU64::new(0), metrics, route_id, source_ip, direction, reported: false, connection_activity: None, activity_start: None, } } /// Set the connection-level activity tracker. When set, each data frame /// updates this timestamp to prevent the idle watchdog from killing the /// connection during active body streaming. pub fn with_connection_activity(mut self, activity: Arc, start: std::time::Instant) -> Self { self.connection_activity = Some(activity); self.activity_start = Some(start); self } /// Report accumulated bytes to the metrics collector. fn report(&mut self) { if self.reported { return; } self.reported = true; let bytes = self.counted_bytes.load(Ordering::Relaxed); if bytes == 0 { return; } let route_id = self.route_id.as_deref(); let source_ip = self.source_ip.as_deref(); match self.direction { Direction::In => self.metrics.record_bytes(bytes, 0, route_id, source_ip), Direction::Out => self.metrics.record_bytes(0, bytes, route_id, source_ip), } } } impl Drop for CountingBody { fn drop(&mut self) { self.report(); } } // CountingBody is Unpin because inner is Pin> (always Unpin). impl Unpin for CountingBody {} impl http_body::Body for CountingBody where B: http_body::Body, { type Data = Bytes; type Error = B::Error; fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { let this = self.get_mut(); match this.inner.as_mut().poll_frame(cx) { Poll::Ready(Some(Ok(frame))) => { if let Some(data) = frame.data_ref() { this.counted_bytes.fetch_add(data.len() as u64, Ordering::Relaxed); // Keep the connection-level idle watchdog alive during body streaming if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) { activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); } } Poll::Ready(Some(Ok(frame))) } Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), Poll::Ready(None) => { // Body is fully consumed — report now this.report(); Poll::Ready(None) } Poll::Pending => Poll::Pending, } } fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } fn size_hint(&self) -> http_body::SizeHint { self.inner.size_hint() } }