use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use bytes::{Bytes, BytesMut, BufMut}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; use tokio::time::Instant; // Frame type constants pub const FRAME_OPEN: u8 = 0x01; pub const FRAME_DATA: u8 = 0x02; pub const FRAME_CLOSE: u8 = 0x03; pub const FRAME_DATA_BACK: u8 = 0x04; pub const FRAME_CLOSE_BACK: u8 = 0x05; pub const FRAME_CONFIG: u8 = 0x06; // Hub -> Edge: configuration update pub const FRAME_PING: u8 = 0x07; // Hub -> Edge: heartbeat probe pub const FRAME_PONG: u8 = 0x08; // Edge -> Hub: heartbeat response pub const FRAME_WINDOW_UPDATE: u8 = 0x09; // Edge -> Hub: per-stream flow control pub const FRAME_WINDOW_UPDATE_BACK: u8 = 0x0A; // Hub -> Edge: per-stream flow control // UDP tunnel frame types pub const FRAME_UDP_OPEN: u8 = 0x0B; // Edge -> Hub: open UDP session (payload: PROXY v2 header) pub const FRAME_UDP_DATA: u8 = 0x0C; // Edge -> Hub: UDP datagram pub const FRAME_UDP_DATA_BACK: u8 = 0x0D; // Hub -> Edge: UDP datagram pub const FRAME_UDP_CLOSE: u8 = 0x0E; // Either direction: close UDP session // Frame header size: 4 (stream_id) + 1 (type) + 4 (length) = 9 bytes pub const FRAME_HEADER_SIZE: usize = 9; // Maximum payload size (16 MB) pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024; // Per-stream flow control constants /// Initial (and maximum) per-stream window size (4 MB). pub const INITIAL_STREAM_WINDOW: u32 = 4 * 1024 * 1024; /// Send WINDOW_UPDATE after consuming this many bytes (half the initial window). pub const WINDOW_UPDATE_THRESHOLD: u32 = INITIAL_STREAM_WINDOW / 2; /// Maximum window size to prevent overflow. pub const MAX_WINDOW_SIZE: u32 = 4 * 1024 * 1024; // Sustained stream classification constants /// Throughput threshold for sustained classification (2.5 MB/s = 20 Mbit/s). pub const SUSTAINED_THRESHOLD_BPS: u64 = 2_500_000; /// Minimum duration before a stream can be classified as sustained. pub const SUSTAINED_MIN_DURATION_SECS: u64 = 10; /// Fixed window for sustained streams (1 MB — the floor). pub const SUSTAINED_WINDOW: u32 = 1 * 1024 * 1024; /// Maximum bytes written from sustained queue per forced drain (1 MB/s guarantee). pub const SUSTAINED_FORCED_DRAIN_CAP: usize = 1_048_576; /// Encode a WINDOW_UPDATE frame for a specific stream. pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Bytes { encode_frame(stream_id, frame_type, &increment.to_be_bytes()) } /// Compute the target per-stream window size based on the number of active streams. /// Total memory budget is ~200MB shared across all streams. Up to 50 streams get the /// full 4MB window; above that the window scales down to a 1MB floor at 200+ streams. pub fn compute_window_for_stream_count(active: u32) -> u32 { let per_stream = (200 * 1024 * 1024u64) / (active.max(1) as u64); per_stream.clamp(1 * 1024 * 1024, INITIAL_STREAM_WINDOW as u64) as u32 } /// Decode a WINDOW_UPDATE payload into a byte increment. Returns None if payload is malformed. pub fn decode_window_update(payload: &[u8]) -> Option { if payload.len() != 4 { return None; } Some(u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]])) } /// A single multiplexed frame. #[derive(Debug, Clone)] pub struct Frame { pub stream_id: u32, pub frame_type: u8, pub payload: Bytes, } /// Encode a frame into bytes: [stream_id:4][type:1][length:4][payload] pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Bytes { let len = payload.len() as u32; let mut buf = BytesMut::with_capacity(FRAME_HEADER_SIZE + payload.len()); buf.put_slice(&stream_id.to_be_bytes()); buf.put_u8(frame_type); buf.put_slice(&len.to_be_bytes()); buf.put_slice(payload); buf.freeze() } /// Write a frame header into `buf[0..FRAME_HEADER_SIZE]`. /// The caller must ensure payload is already at `buf[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len]`. /// This enables zero-copy encoding: read directly into `buf[FRAME_HEADER_SIZE..]`, then /// prepend the header without copying the payload. pub fn encode_frame_header(buf: &mut [u8], stream_id: u32, frame_type: u8, payload_len: usize) { buf[0..4].copy_from_slice(&stream_id.to_be_bytes()); buf[4] = frame_type; buf[5..9].copy_from_slice(&(payload_len as u32).to_be_bytes()); } /// Build a PROXY protocol v1 header line. /// Format: `PROXY TCP4 \r\n` pub fn build_proxy_v1_header( client_ip: &str, edge_ip: &str, client_port: u16, dest_port: u16, ) -> String { format!( "PROXY TCP4 {} {} {} {}\r\n", client_ip, edge_ip, client_port, dest_port ) } /// PROXY protocol v2 signature (12 bytes). pub const PROXY_V2_SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; /// Transport protocol for PROXY v2 header. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProxyV2Transport { /// TCP (STREAM) — byte 13 low nibble = 0x1 Tcp, /// UDP (DGRAM) — byte 13 low nibble = 0x2 Udp, } /// Build a PROXY protocol v2 binary header for IPv4. /// /// Returns a 28-byte header: /// - 12B signature /// - 1B version (0x2) + command (0x1 = PROXY) /// - 1B address family (0x1 = AF_INET) + transport (0x1 = TCP, 0x2 = UDP) /// - 2B address block length (0x000C = 12) /// - 4B source IPv4 address /// - 4B destination IPv4 address /// - 2B source port /// - 2B destination port pub fn build_proxy_v2_header( src_ip: &std::net::Ipv4Addr, dst_ip: &std::net::Ipv4Addr, src_port: u16, dst_port: u16, transport: ProxyV2Transport, ) -> Bytes { let mut buf = BytesMut::with_capacity(28); // Signature (12 bytes) buf.put_slice(&PROXY_V2_SIGNATURE); // Version 2 + PROXY command buf.put_u8(0x21); // AF_INET (0x1) + transport let transport_nibble = match transport { ProxyV2Transport::Tcp => 0x1, ProxyV2Transport::Udp => 0x2, }; buf.put_u8(0x10 | transport_nibble); // Address block length: 12 bytes for IPv4 buf.put_u16(12); // Source address (4 bytes, network byte order) buf.put_slice(&src_ip.octets()); // Destination address (4 bytes, network byte order) buf.put_slice(&dst_ip.octets()); // Source port (2 bytes, network byte order) buf.put_u16(src_port); // Destination port (2 bytes, network byte order) buf.put_u16(dst_port); buf.freeze() } /// Build a PROXY protocol v2 binary header from string IP addresses. /// Falls back to 0.0.0.0 if parsing fails. pub fn build_proxy_v2_header_from_str( src_ip: &str, dst_ip: &str, src_port: u16, dst_port: u16, transport: ProxyV2Transport, ) -> Bytes { let src: std::net::Ipv4Addr = src_ip.parse().unwrap_or(std::net::Ipv4Addr::UNSPECIFIED); let dst: std::net::Ipv4Addr = dst_ip.parse().unwrap_or(std::net::Ipv4Addr::UNSPECIFIED); build_proxy_v2_header(&src, &dst, src_port, dst_port, transport) } /// Stateful async frame reader that yields `Frame` values from an `AsyncRead`. pub struct FrameReader { reader: R, header_buf: [u8; FRAME_HEADER_SIZE], } impl FrameReader { pub fn new(reader: R) -> Self { Self { reader, header_buf: [0u8; FRAME_HEADER_SIZE], } } /// Read the next frame. Returns `None` on EOF, `Err` on protocol violation. pub async fn next_frame(&mut self) -> Result, std::io::Error> { // Read header match self.reader.read_exact(&mut self.header_buf).await { Ok(_) => {} Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(e), } let stream_id = u32::from_be_bytes([ self.header_buf[0], self.header_buf[1], self.header_buf[2], self.header_buf[3], ]); let frame_type = self.header_buf[4]; let length = u32::from_be_bytes([ self.header_buf[5], self.header_buf[6], self.header_buf[7], self.header_buf[8], ]); if length > MAX_PAYLOAD_SIZE { log::error!( "CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}", self.header_buf, stream_id, frame_type, length ); return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("frame payload too large: {} bytes (header={:02x?})", length, self.header_buf), )); } let mut payload = BytesMut::zeroed(length as usize); if length > 0 { self.reader.read_exact(&mut payload).await?; } Ok(Some(Frame { stream_id, frame_type, payload: payload.freeze(), })) } /// Consume the reader and return the inner stream. pub fn into_inner(self) -> R { self.reader } } // --------------------------------------------------------------------------- // TunnelIo: single-owner I/O multiplexer for the TLS tunnel connection // --------------------------------------------------------------------------- /// Events produced by the TunnelIo event loop. #[derive(Debug)] pub enum TunnelEvent { /// A complete frame was read from the remote side. Frame(Frame), /// The remote side closed the connection (EOF). Eof, /// A read error occurred. ReadError(std::io::Error), /// A write error occurred. WriteError(std::io::Error), /// No frames received for the liveness timeout duration. LivenessTimeout, /// The cancellation token was triggered. Cancelled, } /// Write state extracted into a sub-struct so the borrow checker can see /// disjoint field access between `self.write` and `self.stream`. struct WriteState { ctrl_queue: VecDeque, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first data_queue: VecDeque, // DATA, DATA_BACK — only when ctrl is empty sustained_queue: VecDeque, // DATA, DATA_BACK from sustained streams — lowest priority offset: usize, // progress within current frame being written flush_needed: bool, // Sustained starvation prevention: guaranteed 1 MB/s drain sustained_last_drain: Instant, sustained_bytes_this_period: usize, } impl WriteState { fn has_work(&self) -> bool { !self.ctrl_queue.is_empty() || !self.data_queue.is_empty() || !self.sustained_queue.is_empty() } } /// Single-owner I/O engine for the tunnel TLS connection. /// /// Owns the TLS stream directly — no `tokio::io::split()`, no mutex. /// Uses three priority write queues: /// 1. ctrl (PONG, WINDOW_UPDATE, CLOSE, OPEN) — always first /// 2. data (DATA, DATA_BACK from normal streams) — when ctrl empty /// 3. sustained (DATA, DATA_BACK from sustained streams) — lowest priority, /// drained freely when ctrl+data empty, or forced 1MB/s when they're not pub struct TunnelIo { stream: S, // Read state: accumulate bytes, parse frames incrementally read_buf: Vec, read_pos: usize, parse_pos: usize, // Write state: extracted sub-struct for safe disjoint borrows write: WriteState, } impl TunnelIo { pub fn new(stream: S, initial_data: Vec) -> Self { let read_pos = initial_data.len(); let mut read_buf = initial_data; if read_buf.capacity() < 65536 { read_buf.reserve(65536 - read_buf.len()); } Self { stream, read_buf, read_pos, parse_pos: 0, write: WriteState { ctrl_queue: VecDeque::new(), data_queue: VecDeque::new(), sustained_queue: VecDeque::new(), offset: 0, flush_needed: false, sustained_last_drain: Instant::now(), sustained_bytes_this_period: 0, }, } } /// Queue a high-priority control frame (PONG, WINDOW_UPDATE, CLOSE, OPEN). pub fn queue_ctrl(&mut self, frame: Bytes) { self.write.ctrl_queue.push_back(frame); } /// Queue a lower-priority data frame (DATA, DATA_BACK). pub fn queue_data(&mut self, frame: Bytes) { self.write.data_queue.push_back(frame); } /// Queue a lowest-priority sustained data frame. pub fn queue_sustained(&mut self, frame: Bytes) { self.write.sustained_queue.push_back(frame); } /// Try to parse a complete frame from the read buffer. /// Uses a parse_pos cursor to avoid drain() on every frame. pub fn try_parse_frame(&mut self) -> Option> { let available = self.read_pos - self.parse_pos; if available < FRAME_HEADER_SIZE { return None; } let base = self.parse_pos; let stream_id = u32::from_be_bytes([ self.read_buf[base], self.read_buf[base + 1], self.read_buf[base + 2], self.read_buf[base + 3], ]); let frame_type = self.read_buf[base + 4]; let length = u32::from_be_bytes([ self.read_buf[base + 5], self.read_buf[base + 6], self.read_buf[base + 7], self.read_buf[base + 8], ]); if length > MAX_PAYLOAD_SIZE { let header = [ self.read_buf[base], self.read_buf[base + 1], self.read_buf[base + 2], self.read_buf[base + 3], self.read_buf[base + 4], self.read_buf[base + 5], self.read_buf[base + 6], self.read_buf[base + 7], self.read_buf[base + 8], ]; log::error!( "CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}", header, stream_id, frame_type, length ); return Some(Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("frame payload too large: {} bytes (header={:02x?})", length, header), ))); } let total_frame_size = FRAME_HEADER_SIZE + length as usize; if available < total_frame_size { return None; } let payload = Bytes::copy_from_slice( &self.read_buf[base + FRAME_HEADER_SIZE..base + total_frame_size], ); self.parse_pos += total_frame_size; // Compact when parse_pos > half the data to reclaim memory if self.parse_pos > self.read_pos / 2 && self.parse_pos > 0 { self.read_buf.drain(..self.parse_pos); self.read_pos -= self.parse_pos; self.parse_pos = 0; } Some(Ok(Frame { stream_id, frame_type, payload })) } /// Poll-based I/O step. Returns Ready on events, Pending when idle. /// /// Order: write(ctrl->data->sustained) -> flush -> read -> channels -> timers pub fn poll_step( &mut self, cx: &mut Context<'_>, ctrl_rx: &mut tokio::sync::mpsc::Receiver, data_rx: &mut tokio::sync::mpsc::Receiver, sustained_rx: &mut tokio::sync::mpsc::Receiver, liveness_deadline: &mut Pin>, cancel_token: &tokio_util::sync::CancellationToken, ) -> Poll { // 1. WRITE: 3-tier priority — ctrl first, then data, then sustained. // Sustained drains freely when ctrl+data are empty. // Write one frame, set flush_needed, then flush must complete before // writing more. This prevents unbounded TLS session buffer growth. // Safe: `self.write` and `self.stream` are disjoint fields. let mut writes = 0; while self.write.has_work() && writes < 16 && !self.write.flush_needed { // Pick queue: ctrl > data > sustained let queue_id = if !self.write.ctrl_queue.is_empty() { 0 // ctrl } else if !self.write.data_queue.is_empty() { 1 // data } else { 2 // sustained }; let frame = match queue_id { 0 => self.write.ctrl_queue.front().unwrap(), 1 => self.write.data_queue.front().unwrap(), _ => self.write.sustained_queue.front().unwrap(), }; let remaining = &frame[self.write.offset..]; match Pin::new(&mut self.stream).poll_write(cx, remaining) { Poll::Ready(Ok(0)) => { log::error!("TunnelIo: poll_write returned 0 (write zero), ctrl_q={} data_q={} sustained_q={}", self.write.ctrl_queue.len(), self.write.data_queue.len(), self.write.sustained_queue.len()); return Poll::Ready(TunnelEvent::WriteError( std::io::Error::new(std::io::ErrorKind::WriteZero, "write zero"), )); } Poll::Ready(Ok(n)) => { self.write.offset += n; self.write.flush_needed = true; if self.write.offset >= frame.len() { match queue_id { 0 => { self.write.ctrl_queue.pop_front(); } 1 => { self.write.data_queue.pop_front(); } _ => { self.write.sustained_queue.pop_front(); self.write.sustained_last_drain = Instant::now(); self.write.sustained_bytes_this_period = 0; } } self.write.offset = 0; writes += 1; } } Poll::Ready(Err(e)) => { log::error!("TunnelIo: poll_write error: {} (ctrl_q={} data_q={} sustained_q={})", e, self.write.ctrl_queue.len(), self.write.data_queue.len(), self.write.sustained_queue.len()); return Poll::Ready(TunnelEvent::WriteError(e)); } Poll::Pending => break, } } // 1b. FORCED SUSTAINED DRAIN: when ctrl/data have work but sustained is waiting, // guarantee at least 1 MB/s by draining up to SUSTAINED_FORCED_DRAIN_CAP // once per second. if !self.write.sustained_queue.is_empty() && (!self.write.ctrl_queue.is_empty() || !self.write.data_queue.is_empty()) && !self.write.flush_needed { let now = Instant::now(); if now.duration_since(self.write.sustained_last_drain) >= Duration::from_secs(1) { self.write.sustained_bytes_this_period = 0; self.write.sustained_last_drain = now; while !self.write.sustained_queue.is_empty() && self.write.sustained_bytes_this_period < SUSTAINED_FORCED_DRAIN_CAP && !self.write.flush_needed { let frame = self.write.sustained_queue.front().unwrap(); let remaining = &frame[self.write.offset..]; match Pin::new(&mut self.stream).poll_write(cx, remaining) { Poll::Ready(Ok(0)) => { return Poll::Ready(TunnelEvent::WriteError( std::io::Error::new(std::io::ErrorKind::WriteZero, "write zero"), )); } Poll::Ready(Ok(n)) => { self.write.offset += n; self.write.flush_needed = true; self.write.sustained_bytes_this_period += n; if self.write.offset >= frame.len() { self.write.sustained_queue.pop_front(); self.write.offset = 0; } } Poll::Ready(Err(e)) => { return Poll::Ready(TunnelEvent::WriteError(e)); } Poll::Pending => break, } } } } // 2. FLUSH: push encrypted data from TLS session to TCP. if self.write.flush_needed { match Pin::new(&mut self.stream).poll_flush(cx) { Poll::Ready(Ok(())) => { self.write.flush_needed = false; } Poll::Ready(Err(e)) => { log::error!("TunnelIo: poll_flush error: {}", e); return Poll::Ready(TunnelEvent::WriteError(e)); } Poll::Pending => {} // TCP waker will notify us } } // 3. READ: drain stream until Pending to ensure the TCP waker is always registered. // Without this loop, a Ready return with partial frame data would consume // the waker without re-registering it, causing the task to sleep until a // timer or channel wakes it (potentially 15+ seconds of lost reads). loop { // Compact if needed to make room for reads if self.parse_pos > 0 && self.read_buf.len() - self.read_pos < 32768 { self.read_buf.drain(..self.parse_pos); self.read_pos -= self.parse_pos; self.parse_pos = 0; } if self.read_buf.len() < self.read_pos + 32768 { self.read_buf.resize(self.read_pos + 32768, 0); } let mut rbuf = ReadBuf::new(&mut self.read_buf[self.read_pos..]); match Pin::new(&mut self.stream).poll_read(cx, &mut rbuf) { Poll::Ready(Ok(())) => { let n = rbuf.filled().len(); if n == 0 { return Poll::Ready(TunnelEvent::Eof); } self.read_pos += n; if let Some(result) = self.try_parse_frame() { return match result { Ok(frame) => Poll::Ready(TunnelEvent::Frame(frame)), Err(e) => Poll::Ready(TunnelEvent::ReadError(e)), }; } // Partial data — loop to call poll_read again so the TCP // waker is re-registered when it finally returns Pending. } Poll::Ready(Err(e)) => { log::error!("TunnelIo: poll_read error: {}", e); return Poll::Ready(TunnelEvent::ReadError(e)); } Poll::Pending => break, } } // 4. CHANNELS: drain ctrl (always — priority), data (only if queue is small). // Ctrl frames must never be delayed — always drain fully. // Data frames are gated: keep data in the bounded channel for proper // backpressure when TLS writes are slow. Without this gate, the internal // data_queue (unbounded VecDeque) grows to hundreds of MB under throttle -> OOM. let mut got_new = false; loop { match ctrl_rx.poll_recv(cx) { Poll::Ready(Some(frame)) => { self.write.ctrl_queue.push_back(frame); got_new = true; } Poll::Ready(None) => { return Poll::Ready(TunnelEvent::WriteError( std::io::Error::new(std::io::ErrorKind::BrokenPipe, "ctrl channel closed"), )); } Poll::Pending => break, } } if self.write.data_queue.len() < 64 { loop { match data_rx.poll_recv(cx) { Poll::Ready(Some(frame)) => { self.write.data_queue.push_back(frame); got_new = true; } Poll::Ready(None) => { return Poll::Ready(TunnelEvent::WriteError( std::io::Error::new(std::io::ErrorKind::BrokenPipe, "data channel closed"), )); } Poll::Pending => break, } } } // Sustained channel: drain when sustained_queue is small (same backpressure pattern). // Channel close is non-fatal — not all connections have sustained streams. if self.write.sustained_queue.len() < 64 { loop { match sustained_rx.poll_recv(cx) { Poll::Ready(Some(frame)) => { self.write.sustained_queue.push_back(frame); got_new = true; } Poll::Ready(None) | Poll::Pending => break, } } } // 5. TIMERS if liveness_deadline.as_mut().poll(cx).is_ready() { return Poll::Ready(TunnelEvent::LivenessTimeout); } if cancel_token.is_cancelled() { return Poll::Ready(TunnelEvent::Cancelled); } // 6. SELF-WAKE: only when flush is complete AND we have work. // When flush is Pending, the TCP write-readiness waker will notify us. // CRITICAL: do NOT self-wake when flush_needed — poll_write always returns // Ready (TLS buffers in-memory), so self-waking causes a tight spin loop // that fills the TLS session buffer unboundedly -> OOM -> ECONNRESET. if !self.write.flush_needed && (got_new || self.write.has_work()) { cx.waker().wake_by_ref(); } Poll::Pending } pub fn into_inner(self) -> S { self.stream } } #[cfg(test)] mod tests { use super::*; #[test] fn test_encode_frame_header() { let payload = b"hello"; let mut buf = vec![0u8; FRAME_HEADER_SIZE + payload.len()]; buf[FRAME_HEADER_SIZE..].copy_from_slice(payload); encode_frame_header(&mut buf, 42, FRAME_DATA, payload.len()); assert_eq!(buf, &encode_frame(42, FRAME_DATA, payload)[..]); } #[test] fn test_encode_frame_header_empty_payload() { let mut buf = vec![0u8; FRAME_HEADER_SIZE]; encode_frame_header(&mut buf, 99, FRAME_CLOSE, 0); assert_eq!(buf, &encode_frame(99, FRAME_CLOSE, &[])[..]); } #[test] fn test_encode_frame() { let data = b"hello"; let encoded = encode_frame(42, FRAME_DATA, data); assert_eq!(encoded.len(), FRAME_HEADER_SIZE + data.len()); // stream_id = 42 in BE assert_eq!(&encoded[0..4], &42u32.to_be_bytes()); // frame type assert_eq!(encoded[4], FRAME_DATA); // length assert_eq!(&encoded[5..9], &5u32.to_be_bytes()); // payload assert_eq!(&encoded[9..], b"hello"); } #[test] fn test_encode_empty_frame() { let encoded = encode_frame(1, FRAME_CLOSE, &[]); assert_eq!(encoded.len(), FRAME_HEADER_SIZE); assert_eq!(&encoded[5..9], &0u32.to_be_bytes()); } #[test] fn test_proxy_v1_header() { let header = build_proxy_v1_header("1.2.3.4", "5.6.7.8", 12345, 443); assert_eq!(header, "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n"); } #[test] fn test_proxy_v2_header_tcp4() { let src = "198.51.100.10".parse().unwrap(); let dst = "203.0.113.25".parse().unwrap(); let header = build_proxy_v2_header(&src, &dst, 54321, 8443, ProxyV2Transport::Tcp); assert_eq!(header.len(), 28); // Signature assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE); // Version 2 + PROXY command assert_eq!(header[12], 0x21); // AF_INET + STREAM (TCP) assert_eq!(header[13], 0x11); // Address length = 12 assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12); // Source IP: 198.51.100.10 assert_eq!(&header[16..20], &[198, 51, 100, 10]); // Dest IP: 203.0.113.25 assert_eq!(&header[20..24], &[203, 0, 113, 25]); // Source port: 54321 assert_eq!(u16::from_be_bytes([header[24], header[25]]), 54321); // Dest port: 8443 assert_eq!(u16::from_be_bytes([header[26], header[27]]), 8443); } #[test] fn test_proxy_v2_header_udp4() { let src = "10.0.0.1".parse().unwrap(); let dst = "10.0.0.2".parse().unwrap(); let header = build_proxy_v2_header(&src, &dst, 12345, 53, ProxyV2Transport::Udp); assert_eq!(header.len(), 28); assert_eq!(header[12], 0x21); // v2, PROXY assert_eq!(header[13], 0x12); // AF_INET + DGRAM (UDP) assert_eq!(&header[16..20], &[10, 0, 0, 1]); // src assert_eq!(&header[20..24], &[10, 0, 0, 2]); // dst assert_eq!(u16::from_be_bytes([header[24], header[25]]), 12345); assert_eq!(u16::from_be_bytes([header[26], header[27]]), 53); } #[test] fn test_proxy_v2_header_from_str() { let header = build_proxy_v2_header_from_str("1.2.3.4", "5.6.7.8", 1000, 443, ProxyV2Transport::Tcp); assert_eq!(header.len(), 28); assert_eq!(&header[16..20], &[1, 2, 3, 4]); assert_eq!(&header[20..24], &[5, 6, 7, 8]); } #[test] fn test_proxy_v2_header_from_str_invalid_ip() { let header = build_proxy_v2_header_from_str("not-an-ip", "also-not", 1000, 443, ProxyV2Transport::Udp); assert_eq!(header.len(), 28); // Falls back to 0.0.0.0 assert_eq!(&header[16..20], &[0, 0, 0, 0]); assert_eq!(&header[20..24], &[0, 0, 0, 0]); assert_eq!(header[13], 0x12); // UDP } #[tokio::test] async fn test_frame_reader() { let frame1 = encode_frame(1, FRAME_OPEN, b"PROXY TCP4 1.2.3.4 5.6.7.8 1234 443\r\n"); let frame2 = encode_frame(1, FRAME_DATA, b"GET / HTTP/1.1\r\n"); let frame3 = encode_frame(1, FRAME_CLOSE, &[]); let mut data = Vec::new(); data.extend_from_slice(&frame1); data.extend_from_slice(&frame2); data.extend_from_slice(&frame3); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let f1 = reader.next_frame().await.unwrap().unwrap(); assert_eq!(f1.stream_id, 1); assert_eq!(f1.frame_type, FRAME_OPEN); assert!(f1.payload.starts_with(b"PROXY")); let f2 = reader.next_frame().await.unwrap().unwrap(); assert_eq!(f2.frame_type, FRAME_DATA); let f3 = reader.next_frame().await.unwrap().unwrap(); assert_eq!(f3.frame_type, FRAME_CLOSE); assert!(f3.payload.is_empty()); // EOF assert!(reader.next_frame().await.unwrap().is_none()); } #[test] fn test_encode_frame_config_type() { let payload = b"{\"listenPorts\":[443]}"; let encoded = encode_frame(0, FRAME_CONFIG, payload); assert_eq!(encoded[4], FRAME_CONFIG); assert_eq!(&encoded[0..4], &0u32.to_be_bytes()); assert_eq!(&encoded[9..], payload.as_slice()); } #[test] fn test_encode_frame_data_back_type() { let payload = b"response data"; let encoded = encode_frame(7, FRAME_DATA_BACK, payload); assert_eq!(encoded[4], FRAME_DATA_BACK); assert_eq!(&encoded[0..4], &7u32.to_be_bytes()); assert_eq!(&encoded[5..9], &(payload.len() as u32).to_be_bytes()); assert_eq!(&encoded[9..], payload.as_slice()); } #[test] fn test_encode_frame_close_back_type() { let encoded = encode_frame(99, FRAME_CLOSE_BACK, &[]); assert_eq!(encoded[4], FRAME_CLOSE_BACK); assert_eq!(&encoded[0..4], &99u32.to_be_bytes()); assert_eq!(&encoded[5..9], &0u32.to_be_bytes()); assert_eq!(encoded.len(), FRAME_HEADER_SIZE); } #[test] fn test_encode_frame_large_stream_id() { let encoded = encode_frame(u32::MAX, FRAME_DATA, b"x"); assert_eq!(&encoded[0..4], &u32::MAX.to_be_bytes()); assert_eq!(encoded[4], FRAME_DATA); assert_eq!(&encoded[5..9], &1u32.to_be_bytes()); assert_eq!(encoded[9], b'x'); } #[tokio::test] async fn test_frame_reader_max_payload_rejection() { let mut data = Vec::new(); data.extend_from_slice(&1u32.to_be_bytes()); data.push(FRAME_DATA); data.extend_from_slice(&(MAX_PAYLOAD_SIZE + 1).to_be_bytes()); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let result = reader.next_frame().await; assert!(result.is_err()); let err = result.unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); } #[tokio::test] async fn test_frame_reader_eof_mid_header() { // Only 5 bytes — not enough for a 9-byte header let data = vec![0u8; 5]; let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); // Should return Ok(None) on partial header EOF let result = reader.next_frame().await; assert!(result.unwrap().is_none()); } #[tokio::test] async fn test_frame_reader_eof_mid_payload() { // Full header claiming 100 bytes of payload, but only 10 bytes present let mut data = Vec::new(); data.extend_from_slice(&1u32.to_be_bytes()); data.push(FRAME_DATA); data.extend_from_slice(&100u32.to_be_bytes()); data.extend_from_slice(&[0xAB; 10]); let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); let result = reader.next_frame().await; assert!(result.is_err()); } #[tokio::test] async fn test_frame_reader_all_frame_types() { let types = [ FRAME_OPEN, FRAME_DATA, FRAME_CLOSE, FRAME_DATA_BACK, FRAME_CLOSE_BACK, FRAME_CONFIG, FRAME_PING, FRAME_PONG, ]; let mut data = Vec::new(); for (i, &ft) in types.iter().enumerate() { let payload = format!("payload_{}", i); data.extend_from_slice(&encode_frame(i as u32, ft, payload.as_bytes())); } let cursor = std::io::Cursor::new(data); let mut reader = FrameReader::new(cursor); for (i, &ft) in types.iter().enumerate() { let frame = reader.next_frame().await.unwrap().unwrap(); assert_eq!(frame.stream_id, i as u32); assert_eq!(frame.frame_type, ft); assert_eq!(&frame.payload[..], format!("payload_{}", i).as_bytes()); } assert!(reader.next_frame().await.unwrap().is_none()); } #[tokio::test] async fn test_frame_reader_zero_length_payload() { let data = encode_frame(42, FRAME_CLOSE, &[]); let cursor = std::io::Cursor::new(data.to_vec()); let mut reader = FrameReader::new(cursor); let frame = reader.next_frame().await.unwrap().unwrap(); assert_eq!(frame.stream_id, 42); assert_eq!(frame.frame_type, FRAME_CLOSE); assert!(frame.payload.is_empty()); } #[test] fn test_encode_frame_ping_pong() { // PING: stream_id=0, empty payload (control frame) let ping = encode_frame(0, FRAME_PING, &[]); assert_eq!(ping[4], FRAME_PING); assert_eq!(&ping[0..4], &0u32.to_be_bytes()); assert_eq!(ping.len(), FRAME_HEADER_SIZE); // PONG: stream_id=0, empty payload (control frame) let pong = encode_frame(0, FRAME_PONG, &[]); assert_eq!(pong[4], FRAME_PONG); assert_eq!(&pong[0..4], &0u32.to_be_bytes()); assert_eq!(pong.len(), FRAME_HEADER_SIZE); } // --- compute_window_for_stream_count tests --- #[test] fn test_adaptive_window_zero_streams() { // 0 streams treated as 1: 200MB/1 -> clamped to 4MB max assert_eq!(compute_window_for_stream_count(0), INITIAL_STREAM_WINDOW); } #[test] fn test_adaptive_window_one_stream() { assert_eq!(compute_window_for_stream_count(1), INITIAL_STREAM_WINDOW); } #[test] fn test_adaptive_window_50_streams_full() { // 200MB/50 = 4MB = exactly INITIAL_STREAM_WINDOW assert_eq!(compute_window_for_stream_count(50), INITIAL_STREAM_WINDOW); } #[test] fn test_adaptive_window_51_streams_starts_scaling() { // 200MB/51 < 4MB — first value below max let w = compute_window_for_stream_count(51); assert!(w < INITIAL_STREAM_WINDOW); assert_eq!(w, (200 * 1024 * 1024u64 / 51) as u32); } #[test] fn test_adaptive_window_100_streams() { // 200MB/100 = 2MB assert_eq!(compute_window_for_stream_count(100), 2 * 1024 * 1024); } #[test] fn test_adaptive_window_200_streams_at_floor() { // 200MB/200 = 1MB = exactly the floor assert_eq!(compute_window_for_stream_count(200), 1 * 1024 * 1024); } #[test] fn test_adaptive_window_500_streams_clamped() { // 200MB/500 = 0.4MB -> clamped up to 1MB floor assert_eq!(compute_window_for_stream_count(500), 1 * 1024 * 1024); } #[test] fn test_adaptive_window_max_u32() { // Extreme: u32::MAX streams -> tiny value -> clamped to 1MB assert_eq!(compute_window_for_stream_count(u32::MAX), 1 * 1024 * 1024); } #[test] fn test_adaptive_window_monotonically_decreasing() { let mut prev = compute_window_for_stream_count(1); for n in [2, 10, 50, 51, 100, 200, 500, 1000] { let w = compute_window_for_stream_count(n); assert!(w <= prev, "window increased from {} to {} at n={}", prev, w, n); prev = w; } } #[test] fn test_adaptive_window_total_budget_bounded() { // active x per_stream_window should never exceed 200MB (+ clamp overhead for high N) for n in [1, 10, 50, 100, 200] { let w = compute_window_for_stream_count(n); let total = w as u64 * n as u64; assert!(total <= 200 * 1024 * 1024, "total {}MB exceeds budget at n={}", total / (1024*1024), n); } } // --- encode/decode window_update roundtrip --- #[test] fn test_window_update_roundtrip() { for &increment in &[0u32, 1, 64 * 1024, INITIAL_STREAM_WINDOW, MAX_WINDOW_SIZE, u32::MAX] { let frame = encode_window_update(42, FRAME_WINDOW_UPDATE, increment); assert_eq!(frame[4], FRAME_WINDOW_UPDATE); let decoded = decode_window_update(&frame[FRAME_HEADER_SIZE..]); assert_eq!(decoded, Some(increment)); } } #[test] fn test_window_update_back_roundtrip() { let frame = encode_window_update(7, FRAME_WINDOW_UPDATE_BACK, 1234567); assert_eq!(frame[4], FRAME_WINDOW_UPDATE_BACK); assert_eq!(decode_window_update(&frame[FRAME_HEADER_SIZE..]), Some(1234567)); } #[test] fn test_decode_window_update_malformed() { assert_eq!(decode_window_update(&[]), None); assert_eq!(decode_window_update(&[0, 0, 0]), None); assert_eq!(decode_window_update(&[0, 0, 0, 0, 0]), None); } }