349 lines
12 KiB
Rust
349 lines
12 KiB
Rust
use tokio::io::{AsyncRead, AsyncReadExt};
|
|
|
|
// Frame type constants
|
|
pub const FRAME_OPEN: u8 = 0x01;
|
|
pub const FRAME_DATA: u8 = 0x02;
|
|
pub const FRAME_CLOSE: u8 = 0x03;
|
|
pub const FRAME_DATA_BACK: u8 = 0x04;
|
|
pub const FRAME_CLOSE_BACK: u8 = 0x05;
|
|
pub const FRAME_CONFIG: u8 = 0x06; // Hub -> Edge: configuration update
|
|
pub const FRAME_PING: u8 = 0x07; // Hub -> Edge: heartbeat probe
|
|
pub const FRAME_PONG: u8 = 0x08; // Edge -> Hub: heartbeat response
|
|
pub const FRAME_WINDOW_UPDATE: u8 = 0x09; // Edge -> Hub: per-stream flow control
|
|
pub const FRAME_WINDOW_UPDATE_BACK: u8 = 0x0A; // Hub -> Edge: per-stream flow control
|
|
|
|
// Frame header size: 4 (stream_id) + 1 (type) + 4 (length) = 9 bytes
|
|
pub const FRAME_HEADER_SIZE: usize = 9;
|
|
|
|
// Maximum payload size (16 MB)
|
|
pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024;
|
|
|
|
// Per-stream flow control constants
|
|
/// Initial per-stream window size (4 MB). Sized for full throughput at high RTT:
|
|
/// at 100ms RTT, this sustains ~40 MB/s per stream.
|
|
pub const INITIAL_STREAM_WINDOW: u32 = 4 * 1024 * 1024;
|
|
/// Send WINDOW_UPDATE after consuming this many bytes (half the initial window).
|
|
pub const WINDOW_UPDATE_THRESHOLD: u32 = INITIAL_STREAM_WINDOW / 2;
|
|
/// Maximum window size to prevent overflow.
|
|
pub const MAX_WINDOW_SIZE: u32 = 16 * 1024 * 1024;
|
|
|
|
/// Encode a WINDOW_UPDATE frame for a specific stream.
|
|
pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Vec<u8> {
|
|
encode_frame(stream_id, frame_type, &increment.to_be_bytes())
|
|
}
|
|
|
|
/// Compute the target per-stream window size based on the number of active streams.
|
|
/// Total memory budget is ~32MB shared across all streams. As more streams are active,
|
|
/// each gets a smaller window. This adapts to current demand — few streams get high
|
|
/// throughput, many streams save memory and reduce control frame pressure.
|
|
pub fn compute_window_for_stream_count(active: u32) -> u32 {
|
|
let per_stream = (32 * 1024 * 1024u64) / (active.max(1) as u64);
|
|
per_stream.clamp(64 * 1024, INITIAL_STREAM_WINDOW as u64) as u32
|
|
}
|
|
|
|
/// Decode a WINDOW_UPDATE payload into a byte increment. Returns None if payload is malformed.
|
|
pub fn decode_window_update(payload: &[u8]) -> Option<u32> {
|
|
if payload.len() != 4 {
|
|
return None;
|
|
}
|
|
Some(u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]))
|
|
}
|
|
|
|
/// A single multiplexed frame.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Frame {
|
|
pub stream_id: u32,
|
|
pub frame_type: u8,
|
|
pub payload: Vec<u8>,
|
|
}
|
|
|
|
/// Encode a frame into bytes: [stream_id:4][type:1][length:4][payload]
|
|
pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Vec<u8> {
|
|
let len = payload.len() as u32;
|
|
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + payload.len());
|
|
buf.extend_from_slice(&stream_id.to_be_bytes());
|
|
buf.push(frame_type);
|
|
buf.extend_from_slice(&len.to_be_bytes());
|
|
buf.extend_from_slice(payload);
|
|
buf
|
|
}
|
|
|
|
/// Build a PROXY protocol v1 header line.
|
|
/// Format: `PROXY TCP4 <client_ip> <edge_ip> <client_port> <dest_port>\r\n`
|
|
pub fn build_proxy_v1_header(
|
|
client_ip: &str,
|
|
edge_ip: &str,
|
|
client_port: u16,
|
|
dest_port: u16,
|
|
) -> String {
|
|
format!(
|
|
"PROXY TCP4 {} {} {} {}\r\n",
|
|
client_ip, edge_ip, client_port, dest_port
|
|
)
|
|
}
|
|
|
|
/// Stateful async frame reader that yields `Frame` values from an `AsyncRead`.
|
|
pub struct FrameReader<R> {
|
|
reader: R,
|
|
header_buf: [u8; FRAME_HEADER_SIZE],
|
|
}
|
|
|
|
impl<R: AsyncRead + Unpin> FrameReader<R> {
|
|
pub fn new(reader: R) -> Self {
|
|
Self {
|
|
reader,
|
|
header_buf: [0u8; FRAME_HEADER_SIZE],
|
|
}
|
|
}
|
|
|
|
/// Read the next frame. Returns `None` on EOF, `Err` on protocol violation.
|
|
pub async fn next_frame(&mut self) -> Result<Option<Frame>, std::io::Error> {
|
|
// Read header
|
|
match self.reader.read_exact(&mut self.header_buf).await {
|
|
Ok(_) => {}
|
|
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
|
Err(e) => return Err(e),
|
|
}
|
|
|
|
let stream_id = u32::from_be_bytes([
|
|
self.header_buf[0],
|
|
self.header_buf[1],
|
|
self.header_buf[2],
|
|
self.header_buf[3],
|
|
]);
|
|
let frame_type = self.header_buf[4];
|
|
let length = u32::from_be_bytes([
|
|
self.header_buf[5],
|
|
self.header_buf[6],
|
|
self.header_buf[7],
|
|
self.header_buf[8],
|
|
]);
|
|
|
|
if length > MAX_PAYLOAD_SIZE {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidData,
|
|
format!("frame payload too large: {} bytes", length),
|
|
));
|
|
}
|
|
|
|
let mut payload = vec![0u8; length as usize];
|
|
if length > 0 {
|
|
self.reader.read_exact(&mut payload).await?;
|
|
}
|
|
|
|
Ok(Some(Frame {
|
|
stream_id,
|
|
frame_type,
|
|
payload,
|
|
}))
|
|
}
|
|
|
|
/// Consume the reader and return the inner stream.
|
|
pub fn into_inner(self) -> R {
|
|
self.reader
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_encode_frame() {
|
|
let data = b"hello";
|
|
let encoded = encode_frame(42, FRAME_DATA, data);
|
|
assert_eq!(encoded.len(), FRAME_HEADER_SIZE + data.len());
|
|
// stream_id = 42 in BE
|
|
assert_eq!(&encoded[0..4], &42u32.to_be_bytes());
|
|
// frame type
|
|
assert_eq!(encoded[4], FRAME_DATA);
|
|
// length
|
|
assert_eq!(&encoded[5..9], &5u32.to_be_bytes());
|
|
// payload
|
|
assert_eq!(&encoded[9..], b"hello");
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_empty_frame() {
|
|
let encoded = encode_frame(1, FRAME_CLOSE, &[]);
|
|
assert_eq!(encoded.len(), FRAME_HEADER_SIZE);
|
|
assert_eq!(&encoded[5..9], &0u32.to_be_bytes());
|
|
}
|
|
|
|
#[test]
|
|
fn test_proxy_v1_header() {
|
|
let header = build_proxy_v1_header("1.2.3.4", "5.6.7.8", 12345, 443);
|
|
assert_eq!(header, "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader() {
|
|
let frame1 = encode_frame(1, FRAME_OPEN, b"PROXY TCP4 1.2.3.4 5.6.7.8 1234 443\r\n");
|
|
let frame2 = encode_frame(1, FRAME_DATA, b"GET / HTTP/1.1\r\n");
|
|
let frame3 = encode_frame(1, FRAME_CLOSE, &[]);
|
|
|
|
let mut data = Vec::new();
|
|
data.extend_from_slice(&frame1);
|
|
data.extend_from_slice(&frame2);
|
|
data.extend_from_slice(&frame3);
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let f1 = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(f1.stream_id, 1);
|
|
assert_eq!(f1.frame_type, FRAME_OPEN);
|
|
assert!(f1.payload.starts_with(b"PROXY"));
|
|
|
|
let f2 = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(f2.frame_type, FRAME_DATA);
|
|
|
|
let f3 = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(f3.frame_type, FRAME_CLOSE);
|
|
assert!(f3.payload.is_empty());
|
|
|
|
// EOF
|
|
assert!(reader.next_frame().await.unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_config_type() {
|
|
let payload = b"{\"listenPorts\":[443]}";
|
|
let encoded = encode_frame(0, FRAME_CONFIG, payload);
|
|
assert_eq!(encoded[4], FRAME_CONFIG);
|
|
assert_eq!(&encoded[0..4], &0u32.to_be_bytes());
|
|
assert_eq!(&encoded[9..], payload.as_slice());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_data_back_type() {
|
|
let payload = b"response data";
|
|
let encoded = encode_frame(7, FRAME_DATA_BACK, payload);
|
|
assert_eq!(encoded[4], FRAME_DATA_BACK);
|
|
assert_eq!(&encoded[0..4], &7u32.to_be_bytes());
|
|
assert_eq!(&encoded[5..9], &(payload.len() as u32).to_be_bytes());
|
|
assert_eq!(&encoded[9..], payload.as_slice());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_close_back_type() {
|
|
let encoded = encode_frame(99, FRAME_CLOSE_BACK, &[]);
|
|
assert_eq!(encoded[4], FRAME_CLOSE_BACK);
|
|
assert_eq!(&encoded[0..4], &99u32.to_be_bytes());
|
|
assert_eq!(&encoded[5..9], &0u32.to_be_bytes());
|
|
assert_eq!(encoded.len(), FRAME_HEADER_SIZE);
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_large_stream_id() {
|
|
let encoded = encode_frame(u32::MAX, FRAME_DATA, b"x");
|
|
assert_eq!(&encoded[0..4], &u32::MAX.to_be_bytes());
|
|
assert_eq!(encoded[4], FRAME_DATA);
|
|
assert_eq!(&encoded[5..9], &1u32.to_be_bytes());
|
|
assert_eq!(encoded[9], b'x');
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_max_payload_rejection() {
|
|
let mut data = Vec::new();
|
|
data.extend_from_slice(&1u32.to_be_bytes());
|
|
data.push(FRAME_DATA);
|
|
data.extend_from_slice(&(MAX_PAYLOAD_SIZE + 1).to_be_bytes());
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let result = reader.next_frame().await;
|
|
assert!(result.is_err());
|
|
let err = result.unwrap_err();
|
|
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_eof_mid_header() {
|
|
// Only 5 bytes — not enough for a 9-byte header
|
|
let data = vec![0u8; 5];
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
// Should return Ok(None) on partial header EOF
|
|
let result = reader.next_frame().await;
|
|
assert!(result.unwrap().is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_eof_mid_payload() {
|
|
// Full header claiming 100 bytes of payload, but only 10 bytes present
|
|
let mut data = Vec::new();
|
|
data.extend_from_slice(&1u32.to_be_bytes());
|
|
data.push(FRAME_DATA);
|
|
data.extend_from_slice(&100u32.to_be_bytes());
|
|
data.extend_from_slice(&[0xAB; 10]);
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let result = reader.next_frame().await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_all_frame_types() {
|
|
let types = [
|
|
FRAME_OPEN,
|
|
FRAME_DATA,
|
|
FRAME_CLOSE,
|
|
FRAME_DATA_BACK,
|
|
FRAME_CLOSE_BACK,
|
|
FRAME_CONFIG,
|
|
FRAME_PING,
|
|
FRAME_PONG,
|
|
];
|
|
|
|
let mut data = Vec::new();
|
|
for (i, &ft) in types.iter().enumerate() {
|
|
let payload = format!("payload_{}", i);
|
|
data.extend_from_slice(&encode_frame(i as u32, ft, payload.as_bytes()));
|
|
}
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
for (i, &ft) in types.iter().enumerate() {
|
|
let frame = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(frame.stream_id, i as u32);
|
|
assert_eq!(frame.frame_type, ft);
|
|
assert_eq!(frame.payload, format!("payload_{}", i).as_bytes());
|
|
}
|
|
|
|
assert!(reader.next_frame().await.unwrap().is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_zero_length_payload() {
|
|
let data = encode_frame(42, FRAME_CLOSE, &[]);
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let frame = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(frame.stream_id, 42);
|
|
assert_eq!(frame.frame_type, FRAME_CLOSE);
|
|
assert!(frame.payload.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_ping_pong() {
|
|
// PING: stream_id=0, empty payload (control frame)
|
|
let ping = encode_frame(0, FRAME_PING, &[]);
|
|
assert_eq!(ping[4], FRAME_PING);
|
|
assert_eq!(&ping[0..4], &0u32.to_be_bytes());
|
|
assert_eq!(ping.len(), FRAME_HEADER_SIZE);
|
|
|
|
// PONG: stream_id=0, empty payload (control frame)
|
|
let pong = encode_frame(0, FRAME_PONG, &[]);
|
|
assert_eq!(pong[4], FRAME_PONG);
|
|
assert_eq!(&pong[0..4], &0u32.to_be_bytes());
|
|
assert_eq!(pong.len(), FRAME_HEADER_SIZE);
|
|
}
|
|
}
|