874 lines
34 KiB
Rust
874 lines
34 KiB
Rust
use std::collections::VecDeque;
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
use std::time::Duration;
|
|
use bytes::{Bytes, BytesMut, BufMut};
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
|
use tokio::time::Instant;
|
|
|
|
// Frame type constants
|
|
pub const FRAME_OPEN: u8 = 0x01;
|
|
pub const FRAME_DATA: u8 = 0x02;
|
|
pub const FRAME_CLOSE: u8 = 0x03;
|
|
pub const FRAME_DATA_BACK: u8 = 0x04;
|
|
pub const FRAME_CLOSE_BACK: u8 = 0x05;
|
|
pub const FRAME_CONFIG: u8 = 0x06; // Hub -> Edge: configuration update
|
|
pub const FRAME_PING: u8 = 0x07; // Hub -> Edge: heartbeat probe
|
|
pub const FRAME_PONG: u8 = 0x08; // Edge -> Hub: heartbeat response
|
|
pub const FRAME_WINDOW_UPDATE: u8 = 0x09; // Edge -> Hub: per-stream flow control
|
|
pub const FRAME_WINDOW_UPDATE_BACK: u8 = 0x0A; // Hub -> Edge: per-stream flow control
|
|
|
|
// Frame header size: 4 (stream_id) + 1 (type) + 4 (length) = 9 bytes
|
|
pub const FRAME_HEADER_SIZE: usize = 9;
|
|
|
|
// Maximum payload size (16 MB)
|
|
pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024;
|
|
|
|
// Per-stream flow control constants
|
|
/// Initial (and maximum) per-stream window size (4 MB).
|
|
pub const INITIAL_STREAM_WINDOW: u32 = 4 * 1024 * 1024;
|
|
/// Send WINDOW_UPDATE after consuming this many bytes (half the initial window).
|
|
pub const WINDOW_UPDATE_THRESHOLD: u32 = INITIAL_STREAM_WINDOW / 2;
|
|
/// Maximum window size to prevent overflow.
|
|
pub const MAX_WINDOW_SIZE: u32 = 4 * 1024 * 1024;
|
|
|
|
// Sustained stream classification constants
|
|
/// Throughput threshold for sustained classification (2.5 MB/s = 20 Mbit/s).
|
|
pub const SUSTAINED_THRESHOLD_BPS: u64 = 2_500_000;
|
|
/// Minimum duration before a stream can be classified as sustained.
|
|
pub const SUSTAINED_MIN_DURATION_SECS: u64 = 10;
|
|
/// Fixed window for sustained streams (1 MB — the floor).
|
|
pub const SUSTAINED_WINDOW: u32 = 1 * 1024 * 1024;
|
|
/// Maximum bytes written from sustained queue per forced drain (1 MB/s guarantee).
|
|
pub const SUSTAINED_FORCED_DRAIN_CAP: usize = 1_048_576;
|
|
|
|
/// Encode a WINDOW_UPDATE frame for a specific stream.
|
|
pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Bytes {
|
|
encode_frame(stream_id, frame_type, &increment.to_be_bytes())
|
|
}
|
|
|
|
/// Compute the target per-stream window size based on the number of active streams.
|
|
/// Total memory budget is ~200MB shared across all streams. Up to 50 streams get the
|
|
/// full 4MB window; above that the window scales down to a 1MB floor at 200+ streams.
|
|
pub fn compute_window_for_stream_count(active: u32) -> u32 {
|
|
let per_stream = (200 * 1024 * 1024u64) / (active.max(1) as u64);
|
|
per_stream.clamp(1 * 1024 * 1024, INITIAL_STREAM_WINDOW as u64) as u32
|
|
}
|
|
|
|
/// Decode a WINDOW_UPDATE payload into a byte increment. Returns None if payload is malformed.
|
|
pub fn decode_window_update(payload: &[u8]) -> Option<u32> {
|
|
if payload.len() != 4 {
|
|
return None;
|
|
}
|
|
Some(u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]))
|
|
}
|
|
|
|
/// A single multiplexed frame.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Frame {
|
|
pub stream_id: u32,
|
|
pub frame_type: u8,
|
|
pub payload: Bytes,
|
|
}
|
|
|
|
/// Encode a frame into bytes: [stream_id:4][type:1][length:4][payload]
|
|
pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Bytes {
|
|
let len = payload.len() as u32;
|
|
let mut buf = BytesMut::with_capacity(FRAME_HEADER_SIZE + payload.len());
|
|
buf.put_slice(&stream_id.to_be_bytes());
|
|
buf.put_u8(frame_type);
|
|
buf.put_slice(&len.to_be_bytes());
|
|
buf.put_slice(payload);
|
|
buf.freeze()
|
|
}
|
|
|
|
/// Write a frame header into `buf[0..FRAME_HEADER_SIZE]`.
|
|
/// The caller must ensure payload is already at `buf[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len]`.
|
|
/// This enables zero-copy encoding: read directly into `buf[FRAME_HEADER_SIZE..]`, then
|
|
/// prepend the header without copying the payload.
|
|
pub fn encode_frame_header(buf: &mut [u8], stream_id: u32, frame_type: u8, payload_len: usize) {
|
|
buf[0..4].copy_from_slice(&stream_id.to_be_bytes());
|
|
buf[4] = frame_type;
|
|
buf[5..9].copy_from_slice(&(payload_len as u32).to_be_bytes());
|
|
}
|
|
|
|
/// Build a PROXY protocol v1 header line.
|
|
/// Format: `PROXY TCP4 <client_ip> <edge_ip> <client_port> <dest_port>\r\n`
|
|
pub fn build_proxy_v1_header(
|
|
client_ip: &str,
|
|
edge_ip: &str,
|
|
client_port: u16,
|
|
dest_port: u16,
|
|
) -> String {
|
|
format!(
|
|
"PROXY TCP4 {} {} {} {}\r\n",
|
|
client_ip, edge_ip, client_port, dest_port
|
|
)
|
|
}
|
|
|
|
/// Stateful async frame reader that yields `Frame` values from an `AsyncRead`.
|
|
pub struct FrameReader<R> {
|
|
reader: R,
|
|
header_buf: [u8; FRAME_HEADER_SIZE],
|
|
}
|
|
|
|
impl<R: AsyncRead + Unpin> FrameReader<R> {
|
|
pub fn new(reader: R) -> Self {
|
|
Self {
|
|
reader,
|
|
header_buf: [0u8; FRAME_HEADER_SIZE],
|
|
}
|
|
}
|
|
|
|
/// Read the next frame. Returns `None` on EOF, `Err` on protocol violation.
|
|
pub async fn next_frame(&mut self) -> Result<Option<Frame>, std::io::Error> {
|
|
// Read header
|
|
match self.reader.read_exact(&mut self.header_buf).await {
|
|
Ok(_) => {}
|
|
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
|
Err(e) => return Err(e),
|
|
}
|
|
|
|
let stream_id = u32::from_be_bytes([
|
|
self.header_buf[0],
|
|
self.header_buf[1],
|
|
self.header_buf[2],
|
|
self.header_buf[3],
|
|
]);
|
|
let frame_type = self.header_buf[4];
|
|
let length = u32::from_be_bytes([
|
|
self.header_buf[5],
|
|
self.header_buf[6],
|
|
self.header_buf[7],
|
|
self.header_buf[8],
|
|
]);
|
|
|
|
if length > MAX_PAYLOAD_SIZE {
|
|
log::error!(
|
|
"CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}",
|
|
self.header_buf, stream_id, frame_type, length
|
|
);
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidData,
|
|
format!("frame payload too large: {} bytes (header={:02x?})", length, self.header_buf),
|
|
));
|
|
}
|
|
|
|
let mut payload = BytesMut::zeroed(length as usize);
|
|
if length > 0 {
|
|
self.reader.read_exact(&mut payload).await?;
|
|
}
|
|
|
|
Ok(Some(Frame {
|
|
stream_id,
|
|
frame_type,
|
|
payload: payload.freeze(),
|
|
}))
|
|
}
|
|
|
|
/// Consume the reader and return the inner stream.
|
|
pub fn into_inner(self) -> R {
|
|
self.reader
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TunnelIo: single-owner I/O multiplexer for the TLS tunnel connection
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Events produced by the TunnelIo event loop.
|
|
#[derive(Debug)]
|
|
pub enum TunnelEvent {
|
|
/// A complete frame was read from the remote side.
|
|
Frame(Frame),
|
|
/// The remote side closed the connection (EOF).
|
|
Eof,
|
|
/// A read error occurred.
|
|
ReadError(std::io::Error),
|
|
/// A write error occurred.
|
|
WriteError(std::io::Error),
|
|
/// No frames received for the liveness timeout duration.
|
|
LivenessTimeout,
|
|
/// The cancellation token was triggered.
|
|
Cancelled,
|
|
}
|
|
|
|
/// Write state extracted into a sub-struct so the borrow checker can see
|
|
/// disjoint field access between `self.write` and `self.stream`.
|
|
struct WriteState {
|
|
ctrl_queue: VecDeque<Bytes>, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first
|
|
data_queue: VecDeque<Bytes>, // DATA, DATA_BACK — only when ctrl is empty
|
|
sustained_queue: VecDeque<Bytes>, // DATA, DATA_BACK from sustained streams — lowest priority
|
|
offset: usize, // progress within current frame being written
|
|
flush_needed: bool,
|
|
// Sustained starvation prevention: guaranteed 1 MB/s drain
|
|
sustained_last_drain: Instant,
|
|
sustained_bytes_this_period: usize,
|
|
}
|
|
|
|
impl WriteState {
|
|
fn has_work(&self) -> bool {
|
|
!self.ctrl_queue.is_empty() || !self.data_queue.is_empty() || !self.sustained_queue.is_empty()
|
|
}
|
|
}
|
|
|
|
/// Single-owner I/O engine for the tunnel TLS connection.
|
|
///
|
|
/// Owns the TLS stream directly — no `tokio::io::split()`, no mutex.
|
|
/// Uses three priority write queues:
|
|
/// 1. ctrl (PONG, WINDOW_UPDATE, CLOSE, OPEN) — always first
|
|
/// 2. data (DATA, DATA_BACK from normal streams) — when ctrl empty
|
|
/// 3. sustained (DATA, DATA_BACK from sustained streams) — lowest priority,
|
|
/// drained freely when ctrl+data empty, or forced 1MB/s when they're not
|
|
pub struct TunnelIo<S> {
|
|
stream: S,
|
|
// Read state: accumulate bytes, parse frames incrementally
|
|
read_buf: Vec<u8>,
|
|
read_pos: usize,
|
|
parse_pos: usize,
|
|
// Write state: extracted sub-struct for safe disjoint borrows
|
|
write: WriteState,
|
|
}
|
|
|
|
impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
|
pub fn new(stream: S, initial_data: Vec<u8>) -> Self {
|
|
let read_pos = initial_data.len();
|
|
let mut read_buf = initial_data;
|
|
if read_buf.capacity() < 65536 {
|
|
read_buf.reserve(65536 - read_buf.len());
|
|
}
|
|
Self {
|
|
stream,
|
|
read_buf,
|
|
read_pos,
|
|
parse_pos: 0,
|
|
write: WriteState {
|
|
ctrl_queue: VecDeque::new(),
|
|
data_queue: VecDeque::new(),
|
|
sustained_queue: VecDeque::new(),
|
|
offset: 0,
|
|
flush_needed: false,
|
|
sustained_last_drain: Instant::now(),
|
|
sustained_bytes_this_period: 0,
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Queue a high-priority control frame (PONG, WINDOW_UPDATE, CLOSE, OPEN).
|
|
pub fn queue_ctrl(&mut self, frame: Bytes) {
|
|
self.write.ctrl_queue.push_back(frame);
|
|
}
|
|
|
|
/// Queue a lower-priority data frame (DATA, DATA_BACK).
|
|
pub fn queue_data(&mut self, frame: Bytes) {
|
|
self.write.data_queue.push_back(frame);
|
|
}
|
|
|
|
/// Queue a lowest-priority sustained data frame.
|
|
pub fn queue_sustained(&mut self, frame: Bytes) {
|
|
self.write.sustained_queue.push_back(frame);
|
|
}
|
|
|
|
/// Try to parse a complete frame from the read buffer.
|
|
/// Uses a parse_pos cursor to avoid drain() on every frame.
|
|
pub fn try_parse_frame(&mut self) -> Option<Result<Frame, std::io::Error>> {
|
|
let available = self.read_pos - self.parse_pos;
|
|
if available < FRAME_HEADER_SIZE {
|
|
return None;
|
|
}
|
|
|
|
let base = self.parse_pos;
|
|
let stream_id = u32::from_be_bytes([
|
|
self.read_buf[base], self.read_buf[base + 1],
|
|
self.read_buf[base + 2], self.read_buf[base + 3],
|
|
]);
|
|
let frame_type = self.read_buf[base + 4];
|
|
let length = u32::from_be_bytes([
|
|
self.read_buf[base + 5], self.read_buf[base + 6],
|
|
self.read_buf[base + 7], self.read_buf[base + 8],
|
|
]);
|
|
|
|
if length > MAX_PAYLOAD_SIZE {
|
|
let header = [
|
|
self.read_buf[base], self.read_buf[base + 1],
|
|
self.read_buf[base + 2], self.read_buf[base + 3],
|
|
self.read_buf[base + 4], self.read_buf[base + 5],
|
|
self.read_buf[base + 6], self.read_buf[base + 7],
|
|
self.read_buf[base + 8],
|
|
];
|
|
log::error!(
|
|
"CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}",
|
|
header, stream_id, frame_type, length
|
|
);
|
|
return Some(Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidData,
|
|
format!("frame payload too large: {} bytes (header={:02x?})", length, header),
|
|
)));
|
|
}
|
|
|
|
let total_frame_size = FRAME_HEADER_SIZE + length as usize;
|
|
if available < total_frame_size {
|
|
return None;
|
|
}
|
|
|
|
let payload = Bytes::copy_from_slice(
|
|
&self.read_buf[base + FRAME_HEADER_SIZE..base + total_frame_size],
|
|
);
|
|
self.parse_pos += total_frame_size;
|
|
|
|
// Compact when parse_pos > half the data to reclaim memory
|
|
if self.parse_pos > self.read_pos / 2 && self.parse_pos > 0 {
|
|
self.read_buf.drain(..self.parse_pos);
|
|
self.read_pos -= self.parse_pos;
|
|
self.parse_pos = 0;
|
|
}
|
|
|
|
Some(Ok(Frame { stream_id, frame_type, payload }))
|
|
}
|
|
|
|
/// Poll-based I/O step. Returns Ready on events, Pending when idle.
|
|
///
|
|
/// Order: write(ctrl->data->sustained) -> flush -> read -> channels -> timers
|
|
pub fn poll_step(
|
|
&mut self,
|
|
cx: &mut Context<'_>,
|
|
ctrl_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
|
|
data_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
|
|
sustained_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
|
|
liveness_deadline: &mut Pin<Box<tokio::time::Sleep>>,
|
|
cancel_token: &tokio_util::sync::CancellationToken,
|
|
) -> Poll<TunnelEvent> {
|
|
// 1. WRITE: 3-tier priority — ctrl first, then data, then sustained.
|
|
// Sustained drains freely when ctrl+data are empty.
|
|
// Write one frame, set flush_needed, then flush must complete before
|
|
// writing more. This prevents unbounded TLS session buffer growth.
|
|
// Safe: `self.write` and `self.stream` are disjoint fields.
|
|
let mut writes = 0;
|
|
while self.write.has_work() && writes < 16 && !self.write.flush_needed {
|
|
// Pick queue: ctrl > data > sustained
|
|
let queue_id = if !self.write.ctrl_queue.is_empty() {
|
|
0 // ctrl
|
|
} else if !self.write.data_queue.is_empty() {
|
|
1 // data
|
|
} else {
|
|
2 // sustained
|
|
};
|
|
let frame = match queue_id {
|
|
0 => self.write.ctrl_queue.front().unwrap(),
|
|
1 => self.write.data_queue.front().unwrap(),
|
|
_ => self.write.sustained_queue.front().unwrap(),
|
|
};
|
|
let remaining = &frame[self.write.offset..];
|
|
|
|
match Pin::new(&mut self.stream).poll_write(cx, remaining) {
|
|
Poll::Ready(Ok(0)) => {
|
|
log::error!("TunnelIo: poll_write returned 0 (write zero), ctrl_q={} data_q={} sustained_q={}",
|
|
self.write.ctrl_queue.len(), self.write.data_queue.len(), self.write.sustained_queue.len());
|
|
return Poll::Ready(TunnelEvent::WriteError(
|
|
std::io::Error::new(std::io::ErrorKind::WriteZero, "write zero"),
|
|
));
|
|
}
|
|
Poll::Ready(Ok(n)) => {
|
|
self.write.offset += n;
|
|
self.write.flush_needed = true;
|
|
if self.write.offset >= frame.len() {
|
|
match queue_id {
|
|
0 => { self.write.ctrl_queue.pop_front(); }
|
|
1 => { self.write.data_queue.pop_front(); }
|
|
_ => {
|
|
self.write.sustained_queue.pop_front();
|
|
self.write.sustained_last_drain = Instant::now();
|
|
self.write.sustained_bytes_this_period = 0;
|
|
}
|
|
}
|
|
self.write.offset = 0;
|
|
writes += 1;
|
|
}
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
log::error!("TunnelIo: poll_write error: {} (ctrl_q={} data_q={} sustained_q={})",
|
|
e, self.write.ctrl_queue.len(), self.write.data_queue.len(), self.write.sustained_queue.len());
|
|
return Poll::Ready(TunnelEvent::WriteError(e));
|
|
}
|
|
Poll::Pending => break,
|
|
}
|
|
}
|
|
|
|
// 1b. FORCED SUSTAINED DRAIN: when ctrl/data have work but sustained is waiting,
|
|
// guarantee at least 1 MB/s by draining up to SUSTAINED_FORCED_DRAIN_CAP
|
|
// once per second.
|
|
if !self.write.sustained_queue.is_empty()
|
|
&& (!self.write.ctrl_queue.is_empty() || !self.write.data_queue.is_empty())
|
|
&& !self.write.flush_needed
|
|
{
|
|
let now = Instant::now();
|
|
if now.duration_since(self.write.sustained_last_drain) >= Duration::from_secs(1) {
|
|
self.write.sustained_bytes_this_period = 0;
|
|
self.write.sustained_last_drain = now;
|
|
|
|
while !self.write.sustained_queue.is_empty()
|
|
&& self.write.sustained_bytes_this_period < SUSTAINED_FORCED_DRAIN_CAP
|
|
&& !self.write.flush_needed
|
|
{
|
|
let frame = self.write.sustained_queue.front().unwrap();
|
|
let remaining = &frame[self.write.offset..];
|
|
match Pin::new(&mut self.stream).poll_write(cx, remaining) {
|
|
Poll::Ready(Ok(0)) => {
|
|
return Poll::Ready(TunnelEvent::WriteError(
|
|
std::io::Error::new(std::io::ErrorKind::WriteZero, "write zero"),
|
|
));
|
|
}
|
|
Poll::Ready(Ok(n)) => {
|
|
self.write.offset += n;
|
|
self.write.flush_needed = true;
|
|
self.write.sustained_bytes_this_period += n;
|
|
if self.write.offset >= frame.len() {
|
|
self.write.sustained_queue.pop_front();
|
|
self.write.offset = 0;
|
|
}
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
return Poll::Ready(TunnelEvent::WriteError(e));
|
|
}
|
|
Poll::Pending => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. FLUSH: push encrypted data from TLS session to TCP.
|
|
if self.write.flush_needed {
|
|
match Pin::new(&mut self.stream).poll_flush(cx) {
|
|
Poll::Ready(Ok(())) => {
|
|
self.write.flush_needed = false;
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
log::error!("TunnelIo: poll_flush error: {}", e);
|
|
return Poll::Ready(TunnelEvent::WriteError(e));
|
|
}
|
|
Poll::Pending => {} // TCP waker will notify us
|
|
}
|
|
}
|
|
|
|
// 3. READ: drain stream until Pending to ensure the TCP waker is always registered.
|
|
// Without this loop, a Ready return with partial frame data would consume
|
|
// the waker without re-registering it, causing the task to sleep until a
|
|
// timer or channel wakes it (potentially 15+ seconds of lost reads).
|
|
loop {
|
|
// Compact if needed to make room for reads
|
|
if self.parse_pos > 0 && self.read_buf.len() - self.read_pos < 32768 {
|
|
self.read_buf.drain(..self.parse_pos);
|
|
self.read_pos -= self.parse_pos;
|
|
self.parse_pos = 0;
|
|
}
|
|
if self.read_buf.len() < self.read_pos + 32768 {
|
|
self.read_buf.resize(self.read_pos + 32768, 0);
|
|
}
|
|
let mut rbuf = ReadBuf::new(&mut self.read_buf[self.read_pos..]);
|
|
match Pin::new(&mut self.stream).poll_read(cx, &mut rbuf) {
|
|
Poll::Ready(Ok(())) => {
|
|
let n = rbuf.filled().len();
|
|
if n == 0 {
|
|
return Poll::Ready(TunnelEvent::Eof);
|
|
}
|
|
self.read_pos += n;
|
|
if let Some(result) = self.try_parse_frame() {
|
|
return match result {
|
|
Ok(frame) => Poll::Ready(TunnelEvent::Frame(frame)),
|
|
Err(e) => Poll::Ready(TunnelEvent::ReadError(e)),
|
|
};
|
|
}
|
|
// Partial data — loop to call poll_read again so the TCP
|
|
// waker is re-registered when it finally returns Pending.
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
log::error!("TunnelIo: poll_read error: {}", e);
|
|
return Poll::Ready(TunnelEvent::ReadError(e));
|
|
}
|
|
Poll::Pending => break,
|
|
}
|
|
}
|
|
|
|
// 4. CHANNELS: drain ctrl (always — priority), data (only if queue is small).
|
|
// Ctrl frames must never be delayed — always drain fully.
|
|
// Data frames are gated: keep data in the bounded channel for proper
|
|
// backpressure when TLS writes are slow. Without this gate, the internal
|
|
// data_queue (unbounded VecDeque) grows to hundreds of MB under throttle -> OOM.
|
|
let mut got_new = false;
|
|
loop {
|
|
match ctrl_rx.poll_recv(cx) {
|
|
Poll::Ready(Some(frame)) => { self.write.ctrl_queue.push_back(frame); got_new = true; }
|
|
Poll::Ready(None) => {
|
|
return Poll::Ready(TunnelEvent::WriteError(
|
|
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "ctrl channel closed"),
|
|
));
|
|
}
|
|
Poll::Pending => break,
|
|
}
|
|
}
|
|
if self.write.data_queue.len() < 64 {
|
|
loop {
|
|
match data_rx.poll_recv(cx) {
|
|
Poll::Ready(Some(frame)) => { self.write.data_queue.push_back(frame); got_new = true; }
|
|
Poll::Ready(None) => {
|
|
return Poll::Ready(TunnelEvent::WriteError(
|
|
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "data channel closed"),
|
|
));
|
|
}
|
|
Poll::Pending => break,
|
|
}
|
|
}
|
|
}
|
|
// Sustained channel: drain when sustained_queue is small (same backpressure pattern).
|
|
// Channel close is non-fatal — not all connections have sustained streams.
|
|
if self.write.sustained_queue.len() < 64 {
|
|
loop {
|
|
match sustained_rx.poll_recv(cx) {
|
|
Poll::Ready(Some(frame)) => { self.write.sustained_queue.push_back(frame); got_new = true; }
|
|
Poll::Ready(None) | Poll::Pending => break,
|
|
}
|
|
}
|
|
}
|
|
|
|
// 5. TIMERS
|
|
if liveness_deadline.as_mut().poll(cx).is_ready() {
|
|
return Poll::Ready(TunnelEvent::LivenessTimeout);
|
|
}
|
|
if cancel_token.is_cancelled() {
|
|
return Poll::Ready(TunnelEvent::Cancelled);
|
|
}
|
|
|
|
// 6. SELF-WAKE: only when flush is complete AND we have work.
|
|
// When flush is Pending, the TCP write-readiness waker will notify us.
|
|
// CRITICAL: do NOT self-wake when flush_needed — poll_write always returns
|
|
// Ready (TLS buffers in-memory), so self-waking causes a tight spin loop
|
|
// that fills the TLS session buffer unboundedly -> OOM -> ECONNRESET.
|
|
if !self.write.flush_needed && (got_new || self.write.has_work()) {
|
|
cx.waker().wake_by_ref();
|
|
}
|
|
|
|
Poll::Pending
|
|
}
|
|
|
|
pub fn into_inner(self) -> S {
|
|
self.stream
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_encode_frame_header() {
|
|
let payload = b"hello";
|
|
let mut buf = vec![0u8; FRAME_HEADER_SIZE + payload.len()];
|
|
buf[FRAME_HEADER_SIZE..].copy_from_slice(payload);
|
|
encode_frame_header(&mut buf, 42, FRAME_DATA, payload.len());
|
|
assert_eq!(buf, &encode_frame(42, FRAME_DATA, payload)[..]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_header_empty_payload() {
|
|
let mut buf = vec![0u8; FRAME_HEADER_SIZE];
|
|
encode_frame_header(&mut buf, 99, FRAME_CLOSE, 0);
|
|
assert_eq!(buf, &encode_frame(99, FRAME_CLOSE, &[])[..]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame() {
|
|
let data = b"hello";
|
|
let encoded = encode_frame(42, FRAME_DATA, data);
|
|
assert_eq!(encoded.len(), FRAME_HEADER_SIZE + data.len());
|
|
// stream_id = 42 in BE
|
|
assert_eq!(&encoded[0..4], &42u32.to_be_bytes());
|
|
// frame type
|
|
assert_eq!(encoded[4], FRAME_DATA);
|
|
// length
|
|
assert_eq!(&encoded[5..9], &5u32.to_be_bytes());
|
|
// payload
|
|
assert_eq!(&encoded[9..], b"hello");
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_empty_frame() {
|
|
let encoded = encode_frame(1, FRAME_CLOSE, &[]);
|
|
assert_eq!(encoded.len(), FRAME_HEADER_SIZE);
|
|
assert_eq!(&encoded[5..9], &0u32.to_be_bytes());
|
|
}
|
|
|
|
#[test]
|
|
fn test_proxy_v1_header() {
|
|
let header = build_proxy_v1_header("1.2.3.4", "5.6.7.8", 12345, 443);
|
|
assert_eq!(header, "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader() {
|
|
let frame1 = encode_frame(1, FRAME_OPEN, b"PROXY TCP4 1.2.3.4 5.6.7.8 1234 443\r\n");
|
|
let frame2 = encode_frame(1, FRAME_DATA, b"GET / HTTP/1.1\r\n");
|
|
let frame3 = encode_frame(1, FRAME_CLOSE, &[]);
|
|
|
|
let mut data = Vec::new();
|
|
data.extend_from_slice(&frame1);
|
|
data.extend_from_slice(&frame2);
|
|
data.extend_from_slice(&frame3);
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let f1 = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(f1.stream_id, 1);
|
|
assert_eq!(f1.frame_type, FRAME_OPEN);
|
|
assert!(f1.payload.starts_with(b"PROXY"));
|
|
|
|
let f2 = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(f2.frame_type, FRAME_DATA);
|
|
|
|
let f3 = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(f3.frame_type, FRAME_CLOSE);
|
|
assert!(f3.payload.is_empty());
|
|
|
|
// EOF
|
|
assert!(reader.next_frame().await.unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_config_type() {
|
|
let payload = b"{\"listenPorts\":[443]}";
|
|
let encoded = encode_frame(0, FRAME_CONFIG, payload);
|
|
assert_eq!(encoded[4], FRAME_CONFIG);
|
|
assert_eq!(&encoded[0..4], &0u32.to_be_bytes());
|
|
assert_eq!(&encoded[9..], payload.as_slice());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_data_back_type() {
|
|
let payload = b"response data";
|
|
let encoded = encode_frame(7, FRAME_DATA_BACK, payload);
|
|
assert_eq!(encoded[4], FRAME_DATA_BACK);
|
|
assert_eq!(&encoded[0..4], &7u32.to_be_bytes());
|
|
assert_eq!(&encoded[5..9], &(payload.len() as u32).to_be_bytes());
|
|
assert_eq!(&encoded[9..], payload.as_slice());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_close_back_type() {
|
|
let encoded = encode_frame(99, FRAME_CLOSE_BACK, &[]);
|
|
assert_eq!(encoded[4], FRAME_CLOSE_BACK);
|
|
assert_eq!(&encoded[0..4], &99u32.to_be_bytes());
|
|
assert_eq!(&encoded[5..9], &0u32.to_be_bytes());
|
|
assert_eq!(encoded.len(), FRAME_HEADER_SIZE);
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_large_stream_id() {
|
|
let encoded = encode_frame(u32::MAX, FRAME_DATA, b"x");
|
|
assert_eq!(&encoded[0..4], &u32::MAX.to_be_bytes());
|
|
assert_eq!(encoded[4], FRAME_DATA);
|
|
assert_eq!(&encoded[5..9], &1u32.to_be_bytes());
|
|
assert_eq!(encoded[9], b'x');
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_max_payload_rejection() {
|
|
let mut data = Vec::new();
|
|
data.extend_from_slice(&1u32.to_be_bytes());
|
|
data.push(FRAME_DATA);
|
|
data.extend_from_slice(&(MAX_PAYLOAD_SIZE + 1).to_be_bytes());
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let result = reader.next_frame().await;
|
|
assert!(result.is_err());
|
|
let err = result.unwrap_err();
|
|
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_eof_mid_header() {
|
|
// Only 5 bytes — not enough for a 9-byte header
|
|
let data = vec![0u8; 5];
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
// Should return Ok(None) on partial header EOF
|
|
let result = reader.next_frame().await;
|
|
assert!(result.unwrap().is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_eof_mid_payload() {
|
|
// Full header claiming 100 bytes of payload, but only 10 bytes present
|
|
let mut data = Vec::new();
|
|
data.extend_from_slice(&1u32.to_be_bytes());
|
|
data.push(FRAME_DATA);
|
|
data.extend_from_slice(&100u32.to_be_bytes());
|
|
data.extend_from_slice(&[0xAB; 10]);
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let result = reader.next_frame().await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_all_frame_types() {
|
|
let types = [
|
|
FRAME_OPEN,
|
|
FRAME_DATA,
|
|
FRAME_CLOSE,
|
|
FRAME_DATA_BACK,
|
|
FRAME_CLOSE_BACK,
|
|
FRAME_CONFIG,
|
|
FRAME_PING,
|
|
FRAME_PONG,
|
|
];
|
|
|
|
let mut data = Vec::new();
|
|
for (i, &ft) in types.iter().enumerate() {
|
|
let payload = format!("payload_{}", i);
|
|
data.extend_from_slice(&encode_frame(i as u32, ft, payload.as_bytes()));
|
|
}
|
|
|
|
let cursor = std::io::Cursor::new(data);
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
for (i, &ft) in types.iter().enumerate() {
|
|
let frame = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(frame.stream_id, i as u32);
|
|
assert_eq!(frame.frame_type, ft);
|
|
assert_eq!(&frame.payload[..], format!("payload_{}", i).as_bytes());
|
|
}
|
|
|
|
assert!(reader.next_frame().await.unwrap().is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_frame_reader_zero_length_payload() {
|
|
let data = encode_frame(42, FRAME_CLOSE, &[]);
|
|
let cursor = std::io::Cursor::new(data.to_vec());
|
|
let mut reader = FrameReader::new(cursor);
|
|
|
|
let frame = reader.next_frame().await.unwrap().unwrap();
|
|
assert_eq!(frame.stream_id, 42);
|
|
assert_eq!(frame.frame_type, FRAME_CLOSE);
|
|
assert!(frame.payload.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_encode_frame_ping_pong() {
|
|
// PING: stream_id=0, empty payload (control frame)
|
|
let ping = encode_frame(0, FRAME_PING, &[]);
|
|
assert_eq!(ping[4], FRAME_PING);
|
|
assert_eq!(&ping[0..4], &0u32.to_be_bytes());
|
|
assert_eq!(ping.len(), FRAME_HEADER_SIZE);
|
|
|
|
// PONG: stream_id=0, empty payload (control frame)
|
|
let pong = encode_frame(0, FRAME_PONG, &[]);
|
|
assert_eq!(pong[4], FRAME_PONG);
|
|
assert_eq!(&pong[0..4], &0u32.to_be_bytes());
|
|
assert_eq!(pong.len(), FRAME_HEADER_SIZE);
|
|
}
|
|
|
|
// --- compute_window_for_stream_count tests ---
|
|
|
|
#[test]
|
|
fn test_adaptive_window_zero_streams() {
|
|
// 0 streams treated as 1: 200MB/1 -> clamped to 4MB max
|
|
assert_eq!(compute_window_for_stream_count(0), INITIAL_STREAM_WINDOW);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_one_stream() {
|
|
assert_eq!(compute_window_for_stream_count(1), INITIAL_STREAM_WINDOW);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_50_streams_full() {
|
|
// 200MB/50 = 4MB = exactly INITIAL_STREAM_WINDOW
|
|
assert_eq!(compute_window_for_stream_count(50), INITIAL_STREAM_WINDOW);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_51_streams_starts_scaling() {
|
|
// 200MB/51 < 4MB — first value below max
|
|
let w = compute_window_for_stream_count(51);
|
|
assert!(w < INITIAL_STREAM_WINDOW);
|
|
assert_eq!(w, (200 * 1024 * 1024u64 / 51) as u32);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_100_streams() {
|
|
// 200MB/100 = 2MB
|
|
assert_eq!(compute_window_for_stream_count(100), 2 * 1024 * 1024);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_200_streams_at_floor() {
|
|
// 200MB/200 = 1MB = exactly the floor
|
|
assert_eq!(compute_window_for_stream_count(200), 1 * 1024 * 1024);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_500_streams_clamped() {
|
|
// 200MB/500 = 0.4MB -> clamped up to 1MB floor
|
|
assert_eq!(compute_window_for_stream_count(500), 1 * 1024 * 1024);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_max_u32() {
|
|
// Extreme: u32::MAX streams -> tiny value -> clamped to 1MB
|
|
assert_eq!(compute_window_for_stream_count(u32::MAX), 1 * 1024 * 1024);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_monotonically_decreasing() {
|
|
let mut prev = compute_window_for_stream_count(1);
|
|
for n in [2, 10, 50, 51, 100, 200, 500, 1000] {
|
|
let w = compute_window_for_stream_count(n);
|
|
assert!(w <= prev, "window increased from {} to {} at n={}", prev, w, n);
|
|
prev = w;
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_window_total_budget_bounded() {
|
|
// active x per_stream_window should never exceed 200MB (+ clamp overhead for high N)
|
|
for n in [1, 10, 50, 100, 200] {
|
|
let w = compute_window_for_stream_count(n);
|
|
let total = w as u64 * n as u64;
|
|
assert!(total <= 200 * 1024 * 1024, "total {}MB exceeds budget at n={}", total / (1024*1024), n);
|
|
}
|
|
}
|
|
|
|
// --- encode/decode window_update roundtrip ---
|
|
|
|
#[test]
|
|
fn test_window_update_roundtrip() {
|
|
for &increment in &[0u32, 1, 64 * 1024, INITIAL_STREAM_WINDOW, MAX_WINDOW_SIZE, u32::MAX] {
|
|
let frame = encode_window_update(42, FRAME_WINDOW_UPDATE, increment);
|
|
assert_eq!(frame[4], FRAME_WINDOW_UPDATE);
|
|
let decoded = decode_window_update(&frame[FRAME_HEADER_SIZE..]);
|
|
assert_eq!(decoded, Some(increment));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_window_update_back_roundtrip() {
|
|
let frame = encode_window_update(7, FRAME_WINDOW_UPDATE_BACK, 1234567);
|
|
assert_eq!(frame[4], FRAME_WINDOW_UPDATE_BACK);
|
|
assert_eq!(decode_window_update(&frame[FRAME_HEADER_SIZE..]), Some(1234567));
|
|
}
|
|
|
|
#[test]
|
|
fn test_decode_window_update_malformed() {
|
|
assert_eq!(decode_window_update(&[]), None);
|
|
assert_eq!(decode_window_update(&[0, 0, 0]), None);
|
|
assert_eq!(decode_window_update(&[0, 0, 0, 0, 0]), None);
|
|
}
|
|
}
|