use std::collections::HashMap; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio_rustls::TlsConnector; use serde::{Deserialize, Serialize}; use remoteingress_protocol::*; /// Edge configuration. #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct EdgeConfig { pub hub_host: String, pub hub_port: u16, pub edge_id: String, pub secret: String, pub listen_ports: Vec, pub stun_interval_secs: Option, } /// Events emitted by the edge. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] #[serde(tag = "type")] pub enum EdgeEvent { TunnelConnected, TunnelDisconnected, #[serde(rename_all = "camelCase")] PublicIpDiscovered { ip: String }, } /// Edge status response. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct EdgeStatus { pub running: bool, pub connected: bool, pub public_ip: Option, pub active_streams: usize, pub listen_ports: Vec, } /// The tunnel edge that listens for client connections and multiplexes them to the hub. pub struct TunnelEdge { config: RwLock, event_tx: mpsc::UnboundedSender, event_rx: Mutex>>, shutdown_tx: Mutex>>, running: RwLock, connected: Arc>, public_ip: Arc>>, active_streams: Arc, next_stream_id: Arc, } impl TunnelEdge { pub fn new(config: EdgeConfig) -> Self { let (event_tx, event_rx) = mpsc::unbounded_channel(); Self { config: RwLock::new(config), event_tx, event_rx: Mutex::new(Some(event_rx)), shutdown_tx: Mutex::new(None), running: RwLock::new(false), connected: Arc::new(RwLock::new(false)), public_ip: Arc::new(RwLock::new(None)), active_streams: Arc::new(AtomicU32::new(0)), next_stream_id: Arc::new(AtomicU32::new(1)), } } /// Take the event receiver (can only be called once). pub async fn take_event_rx(&self) -> Option> { self.event_rx.lock().await.take() } /// Get the current edge status. pub async fn get_status(&self) -> EdgeStatus { EdgeStatus { running: *self.running.read().await, connected: *self.connected.read().await, public_ip: self.public_ip.read().await.clone(), active_streams: self.active_streams.load(Ordering::Relaxed) as usize, listen_ports: self.config.read().await.listen_ports.clone(), } } /// Start the edge: connect to hub, start listeners. pub async fn start(&self) -> Result<(), Box> { let config = self.config.read().await.clone(); let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1); *self.shutdown_tx.lock().await = Some(shutdown_tx); *self.running.write().await = true; let connected = self.connected.clone(); let public_ip = self.public_ip.clone(); let active_streams = self.active_streams.clone(); let next_stream_id = self.next_stream_id.clone(); let event_tx = self.event_tx.clone(); tokio::spawn(async move { edge_main_loop( config, connected, public_ip, active_streams, next_stream_id, event_tx, shutdown_rx, ) .await; }); Ok(()) } /// Stop the edge. pub async fn stop(&self) { if let Some(tx) = self.shutdown_tx.lock().await.take() { let _ = tx.send(()).await; } *self.running.write().await = false; *self.connected.write().await = false; } } async fn edge_main_loop( config: EdgeConfig, connected: Arc>, public_ip: Arc>>, active_streams: Arc, next_stream_id: Arc, event_tx: mpsc::UnboundedSender, mut shutdown_rx: mpsc::Receiver<()>, ) { let mut backoff_ms: u64 = 1000; let max_backoff_ms: u64 = 30000; loop { // Try to connect to hub let result = connect_to_hub_and_run( &config, &connected, &public_ip, &active_streams, &next_stream_id, &event_tx, &mut shutdown_rx, ) .await; *connected.write().await = false; let _ = event_tx.send(EdgeEvent::TunnelDisconnected); active_streams.store(0, Ordering::Relaxed); match result { EdgeLoopResult::Shutdown => break, EdgeLoopResult::Reconnect => { log::info!("Reconnecting in {}ms...", backoff_ms); tokio::select! { _ = tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)) => {} _ = shutdown_rx.recv() => break, } backoff_ms = (backoff_ms * 2).min(max_backoff_ms); } } } } enum EdgeLoopResult { Shutdown, Reconnect, } async fn connect_to_hub_and_run( config: &EdgeConfig, connected: &Arc>, public_ip: &Arc>>, active_streams: &Arc, next_stream_id: &Arc, event_tx: &mpsc::UnboundedSender, shutdown_rx: &mut mpsc::Receiver<()>, ) -> EdgeLoopResult { // Build TLS connector that skips cert verification (auth is via secret) let tls_config = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoCertVerifier)) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(tls_config)); let addr = format!("{}:{}", config.hub_host, config.hub_port); let tcp = match TcpStream::connect(&addr).await { Ok(s) => s, Err(e) => { log::error!("Failed to connect to hub at {}: {}", addr, e); return EdgeLoopResult::Reconnect; } }; let server_name = rustls::pki_types::ServerName::try_from(config.hub_host.clone()) .unwrap_or_else(|_| rustls::pki_types::ServerName::try_from("remoteingress-hub".to_string()).unwrap()); let tls_stream = match connector.connect(server_name, tcp).await { Ok(s) => s, Err(e) => { log::error!("TLS handshake failed: {}", e); return EdgeLoopResult::Reconnect; } }; let (read_half, mut write_half) = tokio::io::split(tls_stream); // Send auth line let auth_line = format!("EDGE {} {}\n", config.edge_id, config.secret); if write_half.write_all(auth_line.as_bytes()).await.is_err() { return EdgeLoopResult::Reconnect; } *connected.write().await = true; let _ = event_tx.send(EdgeEvent::TunnelConnected); log::info!("Connected to hub at {}", addr); // Start STUN discovery let stun_interval = config.stun_interval_secs.unwrap_or(300); let public_ip_clone = public_ip.clone(); let event_tx_clone = event_tx.clone(); let stun_handle = tokio::spawn(async move { loop { if let Some(ip) = crate::stun::discover_public_ip().await { let mut pip = public_ip_clone.write().await; let changed = pip.as_ref() != Some(&ip); *pip = Some(ip.clone()); if changed { let _ = event_tx_clone.send(EdgeEvent::PublicIpDiscovered { ip }); } } tokio::time::sleep(std::time::Duration::from_secs(stun_interval)).await; } }); // Client socket map: stream_id -> sender for writing data back to client let client_writers: Arc>>>> = Arc::new(Mutex::new(HashMap::new())); // Shared tunnel writer let tunnel_writer = Arc::new(Mutex::new(write_half)); // Start TCP listeners for each port let mut listener_handles = Vec::new(); for &port in &config.listen_ports { let tunnel_writer = tunnel_writer.clone(); let client_writers = client_writers.clone(); let active_streams = active_streams.clone(); let next_stream_id = next_stream_id.clone(); let edge_id = config.edge_id.clone(); let handle = tokio::spawn(async move { let listener = match TcpListener::bind(("0.0.0.0", port)).await { Ok(l) => l, Err(e) => { log::error!("Failed to bind port {}: {}", port, e); return; } }; log::info!("Listening on port {}", port); loop { match listener.accept().await { Ok((client_stream, client_addr)) => { let stream_id = next_stream_id.fetch_add(1, Ordering::Relaxed); let tunnel_writer = tunnel_writer.clone(); let client_writers = client_writers.clone(); let active_streams = active_streams.clone(); let edge_id = edge_id.clone(); active_streams.fetch_add(1, Ordering::Relaxed); tokio::spawn(async move { handle_client_connection( client_stream, client_addr, stream_id, port, &edge_id, tunnel_writer, client_writers, ) .await; active_streams.fetch_sub(1, Ordering::Relaxed); }); } Err(e) => { log::error!("Accept error on port {}: {}", port, e); } } } }); listener_handles.push(handle); } // Read frames from hub let mut frame_reader = FrameReader::new(read_half); let result = loop { tokio::select! { frame_result = frame_reader.next_frame() => { match frame_result { Ok(Some(frame)) => { match frame.frame_type { FRAME_DATA_BACK => { let writers = client_writers.lock().await; if let Some(tx) = writers.get(&frame.stream_id) { let _ = tx.send(frame.payload).await; } } FRAME_CLOSE_BACK => { let mut writers = client_writers.lock().await; writers.remove(&frame.stream_id); } _ => { log::warn!("Unexpected frame type {} from hub", frame.frame_type); } } } Ok(None) => { log::info!("Hub disconnected (EOF)"); break EdgeLoopResult::Reconnect; } Err(e) => { log::error!("Hub frame error: {}", e); break EdgeLoopResult::Reconnect; } } } _ = shutdown_rx.recv() => { break EdgeLoopResult::Shutdown; } } }; // Cleanup stun_handle.abort(); for h in listener_handles { h.abort(); } result } async fn handle_client_connection( client_stream: TcpStream, client_addr: std::net::SocketAddr, stream_id: u32, dest_port: u16, edge_id: &str, tunnel_writer: Arc>>>, client_writers: Arc>>>>, ) { let client_ip = client_addr.ip().to_string(); let client_port = client_addr.port(); // Determine edge IP (use 0.0.0.0 as placeholder — hub doesn't use it for routing) let edge_ip = "0.0.0.0"; // Send OPEN frame with PROXY v1 header let proxy_header = build_proxy_v1_header(&client_ip, edge_ip, client_port, dest_port); let open_frame = encode_frame(stream_id, FRAME_OPEN, proxy_header.as_bytes()); { let mut w = tunnel_writer.lock().await; if w.write_all(&open_frame).await.is_err() { return; } } // Set up channel for data coming back from hub let (back_tx, mut back_rx) = mpsc::channel::>(256); { let mut writers = client_writers.lock().await; writers.insert(stream_id, back_tx); } let (mut client_read, mut client_write) = client_stream.into_split(); // Task: hub -> client let hub_to_client = tokio::spawn(async move { while let Some(data) = back_rx.recv().await { if client_write.write_all(&data).await.is_err() { break; } } let _ = client_write.shutdown().await; }); // Task: client -> hub let mut buf = vec![0u8; 32768]; loop { match client_read.read(&mut buf).await { Ok(0) => break, Ok(n) => { let data_frame = encode_frame(stream_id, FRAME_DATA, &buf[..n]); let mut w = tunnel_writer.lock().await; if w.write_all(&data_frame).await.is_err() { break; } } Err(_) => break, } } // Send CLOSE frame let close_frame = encode_frame(stream_id, FRAME_CLOSE, &[]); { let mut w = tunnel_writer.lock().await; let _ = w.write_all(&close_frame).await; } // Cleanup { let mut writers = client_writers.lock().await; writers.remove(&stream_id); } hub_to_client.abort(); let _ = edge_id; // used for logging context } /// TLS certificate verifier that accepts any certificate (auth is via shared secret). #[derive(Debug)] struct NoCertVerifier; impl rustls::client::danger::ServerCertVerifier for NoCertVerifier { fn verify_server_cert( &self, _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls::pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ECDSA_NISTP521_SHA512, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::ED448, ] } }