use std::collections::HashMap; use rustproxy_config::{RouteConfig, RouteTarget, TlsMode}; use crate::matchers; /// Context for route matching (subset of connection info). pub struct MatchContext<'a> { pub port: u16, pub domain: Option<&'a str>, pub path: Option<&'a str>, pub client_ip: Option<&'a str>, pub tls_version: Option<&'a str>, pub headers: Option<&'a HashMap>, pub is_tls: bool, /// Detected protocol: "http" or "tcp". None when unknown (e.g. pre-TLS-termination). pub protocol: Option<&'a str>, } /// Result of a route match. pub struct RouteMatchResult<'a> { pub route: &'a RouteConfig, pub target: Option<&'a RouteTarget>, } /// Port-indexed route lookup with priority-based matching. /// This is the core routing engine. pub struct RouteManager { /// Routes indexed by port for O(1) port lookup. port_index: HashMap>, /// All routes, sorted by priority (highest first). routes: Vec, } impl RouteManager { /// Create a new RouteManager from a list of routes. pub fn new(routes: Vec) -> Self { let mut manager = Self { port_index: HashMap::new(), routes: Vec::new(), }; // Filter enabled routes and sort by priority let mut enabled_routes: Vec = routes .into_iter() .filter(|r| r.is_enabled()) .collect(); enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority())); // Build port index for (idx, route) in enabled_routes.iter().enumerate() { for port in route.listening_ports() { manager.port_index .entry(port) .or_default() .push(idx); } } manager.routes = enabled_routes; manager } /// Find the best matching route for the given context. pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option> { // Get routes for this port let route_indices = self.port_index.get(&ctx.port)?; for &idx in route_indices { let route = &self.routes[idx]; if self.matches_route(route, ctx) { // Find the best matching target within the route let target = self.find_target(route, ctx); return Some(RouteMatchResult { route, target }); } } None } /// Check if a route matches the given context. fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool { let rm = &route.route_match; // Domain matching if let Some(ref domains) = rm.domains { if let Some(domain) = ctx.domain { let patterns = domains.to_vec(); if !matchers::domain_matches_any(&patterns, domain) { return false; } } else if ctx.is_tls { // TLS connection without SNI cannot match a domain-restricted route. // This prevents session-ticket resumption from misrouting when clients // omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption). // Wildcard-only routes (domains: ["*"]) still match since they accept all. let patterns = domains.to_vec(); let is_wildcard_only = patterns.iter().all(|d| *d == "*"); if !is_wildcard_only { return false; } } } // Path matching if let Some(ref pattern) = rm.path { if let Some(path) = ctx.path { if !matchers::path_matches(pattern, path) { return false; } } else { // Route requires path but none provided return false; } } // Client IP matching if let Some(ref client_ips) = rm.client_ip { if let Some(ip) = ctx.client_ip { if !matchers::ip_matches_any(client_ips, ip) { return false; } } else { return false; } } // TLS version matching if let Some(ref tls_versions) = rm.tls_version { if let Some(version) = ctx.tls_version { if !tls_versions.iter().any(|v| v == version) { return false; } } else { return false; } } // Header matching if let Some(ref patterns) = rm.headers { if let Some(headers) = ctx.headers { if !matchers::headers_match(patterns, headers) { return false; } } else { return false; } } // Protocol matching if let Some(ref required_protocol) = rm.protocol { if let Some(protocol) = ctx.protocol { if required_protocol != protocol { return false; } } // If protocol not yet known (None), allow match — protocol will be // validated after detection (post-TLS-termination peek) } true } /// Find the best matching target within a route. fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> { let targets = route.action.targets.as_ref()?; if targets.len() == 1 && targets[0].target_match.is_none() { return Some(&targets[0]); } // Sort candidates by priority (already in order from config) let mut best: Option<&RouteTarget> = None; let mut best_priority = i32::MIN; for target in targets { let priority = target.priority.unwrap_or(0); if let Some(ref tm) = target.target_match { if !self.matches_target(tm, ctx) { continue; } } if priority > best_priority || best.is_none() { best = Some(target); best_priority = priority; } } // Fall back to first target without match criteria best.or_else(|| { targets.iter().find(|t| t.target_match.is_none()) }) } /// Check if a target match criteria matches the context. fn matches_target( &self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>, ) -> bool { // Port matching if let Some(ref ports) = tm.ports { if !ports.contains(&ctx.port) { return false; } } // Path matching if let Some(ref pattern) = tm.path { if let Some(path) = ctx.path { if !matchers::path_matches(pattern, path) { return false; } } else { return false; } } // Header matching if let Some(ref patterns) = tm.headers { if let Some(headers) = ctx.headers { if !matchers::headers_match(patterns, headers) { return false; } } else { return false; } } true } /// Get all unique listening ports. pub fn listening_ports(&self) -> Vec { let mut ports: Vec = self.port_index.keys().copied().collect(); ports.sort(); ports } /// Get all routes for a specific port. pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> { self.port_index .get(&port) .map(|indices| indices.iter().map(|&i| &self.routes[i]).collect()) .unwrap_or_default() } /// Get the total number of enabled routes. pub fn route_count(&self) -> usize { self.routes.len() } /// Check if any route on the given port requires SNI. pub fn port_requires_sni(&self, port: u16) -> bool { let routes = self.routes_for_port(port); // If multiple passthrough routes on same port, SNI is needed let passthrough_routes: Vec<_> = routes .iter() .filter(|r| { r.tls_mode() == Some(&TlsMode::Passthrough) }) .collect(); if passthrough_routes.len() > 1 { return true; } // Single passthrough route with specific domain restriction needs SNI if let Some(route) = passthrough_routes.first() { if let Some(ref domains) = route.route_match.domains { let domain_list = domains.to_vec(); // If it's not just a wildcard, SNI is needed if !domain_list.iter().all(|d| *d == "*") { return true; } } } false } } #[cfg(test)] mod tests { use super::*; use rustproxy_config::*; fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig { RouteConfig { id: None, route_match: RouteMatch { ports: PortRange::Single(port), domains: domain.map(|d| DomainSpec::Single(d.to_string())), path: None, client_ip: None, tls_version: None, headers: None, protocol: None, }, action: RouteAction { action_type: RouteActionType::Forward, targets: Some(vec![RouteTarget { target_match: None, host: HostSpec::Single("localhost".to_string()), port: PortSpec::Fixed(8080), tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None, headers: None, advanced: None, priority: None, }]), tls: None, websocket: None, load_balancing: None, advanced: None, options: None, forwarding_engine: None, nftables: None, send_proxy_protocol: None, }, headers: None, security: None, name: None, description: None, priority: Some(priority), tags: None, enabled: None, } } #[test] fn test_basic_routing() { let routes = vec![ make_route(80, Some("example.com"), 0), make_route(80, Some("other.com"), 0), ]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: Some("example.com"), path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; let result = manager.find_route(&ctx); assert!(result.is_some()); } #[test] fn test_priority_ordering() { let routes = vec![ make_route(80, Some("*.example.com"), 0), make_route(80, Some("api.example.com"), 10), // Higher priority ]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: Some("api.example.com"), path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; let result = manager.find_route(&ctx).unwrap(); // Should match the higher-priority specific route assert!(result.route.route_match.domains.as_ref() .map(|d| d.to_vec()) .unwrap() .contains(&"api.example.com")); } #[test] fn test_no_match() { let routes = vec![make_route(80, Some("example.com"), 0)]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 443, // Different port domain: Some("example.com"), path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; assert!(manager.find_route(&ctx).is_none()); } #[test] fn test_disabled_routes_excluded() { let mut route = make_route(80, Some("example.com"), 0); route.enabled = Some(false); let manager = RouteManager::new(vec![route]); assert_eq!(manager.route_count(), 0); } #[test] fn test_listening_ports() { let routes = vec![ make_route(80, Some("a.com"), 0), make_route(443, Some("b.com"), 0), make_route(80, Some("c.com"), 0), // duplicate port ]; let manager = RouteManager::new(routes); let ports = manager.listening_ports(); assert_eq!(ports, vec![80, 443]); } #[test] fn test_port_requires_sni_single_passthrough() { let mut route = make_route(443, Some("example.com"), 0); route.action.tls = Some(RouteTls { mode: TlsMode::Passthrough, certificate: None, acme: None, versions: None, ciphers: None, honor_cipher_order: None, session_timeout: None, }); let manager = RouteManager::new(vec![route]); // Single passthrough route with specific domain needs SNI assert!(manager.port_requires_sni(443)); } #[test] fn test_port_requires_sni_wildcard_only() { let mut route = make_route(443, Some("*"), 0); route.action.tls = Some(RouteTls { mode: TlsMode::Passthrough, certificate: None, acme: None, versions: None, ciphers: None, honor_cipher_order: None, session_timeout: None, }); let manager = RouteManager::new(vec![route]); // Single passthrough route with wildcard doesn't need SNI assert!(!manager.port_requires_sni(443)); } #[test] fn test_routes_for_port() { let routes = vec![ make_route(80, Some("a.com"), 0), make_route(80, Some("b.com"), 0), make_route(443, Some("c.com"), 0), ]; let manager = RouteManager::new(routes); assert_eq!(manager.routes_for_port(80).len(), 2); assert_eq!(manager.routes_for_port(443).len(), 1); assert_eq!(manager.routes_for_port(8080).len(), 0); } #[test] fn test_wildcard_domain_matches_any() { let routes = vec![make_route(80, Some("*"), 0)]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: Some("anything.example.com"), path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; assert!(manager.find_route(&ctx).is_some()); } #[test] fn test_tls_no_sni_rejects_domain_restricted_route() { let routes = vec![make_route(443, Some("example.com"), 0)]; let manager = RouteManager::new(routes); // TLS connection without SNI should NOT match a domain-restricted route let ctx = MatchContext { port: 443, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: true, protocol: None, }; assert!(manager.find_route(&ctx).is_none()); } #[test] fn test_tls_no_sni_rejects_wildcard_subdomain_route() { let routes = vec![make_route(443, Some("*.example.com"), 0)]; let manager = RouteManager::new(routes); // TLS connection without SNI should NOT match *.example.com let ctx = MatchContext { port: 443, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: true, protocol: None, }; assert!(manager.find_route(&ctx).is_none()); } #[test] fn test_tls_no_sni_matches_wildcard_only_route() { let routes = vec![make_route(443, Some("*"), 0)]; let manager = RouteManager::new(routes); // TLS connection without SNI SHOULD match a wildcard-only route let ctx = MatchContext { port: 443, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: true, protocol: None, }; assert!(manager.find_route(&ctx).is_some()); } #[test] fn test_tls_no_sni_skips_domain_restricted_matches_fallback() { // Two routes: first is domain-restricted, second is wildcard catch-all let routes = vec![ make_route(443, Some("specific.com"), 10), make_route(443, Some("*"), 0), ]; let manager = RouteManager::new(routes); // TLS without SNI should skip specific.com and fall through to wildcard let ctx = MatchContext { port: 443, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: true, protocol: None, }; let result = manager.find_route(&ctx); assert!(result.is_some()); let matched_domains = result.unwrap().route.route_match.domains.as_ref() .map(|d| d.to_vec()).unwrap(); assert!(matched_domains.contains(&"*")); } #[test] fn test_non_tls_no_domain_still_matches_domain_restricted() { // Non-TLS (plain HTTP) without domain should still match domain-restricted routes // (the HTTP proxy layer handles Host-based routing) let routes = vec![make_route(80, Some("example.com"), 0)]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; assert!(manager.find_route(&ctx).is_some()); } #[test] fn test_no_domain_route_matches_any_domain() { let routes = vec![make_route(80, None, 0)]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: Some("example.com"), path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; assert!(manager.find_route(&ctx).is_some()); } #[test] fn test_target_sub_matching() { let mut route = make_route(80, Some("example.com"), 0); route.action.targets = Some(vec![ RouteTarget { target_match: Some(rustproxy_config::TargetMatch { ports: None, path: Some("/api/*".to_string()), headers: None, method: None, }), host: HostSpec::Single("api-backend".to_string()), port: PortSpec::Fixed(3000), tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None, headers: None, advanced: None, priority: Some(10), }, RouteTarget { target_match: None, host: HostSpec::Single("default-backend".to_string()), port: PortSpec::Fixed(8080), tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None, headers: None, advanced: None, priority: None, }, ]); let manager = RouteManager::new(vec![route]); // Should match the API target let ctx = MatchContext { port: 80, domain: Some("example.com"), path: Some("/api/users"), client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; let result = manager.find_route(&ctx).unwrap(); assert_eq!(result.target.unwrap().host.first(), "api-backend"); // Should fall back to default target let ctx = MatchContext { port: 80, domain: Some("example.com"), path: Some("/home"), client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: None, }; let result = manager.find_route(&ctx).unwrap(); assert_eq!(result.target.unwrap().host.first(), "default-backend"); } fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig { let mut route = make_route(port, domain, 0); route.route_match.protocol = protocol.map(|s| s.to_string()); route } #[test] fn test_protocol_http_matches_http() { let routes = vec![make_route_with_protocol(80, None, Some("http"))]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: Some("http"), }; assert!(manager.find_route(&ctx).is_some()); } #[test] fn test_protocol_http_rejects_tcp() { let routes = vec![make_route_with_protocol(80, None, Some("http"))]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 80, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: Some("tcp"), }; assert!(manager.find_route(&ctx).is_none()); } #[test] fn test_protocol_none_matches_any() { // Route with no protocol restriction matches any protocol let routes = vec![make_route_with_protocol(80, None, None)]; let manager = RouteManager::new(routes); let ctx_http = MatchContext { port: 80, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: Some("http"), }; assert!(manager.find_route(&ctx_http).is_some()); let ctx_tcp = MatchContext { port: 80, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: false, protocol: Some("tcp"), }; assert!(manager.find_route(&ctx_tcp).is_some()); } #[test] fn test_protocol_http_matches_when_unknown() { // Route with protocol: "http" should match when ctx.protocol is None // (pre-TLS-termination, protocol not yet known) let routes = vec![make_route_with_protocol(443, None, Some("http"))]; let manager = RouteManager::new(routes); let ctx = MatchContext { port: 443, domain: None, path: None, client_ip: None, tls_version: None, headers: None, is_tls: true, protocol: None, }; assert!(manager.find_route(&ctx).is_some()); } }