828 lines
31 KiB
Rust
828 lines
31 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::{mpsc, Mutex, RwLock};
|
|
use tokio_rustls::TlsAcceptor;
|
|
use tokio_util::sync::CancellationToken;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use remoteingress_protocol::*;
|
|
|
|
/// Hub configuration.
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct HubConfig {
|
|
pub tunnel_port: u16,
|
|
pub target_host: Option<String>,
|
|
#[serde(skip)]
|
|
pub tls_cert_pem: Option<String>,
|
|
#[serde(skip)]
|
|
pub tls_key_pem: Option<String>,
|
|
}
|
|
|
|
impl Default for HubConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
tunnel_port: 8443,
|
|
target_host: Some("127.0.0.1".to_string()),
|
|
tls_cert_pem: None,
|
|
tls_key_pem: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// An allowed edge identity.
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct AllowedEdge {
|
|
pub id: String,
|
|
pub secret: String,
|
|
#[serde(default)]
|
|
pub listen_ports: Vec<u16>,
|
|
pub stun_interval_secs: Option<u64>,
|
|
}
|
|
|
|
/// Handshake response sent to edge after authentication.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct HandshakeResponse {
|
|
listen_ports: Vec<u16>,
|
|
stun_interval_secs: u64,
|
|
}
|
|
|
|
/// Configuration update pushed to a connected edge at runtime.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct EdgeConfigUpdate {
|
|
pub listen_ports: Vec<u16>,
|
|
}
|
|
|
|
/// Runtime status of a connected edge.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct ConnectedEdgeStatus {
|
|
pub edge_id: String,
|
|
pub connected_at: u64,
|
|
pub active_streams: usize,
|
|
}
|
|
|
|
/// Events emitted by the hub.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[serde(tag = "type")]
|
|
pub enum HubEvent {
|
|
#[serde(rename_all = "camelCase")]
|
|
EdgeConnected { edge_id: String },
|
|
#[serde(rename_all = "camelCase")]
|
|
EdgeDisconnected { edge_id: String },
|
|
#[serde(rename_all = "camelCase")]
|
|
StreamOpened { edge_id: String, stream_id: u32 },
|
|
#[serde(rename_all = "camelCase")]
|
|
StreamClosed { edge_id: String, stream_id: u32 },
|
|
}
|
|
|
|
/// Hub status response.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct HubStatus {
|
|
pub running: bool,
|
|
pub tunnel_port: u16,
|
|
pub connected_edges: Vec<ConnectedEdgeStatus>,
|
|
}
|
|
|
|
/// The tunnel hub that accepts edge connections and demuxes streams to SmartProxy.
|
|
pub struct TunnelHub {
|
|
config: RwLock<HubConfig>,
|
|
allowed_edges: Arc<RwLock<HashMap<String, AllowedEdge>>>,
|
|
connected_edges: Arc<Mutex<HashMap<String, ConnectedEdgeInfo>>>,
|
|
event_tx: mpsc::Sender<HubEvent>,
|
|
event_rx: Mutex<Option<mpsc::Receiver<HubEvent>>>,
|
|
shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
|
|
running: RwLock<bool>,
|
|
cancel_token: CancellationToken,
|
|
}
|
|
|
|
struct ConnectedEdgeInfo {
|
|
connected_at: u64,
|
|
active_streams: Arc<Mutex<HashMap<u32, (mpsc::Sender<Vec<u8>>, CancellationToken)>>>,
|
|
config_tx: mpsc::Sender<EdgeConfigUpdate>,
|
|
#[allow(dead_code)] // kept alive for Drop — cancels child tokens when edge is removed
|
|
cancel_token: CancellationToken,
|
|
}
|
|
|
|
impl TunnelHub {
|
|
pub fn new(config: HubConfig) -> Self {
|
|
let (event_tx, event_rx) = mpsc::channel(1024);
|
|
Self {
|
|
config: RwLock::new(config),
|
|
allowed_edges: Arc::new(RwLock::new(HashMap::new())),
|
|
connected_edges: Arc::new(Mutex::new(HashMap::new())),
|
|
event_tx,
|
|
event_rx: Mutex::new(Some(event_rx)),
|
|
shutdown_tx: Mutex::new(None),
|
|
running: RwLock::new(false),
|
|
cancel_token: CancellationToken::new(),
|
|
}
|
|
}
|
|
|
|
/// Take the event receiver (can only be called once).
|
|
pub async fn take_event_rx(&self) -> Option<mpsc::Receiver<HubEvent>> {
|
|
self.event_rx.lock().await.take()
|
|
}
|
|
|
|
/// Update the list of allowed edges.
|
|
/// For any currently-connected edge whose ports changed, push a config update.
|
|
pub async fn update_allowed_edges(&self, edges: Vec<AllowedEdge>) {
|
|
let mut map = self.allowed_edges.write().await;
|
|
|
|
// Build new map
|
|
let mut new_map = HashMap::new();
|
|
for edge in &edges {
|
|
new_map.insert(edge.id.clone(), edge.clone());
|
|
}
|
|
|
|
// Push config updates to connected edges whose ports changed
|
|
let connected = self.connected_edges.lock().await;
|
|
for edge in &edges {
|
|
if let Some(info) = connected.get(&edge.id) {
|
|
// Check if ports changed compared to old config
|
|
let ports_changed = match map.get(&edge.id) {
|
|
Some(old) => old.listen_ports != edge.listen_ports,
|
|
None => true, // newly allowed edge that's already connected
|
|
};
|
|
if ports_changed {
|
|
let update = EdgeConfigUpdate {
|
|
listen_ports: edge.listen_ports.clone(),
|
|
};
|
|
let _ = info.config_tx.try_send(update);
|
|
}
|
|
}
|
|
}
|
|
|
|
*map = new_map;
|
|
}
|
|
|
|
/// Get the current hub status.
|
|
pub async fn get_status(&self) -> HubStatus {
|
|
let running = *self.running.read().await;
|
|
let config = self.config.read().await;
|
|
let edges = self.connected_edges.lock().await;
|
|
|
|
let mut connected = Vec::new();
|
|
for (id, info) in edges.iter() {
|
|
let streams = info.active_streams.lock().await;
|
|
connected.push(ConnectedEdgeStatus {
|
|
edge_id: id.clone(),
|
|
connected_at: info.connected_at,
|
|
active_streams: streams.len(),
|
|
});
|
|
}
|
|
|
|
HubStatus {
|
|
running,
|
|
tunnel_port: config.tunnel_port,
|
|
connected_edges: connected,
|
|
}
|
|
}
|
|
|
|
/// Start the hub — listen for TLS connections from edges.
|
|
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let config = self.config.read().await.clone();
|
|
let tls_config = build_tls_config(&config)?;
|
|
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
|
|
|
let listener = TcpListener::bind(("0.0.0.0", config.tunnel_port)).await?;
|
|
log::info!("Hub listening on port {}", config.tunnel_port);
|
|
|
|
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
|
*self.shutdown_tx.lock().await = Some(shutdown_tx);
|
|
*self.running.write().await = true;
|
|
|
|
let allowed = self.allowed_edges.clone();
|
|
let connected = self.connected_edges.clone();
|
|
let event_tx = self.event_tx.clone();
|
|
let target_host = config.target_host.unwrap_or_else(|| "127.0.0.1".to_string());
|
|
let hub_token = self.cancel_token.clone();
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
result = listener.accept() => {
|
|
match result {
|
|
Ok((stream, addr)) => {
|
|
log::info!("Edge connection from {}", addr);
|
|
let acceptor = acceptor.clone();
|
|
let allowed = allowed.clone();
|
|
let connected = connected.clone();
|
|
let event_tx = event_tx.clone();
|
|
let target = target_host.clone();
|
|
let edge_token = hub_token.child_token();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_edge_connection(
|
|
stream, acceptor, allowed, connected, event_tx, target, edge_token,
|
|
).await {
|
|
log::error!("Edge connection error: {}", e);
|
|
}
|
|
});
|
|
}
|
|
Err(e) => {
|
|
log::error!("Accept error: {}", e);
|
|
}
|
|
}
|
|
}
|
|
_ = hub_token.cancelled() => {
|
|
log::info!("Hub shutting down (token cancelled)");
|
|
break;
|
|
}
|
|
_ = shutdown_rx.recv() => {
|
|
log::info!("Hub shutting down");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Stop the hub.
|
|
pub async fn stop(&self) {
|
|
self.cancel_token.cancel();
|
|
if let Some(tx) = self.shutdown_tx.lock().await.take() {
|
|
let _ = tx.send(()).await;
|
|
}
|
|
*self.running.write().await = false;
|
|
// Clear connected edges
|
|
self.connected_edges.lock().await.clear();
|
|
}
|
|
}
|
|
|
|
impl Drop for TunnelHub {
|
|
fn drop(&mut self) {
|
|
self.cancel_token.cancel();
|
|
}
|
|
}
|
|
|
|
/// Handle a single edge connection: authenticate, then enter frame loop.
|
|
async fn handle_edge_connection(
|
|
stream: TcpStream,
|
|
acceptor: TlsAcceptor,
|
|
allowed: Arc<RwLock<HashMap<String, AllowedEdge>>>,
|
|
connected: Arc<Mutex<HashMap<String, ConnectedEdgeInfo>>>,
|
|
event_tx: mpsc::Sender<HubEvent>,
|
|
target_host: String,
|
|
edge_token: CancellationToken,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let tls_stream = acceptor.accept(stream).await?;
|
|
let (read_half, mut write_half) = tokio::io::split(tls_stream);
|
|
let mut buf_reader = BufReader::new(read_half);
|
|
|
|
// Read auth line: "EDGE <edgeId> <secret>\n"
|
|
let mut auth_line = String::new();
|
|
buf_reader.read_line(&mut auth_line).await?;
|
|
let auth_line = auth_line.trim();
|
|
|
|
let parts: Vec<&str> = auth_line.splitn(3, ' ').collect();
|
|
if parts.len() != 3 || parts[0] != "EDGE" {
|
|
return Err("invalid auth line".into());
|
|
}
|
|
|
|
let edge_id = parts[1].to_string();
|
|
let secret = parts[2];
|
|
|
|
// Verify credentials and extract edge config
|
|
let (listen_ports, stun_interval_secs) = {
|
|
let edges = allowed.read().await;
|
|
match edges.get(&edge_id) {
|
|
Some(edge) => {
|
|
if !constant_time_eq(secret.as_bytes(), edge.secret.as_bytes()) {
|
|
return Err(format!("invalid secret for edge {}", edge_id).into());
|
|
}
|
|
(edge.listen_ports.clone(), edge.stun_interval_secs.unwrap_or(300))
|
|
}
|
|
None => {
|
|
return Err(format!("unknown edge {}", edge_id).into());
|
|
}
|
|
}
|
|
};
|
|
|
|
log::info!("Edge {} authenticated", edge_id);
|
|
let _ = event_tx.try_send(HubEvent::EdgeConnected {
|
|
edge_id: edge_id.clone(),
|
|
});
|
|
|
|
// Send handshake response with initial config before frame protocol begins
|
|
let handshake = HandshakeResponse {
|
|
listen_ports: listen_ports.clone(),
|
|
stun_interval_secs,
|
|
};
|
|
let mut handshake_json = serde_json::to_string(&handshake)?;
|
|
handshake_json.push('\n');
|
|
write_half.write_all(handshake_json.as_bytes()).await?;
|
|
|
|
// Track this edge
|
|
let streams: Arc<Mutex<HashMap<u32, (mpsc::Sender<Vec<u8>>, CancellationToken)>>> =
|
|
Arc::new(Mutex::new(HashMap::new()));
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs();
|
|
|
|
// Create config update channel
|
|
let (config_tx, mut config_rx) = mpsc::channel::<EdgeConfigUpdate>(16);
|
|
|
|
{
|
|
let mut edges = connected.lock().await;
|
|
edges.insert(
|
|
edge_id.clone(),
|
|
ConnectedEdgeInfo {
|
|
connected_at: now,
|
|
active_streams: streams.clone(),
|
|
config_tx,
|
|
cancel_token: edge_token.clone(),
|
|
},
|
|
);
|
|
}
|
|
|
|
// Shared writer for sending frames back to edge
|
|
let write_half = Arc::new(Mutex::new(write_half));
|
|
|
|
// Spawn task to forward config updates as FRAME_CONFIG frames
|
|
let config_writer = write_half.clone();
|
|
let config_edge_id = edge_id.clone();
|
|
let config_token = edge_token.clone();
|
|
let config_handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
update = config_rx.recv() => {
|
|
match update {
|
|
Some(update) => {
|
|
if let Ok(payload) = serde_json::to_vec(&update) {
|
|
let frame = encode_frame(0, FRAME_CONFIG, &payload);
|
|
let mut w = config_writer.lock().await;
|
|
if w.write_all(&frame).await.is_err() {
|
|
log::error!("Failed to send config update to edge {}", config_edge_id);
|
|
break;
|
|
}
|
|
log::info!("Sent config update to edge {}: ports {:?}", config_edge_id, update.listen_ports);
|
|
}
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
_ = config_token.cancelled() => break,
|
|
}
|
|
}
|
|
});
|
|
|
|
// Frame reading loop
|
|
let mut frame_reader = FrameReader::new(buf_reader);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
frame_result = frame_reader.next_frame() => {
|
|
match frame_result {
|
|
Ok(Some(frame)) => {
|
|
match frame.frame_type {
|
|
FRAME_OPEN => {
|
|
// Payload is PROXY v1 header line
|
|
let proxy_header = String::from_utf8_lossy(&frame.payload).to_string();
|
|
|
|
// Parse destination port from PROXY header
|
|
let dest_port = parse_dest_port_from_proxy(&proxy_header).unwrap_or(443);
|
|
|
|
let stream_id = frame.stream_id;
|
|
let edge_id_clone = edge_id.clone();
|
|
let event_tx_clone = event_tx.clone();
|
|
let streams_clone = streams.clone();
|
|
let writer_clone = write_half.clone();
|
|
let target = target_host.clone();
|
|
let stream_token = edge_token.child_token();
|
|
|
|
let _ = event_tx.try_send(HubEvent::StreamOpened {
|
|
edge_id: edge_id.clone(),
|
|
stream_id,
|
|
});
|
|
|
|
// Create channel for data from edge to this stream
|
|
let (data_tx, mut data_rx) = mpsc::channel::<Vec<u8>>(256);
|
|
{
|
|
let mut s = streams.lock().await;
|
|
s.insert(stream_id, (data_tx, stream_token.clone()));
|
|
}
|
|
|
|
// Spawn task: connect to SmartProxy, send PROXY header, pipe data
|
|
tokio::spawn(async move {
|
|
let result = async {
|
|
let mut upstream =
|
|
TcpStream::connect((target.as_str(), dest_port)).await?;
|
|
upstream.write_all(proxy_header.as_bytes()).await?;
|
|
|
|
let (mut up_read, mut up_write) =
|
|
upstream.into_split();
|
|
|
|
// Forward data from edge (via channel) to SmartProxy
|
|
let writer_token = stream_token.clone();
|
|
let writer_for_edge_data = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
data = data_rx.recv() => {
|
|
match data {
|
|
Some(data) => {
|
|
if up_write.write_all(&data).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
_ = writer_token.cancelled() => break,
|
|
}
|
|
}
|
|
let _ = up_write.shutdown().await;
|
|
});
|
|
|
|
// Forward data from SmartProxy back to edge
|
|
let mut buf = vec![0u8; 32768];
|
|
loop {
|
|
tokio::select! {
|
|
read_result = up_read.read(&mut buf) => {
|
|
match read_result {
|
|
Ok(0) => break,
|
|
Ok(n) => {
|
|
let frame =
|
|
encode_frame(stream_id, FRAME_DATA_BACK, &buf[..n]);
|
|
let mut w = writer_clone.lock().await;
|
|
if w.write_all(&frame).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
_ = stream_token.cancelled() => break,
|
|
}
|
|
}
|
|
|
|
// Send CLOSE_BACK to edge (only if not cancelled)
|
|
if !stream_token.is_cancelled() {
|
|
let close_frame = encode_frame(stream_id, FRAME_CLOSE_BACK, &[]);
|
|
let mut w = writer_clone.lock().await;
|
|
let _ = w.write_all(&close_frame).await;
|
|
}
|
|
|
|
writer_for_edge_data.abort();
|
|
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
|
|
}
|
|
.await;
|
|
|
|
if let Err(e) = result {
|
|
log::error!("Stream {} error: {}", stream_id, e);
|
|
// Send CLOSE_BACK on error (only if not cancelled)
|
|
if !stream_token.is_cancelled() {
|
|
let close_frame = encode_frame(stream_id, FRAME_CLOSE_BACK, &[]);
|
|
let mut w = writer_clone.lock().await;
|
|
let _ = w.write_all(&close_frame).await;
|
|
}
|
|
}
|
|
|
|
// Clean up stream (guard against duplicate if FRAME_CLOSE already removed it)
|
|
let was_present = {
|
|
let mut s = streams_clone.lock().await;
|
|
s.remove(&stream_id).is_some()
|
|
};
|
|
if was_present {
|
|
let _ = event_tx_clone.try_send(HubEvent::StreamClosed {
|
|
edge_id: edge_id_clone,
|
|
stream_id,
|
|
});
|
|
}
|
|
});
|
|
}
|
|
FRAME_DATA => {
|
|
let s = streams.lock().await;
|
|
if let Some((tx, _)) = s.get(&frame.stream_id) {
|
|
let _ = tx.send(frame.payload).await;
|
|
}
|
|
}
|
|
FRAME_CLOSE => {
|
|
let mut s = streams.lock().await;
|
|
if let Some((_, token)) = s.remove(&frame.stream_id) {
|
|
token.cancel();
|
|
let _ = event_tx.try_send(HubEvent::StreamClosed {
|
|
edge_id: edge_id.clone(),
|
|
stream_id: frame.stream_id,
|
|
});
|
|
}
|
|
}
|
|
_ => {
|
|
log::warn!("Unexpected frame type {} from edge", frame.frame_type);
|
|
}
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
log::info!("Edge {} disconnected (EOF)", edge_id);
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
log::error!("Edge {} frame error: {}", edge_id, e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
_ = edge_token.cancelled() => {
|
|
log::info!("Edge {} cancelled by hub", edge_id);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Cleanup: cancel edge token to propagate to all child tasks
|
|
edge_token.cancel();
|
|
config_handle.abort();
|
|
{
|
|
let mut edges = connected.lock().await;
|
|
edges.remove(&edge_id);
|
|
}
|
|
let _ = event_tx.try_send(HubEvent::EdgeDisconnected {
|
|
edge_id: edge_id.clone(),
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Parse destination port from PROXY v1 header.
|
|
fn parse_dest_port_from_proxy(header: &str) -> Option<u16> {
|
|
let parts: Vec<&str> = header.trim().split_whitespace().collect();
|
|
if parts.len() >= 6 {
|
|
parts[5].parse().ok()
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Build TLS server config from PEM strings, or auto-generate self-signed.
|
|
fn build_tls_config(
|
|
config: &HubConfig,
|
|
) -> Result<rustls::ServerConfig, Box<dyn std::error::Error + Send + Sync>> {
|
|
let (cert_pem, key_pem) = match (&config.tls_cert_pem, &config.tls_key_pem) {
|
|
(Some(cert), Some(key)) => (cert.clone(), key.clone()),
|
|
_ => {
|
|
// Generate self-signed certificate
|
|
let cert = rcgen::generate_simple_self_signed(vec!["remoteingress-hub".to_string()])?;
|
|
let cert_pem = cert.cert.pem();
|
|
let key_pem = cert.key_pair.serialize_pem();
|
|
(cert_pem, key_pem)
|
|
}
|
|
};
|
|
|
|
let certs = rustls_pemfile_parse_certs(&cert_pem)?;
|
|
let key = rustls_pemfile_parse_key(&key_pem)?;
|
|
|
|
let mut config = rustls::ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(certs, key)?;
|
|
|
|
config.alpn_protocols = vec![b"remoteingress".to_vec()];
|
|
|
|
Ok(config)
|
|
}
|
|
|
|
fn rustls_pemfile_parse_certs(
|
|
pem: &str,
|
|
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>>
|
|
{
|
|
let mut reader = std::io::Cursor::new(pem.as_bytes());
|
|
let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
|
|
Ok(certs)
|
|
}
|
|
|
|
fn rustls_pemfile_parse_key(
|
|
pem: &str,
|
|
) -> Result<rustls::pki_types::PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let mut reader = std::io::Cursor::new(pem.as_bytes());
|
|
let key = rustls_pemfile::private_key(&mut reader)?
|
|
.ok_or("no private key found in PEM")?;
|
|
Ok(key)
|
|
}
|
|
|
|
/// Constant-time comparison of two byte slices.
|
|
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
|
if a.len() != b.len() {
|
|
return false;
|
|
}
|
|
let mut diff = 0u8;
|
|
for (x, y) in a.iter().zip(b.iter()) {
|
|
diff |= x ^ y;
|
|
}
|
|
diff == 0
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// --- constant_time_eq tests ---
|
|
|
|
#[test]
|
|
fn test_constant_time_eq_equal() {
|
|
assert!(constant_time_eq(b"hello", b"hello"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_constant_time_eq_different_content() {
|
|
assert!(!constant_time_eq(b"hello", b"world"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_constant_time_eq_different_lengths() {
|
|
assert!(!constant_time_eq(b"short", b"longer"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_constant_time_eq_both_empty() {
|
|
assert!(constant_time_eq(b"", b""));
|
|
}
|
|
|
|
#[test]
|
|
fn test_constant_time_eq_one_empty() {
|
|
assert!(!constant_time_eq(b"", b"notempty"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_constant_time_eq_single_bit_difference() {
|
|
// 'A' = 0x41, 'a' = 0x61 — differ by one bit
|
|
assert!(!constant_time_eq(b"A", b"a"));
|
|
}
|
|
|
|
// --- parse_dest_port_from_proxy tests ---
|
|
|
|
#[test]
|
|
fn test_parse_dest_port_443() {
|
|
let header = "PROXY TCP4 1.2.3.4 5.6.7.8 12345 443\r\n";
|
|
assert_eq!(parse_dest_port_from_proxy(header), Some(443));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_dest_port_80() {
|
|
let header = "PROXY TCP4 10.0.0.1 10.0.0.2 54321 80\r\n";
|
|
assert_eq!(parse_dest_port_from_proxy(header), Some(80));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_dest_port_65535() {
|
|
let header = "PROXY TCP4 10.0.0.1 10.0.0.2 1 65535\r\n";
|
|
assert_eq!(parse_dest_port_from_proxy(header), Some(65535));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_dest_port_too_few_fields() {
|
|
let header = "PROXY TCP4 1.2.3.4";
|
|
assert_eq!(parse_dest_port_from_proxy(header), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_dest_port_empty_string() {
|
|
assert_eq!(parse_dest_port_from_proxy(""), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_dest_port_non_numeric() {
|
|
let header = "PROXY TCP4 1.2.3.4 5.6.7.8 12345 abc\r\n";
|
|
assert_eq!(parse_dest_port_from_proxy(header), None);
|
|
}
|
|
|
|
// --- Serde tests ---
|
|
|
|
#[test]
|
|
fn test_allowed_edge_deserialize_all_fields() {
|
|
let json = r#"{
|
|
"id": "edge-1",
|
|
"secret": "s3cret",
|
|
"listenPorts": [443, 8080],
|
|
"stunIntervalSecs": 120
|
|
}"#;
|
|
let edge: AllowedEdge = serde_json::from_str(json).unwrap();
|
|
assert_eq!(edge.id, "edge-1");
|
|
assert_eq!(edge.secret, "s3cret");
|
|
assert_eq!(edge.listen_ports, vec![443, 8080]);
|
|
assert_eq!(edge.stun_interval_secs, Some(120));
|
|
}
|
|
|
|
#[test]
|
|
fn test_allowed_edge_deserialize_with_defaults() {
|
|
let json = r#"{"id": "edge-2", "secret": "key"}"#;
|
|
let edge: AllowedEdge = serde_json::from_str(json).unwrap();
|
|
assert_eq!(edge.id, "edge-2");
|
|
assert_eq!(edge.secret, "key");
|
|
assert!(edge.listen_ports.is_empty());
|
|
assert_eq!(edge.stun_interval_secs, None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_handshake_response_serializes_camel_case() {
|
|
let resp = HandshakeResponse {
|
|
listen_ports: vec![443, 8080],
|
|
stun_interval_secs: 300,
|
|
};
|
|
let json = serde_json::to_value(&resp).unwrap();
|
|
assert_eq!(json["listenPorts"], serde_json::json!([443, 8080]));
|
|
assert_eq!(json["stunIntervalSecs"], 300);
|
|
// Ensure snake_case keys are NOT present
|
|
assert!(json.get("listen_ports").is_none());
|
|
assert!(json.get("stun_interval_secs").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_config_update_serializes_camel_case() {
|
|
let update = EdgeConfigUpdate {
|
|
listen_ports: vec![80, 443],
|
|
};
|
|
let json = serde_json::to_value(&update).unwrap();
|
|
assert_eq!(json["listenPorts"], serde_json::json!([80, 443]));
|
|
assert!(json.get("listen_ports").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_hub_config_default() {
|
|
let config = HubConfig::default();
|
|
assert_eq!(config.tunnel_port, 8443);
|
|
assert_eq!(config.target_host, Some("127.0.0.1".to_string()));
|
|
assert!(config.tls_cert_pem.is_none());
|
|
assert!(config.tls_key_pem.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_hub_event_edge_connected_serialize() {
|
|
let event = HubEvent::EdgeConnected {
|
|
edge_id: "edge-1".to_string(),
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "edgeConnected");
|
|
assert_eq!(json["edgeId"], "edge-1");
|
|
}
|
|
|
|
#[test]
|
|
fn test_hub_event_edge_disconnected_serialize() {
|
|
let event = HubEvent::EdgeDisconnected {
|
|
edge_id: "edge-2".to_string(),
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "edgeDisconnected");
|
|
assert_eq!(json["edgeId"], "edge-2");
|
|
}
|
|
|
|
#[test]
|
|
fn test_hub_event_stream_opened_serialize() {
|
|
let event = HubEvent::StreamOpened {
|
|
edge_id: "e".to_string(),
|
|
stream_id: 42,
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "streamOpened");
|
|
assert_eq!(json["edgeId"], "e");
|
|
assert_eq!(json["streamId"], 42);
|
|
}
|
|
|
|
#[test]
|
|
fn test_hub_event_stream_closed_serialize() {
|
|
let event = HubEvent::StreamClosed {
|
|
edge_id: "e".to_string(),
|
|
stream_id: 7,
|
|
};
|
|
let json = serde_json::to_value(&event).unwrap();
|
|
assert_eq!(json["type"], "streamClosed");
|
|
assert_eq!(json["edgeId"], "e");
|
|
assert_eq!(json["streamId"], 7);
|
|
}
|
|
|
|
// --- Async tests ---
|
|
|
|
#[tokio::test]
|
|
async fn test_tunnel_hub_new_get_status() {
|
|
let hub = TunnelHub::new(HubConfig::default());
|
|
let status = hub.get_status().await;
|
|
assert!(!status.running);
|
|
assert!(status.connected_edges.is_empty());
|
|
assert_eq!(status.tunnel_port, 8443);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_tunnel_hub_take_event_rx() {
|
|
let hub = TunnelHub::new(HubConfig::default());
|
|
let rx1 = hub.take_event_rx().await;
|
|
assert!(rx1.is_some());
|
|
let rx2 = hub.take_event_rx().await;
|
|
assert!(rx2.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_tunnel_hub_stop_without_start() {
|
|
let hub = TunnelHub::new(HubConfig::default());
|
|
hub.stop().await; // should not panic
|
|
let status = hub.get_status().await;
|
|
assert!(!status.running);
|
|
}
|
|
}
|