fix(rust-core,protocol): eliminate edge stream registration races and reduce frame buffering copies

This commit is contained in:
2026-03-17 16:37:43 +00:00
parent e8d429f117
commit 156b17135f
8 changed files with 283 additions and 174 deletions

View File

@@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
tokio = { version = "1", features = ["io-util", "sync", "time"] }
tokio-util = "0.7"
bytes = "1"
log = "0.4"
[dev-dependencies]

View File

@@ -2,6 +2,7 @@ use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
// Frame type constants
@@ -32,7 +33,7 @@ pub const WINDOW_UPDATE_THRESHOLD: u32 = INITIAL_STREAM_WINDOW / 2;
pub const MAX_WINDOW_SIZE: u32 = 16 * 1024 * 1024;
/// Encode a WINDOW_UPDATE frame for a specific stream.
pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Vec<u8> {
pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Bytes {
encode_frame(stream_id, frame_type, &increment.to_be_bytes())
}
@@ -45,6 +46,30 @@ pub fn compute_window_for_stream_count(active: u32) -> u32 {
per_stream.clamp(64 * 1024, INITIAL_STREAM_WINDOW as u64) as u32
}
/// Proactively clamp a send_window AtomicU32 down to at most `target`.
/// CAS loop so concurrent WINDOW_UPDATE additions are not lost.
/// Returns the value after clamping.
#[inline]
pub fn clamp_send_window(
send_window: &std::sync::atomic::AtomicU32,
target: u32,
) -> u32 {
loop {
let current = send_window.load(std::sync::atomic::Ordering::Acquire);
if current <= target {
return current;
}
match send_window.compare_exchange_weak(
current, target,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
) {
Ok(_) => return target,
Err(_) => continue,
}
}
}
/// Decode a WINDOW_UPDATE payload into a byte increment. Returns None if payload is malformed.
pub fn decode_window_update(payload: &[u8]) -> Option<u32> {
if payload.len() != 4 {
@@ -58,18 +83,18 @@ pub fn decode_window_update(payload: &[u8]) -> Option<u32> {
pub struct Frame {
pub stream_id: u32,
pub frame_type: u8,
pub payload: Vec<u8>,
pub payload: Bytes,
}
/// Encode a frame into bytes: [stream_id:4][type:1][length:4][payload]
pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Vec<u8> {
pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Bytes {
let len = payload.len() as u32;
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + payload.len());
buf.extend_from_slice(&stream_id.to_be_bytes());
buf.push(frame_type);
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(payload);
buf
Bytes::from(buf)
}
/// Write a frame header into `buf[0..FRAME_HEADER_SIZE]`.
@@ -152,7 +177,7 @@ impl<R: AsyncRead + Unpin> FrameReader<R> {
Ok(Some(Frame {
stream_id,
frame_type,
payload,
payload: Bytes::from(payload),
}))
}
@@ -186,9 +211,9 @@ pub enum TunnelEvent {
/// Write state extracted into a sub-struct so the borrow checker can see
/// disjoint field access between `self.write` and `self.stream`.
struct WriteState {
ctrl_queue: VecDeque<Vec<u8>>, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first
data_queue: VecDeque<Vec<u8>>, // DATA, DATA_BACK — only when ctrl is empty
offset: usize, // progress within current frame being written
ctrl_queue: VecDeque<Bytes>, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first
data_queue: VecDeque<Bytes>, // DATA, DATA_BACK — only when ctrl is empty
offset: usize, // progress within current frame being written
flush_needed: bool,
}
@@ -206,26 +231,21 @@ impl WriteState {
/// WINDOW_UPDATE starvation that causes flow control deadlocks.
pub struct TunnelIo<S> {
stream: S,
// Read state: accumulate bytes, parse frames incrementally
read_buf: Vec<u8>,
read_pos: usize,
parse_pos: usize,
// Read state: BytesMut accumulates bytes; split_to extracts frames zero-copy.
read_buf: BytesMut,
// Write state: extracted sub-struct for safe disjoint borrows
write: WriteState,
}
impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
pub fn new(stream: S, initial_data: Vec<u8>) -> Self {
let read_pos = initial_data.len();
let mut read_buf = initial_data;
let mut read_buf = BytesMut::from(&initial_data[..]);
if read_buf.capacity() < 65536 {
read_buf.reserve(65536 - read_buf.len());
}
Self {
stream,
read_buf,
read_pos,
parse_pos: 0,
write: WriteState {
ctrl_queue: VecDeque::new(),
data_queue: VecDeque::new(),
@@ -236,41 +256,39 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
}
/// Queue a high-priority control frame (PONG, WINDOW_UPDATE, CLOSE, OPEN).
pub fn queue_ctrl(&mut self, frame: Vec<u8>) {
pub fn queue_ctrl(&mut self, frame: Bytes) {
self.write.ctrl_queue.push_back(frame);
}
/// Queue a lower-priority data frame (DATA, DATA_BACK).
pub fn queue_data(&mut self, frame: Vec<u8>) {
pub fn queue_data(&mut self, frame: Bytes) {
self.write.data_queue.push_back(frame);
}
/// Try to parse a complete frame from the read buffer.
/// Uses a parse_pos cursor to avoid drain() on every frame.
/// Zero-copy: uses BytesMut::split_to to extract frames without allocating.
pub fn try_parse_frame(&mut self) -> Option<Result<Frame, std::io::Error>> {
let available = self.read_pos - self.parse_pos;
if available < FRAME_HEADER_SIZE {
if self.read_buf.len() < FRAME_HEADER_SIZE {
return None;
}
let base = self.parse_pos;
let stream_id = u32::from_be_bytes([
self.read_buf[base], self.read_buf[base + 1],
self.read_buf[base + 2], self.read_buf[base + 3],
self.read_buf[0], self.read_buf[1],
self.read_buf[2], self.read_buf[3],
]);
let frame_type = self.read_buf[base + 4];
let frame_type = self.read_buf[4];
let length = u32::from_be_bytes([
self.read_buf[base + 5], self.read_buf[base + 6],
self.read_buf[base + 7], self.read_buf[base + 8],
self.read_buf[5], self.read_buf[6],
self.read_buf[7], self.read_buf[8],
]);
if length > MAX_PAYLOAD_SIZE {
let header = [
self.read_buf[base], self.read_buf[base + 1],
self.read_buf[base + 2], self.read_buf[base + 3],
self.read_buf[base + 4], self.read_buf[base + 5],
self.read_buf[base + 6], self.read_buf[base + 7],
self.read_buf[base + 8],
self.read_buf[0], self.read_buf[1],
self.read_buf[2], self.read_buf[3],
self.read_buf[4], self.read_buf[5],
self.read_buf[6], self.read_buf[7],
self.read_buf[8],
];
log::error!(
"CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}",
@@ -283,19 +301,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
}
let total_frame_size = FRAME_HEADER_SIZE + length as usize;
if available < total_frame_size {
if self.read_buf.len() < total_frame_size {
return None;
}
let payload = self.read_buf[base + FRAME_HEADER_SIZE..base + total_frame_size].to_vec();
self.parse_pos += total_frame_size;
// Compact when parse_pos > half the data to reclaim memory
if self.parse_pos > self.read_pos / 2 && self.parse_pos > 0 {
self.read_buf.drain(..self.parse_pos);
self.read_pos -= self.parse_pos;
self.parse_pos = 0;
}
// Zero-copy extraction: split the frame off the read buffer (O(1) pointer adjustment).
// split_to removes the first total_frame_size bytes from read_buf.
let mut frame_data = self.read_buf.split_to(total_frame_size);
// Split off header, keep only payload. freeze() converts BytesMut → Bytes (O(1)).
let payload = frame_data.split_off(FRAME_HEADER_SIZE).freeze();
Some(Ok(Frame { stream_id, frame_type, payload }))
}
@@ -306,8 +320,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
pub fn poll_step(
&mut self,
cx: &mut Context<'_>,
ctrl_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
data_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
ctrl_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
data_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
liveness_deadline: &mut Pin<Box<tokio::time::Sleep>>,
cancel_token: &tokio_util::sync::CancellationToken,
) -> Poll<TunnelEvent> {
@@ -371,23 +385,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
// the waker without re-registering it, causing the task to sleep until a
// timer or channel wakes it (potentially 15+ seconds of lost reads).
loop {
// Compact if needed to make room for reads
if self.parse_pos > 0 && self.read_buf.len() - self.read_pos < 32768 {
self.read_buf.drain(..self.parse_pos);
self.read_pos -= self.parse_pos;
self.parse_pos = 0;
}
if self.read_buf.len() < self.read_pos + 32768 {
self.read_buf.resize(self.read_pos + 32768, 0);
}
let mut rbuf = ReadBuf::new(&mut self.read_buf[self.read_pos..]);
// Ensure at least 32KB of writable space
let len_before = self.read_buf.len();
self.read_buf.resize(len_before + 32768, 0);
let mut rbuf = ReadBuf::new(&mut self.read_buf[len_before..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut rbuf) {
Poll::Ready(Ok(())) => {
let n = rbuf.filled().len();
// Trim back to actual data length
self.read_buf.truncate(len_before + n);
if n == 0 {
return Poll::Ready(TunnelEvent::Eof);
}
self.read_pos += n;
if let Some(result) = self.try_parse_frame() {
return match result {
Ok(frame) => Poll::Ready(TunnelEvent::Frame(frame)),
@@ -398,10 +407,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
// waker is re-registered when it finally returns Pending.
}
Poll::Ready(Err(e)) => {
self.read_buf.truncate(len_before);
log::error!("TunnelIo: poll_read error: {}", e);
return Poll::Ready(TunnelEvent::ReadError(e));
}
Poll::Pending => break,
Poll::Pending => {
self.read_buf.truncate(len_before);
break;
}
}
}
@@ -471,14 +484,14 @@ mod tests {
let mut buf = vec![0u8; FRAME_HEADER_SIZE + payload.len()];
buf[FRAME_HEADER_SIZE..].copy_from_slice(payload);
encode_frame_header(&mut buf, 42, FRAME_DATA, payload.len());
assert_eq!(buf, encode_frame(42, FRAME_DATA, payload));
assert_eq!(buf[..], encode_frame(42, FRAME_DATA, payload)[..]);
}
#[test]
fn test_encode_frame_header_empty_payload() {
let mut buf = vec![0u8; FRAME_HEADER_SIZE];
encode_frame_header(&mut buf, 99, FRAME_CLOSE, 0);
assert_eq!(buf, encode_frame(99, FRAME_CLOSE, &[]));
assert_eq!(buf[..], encode_frame(99, FRAME_CLOSE, &[])[..]);
}
#[test]
@@ -646,7 +659,7 @@ mod tests {
let frame = reader.next_frame().await.unwrap().unwrap();
assert_eq!(frame.stream_id, i as u32);
assert_eq!(frame.frame_type, ft);
assert_eq!(frame.payload, format!("payload_{}", i).as_bytes());
assert_eq!(&frame.payload[..], format!("payload_{}", i).as_bytes());
}
assert!(reader.next_frame().await.unwrap().is_none());
@@ -655,7 +668,7 @@ mod tests {
#[tokio::test]
async fn test_frame_reader_zero_length_payload() {
let data = encode_frame(42, FRAME_CLOSE, &[]);
let cursor = std::io::Cursor::new(data);
let cursor = std::io::Cursor::new(data.to_vec());
let mut reader = FrameReader::new(cursor);
let frame = reader.next_frame().await.unwrap().unwrap();
@@ -783,6 +796,39 @@ mod tests {
}
}
// --- clamp_send_window tests ---
#[test]
fn test_clamp_send_window_reduces_above_target() {
let w = std::sync::atomic::AtomicU32::new(4 * 1024 * 1024); // 4 MB
let result = clamp_send_window(&w, 512 * 1024); // target 512 KB
assert_eq!(result, 512 * 1024);
assert_eq!(w.load(std::sync::atomic::Ordering::Relaxed), 512 * 1024);
}
#[test]
fn test_clamp_send_window_noop_below_target() {
let w = std::sync::atomic::AtomicU32::new(256 * 1024); // 256 KB
let result = clamp_send_window(&w, 512 * 1024); // target 512 KB
assert_eq!(result, 256 * 1024);
assert_eq!(w.load(std::sync::atomic::Ordering::Relaxed), 256 * 1024);
}
#[test]
fn test_clamp_send_window_noop_at_target() {
let w = std::sync::atomic::AtomicU32::new(512 * 1024);
let result = clamp_send_window(&w, 512 * 1024);
assert_eq!(result, 512 * 1024);
assert_eq!(w.load(std::sync::atomic::Ordering::Relaxed), 512 * 1024);
}
#[test]
fn test_clamp_send_window_zero_value() {
let w = std::sync::atomic::AtomicU32::new(0);
let result = clamp_send_window(&w, 64 * 1024);
assert_eq!(result, 0);
}
// --- encode/decode window_update roundtrip ---
#[test]