fix(rust-core,protocol): eliminate edge stream registration races and reduce frame buffering copies
This commit is contained in:
@@ -6,6 +6,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["io-util", "sync", "time"] }
|
||||
tokio-util = "0.7"
|
||||
bytes = "1"
|
||||
log = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::collections::VecDeque;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
|
||||
// Frame type constants
|
||||
@@ -32,7 +33,7 @@ pub const WINDOW_UPDATE_THRESHOLD: u32 = INITIAL_STREAM_WINDOW / 2;
|
||||
pub const MAX_WINDOW_SIZE: u32 = 16 * 1024 * 1024;
|
||||
|
||||
/// Encode a WINDOW_UPDATE frame for a specific stream.
|
||||
pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Vec<u8> {
|
||||
pub fn encode_window_update(stream_id: u32, frame_type: u8, increment: u32) -> Bytes {
|
||||
encode_frame(stream_id, frame_type, &increment.to_be_bytes())
|
||||
}
|
||||
|
||||
@@ -45,6 +46,30 @@ pub fn compute_window_for_stream_count(active: u32) -> u32 {
|
||||
per_stream.clamp(64 * 1024, INITIAL_STREAM_WINDOW as u64) as u32
|
||||
}
|
||||
|
||||
/// Proactively clamp a send_window AtomicU32 down to at most `target`.
|
||||
/// CAS loop so concurrent WINDOW_UPDATE additions are not lost.
|
||||
/// Returns the value after clamping.
|
||||
#[inline]
|
||||
pub fn clamp_send_window(
|
||||
send_window: &std::sync::atomic::AtomicU32,
|
||||
target: u32,
|
||||
) -> u32 {
|
||||
loop {
|
||||
let current = send_window.load(std::sync::atomic::Ordering::Acquire);
|
||||
if current <= target {
|
||||
return current;
|
||||
}
|
||||
match send_window.compare_exchange_weak(
|
||||
current, target,
|
||||
std::sync::atomic::Ordering::AcqRel,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => return target,
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a WINDOW_UPDATE payload into a byte increment. Returns None if payload is malformed.
|
||||
pub fn decode_window_update(payload: &[u8]) -> Option<u32> {
|
||||
if payload.len() != 4 {
|
||||
@@ -58,18 +83,18 @@ pub fn decode_window_update(payload: &[u8]) -> Option<u32> {
|
||||
pub struct Frame {
|
||||
pub stream_id: u32,
|
||||
pub frame_type: u8,
|
||||
pub payload: Vec<u8>,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
/// Encode a frame into bytes: [stream_id:4][type:1][length:4][payload]
|
||||
pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Vec<u8> {
|
||||
pub fn encode_frame(stream_id: u32, frame_type: u8, payload: &[u8]) -> Bytes {
|
||||
let len = payload.len() as u32;
|
||||
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + payload.len());
|
||||
buf.extend_from_slice(&stream_id.to_be_bytes());
|
||||
buf.push(frame_type);
|
||||
buf.extend_from_slice(&len.to_be_bytes());
|
||||
buf.extend_from_slice(payload);
|
||||
buf
|
||||
Bytes::from(buf)
|
||||
}
|
||||
|
||||
/// Write a frame header into `buf[0..FRAME_HEADER_SIZE]`.
|
||||
@@ -152,7 +177,7 @@ impl<R: AsyncRead + Unpin> FrameReader<R> {
|
||||
Ok(Some(Frame {
|
||||
stream_id,
|
||||
frame_type,
|
||||
payload,
|
||||
payload: Bytes::from(payload),
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -186,9 +211,9 @@ pub enum TunnelEvent {
|
||||
/// Write state extracted into a sub-struct so the borrow checker can see
|
||||
/// disjoint field access between `self.write` and `self.stream`.
|
||||
struct WriteState {
|
||||
ctrl_queue: VecDeque<Vec<u8>>, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first
|
||||
data_queue: VecDeque<Vec<u8>>, // DATA, DATA_BACK — only when ctrl is empty
|
||||
offset: usize, // progress within current frame being written
|
||||
ctrl_queue: VecDeque<Bytes>, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first
|
||||
data_queue: VecDeque<Bytes>, // DATA, DATA_BACK — only when ctrl is empty
|
||||
offset: usize, // progress within current frame being written
|
||||
flush_needed: bool,
|
||||
}
|
||||
|
||||
@@ -206,26 +231,21 @@ impl WriteState {
|
||||
/// WINDOW_UPDATE starvation that causes flow control deadlocks.
|
||||
pub struct TunnelIo<S> {
|
||||
stream: S,
|
||||
// Read state: accumulate bytes, parse frames incrementally
|
||||
read_buf: Vec<u8>,
|
||||
read_pos: usize,
|
||||
parse_pos: usize,
|
||||
// Read state: BytesMut accumulates bytes; split_to extracts frames zero-copy.
|
||||
read_buf: BytesMut,
|
||||
// Write state: extracted sub-struct for safe disjoint borrows
|
||||
write: WriteState,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
pub fn new(stream: S, initial_data: Vec<u8>) -> Self {
|
||||
let read_pos = initial_data.len();
|
||||
let mut read_buf = initial_data;
|
||||
let mut read_buf = BytesMut::from(&initial_data[..]);
|
||||
if read_buf.capacity() < 65536 {
|
||||
read_buf.reserve(65536 - read_buf.len());
|
||||
}
|
||||
Self {
|
||||
stream,
|
||||
read_buf,
|
||||
read_pos,
|
||||
parse_pos: 0,
|
||||
write: WriteState {
|
||||
ctrl_queue: VecDeque::new(),
|
||||
data_queue: VecDeque::new(),
|
||||
@@ -236,41 +256,39 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
}
|
||||
|
||||
/// Queue a high-priority control frame (PONG, WINDOW_UPDATE, CLOSE, OPEN).
|
||||
pub fn queue_ctrl(&mut self, frame: Vec<u8>) {
|
||||
pub fn queue_ctrl(&mut self, frame: Bytes) {
|
||||
self.write.ctrl_queue.push_back(frame);
|
||||
}
|
||||
|
||||
/// Queue a lower-priority data frame (DATA, DATA_BACK).
|
||||
pub fn queue_data(&mut self, frame: Vec<u8>) {
|
||||
pub fn queue_data(&mut self, frame: Bytes) {
|
||||
self.write.data_queue.push_back(frame);
|
||||
}
|
||||
|
||||
/// Try to parse a complete frame from the read buffer.
|
||||
/// Uses a parse_pos cursor to avoid drain() on every frame.
|
||||
/// Zero-copy: uses BytesMut::split_to to extract frames without allocating.
|
||||
pub fn try_parse_frame(&mut self) -> Option<Result<Frame, std::io::Error>> {
|
||||
let available = self.read_pos - self.parse_pos;
|
||||
if available < FRAME_HEADER_SIZE {
|
||||
if self.read_buf.len() < FRAME_HEADER_SIZE {
|
||||
return None;
|
||||
}
|
||||
|
||||
let base = self.parse_pos;
|
||||
let stream_id = u32::from_be_bytes([
|
||||
self.read_buf[base], self.read_buf[base + 1],
|
||||
self.read_buf[base + 2], self.read_buf[base + 3],
|
||||
self.read_buf[0], self.read_buf[1],
|
||||
self.read_buf[2], self.read_buf[3],
|
||||
]);
|
||||
let frame_type = self.read_buf[base + 4];
|
||||
let frame_type = self.read_buf[4];
|
||||
let length = u32::from_be_bytes([
|
||||
self.read_buf[base + 5], self.read_buf[base + 6],
|
||||
self.read_buf[base + 7], self.read_buf[base + 8],
|
||||
self.read_buf[5], self.read_buf[6],
|
||||
self.read_buf[7], self.read_buf[8],
|
||||
]);
|
||||
|
||||
if length > MAX_PAYLOAD_SIZE {
|
||||
let header = [
|
||||
self.read_buf[base], self.read_buf[base + 1],
|
||||
self.read_buf[base + 2], self.read_buf[base + 3],
|
||||
self.read_buf[base + 4], self.read_buf[base + 5],
|
||||
self.read_buf[base + 6], self.read_buf[base + 7],
|
||||
self.read_buf[base + 8],
|
||||
self.read_buf[0], self.read_buf[1],
|
||||
self.read_buf[2], self.read_buf[3],
|
||||
self.read_buf[4], self.read_buf[5],
|
||||
self.read_buf[6], self.read_buf[7],
|
||||
self.read_buf[8],
|
||||
];
|
||||
log::error!(
|
||||
"CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}",
|
||||
@@ -283,19 +301,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
}
|
||||
|
||||
let total_frame_size = FRAME_HEADER_SIZE + length as usize;
|
||||
if available < total_frame_size {
|
||||
if self.read_buf.len() < total_frame_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
let payload = self.read_buf[base + FRAME_HEADER_SIZE..base + total_frame_size].to_vec();
|
||||
self.parse_pos += total_frame_size;
|
||||
|
||||
// Compact when parse_pos > half the data to reclaim memory
|
||||
if self.parse_pos > self.read_pos / 2 && self.parse_pos > 0 {
|
||||
self.read_buf.drain(..self.parse_pos);
|
||||
self.read_pos -= self.parse_pos;
|
||||
self.parse_pos = 0;
|
||||
}
|
||||
// Zero-copy extraction: split the frame off the read buffer (O(1) pointer adjustment).
|
||||
// split_to removes the first total_frame_size bytes from read_buf.
|
||||
let mut frame_data = self.read_buf.split_to(total_frame_size);
|
||||
// Split off header, keep only payload. freeze() converts BytesMut → Bytes (O(1)).
|
||||
let payload = frame_data.split_off(FRAME_HEADER_SIZE).freeze();
|
||||
|
||||
Some(Ok(Frame { stream_id, frame_type, payload }))
|
||||
}
|
||||
@@ -306,8 +320,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
pub fn poll_step(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
ctrl_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
ctrl_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Bytes>,
|
||||
liveness_deadline: &mut Pin<Box<tokio::time::Sleep>>,
|
||||
cancel_token: &tokio_util::sync::CancellationToken,
|
||||
) -> Poll<TunnelEvent> {
|
||||
@@ -371,23 +385,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
// the waker without re-registering it, causing the task to sleep until a
|
||||
// timer or channel wakes it (potentially 15+ seconds of lost reads).
|
||||
loop {
|
||||
// Compact if needed to make room for reads
|
||||
if self.parse_pos > 0 && self.read_buf.len() - self.read_pos < 32768 {
|
||||
self.read_buf.drain(..self.parse_pos);
|
||||
self.read_pos -= self.parse_pos;
|
||||
self.parse_pos = 0;
|
||||
}
|
||||
if self.read_buf.len() < self.read_pos + 32768 {
|
||||
self.read_buf.resize(self.read_pos + 32768, 0);
|
||||
}
|
||||
let mut rbuf = ReadBuf::new(&mut self.read_buf[self.read_pos..]);
|
||||
// Ensure at least 32KB of writable space
|
||||
let len_before = self.read_buf.len();
|
||||
self.read_buf.resize(len_before + 32768, 0);
|
||||
let mut rbuf = ReadBuf::new(&mut self.read_buf[len_before..]);
|
||||
match Pin::new(&mut self.stream).poll_read(cx, &mut rbuf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let n = rbuf.filled().len();
|
||||
// Trim back to actual data length
|
||||
self.read_buf.truncate(len_before + n);
|
||||
if n == 0 {
|
||||
return Poll::Ready(TunnelEvent::Eof);
|
||||
}
|
||||
self.read_pos += n;
|
||||
if let Some(result) = self.try_parse_frame() {
|
||||
return match result {
|
||||
Ok(frame) => Poll::Ready(TunnelEvent::Frame(frame)),
|
||||
@@ -398,10 +407,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
// waker is re-registered when it finally returns Pending.
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.read_buf.truncate(len_before);
|
||||
log::error!("TunnelIo: poll_read error: {}", e);
|
||||
return Poll::Ready(TunnelEvent::ReadError(e));
|
||||
}
|
||||
Poll::Pending => break,
|
||||
Poll::Pending => {
|
||||
self.read_buf.truncate(len_before);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -471,14 +484,14 @@ mod tests {
|
||||
let mut buf = vec![0u8; FRAME_HEADER_SIZE + payload.len()];
|
||||
buf[FRAME_HEADER_SIZE..].copy_from_slice(payload);
|
||||
encode_frame_header(&mut buf, 42, FRAME_DATA, payload.len());
|
||||
assert_eq!(buf, encode_frame(42, FRAME_DATA, payload));
|
||||
assert_eq!(buf[..], encode_frame(42, FRAME_DATA, payload)[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_frame_header_empty_payload() {
|
||||
let mut buf = vec![0u8; FRAME_HEADER_SIZE];
|
||||
encode_frame_header(&mut buf, 99, FRAME_CLOSE, 0);
|
||||
assert_eq!(buf, encode_frame(99, FRAME_CLOSE, &[]));
|
||||
assert_eq!(buf[..], encode_frame(99, FRAME_CLOSE, &[])[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -646,7 +659,7 @@ mod tests {
|
||||
let frame = reader.next_frame().await.unwrap().unwrap();
|
||||
assert_eq!(frame.stream_id, i as u32);
|
||||
assert_eq!(frame.frame_type, ft);
|
||||
assert_eq!(frame.payload, format!("payload_{}", i).as_bytes());
|
||||
assert_eq!(&frame.payload[..], format!("payload_{}", i).as_bytes());
|
||||
}
|
||||
|
||||
assert!(reader.next_frame().await.unwrap().is_none());
|
||||
@@ -655,7 +668,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_frame_reader_zero_length_payload() {
|
||||
let data = encode_frame(42, FRAME_CLOSE, &[]);
|
||||
let cursor = std::io::Cursor::new(data);
|
||||
let cursor = std::io::Cursor::new(data.to_vec());
|
||||
let mut reader = FrameReader::new(cursor);
|
||||
|
||||
let frame = reader.next_frame().await.unwrap().unwrap();
|
||||
@@ -783,6 +796,39 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// --- clamp_send_window tests ---
|
||||
|
||||
#[test]
|
||||
fn test_clamp_send_window_reduces_above_target() {
|
||||
let w = std::sync::atomic::AtomicU32::new(4 * 1024 * 1024); // 4 MB
|
||||
let result = clamp_send_window(&w, 512 * 1024); // target 512 KB
|
||||
assert_eq!(result, 512 * 1024);
|
||||
assert_eq!(w.load(std::sync::atomic::Ordering::Relaxed), 512 * 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clamp_send_window_noop_below_target() {
|
||||
let w = std::sync::atomic::AtomicU32::new(256 * 1024); // 256 KB
|
||||
let result = clamp_send_window(&w, 512 * 1024); // target 512 KB
|
||||
assert_eq!(result, 256 * 1024);
|
||||
assert_eq!(w.load(std::sync::atomic::Ordering::Relaxed), 256 * 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clamp_send_window_noop_at_target() {
|
||||
let w = std::sync::atomic::AtomicU32::new(512 * 1024);
|
||||
let result = clamp_send_window(&w, 512 * 1024);
|
||||
assert_eq!(result, 512 * 1024);
|
||||
assert_eq!(w.load(std::sync::atomic::Ordering::Relaxed), 512 * 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clamp_send_window_zero_value() {
|
||||
let w = std::sync::atomic::AtomicU32::new(0);
|
||||
let result = clamp_send_window(&w, 64 * 1024);
|
||||
assert_eq!(result, 0);
|
||||
}
|
||||
|
||||
// --- encode/decode window_update roundtrip ---
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user