use std::net::Ipv4Addr; use ipnet::Ipv4Net; use crate::client_registry::ClientSecurity; /// Result of an ACL check. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AclResult { Allow, DenySrc, DenyDst, } /// Check whether a connection source IP is in a server-level block list. /// Used for pre-handshake rejection of known-bad IPs. pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool { ip_matches_any(ip, block_list) } /// Check whether a source IP is allowed by allow/block lists. /// Returns true if the IP is permitted (not blocked and passes allow check). pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool { // Deny overrides allow if let Some(bl) = block_list { if ip_matches_any(ip, bl) { return false; } } // If allow list exists and is non-empty, IP must match if let Some(al) = allow_list { if !al.is_empty() && !ip_matches_any(ip, al) { return false; } } true } /// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy. /// /// Evaluation order (deny overrides allow): /// 1. If src_ip is in ip_block_list → DenySrc /// 2. If dst_ip is in destination_block_list → DenyDst /// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc /// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst /// 5. Otherwise → Allow pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult { // Step 1: Check source block list (deny overrides) if let Some(ref block_list) = security.ip_block_list { if ip_matches_any(src_ip, block_list) { return AclResult::DenySrc; } } // Step 2: Check destination block list (deny overrides) if let Some(ref block_list) = security.destination_block_list { if ip_matches_any(dst_ip, block_list) { return AclResult::DenyDst; } } // Step 3: Check source allow list (if non-empty, must match) if let Some(ref allow_list) = security.ip_allow_list { if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) { return AclResult::DenySrc; } } // Step 4: Check destination allow list (if non-empty, must match) if let Some(ref allow_list) = security.destination_allow_list { if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) { return AclResult::DenyDst; } } AclResult::Allow } /// Check if `ip` matches any pattern in the list. /// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*), /// and IP ranges (192.168.1.1-192.168.1.100). pub fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool { for pattern in patterns { if ip_matches(ip, pattern) { return true; } } false } /// Check if `ip` matches a single pattern. fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool { let pattern = pattern.trim(); // CIDR notation (e.g. 192.168.1.0/24) if pattern.contains('/') { if let Ok(net) = pattern.parse::() { return net.contains(&ip); } return false; } // IP range (e.g. 192.168.1.1-192.168.1.100) if pattern.contains('-') { let parts: Vec<&str> = pattern.splitn(2, '-').collect(); if parts.len() == 2 { if let (Ok(start), Ok(end)) = (parts[0].trim().parse::(), parts[1].trim().parse::()) { let ip_u32 = u32::from(ip); return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end); } } return false; } // Wildcard pattern (e.g. 192.168.1.*) if pattern.contains('*') { return wildcard_matches(ip, pattern); } // Exact IP match if let Ok(exact) = pattern.parse::() { return ip == exact; } false } /// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*". fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool { let ip_octets = ip.octets(); let pattern_parts: Vec<&str> = pattern.split('.').collect(); if pattern_parts.len() != 4 { return false; } for (i, part) in pattern_parts.iter().enumerate() { if *part == "*" { continue; } if let Ok(octet) = part.parse::() { if ip_octets[i] != octet { return false; } } else { return false; } } true } #[cfg(test)] mod tests { use super::*; use crate::client_registry::{ClientRateLimit, ClientSecurity}; fn security( ip_allow: Option>, ip_block: Option>, dst_allow: Option>, dst_block: Option>, ) -> ClientSecurity { ClientSecurity { ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()), ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()), destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()), destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()), max_connections: None, rate_limit: None, } } fn ip(s: &str) -> Ipv4Addr { s.parse().unwrap() } // ── No restrictions (empty security) ──────────────────────────────── #[test] fn empty_security_allows_all() { let sec = security(None, None, None, None); assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow); } #[test] fn empty_lists_allow_all() { let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![])); assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow); } // ── Source IP allow list ──────────────────────────────────────────── #[test] fn src_allow_exact_match() { let sec = security(Some(vec!["10.0.0.1"]), None, None, None); assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc); } #[test] fn src_allow_cidr() { let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None); assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc); } #[test] fn src_allow_wildcard() { let sec = security(Some(vec!["10.0.*.*"]), None, None, None); assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc); } #[test] fn src_allow_range() { let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None); assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc); } // ── Source IP block list (deny overrides) ─────────────────────────── #[test] fn src_block_overrides_allow() { let sec = security( Some(vec!["192.168.1.0/24"]), Some(vec!["192.168.1.100"]), None, None, ); assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc); } // ── Destination allow list ────────────────────────────────────────── #[test] fn dst_allow_exact() { let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None); assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst); } #[test] fn dst_allow_cidr() { let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None); assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst); } // ── Destination block list (deny overrides) ───────────────────────── #[test] fn dst_block_overrides_allow() { let sec = security( None, None, Some(vec!["10.0.0.0/8"]), Some(vec!["10.0.0.99"]), ); assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow); assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst); } // ── Combined source + destination ─────────────────────────────────── #[test] fn combined_src_and_dst_filtering() { let sec = security( Some(vec!["192.168.1.0/24"]), None, Some(vec!["8.8.8.8"]), None, ); // Valid source, valid dest assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow); // Invalid source assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc); // Valid source, invalid dest assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst); } // ── IP matching edge cases ────────────────────────────────────────── #[test] fn wildcard_single_octet() { assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*")); assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*")); } #[test] fn range_boundaries() { assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5")); assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5")); assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5")); assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5")); } #[test] fn invalid_pattern_no_match() { assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip")); assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99")); assert!(!ip_matches(ip("10.0.0.1"), "10.0.0")); } }