479 lines
16 KiB
Rust
479 lines
16 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::atomic::{AtomicU32, Ordering};
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::{mpsc, Mutex, RwLock};
|
|
use tokio_rustls::TlsConnector;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use remoteingress_protocol::*;
|
|
|
|
/// Edge configuration.
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct EdgeConfig {
|
|
pub hub_host: String,
|
|
pub hub_port: u16,
|
|
pub edge_id: String,
|
|
pub secret: String,
|
|
pub listen_ports: Vec<u16>,
|
|
pub stun_interval_secs: Option<u64>,
|
|
}
|
|
|
|
/// Events emitted by the edge.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[serde(tag = "type")]
|
|
pub enum EdgeEvent {
|
|
TunnelConnected,
|
|
TunnelDisconnected,
|
|
#[serde(rename_all = "camelCase")]
|
|
PublicIpDiscovered { ip: String },
|
|
}
|
|
|
|
/// Edge status response.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct EdgeStatus {
|
|
pub running: bool,
|
|
pub connected: bool,
|
|
pub public_ip: Option<String>,
|
|
pub active_streams: usize,
|
|
pub listen_ports: Vec<u16>,
|
|
}
|
|
|
|
/// The tunnel edge that listens for client connections and multiplexes them to the hub.
|
|
pub struct TunnelEdge {
|
|
config: RwLock<EdgeConfig>,
|
|
event_tx: mpsc::UnboundedSender<EdgeEvent>,
|
|
event_rx: Mutex<Option<mpsc::UnboundedReceiver<EdgeEvent>>>,
|
|
shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
|
|
running: RwLock<bool>,
|
|
connected: Arc<RwLock<bool>>,
|
|
public_ip: Arc<RwLock<Option<String>>>,
|
|
active_streams: Arc<AtomicU32>,
|
|
next_stream_id: Arc<AtomicU32>,
|
|
}
|
|
|
|
impl TunnelEdge {
|
|
pub fn new(config: EdgeConfig) -> Self {
|
|
let (event_tx, event_rx) = mpsc::unbounded_channel();
|
|
Self {
|
|
config: RwLock::new(config),
|
|
event_tx,
|
|
event_rx: Mutex::new(Some(event_rx)),
|
|
shutdown_tx: Mutex::new(None),
|
|
running: RwLock::new(false),
|
|
connected: Arc::new(RwLock::new(false)),
|
|
public_ip: Arc::new(RwLock::new(None)),
|
|
active_streams: Arc::new(AtomicU32::new(0)),
|
|
next_stream_id: Arc::new(AtomicU32::new(1)),
|
|
}
|
|
}
|
|
|
|
/// Take the event receiver (can only be called once).
|
|
pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<EdgeEvent>> {
|
|
self.event_rx.lock().await.take()
|
|
}
|
|
|
|
/// Get the current edge status.
|
|
pub async fn get_status(&self) -> EdgeStatus {
|
|
EdgeStatus {
|
|
running: *self.running.read().await,
|
|
connected: *self.connected.read().await,
|
|
public_ip: self.public_ip.read().await.clone(),
|
|
active_streams: self.active_streams.load(Ordering::Relaxed) as usize,
|
|
listen_ports: self.config.read().await.listen_ports.clone(),
|
|
}
|
|
}
|
|
|
|
/// Start the edge: connect to hub, start listeners.
|
|
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let config = self.config.read().await.clone();
|
|
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
|
|
*self.shutdown_tx.lock().await = Some(shutdown_tx);
|
|
*self.running.write().await = true;
|
|
|
|
let connected = self.connected.clone();
|
|
let public_ip = self.public_ip.clone();
|
|
let active_streams = self.active_streams.clone();
|
|
let next_stream_id = self.next_stream_id.clone();
|
|
let event_tx = self.event_tx.clone();
|
|
|
|
tokio::spawn(async move {
|
|
edge_main_loop(
|
|
config,
|
|
connected,
|
|
public_ip,
|
|
active_streams,
|
|
next_stream_id,
|
|
event_tx,
|
|
shutdown_rx,
|
|
)
|
|
.await;
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Stop the edge.
|
|
pub async fn stop(&self) {
|
|
if let Some(tx) = self.shutdown_tx.lock().await.take() {
|
|
let _ = tx.send(()).await;
|
|
}
|
|
*self.running.write().await = false;
|
|
*self.connected.write().await = false;
|
|
}
|
|
}
|
|
|
|
async fn edge_main_loop(
|
|
config: EdgeConfig,
|
|
connected: Arc<RwLock<bool>>,
|
|
public_ip: Arc<RwLock<Option<String>>>,
|
|
active_streams: Arc<AtomicU32>,
|
|
next_stream_id: Arc<AtomicU32>,
|
|
event_tx: mpsc::UnboundedSender<EdgeEvent>,
|
|
mut shutdown_rx: mpsc::Receiver<()>,
|
|
) {
|
|
let mut backoff_ms: u64 = 1000;
|
|
let max_backoff_ms: u64 = 30000;
|
|
|
|
loop {
|
|
// Try to connect to hub
|
|
let result = connect_to_hub_and_run(
|
|
&config,
|
|
&connected,
|
|
&public_ip,
|
|
&active_streams,
|
|
&next_stream_id,
|
|
&event_tx,
|
|
&mut shutdown_rx,
|
|
)
|
|
.await;
|
|
|
|
*connected.write().await = false;
|
|
let _ = event_tx.send(EdgeEvent::TunnelDisconnected);
|
|
active_streams.store(0, Ordering::Relaxed);
|
|
|
|
match result {
|
|
EdgeLoopResult::Shutdown => break,
|
|
EdgeLoopResult::Reconnect => {
|
|
log::info!("Reconnecting in {}ms...", backoff_ms);
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)) => {}
|
|
_ = shutdown_rx.recv() => break,
|
|
}
|
|
backoff_ms = (backoff_ms * 2).min(max_backoff_ms);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
enum EdgeLoopResult {
|
|
Shutdown,
|
|
Reconnect,
|
|
}
|
|
|
|
async fn connect_to_hub_and_run(
|
|
config: &EdgeConfig,
|
|
connected: &Arc<RwLock<bool>>,
|
|
public_ip: &Arc<RwLock<Option<String>>>,
|
|
active_streams: &Arc<AtomicU32>,
|
|
next_stream_id: &Arc<AtomicU32>,
|
|
event_tx: &mpsc::UnboundedSender<EdgeEvent>,
|
|
shutdown_rx: &mut mpsc::Receiver<()>,
|
|
) -> EdgeLoopResult {
|
|
// Build TLS connector that skips cert verification (auth is via secret)
|
|
let tls_config = rustls::ClientConfig::builder()
|
|
.dangerous()
|
|
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
|
|
.with_no_client_auth();
|
|
|
|
let connector = TlsConnector::from(Arc::new(tls_config));
|
|
|
|
let addr = format!("{}:{}", config.hub_host, config.hub_port);
|
|
let tcp = match TcpStream::connect(&addr).await {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
log::error!("Failed to connect to hub at {}: {}", addr, e);
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
};
|
|
|
|
let server_name = rustls::pki_types::ServerName::try_from(config.hub_host.clone())
|
|
.unwrap_or_else(|_| rustls::pki_types::ServerName::try_from("remoteingress-hub".to_string()).unwrap());
|
|
|
|
let tls_stream = match connector.connect(server_name, tcp).await {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
log::error!("TLS handshake failed: {}", e);
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
};
|
|
|
|
let (read_half, mut write_half) = tokio::io::split(tls_stream);
|
|
|
|
// Send auth line
|
|
let auth_line = format!("EDGE {} {}\n", config.edge_id, config.secret);
|
|
if write_half.write_all(auth_line.as_bytes()).await.is_err() {
|
|
return EdgeLoopResult::Reconnect;
|
|
}
|
|
|
|
*connected.write().await = true;
|
|
let _ = event_tx.send(EdgeEvent::TunnelConnected);
|
|
log::info!("Connected to hub at {}", addr);
|
|
|
|
// Start STUN discovery
|
|
let stun_interval = config.stun_interval_secs.unwrap_or(300);
|
|
let public_ip_clone = public_ip.clone();
|
|
let event_tx_clone = event_tx.clone();
|
|
let stun_handle = tokio::spawn(async move {
|
|
loop {
|
|
if let Some(ip) = crate::stun::discover_public_ip().await {
|
|
let mut pip = public_ip_clone.write().await;
|
|
let changed = pip.as_ref() != Some(&ip);
|
|
*pip = Some(ip.clone());
|
|
if changed {
|
|
let _ = event_tx_clone.send(EdgeEvent::PublicIpDiscovered { ip });
|
|
}
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_secs(stun_interval)).await;
|
|
}
|
|
});
|
|
|
|
// Client socket map: stream_id -> sender for writing data back to client
|
|
let client_writers: Arc<Mutex<HashMap<u32, mpsc::Sender<Vec<u8>>>>> =
|
|
Arc::new(Mutex::new(HashMap::new()));
|
|
|
|
// Shared tunnel writer
|
|
let tunnel_writer = Arc::new(Mutex::new(write_half));
|
|
|
|
// Start TCP listeners for each port
|
|
let mut listener_handles = Vec::new();
|
|
for &port in &config.listen_ports {
|
|
let tunnel_writer = tunnel_writer.clone();
|
|
let client_writers = client_writers.clone();
|
|
let active_streams = active_streams.clone();
|
|
let next_stream_id = next_stream_id.clone();
|
|
let edge_id = config.edge_id.clone();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
let listener = match TcpListener::bind(("0.0.0.0", port)).await {
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
log::error!("Failed to bind port {}: {}", port, e);
|
|
return;
|
|
}
|
|
};
|
|
log::info!("Listening on port {}", port);
|
|
|
|
loop {
|
|
match listener.accept().await {
|
|
Ok((client_stream, client_addr)) => {
|
|
let stream_id = next_stream_id.fetch_add(1, Ordering::Relaxed);
|
|
let tunnel_writer = tunnel_writer.clone();
|
|
let client_writers = client_writers.clone();
|
|
let active_streams = active_streams.clone();
|
|
let edge_id = edge_id.clone();
|
|
|
|
active_streams.fetch_add(1, Ordering::Relaxed);
|
|
|
|
tokio::spawn(async move {
|
|
handle_client_connection(
|
|
client_stream,
|
|
client_addr,
|
|
stream_id,
|
|
port,
|
|
&edge_id,
|
|
tunnel_writer,
|
|
client_writers,
|
|
)
|
|
.await;
|
|
active_streams.fetch_sub(1, Ordering::Relaxed);
|
|
});
|
|
}
|
|
Err(e) => {
|
|
log::error!("Accept error on port {}: {}", port, e);
|
|
}
|
|
}
|
|
}
|
|
});
|
|
listener_handles.push(handle);
|
|
}
|
|
|
|
// Read frames from hub
|
|
let mut frame_reader = FrameReader::new(read_half);
|
|
let result = loop {
|
|
tokio::select! {
|
|
frame_result = frame_reader.next_frame() => {
|
|
match frame_result {
|
|
Ok(Some(frame)) => {
|
|
match frame.frame_type {
|
|
FRAME_DATA_BACK => {
|
|
let writers = client_writers.lock().await;
|
|
if let Some(tx) = writers.get(&frame.stream_id) {
|
|
let _ = tx.send(frame.payload).await;
|
|
}
|
|
}
|
|
FRAME_CLOSE_BACK => {
|
|
let mut writers = client_writers.lock().await;
|
|
writers.remove(&frame.stream_id);
|
|
}
|
|
_ => {
|
|
log::warn!("Unexpected frame type {} from hub", frame.frame_type);
|
|
}
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
log::info!("Hub disconnected (EOF)");
|
|
break EdgeLoopResult::Reconnect;
|
|
}
|
|
Err(e) => {
|
|
log::error!("Hub frame error: {}", e);
|
|
break EdgeLoopResult::Reconnect;
|
|
}
|
|
}
|
|
}
|
|
_ = shutdown_rx.recv() => {
|
|
break EdgeLoopResult::Shutdown;
|
|
}
|
|
}
|
|
};
|
|
|
|
// Cleanup
|
|
stun_handle.abort();
|
|
for h in listener_handles {
|
|
h.abort();
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
async fn handle_client_connection(
|
|
client_stream: TcpStream,
|
|
client_addr: std::net::SocketAddr,
|
|
stream_id: u32,
|
|
dest_port: u16,
|
|
edge_id: &str,
|
|
tunnel_writer: Arc<Mutex<tokio::io::WriteHalf<tokio_rustls::client::TlsStream<TcpStream>>>>,
|
|
client_writers: Arc<Mutex<HashMap<u32, mpsc::Sender<Vec<u8>>>>>,
|
|
) {
|
|
let client_ip = client_addr.ip().to_string();
|
|
let client_port = client_addr.port();
|
|
|
|
// Determine edge IP (use 0.0.0.0 as placeholder — hub doesn't use it for routing)
|
|
let edge_ip = "0.0.0.0";
|
|
|
|
// Send OPEN frame with PROXY v1 header
|
|
let proxy_header = build_proxy_v1_header(&client_ip, edge_ip, client_port, dest_port);
|
|
let open_frame = encode_frame(stream_id, FRAME_OPEN, proxy_header.as_bytes());
|
|
{
|
|
let mut w = tunnel_writer.lock().await;
|
|
if w.write_all(&open_frame).await.is_err() {
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Set up channel for data coming back from hub
|
|
let (back_tx, mut back_rx) = mpsc::channel::<Vec<u8>>(256);
|
|
{
|
|
let mut writers = client_writers.lock().await;
|
|
writers.insert(stream_id, back_tx);
|
|
}
|
|
|
|
let (mut client_read, mut client_write) = client_stream.into_split();
|
|
|
|
// Task: hub -> client
|
|
let hub_to_client = tokio::spawn(async move {
|
|
while let Some(data) = back_rx.recv().await {
|
|
if client_write.write_all(&data).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
let _ = client_write.shutdown().await;
|
|
});
|
|
|
|
// Task: client -> hub
|
|
let mut buf = vec![0u8; 32768];
|
|
loop {
|
|
match client_read.read(&mut buf).await {
|
|
Ok(0) => break,
|
|
Ok(n) => {
|
|
let data_frame = encode_frame(stream_id, FRAME_DATA, &buf[..n]);
|
|
let mut w = tunnel_writer.lock().await;
|
|
if w.write_all(&data_frame).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
|
|
// Send CLOSE frame
|
|
let close_frame = encode_frame(stream_id, FRAME_CLOSE, &[]);
|
|
{
|
|
let mut w = tunnel_writer.lock().await;
|
|
let _ = w.write_all(&close_frame).await;
|
|
}
|
|
|
|
// Cleanup
|
|
{
|
|
let mut writers = client_writers.lock().await;
|
|
writers.remove(&stream_id);
|
|
}
|
|
hub_to_client.abort();
|
|
let _ = edge_id; // used for logging context
|
|
}
|
|
|
|
/// TLS certificate verifier that accepts any certificate (auth is via shared secret).
|
|
#[derive(Debug)]
|
|
struct NoCertVerifier;
|
|
|
|
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
|
|
fn verify_server_cert(
|
|
&self,
|
|
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
|
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
|
_server_name: &rustls::pki_types::ServerName<'_>,
|
|
_ocsp_response: &[u8],
|
|
_now: rustls::pki_types::UnixTime,
|
|
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
|
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
|
}
|
|
|
|
fn verify_tls12_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn verify_tls13_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
|
vec![
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
|
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
|
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
|
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
|
|
rustls::SignatureScheme::RSA_PSS_SHA256,
|
|
rustls::SignatureScheme::RSA_PSS_SHA384,
|
|
rustls::SignatureScheme::RSA_PSS_SHA512,
|
|
rustls::SignatureScheme::ED25519,
|
|
rustls::SignatureScheme::ED448,
|
|
]
|
|
}
|
|
}
|