//! Route-aware upstream selection with load balancing. use std::collections::HashMap; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::sync::Mutex; use dashmap::DashMap; use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm}; /// Upstream selection result. pub struct UpstreamSelection { pub host: String, pub port: u16, pub use_tls: bool, } /// Selects upstream backends with load balancing support. pub struct UpstreamSelector { /// Round-robin counters per route (keyed by first target host:port) round_robin: Mutex>, /// Active connection counts per host (keyed by "host:port") active_connections: Arc>, } impl UpstreamSelector { pub fn new() -> Self { Self { round_robin: Mutex::new(HashMap::new()), active_connections: Arc::new(DashMap::new()), } } /// Select an upstream target based on the route target config and load balancing. pub fn select( &self, target: &RouteTarget, client_addr: &SocketAddr, incoming_port: u16, ) -> UpstreamSelection { let hosts = target.host.to_vec(); let port = target.port.resolve(incoming_port); if hosts.len() <= 1 { return UpstreamSelection { host: hosts.first().map(|s| s.to_string()).unwrap_or_default(), port, use_tls: target.tls.is_some(), }; } // Determine load balancing algorithm let algorithm = target.load_balancing.as_ref() .map(|lb| &lb.algorithm) .unwrap_or(&LoadBalancingAlgorithm::RoundRobin); let idx = match algorithm { LoadBalancingAlgorithm::RoundRobin => { self.round_robin_select(&hosts, port) } LoadBalancingAlgorithm::IpHash => { let hash = Self::ip_hash(client_addr); hash % hosts.len() } LoadBalancingAlgorithm::LeastConnections => { self.least_connections_select(&hosts, port) } }; UpstreamSelection { host: hosts[idx].to_string(), port, use_tls: target.tls.is_some(), } } fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize { let key = format!("{}:{}", hosts[0], port); let mut counters = self.round_robin.lock().unwrap(); let counter = counters .entry(key) .or_insert_with(|| AtomicUsize::new(0)); let idx = counter.fetch_add(1, Ordering::Relaxed); idx % hosts.len() } fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize { let mut min_conns = u64::MAX; let mut min_idx = 0; for (i, host) in hosts.iter().enumerate() { let key = format!("{}:{}", host, port); let conns = self.active_connections .get(&key) .map(|entry| entry.value().load(Ordering::Relaxed)) .unwrap_or(0); if conns < min_conns { min_conns = conns; min_idx = i; } } min_idx } /// Record that a connection to the given host has started. pub fn connection_started(&self, host: &str) { self.active_connections .entry(host.to_string()) .or_insert_with(|| AtomicU64::new(0)) .fetch_add(1, Ordering::Relaxed); } /// Record that a connection to the given host has ended. pub fn connection_ended(&self, host: &str) { if let Some(counter) = self.active_connections.get(host) { let prev = counter.value().fetch_sub(1, Ordering::Relaxed); // Guard against underflow (shouldn't happen, but be safe) if prev == 0 { counter.value().store(0, Ordering::Relaxed); } } } fn ip_hash(addr: &SocketAddr) -> usize { let ip_str = addr.ip().to_string(); let mut hash: usize = 5381; for byte in ip_str.bytes() { hash = hash.wrapping_mul(33).wrapping_add(byte as usize); } hash } } impl Default for UpstreamSelector { fn default() -> Self { Self::new() } } impl Clone for UpstreamSelector { fn clone(&self) -> Self { Self { round_robin: Mutex::new(HashMap::new()), active_connections: Arc::clone(&self.active_connections), } } } #[cfg(test)] mod tests { use super::*; use rustproxy_config::*; fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget { RouteTarget { target_match: None, host: if hosts.len() == 1 { HostSpec::Single(hosts[0].to_string()) } else { HostSpec::List(hosts.iter().map(|s| s.to_string()).collect()) }, port: PortSpec::Fixed(port), tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None, headers: None, advanced: None, priority: None, } } #[test] fn test_single_host() { let selector = UpstreamSelector::new(); let target = make_target(vec!["backend"], 8080); let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap(); let result = selector.select(&target, &addr, 80); assert_eq!(result.host, "backend"); assert_eq!(result.port, 8080); } #[test] fn test_round_robin() { let selector = UpstreamSelector::new(); let mut target = make_target(vec!["a", "b", "c"], 8080); target.load_balancing = Some(RouteLoadBalancing { algorithm: LoadBalancingAlgorithm::RoundRobin, health_check: None, }); let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap(); let r1 = selector.select(&target, &addr, 80); let r2 = selector.select(&target, &addr, 80); let r3 = selector.select(&target, &addr, 80); let r4 = selector.select(&target, &addr, 80); // Should cycle through a, b, c, a assert_eq!(r1.host, "a"); assert_eq!(r2.host, "b"); assert_eq!(r3.host, "c"); assert_eq!(r4.host, "a"); } #[test] fn test_ip_hash_consistent() { let selector = UpstreamSelector::new(); let mut target = make_target(vec!["a", "b", "c"], 8080); target.load_balancing = Some(RouteLoadBalancing { algorithm: LoadBalancingAlgorithm::IpHash, health_check: None, }); let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap(); let r1 = selector.select(&target, &addr, 80); let r2 = selector.select(&target, &addr, 80); // Same IP should always get same backend assert_eq!(r1.host, r2.host); } }