use std::collections::HashMap; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, Mutex, Notify, RwLock}; use tokio::task::JoinHandle; use tokio::time::{Instant, sleep_until}; use tokio_rustls::TlsConnector; use tokio_util::sync::CancellationToken; use serde::{Deserialize, Serialize}; use remoteingress_protocol::*; type EdgeTlsStream = tokio_rustls::client::TlsStream; /// Result of processing a frame (shared with hub.rs pattern). #[allow(dead_code)] enum EdgeFrameAction { Continue, Disconnect(String), } /// Per-stream state tracked in the edge's client_writers map. struct EdgeStreamState { /// Unbounded channel to deliver FRAME_DATA_BACK payloads to the hub_to_client task. /// Unbounded because flow control (WINDOW_UPDATE) already limits bytes-in-flight. back_tx: mpsc::UnboundedSender>, /// Send window for FRAME_DATA (upload direction). /// Decremented by the client reader, incremented by FRAME_WINDOW_UPDATE_BACK from hub. send_window: Arc, /// Notifier to wake the client reader when the window opens. window_notify: Arc, } /// Edge configuration (hub-host + credentials only; ports come from hub). #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct EdgeConfig { pub hub_host: String, pub hub_port: u16, pub edge_id: String, pub secret: String, /// Optional bind address for TCP listeners (defaults to "0.0.0.0"). /// Useful for testing on localhost where edge and upstream share the same machine. #[serde(default)] pub bind_address: Option, } /// Handshake config received from hub after authentication. #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] struct HandshakeConfig { listen_ports: Vec, #[serde(default = "default_stun_interval")] stun_interval_secs: u64, } fn default_stun_interval() -> u64 { 300 } /// Runtime config update received from hub via FRAME_CONFIG. #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] struct ConfigUpdate { listen_ports: Vec, } /// Events emitted by the edge. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] #[serde(tag = "type")] pub enum EdgeEvent { TunnelConnected, #[serde(rename_all = "camelCase")] TunnelDisconnected { reason: String }, #[serde(rename_all = "camelCase")] PublicIpDiscovered { ip: String }, #[serde(rename_all = "camelCase")] PortsAssigned { listen_ports: Vec }, #[serde(rename_all = "camelCase")] PortsUpdated { listen_ports: Vec }, } /// Edge status response. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct EdgeStatus { pub running: bool, pub connected: bool, pub public_ip: Option, pub active_streams: usize, pub listen_ports: Vec, } /// The tunnel edge that listens for client connections and multiplexes them to the hub. pub struct TunnelEdge { config: RwLock, event_tx: mpsc::Sender, event_rx: Mutex>>, shutdown_tx: Mutex>>, running: RwLock, connected: Arc>, public_ip: Arc>>, active_streams: Arc, next_stream_id: Arc, listen_ports: Arc>>, cancel_token: CancellationToken, } impl TunnelEdge { pub fn new(config: EdgeConfig) -> Self { let (event_tx, event_rx) = mpsc::channel(1024); Self { config: RwLock::new(config), event_tx, event_rx: Mutex::new(Some(event_rx)), shutdown_tx: Mutex::new(None), running: RwLock::new(false), connected: Arc::new(RwLock::new(false)), public_ip: Arc::new(RwLock::new(None)), active_streams: Arc::new(AtomicU32::new(0)), next_stream_id: Arc::new(AtomicU32::new(1)), listen_ports: Arc::new(RwLock::new(Vec::new())), cancel_token: CancellationToken::new(), } } /// Take the event receiver (can only be called once). pub async fn take_event_rx(&self) -> Option> { self.event_rx.lock().await.take() } /// Get the current edge status. pub async fn get_status(&self) -> EdgeStatus { EdgeStatus { running: *self.running.read().await, connected: *self.connected.read().await, public_ip: self.public_ip.read().await.clone(), active_streams: self.active_streams.load(Ordering::Relaxed) as usize, listen_ports: self.listen_ports.read().await.clone(), } } /// Start the edge: connect to hub, start listeners. pub async fn start(&self) -> Result<(), Box> { let config = self.config.read().await.clone(); let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1); *self.shutdown_tx.lock().await = Some(shutdown_tx); *self.running.write().await = true; let connected = self.connected.clone(); let public_ip = self.public_ip.clone(); let active_streams = self.active_streams.clone(); let next_stream_id = self.next_stream_id.clone(); let event_tx = self.event_tx.clone(); let listen_ports = self.listen_ports.clone(); let cancel_token = self.cancel_token.clone(); tokio::spawn(async move { edge_main_loop( config, connected, public_ip, active_streams, next_stream_id, event_tx, listen_ports, shutdown_rx, cancel_token, ) .await; }); Ok(()) } /// Stop the edge. pub async fn stop(&self) { self.cancel_token.cancel(); if let Some(tx) = self.shutdown_tx.lock().await.take() { let _ = tx.send(()).await; } *self.running.write().await = false; *self.connected.write().await = false; self.listen_ports.write().await.clear(); } } impl Drop for TunnelEdge { fn drop(&mut self) { self.cancel_token.cancel(); } } async fn edge_main_loop( config: EdgeConfig, connected: Arc>, public_ip: Arc>>, active_streams: Arc, next_stream_id: Arc, event_tx: mpsc::Sender, listen_ports: Arc>>, mut shutdown_rx: mpsc::Receiver<()>, cancel_token: CancellationToken, ) { let mut backoff_ms: u64 = 1000; let max_backoff_ms: u64 = 30000; // Build TLS config ONCE outside the reconnect loop — preserves session // cache across reconnections for TLS session resumption (saves 1 RTT). let tls_config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoCertVerifier)) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(tls_config)); loop { // Create a per-connection child token let connection_token = cancel_token.child_token(); // Try to connect to hub let result = connect_to_hub_and_run( &config, &connected, &public_ip, &active_streams, &next_stream_id, &event_tx, &listen_ports, &mut shutdown_rx, &connection_token, &connector, ) .await; // Cancel connection token to kill all orphaned tasks from this cycle connection_token.cancel(); // Reset backoff after a successful connection for fast reconnect let was_connected = *connected.read().await; if was_connected { backoff_ms = 1000; log::info!("Was connected; resetting backoff to {}ms for fast reconnect", backoff_ms); } *connected.write().await = false; // Extract reason for disconnect event let reason = match &result { EdgeLoopResult::Reconnect(r) => r.clone(), EdgeLoopResult::Shutdown => "shutdown".to_string(), }; // Only emit disconnect event on actual disconnection, not on failed reconnects. // Failed reconnects never reach line 335 (handshake success), so was_connected is false. if was_connected { let _ = event_tx.try_send(EdgeEvent::TunnelDisconnected { reason: reason.clone() }); } active_streams.store(0, Ordering::Relaxed); // Reset stream ID counter for next connection cycle next_stream_id.store(1, Ordering::Relaxed); listen_ports.write().await.clear(); match result { EdgeLoopResult::Shutdown => break, EdgeLoopResult::Reconnect(_) => { log::info!("Reconnecting in {}ms...", backoff_ms); tokio::select! { _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {} _ = cancel_token.cancelled() => break, _ = shutdown_rx.recv() => break, } backoff_ms = (backoff_ms * 2).min(max_backoff_ms); } } } } enum EdgeLoopResult { Shutdown, Reconnect(String), // reason for disconnection } /// Process a single frame received from the hub side of the tunnel. /// Handles FRAME_DATA_BACK, FRAME_WINDOW_UPDATE_BACK, FRAME_CLOSE_BACK, FRAME_CONFIG, FRAME_PING. async fn handle_edge_frame( frame: Frame, tunnel_io: &mut remoteingress_protocol::TunnelIo, client_writers: &Arc>>, listen_ports: &Arc>>, event_tx: &mpsc::Sender, tunnel_writer_tx: &mpsc::Sender>, tunnel_data_tx: &mpsc::Sender>, port_listeners: &mut HashMap>, active_streams: &Arc, next_stream_id: &Arc, edge_id: &str, connection_token: &CancellationToken, bind_address: &str, ) -> EdgeFrameAction { match frame.frame_type { FRAME_DATA_BACK => { // Dispatch to per-stream unbounded channel. Flow control (WINDOW_UPDATE) // limits bytes-in-flight, so the channel won't grow unbounded. send() only // fails if the receiver is dropped (hub_to_client task already exited). let mut writers = client_writers.lock().await; if let Some(state) = writers.get(&frame.stream_id) { if state.back_tx.send(frame.payload).is_err() { // Receiver dropped — hub_to_client task already exited, clean up writers.remove(&frame.stream_id); } } } FRAME_WINDOW_UPDATE_BACK => { if let Some(increment) = decode_window_update(&frame.payload) { if increment > 0 { let writers = client_writers.lock().await; if let Some(state) = writers.get(&frame.stream_id) { let prev = state.send_window.fetch_add(increment, Ordering::Release); if prev + increment > MAX_WINDOW_SIZE { state.send_window.store(MAX_WINDOW_SIZE, Ordering::Release); } state.window_notify.notify_one(); } } } } FRAME_CLOSE_BACK => { let mut writers = client_writers.lock().await; writers.remove(&frame.stream_id); } FRAME_CONFIG => { if let Ok(update) = serde_json::from_slice::(&frame.payload) { log::info!("Config update from hub: ports {:?}", update.listen_ports); *listen_ports.write().await = update.listen_ports.clone(); let _ = event_tx.try_send(EdgeEvent::PortsUpdated { listen_ports: update.listen_ports.clone(), }); apply_port_config( &update.listen_ports, port_listeners, tunnel_writer_tx, tunnel_data_tx, client_writers, active_streams, next_stream_id, edge_id, connection_token, bind_address, ); } } FRAME_PING => { // Queue PONG directly — no channel round-trip, guaranteed delivery tunnel_io.queue_ctrl(encode_frame(0, FRAME_PONG, &[])); } _ => { log::warn!("Unexpected frame type {} from hub", frame.frame_type); } } EdgeFrameAction::Continue } async fn connect_to_hub_and_run( config: &EdgeConfig, connected: &Arc>, public_ip: &Arc>>, active_streams: &Arc, next_stream_id: &Arc, event_tx: &mpsc::Sender, listen_ports: &Arc>>, shutdown_rx: &mut mpsc::Receiver<()>, connection_token: &CancellationToken, connector: &TlsConnector, ) -> EdgeLoopResult { let addr = format!("{}:{}", config.hub_host, config.hub_port); let tcp = match TcpStream::connect(&addr).await { Ok(s) => { // Disable Nagle's algorithm for low-latency control frames (PING/PONG, WINDOW_UPDATE) let _ = s.set_nodelay(true); // TCP keepalive detects silent network failures (NAT timeout, path change) // faster than the 45s application-level liveness timeout. let ka = socket2::TcpKeepalive::new() .with_time(Duration::from_secs(30)); #[cfg(target_os = "linux")] let ka = ka.with_interval(Duration::from_secs(10)); let _ = socket2::SockRef::from(&s).set_tcp_keepalive(&ka); s } Err(e) => { log::error!("Failed to connect to hub at {}: {}", addr, e); return EdgeLoopResult::Reconnect(format!("tcp_connect_failed: {}", e)); } }; let server_name = rustls::pki_types::ServerName::try_from(config.hub_host.clone()) .unwrap_or_else(|_| rustls::pki_types::ServerName::try_from("remoteingress-hub".to_string()).unwrap()); let mut tls_stream = match connector.connect(server_name, tcp).await { Ok(s) => s, Err(e) => { log::error!("TLS handshake failed: {}", e); return EdgeLoopResult::Reconnect(format!("tls_handshake_failed: {}", e)); } }; // Send auth line (we own the whole stream — no split) let auth_line = format!("EDGE {} {}\n", config.edge_id, config.secret); if tls_stream.write_all(auth_line.as_bytes()).await.is_err() { return EdgeLoopResult::Reconnect("auth_write_failed".to_string()); } if tls_stream.flush().await.is_err() { return EdgeLoopResult::Reconnect("auth_flush_failed".to_string()); } // Read handshake line byte-by-byte (no BufReader — into_inner corrupts TLS state) let mut handshake_bytes = Vec::with_capacity(512); let mut byte = [0u8; 1]; loop { match tls_stream.read_exact(&mut byte).await { Ok(_) => { handshake_bytes.push(byte[0]); if byte[0] == b'\n' { break; } if handshake_bytes.len() > 8192 { return EdgeLoopResult::Reconnect("handshake_too_long".to_string()); } } Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { log::error!("Hub rejected connection (EOF before handshake)"); return EdgeLoopResult::Reconnect("hub_rejected_eof".to_string()); } Err(e) => { log::error!("Failed to read handshake response: {}", e); return EdgeLoopResult::Reconnect(format!("handshake_read_failed: {}", e)); } } } let handshake_line = String::from_utf8_lossy(&handshake_bytes); let handshake: HandshakeConfig = match serde_json::from_str(handshake_line.trim()) { Ok(h) => h, Err(e) => { log::error!("Invalid handshake response: {}", e); return EdgeLoopResult::Reconnect(format!("handshake_invalid: {}", e)); } }; log::info!( "Handshake from hub: ports {:?}, stun_interval {}s", handshake.listen_ports, handshake.stun_interval_secs ); *connected.write().await = true; let _ = event_tx.try_send(EdgeEvent::TunnelConnected); log::info!("Connected to hub at {}", addr); // Store initial ports and emit event *listen_ports.write().await = handshake.listen_ports.clone(); let _ = event_tx.try_send(EdgeEvent::PortsAssigned { listen_ports: handshake.listen_ports.clone(), }); // Start STUN discovery let stun_interval = handshake.stun_interval_secs; let public_ip_clone = public_ip.clone(); let event_tx_clone = event_tx.clone(); let stun_token = connection_token.clone(); let stun_handle = tokio::spawn(async move { loop { tokio::select! { ip_result = crate::stun::discover_public_ip() => { if let Some(ip) = ip_result { let mut pip = public_ip_clone.write().await; let changed = pip.as_ref() != Some(&ip); *pip = Some(ip.clone()); if changed { let _ = event_tx_clone.try_send(EdgeEvent::PublicIpDiscovered { ip }); } } } _ = stun_token.cancelled() => break, } tokio::select! { _ = tokio::time::sleep(Duration::from_secs(stun_interval)) => {} _ = stun_token.cancelled() => break, } } }); // Client socket map: stream_id -> per-stream state (back channel + flow control) let client_writers: Arc>> = Arc::new(Mutex::new(HashMap::new())); // QoS dual-channel: ctrl frames have priority over data frames. // Stream handlers send through these channels → TunnelIo drains them. let (tunnel_ctrl_tx, mut tunnel_ctrl_rx) = mpsc::channel::>(256); let (tunnel_data_tx, mut tunnel_data_rx) = mpsc::channel::>(4096); let tunnel_writer_tx = tunnel_ctrl_tx.clone(); // Start TCP listeners for initial ports let mut port_listeners: HashMap> = HashMap::new(); let bind_address = config.bind_address.as_deref().unwrap_or("0.0.0.0"); apply_port_config( &handshake.listen_ports, &mut port_listeners, &tunnel_writer_tx, &tunnel_data_tx, &client_writers, active_streams, next_stream_id, &config.edge_id, connection_token, bind_address, ); // Single-owner I/O engine — no tokio::io::split, no mutex let mut tunnel_io = remoteingress_protocol::TunnelIo::new(tls_stream, Vec::new()); let liveness_timeout_dur = Duration::from_secs(45); let mut last_activity = Instant::now(); let mut liveness_deadline = Box::pin(sleep_until(last_activity + liveness_timeout_dur)); let result = 'io_loop: loop { // Drain any buffered frames loop { let frame = match tunnel_io.try_parse_frame() { Some(Ok(f)) => f, Some(Err(e)) => { log::error!("Hub frame error: {}", e); break 'io_loop EdgeLoopResult::Reconnect(format!("hub_frame_error: {}", e)); } None => break, }; last_activity = Instant::now(); liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur); if let EdgeFrameAction::Disconnect(reason) = handle_edge_frame( frame, &mut tunnel_io, &client_writers, listen_ports, event_tx, &tunnel_writer_tx, &tunnel_data_tx, &mut port_listeners, active_streams, next_stream_id, &config.edge_id, connection_token, bind_address, ).await { break 'io_loop EdgeLoopResult::Reconnect(reason); } } // Poll I/O: write(ctrl→data), flush, read, channels, timers let event = std::future::poll_fn(|cx| { tunnel_io.poll_step(cx, &mut tunnel_ctrl_rx, &mut tunnel_data_rx, &mut liveness_deadline, connection_token) }).await; match event { remoteingress_protocol::TunnelEvent::Frame(frame) => { last_activity = Instant::now(); liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur); if let EdgeFrameAction::Disconnect(reason) = handle_edge_frame( frame, &mut tunnel_io, &client_writers, listen_ports, event_tx, &tunnel_writer_tx, &tunnel_data_tx, &mut port_listeners, active_streams, next_stream_id, &config.edge_id, connection_token, bind_address, ).await { break EdgeLoopResult::Reconnect(reason); } } remoteingress_protocol::TunnelEvent::Eof => { log::info!("Hub disconnected (EOF)"); break EdgeLoopResult::Reconnect("hub_eof".to_string()); } remoteingress_protocol::TunnelEvent::ReadError(e) => { log::error!("Hub frame read error: {}", e); break EdgeLoopResult::Reconnect(format!("hub_frame_error: {}", e)); } remoteingress_protocol::TunnelEvent::WriteError(e) => { log::error!("Tunnel write error: {}", e); break EdgeLoopResult::Reconnect(format!("tunnel_write_error: {}", e)); } remoteingress_protocol::TunnelEvent::LivenessTimeout => { log::warn!("Hub liveness timeout (no frames for {}s), reconnecting", liveness_timeout_dur.as_secs()); break EdgeLoopResult::Reconnect("liveness_timeout".to_string()); } remoteingress_protocol::TunnelEvent::Cancelled => { if shutdown_rx.try_recv().is_ok() { break EdgeLoopResult::Shutdown; } break EdgeLoopResult::Shutdown; } } }; // Graceful TLS shutdown: send close_notify so the hub sees a clean disconnect // instead of "peer closed connection without sending TLS close_notify". let mut tls_stream = tunnel_io.into_inner(); let _ = tokio::time::timeout( Duration::from_secs(2), tls_stream.shutdown(), ).await; // Cleanup connection_token.cancel(); stun_handle.abort(); for (_, h) in port_listeners.drain() { h.abort(); } result } /// Apply a new port configuration: spawn listeners for added ports, abort removed ports. fn apply_port_config( new_ports: &[u16], port_listeners: &mut HashMap>, tunnel_ctrl_tx: &mpsc::Sender>, tunnel_data_tx: &mpsc::Sender>, client_writers: &Arc>>, active_streams: &Arc, next_stream_id: &Arc, edge_id: &str, connection_token: &CancellationToken, bind_address: &str, ) { let new_set: std::collections::HashSet = new_ports.iter().copied().collect(); let old_set: std::collections::HashSet = port_listeners.keys().copied().collect(); // Remove ports no longer needed for &port in old_set.difference(&new_set) { if let Some(handle) = port_listeners.remove(&port) { log::info!("Stopping listener on port {}", port); handle.abort(); } } // Add new ports for &port in new_set.difference(&old_set) { let tunnel_ctrl_tx = tunnel_ctrl_tx.clone(); let tunnel_data_tx = tunnel_data_tx.clone(); let client_writers = client_writers.clone(); let active_streams = active_streams.clone(); let next_stream_id = next_stream_id.clone(); let edge_id = edge_id.to_string(); let port_token = connection_token.child_token(); let bind_addr = bind_address.to_string(); let handle = tokio::spawn(async move { let listener = match TcpListener::bind((bind_addr.as_str(), port)).await { Ok(l) => l, Err(e) => { log::error!("Failed to bind port {}: {}", port, e); return; } }; log::info!("Listening on port {}", port); loop { tokio::select! { accept_result = listener.accept() => { match accept_result { Ok((client_stream, client_addr)) => { // TCP keepalive detects dead clients that disappear without FIN. // Without this, zombie streams accumulate and never get cleaned up. let _ = client_stream.set_nodelay(true); let ka = socket2::TcpKeepalive::new() .with_time(Duration::from_secs(60)); #[cfg(target_os = "linux")] let ka = ka.with_interval(Duration::from_secs(60)); let _ = socket2::SockRef::from(&client_stream).set_tcp_keepalive(&ka); let stream_id = next_stream_id.fetch_add(1, Ordering::Relaxed); let tunnel_ctrl_tx = tunnel_ctrl_tx.clone(); let tunnel_data_tx = tunnel_data_tx.clone(); let client_writers = client_writers.clone(); let active_streams = active_streams.clone(); let edge_id = edge_id.clone(); let client_token = port_token.child_token(); active_streams.fetch_add(1, Ordering::Relaxed); tokio::spawn(async move { handle_client_connection( client_stream, client_addr, stream_id, port, &edge_id, tunnel_ctrl_tx, tunnel_data_tx, client_writers, client_token, Arc::clone(&active_streams), ) .await; // Saturating decrement: prevent underflow when // edge_main_loop's store(0) races with task cleanup. loop { let current = active_streams.load(Ordering::Relaxed); if current == 0 { break; } if active_streams.compare_exchange_weak( current, current - 1, Ordering::Relaxed, Ordering::Relaxed, ).is_ok() { break; } } }); } Err(e) => { log::error!("Accept error on port {}: {}", port, e); } } } _ = port_token.cancelled() => { log::info!("Port {} listener cancelled", port); break; } } } }); port_listeners.insert(port, handle); } } async fn handle_client_connection( client_stream: TcpStream, client_addr: std::net::SocketAddr, stream_id: u32, dest_port: u16, edge_id: &str, tunnel_ctrl_tx: mpsc::Sender>, tunnel_data_tx: mpsc::Sender>, client_writers: Arc>>, client_token: CancellationToken, active_streams: Arc, ) { let client_ip = client_addr.ip().to_string(); let client_port = client_addr.port(); // Determine edge IP (use 0.0.0.0 as placeholder — hub doesn't use it for routing) let edge_ip = "0.0.0.0"; // Send OPEN frame with PROXY v1 header via control channel let proxy_header = build_proxy_v1_header(&client_ip, edge_ip, client_port, dest_port); let open_frame = encode_frame(stream_id, FRAME_OPEN, proxy_header.as_bytes()); let send_ok = tokio::select! { result = tunnel_ctrl_tx.send(open_frame) => result.is_ok(), _ = client_token.cancelled() => false, }; if !send_ok { return; } // Per-stream unbounded back-channel. Flow control (WINDOW_UPDATE) limits // bytes-in-flight, so this won't grow unbounded. Unbounded avoids killing // streams due to channel overflow — backpressure slows streams, never kills them. let (back_tx, mut back_rx) = mpsc::unbounded_channel::>(); // Adaptive initial window: scale with current stream count to keep total in-flight // data within the 32MB budget. Prevents burst flooding when many streams open. let initial_window = remoteingress_protocol::compute_window_for_stream_count( active_streams.load(Ordering::Relaxed), ); let send_window = Arc::new(AtomicU32::new(initial_window)); let window_notify = Arc::new(Notify::new()); { let mut writers = client_writers.lock().await; writers.insert(stream_id, EdgeStreamState { back_tx, send_window: Arc::clone(&send_window), window_notify: Arc::clone(&window_notify), }); } let (mut client_read, mut client_write) = client_stream.into_split(); // Task: hub -> client (download direction) // After writing to client TCP, send WINDOW_UPDATE to hub so it can send more let hub_to_client_token = client_token.clone(); let wu_tx = tunnel_ctrl_tx.clone(); let active_streams_h2c = Arc::clone(&active_streams); let mut hub_to_client = tokio::spawn(async move { let mut consumed_since_update: u32 = 0; loop { tokio::select! { data = back_rx.recv() => { match data { Some(data) => { let len = data.len() as u32; if client_write.write_all(&data).await.is_err() { break; } // Track consumption for adaptive flow control. // The increment is capped to the adaptive window so the sender's // effective window shrinks to match current demand (fewer streams // = larger window, more streams = smaller window per stream). consumed_since_update += len; let adaptive_window = remoteingress_protocol::compute_window_for_stream_count( active_streams_h2c.load(Ordering::Relaxed), ); let threshold = adaptive_window / 2; if consumed_since_update >= threshold { let increment = consumed_since_update.min(adaptive_window); let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE, increment); // Use send().await for guaranteed delivery — dropping WINDOW_UPDATEs // causes permanent flow stalls. Safe: runs in per-stream task, not main loop. tokio::select! { result = wu_tx.send(frame) => { if result.is_ok() { consumed_since_update -= increment; } } _ = hub_to_client_token.cancelled() => break, } } } None => break, } } _ = hub_to_client_token.cancelled() => break, } } // Send final window update for any remaining consumed bytes if consumed_since_update > 0 { let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE, consumed_since_update); tokio::select! { _ = wu_tx.send(frame) => {} _ = hub_to_client_token.cancelled() => {} } } let _ = client_write.shutdown().await; }); // Task: client -> hub (upload direction) with per-stream flow control. // Zero-copy: read payload directly after the header, then prepend header. let mut buf = vec![0u8; FRAME_HEADER_SIZE + 32768]; loop { // Wait for send window to have capacity (with stall timeout). // Safe pattern: register notified BEFORE checking the condition // to avoid missing a notify_one that fires between load and select. loop { let notified = window_notify.notified(); tokio::pin!(notified); notified.as_mut().enable(); let w = send_window.load(Ordering::Acquire); if w > 0 { break; } tokio::select! { _ = notified => continue, _ = client_token.cancelled() => break, _ = tokio::time::sleep(Duration::from_secs(120)) => { log::warn!("Stream {} upload stalled (window empty for 120s)", stream_id); break; } } } if client_token.is_cancelled() { break; } // Limit read size to available window. // IMPORTANT: if window is 0 (stall timeout fired), we must NOT // read into an empty buffer — read(&mut buf[..0]) returns Ok(0) // which would be falsely interpreted as EOF. let w = send_window.load(Ordering::Acquire) as usize; if w == 0 { log::warn!("Stream {} upload: window still 0 after stall timeout, closing", stream_id); break; } // Adaptive: cap read to current per-stream target window let adaptive_cap = remoteingress_protocol::compute_window_for_stream_count( active_streams.load(Ordering::Relaxed), ) as usize; let max_read = w.min(32768).min(adaptive_cap); tokio::select! { read_result = client_read.read(&mut buf[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + max_read]) => { match read_result { Ok(0) => break, Ok(n) => { send_window.fetch_sub(n as u32, Ordering::Release); encode_frame_header(&mut buf, stream_id, FRAME_DATA, n); let data_frame = buf[..FRAME_HEADER_SIZE + n].to_vec(); if tunnel_data_tx.send(data_frame).await.is_err() { log::warn!("Stream {} data channel closed, closing", stream_id); break; } } Err(_) => break, } } _ = client_token.cancelled() => break, } } // Wait for the download task (hub → client) to finish BEFORE sending CLOSE. // Upload EOF (client done sending) does NOT mean the response is done. // For asymmetric transfers like git fetch (small request, large response), // the response is still streaming when the upload finishes. // Sending CLOSE before the response finishes would cause the hub to cancel // the upstream reader mid-response, truncating the data. let _ = tokio::time::timeout( Duration::from_secs(300), // 5 min max wait for download to finish &mut hub_to_client, ).await; // NOW send CLOSE — the response has been fully delivered (or timed out). // select! with cancellation guard prevents indefinite blocking if tunnel dies. if !client_token.is_cancelled() { let close_frame = encode_frame(stream_id, FRAME_CLOSE, &[]); tokio::select! { _ = tunnel_data_tx.send(close_frame) => {} _ = client_token.cancelled() => {} } } // Clean up { let mut writers = client_writers.lock().await; writers.remove(&stream_id); } hub_to_client.abort(); // No-op if already finished; safety net if timeout fired let _ = edge_id; // used for logging context } #[cfg(test)] mod tests { use super::*; // --- Serde tests --- #[test] fn test_edge_config_deserialize_camel_case() { let json = r#"{ "hubHost": "hub.example.com", "hubPort": 8443, "edgeId": "edge-1", "secret": "my-secret" }"#; let config: EdgeConfig = serde_json::from_str(json).unwrap(); assert_eq!(config.hub_host, "hub.example.com"); assert_eq!(config.hub_port, 8443); assert_eq!(config.edge_id, "edge-1"); assert_eq!(config.secret, "my-secret"); } #[test] fn test_edge_config_serialize_roundtrip() { let config = EdgeConfig { hub_host: "host.test".to_string(), hub_port: 9999, edge_id: "e1".to_string(), secret: "sec".to_string(), bind_address: None, }; let json = serde_json::to_string(&config).unwrap(); let back: EdgeConfig = serde_json::from_str(&json).unwrap(); assert_eq!(back.hub_host, config.hub_host); assert_eq!(back.hub_port, config.hub_port); assert_eq!(back.edge_id, config.edge_id); assert_eq!(back.secret, config.secret); } #[test] fn test_handshake_config_deserialize_all_fields() { let json = r#"{"listenPorts": [80, 443], "stunIntervalSecs": 120}"#; let hc: HandshakeConfig = serde_json::from_str(json).unwrap(); assert_eq!(hc.listen_ports, vec![80, 443]); assert_eq!(hc.stun_interval_secs, 120); } #[test] fn test_handshake_config_default_stun_interval() { let json = r#"{"listenPorts": [443]}"#; let hc: HandshakeConfig = serde_json::from_str(json).unwrap(); assert_eq!(hc.listen_ports, vec![443]); assert_eq!(hc.stun_interval_secs, 300); } #[test] fn test_config_update_deserialize() { let json = r#"{"listenPorts": [8080, 9090]}"#; let update: ConfigUpdate = serde_json::from_str(json).unwrap(); assert_eq!(update.listen_ports, vec![8080, 9090]); } #[test] fn test_edge_status_serialize() { let status = EdgeStatus { running: true, connected: true, public_ip: Some("1.2.3.4".to_string()), active_streams: 5, listen_ports: vec![443], }; let json = serde_json::to_value(&status).unwrap(); assert_eq!(json["running"], true); assert_eq!(json["connected"], true); assert_eq!(json["publicIp"], "1.2.3.4"); assert_eq!(json["activeStreams"], 5); assert_eq!(json["listenPorts"], serde_json::json!([443])); } #[test] fn test_edge_status_serialize_none_ip() { let status = EdgeStatus { running: false, connected: false, public_ip: None, active_streams: 0, listen_ports: vec![], }; let json = serde_json::to_value(&status).unwrap(); assert!(json["publicIp"].is_null()); } #[test] fn test_edge_event_tunnel_connected() { let event = EdgeEvent::TunnelConnected; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "tunnelConnected"); } #[test] fn test_edge_event_tunnel_disconnected() { let event = EdgeEvent::TunnelDisconnected { reason: "hub_eof".to_string() }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "tunnelDisconnected"); assert_eq!(json["reason"], "hub_eof"); } #[test] fn test_edge_event_public_ip_discovered() { let event = EdgeEvent::PublicIpDiscovered { ip: "203.0.113.1".to_string(), }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "publicIpDiscovered"); assert_eq!(json["ip"], "203.0.113.1"); } #[test] fn test_edge_event_ports_assigned() { let event = EdgeEvent::PortsAssigned { listen_ports: vec![443, 8080], }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "portsAssigned"); assert_eq!(json["listenPorts"], serde_json::json!([443, 8080])); } #[test] fn test_edge_event_ports_updated() { let event = EdgeEvent::PortsUpdated { listen_ports: vec![9090], }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "portsUpdated"); assert_eq!(json["listenPorts"], serde_json::json!([9090])); } // --- Async tests --- #[tokio::test] async fn test_tunnel_edge_new_get_status() { let edge = TunnelEdge::new(EdgeConfig { hub_host: "localhost".to_string(), hub_port: 8443, edge_id: "test-edge".to_string(), secret: "test-secret".to_string(), bind_address: None, }); let status = edge.get_status().await; assert!(!status.running); assert!(!status.connected); assert!(status.public_ip.is_none()); assert_eq!(status.active_streams, 0); assert!(status.listen_ports.is_empty()); } #[tokio::test] async fn test_tunnel_edge_take_event_rx() { let edge = TunnelEdge::new(EdgeConfig { hub_host: "localhost".to_string(), hub_port: 8443, edge_id: "e".to_string(), secret: "s".to_string(), bind_address: None, }); let rx1 = edge.take_event_rx().await; assert!(rx1.is_some()); let rx2 = edge.take_event_rx().await; assert!(rx2.is_none()); } #[tokio::test] async fn test_tunnel_edge_stop_without_start() { let edge = TunnelEdge::new(EdgeConfig { hub_host: "localhost".to_string(), hub_port: 8443, edge_id: "e".to_string(), secret: "s".to_string(), bind_address: None, }); edge.stop().await; // should not panic let status = edge.get_status().await; assert!(!status.running); } } /// TLS certificate verifier that accepts any certificate (auth is via shared secret). #[derive(Debug)] struct NoCertVerifier; impl rustls::client::danger::ServerCertVerifier for NoCertVerifier { fn verify_server_cert( &self, _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ECDSA_NISTP521_SHA512, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::ED448, ] } }