diff --git a/changelog.md b/changelog.md index 87fbd25..e7ffacf 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,16 @@ # Changelog +## 2026-02-16 - 25.6.0 - feat(rustproxy) +add protocol-based routing and backend TLS re-encryption support + +- Introduce optional route_match.protocol ("http" | "tcp") in Rust and TypeScript route types to allow protocol-restricted routing. +- RouteManager: respect protocol field during matching and treat TLS connections without SNI as not matching domain-restricted routes (except wildcard-only routes). +- HTTP proxy: add BackendStream abstraction to unify plain TCP and tokio-rustls TLS backend streams, and support connecting to upstreams over TLS (upstream.use_tls) with an InsecureBackendVerifier for internal/self-signed backends. +- WebSocket and HTTP forwarding updated to use BackendStream so upstream TLS is handled transparently. +- Passthrough listener: perform post-termination protocol detection for TerminateAndReencrypt; route HTTP flows into HttpProxyService and handle non-HTTP as TLS-to-TLS tunnel. +- Add tests for protocol matching, TLS/no-SNI behavior, and other routing edge cases. +- Add rustls and tokio-rustls dependencies (Cargo.toml/Cargo.lock updates). + ## 2026-02-16 - 25.5.0 - feat(tls) add shared TLS acceptor with SNI resolver and session resumption; prefer shared acceptor and fall back to per-connection when routes specify custom TLS versions diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 1109089..3659db7 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -966,12 +966,14 @@ dependencies = [ "hyper", "hyper-util", "regex", + "rustls", "rustproxy-config", "rustproxy-metrics", "rustproxy-routing", "rustproxy-security", "thiserror 2.0.18", "tokio", + "tokio-rustls", "tokio-util", "tracing", ] diff --git a/rust/crates/rustproxy-config/src/helpers.rs b/rust/crates/rustproxy-config/src/helpers.rs index d10c4b2..e29d417 100644 --- a/rust/crates/rustproxy-config/src/helpers.rs +++ b/rust/crates/rustproxy-config/src/helpers.rs @@ -17,6 +17,7 @@ pub fn create_http_route( client_ip: None, tls_version: None, headers: None, + protocol: None, }, action: RouteAction { action_type: RouteActionType::Forward, @@ -108,6 +109,7 @@ pub fn create_http_to_https_redirect( client_ip: None, tls_version: None, headers: None, + protocol: None, }, action: RouteAction { action_type: RouteActionType::Forward, @@ -200,6 +202,7 @@ pub fn create_load_balancer_route( client_ip: None, tls_version: None, headers: None, + protocol: None, }, action: RouteAction { action_type: RouteActionType::Forward, diff --git a/rust/crates/rustproxy-config/src/route_types.rs b/rust/crates/rustproxy-config/src/route_types.rs index 066e289..9125d30 100644 --- a/rust/crates/rustproxy-config/src/route_types.rs +++ b/rust/crates/rustproxy-config/src/route_types.rs @@ -114,6 +114,10 @@ pub struct RouteMatch { /// Match specific HTTP headers #[serde(skip_serializing_if = "Option::is_none")] pub headers: Option>, + + /// Match specific protocol: "http" (includes h2 + websocket) or "tcp" + #[serde(skip_serializing_if = "Option::is_none")] + pub protocol: Option, } // ─── Target Match ──────────────────────────────────────────────────── diff --git a/rust/crates/rustproxy-http/Cargo.toml b/rust/crates/rustproxy-http/Cargo.toml index 5ce5e54..7179798 100644 --- a/rust/crates/rustproxy-http/Cargo.toml +++ b/rust/crates/rustproxy-http/Cargo.toml @@ -18,6 +18,8 @@ http-body = { workspace = true } http-body-util = { workspace = true } bytes = { workspace = true } tokio = { workspace = true } +rustls = { workspace = true } +tokio-rustls = { workspace = true } tracing = { workspace = true } thiserror = { workspace = true } anyhow = { workspace = true } diff --git a/rust/crates/rustproxy-http/src/proxy_service.rs b/rust/crates/rustproxy-http/src/proxy_service.rs index 3080d0f..5cb6609 100644 --- a/rust/crates/rustproxy-http/src/proxy_service.rs +++ b/rust/crates/rustproxy-http/src/proxy_service.rs @@ -18,6 +18,9 @@ use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; +use std::pin::Pin; +use std::task::{Context, Poll}; + use rustproxy_routing::RouteManager; use rustproxy_metrics::MetricsCollector; @@ -35,6 +38,125 @@ const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration:: /// Default WebSocket max lifetime (24 hours). const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400); +/// Backend stream that can be either plain TCP or TLS-wrapped. +/// Used for `terminate-and-reencrypt` mode where the backend requires TLS. +pub(crate) enum BackendStream { + Plain(TcpStream), + Tls(tokio_rustls::client::TlsStream), +} + +impl tokio::io::AsyncRead for BackendStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + BackendStream::Plain(s) => Pin::new(s).poll_read(cx, buf), + BackendStream::Tls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl tokio::io::AsyncWrite for BackendStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + BackendStream::Plain(s) => Pin::new(s).poll_write(cx, buf), + BackendStream::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + BackendStream::Plain(s) => Pin::new(s).poll_flush(cx), + BackendStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + BackendStream::Plain(s) => Pin::new(s).poll_shutdown(cx), + BackendStream::Tls(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +/// Connect to a backend over TLS. Uses InsecureVerifier for internal backends +/// with self-signed certs (same pattern as tls_handler::connect_tls). +async fn connect_tls_backend( + host: &str, + port: u16, +) -> Result, Box> { + let _ = rustls::crypto::ring::default_provider().install_default(); + let config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier)) + .with_no_client_auth(); + + let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); + let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + stream.set_nodelay(true)?; + + let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?; + let tls_stream = connector.connect(server_name, stream).await?; + debug!("Backend TLS connection established to {}:{}", host, port); + Ok(tls_stream) +} + +/// Insecure certificate verifier for backend TLS connections. +/// Internal backends may use self-signed certs. +#[derive(Debug)] +struct InsecureBackendVerifier; + +impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + ] + } +} + /// HTTP proxy service that processes HTTP traffic. pub struct HttpProxyService { route_manager: Arc, @@ -173,6 +295,7 @@ impl HttpProxyService { tls_version: None, headers: Some(&headers), is_tls: false, + protocol: Some("http"), }; let route_match = match self.route_manager.find_route(&ctx) { @@ -273,28 +396,51 @@ impl HttpProxyService { } } - // Connect to upstream with timeout - let upstream_stream = match tokio::time::timeout( - self.connect_timeout, - TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), - ).await { - Ok(Ok(s)) => s, - Ok(Err(e)) => { - error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e); - self.upstream_selector.connection_ended(&upstream_key); - self.metrics.connection_closed(route_id, Some(&ip_str)); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + // Connect to upstream with timeout (TLS if upstream.use_tls is set) + let backend = if upstream.use_tls { + match tokio::time::timeout( + self.connect_timeout, + connect_tls_backend(&upstream.host, upstream.port), + ).await { + Ok(Ok(tls)) => BackendStream::Tls(tls), + Ok(Err(e)) => { + error!("Failed TLS connect to upstream {}:{}: {}", upstream.host, upstream.port, e); + self.upstream_selector.connection_ended(&upstream_key); + self.metrics.connection_closed(route_id, Some(&ip_str)); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); + } + Err(_) => { + error!("Upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); + self.upstream_selector.connection_ended(&upstream_key); + self.metrics.connection_closed(route_id, Some(&ip_str)); + return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); + } } - Err(_) => { - error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port); - self.upstream_selector.connection_ended(&upstream_key); - self.metrics.connection_closed(route_id, Some(&ip_str)); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); + } else { + match tokio::time::timeout( + self.connect_timeout, + TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), + ).await { + Ok(Ok(s)) => { + s.set_nodelay(true).ok(); + BackendStream::Plain(s) + } + Ok(Err(e)) => { + error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e); + self.upstream_selector.connection_ended(&upstream_key); + self.metrics.connection_closed(route_id, Some(&ip_str)); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + } + Err(_) => { + error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port); + self.upstream_selector.connection_ended(&upstream_key); + self.metrics.connection_closed(route_id, Some(&ip_str)); + return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); + } } }; - upstream_stream.set_nodelay(true).ok(); - let io = TokioIo::new(upstream_stream); + let io = TokioIo::new(backend); let result = if use_h2 { // HTTP/2 backend @@ -310,7 +456,7 @@ impl HttpProxyService { /// Forward request to backend via HTTP/1.1 with body streaming. async fn forward_h1( &self, - io: TokioIo, + io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, @@ -376,7 +522,7 @@ impl HttpProxyService { /// Forward request to backend via HTTP/2 with body streaming. async fn forward_h2( &self, - io: TokioIo, + io: TokioIo, parts: hyper::http::request::Parts, body: Incoming, upstream_headers: hyper::HeaderMap, @@ -516,26 +662,49 @@ impl HttpProxyService { info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port); - // Connect to upstream with timeout - let mut upstream_stream = match tokio::time::timeout( - self.connect_timeout, - TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), - ).await { - Ok(Ok(s)) => s, - Ok(Err(e)) => { - error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e); - self.upstream_selector.connection_ended(upstream_key); - self.metrics.connection_closed(route_id, Some(source_ip)); - return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + // Connect to upstream with timeout (TLS if upstream.use_tls is set) + let mut upstream_stream: BackendStream = if upstream.use_tls { + match tokio::time::timeout( + self.connect_timeout, + connect_tls_backend(&upstream.host, upstream.port), + ).await { + Ok(Ok(tls)) => BackendStream::Tls(tls), + Ok(Err(e)) => { + error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id, Some(source_ip)); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable")); + } + Err(_) => { + error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id, Some(source_ip)); + return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout")); + } } - Err(_) => { - error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port); - self.upstream_selector.connection_ended(upstream_key); - self.metrics.connection_closed(route_id, Some(source_ip)); - return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); + } else { + match tokio::time::timeout( + self.connect_timeout, + TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)), + ).await { + Ok(Ok(s)) => { + s.set_nodelay(true).ok(); + BackendStream::Plain(s) + } + Ok(Err(e)) => { + error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id, Some(source_ip)); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + } + Err(_) => { + error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id, Some(source_ip)); + return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout")); + } } }; - upstream_stream.set_nodelay(true).ok(); let path = req.uri().path().to_string(); let upstream_path = { diff --git a/rust/crates/rustproxy-passthrough/src/tcp_listener.rs b/rust/crates/rustproxy-passthrough/src/tcp_listener.rs index 47dc291..f168a97 100644 --- a/rust/crates/rustproxy-passthrough/src/tcp_listener.rs +++ b/rust/crates/rustproxy-passthrough/src/tcp_listener.rs @@ -446,6 +446,7 @@ impl TcpListenerManager { tls_version: None, headers: None, is_tls: false, + protocol: None, }; if let Some(quick_match) = route_manager.find_route(&quick_ctx) { @@ -650,6 +651,8 @@ impl TcpListenerManager { tls_version: None, headers: None, is_tls, + // For TLS connections, protocol is unknown until after termination + protocol: if is_http { Some("http") } else if !is_tls { Some("tcp") } else { None }, }; let route_match = route_manager.find_route(&ctx); @@ -827,6 +830,15 @@ impl TcpListenerManager { } }; + // Check protocol restriction from route config + if let Some(ref required_protocol) = route_match.route.route_match.protocol { + let detected = if peeked { "http" } else { "tcp" }; + if required_protocol != detected { + debug!("Protocol mismatch: route requires '{}', got '{}'", required_protocol, detected); + return Err("Protocol mismatch".into()); + } + } + if peeked { debug!( "TLS Terminate + HTTP: {} -> {}:{} (domain: {:?})", @@ -867,12 +879,59 @@ impl TcpListenerManager { Ok(()) } Some(rustproxy_config::TlsMode::TerminateAndReencrypt) => { + // Inline TLS accept + HTTP detection (same pattern as Terminate mode) let route_tls = route_match.route.action.tls.as_ref(); - Self::handle_tls_terminate_reencrypt( - stream, n, &domain, &target_host, target_port, - peer_addr, &tls_configs, &shared_tls_acceptor, - Arc::clone(&metrics), route_id, &conn_config, route_tls, - ).await + let acceptor = Self::get_tls_acceptor(&domain, &tls_configs, &*shared_tls_acceptor, route_tls)?; + let tls_stream = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), + tls_handler::accept_tls(stream, &acceptor), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e), + Err(_) => return Err("TLS handshake timeout".into()), + }; + + // Peek at decrypted data to detect protocol + let mut buf_stream = tokio::io::BufReader::new(tls_stream); + let is_http_data = { + use tokio::io::AsyncBufReadExt; + match buf_stream.fill_buf().await { + Ok(data) => sni_parser::is_http(data), + Err(_) => false, + } + }; + + // Check protocol restriction from route config + if let Some(ref required_protocol) = route_match.route.route_match.protocol { + let detected = if is_http_data { "http" } else { "tcp" }; + if required_protocol != detected { + debug!("Protocol mismatch: route requires '{}', got '{}'", required_protocol, detected); + return Err("Protocol mismatch".into()); + } + } + + if is_http_data { + // HTTP: full per-request routing via HttpProxyService + // (backend TLS handled by HttpProxyService when upstream.use_tls is set) + debug!( + "TLS Terminate+Reencrypt + HTTP: {} (domain: {:?})", + peer_addr, domain + ); + _conn_guard.disarm(); + http_proxy.handle_io(buf_stream, peer_addr, port, cancel.clone()).await; + } else { + // Non-HTTP: TLS-to-TLS tunnel (existing behavior for raw TCP protocols) + debug!( + "TLS Terminate+Reencrypt + TCP: {} -> {}:{}", + peer_addr, target_host, target_port + ); + Self::handle_tls_reencrypt_tunnel( + buf_stream, &target_host, target_port, + peer_addr, Arc::clone(&metrics), route_id, + &conn_config, + ).await?; + } + Ok(()) } None => { if is_http { @@ -1007,39 +1066,18 @@ impl TcpListenerManager { Ok(()) } - /// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend. - async fn handle_tls_terminate_reencrypt( - stream: tokio::net::TcpStream, - _peek_len: usize, - domain: &Option, + /// Handle non-HTTP TLS-to-TLS tunnel for terminate-and-reencrypt mode. + /// TLS accept has already been done by the caller; this only connects to the + /// backend over TLS and forwards bidirectionally. + async fn handle_tls_reencrypt_tunnel( + buf_stream: tokio::io::BufReader>, target_host: &str, target_port: u16, peer_addr: std::net::SocketAddr, - tls_configs: &HashMap, - shared_tls_acceptor: &Option, metrics: Arc, route_id: Option<&str>, conn_config: &ConnectionConfig, - route_tls: Option<&rustproxy_config::RouteTls>, ) -> Result<(), Box> { - // Use shared acceptor (session resumption) or fall back to per-connection - let acceptor = Self::get_tls_acceptor(domain, tls_configs, shared_tls_acceptor, route_tls)?; - - // Accept TLS from client with timeout - let client_tls = match tokio::time::timeout( - std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), - tls_handler::accept_tls(stream, &acceptor), - ).await { - Ok(Ok(s)) => s, - Ok(Err(e)) => return Err(e), - Err(_) => return Err("TLS handshake timeout".into()), - }; - - debug!( - "TLS Terminate+Reencrypt: {} -> {}:{} (domain: {:?})", - peer_addr, target_host, target_port, domain - ); - // Connect to backend over TLS with timeout let backend_tls = match tokio::time::timeout( std::time::Duration::from_millis(conn_config.connection_timeout_ms), @@ -1050,8 +1088,9 @@ impl TcpListenerManager { Err(_) => return Err("Backend TLS connection timeout".into()), }; - // Forward between two TLS streams - let (client_read, client_write) = tokio::io::split(client_tls); + // Forward between decrypted client stream and backend TLS stream + // (BufReader preserves any already-buffered data from the peek) + let (client_read, client_write) = tokio::io::split(buf_stream); let (backend_read, backend_write) = tokio::io::split(backend_tls); let base_inactivity_ms = conn_config.socket_timeout_ms; diff --git a/rust/crates/rustproxy-routing/src/route_manager.rs b/rust/crates/rustproxy-routing/src/route_manager.rs index 5f67f0e..5b8807c 100644 --- a/rust/crates/rustproxy-routing/src/route_manager.rs +++ b/rust/crates/rustproxy-routing/src/route_manager.rs @@ -12,6 +12,8 @@ pub struct MatchContext<'a> { pub tls_version: Option<&'a str>, pub headers: Option<&'a HashMap>, pub is_tls: bool, + /// Detected protocol: "http" or "tcp". None when unknown (e.g. pre-TLS-termination). + pub protocol: Option<&'a str>, } /// Result of a route match. @@ -87,9 +89,17 @@ impl RouteManager { if !matchers::domain_matches_any(&patterns, domain) { return false; } + } else if ctx.is_tls { + // TLS connection without SNI cannot match a domain-restricted route. + // This prevents session-ticket resumption from misrouting when clients + // omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption). + // Wildcard-only routes (domains: ["*"]) still match since they accept all. + let patterns = domains.to_vec(); + let is_wildcard_only = patterns.iter().all(|d| *d == "*"); + if !is_wildcard_only { + return false; + } } - // If no domain provided but route requires domain, it depends on context - // For TLS passthrough, we need SNI; for other cases we may still match } // Path matching @@ -137,6 +147,17 @@ impl RouteManager { } } + // Protocol matching + if let Some(ref required_protocol) = rm.protocol { + if let Some(protocol) = ctx.protocol { + if required_protocol != protocol { + return false; + } + } + // If protocol not yet known (None), allow match — protocol will be + // validated after detection (post-TLS-termination peek) + } + true } @@ -277,6 +298,7 @@ mod tests { client_ip: None, tls_version: None, headers: None, + protocol: None, }, action: RouteAction { action_type: RouteActionType::Forward, @@ -327,6 +349,7 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, }; let result = manager.find_route(&ctx); @@ -349,6 +372,7 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, }; let result = manager.find_route(&ctx).unwrap(); @@ -372,6 +396,7 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, }; assert!(manager.find_route(&ctx).is_none()); @@ -457,6 +482,116 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, + }; + + assert!(manager.find_route(&ctx).is_some()); + } + + #[test] + fn test_tls_no_sni_rejects_domain_restricted_route() { + let routes = vec![make_route(443, Some("example.com"), 0)]; + let manager = RouteManager::new(routes); + + // TLS connection without SNI should NOT match a domain-restricted route + let ctx = MatchContext { + port: 443, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: true, + protocol: None, + }; + + assert!(manager.find_route(&ctx).is_none()); + } + + #[test] + fn test_tls_no_sni_rejects_wildcard_subdomain_route() { + let routes = vec![make_route(443, Some("*.example.com"), 0)]; + let manager = RouteManager::new(routes); + + // TLS connection without SNI should NOT match *.example.com + let ctx = MatchContext { + port: 443, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: true, + protocol: None, + }; + + assert!(manager.find_route(&ctx).is_none()); + } + + #[test] + fn test_tls_no_sni_matches_wildcard_only_route() { + let routes = vec![make_route(443, Some("*"), 0)]; + let manager = RouteManager::new(routes); + + // TLS connection without SNI SHOULD match a wildcard-only route + let ctx = MatchContext { + port: 443, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: true, + protocol: None, + }; + + assert!(manager.find_route(&ctx).is_some()); + } + + #[test] + fn test_tls_no_sni_skips_domain_restricted_matches_fallback() { + // Two routes: first is domain-restricted, second is wildcard catch-all + let routes = vec![ + make_route(443, Some("specific.com"), 10), + make_route(443, Some("*"), 0), + ]; + let manager = RouteManager::new(routes); + + // TLS without SNI should skip specific.com and fall through to wildcard + let ctx = MatchContext { + port: 443, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: true, + protocol: None, + }; + + let result = manager.find_route(&ctx); + assert!(result.is_some()); + let matched_domains = result.unwrap().route.route_match.domains.as_ref() + .map(|d| d.to_vec()).unwrap(); + assert!(matched_domains.contains(&"*")); + } + + #[test] + fn test_non_tls_no_domain_still_matches_domain_restricted() { + // Non-TLS (plain HTTP) without domain should still match domain-restricted routes + // (the HTTP proxy layer handles Host-based routing) + let routes = vec![make_route(80, Some("example.com"), 0)]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + protocol: None, }; assert!(manager.find_route(&ctx).is_some()); @@ -475,6 +610,7 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, }; assert!(manager.find_route(&ctx).is_some()); @@ -525,6 +661,7 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, }; let result = manager.find_route(&ctx).unwrap(); assert_eq!(result.target.unwrap().host.first(), "api-backend"); @@ -538,8 +675,102 @@ mod tests { tls_version: None, headers: None, is_tls: false, + protocol: None, }; let result = manager.find_route(&ctx).unwrap(); assert_eq!(result.target.unwrap().host.first(), "default-backend"); } + + fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig { + let mut route = make_route(port, domain, 0); + route.route_match.protocol = protocol.map(|s| s.to_string()); + route + } + + #[test] + fn test_protocol_http_matches_http() { + let routes = vec![make_route_with_protocol(80, None, Some("http"))]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + protocol: Some("http"), + }; + assert!(manager.find_route(&ctx).is_some()); + } + + #[test] + fn test_protocol_http_rejects_tcp() { + let routes = vec![make_route_with_protocol(80, None, Some("http"))]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + protocol: Some("tcp"), + }; + assert!(manager.find_route(&ctx).is_none()); + } + + #[test] + fn test_protocol_none_matches_any() { + // Route with no protocol restriction matches any protocol + let routes = vec![make_route_with_protocol(80, None, None)]; + let manager = RouteManager::new(routes); + + let ctx_http = MatchContext { + port: 80, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + protocol: Some("http"), + }; + assert!(manager.find_route(&ctx_http).is_some()); + + let ctx_tcp = MatchContext { + port: 80, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + protocol: Some("tcp"), + }; + assert!(manager.find_route(&ctx_tcp).is_some()); + } + + #[test] + fn test_protocol_http_matches_when_unknown() { + // Route with protocol: "http" should match when ctx.protocol is None + // (pre-TLS-termination, protocol not yet known) + let routes = vec![make_route_with_protocol(443, None, Some("http"))]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 443, + domain: None, + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: true, + protocol: None, + }; + assert!(manager.find_route(&ctx).is_some()); + } } diff --git a/rust/crates/rustproxy/tests/common/mod.rs b/rust/crates/rustproxy/tests/common/mod.rs index 578e3f0..80842ae 100644 --- a/rust/crates/rustproxy/tests/common/mod.rs +++ b/rust/crates/rustproxy/tests/common/mod.rs @@ -201,6 +201,7 @@ pub fn make_test_route( client_ip: None, tls_version: None, headers: None, + protocol: None, }, action: rustproxy_config::RouteAction { action_type: rustproxy_config::RouteActionType::Forward, diff --git a/ts/00_commitinfo_data.ts b/ts/00_commitinfo_data.ts index 00577a2..4146f9d 100644 --- a/ts/00_commitinfo_data.ts +++ b/ts/00_commitinfo_data.ts @@ -3,6 +3,6 @@ */ export const commitinfo = { name: '@push.rocks/smartproxy', - version: '25.5.0', + version: '25.6.0', description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.' } diff --git a/ts/proxies/smart-proxy/models/route-types.ts b/ts/proxies/smart-proxy/models/route-types.ts index 61ec957..837f29b 100644 --- a/ts/proxies/smart-proxy/models/route-types.ts +++ b/ts/proxies/smart-proxy/models/route-types.ts @@ -39,6 +39,7 @@ export interface IRouteMatch { clientIp?: string[]; // Match specific client IPs tlsVersion?: string[]; // Match specific TLS versions headers?: Record; // Match specific HTTP headers + protocol?: 'http' | 'tcp'; // Match specific protocol (http includes h2 + websocket upgrades) }