use thiserror::Error; use std::collections::HashMap; use tracing::{debug, info, warn}; #[derive(Debug, Error)] pub enum NftError { #[error("nft command failed: {0}")] CommandFailed(String), #[error("IO error: {0}")] Io(#[from] std::io::Error), #[error("Not running as root")] NotRoot, } /// Manager for nftables rules. /// /// Executes `nft` CLI commands to manage kernel-level packet forwarding. /// Requires root privileges; operations are skipped gracefully if not root. pub struct NftManager { table_name: String, /// Active rules indexed by route ID active_rules: HashMap>, /// Whether the table has been initialized table_initialized: bool, } impl NftManager { pub fn new(table_name: Option) -> Self { Self { table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()), active_rules: HashMap::new(), table_initialized: false, } } /// Check if we are running as root. fn is_root() -> bool { unsafe { libc::geteuid() == 0 } } /// Execute a single nft command via the CLI. async fn exec_nft(command: &str) -> Result { // The command starts with "nft ", strip it to get the args let args = if command.starts_with("nft ") { &command[4..] } else { command }; let output = tokio::process::Command::new("nft") .args(args.split_whitespace()) .output() .await .map_err(NftError::Io)?; if output.status.success() { Ok(String::from_utf8_lossy(&output.stdout).to_string()) } else { let stderr = String::from_utf8_lossy(&output.stderr); Err(NftError::CommandFailed(format!( "Command '{}' failed: {}", command, stderr ))) } } /// Ensure the nftables table and chains are set up. async fn ensure_table(&mut self) -> Result<(), NftError> { if self.table_initialized { return Ok(()); } let setup_commands = crate::rule_builder::build_table_setup(&self.table_name); for cmd in &setup_commands { Self::exec_nft(cmd).await?; } self.table_initialized = true; info!("NFTables table '{}' initialized", self.table_name); Ok(()) } /// Apply rules for a route. /// /// Executes the nft commands via the CLI. If not running as root, /// the rules are stored locally but not applied to the kernel. pub async fn apply_rules(&mut self, route_id: &str, rules: Vec) -> Result<(), NftError> { if !Self::is_root() { warn!("Not running as root, nftables rules will not be applied to kernel"); self.active_rules.insert(route_id.to_string(), rules); return Ok(()); } self.ensure_table().await?; for cmd in &rules { Self::exec_nft(cmd).await?; debug!("Applied nft rule: {}", cmd); } info!("Applied {} nftables rules for route '{}'", rules.len(), route_id); self.active_rules.insert(route_id.to_string(), rules); Ok(()) } /// Remove rules for a route. /// /// Currently removes the route from tracking. To fully remove specific /// rules would require handle-based tracking; for now, cleanup() removes /// the entire table. pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> { if let Some(rules) = self.active_rules.remove(route_id) { info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id); } Ok(()) } /// Clean up all managed rules by deleting the entire nftables table. pub async fn cleanup(&mut self) -> Result<(), NftError> { if !Self::is_root() { warn!("Not running as root, skipping nftables cleanup"); self.active_rules.clear(); self.table_initialized = false; return Ok(()); } if self.table_initialized { let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name); for cmd in &cleanup_commands { match Self::exec_nft(cmd).await { Ok(_) => debug!("Cleanup: {}", cmd), Err(e) => warn!("Cleanup command failed (may be ok): {}", e), } } info!("NFTables table '{}' cleaned up", self.table_name); } self.active_rules.clear(); self.table_initialized = false; Ok(()) } /// Get the table name. pub fn table_name(&self) -> &str { &self.table_name } /// Whether the table has been initialized in the kernel. pub fn is_initialized(&self) -> bool { self.table_initialized } /// Get the number of active route rule sets. pub fn active_route_count(&self) -> usize { self.active_rules.len() } /// Get the status of all active rules. pub fn status(&self) -> HashMap { let mut status = HashMap::new(); for (route_id, rules) in &self.active_rules { status.insert( route_id.clone(), serde_json::json!({ "ruleCount": rules.len(), "rules": rules, }), ); } status } } #[cfg(test)] mod tests { use super::*; #[test] fn test_new_default_table_name() { let mgr = NftManager::new(None); assert_eq!(mgr.table_name(), "rustproxy"); assert!(!mgr.is_initialized()); } #[test] fn test_new_custom_table_name() { let mgr = NftManager::new(Some("custom".to_string())); assert_eq!(mgr.table_name(), "custom"); } #[tokio::test] async fn test_apply_rules_non_root() { let mut mgr = NftManager::new(None); // When not root, rules are stored but not applied to kernel let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()]; mgr.apply_rules("route-1", rules).await.unwrap(); assert_eq!(mgr.active_route_count(), 1); let status = mgr.status(); assert!(status.contains_key("route-1")); assert_eq!(status["route-1"]["ruleCount"], 1); } #[tokio::test] async fn test_remove_rules() { let mut mgr = NftManager::new(None); let rules = vec!["nft add rule test".to_string()]; mgr.apply_rules("route-1", rules).await.unwrap(); assert_eq!(mgr.active_route_count(), 1); mgr.remove_rules("route-1").await.unwrap(); assert_eq!(mgr.active_route_count(), 0); } #[tokio::test] async fn test_cleanup_non_root() { let mut mgr = NftManager::new(None); let rules = vec!["nft add rule test".to_string()]; mgr.apply_rules("route-1", rules).await.unwrap(); mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap(); mgr.cleanup().await.unwrap(); assert_eq!(mgr.active_route_count(), 0); assert!(!mgr.is_initialized()); } #[tokio::test] async fn test_status_multiple_routes() { let mut mgr = NftManager::new(None); mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap(); mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap(); let status = mgr.status(); assert_eq!(status.len(), 2); assert_eq!(status["web"]["ruleCount"], 2); assert_eq!(status["api"]["ruleCount"], 1); } }