feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
This commit is contained in:
931
rust/crates/rustproxy/src/lib.rs
Normal file
931
rust/crates/rustproxy/src/lib.rs
Normal file
@@ -0,0 +1,931 @@
|
||||
//! # RustProxy
|
||||
//!
|
||||
//! High-performance multi-protocol proxy built on Rust,
|
||||
//! compatible with SmartProxy configuration.
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use rustproxy::RustProxy;
|
||||
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let options = RustProxyOptions {
|
||||
//! routes: vec![
|
||||
//! create_https_passthrough_route("example.com", "backend", 443),
|
||||
//! ],
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! let mut proxy = RustProxy::new(options)?;
|
||||
//! proxy.start().await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod challenge_server;
|
||||
pub mod management;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use anyhow::Result;
|
||||
use tracing::{info, warn, debug, error};
|
||||
|
||||
// Re-export key types
|
||||
pub use rustproxy_config;
|
||||
pub use rustproxy_routing;
|
||||
pub use rustproxy_passthrough;
|
||||
pub use rustproxy_tls;
|
||||
pub use rustproxy_http;
|
||||
pub use rustproxy_nftables;
|
||||
pub use rustproxy_metrics;
|
||||
pub use rustproxy_security;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
|
||||
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
|
||||
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use rustproxy_nftables::{NftManager, rule_builder};
|
||||
|
||||
/// Certificate status.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertStatus {
|
||||
pub domain: String,
|
||||
pub source: String,
|
||||
pub expires_at: u64,
|
||||
pub is_valid: bool,
|
||||
}
|
||||
|
||||
/// The main RustProxy struct.
|
||||
/// This is the primary public API matching SmartProxy's interface.
|
||||
pub struct RustProxy {
|
||||
options: RustProxyOptions,
|
||||
route_table: ArcSwap<RouteManager>,
|
||||
listener_manager: Option<TcpListenerManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
|
||||
challenge_server: Option<challenge_server::ChallengeServer>,
|
||||
renewal_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
nft_manager: Option<NftManager>,
|
||||
started: bool,
|
||||
started_at: Option<Instant>,
|
||||
/// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
|
||||
socket_handler_relay_path: Option<String>,
|
||||
}
|
||||
|
||||
impl RustProxy {
|
||||
/// Create a new RustProxy instance with the given configuration.
|
||||
pub fn new(mut options: RustProxyOptions) -> Result<Self> {
|
||||
// Apply defaults to routes before validation
|
||||
Self::apply_defaults(&mut options);
|
||||
|
||||
// Validate routes
|
||||
if let Err(errors) = rustproxy_config::validate_routes(&options.routes) {
|
||||
for err in &errors {
|
||||
warn!("Route validation error: {}", err);
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
anyhow::bail!("Route validation failed with {} errors", errors.len());
|
||||
}
|
||||
}
|
||||
|
||||
let route_manager = RouteManager::new(options.routes.clone());
|
||||
|
||||
// Set up certificate manager if ACME is configured
|
||||
let cert_manager = Self::build_cert_manager(&options)
|
||||
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
||||
|
||||
Ok(Self {
|
||||
options,
|
||||
route_table: ArcSwap::from(Arc::new(route_manager)),
|
||||
listener_manager: None,
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
cert_manager,
|
||||
challenge_server: None,
|
||||
renewal_handle: None,
|
||||
nft_manager: None,
|
||||
started: false,
|
||||
started_at: None,
|
||||
socket_handler_relay_path: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply default configuration to routes that lack targets or security.
|
||||
fn apply_defaults(options: &mut RustProxyOptions) {
|
||||
let defaults = match &options.defaults {
|
||||
Some(d) => d.clone(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
for route in &mut options.routes {
|
||||
// Apply default target if route has no targets
|
||||
if route.action.targets.is_none() {
|
||||
if let Some(ref default_target) = defaults.target {
|
||||
debug!("Applying default target {}:{} to route {:?}",
|
||||
default_target.host, default_target.port,
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.action.targets = Some(vec![
|
||||
rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
||||
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply default security if route has no security
|
||||
if route.security.is_none() {
|
||||
if let Some(ref default_security) = defaults.security {
|
||||
let mut security = rustproxy_config::RouteSecurity {
|
||||
ip_allow_list: None,
|
||||
ip_block_list: None,
|
||||
max_connections: default_security.max_connections,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
};
|
||||
|
||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||
security.ip_allow_list = Some(allow_list.clone());
|
||||
}
|
||||
if let Some(ref block_list) = default_security.ip_block_list {
|
||||
security.ip_block_list = Some(block_list.clone());
|
||||
}
|
||||
|
||||
// Only apply if there's something meaningful
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
debug!("Applying default security to route {:?}",
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.security = Some(security);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a CertManager from options.
|
||||
fn build_cert_manager(options: &RustProxyOptions) -> Option<CertManager> {
|
||||
let acme = options.acme.as_ref()?;
|
||||
if !acme.enabled.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let store_path = acme.certificate_store
|
||||
.as_deref()
|
||||
.unwrap_or("./certs");
|
||||
let email = acme.email.clone()
|
||||
.or_else(|| acme.account_email.clone());
|
||||
let use_production = acme.use_production.unwrap_or(false);
|
||||
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
|
||||
|
||||
let store = CertStore::new(store_path);
|
||||
Some(CertManager::new(store, email, use_production, renew_before_days))
|
||||
}
|
||||
|
||||
/// Build ConnectionConfig from RustProxyOptions.
|
||||
fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig {
|
||||
ConnectionConfig {
|
||||
connection_timeout_ms: options.effective_connection_timeout(),
|
||||
initial_data_timeout_ms: options.effective_initial_data_timeout(),
|
||||
socket_timeout_ms: options.effective_socket_timeout(),
|
||||
max_connection_lifetime_ms: options.effective_max_connection_lifetime(),
|
||||
graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000),
|
||||
max_connections_per_ip: options.max_connections_per_ip,
|
||||
connection_rate_limit_per_minute: options.connection_rate_limit_per_minute,
|
||||
keep_alive_treatment: options.keep_alive_treatment.clone(),
|
||||
keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier,
|
||||
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
|
||||
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
|
||||
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the proxy, binding to all configured ports.
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
if self.started {
|
||||
anyhow::bail!("Proxy is already started");
|
||||
}
|
||||
|
||||
info!("Starting RustProxy...");
|
||||
|
||||
// Load persisted certificates
|
||||
if let Some(ref cm) = self.cert_manager {
|
||||
let mut cm = cm.lock().await;
|
||||
match cm.load_all() {
|
||||
Ok(count) => {
|
||||
if count > 0 {
|
||||
info!("Loaded {} persisted certificates", count);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to load persisted certificates: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-provision certificates for routes with certificate: 'auto'
|
||||
self.auto_provision_certificates().await;
|
||||
|
||||
let route_manager = self.route_table.load();
|
||||
let ports = route_manager.listening_ports();
|
||||
|
||||
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
|
||||
|
||||
// Create TCP listener manager with metrics
|
||||
let mut listener = TcpListenerManager::with_metrics(
|
||||
Arc::clone(&*route_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
);
|
||||
|
||||
// Apply connection config from options
|
||||
let conn_config = Self::build_connection_config(&self.options);
|
||||
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
||||
conn_config.connection_timeout_ms,
|
||||
conn_config.initial_data_timeout_ms,
|
||||
conn_config.socket_timeout_ms,
|
||||
conn_config.max_connection_lifetime_ms,
|
||||
);
|
||||
listener.set_connection_config(conn_config);
|
||||
|
||||
// Extract TLS configurations from routes and cert manager
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
|
||||
// Also load certs from cert manager into TLS config
|
||||
if let Some(ref cm) = self.cert_manager {
|
||||
let cm = cm.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tls_configs.is_empty() {
|
||||
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
// Bind all ports
|
||||
for port in &ports {
|
||||
listener.add_port(*port).await?;
|
||||
}
|
||||
|
||||
self.listener_manager = Some(listener);
|
||||
self.started = true;
|
||||
self.started_at = Some(Instant::now());
|
||||
|
||||
// Apply NFTables rules for routes using nftables forwarding engine
|
||||
self.apply_nftables_rules(&self.options.routes.clone()).await;
|
||||
|
||||
// Start renewal timer if ACME is enabled
|
||||
self.start_renewal_timer();
|
||||
|
||||
info!("RustProxy started successfully on ports: {:?}", ports);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Auto-provision certificates for routes that use certificate: 'auto'.
|
||||
async fn auto_provision_certificates(&mut self) {
|
||||
let cm_arc = match self.cert_manager {
|
||||
Some(ref cm) => Arc::clone(cm),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let mut domains_to_provision = Vec::new();
|
||||
|
||||
for route in &self.options.routes {
|
||||
let tls_mode = route.tls_mode();
|
||||
let needs_cert = matches!(
|
||||
tls_mode,
|
||||
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
|
||||
);
|
||||
if !needs_cert {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Auto(_)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
let domain = domain.to_string();
|
||||
// Skip if we already have a valid cert
|
||||
let cm = cm_arc.lock().await;
|
||||
if cm.store().has(&domain) {
|
||||
debug!("Already have cert for {}, skipping auto-provision", domain);
|
||||
continue;
|
||||
}
|
||||
drop(cm);
|
||||
domains_to_provision.push(domain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if domains_to_provision.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut challenge_server = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = challenge_server.start(acme_port).await {
|
||||
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
|
||||
return;
|
||||
}
|
||||
|
||||
for domain in &domains_to_provision {
|
||||
info!("Provisioning certificate for {}", domain);
|
||||
|
||||
let cm = cm_arc.lock().await;
|
||||
let acme_client = cm.acme_client();
|
||||
drop(cm);
|
||||
|
||||
if let Some(acme_client) = acme_client {
|
||||
let challenge_server_ref = &challenge_server;
|
||||
let result = acme_client.provision(domain, |pending| {
|
||||
challenge_server_ref.set_challenge(
|
||||
pending.token.clone(),
|
||||
pending.key_authorization.clone(),
|
||||
);
|
||||
async move { Ok(()) }
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok((cert_pem, key_pem)) => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem,
|
||||
key_pem,
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.clone(),
|
||||
source: CertSource::Acme,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400, // 90 days
|
||||
renewed_at: None,
|
||||
},
|
||||
};
|
||||
|
||||
let mut cm = cm_arc.lock().await;
|
||||
if let Err(e) = cm.load_static(domain.clone(), bundle) {
|
||||
error!("Failed to store certificate for {}: {}", domain, e);
|
||||
}
|
||||
|
||||
info!("Certificate provisioned for {}", domain);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to provision certificate for {}: {}", domain, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
challenge_server.stop().await;
|
||||
}
|
||||
|
||||
/// Start the renewal timer background task.
|
||||
/// The background task checks for expiring certificates and renews them.
|
||||
fn start_renewal_timer(&mut self) {
|
||||
let cm_arc = match self.cert_manager {
|
||||
Some(ref cm) => Arc::clone(cm),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let auto_renew = self.options.acme.as_ref()
|
||||
.and_then(|a| a.auto_renew)
|
||||
.unwrap_or(true);
|
||||
|
||||
if !auto_renew {
|
||||
return;
|
||||
}
|
||||
|
||||
let check_interval_hours = self.options.acme.as_ref()
|
||||
.and_then(|a| a.renew_check_interval_hours)
|
||||
.unwrap_or(24);
|
||||
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours);
|
||||
|
||||
// Check which domains need renewal
|
||||
let domains = {
|
||||
let cm = cm_arc.lock().await;
|
||||
cm.check_renewals()
|
||||
};
|
||||
|
||||
if domains.is_empty() {
|
||||
debug!("No certificates need renewal");
|
||||
continue;
|
||||
}
|
||||
|
||||
info!("Renewing {} certificate(s)", domains.len());
|
||||
|
||||
// Start challenge server for renewals
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = cs.start(acme_port).await {
|
||||
error!("Failed to start challenge server for renewal: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
for domain in &domains {
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok(_bundle) => {
|
||||
info!("Successfully renewed certificate for {}", domain);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to renew certificate for {}: {}", domain, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cs.stop().await;
|
||||
}
|
||||
});
|
||||
|
||||
self.renewal_handle = Some(handle);
|
||||
}
|
||||
|
||||
/// Stop the proxy gracefully.
|
||||
pub async fn stop(&mut self) -> Result<()> {
|
||||
if !self.started {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Stopping RustProxy...");
|
||||
|
||||
// Stop renewal timer
|
||||
if let Some(handle) = self.renewal_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
// Stop challenge server if running
|
||||
if let Some(ref mut cs) = self.challenge_server {
|
||||
cs.stop().await;
|
||||
}
|
||||
self.challenge_server = None;
|
||||
|
||||
// Clean up NFTables rules
|
||||
if let Some(ref mut nft) = self.nft_manager {
|
||||
if let Err(e) = nft.cleanup().await {
|
||||
warn!("NFTables cleanup failed: {}", e);
|
||||
}
|
||||
}
|
||||
self.nft_manager = None;
|
||||
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.graceful_stop().await;
|
||||
}
|
||||
self.listener_manager = None;
|
||||
self.started = false;
|
||||
|
||||
info!("RustProxy stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update routes atomically (hot-reload).
|
||||
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
|
||||
// Validate new routes
|
||||
rustproxy_config::validate_routes(&routes)
|
||||
.map_err(|errors| {
|
||||
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
||||
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
||||
})?;
|
||||
|
||||
let new_manager = RouteManager::new(routes.clone());
|
||||
let new_ports = new_manager.listening_ports();
|
||||
|
||||
info!("Updating routes: {} routes on {} ports",
|
||||
new_manager.route_count(), new_ports.len());
|
||||
|
||||
// Get old ports
|
||||
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
|
||||
listener.listening_ports()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Atomically swap the route table
|
||||
let new_manager = Arc::new(new_manager);
|
||||
self.route_table.store(Arc::clone(&new_manager));
|
||||
|
||||
// Update listener manager
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.update_route_manager(Arc::clone(&new_manager));
|
||||
|
||||
// Update TLS configs
|
||||
let mut tls_configs = Self::extract_tls_configs(&routes);
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
listener.set_tls_configs(tls_configs);
|
||||
|
||||
// Add new ports
|
||||
for port in &new_ports {
|
||||
if !old_ports.contains(port) {
|
||||
listener.add_port(*port).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old ports no longer needed
|
||||
for port in &old_ports {
|
||||
if !new_ports.contains(port) {
|
||||
listener.remove_port(*port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update NFTables rules: remove old, apply new
|
||||
self.update_nftables_rules(&routes).await;
|
||||
|
||||
self.options.routes = routes;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Provision a certificate for a named route.
|
||||
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
let cm_arc = self.cert_manager.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
|
||||
|
||||
// Find the route by name
|
||||
let route = self.options.routes.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
|
||||
|
||||
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
cs.start(acme_port).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
|
||||
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(&domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
drop(cm);
|
||||
|
||||
cs.stop().await;
|
||||
|
||||
let bundle = result
|
||||
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||
|
||||
// Hot-swap into TLS configs
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Renew a certificate for a named route.
|
||||
pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
// Renewal is just re-provisioning
|
||||
self.provision_certificate(route_name).await
|
||||
}
|
||||
|
||||
/// Get the status of a certificate for a named route.
|
||||
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
|
||||
let route = self.options.routes.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
|
||||
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
if let Some(bundle) = cm.get_cert(&domain) {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
return Some(CertStatus {
|
||||
domain,
|
||||
source: format!("{:?}", bundle.metadata.source),
|
||||
expires_at: bundle.metadata.expires_at,
|
||||
is_valid: bundle.metadata.expires_at > now,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get current metrics snapshot.
|
||||
pub fn get_metrics(&self) -> Metrics {
|
||||
self.metrics.snapshot()
|
||||
}
|
||||
|
||||
/// Add a listening port at runtime.
|
||||
pub async fn add_listening_port(&mut self, port: u16) -> Result<()> {
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.add_port(port).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a listening port at runtime.
|
||||
pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> {
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.remove_port(port);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all currently listening ports.
|
||||
pub fn get_listening_ports(&self) -> Vec<u16> {
|
||||
self.listener_manager
|
||||
.as_ref()
|
||||
.map(|l| l.listening_ports())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get statistics snapshot.
|
||||
pub fn get_statistics(&self) -> Statistics {
|
||||
let uptime = self.started_at
|
||||
.map(|t| t.elapsed().as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
Statistics {
|
||||
active_connections: self.metrics.active_connections(),
|
||||
total_connections: self.metrics.total_connections(),
|
||||
routes_count: self.route_table.load().route_count() as u64,
|
||||
listening_ports: self.get_listening_ports(),
|
||||
uptime_seconds: uptime,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
|
||||
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
|
||||
info!("Socket handler relay path set to: {:?}", path);
|
||||
self.socket_handler_relay_path = path;
|
||||
}
|
||||
|
||||
/// Get the current socket handler relay path.
|
||||
pub fn get_socket_handler_relay_path(&self) -> Option<&str> {
|
||||
self.socket_handler_relay_path.as_deref()
|
||||
}
|
||||
|
||||
/// Load a certificate for a domain and hot-swap the TLS configuration.
|
||||
pub async fn load_certificate(
|
||||
&mut self,
|
||||
domain: &str,
|
||||
cert_pem: String,
|
||||
key_pem: String,
|
||||
ca_pem: Option<String>,
|
||||
) -> Result<()> {
|
||||
info!("Loading certificate for domain: {}", domain);
|
||||
|
||||
// Store in cert manager if available
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
ca_pem: ca_pem.clone(),
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400, // assume 90 days
|
||||
renewed_at: None,
|
||||
},
|
||||
};
|
||||
|
||||
let mut cm = cm_arc.lock().await;
|
||||
cm.load_static(domain.to_string(), bundle)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to store certificate: {}", e))?;
|
||||
}
|
||||
|
||||
// Hot-swap TLS config on the listener
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
|
||||
// Add the new cert
|
||||
tls_configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
});
|
||||
|
||||
// Also include all existing certs from cert manager
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
info!("Certificate loaded and TLS config updated for {}", domain);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get NFTables status.
|
||||
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
|
||||
match &self.nft_manager {
|
||||
Some(nft) => Ok(nft.status()),
|
||||
None => Ok(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply NFTables rules for routes using the nftables forwarding engine.
|
||||
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
|
||||
let nft_routes: Vec<&RouteConfig> = routes.iter()
|
||||
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
|
||||
.collect();
|
||||
|
||||
if nft_routes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Applying NFTables rules for {} routes", nft_routes.len());
|
||||
|
||||
let table_name = nft_routes.iter()
|
||||
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
|
||||
.unwrap_or_else(|| "rustproxy".to_string());
|
||||
|
||||
let mut nft = NftManager::new(Some(table_name));
|
||||
|
||||
for route in &nft_routes {
|
||||
let route_id = route.id.as_deref()
|
||||
.or(route.name.as_deref())
|
||||
.unwrap_or("unnamed");
|
||||
|
||||
let nft_options = match &route.action.nftables {
|
||||
Some(opts) => opts.clone(),
|
||||
None => rustproxy_config::NfTablesOptions {
|
||||
preserve_source_ip: None,
|
||||
protocol: None,
|
||||
max_rate: None,
|
||||
priority: None,
|
||||
table_name: None,
|
||||
use_ip_sets: None,
|
||||
use_advanced_nat: None,
|
||||
},
|
||||
};
|
||||
|
||||
let targets = match &route.action.targets {
|
||||
Some(targets) => targets,
|
||||
None => {
|
||||
warn!("NFTables route '{}' has no targets, skipping", route_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let source_ports = route.route_match.ports.to_ports();
|
||||
for target in targets {
|
||||
let target_host = target.host.first().to_string();
|
||||
let target_port_spec = &target.port;
|
||||
|
||||
for &source_port in &source_ports {
|
||||
let resolved_port = target_port_spec.resolve(source_port);
|
||||
let rules = rule_builder::build_dnat_rule(
|
||||
nft.table_name(),
|
||||
"prerouting",
|
||||
source_port,
|
||||
&target_host,
|
||||
resolved_port,
|
||||
&nft_options,
|
||||
);
|
||||
|
||||
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
|
||||
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
|
||||
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.nft_manager = Some(nft);
|
||||
}
|
||||
|
||||
/// Update NFTables rules when routes change.
|
||||
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
|
||||
// Clean up old rules
|
||||
if let Some(ref mut nft) = self.nft_manager {
|
||||
if let Err(e) = nft.cleanup().await {
|
||||
warn!("NFTables cleanup during update failed: {}", e);
|
||||
}
|
||||
}
|
||||
self.nft_manager = None;
|
||||
|
||||
// Apply new rules
|
||||
self.apply_nftables_rules(new_routes).await;
|
||||
}
|
||||
|
||||
/// Extract TLS configurations from route configs.
|
||||
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
|
||||
let mut configs = HashMap::new();
|
||||
|
||||
for route in routes {
|
||||
let tls_mode = route.tls_mode();
|
||||
let needs_cert = matches!(
|
||||
tls_mode,
|
||||
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
|
||||
);
|
||||
if !needs_cert {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_config.cert.clone(),
|
||||
key_pem: cert_config.key.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
configs
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user