Files
smartvpn/rust/src/acl.rs

303 lines
11 KiB
Rust

use std::net::Ipv4Addr;
use ipnet::Ipv4Net;
use crate::client_registry::ClientSecurity;
/// Result of an ACL check.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AclResult {
Allow,
DenySrc,
DenyDst,
}
/// Check whether a connection source IP is in a server-level block list.
/// Used for pre-handshake rejection of known-bad IPs.
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
ip_matches_any(ip, block_list)
}
/// Check whether a source IP is allowed by allow/block lists.
/// Returns true if the IP is permitted (not blocked and passes allow check).
pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool {
// Deny overrides allow
if let Some(bl) = block_list {
if ip_matches_any(ip, bl) {
return false;
}
}
// If allow list exists and is non-empty, IP must match
if let Some(al) = allow_list {
if !al.is_empty() && !ip_matches_any(ip, al) {
return false;
}
}
true
}
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
///
/// Evaluation order (deny overrides allow):
/// 1. If src_ip is in ip_block_list → DenySrc
/// 2. If dst_ip is in destination_block_list → DenyDst
/// 3. If ip_allow_list is non-empty and src_ip is NOT in it → DenySrc
/// 4. If destination_allow_list is non-empty and dst_ip is NOT in it → DenyDst
/// 5. Otherwise → Allow
pub fn check_acl(security: &ClientSecurity, src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> AclResult {
// Step 1: Check source block list (deny overrides)
if let Some(ref block_list) = security.ip_block_list {
if ip_matches_any(src_ip, block_list) {
return AclResult::DenySrc;
}
}
// Step 2: Check destination block list (deny overrides)
if let Some(ref block_list) = security.destination_block_list {
if ip_matches_any(dst_ip, block_list) {
return AclResult::DenyDst;
}
}
// Step 3: Check source allow list (if non-empty, must match)
if let Some(ref allow_list) = security.ip_allow_list {
if !allow_list.is_empty() && !ip_matches_any(src_ip, allow_list) {
return AclResult::DenySrc;
}
}
// Step 4: Check destination allow list (if non-empty, must match)
if let Some(ref allow_list) = security.destination_allow_list {
if !allow_list.is_empty() && !ip_matches_any(dst_ip, allow_list) {
return AclResult::DenyDst;
}
}
AclResult::Allow
}
/// Check if `ip` matches any pattern in the list.
/// Supports: exact IP, CIDR notation, wildcard patterns (192.168.1.*),
/// and IP ranges (192.168.1.1-192.168.1.100).
pub fn ip_matches_any(ip: Ipv4Addr, patterns: &[String]) -> bool {
for pattern in patterns {
if ip_matches(ip, pattern) {
return true;
}
}
false
}
/// Check if `ip` matches a single pattern.
fn ip_matches(ip: Ipv4Addr, pattern: &str) -> bool {
let pattern = pattern.trim();
// CIDR notation (e.g. 192.168.1.0/24)
if pattern.contains('/') {
if let Ok(net) = pattern.parse::<Ipv4Net>() {
return net.contains(&ip);
}
return false;
}
// IP range (e.g. 192.168.1.1-192.168.1.100)
if pattern.contains('-') {
let parts: Vec<&str> = pattern.splitn(2, '-').collect();
if parts.len() == 2 {
if let (Ok(start), Ok(end)) = (parts[0].trim().parse::<Ipv4Addr>(), parts[1].trim().parse::<Ipv4Addr>()) {
let ip_u32 = u32::from(ip);
return ip_u32 >= u32::from(start) && ip_u32 <= u32::from(end);
}
}
return false;
}
// Wildcard pattern (e.g. 192.168.1.*)
if pattern.contains('*') {
return wildcard_matches(ip, pattern);
}
// Exact IP match
if let Ok(exact) = pattern.parse::<Ipv4Addr>() {
return ip == exact;
}
false
}
/// Match an IP against a wildcard pattern like "192.168.1.*" or "10.*.*.*".
fn wildcard_matches(ip: Ipv4Addr, pattern: &str) -> bool {
let ip_octets = ip.octets();
let pattern_parts: Vec<&str> = pattern.split('.').collect();
if pattern_parts.len() != 4 {
return false;
}
for (i, part) in pattern_parts.iter().enumerate() {
if *part == "*" {
continue;
}
if let Ok(octet) = part.parse::<u8>() {
if ip_octets[i] != octet {
return false;
}
} else {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client_registry::{ClientRateLimit, ClientSecurity};
fn security(
ip_allow: Option<Vec<&str>>,
ip_block: Option<Vec<&str>>,
dst_allow: Option<Vec<&str>>,
dst_block: Option<Vec<&str>>,
) -> ClientSecurity {
ClientSecurity {
ip_allow_list: ip_allow.map(|v| v.into_iter().map(String::from).collect()),
ip_block_list: ip_block.map(|v| v.into_iter().map(String::from).collect()),
destination_allow_list: dst_allow.map(|v| v.into_iter().map(String::from).collect()),
destination_block_list: dst_block.map(|v| v.into_iter().map(String::from).collect()),
max_connections: None,
rate_limit: None,
}
}
fn ip(s: &str) -> Ipv4Addr {
s.parse().unwrap()
}
// ── No restrictions (empty security) ────────────────────────────────
#[test]
fn empty_security_allows_all() {
let sec = security(None, None, None, None);
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
}
#[test]
fn empty_lists_allow_all() {
let sec = security(Some(vec![]), Some(vec![]), Some(vec![]), Some(vec![]));
assert_eq!(check_acl(&sec, ip("1.2.3.4"), ip("5.6.7.8")), AclResult::Allow);
}
// ── Source IP allow list ────────────────────────────────────────────
#[test]
fn src_allow_exact_match() {
let sec = security(Some(vec!["10.0.0.1"]), None, None, None);
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("5.6.7.8")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("10.0.0.2"), ip("5.6.7.8")), AclResult::DenySrc);
}
#[test]
fn src_allow_cidr() {
let sec = security(Some(vec!["192.168.1.0/24"]), None, None, None);
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("192.168.2.1"), ip("5.6.7.8")), AclResult::DenySrc);
}
#[test]
fn src_allow_wildcard() {
let sec = security(Some(vec!["10.0.*.*"]), None, None, None);
assert_eq!(check_acl(&sec, ip("10.0.5.3"), ip("5.6.7.8")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("10.1.0.1"), ip("5.6.7.8")), AclResult::DenySrc);
}
#[test]
fn src_allow_range() {
let sec = security(Some(vec!["10.0.0.1-10.0.0.10"]), None, None, None);
assert_eq!(check_acl(&sec, ip("10.0.0.5"), ip("5.6.7.8")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("10.0.0.11"), ip("5.6.7.8")), AclResult::DenySrc);
}
// ── Source IP block list (deny overrides) ───────────────────────────
#[test]
fn src_block_overrides_allow() {
let sec = security(
Some(vec!["192.168.1.0/24"]),
Some(vec!["192.168.1.100"]),
None,
None,
);
assert_eq!(check_acl(&sec, ip("192.168.1.50"), ip("5.6.7.8")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("192.168.1.100"), ip("5.6.7.8")), AclResult::DenySrc);
}
// ── Destination allow list ──────────────────────────────────────────
#[test]
fn dst_allow_exact() {
let sec = security(None, None, Some(vec!["8.8.8.8", "8.8.4.4"]), None);
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("1.1.1.1")), AclResult::DenyDst);
}
#[test]
fn dst_allow_cidr() {
let sec = security(None, None, Some(vec!["10.0.0.0/8"]), None);
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.5.3.2")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("172.16.0.1")), AclResult::DenyDst);
}
// ── Destination block list (deny overrides) ─────────────────────────
#[test]
fn dst_block_overrides_allow() {
let sec = security(
None,
None,
Some(vec!["10.0.0.0/8"]),
Some(vec!["10.0.0.99"]),
);
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.1")), AclResult::Allow);
assert_eq!(check_acl(&sec, ip("1.1.1.1"), ip("10.0.0.99")), AclResult::DenyDst);
}
// ── Combined source + destination ───────────────────────────────────
#[test]
fn combined_src_and_dst_filtering() {
let sec = security(
Some(vec!["192.168.1.0/24"]),
None,
Some(vec!["8.8.8.8"]),
None,
);
// Valid source, valid dest
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("8.8.8.8")), AclResult::Allow);
// Invalid source
assert_eq!(check_acl(&sec, ip("10.0.0.1"), ip("8.8.8.8")), AclResult::DenySrc);
// Valid source, invalid dest
assert_eq!(check_acl(&sec, ip("192.168.1.10"), ip("1.1.1.1")), AclResult::DenyDst);
}
// ── IP matching edge cases ──────────────────────────────────────────
#[test]
fn wildcard_single_octet() {
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.*"));
assert!(!ip_matches(ip("10.0.1.5"), "10.0.0.*"));
}
#[test]
fn range_boundaries() {
assert!(ip_matches(ip("10.0.0.1"), "10.0.0.1-10.0.0.5"));
assert!(ip_matches(ip("10.0.0.5"), "10.0.0.1-10.0.0.5"));
assert!(!ip_matches(ip("10.0.0.6"), "10.0.0.1-10.0.0.5"));
assert!(!ip_matches(ip("10.0.0.0"), "10.0.0.1-10.0.0.5"));
}
#[test]
fn invalid_pattern_no_match() {
assert!(!ip_matches(ip("10.0.0.1"), "not-an-ip"));
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0.1/99"));
assert!(!ip_matches(ip("10.0.0.1"), "10.0.0"));
}
}