Files
remoteingress/rust/crates/remoteingress-core/src/edge.rs

2014 lines
82 KiB
Rust

use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{mpsc, Mutex, Notify, RwLock};
use tokio::task::JoinHandle;
use tokio::time::{Instant, sleep_until};
use tokio_rustls::TlsConnector;
use tokio_util::sync::CancellationToken;
use serde::{Deserialize, Serialize};
use bytes::Bytes;
use remoteingress_protocol::*;
use crate::transport::TransportMode;
use crate::transport::quic as quic_transport;
use crate::udp_session::{UdpSessionKey, UdpSessionManager};
type EdgeTlsStream = tokio_rustls::client::TlsStream<TcpStream>;
/// Result of processing a frame (shared with hub.rs pattern).
#[allow(dead_code)]
enum EdgeFrameAction {
Continue,
Disconnect(String),
}
/// Per-stream state tracked in the edge's client_writers map.
struct EdgeStreamState {
/// Unbounded channel to deliver FRAME_DATA_BACK payloads to the hub_to_client task.
/// Unbounded because flow control (WINDOW_UPDATE) already limits bytes-in-flight.
back_tx: mpsc::UnboundedSender<Bytes>,
/// Send window for FRAME_DATA (upload direction).
/// Decremented by the client reader, incremented by FRAME_WINDOW_UPDATE_BACK from hub.
send_window: Arc<AtomicU32>,
/// Notifier to wake the client reader when the window opens.
window_notify: Arc<Notify>,
}
/// Edge configuration (hub-host + credentials only; ports come from hub).
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EdgeConfig {
pub hub_host: String,
pub hub_port: u16,
pub edge_id: String,
pub secret: String,
/// Optional bind address for TCP listeners (defaults to "0.0.0.0").
/// Useful for testing on localhost where edge and upstream share the same machine.
#[serde(default)]
pub bind_address: Option<String>,
/// Transport mode for the tunnel connection (defaults to TcpTls).
#[serde(default)]
pub transport_mode: Option<TransportMode>,
}
/// Handshake config received from hub after authentication.
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HandshakeConfig {
listen_ports: Vec<u16>,
#[serde(default)]
listen_ports_udp: Vec<u16>,
#[serde(default = "default_stun_interval")]
stun_interval_secs: u64,
}
fn default_stun_interval() -> u64 {
300
}
/// Runtime config update received from hub via FRAME_CONFIG.
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ConfigUpdate {
listen_ports: Vec<u16>,
#[serde(default)]
listen_ports_udp: Vec<u16>,
}
/// Events emitted by the edge.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(tag = "type")]
pub enum EdgeEvent {
TunnelConnected,
#[serde(rename_all = "camelCase")]
TunnelDisconnected { reason: String },
#[serde(rename_all = "camelCase")]
PublicIpDiscovered { ip: String },
#[serde(rename_all = "camelCase")]
PortsAssigned { listen_ports: Vec<u16> },
#[serde(rename_all = "camelCase")]
PortsUpdated { listen_ports: Vec<u16> },
}
/// Edge status response.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EdgeStatus {
pub running: bool,
pub connected: bool,
pub public_ip: Option<String>,
pub active_streams: usize,
pub listen_ports: Vec<u16>,
}
/// The tunnel edge that listens for client connections and multiplexes them to the hub.
pub struct TunnelEdge {
config: RwLock<EdgeConfig>,
event_tx: mpsc::Sender<EdgeEvent>,
event_rx: Mutex<Option<mpsc::Receiver<EdgeEvent>>>,
shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
running: RwLock<bool>,
connected: Arc<RwLock<bool>>,
public_ip: Arc<RwLock<Option<String>>>,
active_streams: Arc<AtomicU32>,
next_stream_id: Arc<AtomicU32>,
listen_ports: Arc<RwLock<Vec<u16>>>,
cancel_token: CancellationToken,
}
impl TunnelEdge {
pub fn new(config: EdgeConfig) -> Self {
let (event_tx, event_rx) = mpsc::channel(1024);
Self {
config: RwLock::new(config),
event_tx,
event_rx: Mutex::new(Some(event_rx)),
shutdown_tx: Mutex::new(None),
running: RwLock::new(false),
connected: Arc::new(RwLock::new(false)),
public_ip: Arc::new(RwLock::new(None)),
active_streams: Arc::new(AtomicU32::new(0)),
next_stream_id: Arc::new(AtomicU32::new(1)),
listen_ports: Arc::new(RwLock::new(Vec::new())),
cancel_token: CancellationToken::new(),
}
}
/// Take the event receiver (can only be called once).
pub async fn take_event_rx(&self) -> Option<mpsc::Receiver<EdgeEvent>> {
self.event_rx.lock().await.take()
}
/// Get the current edge status.
pub async fn get_status(&self) -> EdgeStatus {
EdgeStatus {
running: *self.running.read().await,
connected: *self.connected.read().await,
public_ip: self.public_ip.read().await.clone(),
active_streams: self.active_streams.load(Ordering::Relaxed) as usize,
listen_ports: self.listen_ports.read().await.clone(),
}
}
/// Start the edge: connect to hub, start listeners.
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let config = self.config.read().await.clone();
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
*self.shutdown_tx.lock().await = Some(shutdown_tx);
*self.running.write().await = true;
let connected = self.connected.clone();
let public_ip = self.public_ip.clone();
let active_streams = self.active_streams.clone();
let next_stream_id = self.next_stream_id.clone();
let event_tx = self.event_tx.clone();
let listen_ports = self.listen_ports.clone();
let cancel_token = self.cancel_token.clone();
tokio::spawn(async move {
edge_main_loop(
config,
connected,
public_ip,
active_streams,
next_stream_id,
event_tx,
listen_ports,
shutdown_rx,
cancel_token,
)
.await;
});
Ok(())
}
/// Stop the edge.
pub async fn stop(&self) {
self.cancel_token.cancel();
if let Some(tx) = self.shutdown_tx.lock().await.take() {
let _ = tx.send(()).await;
}
*self.running.write().await = false;
*self.connected.write().await = false;
self.listen_ports.write().await.clear();
}
}
impl Drop for TunnelEdge {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
async fn edge_main_loop(
config: EdgeConfig,
connected: Arc<RwLock<bool>>,
public_ip: Arc<RwLock<Option<String>>>,
active_streams: Arc<AtomicU32>,
next_stream_id: Arc<AtomicU32>,
event_tx: mpsc::Sender<EdgeEvent>,
listen_ports: Arc<RwLock<Vec<u16>>>,
mut shutdown_rx: mpsc::Receiver<()>,
cancel_token: CancellationToken,
) {
let mut backoff_ms: u64 = 1000;
let max_backoff_ms: u64 = 30000;
let transport_mode = config.transport_mode.unwrap_or(TransportMode::TcpTls);
// Build TLS config ONCE outside the reconnect loop — preserves session
// cache across reconnections for TLS session resumption (saves 1 RTT).
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(tls_config));
// Build QUIC client config ONCE (shares session cache across reconnections).
let quic_client_config = quic_transport::build_quic_client_config();
let quic_endpoint = if matches!(transport_mode, TransportMode::Quic | TransportMode::QuicWithFallback) {
match quinn::Endpoint::client("0.0.0.0:0".parse().unwrap()) {
Ok(mut ep) => {
ep.set_default_client_config(quic_client_config);
Some(ep)
}
Err(e) => {
log::error!("Failed to create QUIC endpoint: {}", e);
None
}
}
} else {
None
};
loop {
// Create a per-connection child token
let connection_token = cancel_token.child_token();
// Try to connect to hub using the configured transport
let result = match transport_mode {
TransportMode::TcpTls => {
connect_to_hub_and_run(
&config, &connected, &public_ip, &active_streams, &next_stream_id,
&event_tx, &listen_ports, &mut shutdown_rx, &connection_token, &connector,
).await
}
TransportMode::Quic => {
if let Some(ep) = &quic_endpoint {
connect_to_hub_and_run_quic(
&config, &connected, &public_ip, &active_streams, &next_stream_id,
&event_tx, &listen_ports, &mut shutdown_rx, &connection_token, ep,
).await
} else {
EdgeLoopResult::Reconnect("quic_endpoint_unavailable".to_string())
}
}
TransportMode::QuicWithFallback => {
if let Some(ep) = &quic_endpoint {
// Try QUIC first with a 5s timeout
let quic_result = tokio::time::timeout(
Duration::from_secs(5),
connect_to_hub_quic_handshake(&config, ep, &connection_token),
).await;
match quic_result {
Ok(Ok(quic_conn)) => {
connect_to_hub_and_run_quic_with_connection(
&config, &connected, &public_ip, &active_streams, &next_stream_id,
&event_tx, &listen_ports, &mut shutdown_rx, &connection_token,
quic_conn,
).await
}
_ => {
log::info!("QUIC connect failed or timed out, falling back to TCP+TLS");
connect_to_hub_and_run(
&config, &connected, &public_ip, &active_streams, &next_stream_id,
&event_tx, &listen_ports, &mut shutdown_rx, &connection_token, &connector,
).await
}
}
} else {
// No QUIC endpoint, fall back to TCP+TLS
connect_to_hub_and_run(
&config, &connected, &public_ip, &active_streams, &next_stream_id,
&event_tx, &listen_ports, &mut shutdown_rx, &connection_token, &connector,
).await
}
}
};
// Cancel connection token to kill all orphaned tasks from this cycle
connection_token.cancel();
// Reset backoff after a successful connection for fast reconnect
let was_connected = *connected.read().await;
if was_connected {
backoff_ms = 1000;
log::info!("Was connected; resetting backoff to {}ms for fast reconnect", backoff_ms);
}
*connected.write().await = false;
// Extract reason for disconnect event
let reason = match &result {
EdgeLoopResult::Reconnect(r) => r.clone(),
EdgeLoopResult::Shutdown => "shutdown".to_string(),
};
// Only emit disconnect event on actual disconnection, not on failed reconnects.
// Failed reconnects never reach line 335 (handshake success), so was_connected is false.
if was_connected {
let _ = event_tx.try_send(EdgeEvent::TunnelDisconnected { reason: reason.clone() });
}
active_streams.store(0, Ordering::Relaxed);
// Reset stream ID counter for next connection cycle
next_stream_id.store(1, Ordering::Relaxed);
listen_ports.write().await.clear();
match result {
EdgeLoopResult::Shutdown => break,
EdgeLoopResult::Reconnect(_) => {
log::info!("Reconnecting in {}ms...", backoff_ms);
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
_ = cancel_token.cancelled() => break,
_ = shutdown_rx.recv() => break,
}
backoff_ms = (backoff_ms * 2).min(max_backoff_ms);
}
}
}
}
enum EdgeLoopResult {
Shutdown,
Reconnect(String), // reason for disconnection
}
/// Process a single frame received from the hub side of the tunnel.
/// Handles FRAME_DATA_BACK, FRAME_WINDOW_UPDATE_BACK, FRAME_CLOSE_BACK, FRAME_CONFIG, FRAME_PING,
/// and UDP frames: FRAME_UDP_DATA_BACK, FRAME_UDP_CLOSE.
async fn handle_edge_frame(
frame: Frame,
tunnel_io: &mut remoteingress_protocol::TunnelIo<EdgeTlsStream>,
client_writers: &Arc<Mutex<HashMap<u32, EdgeStreamState>>>,
listen_ports: &Arc<RwLock<Vec<u16>>>,
event_tx: &mpsc::Sender<EdgeEvent>,
tunnel_writer_tx: &mpsc::Sender<Bytes>,
tunnel_data_tx: &mpsc::Sender<Bytes>,
tunnel_sustained_tx: &mpsc::Sender<Bytes>,
port_listeners: &mut HashMap<u16, JoinHandle<()>>,
udp_listeners: &mut HashMap<u16, JoinHandle<()>>,
active_streams: &Arc<AtomicU32>,
next_stream_id: &Arc<AtomicU32>,
edge_id: &str,
connection_token: &CancellationToken,
bind_address: &str,
udp_sessions: &Arc<Mutex<UdpSessionManager>>,
udp_sockets: &Arc<Mutex<HashMap<u16, Arc<UdpSocket>>>>,
) -> EdgeFrameAction {
match frame.frame_type {
FRAME_DATA_BACK => {
// Dispatch to per-stream unbounded channel. Flow control (WINDOW_UPDATE)
// limits bytes-in-flight, so the channel won't grow unbounded. send() only
// fails if the receiver is dropped (hub_to_client task already exited).
let mut writers = client_writers.lock().await;
if let Some(state) = writers.get(&frame.stream_id) {
if state.back_tx.send(frame.payload).is_err() {
// Receiver dropped — hub_to_client task already exited, clean up
writers.remove(&frame.stream_id);
}
}
}
FRAME_WINDOW_UPDATE_BACK => {
if let Some(increment) = decode_window_update(&frame.payload) {
if increment > 0 {
let writers = client_writers.lock().await;
if let Some(state) = writers.get(&frame.stream_id) {
let prev = state.send_window.fetch_add(increment, Ordering::Release);
if prev + increment > MAX_WINDOW_SIZE {
state.send_window.store(MAX_WINDOW_SIZE, Ordering::Release);
}
state.window_notify.notify_one();
}
}
}
}
FRAME_CLOSE_BACK => {
let mut writers = client_writers.lock().await;
writers.remove(&frame.stream_id);
}
FRAME_CONFIG => {
if let Ok(update) = serde_json::from_slice::<ConfigUpdate>(&frame.payload) {
log::info!("Config update from hub: ports {:?}, udp {:?}", update.listen_ports, update.listen_ports_udp);
*listen_ports.write().await = update.listen_ports.clone();
let _ = event_tx.try_send(EdgeEvent::PortsUpdated {
listen_ports: update.listen_ports.clone(),
});
apply_port_config(
&update.listen_ports,
port_listeners,
tunnel_writer_tx,
tunnel_data_tx,
tunnel_sustained_tx,
client_writers,
active_streams,
next_stream_id,
edge_id,
connection_token,
bind_address,
);
apply_udp_port_config(
&update.listen_ports_udp,
udp_listeners,
tunnel_writer_tx,
tunnel_data_tx,
udp_sessions,
udp_sockets,
next_stream_id,
connection_token,
bind_address,
);
}
}
FRAME_PING => {
// Queue PONG directly — no channel round-trip, guaranteed delivery
tunnel_io.queue_ctrl(encode_frame(0, FRAME_PONG, &[]));
}
FRAME_UDP_DATA_BACK => {
// Dispatch return UDP datagram to the original client
let mut sessions = udp_sessions.lock().await;
if let Some(session) = sessions.get_by_stream_id(frame.stream_id) {
let client_addr = session.client_addr;
let dest_port = session.dest_port;
let sockets = udp_sockets.lock().await;
if let Some(socket) = sockets.get(&dest_port) {
let _ = socket.send_to(&frame.payload, client_addr).await;
}
}
}
FRAME_UDP_CLOSE => {
let mut sessions = udp_sessions.lock().await;
sessions.remove_by_stream_id(frame.stream_id);
}
_ => {
log::warn!("Unexpected frame type {} from hub", frame.frame_type);
}
}
EdgeFrameAction::Continue
}
async fn connect_to_hub_and_run(
config: &EdgeConfig,
connected: &Arc<RwLock<bool>>,
public_ip: &Arc<RwLock<Option<String>>>,
active_streams: &Arc<AtomicU32>,
next_stream_id: &Arc<AtomicU32>,
event_tx: &mpsc::Sender<EdgeEvent>,
listen_ports: &Arc<RwLock<Vec<u16>>>,
shutdown_rx: &mut mpsc::Receiver<()>,
connection_token: &CancellationToken,
connector: &TlsConnector,
) -> EdgeLoopResult {
let addr = format!("{}:{}", config.hub_host, config.hub_port);
let tcp = match TcpStream::connect(&addr).await {
Ok(s) => {
// Disable Nagle's algorithm for low-latency control frames (PING/PONG, WINDOW_UPDATE)
let _ = s.set_nodelay(true);
// TCP keepalive detects silent network failures (NAT timeout, path change)
// faster than the 45s application-level liveness timeout.
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(30));
#[cfg(target_os = "linux")]
let ka = ka.with_interval(Duration::from_secs(10));
let _ = socket2::SockRef::from(&s).set_tcp_keepalive(&ka);
s
}
Err(e) => {
log::error!("Failed to connect to hub at {}: {}", addr, e);
return EdgeLoopResult::Reconnect(format!("tcp_connect_failed: {}", e));
}
};
let server_name = rustls::pki_types::ServerName::try_from(config.hub_host.clone())
.unwrap_or_else(|_| rustls::pki_types::ServerName::try_from("remoteingress-hub".to_string()).unwrap());
let mut tls_stream = match connector.connect(server_name, tcp).await {
Ok(s) => s,
Err(e) => {
log::error!("TLS handshake failed: {}", e);
return EdgeLoopResult::Reconnect(format!("tls_handshake_failed: {}", e));
}
};
// Send auth line (we own the whole stream — no split)
let auth_line = format!("EDGE {} {}\n", config.edge_id, config.secret);
if tls_stream.write_all(auth_line.as_bytes()).await.is_err() {
return EdgeLoopResult::Reconnect("auth_write_failed".to_string());
}
if tls_stream.flush().await.is_err() {
return EdgeLoopResult::Reconnect("auth_flush_failed".to_string());
}
// Read handshake line byte-by-byte (no BufReader — into_inner corrupts TLS state)
let mut handshake_bytes = Vec::with_capacity(512);
let mut byte = [0u8; 1];
loop {
match tls_stream.read_exact(&mut byte).await {
Ok(_) => {
handshake_bytes.push(byte[0]);
if byte[0] == b'\n' { break; }
if handshake_bytes.len() > 8192 {
return EdgeLoopResult::Reconnect("handshake_too_long".to_string());
}
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
log::error!("Hub rejected connection (EOF before handshake)");
return EdgeLoopResult::Reconnect("hub_rejected_eof".to_string());
}
Err(e) => {
log::error!("Failed to read handshake response: {}", e);
return EdgeLoopResult::Reconnect(format!("handshake_read_failed: {}", e));
}
}
}
let handshake_line = String::from_utf8_lossy(&handshake_bytes);
let handshake: HandshakeConfig = match serde_json::from_str(handshake_line.trim()) {
Ok(h) => h,
Err(e) => {
log::error!("Invalid handshake response: {}", e);
return EdgeLoopResult::Reconnect(format!("handshake_invalid: {}", e));
}
};
log::info!(
"Handshake from hub: ports {:?}, stun_interval {}s",
handshake.listen_ports,
handshake.stun_interval_secs
);
*connected.write().await = true;
let _ = event_tx.try_send(EdgeEvent::TunnelConnected);
log::info!("Connected to hub at {}", addr);
// Store initial ports and emit event
*listen_ports.write().await = handshake.listen_ports.clone();
let _ = event_tx.try_send(EdgeEvent::PortsAssigned {
listen_ports: handshake.listen_ports.clone(),
});
// Start STUN discovery
let stun_interval = handshake.stun_interval_secs;
let public_ip_clone = public_ip.clone();
let event_tx_clone = event_tx.clone();
let stun_token = connection_token.clone();
let stun_handle = tokio::spawn(async move {
loop {
tokio::select! {
ip_result = crate::stun::discover_public_ip() => {
if let Some(ip) = ip_result {
let mut pip = public_ip_clone.write().await;
let changed = pip.as_ref() != Some(&ip);
*pip = Some(ip.clone());
if changed {
let _ = event_tx_clone.try_send(EdgeEvent::PublicIpDiscovered { ip });
}
}
}
_ = stun_token.cancelled() => break,
}
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(stun_interval)) => {}
_ = stun_token.cancelled() => break,
}
}
});
// Client socket map: stream_id -> per-stream state (back channel + flow control)
let client_writers: Arc<Mutex<HashMap<u32, EdgeStreamState>>> =
Arc::new(Mutex::new(HashMap::new()));
// QoS dual-channel: ctrl frames have priority over data frames.
// Stream handlers send through these channels → TunnelIo drains them.
let (tunnel_ctrl_tx, mut tunnel_ctrl_rx) = mpsc::channel::<Bytes>(512);
let (tunnel_data_tx, mut tunnel_data_rx) = mpsc::channel::<Bytes>(4096);
let (tunnel_sustained_tx, mut tunnel_sustained_rx) = mpsc::channel::<Bytes>(4096);
let tunnel_writer_tx = tunnel_ctrl_tx.clone();
// Start TCP listeners for initial ports
let mut port_listeners: HashMap<u16, JoinHandle<()>> = HashMap::new();
let bind_address = config.bind_address.as_deref().unwrap_or("0.0.0.0");
apply_port_config(
&handshake.listen_ports,
&mut port_listeners,
&tunnel_writer_tx,
&tunnel_data_tx,
&tunnel_sustained_tx,
&client_writers,
active_streams,
next_stream_id,
&config.edge_id,
connection_token,
bind_address,
);
// UDP session manager + listeners
let udp_sessions: Arc<Mutex<UdpSessionManager>> =
Arc::new(Mutex::new(UdpSessionManager::new(Duration::from_secs(60))));
let udp_sockets: Arc<Mutex<HashMap<u16, Arc<UdpSocket>>>> =
Arc::new(Mutex::new(HashMap::new()));
let mut udp_listeners: HashMap<u16, JoinHandle<()>> = HashMap::new();
apply_udp_port_config(
&handshake.listen_ports_udp,
&mut udp_listeners,
&tunnel_ctrl_tx,
&tunnel_data_tx,
&udp_sessions,
&udp_sockets,
next_stream_id,
connection_token,
bind_address,
);
// Single-owner I/O engine — no tokio::io::split, no mutex
let mut tunnel_io = remoteingress_protocol::TunnelIo::new(tls_stream, Vec::new());
let liveness_timeout_dur = Duration::from_secs(45);
let mut last_activity = Instant::now();
let mut liveness_deadline = Box::pin(sleep_until(last_activity + liveness_timeout_dur));
let result = 'io_loop: loop {
// Drain any buffered frames
loop {
let frame = match tunnel_io.try_parse_frame() {
Some(Ok(f)) => f,
Some(Err(e)) => {
log::error!("Hub frame error: {}", e);
break 'io_loop EdgeLoopResult::Reconnect(format!("hub_frame_error: {}", e));
}
None => break,
};
last_activity = Instant::now();
liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur);
if let EdgeFrameAction::Disconnect(reason) = handle_edge_frame(
frame, &mut tunnel_io, &client_writers, listen_ports, event_tx,
&tunnel_writer_tx, &tunnel_data_tx, &tunnel_sustained_tx, &mut port_listeners,
&mut udp_listeners, active_streams, next_stream_id, &config.edge_id,
connection_token, bind_address, &udp_sessions, &udp_sockets,
).await {
break 'io_loop EdgeLoopResult::Reconnect(reason);
}
}
// Poll I/O: write(ctrl→data), flush, read, channels, timers
let event = std::future::poll_fn(|cx| {
tunnel_io.poll_step(cx, &mut tunnel_ctrl_rx, &mut tunnel_data_rx, &mut tunnel_sustained_rx, &mut liveness_deadline, connection_token)
}).await;
match event {
remoteingress_protocol::TunnelEvent::Frame(frame) => {
last_activity = Instant::now();
liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur);
if let EdgeFrameAction::Disconnect(reason) = handle_edge_frame(
frame, &mut tunnel_io, &client_writers, listen_ports, event_tx,
&tunnel_writer_tx, &tunnel_data_tx, &tunnel_sustained_tx, &mut port_listeners,
&mut udp_listeners, active_streams, next_stream_id, &config.edge_id,
connection_token, bind_address, &udp_sessions, &udp_sockets,
).await {
break EdgeLoopResult::Reconnect(reason);
}
}
remoteingress_protocol::TunnelEvent::Eof => {
log::info!("Hub disconnected (EOF)");
break EdgeLoopResult::Reconnect("hub_eof".to_string());
}
remoteingress_protocol::TunnelEvent::ReadError(e) => {
log::error!("Hub frame read error: {}", e);
break EdgeLoopResult::Reconnect(format!("hub_frame_error: {}", e));
}
remoteingress_protocol::TunnelEvent::WriteError(e) => {
log::error!("Tunnel write error: {}", e);
break EdgeLoopResult::Reconnect(format!("tunnel_write_error: {}", e));
}
remoteingress_protocol::TunnelEvent::LivenessTimeout => {
log::warn!("Hub liveness timeout (no frames for {}s), reconnecting", liveness_timeout_dur.as_secs());
break EdgeLoopResult::Reconnect("liveness_timeout".to_string());
}
remoteingress_protocol::TunnelEvent::Cancelled => {
if shutdown_rx.try_recv().is_ok() {
break EdgeLoopResult::Shutdown;
}
break EdgeLoopResult::Shutdown;
}
}
};
// Cancel stream tokens FIRST so stream handlers exit immediately.
// If we TLS-shutdown first, stream handlers are stuck sending to dead channels
// for up to 2 seconds while the shutdown times out on a dead connection.
connection_token.cancel();
stun_handle.abort();
for (_, h) in port_listeners.drain() {
h.abort();
}
for (_, h) in udp_listeners.drain() {
h.abort();
}
// Graceful TLS shutdown: send close_notify so the hub sees a clean disconnect.
// Stream handlers are already cancelled, so no new data is being produced.
let mut tls_stream = tunnel_io.into_inner();
let _ = tokio::time::timeout(
Duration::from_secs(2),
tls_stream.shutdown(),
).await;
result
}
/// Apply a new port configuration: spawn listeners for added ports, abort removed ports.
fn apply_port_config(
new_ports: &[u16],
port_listeners: &mut HashMap<u16, JoinHandle<()>>,
tunnel_ctrl_tx: &mpsc::Sender<Bytes>,
tunnel_data_tx: &mpsc::Sender<Bytes>,
tunnel_sustained_tx: &mpsc::Sender<Bytes>,
client_writers: &Arc<Mutex<HashMap<u32, EdgeStreamState>>>,
active_streams: &Arc<AtomicU32>,
next_stream_id: &Arc<AtomicU32>,
edge_id: &str,
connection_token: &CancellationToken,
bind_address: &str,
) {
let new_set: std::collections::HashSet<u16> = new_ports.iter().copied().collect();
let old_set: std::collections::HashSet<u16> = port_listeners.keys().copied().collect();
// Remove ports no longer needed
for &port in old_set.difference(&new_set) {
if let Some(handle) = port_listeners.remove(&port) {
log::info!("Stopping listener on port {}", port);
handle.abort();
}
}
// Add new ports
for &port in new_set.difference(&old_set) {
let tunnel_ctrl_tx = tunnel_ctrl_tx.clone();
let tunnel_data_tx = tunnel_data_tx.clone();
let tunnel_sustained_tx = tunnel_sustained_tx.clone();
let client_writers = client_writers.clone();
let active_streams = active_streams.clone();
let next_stream_id = next_stream_id.clone();
let edge_id = edge_id.to_string();
let port_token = connection_token.child_token();
let bind_addr = bind_address.to_string();
let handle = tokio::spawn(async move {
let listener = match TcpListener::bind((bind_addr.as_str(), port)).await {
Ok(l) => l,
Err(e) => {
log::error!("Failed to bind port {}: {}", port, e);
return;
}
};
log::info!("Listening on port {}", port);
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((client_stream, client_addr)) => {
// TCP keepalive detects dead clients that disappear without FIN.
// Without this, zombie streams accumulate and never get cleaned up.
let _ = client_stream.set_nodelay(true);
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60));
#[cfg(target_os = "linux")]
let ka = ka.with_interval(Duration::from_secs(60));
let _ = socket2::SockRef::from(&client_stream).set_tcp_keepalive(&ka);
let stream_id = next_stream_id.fetch_add(1, Ordering::Relaxed);
let tunnel_ctrl_tx = tunnel_ctrl_tx.clone();
let tunnel_data_tx = tunnel_data_tx.clone();
let tunnel_sustained_tx = tunnel_sustained_tx.clone();
let client_writers = client_writers.clone();
let active_streams = active_streams.clone();
let edge_id = edge_id.clone();
let client_token = port_token.child_token();
active_streams.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
handle_client_connection(
client_stream,
client_addr,
stream_id,
port,
&edge_id,
tunnel_ctrl_tx,
tunnel_data_tx,
tunnel_sustained_tx,
client_writers,
client_token,
Arc::clone(&active_streams),
)
.await;
// Saturating decrement: prevent underflow when
// edge_main_loop's store(0) races with task cleanup.
loop {
let current = active_streams.load(Ordering::Relaxed);
if current == 0 { break; }
if active_streams.compare_exchange_weak(
current, current - 1,
Ordering::Relaxed, Ordering::Relaxed,
).is_ok() {
break;
}
}
});
}
Err(e) => {
log::error!("Accept error on port {}: {}", port, e);
}
}
}
_ = port_token.cancelled() => {
log::info!("Port {} listener cancelled", port);
break;
}
}
}
});
port_listeners.insert(port, handle);
}
}
/// Apply UDP port configuration: bind UdpSockets for added ports, abort removed ports.
fn apply_udp_port_config(
new_ports: &[u16],
udp_listeners: &mut HashMap<u16, JoinHandle<()>>,
tunnel_ctrl_tx: &mpsc::Sender<Bytes>,
tunnel_data_tx: &mpsc::Sender<Bytes>,
udp_sessions: &Arc<Mutex<UdpSessionManager>>,
udp_sockets: &Arc<Mutex<HashMap<u16, Arc<UdpSocket>>>>,
next_stream_id: &Arc<AtomicU32>,
connection_token: &CancellationToken,
bind_address: &str,
) {
let new_set: std::collections::HashSet<u16> = new_ports.iter().copied().collect();
let old_set: std::collections::HashSet<u16> = udp_listeners.keys().copied().collect();
// Remove ports no longer needed
for &port in old_set.difference(&new_set) {
if let Some(handle) = udp_listeners.remove(&port) {
log::info!("Stopping UDP listener on port {}", port);
handle.abort();
}
// Remove socket from shared map
let sockets = udp_sockets.clone();
tokio::spawn(async move {
sockets.lock().await.remove(&port);
});
}
// Add new ports
for &port in new_set.difference(&old_set) {
let tunnel_ctrl_tx = tunnel_ctrl_tx.clone();
let tunnel_data_tx = tunnel_data_tx.clone();
let udp_sessions = udp_sessions.clone();
let udp_sockets = udp_sockets.clone();
let next_stream_id = next_stream_id.clone();
let port_token = connection_token.child_token();
let bind_addr = bind_address.to_string();
let handle = tokio::spawn(async move {
let socket = match UdpSocket::bind((bind_addr.as_str(), port)).await {
Ok(s) => Arc::new(s),
Err(e) => {
log::error!("Failed to bind UDP port {}: {}", port, e);
return;
}
};
log::info!("Listening on UDP port {}", port);
// Register socket in shared map for return traffic
udp_sockets.lock().await.insert(port, socket.clone());
let mut buf = vec![0u8; 65536]; // max UDP datagram size
loop {
tokio::select! {
recv_result = socket.recv_from(&mut buf) => {
match recv_result {
Ok((len, client_addr)) => {
let key = UdpSessionKey { client_addr, dest_port: port };
let mut sessions = udp_sessions.lock().await;
let stream_id = if let Some(session) = sessions.get_mut(&key) {
session.stream_id
} else {
// New session — allocate stream_id and send UDP_OPEN
let sid = next_stream_id.fetch_add(1, Ordering::Relaxed);
sessions.insert(key, sid);
let client_ip = client_addr.ip().to_string();
let client_port = client_addr.port();
let proxy_header = build_proxy_v2_header_from_str(
&client_ip, "0.0.0.0", client_port, port,
ProxyV2Transport::Udp,
);
let open_frame = encode_frame(sid, FRAME_UDP_OPEN, &proxy_header);
let _ = tunnel_ctrl_tx.try_send(open_frame);
log::debug!("New UDP session {} from {} -> port {}", sid, client_addr, port);
sid
};
drop(sessions); // release lock before sending
// Send datagram through tunnel
let data_frame = encode_frame(stream_id, FRAME_UDP_DATA, &buf[..len]);
let _ = tunnel_data_tx.try_send(data_frame);
}
Err(e) => {
log::error!("UDP recv error on port {}: {}", port, e);
}
}
}
_ = port_token.cancelled() => {
log::info!("UDP port {} listener cancelled", port);
break;
}
}
}
});
udp_listeners.insert(port, handle);
}
}
async fn handle_client_connection(
client_stream: TcpStream,
client_addr: std::net::SocketAddr,
stream_id: u32,
dest_port: u16,
edge_id: &str,
tunnel_ctrl_tx: mpsc::Sender<Bytes>,
tunnel_data_tx: mpsc::Sender<Bytes>,
tunnel_sustained_tx: mpsc::Sender<Bytes>,
client_writers: Arc<Mutex<HashMap<u32, EdgeStreamState>>>,
client_token: CancellationToken,
active_streams: Arc<AtomicU32>,
) {
let client_ip = client_addr.ip().to_string();
let client_port = client_addr.port();
// Determine edge IP (use 0.0.0.0 as placeholder — hub doesn't use it for routing)
let edge_ip = "0.0.0.0";
// Send OPEN frame with PROXY v1 header via control channel
let proxy_header = build_proxy_v1_header(&client_ip, edge_ip, client_port, dest_port);
let open_frame = encode_frame(stream_id, FRAME_OPEN, proxy_header.as_bytes());
let send_ok = tokio::select! {
result = tunnel_ctrl_tx.send(open_frame) => result.is_ok(),
_ = client_token.cancelled() => false,
};
if !send_ok {
return;
}
// Per-stream unbounded back-channel. Flow control (WINDOW_UPDATE) limits
// bytes-in-flight, so this won't grow unbounded. Unbounded avoids killing
// streams due to channel overflow — backpressure slows streams, never kills them.
let (back_tx, mut back_rx) = mpsc::unbounded_channel::<Bytes>();
// Adaptive initial window: scale with current stream count to keep total in-flight
// data within the 200MB budget. Prevents burst flooding when many streams open.
let initial_window = remoteingress_protocol::compute_window_for_stream_count(
active_streams.load(Ordering::Relaxed),
);
let send_window = Arc::new(AtomicU32::new(initial_window));
let window_notify = Arc::new(Notify::new());
{
let mut writers = client_writers.lock().await;
writers.insert(stream_id, EdgeStreamState {
back_tx,
send_window: Arc::clone(&send_window),
window_notify: Arc::clone(&window_notify),
});
}
let (mut client_read, mut client_write) = client_stream.into_split();
// Task: hub -> client (download direction)
// After writing to client TCP, send WINDOW_UPDATE to hub so it can send more
let hub_to_client_token = client_token.clone();
let wu_tx = tunnel_ctrl_tx.clone();
let active_streams_h2c = Arc::clone(&active_streams);
let mut hub_to_client = tokio::spawn(async move {
let mut consumed_since_update: u32 = 0;
loop {
tokio::select! {
data = back_rx.recv() => {
match data {
Some(data) => {
let len = data.len() as u32;
if client_write.write_all(&data).await.is_err() {
break;
}
// Track consumption for adaptive flow control.
// The increment is capped to the adaptive window so the sender's
// effective window shrinks to match current demand (fewer streams
// = larger window, more streams = smaller window per stream).
consumed_since_update += len;
let adaptive_window = remoteingress_protocol::compute_window_for_stream_count(
active_streams_h2c.load(Ordering::Relaxed),
);
let threshold = adaptive_window / 2;
if consumed_since_update >= threshold {
let increment = consumed_since_update.min(adaptive_window);
let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE, increment);
// Use send().await for guaranteed delivery — dropping WINDOW_UPDATEs
// causes permanent flow stalls. Safe: runs in per-stream task, not main loop.
tokio::select! {
result = wu_tx.send(frame) => {
if result.is_ok() {
consumed_since_update -= increment;
}
}
_ = hub_to_client_token.cancelled() => break,
}
}
}
None => break,
}
}
_ = hub_to_client_token.cancelled() => break,
}
}
// Send final window update for any remaining consumed bytes
if consumed_since_update > 0 {
let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE, consumed_since_update);
tokio::select! {
_ = wu_tx.send(frame) => {}
_ = hub_to_client_token.cancelled() => {}
}
}
let _ = client_write.shutdown().await;
});
// Task: client -> hub (upload direction) with per-stream flow control.
// Zero-copy: read payload directly after the header, then prepend header.
let mut buf = vec![0u8; FRAME_HEADER_SIZE + 32768];
let mut stream_bytes_sent: u64 = 0;
let stream_start = tokio::time::Instant::now();
let mut is_sustained = false;
loop {
// Wait for send window to have capacity (with stall timeout).
// Safe pattern: register notified BEFORE checking the condition
// to avoid missing a notify_one that fires between load and select.
loop {
let notified = window_notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
let w = send_window.load(Ordering::Acquire);
if w > 0 { break; }
tokio::select! {
_ = notified => continue,
_ = client_token.cancelled() => break,
_ = tokio::time::sleep(Duration::from_secs(120)) => {
log::warn!("Stream {} upload stalled (window empty for 120s)", stream_id);
break;
}
}
}
if client_token.is_cancelled() { break; }
// Limit read size to available window.
// IMPORTANT: if window is 0 (stall timeout fired), we must NOT
// read into an empty buffer — read(&mut buf[..0]) returns Ok(0)
// which would be falsely interpreted as EOF.
let w = send_window.load(Ordering::Acquire) as usize;
if w == 0 {
log::warn!("Stream {} upload: window still 0 after stall timeout, closing", stream_id);
break;
}
let max_read = w.min(32768);
tokio::select! {
read_result = client_read.read(&mut buf[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + max_read]) => {
match read_result {
Ok(0) => break,
Ok(n) => {
send_window.fetch_sub(n as u32, Ordering::Release);
encode_frame_header(&mut buf, stream_id, FRAME_DATA, n);
let data_frame = Bytes::copy_from_slice(&buf[..FRAME_HEADER_SIZE + n]);
// Sustained classification: >2.5 MB/s for >10 seconds
stream_bytes_sent += n as u64;
if !is_sustained {
let elapsed = stream_start.elapsed().as_secs();
if elapsed >= remoteingress_protocol::SUSTAINED_MIN_DURATION_SECS
&& stream_bytes_sent / elapsed >= remoteingress_protocol::SUSTAINED_THRESHOLD_BPS
{
is_sustained = true;
log::debug!("Stream {} classified as sustained (upload, {} bytes in {}s)",
stream_id, stream_bytes_sent, elapsed);
}
}
let tx = if is_sustained { &tunnel_sustained_tx } else { &tunnel_data_tx };
let sent = tokio::select! {
result = tx.send(data_frame) => result.is_ok(),
_ = client_token.cancelled() => false,
};
if !sent { break; }
}
Err(_) => break,
}
}
_ = client_token.cancelled() => break,
}
}
// Wait for the download task (hub → client) to finish BEFORE sending CLOSE.
// Upload EOF (client done sending) does NOT mean the response is done.
// For asymmetric transfers like git fetch (small request, large response),
// the response is still streaming when the upload finishes.
// Sending CLOSE before the response finishes would cause the hub to cancel
// the upstream reader mid-response, truncating the data.
let _ = tokio::time::timeout(
Duration::from_secs(300), // 5 min max wait for download to finish
&mut hub_to_client,
).await;
// NOW send CLOSE — the response has been fully delivered (or timed out).
// select! with cancellation guard prevents indefinite blocking if tunnel dies.
if !client_token.is_cancelled() {
let close_frame = encode_frame(stream_id, FRAME_CLOSE, &[]);
let tx = if is_sustained { &tunnel_sustained_tx } else { &tunnel_data_tx };
tokio::select! {
_ = tx.send(close_frame) => {}
_ = client_token.cancelled() => {}
}
}
// Clean up
{
let mut writers = client_writers.lock().await;
writers.remove(&stream_id);
}
hub_to_client.abort(); // No-op if already finished; safety net if timeout fired
let _ = edge_id; // used for logging context
}
// ===== QUIC transport functions =====
/// Perform QUIC handshake only (used by QuicWithFallback to test connectivity).
async fn connect_to_hub_quic_handshake(
config: &EdgeConfig,
endpoint: &quinn::Endpoint,
_connection_token: &CancellationToken,
) -> Result<quinn::Connection, Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("{}:{}", config.hub_host, config.hub_port);
let server_addr: std::net::SocketAddr = tokio::net::lookup_host(&addr)
.await?
.next()
.ok_or("DNS resolution failed")?;
// QUIC/TLS SNI requires a hostname, not an IP address.
// If hub_host is an IP, use the same fallback as the TCP+TLS path.
let server_name = match rustls::pki_types::ServerName::try_from(config.hub_host.as_str()) {
Ok(rustls::pki_types::ServerName::DnsName(_)) => config.hub_host.clone(),
_ => "remoteingress-hub".to_string(),
};
let connection = endpoint.connect(server_addr, &server_name)?.await?;
Ok(connection)
}
/// QUIC edge: connect to hub, authenticate, and run the stream multiplexer.
async fn connect_to_hub_and_run_quic(
config: &EdgeConfig,
connected: &Arc<RwLock<bool>>,
public_ip: &Arc<RwLock<Option<String>>>,
active_streams: &Arc<AtomicU32>,
next_stream_id: &Arc<AtomicU32>,
event_tx: &mpsc::Sender<EdgeEvent>,
listen_ports: &Arc<RwLock<Vec<u16>>>,
shutdown_rx: &mut mpsc::Receiver<()>,
connection_token: &CancellationToken,
endpoint: &quinn::Endpoint,
) -> EdgeLoopResult {
// Establish QUIC connection
let quic_conn = match connect_to_hub_quic_handshake(config, endpoint, connection_token).await {
Ok(c) => c,
Err(e) => {
log::error!("QUIC connect failed: {}", e);
return EdgeLoopResult::Reconnect(format!("quic_connect_failed: {}", e));
}
};
connect_to_hub_and_run_quic_with_connection(
config, connected, public_ip, active_streams, next_stream_id,
event_tx, listen_ports, shutdown_rx, connection_token, quic_conn,
).await
}
/// QUIC edge: run with an already-established QUIC connection.
async fn connect_to_hub_and_run_quic_with_connection(
config: &EdgeConfig,
connected: &Arc<RwLock<bool>>,
public_ip: &Arc<RwLock<Option<String>>>,
active_streams: &Arc<AtomicU32>,
next_stream_id: &Arc<AtomicU32>,
event_tx: &mpsc::Sender<EdgeEvent>,
listen_ports: &Arc<RwLock<Vec<u16>>>,
shutdown_rx: &mut mpsc::Receiver<()>,
connection_token: &CancellationToken,
quic_conn: quinn::Connection,
) -> EdgeLoopResult {
log::info!("QUIC connection established to {}", quic_conn.remote_address());
// Open control stream (first bidirectional stream)
let (mut ctrl_send, mut ctrl_recv) = match quic_conn.open_bi().await {
Ok(s) => s,
Err(e) => {
log::error!("Failed to open QUIC control stream: {}", e);
return EdgeLoopResult::Reconnect(format!("quic_ctrl_open_failed: {}", e));
}
};
// Auth handshake on control stream (same protocol as TCP+TLS)
let auth_line = format!("EDGE {} {}\n", config.edge_id, config.secret);
if let Err(e) = ctrl_send.write_all(auth_line.as_bytes()).await {
return EdgeLoopResult::Reconnect(format!("quic_auth_write_failed: {}", e));
}
// Read handshake response (newline-delimited JSON)
let mut handshake_bytes = Vec::with_capacity(512);
let mut byte = [0u8; 1];
loop {
match ctrl_recv.read_exact(&mut byte).await {
Ok(()) => {
handshake_bytes.push(byte[0]);
if byte[0] == b'\n' { break; }
if handshake_bytes.len() > 8192 {
return EdgeLoopResult::Reconnect("quic_handshake_too_long".to_string());
}
}
Err(e) => {
log::error!("QUIC handshake read failed: {}", e);
return EdgeLoopResult::Reconnect(format!("quic_handshake_read_failed: {}", e));
}
}
}
let handshake_line = String::from_utf8_lossy(&handshake_bytes);
let handshake: HandshakeConfig = match serde_json::from_str(handshake_line.trim()) {
Ok(h) => h,
Err(e) => {
log::error!("Invalid QUIC handshake response: {}", e);
return EdgeLoopResult::Reconnect(format!("quic_handshake_invalid: {}", e));
}
};
log::info!(
"QUIC handshake from hub: ports {:?}, stun_interval {}s",
handshake.listen_ports,
handshake.stun_interval_secs
);
*connected.write().await = true;
let _ = event_tx.try_send(EdgeEvent::TunnelConnected);
log::info!("Connected to hub via QUIC at {}", quic_conn.remote_address());
*listen_ports.write().await = handshake.listen_ports.clone();
let _ = event_tx.try_send(EdgeEvent::PortsAssigned {
listen_ports: handshake.listen_ports.clone(),
});
// Start STUN discovery
let stun_interval = handshake.stun_interval_secs;
let public_ip_clone = public_ip.clone();
let event_tx_clone = event_tx.clone();
let stun_token = connection_token.clone();
let stun_handle = tokio::spawn(async move {
loop {
tokio::select! {
ip_result = crate::stun::discover_public_ip() => {
if let Some(ip) = ip_result {
let mut pip = public_ip_clone.write().await;
let changed = pip.as_ref() != Some(&ip);
*pip = Some(ip.clone());
if changed {
let _ = event_tx_clone.try_send(EdgeEvent::PublicIpDiscovered { ip });
}
}
}
_ = stun_token.cancelled() => break,
}
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(stun_interval)) => {}
_ = stun_token.cancelled() => break,
}
}
});
// Start TCP listeners for the assigned ports.
// For QUIC, each client connection opens a new QUIC bidirectional stream.
let mut port_listeners: HashMap<u16, JoinHandle<()>> = HashMap::new();
let bind_address = config.bind_address.as_deref().unwrap_or("0.0.0.0");
apply_port_config_quic(
&handshake.listen_ports,
&mut port_listeners,
&quic_conn,
active_streams,
next_stream_id,
&config.edge_id,
connection_token,
bind_address,
);
// UDP listeners for QUIC transport — uses QUIC datagrams for low-latency forwarding.
let udp_sessions_quic: Arc<Mutex<UdpSessionManager>> =
Arc::new(Mutex::new(UdpSessionManager::new(Duration::from_secs(60))));
let udp_sockets_quic: Arc<Mutex<HashMap<u16, Arc<UdpSocket>>>> =
Arc::new(Mutex::new(HashMap::new()));
let mut udp_listeners_quic: HashMap<u16, JoinHandle<()>> = HashMap::new();
apply_udp_port_config_quic(
&handshake.listen_ports_udp,
&mut udp_listeners_quic,
&quic_conn,
&udp_sessions_quic,
&udp_sockets_quic,
next_stream_id,
connection_token,
bind_address,
);
// Monitor control stream for config updates, connection health, and QUIC datagrams.
let result = 'quic_loop: loop {
tokio::select! {
// Read control messages from hub
ctrl_msg = quic_transport::read_ctrl_message(&mut ctrl_recv) => {
match ctrl_msg {
Ok(Some((msg_type, payload))) => {
match msg_type {
quic_transport::CTRL_CONFIG => {
if let Ok(update) = serde_json::from_slice::<ConfigUpdate>(&payload) {
log::info!("QUIC config update from hub: ports {:?}", update.listen_ports);
*listen_ports.write().await = update.listen_ports.clone();
let _ = event_tx.try_send(EdgeEvent::PortsUpdated {
listen_ports: update.listen_ports.clone(),
});
apply_port_config_quic(
&update.listen_ports,
&mut port_listeners,
&quic_conn,
active_streams,
next_stream_id,
&config.edge_id,
connection_token,
bind_address,
);
}
}
quic_transport::CTRL_PING => {
// Respond with PONG on control stream
if let Err(e) = quic_transport::write_ctrl_message(
&mut ctrl_send, quic_transport::CTRL_PONG, &[],
).await {
log::error!("Failed to send QUIC PONG: {}", e);
break 'quic_loop EdgeLoopResult::Reconnect(
format!("quic_pong_failed: {}", e),
);
}
}
_ => {
log::warn!("Unknown QUIC control message type: {}", msg_type);
}
}
}
Ok(None) => {
log::info!("Hub closed QUIC control stream (EOF)");
break 'quic_loop EdgeLoopResult::Reconnect("quic_ctrl_eof".to_string());
}
Err(e) => {
log::error!("QUIC control stream read error: {}", e);
break 'quic_loop EdgeLoopResult::Reconnect(
format!("quic_ctrl_error: {}", e),
);
}
}
}
// Receive QUIC datagrams (UDP return traffic from hub)
datagram = quic_conn.read_datagram() => {
match datagram {
Ok(data) => {
// Format: [session_id:4][payload:N]
if data.len() >= 4 {
let session_id = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
let payload = &data[4..];
let mut sessions = udp_sessions_quic.lock().await;
if let Some(session) = sessions.get_by_stream_id(session_id) {
let client_addr = session.client_addr;
let dest_port = session.dest_port;
let sockets = udp_sockets_quic.lock().await;
if let Some(socket) = sockets.get(&dest_port) {
let _ = socket.send_to(payload, client_addr).await;
}
}
}
}
Err(e) => {
log::debug!("QUIC datagram recv error: {}", e);
}
}
}
// QUIC connection closed
reason = quic_conn.closed() => {
log::info!("QUIC connection closed: {}", reason);
break 'quic_loop EdgeLoopResult::Reconnect(format!("quic_closed: {}", reason));
}
// Shutdown signal
_ = connection_token.cancelled() => {
if shutdown_rx.try_recv().is_ok() {
break 'quic_loop EdgeLoopResult::Shutdown;
}
break 'quic_loop EdgeLoopResult::Shutdown;
}
}
};
// Cleanup
connection_token.cancel();
stun_handle.abort();
for (_, h) in port_listeners.drain() {
h.abort();
}
for (_, h) in udp_listeners_quic.drain() {
h.abort();
}
// Graceful QUIC close
quic_conn.close(quinn::VarInt::from_u32(0), b"shutdown");
result
}
/// Apply port config for QUIC transport: spawn TCP listeners that open QUIC streams.
fn apply_port_config_quic(
new_ports: &[u16],
port_listeners: &mut HashMap<u16, JoinHandle<()>>,
quic_conn: &quinn::Connection,
active_streams: &Arc<AtomicU32>,
next_stream_id: &Arc<AtomicU32>,
edge_id: &str,
connection_token: &CancellationToken,
bind_address: &str,
) {
let new_set: std::collections::HashSet<u16> = new_ports.iter().copied().collect();
let old_set: std::collections::HashSet<u16> = port_listeners.keys().copied().collect();
// Remove ports no longer needed
for &port in old_set.difference(&new_set) {
if let Some(handle) = port_listeners.remove(&port) {
log::info!("Stopping QUIC listener on port {}", port);
handle.abort();
}
}
// Add new ports
for &port in new_set.difference(&old_set) {
let quic_conn = quic_conn.clone();
let active_streams = active_streams.clone();
let next_stream_id = next_stream_id.clone();
let _edge_id = edge_id.to_string();
let port_token = connection_token.child_token();
let bind_addr = bind_address.to_string();
let handle = tokio::spawn(async move {
let listener = match TcpListener::bind((bind_addr.as_str(), port)).await {
Ok(l) => l,
Err(e) => {
log::error!("Failed to bind port {} (QUIC): {}", port, e);
return;
}
};
log::info!("Listening on port {} (QUIC transport)", port);
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((client_stream, client_addr)) => {
let _ = client_stream.set_nodelay(true);
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60));
#[cfg(target_os = "linux")]
let ka = ka.with_interval(Duration::from_secs(60));
let _ = socket2::SockRef::from(&client_stream).set_tcp_keepalive(&ka);
let stream_id = next_stream_id.fetch_add(1, Ordering::Relaxed);
let quic_conn = quic_conn.clone();
let active_streams = active_streams.clone();
let client_token = port_token.child_token();
active_streams.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
handle_client_connection_quic(
client_stream,
client_addr,
stream_id,
port,
quic_conn,
client_token,
).await;
// Saturating decrement
loop {
let current = active_streams.load(Ordering::Relaxed);
if current == 0 { break; }
if active_streams.compare_exchange_weak(
current, current - 1,
Ordering::Relaxed, Ordering::Relaxed,
).is_ok() {
break;
}
}
});
}
Err(e) => {
log::error!("Accept error on port {} (QUIC): {}", port, e);
}
}
}
_ = port_token.cancelled() => {
log::info!("Port {} QUIC listener cancelled", port);
break;
}
}
}
});
port_listeners.insert(port, handle);
}
}
/// Handle a single client connection via QUIC transport.
/// Opens a new QUIC bidirectional stream, sends the PROXY header,
/// then bidirectionally copies data between the client TCP socket and the QUIC stream.
/// Apply UDP port config for QUIC transport: bind UdpSockets that send via QUIC datagrams.
fn apply_udp_port_config_quic(
new_ports: &[u16],
udp_listeners: &mut HashMap<u16, JoinHandle<()>>,
quic_conn: &quinn::Connection,
udp_sessions: &Arc<Mutex<UdpSessionManager>>,
udp_sockets: &Arc<Mutex<HashMap<u16, Arc<UdpSocket>>>>,
next_stream_id: &Arc<AtomicU32>,
connection_token: &CancellationToken,
bind_address: &str,
) {
let new_set: std::collections::HashSet<u16> = new_ports.iter().copied().collect();
let old_set: std::collections::HashSet<u16> = udp_listeners.keys().copied().collect();
for &port in old_set.difference(&new_set) {
if let Some(handle) = udp_listeners.remove(&port) {
log::info!("Stopping QUIC UDP listener on port {}", port);
handle.abort();
}
let sockets = udp_sockets.clone();
tokio::spawn(async move { sockets.lock().await.remove(&port); });
}
for &port in new_set.difference(&old_set) {
let quic_conn = quic_conn.clone();
let udp_sessions = udp_sessions.clone();
let udp_sockets = udp_sockets.clone();
let next_stream_id = next_stream_id.clone();
let port_token = connection_token.child_token();
let bind_addr = bind_address.to_string();
let handle = tokio::spawn(async move {
let socket = match UdpSocket::bind((bind_addr.as_str(), port)).await {
Ok(s) => Arc::new(s),
Err(e) => {
log::error!("Failed to bind QUIC UDP port {}: {}", port, e);
return;
}
};
log::info!("Listening on UDP port {} (QUIC datagram transport)", port);
udp_sockets.lock().await.insert(port, socket.clone());
let mut buf = vec![0u8; 65536];
loop {
tokio::select! {
recv_result = socket.recv_from(&mut buf) => {
match recv_result {
Ok((len, client_addr)) => {
let key = UdpSessionKey { client_addr, dest_port: port };
let mut sessions = udp_sessions.lock().await;
let stream_id = if let Some(session) = sessions.get_mut(&key) {
session.stream_id
} else {
// New session — send PROXY v2 header via control-style datagram
let sid = next_stream_id.fetch_add(1, Ordering::Relaxed);
sessions.insert(key, sid);
let client_ip = client_addr.ip().to_string();
let client_port = client_addr.port();
let proxy_header = build_proxy_v2_header_from_str(
&client_ip, "0.0.0.0", client_port, port,
ProxyV2Transport::Udp,
);
// Send OPEN as a QUIC datagram: [session_id:4][0xFF magic:1][proxy_header:28]
let mut open_buf = Vec::with_capacity(4 + 1 + proxy_header.len());
open_buf.extend_from_slice(&sid.to_be_bytes());
open_buf.push(0xFF); // magic byte to distinguish OPEN from DATA
open_buf.extend_from_slice(&proxy_header);
let _ = quic_conn.send_datagram(open_buf.into());
log::debug!("New QUIC UDP session {} from {} -> port {}", sid, client_addr, port);
sid
};
drop(sessions);
// Send datagram: [session_id:4][payload:N]
let mut dgram = Vec::with_capacity(4 + len);
dgram.extend_from_slice(&stream_id.to_be_bytes());
dgram.extend_from_slice(&buf[..len]);
let _ = quic_conn.send_datagram(dgram.into());
}
Err(e) => {
log::error!("QUIC UDP recv error on port {}: {}", port, e);
}
}
}
_ = port_token.cancelled() => {
log::info!("QUIC UDP port {} listener cancelled", port);
break;
}
}
}
});
udp_listeners.insert(port, handle);
}
}
async fn handle_client_connection_quic(
client_stream: TcpStream,
client_addr: std::net::SocketAddr,
stream_id: u32,
dest_port: u16,
quic_conn: quinn::Connection,
client_token: CancellationToken,
) {
let client_ip = client_addr.ip().to_string();
let client_port = client_addr.port();
let edge_ip = "0.0.0.0";
// Open a new QUIC bidirectional stream for this client connection
let (mut quic_send, mut quic_recv) = match quic_conn.open_bi().await {
Ok(s) => s,
Err(e) => {
log::error!("Stream {} failed to open QUIC bi stream: {}", stream_id, e);
return;
}
};
// Send PROXY header as first bytes on the stream
let proxy_header = build_proxy_v1_header(&client_ip, edge_ip, client_port, dest_port);
if let Err(e) = quic_transport::write_proxy_header(&mut quic_send, &proxy_header).await {
log::error!("Stream {} failed to write PROXY header: {}", stream_id, e);
return;
}
let (mut client_read, mut client_write) = client_stream.into_split();
// Task: QUIC -> client (download direction)
let dl_token = client_token.clone();
let mut dl_task = tokio::spawn(async move {
let mut buf = vec![0u8; 32768];
loop {
tokio::select! {
read_result = quic_recv.read(&mut buf) => {
match read_result {
Ok(Some(n)) => {
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Ok(None) => break, // QUIC stream finished
Err(_) => break,
}
}
_ = dl_token.cancelled() => break,
}
}
let _ = client_write.shutdown().await;
});
// Task: client -> QUIC (upload direction)
let mut buf = vec![0u8; 32768];
loop {
tokio::select! {
read_result = client_read.read(&mut buf) => {
match read_result {
Ok(0) => break, // client EOF
Ok(n) => {
if quic_send.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
_ = client_token.cancelled() => break,
}
}
// Wait for download task to finish before closing the QUIC stream
let _ = tokio::time::timeout(Duration::from_secs(300), &mut dl_task).await;
// Gracefully close the QUIC send stream
let _ = quic_send.finish();
dl_task.abort();
}
#[cfg(test)]
mod tests {
use super::*;
// --- Serde tests ---
#[test]
fn test_edge_config_deserialize_camel_case() {
let json = r#"{
"hubHost": "hub.example.com",
"hubPort": 8443,
"edgeId": "edge-1",
"secret": "my-secret"
}"#;
let config: EdgeConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.hub_host, "hub.example.com");
assert_eq!(config.hub_port, 8443);
assert_eq!(config.edge_id, "edge-1");
assert_eq!(config.secret, "my-secret");
}
#[test]
fn test_edge_config_serialize_roundtrip() {
let config = EdgeConfig {
hub_host: "host.test".to_string(),
hub_port: 9999,
edge_id: "e1".to_string(),
secret: "sec".to_string(),
bind_address: None,
transport_mode: None,
};
let json = serde_json::to_string(&config).unwrap();
let back: EdgeConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.hub_host, config.hub_host);
assert_eq!(back.hub_port, config.hub_port);
assert_eq!(back.edge_id, config.edge_id);
assert_eq!(back.secret, config.secret);
}
#[test]
fn test_handshake_config_deserialize_all_fields() {
let json = r#"{"listenPorts": [80, 443], "stunIntervalSecs": 120}"#;
let hc: HandshakeConfig = serde_json::from_str(json).unwrap();
assert_eq!(hc.listen_ports, vec![80, 443]);
assert_eq!(hc.stun_interval_secs, 120);
}
#[test]
fn test_edge_config_transport_mode_deserialize() {
let json = r#"{
"hubHost": "hub.test",
"hubPort": 8443,
"edgeId": "e1",
"secret": "s",
"transportMode": "quic"
}"#;
let config: EdgeConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.transport_mode, Some(TransportMode::Quic));
}
#[test]
fn test_edge_config_transport_mode_default() {
let json = r#"{
"hubHost": "hub.test",
"hubPort": 8443,
"edgeId": "e1",
"secret": "s"
}"#;
let config: EdgeConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.transport_mode, None);
}
#[test]
fn test_edge_config_transport_mode_quic_with_fallback() {
let json = r#"{
"hubHost": "hub.test",
"hubPort": 8443,
"edgeId": "e1",
"secret": "s",
"transportMode": "quicWithFallback"
}"#;
let config: EdgeConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.transport_mode, Some(TransportMode::QuicWithFallback));
}
#[test]
fn test_handshake_config_default_stun_interval() {
let json = r#"{"listenPorts": [443]}"#;
let hc: HandshakeConfig = serde_json::from_str(json).unwrap();
assert_eq!(hc.listen_ports, vec![443]);
assert_eq!(hc.stun_interval_secs, 300);
}
#[test]
fn test_config_update_deserialize() {
let json = r#"{"listenPorts": [8080, 9090]}"#;
let update: ConfigUpdate = serde_json::from_str(json).unwrap();
assert_eq!(update.listen_ports, vec![8080, 9090]);
}
#[test]
fn test_edge_status_serialize() {
let status = EdgeStatus {
running: true,
connected: true,
public_ip: Some("1.2.3.4".to_string()),
active_streams: 5,
listen_ports: vec![443],
};
let json = serde_json::to_value(&status).unwrap();
assert_eq!(json["running"], true);
assert_eq!(json["connected"], true);
assert_eq!(json["publicIp"], "1.2.3.4");
assert_eq!(json["activeStreams"], 5);
assert_eq!(json["listenPorts"], serde_json::json!([443]));
}
#[test]
fn test_edge_status_serialize_none_ip() {
let status = EdgeStatus {
running: false,
connected: false,
public_ip: None,
active_streams: 0,
listen_ports: vec![],
};
let json = serde_json::to_value(&status).unwrap();
assert!(json["publicIp"].is_null());
}
#[test]
fn test_edge_event_tunnel_connected() {
let event = EdgeEvent::TunnelConnected;
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "tunnelConnected");
}
#[test]
fn test_edge_event_tunnel_disconnected() {
let event = EdgeEvent::TunnelDisconnected { reason: "hub_eof".to_string() };
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "tunnelDisconnected");
assert_eq!(json["reason"], "hub_eof");
}
#[test]
fn test_edge_event_public_ip_discovered() {
let event = EdgeEvent::PublicIpDiscovered {
ip: "203.0.113.1".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "publicIpDiscovered");
assert_eq!(json["ip"], "203.0.113.1");
}
#[test]
fn test_edge_event_ports_assigned() {
let event = EdgeEvent::PortsAssigned {
listen_ports: vec![443, 8080],
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "portsAssigned");
assert_eq!(json["listenPorts"], serde_json::json!([443, 8080]));
}
#[test]
fn test_edge_event_ports_updated() {
let event = EdgeEvent::PortsUpdated {
listen_ports: vec![9090],
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "portsUpdated");
assert_eq!(json["listenPorts"], serde_json::json!([9090]));
}
// --- Async tests ---
#[tokio::test]
async fn test_tunnel_edge_new_get_status() {
let edge = TunnelEdge::new(EdgeConfig {
hub_host: "localhost".to_string(),
hub_port: 8443,
edge_id: "test-edge".to_string(),
secret: "test-secret".to_string(),
bind_address: None,
transport_mode: None,
});
let status = edge.get_status().await;
assert!(!status.running);
assert!(!status.connected);
assert!(status.public_ip.is_none());
assert_eq!(status.active_streams, 0);
assert!(status.listen_ports.is_empty());
}
#[tokio::test]
async fn test_tunnel_edge_take_event_rx() {
let edge = TunnelEdge::new(EdgeConfig {
hub_host: "localhost".to_string(),
hub_port: 8443,
edge_id: "e".to_string(),
secret: "s".to_string(),
bind_address: None,
transport_mode: None,
});
let rx1 = edge.take_event_rx().await;
assert!(rx1.is_some());
let rx2 = edge.take_event_rx().await;
assert!(rx2.is_none());
}
#[tokio::test]
async fn test_tunnel_edge_stop_without_start() {
let edge = TunnelEdge::new(EdgeConfig {
hub_host: "localhost".to_string(),
hub_port: 8443,
edge_id: "e".to_string(),
secret: "s".to_string(),
bind_address: None,
transport_mode: None,
});
edge.stop().await; // should not panic
let status = edge.get_status().await;
assert!(!status.running);
}
}
/// TLS certificate verifier that accepts any certificate (auth is via shared secret).
#[derive(Debug)]
struct NoCertVerifier;
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}