Files
smartproxy/rust/crates/rustproxy-routing/src/route_manager.rs

777 lines
24 KiB
Rust

use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
use crate::matchers;
/// Context for route matching (subset of connection info).
pub struct MatchContext<'a> {
pub port: u16,
pub domain: Option<&'a str>,
pub path: Option<&'a str>,
pub client_ip: Option<&'a str>,
pub tls_version: Option<&'a str>,
pub headers: Option<&'a HashMap<String, String>>,
pub is_tls: bool,
/// Detected protocol: "http" or "tcp". None when unknown (e.g. pre-TLS-termination).
pub protocol: Option<&'a str>,
}
/// Result of a route match.
pub struct RouteMatchResult<'a> {
pub route: &'a RouteConfig,
pub target: Option<&'a RouteTarget>,
}
/// Port-indexed route lookup with priority-based matching.
/// This is the core routing engine.
pub struct RouteManager {
/// Routes indexed by port for O(1) port lookup.
port_index: HashMap<u16, Vec<usize>>,
/// All routes, sorted by priority (highest first).
routes: Vec<RouteConfig>,
}
impl RouteManager {
/// Create a new RouteManager from a list of routes.
pub fn new(routes: Vec<RouteConfig>) -> Self {
let mut manager = Self {
port_index: HashMap::new(),
routes: Vec::new(),
};
// Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes
.into_iter()
.filter(|r| r.is_enabled())
.collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index
for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() {
manager.port_index
.entry(port)
.or_default()
.push(idx);
}
}
manager.routes = enabled_routes;
manager
}
/// Find the best matching route for the given context.
pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option<RouteMatchResult<'a>> {
// Get routes for this port
let route_indices = self.port_index.get(&ctx.port)?;
for &idx in route_indices {
let route = &self.routes[idx];
if self.matches_route(route, ctx) {
// Find the best matching target within the route
let target = self.find_target(route, ctx);
return Some(RouteMatchResult { route, target });
}
}
None
}
/// Check if a route matches the given context.
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
let rm = &route.route_match;
// Domain matching
if let Some(ref domains) = rm.domains {
if let Some(domain) = ctx.domain {
let patterns = domains.to_vec();
if !matchers::domain_matches_any(&patterns, domain) {
return false;
}
} else if ctx.is_tls {
// TLS connection without SNI cannot match a domain-restricted route.
// This prevents session-ticket resumption from misrouting when clients
// omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption).
// Wildcard-only routes (domains: ["*"]) still match since they accept all.
let patterns = domains.to_vec();
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
if !is_wildcard_only {
return false;
}
}
}
// Path matching
if let Some(ref pattern) = rm.path {
if let Some(path) = ctx.path {
if !matchers::path_matches(pattern, path) {
return false;
}
} else {
// Route requires path but none provided
return false;
}
}
// Client IP matching
if let Some(ref client_ips) = rm.client_ip {
if let Some(ip) = ctx.client_ip {
if !matchers::ip_matches_any(client_ips, ip) {
return false;
}
} else {
return false;
}
}
// TLS version matching
if let Some(ref tls_versions) = rm.tls_version {
if let Some(version) = ctx.tls_version {
if !tls_versions.iter().any(|v| v == version) {
return false;
}
} else {
return false;
}
}
// Header matching
if let Some(ref patterns) = rm.headers {
if let Some(headers) = ctx.headers {
if !matchers::headers_match(patterns, headers) {
return false;
}
} else {
return false;
}
}
// Protocol matching
if let Some(ref required_protocol) = rm.protocol {
if let Some(protocol) = ctx.protocol {
if required_protocol != protocol {
return false;
}
}
// If protocol not yet known (None), allow match — protocol will be
// validated after detection (post-TLS-termination peek)
}
true
}
/// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() {
return Some(&targets[0]);
}
// Sort candidates by priority (already in order from config)
let mut best: Option<&RouteTarget> = None;
let mut best_priority = i32::MIN;
for target in targets {
let priority = target.priority.unwrap_or(0);
if let Some(ref tm) = target.target_match {
if !self.matches_target(tm, ctx) {
continue;
}
}
if priority > best_priority || best.is_none() {
best = Some(target);
best_priority = priority;
}
}
// Fall back to first target without match criteria
best.or_else(|| {
targets.iter().find(|t| t.target_match.is_none())
})
}
/// Check if a target match criteria matches the context.
fn matches_target(
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
// Port matching
if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) {
return false;
}
}
// Path matching
if let Some(ref pattern) = tm.path {
if let Some(path) = ctx.path {
if !matchers::path_matches(pattern, path) {
return false;
}
} else {
return false;
}
}
// Header matching
if let Some(ref patterns) = tm.headers {
if let Some(headers) = ctx.headers {
if !matchers::headers_match(patterns, headers) {
return false;
}
} else {
return false;
}
}
true
}
/// Get all unique listening ports.
pub fn listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.port_index.keys().copied().collect();
ports.sort();
ports
}
/// Get all routes for a specific port.
pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> {
self.port_index
.get(&port)
.map(|indices| indices.iter().map(|&i| &self.routes[i]).collect())
.unwrap_or_default()
}
/// Get the total number of enabled routes.
pub fn route_count(&self) -> usize {
self.routes.len()
}
/// Check if any route on the given port requires SNI.
pub fn port_requires_sni(&self, port: u16) -> bool {
let routes = self.routes_for_port(port);
// If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes
.iter()
.filter(|r| {
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.collect();
if passthrough_routes.len() > 1 {
return true;
}
// Single passthrough route with specific domain restriction needs SNI
if let Some(route) = passthrough_routes.first() {
if let Some(ref domains) = route.route_match.domains {
let domain_list = domains.to_vec();
// If it's not just a wildcard, SNI is needed
if !domain_list.iter().all(|d| *d == "*") {
return true;
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustproxy_config::*;
fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single("localhost".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: Some(priority),
tags: None,
enabled: None,
}
}
#[test]
fn test_basic_routing() {
let routes = vec![
make_route(80, Some("example.com"), 0),
make_route(80, Some("other.com"), 0),
];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx);
assert!(result.is_some());
}
#[test]
fn test_priority_ordering() {
let routes = vec![
make_route(80, Some("*.example.com"), 0),
make_route(80, Some("api.example.com"), 10), // Higher priority
];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("api.example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref()
.map(|d| d.to_vec())
.unwrap()
.contains(&"api.example.com"));
}
#[test]
fn test_no_match() {
let routes = vec![make_route(80, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443, // Different port
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_disabled_routes_excluded() {
let mut route = make_route(80, Some("example.com"), 0);
route.enabled = Some(false);
let manager = RouteManager::new(vec![route]);
assert_eq!(manager.route_count(), 0);
}
#[test]
fn test_listening_ports() {
let routes = vec![
make_route(80, Some("a.com"), 0),
make_route(443, Some("b.com"), 0),
make_route(80, Some("c.com"), 0), // duplicate port
];
let manager = RouteManager::new(routes);
let ports = manager.listening_ports();
assert_eq!(ports, vec![80, 443]);
}
#[test]
fn test_port_requires_sni_single_passthrough() {
let mut route = make_route(443, Some("example.com"), 0);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
let manager = RouteManager::new(vec![route]);
// Single passthrough route with specific domain needs SNI
assert!(manager.port_requires_sni(443));
}
#[test]
fn test_port_requires_sni_wildcard_only() {
let mut route = make_route(443, Some("*"), 0);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
let manager = RouteManager::new(vec![route]);
// Single passthrough route with wildcard doesn't need SNI
assert!(!manager.port_requires_sni(443));
}
#[test]
fn test_routes_for_port() {
let routes = vec![
make_route(80, Some("a.com"), 0),
make_route(80, Some("b.com"), 0),
make_route(443, Some("c.com"), 0),
];
let manager = RouteManager::new(routes);
assert_eq!(manager.routes_for_port(80).len(), 2);
assert_eq!(manager.routes_for_port(443).len(), 1);
assert_eq!(manager.routes_for_port(8080).len(), 0);
}
#[test]
fn test_wildcard_domain_matches_any() {
let routes = vec![make_route(80, Some("*"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("anything.example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_tls_no_sni_rejects_domain_restricted_route() {
let routes = vec![make_route(443, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
// TLS connection without SNI should NOT match a domain-restricted route
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_tls_no_sni_rejects_wildcard_subdomain_route() {
let routes = vec![make_route(443, Some("*.example.com"), 0)];
let manager = RouteManager::new(routes);
// TLS connection without SNI should NOT match *.example.com
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_tls_no_sni_matches_wildcard_only_route() {
let routes = vec![make_route(443, Some("*"), 0)];
let manager = RouteManager::new(routes);
// TLS connection without SNI SHOULD match a wildcard-only route
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_tls_no_sni_skips_domain_restricted_matches_fallback() {
// Two routes: first is domain-restricted, second is wildcard catch-all
let routes = vec![
make_route(443, Some("specific.com"), 10),
make_route(443, Some("*"), 0),
];
let manager = RouteManager::new(routes);
// TLS without SNI should skip specific.com and fall through to wildcard
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
let result = manager.find_route(&ctx);
assert!(result.is_some());
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
.map(|d| d.to_vec()).unwrap();
assert!(matched_domains.contains(&"*"));
}
#[test]
fn test_non_tls_no_domain_still_matches_domain_restricted() {
// Non-TLS (plain HTTP) without domain should still match domain-restricted routes
// (the HTTP proxy layer handles Host-based routing)
let routes = vec![make_route(80, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_no_domain_route_matches_any_domain() {
let routes = vec![make_route(80, None, 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_target_sub_matching() {
let mut route = make_route(80, Some("example.com"), 0);
route.action.targets = Some(vec![
RouteTarget {
target_match: Some(rustproxy_config::TargetMatch {
ports: None,
path: Some("/api/*".to_string()),
headers: None,
method: None,
}),
host: HostSpec::Single("api-backend".to_string()),
port: PortSpec::Fixed(3000),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: Some(10),
},
RouteTarget {
target_match: None,
host: HostSpec::Single("default-backend".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
},
]);
let manager = RouteManager::new(vec![route]);
// Should match the API target
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: Some("/api/users"),
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "api-backend");
// Should fall back to default target
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: Some("/home"),
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "default-backend");
}
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
let mut route = make_route(port, domain, 0);
route.route_match.protocol = protocol.map(|s| s.to_string());
route
}
#[test]
fn test_protocol_http_matches_http() {
let routes = vec![make_route_with_protocol(80, None, Some("http"))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("http"),
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_protocol_http_rejects_tcp() {
let routes = vec![make_route_with_protocol(80, None, Some("http"))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("tcp"),
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_protocol_none_matches_any() {
// Route with no protocol restriction matches any protocol
let routes = vec![make_route_with_protocol(80, None, None)];
let manager = RouteManager::new(routes);
let ctx_http = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("http"),
};
assert!(manager.find_route(&ctx_http).is_some());
let ctx_tcp = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("tcp"),
};
assert!(manager.find_route(&ctx_tcp).is_some());
}
#[test]
fn test_protocol_http_matches_when_unknown() {
// Route with protocol: "http" should match when ctx.protocol is None
// (pre-TLS-termination, protocol not yet known)
let routes = vec![make_route_with_protocol(443, None, Some("http"))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
}