use std::collections::HashMap; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio_rustls::TlsAcceptor; use tokio_util::sync::CancellationToken; use serde::{Deserialize, Serialize}; use remoteingress_protocol::*; /// Hub configuration. #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct HubConfig { pub tunnel_port: u16, pub target_host: Option, #[serde(skip)] pub tls_cert_pem: Option, #[serde(skip)] pub tls_key_pem: Option, } impl Default for HubConfig { fn default() -> Self { Self { tunnel_port: 8443, target_host: Some("127.0.0.1".to_string()), tls_cert_pem: None, tls_key_pem: None, } } } /// An allowed edge identity. #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct AllowedEdge { pub id: String, pub secret: String, #[serde(default)] pub listen_ports: Vec, pub stun_interval_secs: Option, } /// Handshake response sent to edge after authentication. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct HandshakeResponse { listen_ports: Vec, stun_interval_secs: u64, } /// Configuration update pushed to a connected edge at runtime. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct EdgeConfigUpdate { pub listen_ports: Vec, } /// Runtime status of a connected edge. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct ConnectedEdgeStatus { pub edge_id: String, pub connected_at: u64, pub active_streams: usize, } /// Events emitted by the hub. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] #[serde(tag = "type")] pub enum HubEvent { #[serde(rename_all = "camelCase")] EdgeConnected { edge_id: String }, #[serde(rename_all = "camelCase")] EdgeDisconnected { edge_id: String }, #[serde(rename_all = "camelCase")] StreamOpened { edge_id: String, stream_id: u32 }, #[serde(rename_all = "camelCase")] StreamClosed { edge_id: String, stream_id: u32 }, } /// Hub status response. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct HubStatus { pub running: bool, pub tunnel_port: u16, pub connected_edges: Vec, } /// The tunnel hub that accepts edge connections and demuxes streams to SmartProxy. pub struct TunnelHub { config: RwLock, allowed_edges: Arc>>, connected_edges: Arc>>, event_tx: mpsc::Sender, event_rx: Mutex>>, shutdown_tx: Mutex>>, running: RwLock, cancel_token: CancellationToken, } struct ConnectedEdgeInfo { connected_at: u64, active_streams: Arc>, CancellationToken)>>>, config_tx: mpsc::Sender, #[allow(dead_code)] // kept alive for Drop — cancels child tokens when edge is removed cancel_token: CancellationToken, } impl TunnelHub { pub fn new(config: HubConfig) -> Self { let (event_tx, event_rx) = mpsc::channel(1024); Self { config: RwLock::new(config), allowed_edges: Arc::new(RwLock::new(HashMap::new())), connected_edges: Arc::new(Mutex::new(HashMap::new())), event_tx, event_rx: Mutex::new(Some(event_rx)), shutdown_tx: Mutex::new(None), running: RwLock::new(false), cancel_token: CancellationToken::new(), } } /// Take the event receiver (can only be called once). pub async fn take_event_rx(&self) -> Option> { self.event_rx.lock().await.take() } /// Update the list of allowed edges. /// For any currently-connected edge whose ports changed, push a config update. pub async fn update_allowed_edges(&self, edges: Vec) { let mut map = self.allowed_edges.write().await; // Build new map let mut new_map = HashMap::new(); for edge in &edges { new_map.insert(edge.id.clone(), edge.clone()); } // Push config updates to connected edges whose ports changed let connected = self.connected_edges.lock().await; for edge in &edges { if let Some(info) = connected.get(&edge.id) { // Check if ports changed compared to old config let ports_changed = match map.get(&edge.id) { Some(old) => old.listen_ports != edge.listen_ports, None => true, // newly allowed edge that's already connected }; if ports_changed { let update = EdgeConfigUpdate { listen_ports: edge.listen_ports.clone(), }; let _ = info.config_tx.try_send(update); } } } *map = new_map; } /// Get the current hub status. pub async fn get_status(&self) -> HubStatus { let running = *self.running.read().await; let config = self.config.read().await; let edges = self.connected_edges.lock().await; let mut connected = Vec::new(); for (id, info) in edges.iter() { let streams = info.active_streams.lock().await; connected.push(ConnectedEdgeStatus { edge_id: id.clone(), connected_at: info.connected_at, active_streams: streams.len(), }); } HubStatus { running, tunnel_port: config.tunnel_port, connected_edges: connected, } } /// Start the hub — listen for TLS connections from edges. pub async fn start(&self) -> Result<(), Box> { let config = self.config.read().await.clone(); let tls_config = build_tls_config(&config)?; let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let listener = TcpListener::bind(("0.0.0.0", config.tunnel_port)).await?; log::info!("Hub listening on port {}", config.tunnel_port); let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); *self.shutdown_tx.lock().await = Some(shutdown_tx); *self.running.write().await = true; let allowed = self.allowed_edges.clone(); let connected = self.connected_edges.clone(); let event_tx = self.event_tx.clone(); let target_host = config.target_host.unwrap_or_else(|| "127.0.0.1".to_string()); let hub_token = self.cancel_token.clone(); tokio::spawn(async move { loop { tokio::select! { result = listener.accept() => { match result { Ok((stream, addr)) => { log::info!("Edge connection from {}", addr); let acceptor = acceptor.clone(); let allowed = allowed.clone(); let connected = connected.clone(); let event_tx = event_tx.clone(); let target = target_host.clone(); let edge_token = hub_token.child_token(); tokio::spawn(async move { if let Err(e) = handle_edge_connection( stream, acceptor, allowed, connected, event_tx, target, edge_token, ).await { log::error!("Edge connection error: {}", e); } }); } Err(e) => { log::error!("Accept error: {}", e); } } } _ = hub_token.cancelled() => { log::info!("Hub shutting down (token cancelled)"); break; } _ = shutdown_rx.recv() => { log::info!("Hub shutting down"); break; } } } }); Ok(()) } /// Stop the hub. pub async fn stop(&self) { self.cancel_token.cancel(); if let Some(tx) = self.shutdown_tx.lock().await.take() { let _ = tx.send(()).await; } *self.running.write().await = false; // Clear connected edges self.connected_edges.lock().await.clear(); } } impl Drop for TunnelHub { fn drop(&mut self) { self.cancel_token.cancel(); } } /// Handle a single edge connection: authenticate, then enter frame loop. async fn handle_edge_connection( stream: TcpStream, acceptor: TlsAcceptor, allowed: Arc>>, connected: Arc>>, event_tx: mpsc::Sender, target_host: String, edge_token: CancellationToken, ) -> Result<(), Box> { let tls_stream = acceptor.accept(stream).await?; let (read_half, mut write_half) = tokio::io::split(tls_stream); let mut buf_reader = BufReader::new(read_half); // Read auth line: "EDGE \n" let mut auth_line = String::new(); buf_reader.read_line(&mut auth_line).await?; let auth_line = auth_line.trim(); let parts: Vec<&str> = auth_line.splitn(3, ' ').collect(); if parts.len() != 3 || parts[0] != "EDGE" { return Err("invalid auth line".into()); } let edge_id = parts[1].to_string(); let secret = parts[2]; // Verify credentials and extract edge config let (listen_ports, stun_interval_secs) = { let edges = allowed.read().await; match edges.get(&edge_id) { Some(edge) => { if !constant_time_eq(secret.as_bytes(), edge.secret.as_bytes()) { return Err(format!("invalid secret for edge {}", edge_id).into()); } (edge.listen_ports.clone(), edge.stun_interval_secs.unwrap_or(300)) } None => { return Err(format!("unknown edge {}", edge_id).into()); } } }; log::info!("Edge {} authenticated", edge_id); let _ = event_tx.try_send(HubEvent::EdgeConnected { edge_id: edge_id.clone(), }); // Send handshake response with initial config before frame protocol begins let handshake = HandshakeResponse { listen_ports: listen_ports.clone(), stun_interval_secs, }; let mut handshake_json = serde_json::to_string(&handshake)?; handshake_json.push('\n'); write_half.write_all(handshake_json.as_bytes()).await?; // Track this edge let streams: Arc>, CancellationToken)>>> = Arc::new(Mutex::new(HashMap::new())); let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); // Create config update channel let (config_tx, mut config_rx) = mpsc::channel::(16); { let mut edges = connected.lock().await; edges.insert( edge_id.clone(), ConnectedEdgeInfo { connected_at: now, active_streams: streams.clone(), config_tx, cancel_token: edge_token.clone(), }, ); } // Shared writer for sending frames back to edge let write_half = Arc::new(Mutex::new(write_half)); // Spawn task to forward config updates as FRAME_CONFIG frames let config_writer = write_half.clone(); let config_edge_id = edge_id.clone(); let config_token = edge_token.clone(); let config_handle = tokio::spawn(async move { loop { tokio::select! { update = config_rx.recv() => { match update { Some(update) => { if let Ok(payload) = serde_json::to_vec(&update) { let frame = encode_frame(0, FRAME_CONFIG, &payload); let mut w = config_writer.lock().await; if w.write_all(&frame).await.is_err() { log::error!("Failed to send config update to edge {}", config_edge_id); break; } log::info!("Sent config update to edge {}: ports {:?}", config_edge_id, update.listen_ports); } } None => break, } } _ = config_token.cancelled() => break, } } }); // Frame reading loop let mut frame_reader = FrameReader::new(buf_reader); loop { tokio::select! { frame_result = frame_reader.next_frame() => { match frame_result { Ok(Some(frame)) => { match frame.frame_type { FRAME_OPEN => { // Payload is PROXY v1 header line let proxy_header = String::from_utf8_lossy(&frame.payload).to_string(); // Parse destination port from PROXY header let dest_port = parse_dest_port_from_proxy(&proxy_header).unwrap_or(443); let stream_id = frame.stream_id; let edge_id_clone = edge_id.clone(); let event_tx_clone = event_tx.clone(); let streams_clone = streams.clone(); let writer_clone = write_half.clone(); let target = target_host.clone(); let stream_token = edge_token.child_token(); let _ = event_tx.try_send(HubEvent::StreamOpened { edge_id: edge_id.clone(), stream_id, }); // Create channel for data from edge to this stream let (data_tx, mut data_rx) = mpsc::channel::>(256); { let mut s = streams.lock().await; s.insert(stream_id, (data_tx, stream_token.clone())); } // Spawn task: connect to SmartProxy, send PROXY header, pipe data tokio::spawn(async move { let result = async { let mut upstream = TcpStream::connect((target.as_str(), dest_port)).await?; upstream.write_all(proxy_header.as_bytes()).await?; let (mut up_read, mut up_write) = upstream.into_split(); // Forward data from edge (via channel) to SmartProxy let writer_token = stream_token.clone(); let writer_for_edge_data = tokio::spawn(async move { loop { tokio::select! { data = data_rx.recv() => { match data { Some(data) => { if up_write.write_all(&data).await.is_err() { break; } } None => break, } } _ = writer_token.cancelled() => break, } } let _ = up_write.shutdown().await; }); // Forward data from SmartProxy back to edge let mut buf = vec![0u8; 32768]; loop { tokio::select! { read_result = up_read.read(&mut buf) => { match read_result { Ok(0) => break, Ok(n) => { let frame = encode_frame(stream_id, FRAME_DATA_BACK, &buf[..n]); let mut w = writer_clone.lock().await; if w.write_all(&frame).await.is_err() { break; } } Err(_) => break, } } _ = stream_token.cancelled() => break, } } // Send CLOSE_BACK to edge (only if not cancelled) if !stream_token.is_cancelled() { let close_frame = encode_frame(stream_id, FRAME_CLOSE_BACK, &[]); let mut w = writer_clone.lock().await; let _ = w.write_all(&close_frame).await; } writer_for_edge_data.abort(); Ok::<(), Box>(()) } .await; if let Err(e) = result { log::error!("Stream {} error: {}", stream_id, e); // Send CLOSE_BACK on error (only if not cancelled) if !stream_token.is_cancelled() { let close_frame = encode_frame(stream_id, FRAME_CLOSE_BACK, &[]); let mut w = writer_clone.lock().await; let _ = w.write_all(&close_frame).await; } } // Clean up stream (guard against duplicate if FRAME_CLOSE already removed it) let was_present = { let mut s = streams_clone.lock().await; s.remove(&stream_id).is_some() }; if was_present { let _ = event_tx_clone.try_send(HubEvent::StreamClosed { edge_id: edge_id_clone, stream_id, }); } }); } FRAME_DATA => { let s = streams.lock().await; if let Some((tx, _)) = s.get(&frame.stream_id) { let _ = tx.send(frame.payload).await; } } FRAME_CLOSE => { let mut s = streams.lock().await; if let Some((_, token)) = s.remove(&frame.stream_id) { token.cancel(); let _ = event_tx.try_send(HubEvent::StreamClosed { edge_id: edge_id.clone(), stream_id: frame.stream_id, }); } } _ => { log::warn!("Unexpected frame type {} from edge", frame.frame_type); } } } Ok(None) => { log::info!("Edge {} disconnected (EOF)", edge_id); break; } Err(e) => { log::error!("Edge {} frame error: {}", edge_id, e); break; } } } _ = edge_token.cancelled() => { log::info!("Edge {} cancelled by hub", edge_id); break; } } } // Cleanup: cancel edge token to propagate to all child tasks edge_token.cancel(); config_handle.abort(); { let mut edges = connected.lock().await; edges.remove(&edge_id); } let _ = event_tx.try_send(HubEvent::EdgeDisconnected { edge_id: edge_id.clone(), }); Ok(()) } /// Parse destination port from PROXY v1 header. fn parse_dest_port_from_proxy(header: &str) -> Option { let parts: Vec<&str> = header.trim().split_whitespace().collect(); if parts.len() >= 6 { parts[5].parse().ok() } else { None } } /// Build TLS server config from PEM strings, or auto-generate self-signed. fn build_tls_config( config: &HubConfig, ) -> Result> { let (cert_pem, key_pem) = match (&config.tls_cert_pem, &config.tls_key_pem) { (Some(cert), Some(key)) => (cert.clone(), key.clone()), _ => { // Generate self-signed certificate let cert = rcgen::generate_simple_self_signed(vec!["remoteingress-hub".to_string()])?; let cert_pem = cert.cert.pem(); let key_pem = cert.key_pair.serialize_pem(); (cert_pem, key_pem) } }; let certs = rustls_pemfile_parse_certs(&cert_pem)?; let key = rustls_pemfile_parse_key(&key_pem)?; let mut config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; config.alpn_protocols = vec![b"remoteingress".to_vec()]; Ok(config) } fn rustls_pemfile_parse_certs( pem: &str, ) -> Result>, Box> { let mut reader = std::io::Cursor::new(pem.as_bytes()); let certs = rustls_pemfile::certs(&mut reader).collect::, _>>()?; Ok(certs) } fn rustls_pemfile_parse_key( pem: &str, ) -> Result, Box> { let mut reader = std::io::Cursor::new(pem.as_bytes()); let key = rustls_pemfile::private_key(&mut reader)? .ok_or("no private key found in PEM")?; Ok(key) } /// Constant-time comparison of two byte slices. fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { if a.len() != b.len() { return false; } let mut diff = 0u8; for (x, y) in a.iter().zip(b.iter()) { diff |= x ^ y; } diff == 0 } #[cfg(test)] mod tests { use super::*; // --- constant_time_eq tests --- #[test] fn test_constant_time_eq_equal() { assert!(constant_time_eq(b"hello", b"hello")); } #[test] fn test_constant_time_eq_different_content() { assert!(!constant_time_eq(b"hello", b"world")); } #[test] fn test_constant_time_eq_different_lengths() { assert!(!constant_time_eq(b"short", b"longer")); } #[test] fn test_constant_time_eq_both_empty() { assert!(constant_time_eq(b"", b"")); } #[test] fn test_constant_time_eq_one_empty() { assert!(!constant_time_eq(b"", b"notempty")); } #[test] fn test_constant_time_eq_single_bit_difference() { // 'A' = 0x41, 'a' = 0x61 — differ by one bit assert!(!constant_time_eq(b"A", b"a")); } // --- parse_dest_port_from_proxy tests --- #[test] fn test_parse_dest_port_443() { let header = "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n"; assert_eq!(parse_dest_port_from_proxy(header), Some(443)); } #[test] fn test_parse_dest_port_80() { let header = "PROXY TCP4 10.0.0.1 10.0.0.2 54321 80\r\n"; assert_eq!(parse_dest_port_from_proxy(header), Some(80)); } #[test] fn test_parse_dest_port_65535() { let header = "PROXY TCP4 10.0.0.1 10.0.0.2 1 65535\r\n"; assert_eq!(parse_dest_port_from_proxy(header), Some(65535)); } #[test] fn test_parse_dest_port_too_few_fields() { let header = "PROXY TCP4 1.2.3.4"; assert_eq!(parse_dest_port_from_proxy(header), None); } #[test] fn test_parse_dest_port_empty_string() { assert_eq!(parse_dest_port_from_proxy(""), None); } #[test] fn test_parse_dest_port_non_numeric() { let header = "PROXY TCP4 1.2.3.4 5.6.7.8 12345 abc\r\n"; assert_eq!(parse_dest_port_from_proxy(header), None); } // --- Serde tests --- #[test] fn test_allowed_edge_deserialize_all_fields() { let json = r#"{ "id": "edge-1", "secret": "s3cret", "listenPorts": [443, 8080], "stunIntervalSecs": 120 }"#; let edge: AllowedEdge = serde_json::from_str(json).unwrap(); assert_eq!(edge.id, "edge-1"); assert_eq!(edge.secret, "s3cret"); assert_eq!(edge.listen_ports, vec![443, 8080]); assert_eq!(edge.stun_interval_secs, Some(120)); } #[test] fn test_allowed_edge_deserialize_with_defaults() { let json = r#"{"id": "edge-2", "secret": "key"}"#; let edge: AllowedEdge = serde_json::from_str(json).unwrap(); assert_eq!(edge.id, "edge-2"); assert_eq!(edge.secret, "key"); assert!(edge.listen_ports.is_empty()); assert_eq!(edge.stun_interval_secs, None); } #[test] fn test_handshake_response_serializes_camel_case() { let resp = HandshakeResponse { listen_ports: vec![443, 8080], stun_interval_secs: 300, }; let json = serde_json::to_value(&resp).unwrap(); assert_eq!(json["listenPorts"], serde_json::json!([443, 8080])); assert_eq!(json["stunIntervalSecs"], 300); // Ensure snake_case keys are NOT present assert!(json.get("listen_ports").is_none()); assert!(json.get("stun_interval_secs").is_none()); } #[test] fn test_edge_config_update_serializes_camel_case() { let update = EdgeConfigUpdate { listen_ports: vec![80, 443], }; let json = serde_json::to_value(&update).unwrap(); assert_eq!(json["listenPorts"], serde_json::json!([80, 443])); assert!(json.get("listen_ports").is_none()); } #[test] fn test_hub_config_default() { let config = HubConfig::default(); assert_eq!(config.tunnel_port, 8443); assert_eq!(config.target_host, Some("127.0.0.1".to_string())); assert!(config.tls_cert_pem.is_none()); assert!(config.tls_key_pem.is_none()); } #[test] fn test_hub_event_edge_connected_serialize() { let event = HubEvent::EdgeConnected { edge_id: "edge-1".to_string(), }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "edgeConnected"); assert_eq!(json["edgeId"], "edge-1"); } #[test] fn test_hub_event_edge_disconnected_serialize() { let event = HubEvent::EdgeDisconnected { edge_id: "edge-2".to_string(), }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "edgeDisconnected"); assert_eq!(json["edgeId"], "edge-2"); } #[test] fn test_hub_event_stream_opened_serialize() { let event = HubEvent::StreamOpened { edge_id: "e".to_string(), stream_id: 42, }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "streamOpened"); assert_eq!(json["edgeId"], "e"); assert_eq!(json["streamId"], 42); } #[test] fn test_hub_event_stream_closed_serialize() { let event = HubEvent::StreamClosed { edge_id: "e".to_string(), stream_id: 7, }; let json = serde_json::to_value(&event).unwrap(); assert_eq!(json["type"], "streamClosed"); assert_eq!(json["edgeId"], "e"); assert_eq!(json["streamId"], 7); } // --- Async tests --- #[tokio::test] async fn test_tunnel_hub_new_get_status() { let hub = TunnelHub::new(HubConfig::default()); let status = hub.get_status().await; assert!(!status.running); assert!(status.connected_edges.is_empty()); assert_eq!(status.tunnel_port, 8443); } #[tokio::test] async fn test_tunnel_hub_take_event_rx() { let hub = TunnelHub::new(HubConfig::default()); let rx1 = hub.take_event_rx().await; assert!(rx1.is_some()); let rx2 = hub.take_event_rx().await; assert!(rx2.is_none()); } #[tokio::test] async fn test_tunnel_hub_stop_without_start() { let hub = TunnelHub::new(HubConfig::default()); hub.stop().await; // should not panic let status = hub.get_status().await; assert!(!status.running); } }