Files
remoteingress/rust/crates/remoteingress-core/src/hub.rs

1114 lines
43 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex, Notify, RwLock, Semaphore};
use tokio::time::{interval, sleep_until, Instant};
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use serde::{Deserialize, Serialize};
use remoteingress_protocol::*;
type HubTlsStream = tokio_rustls::server::TlsStream<TcpStream>;
/// Per-stream data channel capacity. With 4MB window and 32KB frames,
/// at most ~128 frames are in-flight. 256 provides comfortable headroom.
const PER_STREAM_DATA_CAPACITY: usize = 256;
/// Result of processing a frame.
#[allow(dead_code)]
enum FrameAction {
Continue,
Disconnect(String),
}
/// Per-stream state tracked in the hub's stream map.
struct HubStreamState {
/// Channel to deliver FRAME_DATA payloads to the upstream writer task.
data_tx: mpsc::Sender<Vec<u8>>,
/// Cancellation token for this stream.
cancel_token: CancellationToken,
/// Send window for FRAME_DATA_BACK (download direction).
/// Decremented by the upstream reader, incremented by FRAME_WINDOW_UPDATE from edge.
send_window: Arc<AtomicU32>,
/// Notifier to wake the upstream reader when the window opens.
window_notify: Arc<Notify>,
}
/// Hub configuration.
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HubConfig {
pub tunnel_port: u16,
pub target_host: Option<String>,
#[serde(default)]
pub tls_cert_pem: Option<String>,
#[serde(default)]
pub tls_key_pem: Option<String>,
}
impl Default for HubConfig {
fn default() -> Self {
Self {
tunnel_port: 8443,
target_host: Some("127.0.0.1".to_string()),
tls_cert_pem: None,
tls_key_pem: None,
}
}
}
/// An allowed edge identity.
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AllowedEdge {
pub id: String,
pub secret: String,
#[serde(default)]
pub listen_ports: Vec<u16>,
pub stun_interval_secs: Option<u64>,
}
/// Handshake response sent to edge after authentication.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct HandshakeResponse {
listen_ports: Vec<u16>,
stun_interval_secs: u64,
}
/// Configuration update pushed to a connected edge at runtime.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EdgeConfigUpdate {
pub listen_ports: Vec<u16>,
}
/// Runtime status of a connected edge.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ConnectedEdgeStatus {
pub edge_id: String,
pub connected_at: u64,
pub active_streams: usize,
pub peer_addr: String,
}
/// Events emitted by the hub.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(tag = "type")]
pub enum HubEvent {
#[serde(rename_all = "camelCase")]
EdgeConnected { edge_id: String, peer_addr: String },
#[serde(rename_all = "camelCase")]
EdgeDisconnected { edge_id: String, reason: String },
#[serde(rename_all = "camelCase")]
StreamOpened { edge_id: String, stream_id: u32 },
#[serde(rename_all = "camelCase")]
StreamClosed { edge_id: String, stream_id: u32 },
}
/// Hub status response.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HubStatus {
pub running: bool,
pub tunnel_port: u16,
pub connected_edges: Vec<ConnectedEdgeStatus>,
}
/// The tunnel hub that accepts edge connections and demuxes streams to SmartProxy.
pub struct TunnelHub {
config: RwLock<HubConfig>,
allowed_edges: Arc<RwLock<HashMap<String, AllowedEdge>>>,
connected_edges: Arc<Mutex<HashMap<String, ConnectedEdgeInfo>>>,
event_tx: mpsc::Sender<HubEvent>,
event_rx: Mutex<Option<mpsc::Receiver<HubEvent>>>,
shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
running: RwLock<bool>,
cancel_token: CancellationToken,
}
struct ConnectedEdgeInfo {
connected_at: u64,
peer_addr: String,
edge_stream_count: Arc<AtomicU32>,
config_tx: mpsc::Sender<EdgeConfigUpdate>,
#[allow(dead_code)] // kept alive for Drop — cancels child tokens when edge is removed
cancel_token: CancellationToken,
}
impl TunnelHub {
pub fn new(config: HubConfig) -> Self {
let (event_tx, event_rx) = mpsc::channel(1024);
Self {
config: RwLock::new(config),
allowed_edges: Arc::new(RwLock::new(HashMap::new())),
connected_edges: Arc::new(Mutex::new(HashMap::new())),
event_tx,
event_rx: Mutex::new(Some(event_rx)),
shutdown_tx: Mutex::new(None),
running: RwLock::new(false),
cancel_token: CancellationToken::new(),
}
}
/// Take the event receiver (can only be called once).
pub async fn take_event_rx(&self) -> Option<mpsc::Receiver<HubEvent>> {
self.event_rx.lock().await.take()
}
/// Update the list of allowed edges.
/// For any currently-connected edge whose ports changed, push a config update.
pub async fn update_allowed_edges(&self, edges: Vec<AllowedEdge>) {
let mut map = self.allowed_edges.write().await;
// Build new map
let mut new_map = HashMap::new();
for edge in &edges {
new_map.insert(edge.id.clone(), edge.clone());
}
// Push config updates to connected edges whose ports changed
let connected = self.connected_edges.lock().await;
for edge in &edges {
if let Some(info) = connected.get(&edge.id) {
// Check if ports changed compared to old config
let ports_changed = match map.get(&edge.id) {
Some(old) => old.listen_ports != edge.listen_ports,
None => true, // newly allowed edge that's already connected
};
if ports_changed {
let update = EdgeConfigUpdate {
listen_ports: edge.listen_ports.clone(),
};
let _ = info.config_tx.try_send(update);
}
}
}
*map = new_map;
}
/// Get the current hub status.
pub async fn get_status(&self) -> HubStatus {
let running = *self.running.read().await;
let config = self.config.read().await;
let edges = self.connected_edges.lock().await;
let mut connected = Vec::new();
for (id, info) in edges.iter() {
connected.push(ConnectedEdgeStatus {
edge_id: id.clone(),
connected_at: info.connected_at,
active_streams: info.edge_stream_count.load(Ordering::Relaxed) as usize,
peer_addr: info.peer_addr.clone(),
});
}
HubStatus {
running,
tunnel_port: config.tunnel_port,
connected_edges: connected,
}
}
/// Start the hub — listen for TLS connections from edges.
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let config = self.config.read().await.clone();
let tls_config = build_tls_config(&config)?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let listener = TcpListener::bind(("0.0.0.0", config.tunnel_port)).await?;
log::info!("Hub listening on port {}", config.tunnel_port);
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
*self.shutdown_tx.lock().await = Some(shutdown_tx);
*self.running.write().await = true;
let allowed = self.allowed_edges.clone();
let connected = self.connected_edges.clone();
let event_tx = self.event_tx.clone();
let target_host = config.target_host.unwrap_or_else(|| "127.0.0.1".to_string());
let hub_token = self.cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
log::info!("Edge connection from {}", addr);
let acceptor = acceptor.clone();
let allowed = allowed.clone();
let connected = connected.clone();
let event_tx = event_tx.clone();
let target = target_host.clone();
let edge_token = hub_token.child_token();
let peer_addr = addr.ip().to_string();
tokio::spawn(async move {
if let Err(e) = handle_edge_connection(
stream, acceptor, allowed, connected, event_tx, target, edge_token, peer_addr,
).await {
log::error!("Edge connection error: {}", e);
}
});
}
Err(e) => {
log::error!("Accept error: {}", e);
}
}
}
_ = hub_token.cancelled() => {
log::info!("Hub shutting down (token cancelled)");
break;
}
_ = shutdown_rx.recv() => {
log::info!("Hub shutting down");
break;
}
}
}
});
Ok(())
}
/// Stop the hub.
pub async fn stop(&self) {
self.cancel_token.cancel();
if let Some(tx) = self.shutdown_tx.lock().await.take() {
let _ = tx.send(()).await;
}
*self.running.write().await = false;
// Clear connected edges
self.connected_edges.lock().await.clear();
}
}
impl Drop for TunnelHub {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
/// Maximum concurrent streams per edge connection.
const MAX_STREAMS_PER_EDGE: usize = 1024;
/// Process a single frame received from the edge side of the tunnel.
/// Handles FRAME_OPEN, FRAME_DATA, FRAME_WINDOW_UPDATE, FRAME_CLOSE, and FRAME_PONG.
async fn handle_hub_frame(
frame: Frame,
tunnel_io: &mut remoteingress_protocol::TunnelIo<HubTlsStream>,
streams: &mut HashMap<u32, HubStreamState>,
stream_semaphore: &Arc<Semaphore>,
edge_stream_count: &Arc<AtomicU32>,
edge_id: &str,
event_tx: &mpsc::Sender<HubEvent>,
ctrl_tx: &mpsc::Sender<Vec<u8>>,
data_tx: &mpsc::Sender<Vec<u8>>,
target_host: &str,
edge_token: &CancellationToken,
cleanup_tx: &mpsc::Sender<u32>,
) -> FrameAction {
match frame.frame_type {
FRAME_OPEN => {
// A4: Check stream limit before processing
let permit = match stream_semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
log::warn!("Edge {} exceeded max streams ({}), rejecting stream {}",
edge_id, MAX_STREAMS_PER_EDGE, frame.stream_id);
let close_frame = encode_frame(frame.stream_id, FRAME_CLOSE_BACK, &[]);
tunnel_io.queue_ctrl(close_frame);
return FrameAction::Continue;
}
};
// Payload is PROXY v1 header line
let proxy_header = String::from_utf8_lossy(&frame.payload).to_string();
// Parse destination port from PROXY header
let dest_port = parse_dest_port_from_proxy(&proxy_header).unwrap_or(443);
let stream_id = frame.stream_id;
let cleanup = cleanup_tx.clone();
let writer_tx = ctrl_tx.clone(); // control: CLOSE_BACK, WINDOW_UPDATE_BACK
let data_writer_tx = data_tx.clone(); // data: DATA_BACK
let target = target_host.to_string();
let stream_token = edge_token.child_token();
let _ = event_tx.try_send(HubEvent::StreamOpened {
edge_id: edge_id.to_string(),
stream_id,
});
// Create channel for data from edge to this stream
let (stream_data_tx, mut stream_data_rx) = mpsc::channel::<Vec<u8>>(PER_STREAM_DATA_CAPACITY);
// Adaptive initial window: scale with current stream count
// to keep total in-flight data within the 32MB budget.
let initial_window = compute_window_for_stream_count(
edge_stream_count.load(Ordering::Relaxed),
);
let send_window = Arc::new(AtomicU32::new(initial_window));
let window_notify = Arc::new(Notify::new());
streams.insert(stream_id, HubStreamState {
data_tx: stream_data_tx,
cancel_token: stream_token.clone(),
send_window: Arc::clone(&send_window),
window_notify: Arc::clone(&window_notify),
});
// Spawn task: connect to SmartProxy, send PROXY header, pipe data
let stream_counter = Arc::clone(edge_stream_count);
tokio::spawn(async move {
let _permit = permit; // hold semaphore permit until stream completes
stream_counter.fetch_add(1, Ordering::Relaxed);
let result = async {
// A2: Connect to SmartProxy with timeout
let mut upstream = tokio::time::timeout(
Duration::from_secs(10),
TcpStream::connect((target.as_str(), dest_port)),
)
.await
.map_err(|_| -> Box<dyn std::error::Error + Send + Sync> {
format!("connect to SmartProxy {}:{} timed out (10s)", target, dest_port).into()
})??;
upstream.set_nodelay(true)?;
upstream.write_all(proxy_header.as_bytes()).await?;
let (mut up_read, mut up_write) =
upstream.into_split();
// Forward data from edge (via channel) to SmartProxy
// After writing to upstream, send WINDOW_UPDATE_BACK to edge
let writer_token = stream_token.clone();
let wub_tx = writer_tx.clone();
let stream_counter_w = Arc::clone(&stream_counter);
let writer_for_edge_data = tokio::spawn(async move {
let mut consumed_since_update: u32 = 0;
loop {
tokio::select! {
data = stream_data_rx.recv() => {
match data {
Some(data) => {
let len = data.len() as u32;
// Check cancellation alongside the write so we respond
// promptly to FRAME_CLOSE instead of blocking up to 60s.
let write_result = tokio::select! {
r = tokio::time::timeout(
Duration::from_secs(60),
up_write.write_all(&data),
) => r,
_ = writer_token.cancelled() => break,
};
match write_result {
Ok(Ok(())) => {}
Ok(Err(_)) => break,
Err(_) => {
log::warn!("Stream {} write to upstream timed out (60s)", stream_id);
break;
}
}
// Track consumption for adaptive flow control.
// Increment capped to adaptive window to limit per-stream in-flight data.
consumed_since_update += len;
let adaptive_window = remoteingress_protocol::compute_window_for_stream_count(
stream_counter_w.load(Ordering::Relaxed),
);
let threshold = adaptive_window / 2;
if consumed_since_update >= threshold {
let increment = consumed_since_update.min(adaptive_window);
let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE_BACK, increment);
if wub_tx.try_send(frame).is_ok() {
consumed_since_update -= increment;
}
// If try_send fails, keep accumulating — retry on next threshold
}
}
None => break,
}
}
_ = writer_token.cancelled() => break,
}
}
// Send final window update for remaining consumed bytes
if consumed_since_update > 0 {
let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE_BACK, consumed_since_update);
let _ = wub_tx.try_send(frame);
}
let _ = up_write.shutdown().await;
});
// Forward data from SmartProxy back to edge via writer channel
// with per-stream flow control (check send_window before reading).
// Zero-copy: read payload directly after the header, then prepend header.
let mut buf = vec![0u8; FRAME_HEADER_SIZE + 32768];
loop {
// Wait for send window to have capacity (with stall timeout).
// Safe pattern: register notified BEFORE checking the condition
// to avoid missing a notify_one that fires between load and select.
loop {
let notified = window_notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
let w = send_window.load(Ordering::Acquire);
if w > 0 { break; }
tokio::select! {
_ = notified => continue,
_ = stream_token.cancelled() => break,
_ = tokio::time::sleep(Duration::from_secs(120)) => {
log::warn!("Stream {} download stalled (window empty for 120s)", stream_id);
break;
}
}
}
if stream_token.is_cancelled() { break; }
// Limit read size to available window.
// IMPORTANT: if window is 0 (stall timeout fired), we must NOT
// read into an empty buffer — read(&mut buf[..0]) returns Ok(0)
// which would be falsely interpreted as EOF.
let w = send_window.load(Ordering::Acquire) as usize;
if w == 0 {
log::warn!("Stream {} download: window still 0 after stall timeout, closing", stream_id);
break;
}
// Adaptive: cap read to current per-stream target window
let adaptive_cap = remoteingress_protocol::compute_window_for_stream_count(
stream_counter.load(Ordering::Relaxed),
) as usize;
let max_read = w.min(32768).min(adaptive_cap);
tokio::select! {
read_result = up_read.read(&mut buf[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + max_read]) => {
match read_result {
Ok(0) => break,
Ok(n) => {
send_window.fetch_sub(n as u32, Ordering::Release);
encode_frame_header(&mut buf, stream_id, FRAME_DATA_BACK, n);
let frame = buf[..FRAME_HEADER_SIZE + n].to_vec();
if data_writer_tx.send(frame).await.is_err() {
log::warn!("Stream {} data channel closed, closing", stream_id);
break;
}
}
Err(_) => break,
}
}
_ = stream_token.cancelled() => break,
}
}
// Send CLOSE_BACK via DATA channel (must arrive AFTER last DATA_BACK).
// Use send().await to guarantee delivery (try_send silently drops if full).
if !stream_token.is_cancelled() {
let close_frame = encode_frame(stream_id, FRAME_CLOSE_BACK, &[]);
let _ = data_writer_tx.send(close_frame).await;
}
writer_for_edge_data.abort();
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
}
.await;
if let Err(e) = result {
log::error!("Stream {} error: {}", stream_id, e);
// Send CLOSE_BACK via DATA channel on error (must arrive after any DATA_BACK).
// Use send().await to guarantee delivery.
if !stream_token.is_cancelled() {
let close_frame = encode_frame(stream_id, FRAME_CLOSE_BACK, &[]);
let _ = data_writer_tx.send(close_frame).await;
}
}
// Signal main loop to remove stream from the map
let _ = cleanup.send(stream_id).await;
stream_counter.fetch_sub(1, Ordering::Relaxed);
});
}
FRAME_DATA => {
// Non-blocking dispatch to per-stream channel.
// With flow control, the sender should rarely exceed the channel capacity.
if let Some(state) = streams.get(&frame.stream_id) {
if state.data_tx.try_send(frame.payload).is_err() {
log::warn!("Stream {} data channel full, closing stream", frame.stream_id);
if let Some(state) = streams.remove(&frame.stream_id) {
state.cancel_token.cancel();
}
}
}
}
FRAME_WINDOW_UPDATE => {
// Edge consumed data — increase our send window for this stream
if let Some(increment) = decode_window_update(&frame.payload) {
if increment > 0 {
if let Some(state) = streams.get(&frame.stream_id) {
let prev = state.send_window.fetch_add(increment, Ordering::Release);
if prev + increment > MAX_WINDOW_SIZE {
state.send_window.store(MAX_WINDOW_SIZE, Ordering::Release);
}
state.window_notify.notify_one();
}
}
}
}
FRAME_CLOSE => {
if let Some(state) = streams.remove(&frame.stream_id) {
state.cancel_token.cancel();
let _ = event_tx.try_send(HubEvent::StreamClosed {
edge_id: edge_id.to_string(),
stream_id: frame.stream_id,
});
}
}
FRAME_PONG => {
log::debug!("Received PONG from edge {}", edge_id);
}
_ => {
log::warn!("Unexpected frame type {} from edge", frame.frame_type);
}
}
FrameAction::Continue
}
/// Handle a single edge connection: authenticate, then enter frame loop.
async fn handle_edge_connection(
stream: TcpStream,
acceptor: TlsAcceptor,
allowed: Arc<RwLock<HashMap<String, AllowedEdge>>>,
connected: Arc<Mutex<HashMap<String, ConnectedEdgeInfo>>>,
event_tx: mpsc::Sender<HubEvent>,
target_host: String,
edge_token: CancellationToken,
peer_addr: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Disable Nagle's algorithm for low-latency control frames (PING/PONG, WINDOW_UPDATE)
stream.set_nodelay(true)?;
// TCP keepalive detects silent network failures (NAT timeout, path change)
// faster than the 45s application-level liveness timeout.
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(30));
#[cfg(target_os = "linux")]
let ka = ka.with_interval(Duration::from_secs(10));
let _ = socket2::SockRef::from(&stream).set_tcp_keepalive(&ka);
let mut tls_stream = acceptor.accept(stream).await?;
// Byte-by-byte auth line reading (no BufReader).
// Auth line: "EDGE <edgeId> <secret>\n"
let mut auth_buf = Vec::with_capacity(512);
loop {
let mut byte = [0u8; 1];
tls_stream.read_exact(&mut byte).await?;
if byte[0] == b'\n' {
break;
}
auth_buf.push(byte[0]);
if auth_buf.len() > 4096 {
return Err("auth line too long".into());
}
}
let auth_line = String::from_utf8(auth_buf)
.map_err(|_| "auth line not valid UTF-8")?;
let auth_line = auth_line.trim();
let parts: Vec<&str> = auth_line.splitn(3, ' ').collect();
if parts.len() != 3 || parts[0] != "EDGE" {
return Err("invalid auth line".into());
}
let edge_id = parts[1].to_string();
let secret = parts[2];
// Verify credentials and extract edge config
let (listen_ports, stun_interval_secs) = {
let edges = allowed.read().await;
match edges.get(&edge_id) {
Some(edge) => {
if !constant_time_eq(secret.as_bytes(), edge.secret.as_bytes()) {
return Err(format!("invalid secret for edge {}", edge_id).into());
}
(edge.listen_ports.clone(), edge.stun_interval_secs.unwrap_or(300))
}
None => {
return Err(format!("unknown edge {}", edge_id).into());
}
}
};
log::info!("Edge {} authenticated from {}", edge_id, peer_addr);
let _ = event_tx.try_send(HubEvent::EdgeConnected {
edge_id: edge_id.clone(),
peer_addr: peer_addr.clone(),
});
// Send handshake response with initial config before frame protocol begins
let handshake = HandshakeResponse {
listen_ports: listen_ports.clone(),
stun_interval_secs,
};
let mut handshake_json = serde_json::to_string(&handshake)?;
handshake_json.push('\n');
tls_stream.write_all(handshake_json.as_bytes()).await?;
tls_stream.flush().await?;
// Track this edge
let mut streams: HashMap<u32, HubStreamState> = HashMap::new();
// Per-edge active stream counter for adaptive flow control
let edge_stream_count = Arc::new(AtomicU32::new(0));
// Cleanup channel: spawned stream tasks send stream_id here when done
let (cleanup_tx, mut cleanup_rx) = mpsc::channel::<u32>(256);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Create config update channel
let (config_tx, mut config_rx) = mpsc::channel::<EdgeConfigUpdate>(16);
{
let mut edges = connected.lock().await;
edges.insert(
edge_id.clone(),
ConnectedEdgeInfo {
connected_at: now,
peer_addr,
edge_stream_count: edge_stream_count.clone(),
config_tx,
cancel_token: edge_token.clone(),
},
);
}
// QoS dual-channel: ctrl frames have priority over data frames.
// Stream handlers send through these channels -> TunnelIo drains them.
let (ctrl_tx, mut ctrl_rx) = mpsc::channel::<Vec<u8>>(256);
let (data_tx, mut data_rx) = mpsc::channel::<Vec<u8>>(4096);
// Spawn task to forward config updates as FRAME_CONFIG frames
let config_writer_tx = ctrl_tx.clone();
let config_edge_id = edge_id.clone();
let config_token = edge_token.clone();
let config_handle = tokio::spawn(async move {
loop {
tokio::select! {
update = config_rx.recv() => {
match update {
Some(update) => {
if let Ok(payload) = serde_json::to_vec(&update) {
let frame = encode_frame(0, FRAME_CONFIG, &payload);
if config_writer_tx.send(frame).await.is_err() {
log::error!("Failed to send config update to edge {}", config_edge_id);
break;
}
log::info!("Sent config update to edge {}: ports {:?}", config_edge_id, update.listen_ports);
}
}
None => break,
}
}
_ = config_token.cancelled() => break,
}
}
});
// A4: Semaphore to limit concurrent streams per edge
let stream_semaphore = Arc::new(Semaphore::new(MAX_STREAMS_PER_EDGE));
// Heartbeat: periodic PING and liveness timeout
let ping_interval_dur = Duration::from_secs(15);
let liveness_timeout_dur = Duration::from_secs(45);
let mut ping_ticker = interval(ping_interval_dur);
ping_ticker.tick().await; // consume the immediate first tick
let mut last_activity = Instant::now();
let mut liveness_deadline = Box::pin(sleep_until(last_activity + liveness_timeout_dur));
// Single-owner I/O engine — no tokio::io::split, no mutex
let mut tunnel_io = remoteingress_protocol::TunnelIo::new(tls_stream, Vec::new());
let mut disconnect_reason = "unknown".to_string();
'hub_loop: loop {
// Drain completed stream cleanups from spawned tasks
while let Ok(stream_id) = cleanup_rx.try_recv() {
if streams.remove(&stream_id).is_some() {
let _ = event_tx.try_send(HubEvent::StreamClosed {
edge_id: edge_id.clone(),
stream_id,
});
}
}
// Drain any buffered frames
loop {
let frame = match tunnel_io.try_parse_frame() {
Some(Ok(f)) => f,
Some(Err(e)) => {
log::error!("Edge {} frame error: {}", edge_id, e);
disconnect_reason = format!("edge_frame_error: {}", e);
break 'hub_loop;
}
None => break,
};
last_activity = Instant::now();
liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur);
if let FrameAction::Disconnect(reason) = handle_hub_frame(
frame, &mut tunnel_io, &mut streams, &stream_semaphore, &edge_stream_count,
&edge_id, &event_tx, &ctrl_tx, &data_tx, &target_host, &edge_token,
&cleanup_tx,
).await {
disconnect_reason = reason;
break 'hub_loop;
}
}
// Poll I/O: write(ctrl->data), flush, read, channels, timers
let event = std::future::poll_fn(|cx| {
// Queue PING if ticker fires
if ping_ticker.poll_tick(cx).is_ready() {
tunnel_io.queue_ctrl(encode_frame(0, FRAME_PING, &[]));
}
tunnel_io.poll_step(cx, &mut ctrl_rx, &mut data_rx, &mut liveness_deadline, &edge_token)
}).await;
match event {
remoteingress_protocol::TunnelEvent::Frame(frame) => {
last_activity = Instant::now();
liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur);
if let FrameAction::Disconnect(reason) = handle_hub_frame(
frame, &mut tunnel_io, &mut streams, &stream_semaphore, &edge_stream_count,
&edge_id, &event_tx, &ctrl_tx, &data_tx, &target_host, &edge_token,
&cleanup_tx,
).await {
disconnect_reason = reason;
break;
}
}
remoteingress_protocol::TunnelEvent::Eof => {
log::info!("Edge {} disconnected (EOF)", edge_id);
disconnect_reason = "edge_eof".to_string();
break;
}
remoteingress_protocol::TunnelEvent::ReadError(e) => {
log::error!("Edge {} frame error: {}", edge_id, e);
disconnect_reason = format!("edge_frame_error: {}", e);
break;
}
remoteingress_protocol::TunnelEvent::WriteError(e) => {
log::error!("Tunnel write error to edge {}: {}", edge_id, e);
disconnect_reason = format!("tunnel_write_error: {}", e);
break;
}
remoteingress_protocol::TunnelEvent::LivenessTimeout => {
log::warn!("Edge {} liveness timeout (no frames for {}s), disconnecting",
edge_id, liveness_timeout_dur.as_secs());
disconnect_reason = "liveness_timeout".to_string();
break;
}
remoteingress_protocol::TunnelEvent::Cancelled => {
log::info!("Edge {} cancelled by hub", edge_id);
disconnect_reason = "cancelled_by_hub".to_string();
break;
}
}
}
// Cleanup: cancel edge token to propagate to all child tasks
edge_token.cancel();
config_handle.abort();
{
let mut edges = connected.lock().await;
edges.remove(&edge_id);
}
let _ = event_tx.try_send(HubEvent::EdgeDisconnected {
edge_id: edge_id.clone(),
reason: disconnect_reason,
});
Ok(())
}
/// Parse destination port from PROXY v1 header.
fn parse_dest_port_from_proxy(header: &str) -> Option<u16> {
let parts: Vec<&str> = header.trim().split_whitespace().collect();
if parts.len() >= 6 {
parts[5].parse().ok()
} else {
None
}
}
/// Build TLS server config from PEM strings, or auto-generate self-signed.
fn build_tls_config(
config: &HubConfig,
) -> Result<rustls::ServerConfig, Box<dyn std::error::Error + Send + Sync>> {
let (cert_pem, key_pem) = match (&config.tls_cert_pem, &config.tls_key_pem) {
(Some(cert), Some(key)) => (cert.clone(), key.clone()),
_ => {
// Generate self-signed certificate
let cert = rcgen::generate_simple_self_signed(vec!["remoteingress-hub".to_string()])?;
let cert_pem = cert.cert.pem();
let key_pem = cert.key_pair.serialize_pem();
(cert_pem, key_pem)
}
};
let certs = rustls_pemfile_parse_certs(&cert_pem)?;
let key = rustls_pemfile_parse_key(&key_pem)?;
let mut config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"remoteingress".to_vec()];
Ok(config)
}
fn rustls_pemfile_parse_certs(
pem: &str,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>>
{
let mut reader = std::io::Cursor::new(pem.as_bytes());
let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
Ok(certs)
}
fn rustls_pemfile_parse_key(
pem: &str,
) -> Result<rustls::pki_types::PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = std::io::Cursor::new(pem.as_bytes());
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("no private key found in PEM")?;
Ok(key)
}
/// Constant-time comparison of two byte slices.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
// --- constant_time_eq tests ---
#[test]
fn test_constant_time_eq_equal() {
assert!(constant_time_eq(b"hello", b"hello"));
}
#[test]
fn test_constant_time_eq_different_content() {
assert!(!constant_time_eq(b"hello", b"world"));
}
#[test]
fn test_constant_time_eq_different_lengths() {
assert!(!constant_time_eq(b"short", b"longer"));
}
#[test]
fn test_constant_time_eq_both_empty() {
assert!(constant_time_eq(b"", b""));
}
#[test]
fn test_constant_time_eq_one_empty() {
assert!(!constant_time_eq(b"", b"notempty"));
}
#[test]
fn test_constant_time_eq_single_bit_difference() {
// 'A' = 0x41, 'a' = 0x61 — differ by one bit
assert!(!constant_time_eq(b"A", b"a"));
}
// --- parse_dest_port_from_proxy tests ---
#[test]
fn test_parse_dest_port_443() {
let header = "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n";
assert_eq!(parse_dest_port_from_proxy(header), Some(443));
}
#[test]
fn test_parse_dest_port_80() {
let header = "PROXY TCP4 10.0.0.1 10.0.0.2 54321 80\r\n";
assert_eq!(parse_dest_port_from_proxy(header), Some(80));
}
#[test]
fn test_parse_dest_port_65535() {
let header = "PROXY TCP4 10.0.0.1 10.0.0.2 1 65535\r\n";
assert_eq!(parse_dest_port_from_proxy(header), Some(65535));
}
#[test]
fn test_parse_dest_port_too_few_fields() {
let header = "PROXY TCP4 1.2.3.4";
assert_eq!(parse_dest_port_from_proxy(header), None);
}
#[test]
fn test_parse_dest_port_empty_string() {
assert_eq!(parse_dest_port_from_proxy(""), None);
}
#[test]
fn test_parse_dest_port_non_numeric() {
let header = "PROXY TCP4 1.2.3.4 5.6.7.8 12345 abc\r\n";
assert_eq!(parse_dest_port_from_proxy(header), None);
}
// --- Serde tests ---
#[test]
fn test_allowed_edge_deserialize_all_fields() {
let json = r#"{
"id": "edge-1",
"secret": "s3cret",
"listenPorts": [443, 8080],
"stunIntervalSecs": 120
}"#;
let edge: AllowedEdge = serde_json::from_str(json).unwrap();
assert_eq!(edge.id, "edge-1");
assert_eq!(edge.secret, "s3cret");
assert_eq!(edge.listen_ports, vec![443, 8080]);
assert_eq!(edge.stun_interval_secs, Some(120));
}
#[test]
fn test_allowed_edge_deserialize_with_defaults() {
let json = r#"{"id": "edge-2", "secret": "key"}"#;
let edge: AllowedEdge = serde_json::from_str(json).unwrap();
assert_eq!(edge.id, "edge-2");
assert_eq!(edge.secret, "key");
assert!(edge.listen_ports.is_empty());
assert_eq!(edge.stun_interval_secs, None);
}
#[test]
fn test_handshake_response_serializes_camel_case() {
let resp = HandshakeResponse {
listen_ports: vec![443, 8080],
stun_interval_secs: 300,
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["listenPorts"], serde_json::json!([443, 8080]));
assert_eq!(json["stunIntervalSecs"], 300);
// Ensure snake_case keys are NOT present
assert!(json.get("listen_ports").is_none());
assert!(json.get("stun_interval_secs").is_none());
}
#[test]
fn test_edge_config_update_serializes_camel_case() {
let update = EdgeConfigUpdate {
listen_ports: vec![80, 443],
};
let json = serde_json::to_value(&update).unwrap();
assert_eq!(json["listenPorts"], serde_json::json!([80, 443]));
assert!(json.get("listen_ports").is_none());
}
#[test]
fn test_hub_config_default() {
let config = HubConfig::default();
assert_eq!(config.tunnel_port, 8443);
assert_eq!(config.target_host, Some("127.0.0.1".to_string()));
assert!(config.tls_cert_pem.is_none());
assert!(config.tls_key_pem.is_none());
}
#[test]
fn test_hub_event_edge_connected_serialize() {
let event = HubEvent::EdgeConnected {
edge_id: "edge-1".to_string(),
peer_addr: "203.0.113.5".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "edgeConnected");
assert_eq!(json["edgeId"], "edge-1");
assert_eq!(json["peerAddr"], "203.0.113.5");
}
#[test]
fn test_hub_event_edge_disconnected_serialize() {
let event = HubEvent::EdgeDisconnected {
edge_id: "edge-2".to_string(),
reason: "liveness_timeout".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "edgeDisconnected");
assert_eq!(json["edgeId"], "edge-2");
assert_eq!(json["reason"], "liveness_timeout");
}
#[test]
fn test_hub_event_stream_opened_serialize() {
let event = HubEvent::StreamOpened {
edge_id: "e".to_string(),
stream_id: 42,
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "streamOpened");
assert_eq!(json["edgeId"], "e");
assert_eq!(json["streamId"], 42);
}
#[test]
fn test_hub_event_stream_closed_serialize() {
let event = HubEvent::StreamClosed {
edge_id: "e".to_string(),
stream_id: 7,
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "streamClosed");
assert_eq!(json["edgeId"], "e");
assert_eq!(json["streamId"], 7);
}
// --- Async tests ---
#[tokio::test]
async fn test_tunnel_hub_new_get_status() {
let hub = TunnelHub::new(HubConfig::default());
let status = hub.get_status().await;
assert!(!status.running);
assert!(status.connected_edges.is_empty());
assert_eq!(status.tunnel_port, 8443);
}
#[tokio::test]
async fn test_tunnel_hub_take_event_rx() {
let hub = TunnelHub::new(HubConfig::default());
let rx1 = hub.take_event_rx().await;
assert!(rx1.is_some());
let rx2 = hub.take_event_rx().await;
assert!(rx2.is_none());
}
#[tokio::test]
async fn test_tunnel_hub_stop_without_start() {
let hub = TunnelHub::new(HubConfig::default());
hub.stop().await; // should not panic
let status = hub.get_status().await;
assert!(!status.running);
}
}