feat(smart-proxy): add socket-handler relay, fast-path port-only forwarding, metrics and bridge improvements, and various TS/Rust integration fixes
This commit is contained in:
@@ -146,6 +146,41 @@ pub fn is_tls(data: &[u8]) -> bool {
|
||||
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
|
||||
}
|
||||
|
||||
/// Extract the HTTP request path from initial data.
|
||||
/// E.g., from "GET /foo/bar HTTP/1.1\r\n..." returns Some("/foo/bar").
|
||||
pub fn extract_http_path(data: &[u8]) -> Option<String> {
|
||||
let text = std::str::from_utf8(data).ok()?;
|
||||
// Find first space (after method)
|
||||
let method_end = text.find(' ')?;
|
||||
let rest = &text[method_end + 1..];
|
||||
// Find end of path (next space before "HTTP/...")
|
||||
let path_end = rest.find(' ').unwrap_or(rest.len());
|
||||
let path = &rest[..path_end];
|
||||
// Strip query string for path matching
|
||||
let path = path.split('?').next().unwrap_or(path);
|
||||
if path.starts_with('/') {
|
||||
Some(path.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the HTTP Host header from initial data.
|
||||
/// E.g., from "GET / HTTP/1.1\r\nHost: example.com\r\n..." returns Some("example.com").
|
||||
pub fn extract_http_host(data: &[u8]) -> Option<String> {
|
||||
let text = std::str::from_utf8(data).ok()?;
|
||||
for line in text.split("\r\n") {
|
||||
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
|
||||
// Strip port if present
|
||||
let host = value.split(':').next().unwrap_or(value).trim();
|
||||
if !host.is_empty() {
|
||||
return Some(host.to_lowercase());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if the initial bytes look like HTTP.
|
||||
pub fn is_http(data: &[u8]) -> bool {
|
||||
if data.len() < 4 {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, error, debug, warn};
|
||||
use thiserror::Error;
|
||||
|
||||
use rustproxy_config::RouteActionType;
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_http::HttpProxyService;
|
||||
@@ -82,8 +84,8 @@ impl Default for ConnectionConfig {
|
||||
pub struct TcpListenerManager {
|
||||
/// Active listeners indexed by port
|
||||
listeners: HashMap<u16, tokio::task::JoinHandle<()>>,
|
||||
/// Shared route manager
|
||||
route_manager: Arc<RouteManager>,
|
||||
/// Shared route manager (ArcSwap for hot-reload visibility in accept loops)
|
||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||
/// Shared metrics collector
|
||||
metrics: Arc<MetricsCollector>,
|
||||
/// TLS acceptors indexed by domain
|
||||
@@ -96,6 +98,8 @@ pub struct TcpListenerManager {
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
/// Cancellation token for graceful shutdown
|
||||
cancel_token: CancellationToken,
|
||||
/// Path to Unix domain socket for relaying socket-handler connections to TypeScript.
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl TcpListenerManager {
|
||||
@@ -112,13 +116,14 @@ impl TcpListenerManager {
|
||||
));
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager,
|
||||
route_manager: Arc::new(ArcSwap::from(route_manager)),
|
||||
metrics,
|
||||
tls_configs: Arc::new(HashMap::new()),
|
||||
http_proxy,
|
||||
conn_config: Arc::new(conn_config),
|
||||
conn_tracker,
|
||||
cancel_token: CancellationToken::new(),
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,13 +140,14 @@ impl TcpListenerManager {
|
||||
));
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager,
|
||||
route_manager: Arc::new(ArcSwap::from(route_manager)),
|
||||
metrics,
|
||||
tls_configs: Arc::new(HashMap::new()),
|
||||
http_proxy,
|
||||
conn_config: Arc::new(conn_config),
|
||||
conn_tracker,
|
||||
cancel_token: CancellationToken::new(),
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +165,12 @@ impl TcpListenerManager {
|
||||
self.tls_configs = Arc::new(configs);
|
||||
}
|
||||
|
||||
/// Set the shared socket-handler relay path.
|
||||
/// This allows RustProxy to share the relay path Arc with the listener.
|
||||
pub fn set_socket_handler_relay(&mut self, relay: Arc<std::sync::RwLock<Option<String>>>) {
|
||||
self.socket_handler_relay = relay;
|
||||
}
|
||||
|
||||
/// Start listening on a port.
|
||||
pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> {
|
||||
if self.listeners.contains_key(&port) {
|
||||
@@ -172,18 +184,19 @@ impl TcpListenerManager {
|
||||
|
||||
info!("Listening on port {}", port);
|
||||
|
||||
let route_manager = Arc::clone(&self.route_manager);
|
||||
let route_manager_swap = Arc::clone(&self.route_manager);
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let tls_configs = Arc::clone(&self.tls_configs);
|
||||
let http_proxy = Arc::clone(&self.http_proxy);
|
||||
let conn_config = Arc::clone(&self.conn_config);
|
||||
let conn_tracker = Arc::clone(&self.conn_tracker);
|
||||
let cancel = self.cancel_token.clone();
|
||||
let relay = Arc::clone(&self.socket_handler_relay);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::accept_loop(
|
||||
listener, port, route_manager, metrics, tls_configs,
|
||||
http_proxy, conn_config, conn_tracker, cancel,
|
||||
listener, port, route_manager_swap, metrics, tls_configs,
|
||||
http_proxy, conn_config, conn_tracker, cancel, relay,
|
||||
).await;
|
||||
});
|
||||
|
||||
@@ -255,8 +268,9 @@ impl TcpListenerManager {
|
||||
}
|
||||
|
||||
/// Update the route manager (for hot-reload).
|
||||
/// Uses ArcSwap so running accept loops immediately see the new routes.
|
||||
pub fn update_route_manager(&mut self, route_manager: Arc<RouteManager>) {
|
||||
self.route_manager = route_manager;
|
||||
self.route_manager.store(route_manager);
|
||||
}
|
||||
|
||||
/// Get a reference to the metrics collector.
|
||||
@@ -268,13 +282,14 @@ impl TcpListenerManager {
|
||||
async fn accept_loop(
|
||||
listener: TcpListener,
|
||||
port: u16,
|
||||
route_manager: Arc<RouteManager>,
|
||||
route_manager_swap: Arc<ArcSwap<RouteManager>>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel: CancellationToken,
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -296,18 +311,20 @@ impl TcpListenerManager {
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
|
||||
let rm = Arc::clone(&route_manager);
|
||||
// Load the latest route manager from ArcSwap on each connection
|
||||
let rm = route_manager_swap.load_full();
|
||||
let m = Arc::clone(&metrics);
|
||||
let tc = Arc::clone(&tls_configs);
|
||||
let hp = Arc::clone(&http_proxy);
|
||||
let cc = Arc::clone(&conn_config);
|
||||
let ct = Arc::clone(&conn_tracker);
|
||||
let cn = cancel.clone();
|
||||
let sr = Arc::clone(&socket_handler_relay);
|
||||
debug!("Accepted connection from {} on port {}", peer_addr, port);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = Self::handle_connection(
|
||||
stream, port, peer_addr, rm, m, tc, hp, cc, cn,
|
||||
stream, port, peer_addr, rm, m, tc, hp, cc, cn, sr,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
debug!("Connection error from {}: {}", peer_addr, e);
|
||||
@@ -336,11 +353,114 @@ impl TcpListenerManager {
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
cancel: CancellationToken,
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
stream.set_nodelay(true)?;
|
||||
|
||||
// === Fast path: try port-only matching before peeking at data ===
|
||||
// This handles "server-speaks-first" protocols where the client
|
||||
// doesn't send initial data (e.g., SMTP, greeting-based protocols).
|
||||
// If a route matches by port alone and doesn't need domain/path/TLS info,
|
||||
// we can forward immediately without waiting for client data.
|
||||
{
|
||||
let quick_ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: None,
|
||||
path: None,
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
if let Some(quick_match) = route_manager.find_route(&quick_ctx) {
|
||||
let rm = &quick_match.route.route_match;
|
||||
let has_no_domain = rm.domains.is_none();
|
||||
let has_no_path = rm.path.is_none();
|
||||
let is_forward = quick_match.route.action.action_type == RouteActionType::Forward;
|
||||
let has_no_tls = quick_match.route.action.tls.is_none();
|
||||
|
||||
// Only use fast path for simple port-only forward routes with no TLS
|
||||
if has_no_domain && has_no_path && is_forward && has_no_tls {
|
||||
if let Some(target) = quick_match.target {
|
||||
let target_host = target.host.first().to_string();
|
||||
let target_port = target.port.resolve(port);
|
||||
let route_id = quick_match.route.id.as_deref();
|
||||
|
||||
// Check route-level IP security
|
||||
if let Some(ref security) = quick_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security, &peer_addr.ip(),
|
||||
) {
|
||||
debug!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
metrics.connection_opened(route_id);
|
||||
|
||||
let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms);
|
||||
let inactivity_timeout = std::time::Duration::from_millis(conn_config.socket_timeout_ms);
|
||||
let max_lifetime = std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms);
|
||||
|
||||
debug!(
|
||||
"Fast-path forward (no peek): {} -> {}:{}",
|
||||
peer_addr, target_host, target_port
|
||||
);
|
||||
|
||||
let backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
metrics.connection_closed(route_id);
|
||||
return Err(e.into());
|
||||
}
|
||||
Err(_) => {
|
||||
metrics.connection_closed(route_id);
|
||||
return Err("Backend connection timeout".into());
|
||||
}
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
// Send PROXY protocol header if configured
|
||||
let should_send_proxy = conn_config.send_proxy_protocol
|
||||
|| quick_match.route.action.send_proxy_protocol.unwrap_or(false)
|
||||
|| target.send_proxy_protocol.unwrap_or(false);
|
||||
if should_send_proxy {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let dest = std::net::SocketAddr::new(
|
||||
target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)),
|
||||
target_port,
|
||||
);
|
||||
let header = crate::proxy_protocol::generate_v1(&peer_addr, &dest);
|
||||
let mut backend_w = backend;
|
||||
backend_w.write_all(header.as_bytes()).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend_w, None,
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
} else {
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, None,
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
}
|
||||
|
||||
metrics.connection_closed(route_id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// === End fast path ===
|
||||
|
||||
// Handle PROXY protocol if configured
|
||||
let mut effective_peer_addr = peer_addr;
|
||||
if conn_config.accept_proxy_protocol {
|
||||
@@ -412,11 +532,24 @@ impl TcpListenerManager {
|
||||
None
|
||||
};
|
||||
|
||||
// Extract HTTP path and host from initial data for route matching
|
||||
let http_path = if is_http {
|
||||
sni_parser::extract_http_path(initial_data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let http_host = if is_http && domain.is_none() {
|
||||
sni_parser::extract_http_host(initial_data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let effective_domain = domain.as_deref().or(http_host.as_deref());
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: domain.as_deref(),
|
||||
path: None,
|
||||
domain: effective_domain,
|
||||
path: http_path.as_deref(),
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
@@ -449,6 +582,28 @@ impl TcpListenerManager {
|
||||
// Track connection in metrics
|
||||
metrics.connection_opened(route_id);
|
||||
|
||||
// Check if this is a socket-handler route that should be relayed to TypeScript
|
||||
if route_match.route.action.action_type == RouteActionType::SocketHandler {
|
||||
let relay_path = {
|
||||
let guard = socket_handler_relay.read().unwrap();
|
||||
guard.clone()
|
||||
};
|
||||
|
||||
if let Some(relay_socket_path) = relay_path {
|
||||
let result = Self::relay_to_socket_handler(
|
||||
stream, n, port, peer_addr,
|
||||
&route_match, domain.as_deref(), is_tls,
|
||||
&relay_socket_path,
|
||||
).await;
|
||||
metrics.connection_closed(route_id);
|
||||
return result;
|
||||
} else {
|
||||
debug!("Socket-handler route matched but no relay path configured");
|
||||
metrics.connection_closed(route_id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
@@ -654,6 +809,75 @@ impl TcpListenerManager {
|
||||
result
|
||||
}
|
||||
|
||||
/// Relay a connection to the TypeScript socket-handler via Unix domain socket.
|
||||
///
|
||||
/// Protocol:
|
||||
/// 1. Connect to the Unix socket at `relay_path`
|
||||
/// 2. Send a JSON metadata line (terminated by \n)
|
||||
/// 3. Forward the initial peeked bytes
|
||||
/// 4. Bidirectional relay between the TCP stream and Unix socket
|
||||
async fn relay_to_socket_handler(
|
||||
mut stream: tokio::net::TcpStream,
|
||||
peek_len: usize,
|
||||
port: u16,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
route_match: &rustproxy_routing::RouteMatchResult<'_>,
|
||||
domain: Option<&str>,
|
||||
is_tls: bool,
|
||||
relay_path: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
// Connect to the TypeScript socket handler server
|
||||
let mut unix_stream = match UnixStream::connect(relay_path).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to connect to socket handler relay at {}: {}", relay_path, e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
// Build metadata JSON
|
||||
let route_key = route_match.route.name.as_deref()
|
||||
.or(route_match.route.id.as_deref())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"routeKey": route_key,
|
||||
"remoteIP": peer_addr.ip().to_string(),
|
||||
"remotePort": peer_addr.port(),
|
||||
"localPort": port,
|
||||
"isTLS": is_tls,
|
||||
"domain": domain,
|
||||
});
|
||||
|
||||
// Send metadata line (JSON + newline)
|
||||
let mut metadata_line = serde_json::to_string(&metadata)?;
|
||||
metadata_line.push('\n');
|
||||
unix_stream.write_all(metadata_line.as_bytes()).await?;
|
||||
|
||||
// Read the initial peeked data from the TCP stream (peek doesn't consume)
|
||||
let mut initial_buf = vec![0u8; peek_len];
|
||||
stream.read_exact(&mut initial_buf).await?;
|
||||
|
||||
// Forward initial data to the Unix socket
|
||||
unix_stream.write_all(&initial_buf).await?;
|
||||
|
||||
// Bidirectional relay between TCP client and Unix socket handler
|
||||
match tokio::io::copy_bidirectional(&mut stream, &mut unix_stream).await {
|
||||
Ok((c2s, s2c)) => {
|
||||
debug!("Socket handler relay complete for {}: {} bytes in, {} bytes out",
|
||||
route_key, c2s, s2c);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Socket handler relay ended for {}: {}", route_key, e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend.
|
||||
async fn handle_tls_terminate_reencrypt(
|
||||
stream: tokio::net::TcpStream,
|
||||
|
||||
Reference in New Issue
Block a user