feat(smart-proxy): add socket-handler relay, fast-path port-only forwarding, metrics and bridge improvements, and various TS/Rust integration fixes

This commit is contained in:
2026-02-09 16:25:33 +00:00
parent 41efdb47f8
commit f7605e042e
17 changed files with 724 additions and 300 deletions

View File

@@ -146,6 +146,41 @@ pub fn is_tls(data: &[u8]) -> bool {
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
}
/// Extract the HTTP request path from initial data.
/// E.g., from "GET /foo/bar HTTP/1.1\r\n..." returns Some("/foo/bar").
pub fn extract_http_path(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
// Find first space (after method)
let method_end = text.find(' ')?;
let rest = &text[method_end + 1..];
// Find end of path (next space before "HTTP/...")
let path_end = rest.find(' ').unwrap_or(rest.len());
let path = &rest[..path_end];
// Strip query string for path matching
let path = path.split('?').next().unwrap_or(path);
if path.starts_with('/') {
Some(path.to_string())
} else {
None
}
}
/// Extract the HTTP Host header from initial data.
/// E.g., from "GET / HTTP/1.1\r\nHost: example.com\r\n..." returns Some("example.com").
pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
// Strip port if present
let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() {
return Some(host.to_lowercase());
}
}
}
None
}
/// Check if the initial bytes look like HTTP.
pub fn is_http(data: &[u8]) -> bool {
if data.len() < 4 {

View File

@@ -1,10 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{info, error, debug, warn};
use thiserror::Error;
use rustproxy_config::RouteActionType;
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use rustproxy_http::HttpProxyService;
@@ -82,8 +84,8 @@ impl Default for ConnectionConfig {
pub struct TcpListenerManager {
/// Active listeners indexed by port
listeners: HashMap<u16, tokio::task::JoinHandle<()>>,
/// Shared route manager
route_manager: Arc<RouteManager>,
/// Shared route manager (ArcSwap for hot-reload visibility in accept loops)
route_manager: Arc<ArcSwap<RouteManager>>,
/// Shared metrics collector
metrics: Arc<MetricsCollector>,
/// TLS acceptors indexed by domain
@@ -96,6 +98,8 @@ pub struct TcpListenerManager {
conn_tracker: Arc<ConnectionTracker>,
/// Cancellation token for graceful shutdown
cancel_token: CancellationToken,
/// Path to Unix domain socket for relaying socket-handler connections to TypeScript.
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
}
impl TcpListenerManager {
@@ -112,13 +116,14 @@ impl TcpListenerManager {
));
Self {
listeners: HashMap::new(),
route_manager,
route_manager: Arc::new(ArcSwap::from(route_manager)),
metrics,
tls_configs: Arc::new(HashMap::new()),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
cancel_token: CancellationToken::new(),
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
}
}
@@ -135,13 +140,14 @@ impl TcpListenerManager {
));
Self {
listeners: HashMap::new(),
route_manager,
route_manager: Arc::new(ArcSwap::from(route_manager)),
metrics,
tls_configs: Arc::new(HashMap::new()),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
cancel_token: CancellationToken::new(),
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
}
}
@@ -159,6 +165,12 @@ impl TcpListenerManager {
self.tls_configs = Arc::new(configs);
}
/// Set the shared socket-handler relay path.
/// This allows RustProxy to share the relay path Arc with the listener.
pub fn set_socket_handler_relay(&mut self, relay: Arc<std::sync::RwLock<Option<String>>>) {
self.socket_handler_relay = relay;
}
/// Start listening on a port.
pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> {
if self.listeners.contains_key(&port) {
@@ -172,18 +184,19 @@ impl TcpListenerManager {
info!("Listening on port {}", port);
let route_manager = Arc::clone(&self.route_manager);
let route_manager_swap = Arc::clone(&self.route_manager);
let metrics = Arc::clone(&self.metrics);
let tls_configs = Arc::clone(&self.tls_configs);
let http_proxy = Arc::clone(&self.http_proxy);
let conn_config = Arc::clone(&self.conn_config);
let conn_tracker = Arc::clone(&self.conn_tracker);
let cancel = self.cancel_token.clone();
let relay = Arc::clone(&self.socket_handler_relay);
let handle = tokio::spawn(async move {
Self::accept_loop(
listener, port, route_manager, metrics, tls_configs,
http_proxy, conn_config, conn_tracker, cancel,
listener, port, route_manager_swap, metrics, tls_configs,
http_proxy, conn_config, conn_tracker, cancel, relay,
).await;
});
@@ -255,8 +268,9 @@ impl TcpListenerManager {
}
/// Update the route manager (for hot-reload).
/// Uses ArcSwap so running accept loops immediately see the new routes.
pub fn update_route_manager(&mut self, route_manager: Arc<RouteManager>) {
self.route_manager = route_manager;
self.route_manager.store(route_manager);
}
/// Get a reference to the metrics collector.
@@ -268,13 +282,14 @@ impl TcpListenerManager {
async fn accept_loop(
listener: TcpListener,
port: u16,
route_manager: Arc<RouteManager>,
route_manager_swap: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
conn_tracker: Arc<ConnectionTracker>,
cancel: CancellationToken,
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
) {
loop {
tokio::select! {
@@ -296,18 +311,20 @@ impl TcpListenerManager {
conn_tracker.connection_opened(&ip);
let rm = Arc::clone(&route_manager);
// Load the latest route manager from ArcSwap on each connection
let rm = route_manager_swap.load_full();
let m = Arc::clone(&metrics);
let tc = Arc::clone(&tls_configs);
let hp = Arc::clone(&http_proxy);
let cc = Arc::clone(&conn_config);
let ct = Arc::clone(&conn_tracker);
let cn = cancel.clone();
let sr = Arc::clone(&socket_handler_relay);
debug!("Accepted connection from {} on port {}", peer_addr, port);
tokio::spawn(async move {
let result = Self::handle_connection(
stream, port, peer_addr, rm, m, tc, hp, cc, cn,
stream, port, peer_addr, rm, m, tc, hp, cc, cn, sr,
).await;
if let Err(e) = result {
debug!("Connection error from {}: {}", peer_addr, e);
@@ -336,11 +353,114 @@ impl TcpListenerManager {
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
cancel: CancellationToken,
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use tokio::io::AsyncReadExt;
stream.set_nodelay(true)?;
// === Fast path: try port-only matching before peeking at data ===
// This handles "server-speaks-first" protocols where the client
// doesn't send initial data (e.g., SMTP, greeting-based protocols).
// If a route matches by port alone and doesn't need domain/path/TLS info,
// we can forward immediately without waiting for client data.
{
let quick_ctx = rustproxy_routing::MatchContext {
port,
domain: None,
path: None,
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: None,
is_tls: false,
};
if let Some(quick_match) = route_manager.find_route(&quick_ctx) {
let rm = &quick_match.route.route_match;
let has_no_domain = rm.domains.is_none();
let has_no_path = rm.path.is_none();
let is_forward = quick_match.route.action.action_type == RouteActionType::Forward;
let has_no_tls = quick_match.route.action.tls.is_none();
// Only use fast path for simple port-only forward routes with no TLS
if has_no_domain && has_no_path && is_forward && has_no_tls {
if let Some(target) = quick_match.target {
let target_host = target.host.first().to_string();
let target_port = target.port.resolve(port);
let route_id = quick_match.route.id.as_deref();
// Check route-level IP security
if let Some(ref security) = quick_match.route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security, &peer_addr.ip(),
) {
debug!("Connection from {} blocked by route security", peer_addr);
return Ok(());
}
}
metrics.connection_opened(route_id);
let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms);
let inactivity_timeout = std::time::Duration::from_millis(conn_config.socket_timeout_ms);
let max_lifetime = std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms);
debug!(
"Fast-path forward (no peek): {} -> {}:{}",
peer_addr, target_host, target_port
);
let backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
metrics.connection_closed(route_id);
return Err(e.into());
}
Err(_) => {
metrics.connection_closed(route_id);
return Err("Backend connection timeout".into());
}
};
backend.set_nodelay(true)?;
// Send PROXY protocol header if configured
let should_send_proxy = conn_config.send_proxy_protocol
|| quick_match.route.action.send_proxy_protocol.unwrap_or(false)
|| target.send_proxy_protocol.unwrap_or(false);
if should_send_proxy {
use tokio::io::AsyncWriteExt;
let dest = std::net::SocketAddr::new(
target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)),
target_port,
);
let header = crate::proxy_protocol::generate_v1(&peer_addr, &dest);
let mut backend_w = backend;
backend_w.write_all(header.as_bytes()).await?;
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend_w, None,
inactivity_timeout, max_lifetime, cancel,
).await?;
metrics.record_bytes(bytes_in, bytes_out, route_id);
} else {
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend, None,
inactivity_timeout, max_lifetime, cancel,
).await?;
metrics.record_bytes(bytes_in, bytes_out, route_id);
}
metrics.connection_closed(route_id);
return Ok(());
}
}
}
}
// === End fast path ===
// Handle PROXY protocol if configured
let mut effective_peer_addr = peer_addr;
if conn_config.accept_proxy_protocol {
@@ -412,11 +532,24 @@ impl TcpListenerManager {
None
};
// Extract HTTP path and host from initial data for route matching
let http_path = if is_http {
sni_parser::extract_http_path(initial_data)
} else {
None
};
let http_host = if is_http && domain.is_none() {
sni_parser::extract_http_host(initial_data)
} else {
None
};
let effective_domain = domain.as_deref().or(http_host.as_deref());
// Match route
let ctx = rustproxy_routing::MatchContext {
port,
domain: domain.as_deref(),
path: None,
domain: effective_domain,
path: http_path.as_deref(),
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: None,
@@ -449,6 +582,28 @@ impl TcpListenerManager {
// Track connection in metrics
metrics.connection_opened(route_id);
// Check if this is a socket-handler route that should be relayed to TypeScript
if route_match.route.action.action_type == RouteActionType::SocketHandler {
let relay_path = {
let guard = socket_handler_relay.read().unwrap();
guard.clone()
};
if let Some(relay_socket_path) = relay_path {
let result = Self::relay_to_socket_handler(
stream, n, port, peer_addr,
&route_match, domain.as_deref(), is_tls,
&relay_socket_path,
).await;
metrics.connection_closed(route_id);
return result;
} else {
debug!("Socket-handler route matched but no relay path configured");
metrics.connection_closed(route_id);
return Ok(());
}
}
let target = match route_match.target {
Some(t) => t,
None => {
@@ -654,6 +809,75 @@ impl TcpListenerManager {
result
}
/// Relay a connection to the TypeScript socket-handler via Unix domain socket.
///
/// Protocol:
/// 1. Connect to the Unix socket at `relay_path`
/// 2. Send a JSON metadata line (terminated by \n)
/// 3. Forward the initial peeked bytes
/// 4. Bidirectional relay between the TCP stream and Unix socket
async fn relay_to_socket_handler(
mut stream: tokio::net::TcpStream,
peek_len: usize,
port: u16,
peer_addr: std::net::SocketAddr,
route_match: &rustproxy_routing::RouteMatchResult<'_>,
domain: Option<&str>,
is_tls: bool,
relay_path: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
// Connect to the TypeScript socket handler server
let mut unix_stream = match UnixStream::connect(relay_path).await {
Ok(s) => s,
Err(e) => {
error!("Failed to connect to socket handler relay at {}: {}", relay_path, e);
return Err(e.into());
}
};
// Build metadata JSON
let route_key = route_match.route.name.as_deref()
.or(route_match.route.id.as_deref())
.unwrap_or("unknown");
let metadata = serde_json::json!({
"routeKey": route_key,
"remoteIP": peer_addr.ip().to_string(),
"remotePort": peer_addr.port(),
"localPort": port,
"isTLS": is_tls,
"domain": domain,
});
// Send metadata line (JSON + newline)
let mut metadata_line = serde_json::to_string(&metadata)?;
metadata_line.push('\n');
unix_stream.write_all(metadata_line.as_bytes()).await?;
// Read the initial peeked data from the TCP stream (peek doesn't consume)
let mut initial_buf = vec![0u8; peek_len];
stream.read_exact(&mut initial_buf).await?;
// Forward initial data to the Unix socket
unix_stream.write_all(&initial_buf).await?;
// Bidirectional relay between TCP client and Unix socket handler
match tokio::io::copy_bidirectional(&mut stream, &mut unix_stream).await {
Ok((c2s, s2c)) => {
debug!("Socket handler relay complete for {}: {} bytes in, {} bytes out",
route_key, c2s, s2c);
}
Err(e) => {
debug!("Socket handler relay ended for {}: {}", route_key, e);
}
}
Ok(())
}
/// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend.
async fn handle_tls_terminate_reencrypt(
stream: tokio::net::TcpStream,

View File

@@ -74,8 +74,8 @@ pub struct RustProxy {
nft_manager: Option<NftManager>,
started: bool,
started_at: Option<Instant>,
/// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
socket_handler_relay_path: Option<String>,
/// Shared path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
}
impl RustProxy {
@@ -111,7 +111,7 @@ impl RustProxy {
nft_manager: None,
started: false,
started_at: None,
socket_handler_relay_path: None,
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
})
}
@@ -259,6 +259,9 @@ impl RustProxy {
);
listener.set_connection_config(conn_config);
// Share the socket-handler relay path with the listener
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
// Extract TLS configurations from routes and cert manager
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
@@ -729,14 +732,16 @@ impl RustProxy {
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
/// take effect immediately for all new connections.
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
info!("Socket handler relay path set to: {:?}", path);
self.socket_handler_relay_path = path;
*self.socket_handler_relay.write().unwrap() = path;
}
/// Get the current socket handler relay path.
pub fn get_socket_handler_relay_path(&self) -> Option<&str> {
self.socket_handler_relay_path.as_deref()
pub fn get_socket_handler_relay_path(&self) -> Option<String> {
self.socket_handler_relay.read().unwrap().clone()
}
/// Load a certificate for a domain and hot-swap the TLS configuration.