//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector. use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::task::{Context, Poll}; use bytes::Bytes; use http_body::Frame; use rustproxy_metrics::MetricsCollector; /// Wraps any `http_body::Body` and counts data bytes passing through. /// /// When the body is fully consumed or dropped, accumulated byte counts /// are reported to the `MetricsCollector`. /// /// The inner body is pinned on the heap to support `!Unpin` types like `hyper::body::Incoming`. pub struct CountingBody { inner: Pin>, counted_bytes: AtomicU64, metrics: Arc, route_id: Option, /// Whether we count bytes as "in" (request body) or "out" (response body). direction: Direction, /// Whether we've already reported the bytes (to avoid double-reporting on drop). reported: bool, } /// Which direction the bytes flow. #[derive(Clone, Copy)] pub enum Direction { /// Request body: bytes flowing from client → upstream (counted as bytes_in) In, /// Response body: bytes flowing from upstream → client (counted as bytes_out) Out, } impl CountingBody { /// Create a new CountingBody wrapping an inner body. pub fn new( inner: B, metrics: Arc, route_id: Option, direction: Direction, ) -> Self { Self { inner: Box::pin(inner), counted_bytes: AtomicU64::new(0), metrics, route_id, direction, reported: false, } } /// Report accumulated bytes to the metrics collector. fn report(&mut self) { if self.reported { return; } self.reported = true; let bytes = self.counted_bytes.load(Ordering::Relaxed); if bytes == 0 { return; } let route_id = self.route_id.as_deref(); match self.direction { Direction::In => self.metrics.record_bytes(bytes, 0, route_id), Direction::Out => self.metrics.record_bytes(0, bytes, route_id), } } } impl Drop for CountingBody { fn drop(&mut self) { self.report(); } } // CountingBody is Unpin because inner is Pin> (always Unpin). impl Unpin for CountingBody {} impl http_body::Body for CountingBody where B: http_body::Body, { type Data = Bytes; type Error = B::Error; fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { let this = self.get_mut(); match this.inner.as_mut().poll_frame(cx) { Poll::Ready(Some(Ok(frame))) => { if let Some(data) = frame.data_ref() { this.counted_bytes.fetch_add(data.len() as u64, Ordering::Relaxed); } Poll::Ready(Some(Ok(frame))) } Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), Poll::Ready(None) => { // Body is fully consumed — report now this.report(); Poll::Ready(None) } Poll::Pending => Poll::Pending, } } fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } fn size_hint(&self) -> http_body::SizeHint { self.inner.size_hint() } }