fix(rust-edge): refactor tunnel I/O to preserve TLS state and prioritize control frames
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use std::collections::VecDeque;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
|
||||
// Frame type constants
|
||||
pub const FRAME_OPEN: u8 = 0x01;
|
||||
@@ -120,9 +124,13 @@ impl<R: AsyncRead + Unpin> FrameReader<R> {
|
||||
]);
|
||||
|
||||
if length > MAX_PAYLOAD_SIZE {
|
||||
log::error!(
|
||||
"CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}",
|
||||
self.header_buf, stream_id, frame_type, length
|
||||
);
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("frame payload too large: {} bytes", length),
|
||||
format!("frame payload too large: {} bytes (header={:02x?})", length, self.header_buf),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -144,6 +152,256 @@ impl<R: AsyncRead + Unpin> FrameReader<R> {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TunnelIo: single-owner I/O multiplexer for the TLS tunnel connection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Events produced by the TunnelIo event loop.
|
||||
#[derive(Debug)]
|
||||
pub enum TunnelEvent {
|
||||
/// A complete frame was read from the remote side.
|
||||
Frame(Frame),
|
||||
/// The remote side closed the connection (EOF).
|
||||
Eof,
|
||||
/// A read error occurred.
|
||||
ReadError(std::io::Error),
|
||||
/// A write error occurred.
|
||||
WriteError(std::io::Error),
|
||||
/// No frames received for the liveness timeout duration.
|
||||
LivenessTimeout,
|
||||
/// The cancellation token was triggered.
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Single-owner I/O engine for the tunnel TLS connection.
|
||||
///
|
||||
/// Owns the TLS stream directly — no `tokio::io::split()`, no mutex.
|
||||
/// Uses two priority write queues: ctrl frames (PONG, WINDOW_UPDATE, CLOSE, OPEN)
|
||||
/// are ALWAYS written before data frames (DATA, DATA_BACK). This prevents
|
||||
/// WINDOW_UPDATE starvation that causes flow control deadlocks.
|
||||
pub struct TunnelIo<S> {
|
||||
stream: S,
|
||||
// Read state: accumulate bytes, parse frames incrementally
|
||||
read_buf: Vec<u8>,
|
||||
read_pos: usize,
|
||||
// Write state: dual priority queues
|
||||
ctrl_queue: VecDeque<Vec<u8>>, // PONG, WINDOW_UPDATE, CLOSE, OPEN — always first
|
||||
data_queue: VecDeque<Vec<u8>>, // DATA, DATA_BACK — only when ctrl is empty
|
||||
write_offset: usize, // progress within current frame being written
|
||||
flush_needed: bool,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> TunnelIo<S> {
|
||||
pub fn new(stream: S, initial_data: Vec<u8>) -> Self {
|
||||
let read_pos = initial_data.len();
|
||||
let mut read_buf = initial_data;
|
||||
if read_buf.capacity() < 65536 {
|
||||
read_buf.reserve(65536 - read_buf.len());
|
||||
}
|
||||
Self {
|
||||
stream,
|
||||
read_buf,
|
||||
read_pos,
|
||||
ctrl_queue: VecDeque::new(),
|
||||
data_queue: VecDeque::new(),
|
||||
write_offset: 0,
|
||||
flush_needed: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Queue a high-priority control frame (PONG, WINDOW_UPDATE, CLOSE, OPEN).
|
||||
pub fn queue_ctrl(&mut self, frame: Vec<u8>) {
|
||||
self.ctrl_queue.push_back(frame);
|
||||
}
|
||||
|
||||
/// Queue a lower-priority data frame (DATA, DATA_BACK).
|
||||
pub fn queue_data(&mut self, frame: Vec<u8>) {
|
||||
self.data_queue.push_back(frame);
|
||||
}
|
||||
|
||||
/// Try to parse a complete frame from the read buffer.
|
||||
pub fn try_parse_frame(&mut self) -> Option<Result<Frame, std::io::Error>> {
|
||||
if self.read_pos < FRAME_HEADER_SIZE {
|
||||
return None;
|
||||
}
|
||||
|
||||
let stream_id = u32::from_be_bytes([
|
||||
self.read_buf[0], self.read_buf[1], self.read_buf[2], self.read_buf[3],
|
||||
]);
|
||||
let frame_type = self.read_buf[4];
|
||||
let length = u32::from_be_bytes([
|
||||
self.read_buf[5], self.read_buf[6], self.read_buf[7], self.read_buf[8],
|
||||
]);
|
||||
|
||||
if length > MAX_PAYLOAD_SIZE {
|
||||
let header = [
|
||||
self.read_buf[0], self.read_buf[1], self.read_buf[2], self.read_buf[3],
|
||||
self.read_buf[4], self.read_buf[5], self.read_buf[6], self.read_buf[7],
|
||||
self.read_buf[8],
|
||||
];
|
||||
log::error!(
|
||||
"CORRUPT FRAME HEADER: raw={:02x?} stream_id={} type=0x{:02x} length={}",
|
||||
header, stream_id, frame_type, length
|
||||
);
|
||||
return Some(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("frame payload too large: {} bytes (header={:02x?})", length, header),
|
||||
)));
|
||||
}
|
||||
|
||||
let total_frame_size = FRAME_HEADER_SIZE + length as usize;
|
||||
if self.read_pos < total_frame_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
let payload = self.read_buf[FRAME_HEADER_SIZE..total_frame_size].to_vec();
|
||||
self.read_buf.drain(..total_frame_size);
|
||||
self.read_pos -= total_frame_size;
|
||||
|
||||
Some(Ok(Frame { stream_id, frame_type, payload }))
|
||||
}
|
||||
|
||||
fn has_write_work(&self) -> bool {
|
||||
!self.ctrl_queue.is_empty() || !self.data_queue.is_empty()
|
||||
}
|
||||
|
||||
/// Poll-based I/O step. Returns Ready on events, Pending when idle.
|
||||
///
|
||||
/// Order: write(ctrl→data) → flush → read → channels → timers
|
||||
pub fn poll_step(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
ctrl_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
liveness_deadline: &mut Pin<Box<tokio::time::Sleep>>,
|
||||
cancel_token: &tokio_util::sync::CancellationToken,
|
||||
) -> Poll<TunnelEvent> {
|
||||
// 1. WRITE: drain ctrl queue first, then data queue.
|
||||
// TLS poll_write writes plaintext to session buffer (always Ready).
|
||||
// Batch up to 16 frames per poll cycle.
|
||||
let mut writes = 0;
|
||||
while self.has_write_work() && writes < 16 {
|
||||
// Determine which queue to write from and the frame data.
|
||||
// We access the queues via raw pointers to avoid borrow conflicts with self.stream.
|
||||
let from_ctrl = !self.ctrl_queue.is_empty();
|
||||
let frame_ptr: *const Vec<u8> = if from_ctrl {
|
||||
self.ctrl_queue.front().unwrap()
|
||||
} else {
|
||||
self.data_queue.front().unwrap()
|
||||
};
|
||||
// SAFETY: the frame is not modified while we hold the pointer — poll_write
|
||||
// only writes to self.stream, and advance_write only runs after poll_write returns.
|
||||
let frame = unsafe { &*frame_ptr };
|
||||
let remaining = &frame[self.write_offset..];
|
||||
|
||||
match Pin::new(&mut self.stream).poll_write(cx, remaining) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return Poll::Ready(TunnelEvent::WriteError(
|
||||
std::io::Error::new(std::io::ErrorKind::WriteZero, "write zero"),
|
||||
));
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
self.write_offset += n;
|
||||
self.flush_needed = true;
|
||||
if self.write_offset >= frame.len() {
|
||||
if from_ctrl { self.ctrl_queue.pop_front(); }
|
||||
else { self.data_queue.pop_front(); }
|
||||
self.write_offset = 0;
|
||||
writes += 1;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(TunnelEvent::WriteError(e)),
|
||||
Poll::Pending => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 2. FLUSH: push encrypted data from TLS session to TCP.
|
||||
if self.flush_needed {
|
||||
match Pin::new(&mut self.stream).poll_flush(cx) {
|
||||
Poll::Ready(Ok(())) => self.flush_needed = false,
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(TunnelEvent::WriteError(e)),
|
||||
Poll::Pending => {} // TCP waker will notify us
|
||||
}
|
||||
}
|
||||
|
||||
// 3. READ: drain stream until Pending to ensure the TCP waker is always registered.
|
||||
// Without this loop, a Ready return with partial frame data would consume
|
||||
// the waker without re-registering it, causing the task to sleep until a
|
||||
// timer or channel wakes it (potentially 15+ seconds of lost reads).
|
||||
loop {
|
||||
if self.read_buf.len() < self.read_pos + 32768 {
|
||||
self.read_buf.resize(self.read_pos + 32768, 0);
|
||||
}
|
||||
let mut rbuf = ReadBuf::new(&mut self.read_buf[self.read_pos..]);
|
||||
match Pin::new(&mut self.stream).poll_read(cx, &mut rbuf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let n = rbuf.filled().len();
|
||||
if n == 0 {
|
||||
return Poll::Ready(TunnelEvent::Eof);
|
||||
}
|
||||
self.read_pos += n;
|
||||
if let Some(result) = self.try_parse_frame() {
|
||||
return match result {
|
||||
Ok(frame) => Poll::Ready(TunnelEvent::Frame(frame)),
|
||||
Err(e) => Poll::Ready(TunnelEvent::ReadError(e)),
|
||||
};
|
||||
}
|
||||
// Partial data — loop to call poll_read again so the TCP
|
||||
// waker is re-registered when it finally returns Pending.
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(TunnelEvent::ReadError(e)),
|
||||
Poll::Pending => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 4. CHANNELS: drain ctrl into ctrl_queue, data into data_queue.
|
||||
let mut got_new = false;
|
||||
loop {
|
||||
match ctrl_rx.poll_recv(cx) {
|
||||
Poll::Ready(Some(frame)) => { self.ctrl_queue.push_back(frame); got_new = true; }
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(TunnelEvent::WriteError(
|
||||
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "ctrl channel closed"),
|
||||
));
|
||||
}
|
||||
Poll::Pending => break,
|
||||
}
|
||||
}
|
||||
loop {
|
||||
match data_rx.poll_recv(cx) {
|
||||
Poll::Ready(Some(frame)) => { self.data_queue.push_back(frame); got_new = true; }
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(TunnelEvent::WriteError(
|
||||
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "data channel closed"),
|
||||
));
|
||||
}
|
||||
Poll::Pending => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 5. TIMERS
|
||||
if liveness_deadline.as_mut().poll(cx).is_ready() {
|
||||
return Poll::Ready(TunnelEvent::LivenessTimeout);
|
||||
}
|
||||
if cancel_token.is_cancelled() {
|
||||
return Poll::Ready(TunnelEvent::Cancelled);
|
||||
}
|
||||
|
||||
// 6. SELF-WAKE: only when we have frames AND flush is done.
|
||||
// If flush is pending, the TCP write-readiness waker will notify us.
|
||||
// If we got new channel frames, wake to write them.
|
||||
if got_new || (!self.flush_needed && self.has_write_work()) {
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user