use tokio::io::{AsyncRead, AsyncReadExt}; // Frame type constants pub const FRAME_OPEN: u8 = 0x01; pub const FRAME_DATA: u8 = 0x02; pub const FRAME_CLOSE: u8 = 0x03; pub const FRAME_DATA_BACK: u8 = 0x04; pub const FRAME_CLOSE_BACK: u8 = 0x05; pub const FRAME_CONFIG: u8 = 0x06; // Hub -> Edge: configuration update pub const FRAME_PING: u8 = 0x07; // Hub -> Edge: heartbeat probe pub const FRAME_PONG: u8 = 0x08; // Edge -> Hub: heartbeat response pub const FRAME_WINDOW_UPDATE: u8 = 0x09; // Edge -> Hub: per-stream flow control pub const FRAME_WINDOW_UPDATE_BACK: u8 = 0x0A; // Hub -> Edge: per-stream flow control // Frame header size: 4 (stream_id) + 1 (type) + 4 (length) = 9 bytes pub const FRAME_HEADER_SIZE: usize = 9; // Maximum payload size (16 MB) pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024; // Per-stream flow control constants /// Initial per-stream window size (4 MB). Sized for full throughput at high RTT: /// at 100ms RTT, this sustains ~40 MB/s per stream. pub const INITIAL_STREAM_WINDOW: u32 = 4 * 1024 * 1024; /// Send WINDOW_UPDATE after consuming this many bytes (half the initial window). pub const WINDOW_UPDATE_THRESHOLD: u32 = INITIAL_STREAM_WINDOW / 2; /// Maximum window size to prevent overflow. pub const MAX_WINDOW_SIZE: u32 = 16 * 1024 * 1024; /// Encode a WINDOW_UPDATE frame for a specific stream. pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Vec { encode_frame(stream_id, frame_type, &increment.to_be_bytes()) } /// Compute the target per-stream window size based on the number of active streams. /// Total memory budget is ~32MB shared across all streams. As more streams are active, /// each gets a smaller window. This adapts to current demand — few streams get high /// throughput, many streams save memory and reduce control frame pressure. pub fn compute_window_for_stream_count(active: u32) -> u32 { let per_stream = (32 * 1024 * 1024u64) / (active.max(1) as u64); per_stream.clamp(64 * 1024, INITIAL_STREAM_WINDOW as u64) as u32 } /// Decode a WINDOW_UPDATE payload into a byte increment. Returns None if payload is malformed. pub fn decode_window_update(payload: &[u8]) -> Option { if payload.len() != 4 { return None; } Some(u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]])) } /// A single multiplexed frame. #[derive(Debug, Clone)] pub struct Frame { pub stream_id: u32, pub frame_type: u8, pub payload: Vec, } /// Encode a frame into bytes: [stream_id:4][type:1][length:4][payload] pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Vec { let len = payload.len() as u32; let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + payload.len()); buf.extend_from_slice(&stream_id.to_be_bytes()); buf.push(frame_type); buf.extend_from_slice(&len.to_be_bytes()); buf.extend_from_slice(payload); buf } /// Build a PROXY protocol v1 header line. /// Format: `PROXY TCP4 \r\n` pub fn build_proxy_v1_header( client_ip: &str, edge_ip: &str, client_port: u16, dest_port: u16, ) -> String { format!( "PROXY TCP4 {} {} {} {}\r\n", client_ip, edge_ip, client_port, dest_port ) } /// Stateful async frame reader that yields `Frame` values from an `AsyncRead`. pub struct FrameReader { reader: R, header_buf: [u8; FRAME_HEADER_SIZE], } impl FrameReader { pub fn new(reader: R) -> Self { Self { reader, header_buf: [0u8; FRAME_HEADER_SIZE], } } /// Read the next frame. Returns `None` on EOF, `Err` on protocol violation. pub async fn next_frame(&mut self) -> Result, std::io::Error> { // Read header match self.reader.read_exact(&mut self.header_buf).await { Ok(_) => {} Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(e), } let stream_id = u32::from_be_bytes([ self.header_buf[0], self.header_buf[1], self.header_buf[2], self.header_buf[3], ]); let frame_type = self.header_buf[4]; let length = u32::from_be_bytes([ self.header_buf[5], self.header_buf[6], self.header_buf[7], self.header_buf[8], ]); if length > MAX_PAYLOAD_SIZE { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("frame payload too large: {} bytes", length), )); } let mut payload = vec![0u8; length as usize]; if length > 0 { self.reader.read_exact(&mut payload).await?; } Ok(Some(Frame { stream_id, frame_type, payload, })) } /// Consume the reader and return the inner stream. pub fn into_inner(self) -> R { self.reader } } #[cfg(test)] mod tests { use super::*; #[test] fn test_encode_frame() { let data = b"hello"; let encoded = encode_frame(42, FRAME_DATA, data); assert_eq!(encoded.len(), FRAME_HEADER_SIZE + data.len()); // stream_id = 42 in BE assert_eq!(&encoded[0..4], &42u32.to_be_bytes()); // frame type assert_eq!(encoded[4], FRAME_DATA); // length assert_eq!(&encoded[5..9], &5u32.to_be_bytes()); // payload assert_eq!(&encoded[9..], b"hello"); } #[test] fn test_encode_empty_frame() { let encoded = encode_frame(1, FRAME_CLOSE, &[]); assert_eq!(encoded.len(), FRAME_HEADER_SIZE); assert_eq!(&encoded[5..9], &0u32.to_be_bytes()); } #[test] fn test_proxy_v1_header() { let header = build_proxy_v1_header("1.2.3.4", "5.6.7.8", 12345, 443); assert_eq!(header, "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n"); } #[tokio::test] async fn test_frame_reader() { let frame1 = encode_frame(1, FRAME_OPEN, b"PROXY TCP4 1.2.3.4 5.6.7.8 1234 443\r\n"); let frame2 = encode_frame(1, FRAME_DATA, b"GET / HTTP/1.1\r\n"); let frame3 = encode_frame(1, FRAME_CLOSE, &[]); let mut data = Vec::new(); data.extend_from_slice(&frame1); data.extend_from_slice(&frame2); data.extend_from_slice(&frame3); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let f1 = reader.next_frame().await.unwrap().unwrap(); assert_eq!(f1.stream_id, 1); assert_eq!(f1.frame_type, FRAME_OPEN); assert!(f1.payload.starts_with(b"PROXY")); let f2 = reader.next_frame().await.unwrap().unwrap(); assert_eq!(f2.frame_type, FRAME_DATA); let f3 = reader.next_frame().await.unwrap().unwrap(); assert_eq!(f3.frame_type, FRAME_CLOSE); assert!(f3.payload.is_empty()); // EOF assert!(reader.next_frame().await.unwrap().is_none()); } #[test] fn test_encode_frame_config_type() { let payload = b"{\"listenPorts\":[443]}"; let encoded = encode_frame(0, FRAME_CONFIG, payload); assert_eq!(encoded[4], FRAME_CONFIG); assert_eq!(&encoded[0..4], &0u32.to_be_bytes()); assert_eq!(&encoded[9..], payload.as_slice()); } #[test] fn test_encode_frame_data_back_type() { let payload = b"response data"; let encoded = encode_frame(7, FRAME_DATA_BACK, payload); assert_eq!(encoded[4], FRAME_DATA_BACK); assert_eq!(&encoded[0..4], &7u32.to_be_bytes()); assert_eq!(&encoded[5..9], &(payload.len() as u32).to_be_bytes()); assert_eq!(&encoded[9..], payload.as_slice()); } #[test] fn test_encode_frame_close_back_type() { let encoded = encode_frame(99, FRAME_CLOSE_BACK, &[]); assert_eq!(encoded[4], FRAME_CLOSE_BACK); assert_eq!(&encoded[0..4], &99u32.to_be_bytes()); assert_eq!(&encoded[5..9], &0u32.to_be_bytes()); assert_eq!(encoded.len(), FRAME_HEADER_SIZE); } #[test] fn test_encode_frame_large_stream_id() { let encoded = encode_frame(u32::MAX, FRAME_DATA, b"x"); assert_eq!(&encoded[0..4], &u32::MAX.to_be_bytes()); assert_eq!(encoded[4], FRAME_DATA); assert_eq!(&encoded[5..9], &1u32.to_be_bytes()); assert_eq!(encoded[9], b'x'); } #[tokio::test] async fn test_frame_reader_max_payload_rejection() { let mut data = Vec::new(); data.extend_from_slice(&1u32.to_be_bytes()); data.push(FRAME_DATA); data.extend_from_slice(&(MAX_PAYLOAD_SIZE + 1).to_be_bytes()); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let result = reader.next_frame().await; assert!(result.is_err()); let err = result.unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); } #[tokio::test] async fn test_frame_reader_eof_mid_header() { // Only 5 bytes — not enough for a 9-byte header let data = vec![0u8; 5]; let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); // Should return Ok(None) on partial header EOF let result = reader.next_frame().await; assert!(result.unwrap().is_none()); } #[tokio::test] async fn test_frame_reader_eof_mid_payload() { // Full header claiming 100 bytes of payload, but only 10 bytes present let mut data = Vec::new(); data.extend_from_slice(&1u32.to_be_bytes()); data.push(FRAME_DATA); data.extend_from_slice(&100u32.to_be_bytes()); data.extend_from_slice(&[0xAB; 10]); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let result = reader.next_frame().await; assert!(result.is_err()); } #[tokio::test] async fn test_frame_reader_all_frame_types() { let types = [ FRAME_OPEN, FRAME_DATA, FRAME_CLOSE, FRAME_DATA_BACK, FRAME_CLOSE_BACK, FRAME_CONFIG, FRAME_PING, FRAME_PONG, ]; let mut data = Vec::new(); for (i, &ft) in types.iter().enumerate() { let payload = format!("payload_{}", i); data.extend_from_slice(&encode_frame(i as u32, ft, payload.as_bytes())); } let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); for (i, &ft) in types.iter().enumerate() { let frame = reader.next_frame().await.unwrap().unwrap(); assert_eq!(frame.stream_id, i as u32); assert_eq!(frame.frame_type, ft); assert_eq!(frame.payload, format!("payload_{}", i).as_bytes()); } assert!(reader.next_frame().await.unwrap().is_none()); } #[tokio::test] async fn test_frame_reader_zero_length_payload() { let data = encode_frame(42, FRAME_CLOSE, &[]); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let frame = reader.next_frame().await.unwrap().unwrap(); assert_eq!(frame.stream_id, 42); assert_eq!(frame.frame_type, FRAME_CLOSE); assert!(frame.payload.is_empty()); } #[test] fn test_encode_frame_ping_pong() { // PING: stream_id=0, empty payload (control frame) let ping = encode_frame(0, FRAME_PING, &[]); assert_eq!(ping[4], FRAME_PING); assert_eq!(&ping[0..4], &0u32.to_be_bytes()); assert_eq!(ping.len(), FRAME_HEADER_SIZE); // PONG: stream_id=0, empty payload (control frame) let pong = encode_frame(0, FRAME_PONG, &[]); assert_eq!(pong[4], FRAME_PONG); assert_eq!(&pong[0..4], &0u32.to_be_bytes()); assert_eq!(pong.len(), FRAME_HEADER_SIZE); } }