1088 lines
43 KiB
Rust
1088 lines
43 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::atomic::{AtomicU32, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::{mpsc, Mutex, Notify, RwLock};
|
|
use tokio::task::JoinHandle;
|
|
use tokio::time::{Instant, sleep_until};
|
|
use tokio_rustls::TlsConnector;
|
|
use tokio_util::sync::CancellationToken;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use remoteingress_protocol::*;
|
|
|
|
/// Per-stream state tracked in the edge's client_writers map.
|
|
struct EdgeStreamState {
|
|
/// Channel to deliver FRAME_DATA_BACK payloads to the hub_to_client task.
|
|
back_tx: mpsc::Sender<Vec<u8>>,
|
|
/// Send window for FRAME_DATA (upload direction).
|
|
/// Decremented by the client reader, incremented by FRAME_WINDOW_UPDATE_BACK from hub.
|
|
send_window: Arc<AtomicU32>,
|
|
/// Notifier to wake the client reader when the window opens.
|
|
window_notify: Arc<Notify>,
|
|
}
|
|
|
|
/// Edge configuration (hub-host + credentials only; ports come from hub).
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct EdgeConfig {
|
|
pub hub_host: String,
|
|
pub hub_port: u16,
|
|
pub edge_id: String,
|
|
pub secret: String,
|
|
/// Optional bind address for TCP listeners (defaults to "0.0.0.0").
|
|
/// Useful for testing on localhost where edge and upstream share the same machine.
|
|
#[serde(default)]
|
|
pub bind_address: Option<String>,
|
|
}
|
|
|
|
/// Handshake config received from hub after authentication.
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct HandshakeConfig {
|
|
listen_ports: Vec<u16>,
|
|
#[serde(default = "default_stun_interval")]
|
|
stun_interval_secs: u64,
|
|
}
|
|
|
|
fn default_stun_interval() -> u64 {
|
|
300
|
|
}
|
|
|
|
/// Runtime config update received from hub via FRAME_CONFIG.
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct ConfigUpdate {
|
|
listen_ports: Vec<u16>,
|
|
}
|
|
|
|
/// Events emitted by the edge.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[serde(tag = "type")]
|
|
pub enum EdgeEvent {
|
|
TunnelConnected,
|
|
TunnelDisconnected,
|
|
#[serde(rename_all = "camelCase")]
|
|
PublicIpDiscovered { ip: String },
|
|
#[serde(rename_all = "camelCase")]
|
|
PortsAssigned { listen_ports: Vec<u16> },
|
|
#[serde(rename_all = "camelCase")]
|
|
PortsUpdated { listen_ports: Vec<u16> },
|
|
}
|
|
|
|
/// Edge status response.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct EdgeStatus {
|
|
pub running: bool,
|
|
pub connected: bool,
|
|
pub public_ip: Option<String>,
|
|
pub active_streams: usize,
|
|
pub listen_ports: Vec<u16>,
|
|
}
|
|
|
|
/// The tunnel edge that listens for client connections and multiplexes them to the hub.
|
|
pub struct TunnelEdge {
|
|
config: RwLock<EdgeConfig>,
|
|
event_tx: mpsc::Sender<EdgeEvent>,
|
|
event_rx: Mutex<Option<mpsc::Receiver<EdgeEvent>>>,
|
|
shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
|
|
running: RwLock<bool>,
|
|
connected: Arc<RwLock<bool>>,
|
|
public_ip: Arc<RwLock<Option<String>>>,
|
|
active_streams: Arc<AtomicU32>,
|
|
next_stream_id: Arc<AtomicU32>,
|
|
listen_ports: Arc<RwLock<Vec<u16>>>,
|
|
cancel_token: CancellationToken,
|
|
}
|
|
|
|
impl TunnelEdge {
|
|
pub fn new(config: EdgeConfig) -> Self {
|
|
let (event_tx, event_rx) = mpsc::channel(1024);
|
|
Self {
|
|
config: RwLock::new(config),
|
|
event_tx,
|
|
event_rx: Mutex::new(Some(event_rx)),
|
|
shutdown_tx: Mutex::new(None),
|
|
running: RwLock::new(false),
|
|
connected: Arc::new(RwLock::new(false)),
|
|
public_ip: Arc::new(RwLock::new(None)),
|
|
active_streams: Arc::new(AtomicU32::new(0)),
|
|
next_stream_id: Arc::new(AtomicU32::new(1)),
|
|
listen_ports: Arc::new(RwLock::new(Vec::new())),
|
|
cancel_token: CancellationToken::new(),
|
|
}
|
|
}
|
|
|
|
/// Take the event receiver (can only be called once).
|
|
pub async fn take_event_rx(&self) -> Option<mpsc::Receiver<EdgeEvent>> {
|
|
self.event_rx.lock().await.take()
|
|
}
|
|
|
|
/// Get the current edge status.
|
|
pub async fn get_status(&self) -> EdgeStatus {
|
|
EdgeStatus {
|
|
running: *self.running.read().await,
|
|
connected: *self.connected.read().await,
|
|
public_ip: self.public_ip.read().await.clone(),
|
|
active_streams: self.active_streams.load(Ordering::Relaxed) as usize,
|
|
listen_ports: self.listen_ports.read().await.clone(),
|
|
}
|
|
}
|
|
|
|
/// Start the edge: connect to hub, start listeners.
|
|
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let config = self.config.read().await.clone();
|
|
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
|
|
*self.shutdown_tx.lock().await = Some(shutdown_tx);
|
|
*self.running.write().await = true;
|
|
|
|
let connected = self.connected.clone();
|
|
let public_ip = self.public_ip.clone();
|
|
let active_streams = self.active_streams.clone();
|
|
let next_stream_id = self.next_stream_id.clone();
|
|
let event_tx = self.event_tx.clone();
|
|
let listen_ports = self.listen_ports.clone();
|
|
let cancel_token = self.cancel_token.clone();
|
|
|
|
tokio::spawn(async move {
|
|
edge_main_loop(
|
|
config,
|
|
connected,
|
|
public_ip,
|
|
active_streams,
|
|
next_stream_id,
|
|
event_tx,
|
|
listen_ports,
|
|
shutdown_rx,
|
|
cancel_token,
|
|
)
|
|
.await;
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Stop the edge.
|
|
pub async fn stop(&self) {
|
|
self.cancel_token.cancel();
|
|
if let Some(tx) = self.shutdown_tx.lock().await.take() {
|
|
let _ = tx.send(()).await;
|
|
}
|
|
*self.running.write().await = false;
|
|
*self.connected.write().await = false;
|
|
self.listen_ports.write().await.clear();
|
|
}
|
|
}
|
|
|
|
impl Drop for TunnelEdge {
|
|
fn drop(&mut self) {
|
|
self.cancel_token.cancel();
|
|
}
|
|
}
|
|
|
|
async fn edge_main_loop(
|
|
config: EdgeConfig,
|
|
connected: Arc<RwLock<bool>>,
|
|
public_ip: Arc<RwLock<Option<String>>>,
|
|
active_streams: Arc<AtomicU32>,
|
|
next_stream_id: Arc<AtomicU32>,
|
|
event_tx: mpsc::Sender<EdgeEvent>,
|
|
listen_ports: Arc<RwLock<Vec<u16>>>,
|
|
mut shutdown_rx: mpsc::Receiver<()>,
|
|
cancel_token: CancellationToken,
|
|
) {
|
|
let mut backoff_ms: u64 = 1000;
|
|
let max_backoff_ms: u64 = 30000;
|
|
|
|
// Build TLS config ONCE outside the reconnect loop — preserves session
|
|
// cache across reconnections for TLS session resumption (saves 1 RTT).
|
|
let tls_config = rustls::ClientConfig::builder()
|
|
.dangerous()
|
|
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
|
|
.with_no_client_auth();
|
|
let connector = TlsConnector::from(Arc::new(tls_config));
|
|
|
|
loop {
|
|
// Create a per-connection child token
|
|
let connection_token = cancel_token.child_token();
|
|
|
|
// Try to connect to hub
|
|
let result = connect_to_hub_and_run(
|
|
&config,
|
|
&connected,
|
|
&public_ip,
|
|
&active_streams,
|
|
&next_stream_id,
|
|
&event_tx,
|
|
&listen_ports,
|
|
&mut shutdown_rx,
|
|
&connection_token,
|
|
&connector,
|
|
)
|
|
.await;
|
|
|
|
// Cancel connection token to kill all orphaned tasks from this cycle
|
|
connection_token.cancel();
|
|
|
|
// Reset backoff after a successful connection for fast reconnect
|
|
let was_connected = *connected.read().await;
|
|
if was_connected {
|
|
backoff_ms = 1000;
|
|
log::info!("Was connected; resetting backoff to {}ms for fast reconnect", backoff_ms);
|
|
}
|
|
|
|
*connected.write().await = false;
|
|
// Only emit disconnect event on actual disconnection, not on failed reconnects.
|
|
// Failed reconnects never reach line 335 (handshake success), so was_connected is false.
|
|
if was_connected {
|
|
let _ = event_tx.try_send(EdgeEvent::TunnelDisconnected);
|
|
}
|
|
active_streams.store(0, Ordering::Relaxed);
|
|
// Reset stream ID counter for next connection cycle
|
|
next_stream_id.store(1, Ordering::Relaxed);
|
|
listen_ports.write().await.clear();
|
|
|
|
match result {
|
|
EdgeLoopResult::Shutdown => break,
|
|
EdgeLoopResult::Reconnect => {
|
|
log::info!("Reconnecting in {}ms...", backoff_ms);
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
|
|
_ = cancel_token.cancelled() => break,
|
|
_ = shutdown_rx.recv() => break,
|
|
}
|
|
backoff_ms = (backoff_ms * 2).min(max_backoff_ms);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
enum EdgeLoopResult {
|
|
Shutdown,
|
|
Reconnect,
|
|
}
|
|
|
|
async fn connect_to_hub_and_run(
|
|
config: &EdgeConfig,
|
|
connected: &Arc<RwLock<bool>>,
|
|
public_ip: &Arc<RwLock<Option<String>>>,
|
|
active_streams: &Arc<AtomicU32>,
|
|
next_stream_id: &Arc<AtomicU32>,
|
|
event_tx: &mpsc::Sender<EdgeEvent>,
|
|
listen_ports: &Arc<RwLock<Vec<u16>>>,
|
|
shutdown_rx: &mut mpsc::Receiver<()>,
|
|
connection_token: &CancellationToken,
|
|
connector: &TlsConnector,
|
|
) -> EdgeLoopResult {
|
|
|
|
let addr = format!("{}:{}", config.hub_host, config.hub_port);
|
|
let tcp = match TcpStream::connect(&addr).await {
|
|
Ok(s) => {
|
|
// Disable Nagle's algorithm for low-latency control frames (PING/PONG, WINDOW_UPDATE)
|
|
let _ = s.set_nodelay(true);
|
|
// TCP keepalive detects silent network failures (NAT timeout, path change)
|
|
// faster than the 45s application-level liveness timeout.
|
|
let ka = socket2::TcpKeepalive::new()
|
|
.with_time(Duration::from_secs(30));
|
|
#[cfg(target_os = "linux")]
|
|
let ka = ka.with_interval(Duration::from_secs(10));
|
|
let _ = socket2::SockRef::from(&s).set_tcp_keepalive(&ka);
|
|
s
|
|
}
|
|
Err(e) => {
|
|
log::error!("Failed to connect to hub at {}: {}", addr, e);
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
};
|
|
|
|
let server_name = rustls::pki_types::ServerName::try_from(config.hub_host.clone())
|
|
.unwrap_or_else(|_| rustls::pki_types::ServerName::try_from("remoteingress-hub".to_string()).unwrap());
|
|
|
|
let tls_stream = match connector.connect(server_name, tcp).await {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
log::error!("TLS handshake failed: {}", e);
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
};
|
|
|
|
let (read_half, mut write_half) = tokio::io::split(tls_stream);
|
|
|
|
// Send auth line
|
|
let auth_line = format!("EDGE {} {}\n", config.edge_id, config.secret);
|
|
if write_half.write_all(auth_line.as_bytes()).await.is_err() {
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
|
|
// Read handshake response line from hub (JSON with initial config)
|
|
let mut buf_reader = BufReader::new(read_half);
|
|
let mut handshake_line = String::new();
|
|
match buf_reader.read_line(&mut handshake_line).await {
|
|
Ok(0) => {
|
|
log::error!("Hub rejected connection (EOF before handshake)");
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
log::error!("Failed to read handshake response: {}", e);
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
}
|
|
|
|
let handshake: HandshakeConfig = match serde_json::from_str(handshake_line.trim()) {
|
|
Ok(h) => h,
|
|
Err(e) => {
|
|
log::error!("Invalid handshake response: {}", e);
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
};
|
|
|
|
log::info!(
|
|
"Handshake from hub: ports {:?}, stun_interval {}s",
|
|
handshake.listen_ports,
|
|
handshake.stun_interval_secs
|
|
);
|
|
|
|
*connected.write().await = true;
|
|
let _ = event_tx.try_send(EdgeEvent::TunnelConnected);
|
|
log::info!("Connected to hub at {}", addr);
|
|
|
|
// Store initial ports and emit event
|
|
*listen_ports.write().await = handshake.listen_ports.clone();
|
|
let _ = event_tx.try_send(EdgeEvent::PortsAssigned {
|
|
listen_ports: handshake.listen_ports.clone(),
|
|
});
|
|
|
|
// Start STUN discovery
|
|
let stun_interval = handshake.stun_interval_secs;
|
|
let public_ip_clone = public_ip.clone();
|
|
let event_tx_clone = event_tx.clone();
|
|
let stun_token = connection_token.clone();
|
|
let stun_handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
ip_result = crate::stun::discover_public_ip() => {
|
|
if let Some(ip) = ip_result {
|
|
let mut pip = public_ip_clone.write().await;
|
|
let changed = pip.as_ref() != Some(&ip);
|
|
*pip = Some(ip.clone());
|
|
if changed {
|
|
let _ = event_tx_clone.try_send(EdgeEvent::PublicIpDiscovered { ip });
|
|
}
|
|
}
|
|
}
|
|
_ = stun_token.cancelled() => break,
|
|
}
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(Duration::from_secs(stun_interval)) => {}
|
|
_ = stun_token.cancelled() => break,
|
|
}
|
|
}
|
|
});
|
|
|
|
// Client socket map: stream_id -> per-stream state (back channel + flow control)
|
|
let client_writers: Arc<Mutex<HashMap<u32, EdgeStreamState>>> =
|
|
Arc::new(Mutex::new(HashMap::new()));
|
|
|
|
// QoS dual-channel tunnel writer: control frames (PONG/WINDOW_UPDATE/CLOSE/OPEN)
|
|
// have priority over data frames (DATA). Prevents PING starvation under load.
|
|
let (tunnel_ctrl_tx, mut tunnel_ctrl_rx) = mpsc::channel::<Vec<u8>>(256);
|
|
let (tunnel_data_tx, mut tunnel_data_rx) = mpsc::channel::<Vec<u8>>(4096);
|
|
// Legacy alias — control channel for PONG, CLOSE, WINDOW_UPDATE, OPEN
|
|
let tunnel_writer_tx = tunnel_ctrl_tx.clone();
|
|
let tw_token = connection_token.clone();
|
|
// Oneshot to signal the reader loop when the writer dies from a write error.
|
|
// This avoids the 45s liveness timeout delay when the tunnel is already dead.
|
|
let (writer_dead_tx, mut writer_dead_rx) = tokio::sync::oneshot::channel::<()>();
|
|
let tunnel_writer_handle = tokio::spawn(async move {
|
|
// BufWriter coalesces small writes (frame headers, control frames) into fewer
|
|
// TLS records and syscalls. Flushed after each frame to avoid holding data.
|
|
let mut writer = tokio::io::BufWriter::with_capacity(65536, write_half);
|
|
let mut write_error = false;
|
|
loop {
|
|
tokio::select! {
|
|
biased; // control frames always take priority over data
|
|
ctrl = tunnel_ctrl_rx.recv() => {
|
|
match ctrl {
|
|
Some(frame_data) => {
|
|
if writer.write_all(&frame_data).await.is_err() { write_error = true; break; }
|
|
if writer.flush().await.is_err() { write_error = true; break; }
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
data = tunnel_data_rx.recv() => {
|
|
match data {
|
|
Some(frame_data) => {
|
|
if writer.write_all(&frame_data).await.is_err() { write_error = true; break; }
|
|
if writer.flush().await.is_err() { write_error = true; break; }
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
_ = tw_token.cancelled() => break,
|
|
}
|
|
}
|
|
if write_error {
|
|
log::error!("Tunnel writer failed, signalling reader for fast reconnect");
|
|
let _ = writer_dead_tx.send(());
|
|
}
|
|
});
|
|
|
|
// Start TCP listeners for initial ports (hot-reloadable)
|
|
let mut port_listeners: HashMap<u16, JoinHandle<()>> = HashMap::new();
|
|
let bind_address = config.bind_address.as_deref().unwrap_or("0.0.0.0");
|
|
apply_port_config(
|
|
&handshake.listen_ports,
|
|
&mut port_listeners,
|
|
&tunnel_writer_tx,
|
|
&tunnel_data_tx,
|
|
&client_writers,
|
|
active_streams,
|
|
next_stream_id,
|
|
&config.edge_id,
|
|
connection_token,
|
|
bind_address,
|
|
);
|
|
|
|
// Heartbeat: liveness timeout detects silent hub failures
|
|
let liveness_timeout_dur = Duration::from_secs(45);
|
|
let mut last_activity = Instant::now();
|
|
let mut liveness_deadline = Box::pin(sleep_until(last_activity + liveness_timeout_dur));
|
|
|
|
// Read frames from hub
|
|
let mut frame_reader = FrameReader::new(buf_reader);
|
|
let result = loop {
|
|
tokio::select! {
|
|
frame_result = frame_reader.next_frame() => {
|
|
match frame_result {
|
|
Ok(Some(frame)) => {
|
|
// Reset liveness on any received frame
|
|
last_activity = Instant::now();
|
|
liveness_deadline.as_mut().reset(last_activity + liveness_timeout_dur);
|
|
|
|
match frame.frame_type {
|
|
FRAME_DATA_BACK => {
|
|
// Non-blocking dispatch to per-stream channel.
|
|
// With flow control, the sender should rarely exceed the channel capacity.
|
|
let mut writers = client_writers.lock().await;
|
|
if let Some(state) = writers.get(&frame.stream_id) {
|
|
if state.back_tx.try_send(frame.payload).is_err() {
|
|
log::warn!("Stream {} back-channel full, closing stream", frame.stream_id);
|
|
writers.remove(&frame.stream_id);
|
|
}
|
|
}
|
|
}
|
|
FRAME_WINDOW_UPDATE_BACK => {
|
|
// Hub consumed data — increase our send window for this stream (upload direction)
|
|
if let Some(increment) = decode_window_update(&frame.payload) {
|
|
if increment > 0 {
|
|
let writers = client_writers.lock().await;
|
|
if let Some(state) = writers.get(&frame.stream_id) {
|
|
let prev = state.send_window.fetch_add(increment, Ordering::Release);
|
|
if prev + increment > MAX_WINDOW_SIZE {
|
|
state.send_window.store(MAX_WINDOW_SIZE, Ordering::Release);
|
|
}
|
|
state.window_notify.notify_one();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
FRAME_CLOSE_BACK => {
|
|
let mut writers = client_writers.lock().await;
|
|
writers.remove(&frame.stream_id);
|
|
}
|
|
FRAME_CONFIG => {
|
|
if let Ok(update) = serde_json::from_slice::<ConfigUpdate>(&frame.payload) {
|
|
log::info!("Config update from hub: ports {:?}", update.listen_ports);
|
|
*listen_ports.write().await = update.listen_ports.clone();
|
|
let _ = event_tx.try_send(EdgeEvent::PortsUpdated {
|
|
listen_ports: update.listen_ports.clone(),
|
|
});
|
|
apply_port_config(
|
|
&update.listen_ports,
|
|
&mut port_listeners,
|
|
&tunnel_writer_tx,
|
|
&tunnel_data_tx,
|
|
&client_writers,
|
|
active_streams,
|
|
next_stream_id,
|
|
&config.edge_id,
|
|
connection_token,
|
|
bind_address,
|
|
);
|
|
}
|
|
}
|
|
FRAME_PING => {
|
|
let pong_frame = encode_frame(0, FRAME_PONG, &[]);
|
|
if tunnel_writer_tx.try_send(pong_frame).is_err() {
|
|
// Control channel full (WINDOW_UPDATE burst from many streams).
|
|
// DON'T disconnect — the 45s liveness timeout gives margin
|
|
// for the channel to drain and the next PONG to succeed.
|
|
log::warn!("PONG send failed, control channel full — skipping this cycle");
|
|
}
|
|
log::trace!("Received PING from hub, sent PONG");
|
|
}
|
|
_ => {
|
|
log::warn!("Unexpected frame type {} from hub", frame.frame_type);
|
|
}
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
log::info!("Hub disconnected (EOF)");
|
|
break EdgeLoopResult::Reconnect;
|
|
}
|
|
Err(e) => {
|
|
log::error!("Hub frame error: {}", e);
|
|
break EdgeLoopResult::Reconnect;
|
|
}
|
|
}
|
|
}
|
|
_ = &mut liveness_deadline => {
|
|
log::warn!("Hub liveness timeout (no frames for {}s), reconnecting",
|
|
liveness_timeout_dur.as_secs());
|
|
break EdgeLoopResult::Reconnect;
|
|
}
|
|
_ = &mut writer_dead_rx => {
|
|
log::error!("Tunnel writer died, reconnecting immediately");
|
|
break EdgeLoopResult::Reconnect;
|
|
}
|
|
_ = connection_token.cancelled() => {
|
|
log::info!("Connection cancelled");
|
|
break EdgeLoopResult::Shutdown;
|
|
}
|
|
_ = shutdown_rx.recv() => {
|
|
break EdgeLoopResult::Shutdown;
|
|
}
|
|
}
|
|
};
|
|
|
|
// Cancel connection token to propagate to all child tasks BEFORE aborting
|
|
connection_token.cancel();
|
|
stun_handle.abort();
|
|
tunnel_writer_handle.abort();
|
|
for (_, h) in port_listeners.drain() {
|
|
h.abort();
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// Apply a new port configuration: spawn listeners for added ports, abort removed ports.
|
|
fn apply_port_config(
|
|
new_ports: &[u16],
|
|
port_listeners: &mut HashMap<u16, JoinHandle<()>>,
|
|
tunnel_ctrl_tx: &mpsc::Sender<Vec<u8>>,
|
|
tunnel_data_tx: &mpsc::Sender<Vec<u8>>,
|
|
client_writers: &Arc<Mutex<HashMap<u32, EdgeStreamState>>>,
|
|
active_streams: &Arc<AtomicU32>,
|
|
next_stream_id: &Arc<AtomicU32>,
|
|
edge_id: &str,
|
|
connection_token: &CancellationToken,
|
|
bind_address: &str,
|
|
) {
|
|
let new_set: std::collections::HashSet<u16> = new_ports.iter().copied().collect();
|
|
let old_set: std::collections::HashSet<u16> = port_listeners.keys().copied().collect();
|
|
|
|
// Remove ports no longer needed
|
|
for &port in old_set.difference(&new_set) {
|
|
if let Some(handle) = port_listeners.remove(&port) {
|
|
log::info!("Stopping listener on port {}", port);
|
|
handle.abort();
|
|
}
|
|
}
|
|
|
|
// Add new ports
|
|
for &port in new_set.difference(&old_set) {
|
|
let tunnel_ctrl_tx = tunnel_ctrl_tx.clone();
|
|
let tunnel_data_tx = tunnel_data_tx.clone();
|
|
let client_writers = client_writers.clone();
|
|
let active_streams = active_streams.clone();
|
|
let next_stream_id = next_stream_id.clone();
|
|
let edge_id = edge_id.to_string();
|
|
let port_token = connection_token.child_token();
|
|
|
|
let bind_addr = bind_address.to_string();
|
|
let handle = tokio::spawn(async move {
|
|
let listener = match TcpListener::bind((bind_addr.as_str(), port)).await {
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
log::error!("Failed to bind port {}: {}", port, e);
|
|
return;
|
|
}
|
|
};
|
|
log::info!("Listening on port {}", port);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
accept_result = listener.accept() => {
|
|
match accept_result {
|
|
Ok((client_stream, client_addr)) => {
|
|
// TCP keepalive detects dead clients that disappear without FIN.
|
|
// Without this, zombie streams accumulate and never get cleaned up.
|
|
let _ = client_stream.set_nodelay(true);
|
|
let ka = socket2::TcpKeepalive::new()
|
|
.with_time(Duration::from_secs(60));
|
|
#[cfg(target_os = "linux")]
|
|
let ka = ka.with_interval(Duration::from_secs(60));
|
|
let _ = socket2::SockRef::from(&client_stream).set_tcp_keepalive(&ka);
|
|
|
|
let stream_id = next_stream_id.fetch_add(1, Ordering::Relaxed);
|
|
let tunnel_ctrl_tx = tunnel_ctrl_tx.clone();
|
|
let tunnel_data_tx = tunnel_data_tx.clone();
|
|
let client_writers = client_writers.clone();
|
|
let active_streams = active_streams.clone();
|
|
let edge_id = edge_id.clone();
|
|
let client_token = port_token.child_token();
|
|
|
|
active_streams.fetch_add(1, Ordering::Relaxed);
|
|
|
|
tokio::spawn(async move {
|
|
handle_client_connection(
|
|
client_stream,
|
|
client_addr,
|
|
stream_id,
|
|
port,
|
|
&edge_id,
|
|
tunnel_ctrl_tx,
|
|
tunnel_data_tx,
|
|
client_writers,
|
|
client_token,
|
|
Arc::clone(&active_streams),
|
|
)
|
|
.await;
|
|
// Saturating decrement: prevent underflow when
|
|
// edge_main_loop's store(0) races with task cleanup.
|
|
loop {
|
|
let current = active_streams.load(Ordering::Relaxed);
|
|
if current == 0 { break; }
|
|
if active_streams.compare_exchange_weak(
|
|
current, current - 1,
|
|
Ordering::Relaxed, Ordering::Relaxed,
|
|
).is_ok() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
Err(e) => {
|
|
log::error!("Accept error on port {}: {}", port, e);
|
|
}
|
|
}
|
|
}
|
|
_ = port_token.cancelled() => {
|
|
log::info!("Port {} listener cancelled", port);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
port_listeners.insert(port, handle);
|
|
}
|
|
}
|
|
|
|
async fn handle_client_connection(
|
|
client_stream: TcpStream,
|
|
client_addr: std::net::SocketAddr,
|
|
stream_id: u32,
|
|
dest_port: u16,
|
|
edge_id: &str,
|
|
tunnel_ctrl_tx: mpsc::Sender<Vec<u8>>,
|
|
tunnel_data_tx: mpsc::Sender<Vec<u8>>,
|
|
client_writers: Arc<Mutex<HashMap<u32, EdgeStreamState>>>,
|
|
client_token: CancellationToken,
|
|
active_streams: Arc<AtomicU32>,
|
|
) {
|
|
let client_ip = client_addr.ip().to_string();
|
|
let client_port = client_addr.port();
|
|
|
|
// Determine edge IP (use 0.0.0.0 as placeholder — hub doesn't use it for routing)
|
|
let edge_ip = "0.0.0.0";
|
|
|
|
// Send OPEN frame with PROXY v1 header via control channel
|
|
let proxy_header = build_proxy_v1_header(&client_ip, edge_ip, client_port, dest_port);
|
|
let open_frame = encode_frame(stream_id, FRAME_OPEN, proxy_header.as_bytes());
|
|
if tunnel_ctrl_tx.send(open_frame).await.is_err() {
|
|
return;
|
|
}
|
|
|
|
// Set up channel for data coming back from hub (capacity 16 is sufficient with flow control)
|
|
let (back_tx, mut back_rx) = mpsc::channel::<Vec<u8>>(256);
|
|
let send_window = Arc::new(AtomicU32::new(INITIAL_STREAM_WINDOW));
|
|
let window_notify = Arc::new(Notify::new());
|
|
{
|
|
let mut writers = client_writers.lock().await;
|
|
writers.insert(stream_id, EdgeStreamState {
|
|
back_tx,
|
|
send_window: Arc::clone(&send_window),
|
|
window_notify: Arc::clone(&window_notify),
|
|
});
|
|
}
|
|
|
|
let (mut client_read, mut client_write) = client_stream.into_split();
|
|
|
|
// Task: hub -> client (download direction)
|
|
// After writing to client TCP, send WINDOW_UPDATE to hub so it can send more
|
|
let hub_to_client_token = client_token.clone();
|
|
let wu_tx = tunnel_ctrl_tx.clone();
|
|
let active_streams_h2c = Arc::clone(&active_streams);
|
|
let mut hub_to_client = tokio::spawn(async move {
|
|
let mut consumed_since_update: u32 = 0;
|
|
loop {
|
|
tokio::select! {
|
|
data = back_rx.recv() => {
|
|
match data {
|
|
Some(data) => {
|
|
let len = data.len() as u32;
|
|
if client_write.write_all(&data).await.is_err() {
|
|
break;
|
|
}
|
|
// Track consumption for adaptive flow control.
|
|
// The increment is capped to the adaptive window so the sender's
|
|
// effective window shrinks to match current demand (fewer streams
|
|
// = larger window, more streams = smaller window per stream).
|
|
consumed_since_update += len;
|
|
let adaptive_window = remoteingress_protocol::compute_window_for_stream_count(
|
|
active_streams_h2c.load(Ordering::Relaxed),
|
|
);
|
|
let threshold = adaptive_window / 2;
|
|
if consumed_since_update >= threshold {
|
|
let increment = consumed_since_update.min(adaptive_window);
|
|
let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE, increment);
|
|
if wu_tx.try_send(frame).is_ok() {
|
|
consumed_since_update -= increment;
|
|
}
|
|
// If try_send fails, keep accumulating — retry on next threshold
|
|
}
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
_ = hub_to_client_token.cancelled() => break,
|
|
}
|
|
}
|
|
// Send final window update for any remaining consumed bytes
|
|
if consumed_since_update > 0 {
|
|
let frame = encode_window_update(stream_id, FRAME_WINDOW_UPDATE, consumed_since_update);
|
|
let _ = wu_tx.try_send(frame);
|
|
}
|
|
let _ = client_write.shutdown().await;
|
|
});
|
|
|
|
// Task: client -> hub (upload direction) with per-stream flow control
|
|
let mut buf = vec![0u8; 32768];
|
|
loop {
|
|
// Wait for send window to have capacity (with stall timeout)
|
|
loop {
|
|
let w = send_window.load(Ordering::Acquire);
|
|
if w > 0 { break; }
|
|
tokio::select! {
|
|
_ = window_notify.notified() => continue,
|
|
_ = client_token.cancelled() => break,
|
|
_ = tokio::time::sleep(Duration::from_secs(120)) => {
|
|
log::warn!("Stream {} upload stalled (window empty for 120s)", stream_id);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if client_token.is_cancelled() { break; }
|
|
|
|
// Limit read size to available window.
|
|
// IMPORTANT: if window is 0 (stall timeout fired), we must NOT
|
|
// read into an empty buffer — read(&mut buf[..0]) returns Ok(0)
|
|
// which would be falsely interpreted as EOF.
|
|
let w = send_window.load(Ordering::Acquire) as usize;
|
|
if w == 0 {
|
|
log::warn!("Stream {} upload: window still 0 after stall timeout, closing", stream_id);
|
|
break;
|
|
}
|
|
// Adaptive: cap read to current per-stream target window
|
|
let adaptive_cap = remoteingress_protocol::compute_window_for_stream_count(
|
|
active_streams.load(Ordering::Relaxed),
|
|
) as usize;
|
|
let max_read = w.min(buf.len()).min(adaptive_cap);
|
|
|
|
tokio::select! {
|
|
read_result = client_read.read(&mut buf[..max_read]) => {
|
|
match read_result {
|
|
Ok(0) => break,
|
|
Ok(n) => {
|
|
send_window.fetch_sub(n as u32, Ordering::Release);
|
|
let data_frame = encode_frame(stream_id, FRAME_DATA, &buf[..n]);
|
|
if tunnel_data_tx.send(data_frame).await.is_err() {
|
|
log::warn!("Stream {} data channel closed, closing", stream_id);
|
|
break;
|
|
}
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
_ = client_token.cancelled() => break,
|
|
}
|
|
}
|
|
|
|
// Wait for the download task (hub → client) to finish BEFORE sending CLOSE.
|
|
// Upload EOF (client done sending) does NOT mean the response is done.
|
|
// For asymmetric transfers like git fetch (small request, large response),
|
|
// the response is still streaming when the upload finishes.
|
|
// Sending CLOSE before the response finishes would cause the hub to cancel
|
|
// the upstream reader mid-response, truncating the data.
|
|
let _ = tokio::time::timeout(
|
|
Duration::from_secs(300), // 5 min max wait for download to finish
|
|
&mut hub_to_client,
|
|
).await;
|
|
|
|
// NOW send CLOSE — the response has been fully delivered (or timed out).
|
|
if !client_token.is_cancelled() {
|
|
let close_frame = encode_frame(stream_id, FRAME_CLOSE, &[]);
|
|
let _ = tunnel_data_tx.send(close_frame).await;
|
|
}
|
|
|
|
// Clean up
|
|
{
|
|
let mut writers = client_writers.lock().await;
|
|
writers.remove(&stream_id);
|
|
}
|
|
hub_to_client.abort(); // No-op if already finished; safety net if timeout fired
|
|
let _ = edge_id; // used for logging context
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// --- Serde tests ---
|
|
|
|
#[test]
|
|
fn test_edge_config_deserialize_camel_case() {
|
|
let json = r#"{
|
|
"hubHost": "hub.example.com",
|
|
"hubPort": 8443,
|
|
"edgeId": "edge-1",
|
|
"secret": "my-secret"
|
|
}"#;
|
|
let config: EdgeConfig = serde_json::from_str(json).unwrap();
|
|
assert_eq!(config.hub_host, "hub.example.com");
|
|
assert_eq!(config.hub_port, 8443);
|
|
assert_eq!(config.edge_id, "edge-1");
|
|
assert_eq!(config.secret, "my-secret");
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_config_serialize_roundtrip() {
|
|
let config = EdgeConfig {
|
|
hub_host: "host.test".to_string(),
|
|
hub_port: 9999,
|
|
edge_id: "e1".to_string(),
|
|
secret: "sec".to_string(),
|
|
bind_address: None,
|
|
};
|
|
let json = serde_json::to_string(&config).unwrap();
|
|
let back: EdgeConfig = serde_json::from_str(&json).unwrap();
|
|
assert_eq!(back.hub_host, config.hub_host);
|
|
assert_eq!(back.hub_port, config.hub_port);
|
|
assert_eq!(back.edge_id, config.edge_id);
|
|
assert_eq!(back.secret, config.secret);
|
|
}
|
|
|
|
#[test]
|
|
fn test_handshake_config_deserialize_all_fields() {
|
|
let json = r#"{"listenPorts": [80, 443], "stunIntervalSecs": 120}"#;
|
|
let hc: HandshakeConfig = serde_json::from_str(json).unwrap();
|
|
assert_eq!(hc.listen_ports, vec![80, 443]);
|
|
assert_eq!(hc.stun_interval_secs, 120);
|
|
}
|
|
|
|
#[test]
|
|
fn test_handshake_config_default_stun_interval() {
|
|
let json = r#"{"listenPorts": [443]}"#;
|
|
let hc: HandshakeConfig = serde_json::from_str(json).unwrap();
|
|
assert_eq!(hc.listen_ports, vec![443]);
|
|
assert_eq!(hc.stun_interval_secs, 300);
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_update_deserialize() {
|
|
let json = r#"{"listenPorts": [8080, 9090]}"#;
|
|
let update: ConfigUpdate = serde_json::from_str(json).unwrap();
|
|
assert_eq!(update.listen_ports, vec![8080, 9090]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_status_serialize() {
|
|
let status = EdgeStatus {
|
|
running: true,
|
|
connected: true,
|
|
public_ip: Some("1.2.3.4".to_string()),
|
|
active_streams: 5,
|
|
listen_ports: vec![443],
|
|
};
|
|
let json = serde_json::to_value(&status).unwrap();
|
|
assert_eq!(json["running"], true);
|
|
assert_eq!(json["connected"], true);
|
|
assert_eq!(json["publicIp"], "1.2.3.4");
|
|
assert_eq!(json["activeStreams"], 5);
|
|
assert_eq!(json["listenPorts"], serde_json::json!([443]));
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_status_serialize_none_ip() {
|
|
let status = EdgeStatus {
|
|
running: false,
|
|
connected: false,
|
|
public_ip: None,
|
|
active_streams: 0,
|
|
listen_ports: vec![],
|
|
};
|
|
let json = serde_json::to_value(&status).unwrap();
|
|
assert!(json["publicIp"].is_null());
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_event_tunnel_connected() {
|
|
let event = EdgeEvent::TunnelConnected;
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "tunnelConnected");
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_event_tunnel_disconnected() {
|
|
let event = EdgeEvent::TunnelDisconnected;
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "tunnelDisconnected");
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_event_public_ip_discovered() {
|
|
let event = EdgeEvent::PublicIpDiscovered {
|
|
ip: "203.0.113.1".to_string(),
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "publicIpDiscovered");
|
|
assert_eq!(json["ip"], "203.0.113.1");
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_event_ports_assigned() {
|
|
let event = EdgeEvent::PortsAssigned {
|
|
listen_ports: vec![443, 8080],
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "portsAssigned");
|
|
assert_eq!(json["listenPorts"], serde_json::json!([443, 8080]));
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_event_ports_updated() {
|
|
let event = EdgeEvent::PortsUpdated {
|
|
listen_ports: vec![9090],
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "portsUpdated");
|
|
assert_eq!(json["listenPorts"], serde_json::json!([9090]));
|
|
}
|
|
|
|
// --- Async tests ---
|
|
|
|
#[tokio::test]
|
|
async fn test_tunnel_edge_new_get_status() {
|
|
let edge = TunnelEdge::new(EdgeConfig {
|
|
hub_host: "localhost".to_string(),
|
|
hub_port: 8443,
|
|
edge_id: "test-edge".to_string(),
|
|
secret: "test-secret".to_string(),
|
|
bind_address: None,
|
|
});
|
|
let status = edge.get_status().await;
|
|
assert!(!status.running);
|
|
assert!(!status.connected);
|
|
assert!(status.public_ip.is_none());
|
|
assert_eq!(status.active_streams, 0);
|
|
assert!(status.listen_ports.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_tunnel_edge_take_event_rx() {
|
|
let edge = TunnelEdge::new(EdgeConfig {
|
|
hub_host: "localhost".to_string(),
|
|
hub_port: 8443,
|
|
edge_id: "e".to_string(),
|
|
secret: "s".to_string(),
|
|
bind_address: None,
|
|
});
|
|
let rx1 = edge.take_event_rx().await;
|
|
assert!(rx1.is_some());
|
|
let rx2 = edge.take_event_rx().await;
|
|
assert!(rx2.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_tunnel_edge_stop_without_start() {
|
|
let edge = TunnelEdge::new(EdgeConfig {
|
|
hub_host: "localhost".to_string(),
|
|
hub_port: 8443,
|
|
edge_id: "e".to_string(),
|
|
secret: "s".to_string(),
|
|
bind_address: None,
|
|
});
|
|
edge.stop().await; // should not panic
|
|
let status = edge.get_status().await;
|
|
assert!(!status.running);
|
|
}
|
|
}
|
|
|
|
/// TLS certificate verifier that accepts any certificate (auth is via shared secret).
|
|
#[derive(Debug)]
|
|
struct NoCertVerifier;
|
|
|
|
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
|
|
fn verify_server_cert(
|
|
&self,
|
|
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
|
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
|
_server_name: &rustls::pki_types::ServerName<'_>,
|
|
_ocsp_response: &[u8],
|
|
_now: rustls::pki_types::UnixTime,
|
|
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
|
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
|
}
|
|
|
|
fn verify_tls12_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn verify_tls13_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
|
vec![
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
|
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
|
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
|
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
|
|
rustls::SignatureScheme::RSA_PSS_SHA256,
|
|
rustls::SignatureScheme::RSA_PSS_SHA384,
|
|
rustls::SignatureScheme::RSA_PSS_SHA512,
|
|
rustls::SignatureScheme::ED25519,
|
|
rustls::SignatureScheme::ED448,
|
|
]
|
|
}
|
|
}
|