Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e806f7257f | |||
| af4908b63f |
@@ -1,5 +1,12 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
|
||||||
|
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
|
||||||
|
|
||||||
|
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
|
||||||
|
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
|
||||||
|
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
|
||||||
|
|
||||||
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
|
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
|
||||||
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
|
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@push.rocks/smartproxy",
|
"name": "@push.rocks/smartproxy",
|
||||||
"version": "27.8.2",
|
"version": "27.9.0",
|
||||||
"private": false,
|
"private": false,
|
||||||
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
||||||
"main": "dist_ts/index.js",
|
"main": "dist_ts/index.js",
|
||||||
|
|||||||
Generated
+1
@@ -1319,6 +1319,7 @@ dependencies = [
|
|||||||
"rustproxy-http",
|
"rustproxy-http",
|
||||||
"rustproxy-metrics",
|
"rustproxy-metrics",
|
||||||
"rustproxy-routing",
|
"rustproxy-routing",
|
||||||
|
"rustproxy-security",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"socket2 0.5.10",
|
"socket2 0.5.10",
|
||||||
|
|||||||
@@ -3,15 +3,15 @@
|
|||||||
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
|
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
|
||||||
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
|
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
|
||||||
|
|
||||||
pub mod route_types;
|
|
||||||
pub mod proxy_options;
|
pub mod proxy_options;
|
||||||
pub mod tls_types;
|
pub mod route_types;
|
||||||
pub mod security_types;
|
pub mod security_types;
|
||||||
|
pub mod tls_types;
|
||||||
pub mod validation;
|
pub mod validation;
|
||||||
|
|
||||||
// Re-export all primary types
|
// Re-export all primary types
|
||||||
pub use route_types::*;
|
|
||||||
pub use proxy_options::*;
|
pub use proxy_options::*;
|
||||||
pub use tls_types::*;
|
pub use route_types::*;
|
||||||
pub use security_types::*;
|
pub use security_types::*;
|
||||||
|
pub use tls_types::*;
|
||||||
pub use validation::*;
|
pub use validation::*;
|
||||||
|
|||||||
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
|
|||||||
pub retention_seconds: Option<u64>,
|
pub retention_seconds: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Global ingress security policy.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SecurityPolicy {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub blocked_ips: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub blocked_cidrs: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// RustProxy configuration options.
|
/// RustProxy configuration options.
|
||||||
/// Matches TypeScript: `ISmartProxyOptions`
|
/// Matches TypeScript: `ISmartProxyOptions`
|
||||||
///
|
///
|
||||||
@@ -235,6 +245,10 @@ pub struct RustProxyOptions {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub metrics: Option<MetricsConfig>,
|
pub metrics: Option<MetricsConfig>,
|
||||||
|
|
||||||
|
/// Global ingress security policy, enforced before route selection.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub security_policy: Option<SecurityPolicy>,
|
||||||
|
|
||||||
// ─── ACME ────────────────────────────────────────────────────────
|
// ─── ACME ────────────────────────────────────────────────────────
|
||||||
/// Global ACME configuration
|
/// Global ACME configuration
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -275,6 +289,7 @@ impl Default for RustProxyOptions {
|
|||||||
use_http_proxy: None,
|
use_http_proxy: None,
|
||||||
http_proxy_port: None,
|
http_proxy_port: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
|
security_policy: None,
|
||||||
acme: None,
|
acme: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,10 +111,7 @@ pub enum IpAllowEntry {
|
|||||||
/// Plain IP/CIDR — allowed for all domains on this route
|
/// Plain IP/CIDR — allowed for all domains on this route
|
||||||
Plain(String),
|
Plain(String),
|
||||||
/// Domain-scoped — allowed only when the requested domain matches
|
/// Domain-scoped — allowed only when the requested domain matches
|
||||||
DomainScoped {
|
DomainScoped { ip: String, domains: Vec<String> },
|
||||||
ip: String,
|
|
||||||
domains: Vec<String>,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Security options for routes.
|
/// Security options for routes.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::route_types::{RouteConfig, RouteActionType};
|
use crate::route_types::{RouteActionType, RouteConfig};
|
||||||
|
|
||||||
/// Validation errors for route configurations.
|
/// Validation errors for route configurations.
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
@@ -30,9 +30,10 @@ pub enum ValidationError {
|
|||||||
/// Validate a single route configuration.
|
/// Validate a single route configuration.
|
||||||
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
|
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
let name = route.name.clone().unwrap_or_else(|| {
|
let name = route
|
||||||
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
|
.name
|
||||||
});
|
.clone()
|
||||||
|
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
|
||||||
|
|
||||||
// Check ports
|
// Check ports
|
||||||
let ports = route.listening_ports();
|
let ports = route.listening_ports();
|
||||||
@@ -160,7 +161,9 @@ mod tests {
|
|||||||
let mut route = make_valid_route();
|
let mut route = make_valid_route();
|
||||||
route.action.targets = None;
|
route.action.targets = None;
|
||||||
let errors = validate_route(&route).unwrap_err();
|
let errors = validate_route(&route).unwrap_err();
|
||||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
assert!(errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -168,7 +171,9 @@ mod tests {
|
|||||||
let mut route = make_valid_route();
|
let mut route = make_valid_route();
|
||||||
route.action.targets = Some(vec![]);
|
route.action.targets = Some(vec![]);
|
||||||
let errors = validate_route(&route).unwrap_err();
|
let errors = validate_route(&route).unwrap_err();
|
||||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
assert!(errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -176,7 +181,9 @@ mod tests {
|
|||||||
let mut route = make_valid_route();
|
let mut route = make_valid_route();
|
||||||
route.route_match.ports = PortRange::Single(0);
|
route.route_match.ports = PortRange::Single(0);
|
||||||
let errors = validate_route(&route).unwrap_err();
|
let errors = validate_route(&route).unwrap_err();
|
||||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
assert!(errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -186,7 +193,9 @@ mod tests {
|
|||||||
let mut r2 = make_valid_route();
|
let mut r2 = make_valid_route();
|
||||||
r2.id = Some("route-1".to_string());
|
r2.id = Some("route-1".to_string());
|
||||||
let errors = validate_routes(&[r1, r2]).unwrap_err();
|
let errors = validate_routes(&[r1, r2]).unwrap_err();
|
||||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
assert!(errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -2,10 +2,10 @@
|
|||||||
//!
|
//!
|
||||||
//! Metrics and throughput tracking for RustProxy.
|
//! Metrics and throughput tracking for RustProxy.
|
||||||
|
|
||||||
pub mod throughput;
|
|
||||||
pub mod collector;
|
pub mod collector;
|
||||||
pub mod log_dedup;
|
pub mod log_dedup;
|
||||||
|
pub mod throughput;
|
||||||
|
|
||||||
pub use throughput::*;
|
|
||||||
pub use collector::*;
|
pub use collector::*;
|
||||||
pub use log_dedup::*;
|
pub use log_dedup::*;
|
||||||
|
pub use throughput::*;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
@@ -47,7 +47,10 @@ impl LogDeduplicator {
|
|||||||
let map_key = format!("{}:{}", category, key);
|
let map_key = format!("{}:{}", category, key);
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
|
||||||
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
|
let entry = self
|
||||||
|
.events
|
||||||
|
.entry(map_key)
|
||||||
|
.or_insert_with(|| AggregatedEvent {
|
||||||
category: category.to_string(),
|
category: category.to_string(),
|
||||||
first_message: message.to_string(),
|
first_message: message.to_string(),
|
||||||
count: AtomicU64::new(0),
|
count: AtomicU64::new(0),
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
|
|||||||
rustproxy-config = { workspace = true }
|
rustproxy-config = { workspace = true }
|
||||||
rustproxy-routing = { workspace = true }
|
rustproxy-routing = { workspace = true }
|
||||||
rustproxy-metrics = { workspace = true }
|
rustproxy-metrics = { workspace = true }
|
||||||
|
rustproxy-security = { workspace = true }
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
|
|||||||
@@ -7,8 +7,8 @@
|
|||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
|
|||||||
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
||||||
let mut recycled = 0u64;
|
let mut recycled = 0u64;
|
||||||
self.connections.retain(|_, entry| {
|
self.connections.retain(|_, entry| {
|
||||||
let matches = entry.domain.as_deref()
|
let matches = entry
|
||||||
|
.domain
|
||||||
|
.as_deref()
|
||||||
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
if matches {
|
if matches {
|
||||||
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
|
|||||||
let mut recycled = 0u64;
|
let mut recycled = 0u64;
|
||||||
self.connections.retain(|_, entry| {
|
self.connections.retain(|_, entry| {
|
||||||
if entry.route_id.as_deref() == Some(route_id) {
|
if entry.route_id.as_deref() == Some(route_id) {
|
||||||
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
|
if !RequestFilter::check_ip_security(
|
||||||
|
new_security,
|
||||||
|
&entry.source_ip,
|
||||||
|
entry.domain.as_deref(),
|
||||||
|
) {
|
||||||
info!(
|
info!(
|
||||||
"Terminating connection from {} — IP now blocked on route '{}'",
|
"Terminating connection from {} — IP now blocked on route '{}'",
|
||||||
entry.source_ip, route_id
|
entry.source_ip, route_id
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ impl ConnectionTracker {
|
|||||||
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
||||||
// Check per-IP connection limit
|
// Check per-IP connection limit
|
||||||
if let Some(max) = self.max_per_ip {
|
if let Some(max) = self.max_per_ip {
|
||||||
let count = self.active
|
let count = self
|
||||||
|
.active
|
||||||
.get(ip)
|
.get(ip)
|
||||||
.map(|c| c.value().load(Ordering::Relaxed))
|
.map(|c| c.value().load(Ordering::Relaxed))
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
@@ -48,7 +49,10 @@ impl ConnectionTracker {
|
|||||||
let timestamps = entry.value_mut();
|
let timestamps = entry.value_mut();
|
||||||
|
|
||||||
// Remove timestamps older than 1 minute
|
// Remove timestamps older than 1 minute
|
||||||
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
while timestamps
|
||||||
|
.front()
|
||||||
|
.is_some_and(|t| now.duration_since(*t) >= one_minute)
|
||||||
|
{
|
||||||
timestamps.pop_front();
|
timestamps.pop_front();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +115,6 @@ impl ConnectionTracker {
|
|||||||
pub fn tracked_ips(&self) -> usize {
|
pub fn tracked_ips(&self) -> usize {
|
||||||
self.active.len()
|
self.active.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
use rustproxy_metrics::MetricsCollector;
|
use rustproxy_metrics::MetricsCollector;
|
||||||
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
|
|||||||
if let Some(data) = initial_data {
|
if let Some(data) = initial_data {
|
||||||
backend.write_all(data).await?;
|
backend.write_all(data).await?;
|
||||||
if let Some(ref ctx) = metrics {
|
if let Some(ref ctx) = metrics {
|
||||||
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
ctx.collector.record_bytes(
|
||||||
|
data.len() as u64,
|
||||||
|
0,
|
||||||
|
ctx.route_id.as_deref(),
|
||||||
|
ctx.source_ip.as_deref(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
|||||||
total += n as u64;
|
total += n as u64;
|
||||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
if let Some(ref ctx) = metrics_c2b {
|
if let Some(ref ctx) = metrics_c2b {
|
||||||
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
ctx.collector.record_bytes(
|
||||||
|
n as u64,
|
||||||
|
0,
|
||||||
|
ctx.route_id.as_deref(),
|
||||||
|
ctx.source_ip.as_deref(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||||
let _ = tokio::time::timeout(
|
let _ =
|
||||||
std::time::Duration::from_secs(2),
|
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
|
||||||
backend_write.shutdown(),
|
|
||||||
).await;
|
|
||||||
total
|
total
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
|||||||
total += n as u64;
|
total += n as u64;
|
||||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
if let Some(ref ctx) = metrics_b2c {
|
if let Some(ref ctx) = metrics_b2c {
|
||||||
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
ctx.collector.record_bytes(
|
||||||
|
0,
|
||||||
|
n as u64,
|
||||||
|
ctx.route_id.as_deref(),
|
||||||
|
ctx.source_ip.as_deref(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||||
let _ = tokio::time::timeout(
|
let _ =
|
||||||
std::time::Duration::from_secs(2),
|
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
|
||||||
client_write.shutdown(),
|
|
||||||
).await;
|
|
||||||
total
|
total
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -4,26 +4,26 @@
|
|||||||
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
|
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
|
||||||
//! and UDP datagram session tracking with forwarding.
|
//! and UDP datagram session tracking with forwarding.
|
||||||
|
|
||||||
pub mod tcp_listener;
|
pub mod connection_registry;
|
||||||
pub mod sni_parser;
|
pub mod connection_tracker;
|
||||||
pub mod forwarder;
|
pub mod forwarder;
|
||||||
pub mod proxy_protocol;
|
pub mod proxy_protocol;
|
||||||
pub mod tls_handler;
|
|
||||||
pub mod connection_tracker;
|
|
||||||
pub mod connection_registry;
|
|
||||||
pub mod socket_opts;
|
|
||||||
pub mod udp_session;
|
|
||||||
pub mod udp_listener;
|
|
||||||
pub mod quic_handler;
|
pub mod quic_handler;
|
||||||
|
pub mod sni_parser;
|
||||||
|
pub mod socket_opts;
|
||||||
|
pub mod tcp_listener;
|
||||||
|
pub mod tls_handler;
|
||||||
|
pub mod udp_listener;
|
||||||
|
pub mod udp_session;
|
||||||
|
|
||||||
pub use tcp_listener::*;
|
pub use connection_registry::*;
|
||||||
pub use sni_parser::*;
|
pub use connection_tracker::*;
|
||||||
pub use forwarder::*;
|
pub use forwarder::*;
|
||||||
pub use proxy_protocol::*;
|
pub use proxy_protocol::*;
|
||||||
pub use tls_handler::*;
|
|
||||||
pub use connection_tracker::*;
|
|
||||||
pub use connection_registry::*;
|
|
||||||
pub use socket_opts::*;
|
|
||||||
pub use udp_session::*;
|
|
||||||
pub use udp_listener::*;
|
|
||||||
pub use quic_handler::*;
|
pub use quic_handler::*;
|
||||||
|
pub use sni_parser::*;
|
||||||
|
pub use socket_opts::*;
|
||||||
|
pub use tcp_listener::*;
|
||||||
|
pub use tls_handler::*;
|
||||||
|
pub use udp_listener::*;
|
||||||
|
pub use udp_session::*;
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
.position(|w| w == b"\r\n")
|
.position(|w| w == b"\r\n")
|
||||||
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
||||||
|
|
||||||
let line = std::str::from_utf8(&data[..line_end])
|
let line =
|
||||||
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||||
|
|
||||||
if !line.starts_with("PROXY ") {
|
if !line.starts_with("PROXY ") {
|
||||||
return Err(ProxyProtocolError::InvalidHeader);
|
return Err(ProxyProtocolError::InvalidHeader);
|
||||||
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
let command = data[12] & 0x0F;
|
let command = data[12] & 0x0F;
|
||||||
// 0x0 = LOCAL, 0x1 = PROXY
|
// 0x0 = LOCAL, 0x1 = PROXY
|
||||||
if command > 1 {
|
if command > 1 {
|
||||||
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
|
return Err(ProxyProtocolError::Parse(format!(
|
||||||
|
"Unknown command: {}",
|
||||||
|
command
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address family (high nibble) + transport (low nibble) of byte 13
|
// Address family (high nibble) + transport (low nibble) of byte 13
|
||||||
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
// AF_INET (0x1) + STREAM (0x1) = TCP4
|
// AF_INET (0x1) + STREAM (0x1) = TCP4
|
||||||
(0x1, 0x1) => {
|
(0x1, 0x1) => {
|
||||||
if addr_len < 12 {
|
if addr_len < 12 {
|
||||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
return Err(ProxyProtocolError::Parse(
|
||||||
|
"IPv4 address block too short".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||||
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
// AF_INET (0x1) + DGRAM (0x2) = UDP4
|
// AF_INET (0x1) + DGRAM (0x2) = UDP4
|
||||||
(0x1, 0x2) => {
|
(0x1, 0x2) => {
|
||||||
if addr_len < 12 {
|
if addr_len < 12 {
|
||||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
return Err(ProxyProtocolError::Parse(
|
||||||
|
"IPv4 address block too short".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||||
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
|
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
|
||||||
(0x2, 0x1) => {
|
(0x2, 0x1) => {
|
||||||
if addr_len < 36 {
|
if addr_len < 36 {
|
||||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
return Err(ProxyProtocolError::Parse(
|
||||||
|
"IPv6 address block too short".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||||
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
|
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
|
||||||
(0x2, 0x2) => {
|
(0x2, 0x2) => {
|
||||||
if addr_len < 36 {
|
if addr_len < 36 {
|
||||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
return Err(ProxyProtocolError::Parse(
|
||||||
|
"IPv6 address block too short".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||||
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a PROXY protocol v2 binary header.
|
/// Generate a PROXY protocol v2 binary header.
|
||||||
pub fn generate_v2(
|
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
|
||||||
source: &SocketAddr,
|
|
||||||
dest: &SocketAddr,
|
|
||||||
transport: ProxyV2Transport,
|
|
||||||
) -> Vec<u8> {
|
|
||||||
let transport_nibble: u8 = match transport {
|
let transport_nibble: u8 = match transport {
|
||||||
ProxyV2Transport::Stream => 0x1,
|
ProxyV2Transport::Stream => 0x1,
|
||||||
ProxyV2Transport::Datagram => 0x2,
|
ProxyV2Transport::Datagram => 0x2,
|
||||||
@@ -462,7 +469,10 @@ mod tests {
|
|||||||
header.push(0x11);
|
header.push(0x11);
|
||||||
header.extend_from_slice(&12u16.to_be_bytes());
|
header.extend_from_slice(&12u16.to_be_bytes());
|
||||||
header.extend_from_slice(&[0u8; 12]);
|
header.extend_from_slice(&[0u8; 12]);
|
||||||
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
|
assert!(matches!(
|
||||||
|
parse_v2(&header),
|
||||||
|
Err(ProxyProtocolError::UnsupportedVersion)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
|
|||||||
use rustproxy_config::{RouteConfig, TransportProtocol};
|
use rustproxy_config::{RouteConfig, TransportProtocol};
|
||||||
use rustproxy_metrics::MetricsCollector;
|
use rustproxy_metrics::MetricsCollector;
|
||||||
use rustproxy_routing::{MatchContext, RouteManager};
|
use rustproxy_routing::{MatchContext, RouteManager};
|
||||||
|
use rustproxy_security::IpBlockList;
|
||||||
|
|
||||||
use rustproxy_http::h3_service::H3ProxyService;
|
use rustproxy_http::h3_service::H3ProxyService;
|
||||||
|
|
||||||
use crate::connection_tracker::ConnectionTracker;
|
|
||||||
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
||||||
|
use crate::connection_tracker::ConnectionTracker;
|
||||||
|
|
||||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||||
///
|
///
|
||||||
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
|
|||||||
quinn::EndpointConfig::default(),
|
quinn::EndpointConfig::default(),
|
||||||
Some(server_config),
|
Some(server_config),
|
||||||
socket,
|
socket,
|
||||||
quinn::default_runtime()
|
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
info!("QUIC endpoint listening on port {}", port);
|
info!("QUIC endpoint listening on port {}", port);
|
||||||
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
|||||||
port: u16,
|
port: u16,
|
||||||
tls_config: Arc<RustlsServerConfig>,
|
tls_config: Arc<RustlsServerConfig>,
|
||||||
proxy_ips: Arc<Vec<IpAddr>>,
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
|
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
) -> anyhow::Result<QuicProxyRelay> {
|
) -> anyhow::Result<QuicProxyRelay> {
|
||||||
// Bind external socket on the real port
|
// Bind external socket on the real port
|
||||||
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
|||||||
quinn::EndpointConfig::default(),
|
quinn::EndpointConfig::default(),
|
||||||
Some(server_config),
|
Some(server_config),
|
||||||
internal_socket,
|
internal_socket,
|
||||||
quinn::default_runtime()
|
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let real_client_map = Arc::new(DashMap::new());
|
let real_client_map = Arc::new(DashMap::new());
|
||||||
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
|||||||
external_socket,
|
external_socket,
|
||||||
quinn_internal_addr,
|
quinn_internal_addr,
|
||||||
proxy_ips,
|
proxy_ips,
|
||||||
|
security_policy,
|
||||||
Arc::clone(&real_client_map),
|
Arc::clone(&real_client_map),
|
||||||
cancel,
|
cancel,
|
||||||
));
|
));
|
||||||
|
|
||||||
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
|
info!(
|
||||||
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
|
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
|
||||||
|
port, quinn_internal_addr
|
||||||
|
);
|
||||||
|
Ok(QuicProxyRelay {
|
||||||
|
endpoint,
|
||||||
|
relay_task,
|
||||||
|
real_client_map,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
||||||
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
|
|||||||
external_socket: Arc<UdpSocket>,
|
external_socket: Arc<UdpSocket>,
|
||||||
quinn_internal_addr: SocketAddr,
|
quinn_internal_addr: SocketAddr,
|
||||||
proxy_ips: Arc<Vec<IpAddr>>,
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
|
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||||
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
) {
|
) {
|
||||||
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
|
|||||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||||
match crate::proxy_protocol::parse_v2(datagram) {
|
match crate::proxy_protocol::parse_v2(datagram) {
|
||||||
Ok((header, _consumed)) => {
|
Ok((header, _consumed)) => {
|
||||||
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
|
debug!(
|
||||||
|
"QUIC PROXY v2 from {}: real client {}",
|
||||||
|
src_addr, header.source_addr
|
||||||
|
);
|
||||||
proxy_addr_map.insert(src_addr, header.source_addr);
|
proxy_addr_map.insert(src_addr, header.source_addr);
|
||||||
continue; // consume the PROXY v2 datagram
|
continue; // consume the PROXY v2 datagram
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
|
debug!(
|
||||||
|
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
|
||||||
|
src_addr, e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine real client address
|
// Determine real client address
|
||||||
let real_client = proxy_addr_map.get(&src_addr)
|
let real_client = proxy_addr_map
|
||||||
|
.get(&src_addr)
|
||||||
.map(|r| *r)
|
.map(|r| *r)
|
||||||
.unwrap_or(src_addr);
|
.unwrap_or(src_addr);
|
||||||
|
|
||||||
|
let block_list = security_policy.load();
|
||||||
|
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
|
||||||
|
debug!(
|
||||||
|
"QUIC datagram from {} blocked by global security policy",
|
||||||
|
real_client
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Get or create relay session for this external source
|
// Get or create relay session for this external source
|
||||||
let session = match relay_sessions.get(&src_addr) {
|
let session = match relay_sessions.get(&src_addr) {
|
||||||
Some(s) => {
|
Some(s) => {
|
||||||
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
s.last_activity
|
||||||
|
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
Arc::clone(s.value())
|
Arc::clone(s.value())
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
||||||
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
|
warn!(
|
||||||
|
"QUIC relay: failed to connect relay socket to {}: {}",
|
||||||
|
quinn_internal_addr, e
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let relay_local_addr = match relay_socket.local_addr() {
|
let relay_local_addr = match relay_socket.local_addr() {
|
||||||
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
|
|||||||
});
|
});
|
||||||
|
|
||||||
relay_sessions.insert(src_addr, Arc::clone(&session));
|
relay_sessions.insert(src_addr, Arc::clone(&session));
|
||||||
debug!("QUIC relay: new session for {} (relay {}), real client {}",
|
debug!(
|
||||||
src_addr, relay_local_addr, real_client);
|
"QUIC relay: new session for {} (relay {}), real client {}",
|
||||||
|
src_addr, relay_local_addr, real_client
|
||||||
|
);
|
||||||
|
|
||||||
session
|
session
|
||||||
}
|
}
|
||||||
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
|
|||||||
if last_cleanup.elapsed() >= cleanup_interval {
|
if last_cleanup.elapsed() >= cleanup_interval {
|
||||||
last_cleanup = Instant::now();
|
last_cleanup = Instant::now();
|
||||||
let now_ms = epoch.elapsed().as_millis() as u64;
|
let now_ms = epoch.elapsed().as_millis() as u64;
|
||||||
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
|
let stale_keys: Vec<SocketAddr> = relay_sessions
|
||||||
|
.iter()
|
||||||
.filter(|entry| {
|
.filter(|entry| {
|
||||||
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
let age =
|
||||||
|
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||||
age > session_timeout_ms
|
age > session_timeout_ms
|
||||||
})
|
})
|
||||||
.map(|entry| *entry.key())
|
.map(|entry| *entry.key())
|
||||||
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
|
|||||||
|
|
||||||
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
||||||
// but no relay session was ever created, e.g. client never sent data)
|
// but no relay session was ever created, e.g. client never sent data)
|
||||||
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
|
let orphaned: Vec<SocketAddr> = proxy_addr_map
|
||||||
|
.iter()
|
||||||
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
||||||
.map(|entry| *entry.key())
|
.map(|entry| *entry.key())
|
||||||
.collect();
|
.collect();
|
||||||
for key in orphaned {
|
for key in orphaned {
|
||||||
proxy_addr_map.remove(&key);
|
proxy_addr_map.remove(&key);
|
||||||
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
|
debug!(
|
||||||
|
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
|
||||||
|
key
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -328,8 +365,14 @@ async fn relay_return_path(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
|
if let Err(e) = external_socket
|
||||||
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
|
.send_to(&buf[..len], external_src_addr)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
debug!(
|
||||||
|
"QUIC relay return send error to {}: {}",
|
||||||
|
external_src_addr, e
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
|
|||||||
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
||||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||||
connection_registry: Arc<ConnectionRegistry>,
|
connection_registry: Arc<ConnectionRegistry>,
|
||||||
|
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||||
) {
|
) {
|
||||||
loop {
|
loop {
|
||||||
let incoming = tokio::select! {
|
let incoming = tokio::select! {
|
||||||
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
|
|||||||
let remote_addr = incoming.remote_address();
|
let remote_addr = incoming.remote_address();
|
||||||
|
|
||||||
// Resolve real client IP from PROXY protocol map if available
|
// Resolve real client IP from PROXY protocol map if available
|
||||||
let real_addr = real_client_map.as_ref()
|
let real_addr = real_client_map
|
||||||
|
.as_ref()
|
||||||
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
||||||
.unwrap_or(remote_addr);
|
.unwrap_or(remote_addr);
|
||||||
let ip = real_addr.ip();
|
let ip = real_addr.ip();
|
||||||
|
|
||||||
|
let block_list = security_policy.load();
|
||||||
|
if !block_list.is_empty() && block_list.is_blocked(&ip) {
|
||||||
|
debug!(
|
||||||
|
"QUIC connection from {} blocked by global security policy",
|
||||||
|
real_addr
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Per-IP rate limiting
|
// Per-IP rate limiting
|
||||||
if !conn_tracker.try_accept(&ip) {
|
if !conn_tracker.try_accept(&ip) {
|
||||||
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
||||||
@@ -414,7 +468,10 @@ pub async fn quic_accept_loop(
|
|||||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||||
security, &ip, ctx.domain,
|
security, &ip, ctx.domain,
|
||||||
) {
|
) {
|
||||||
debug!("QUIC connection from {} blocked by route security", real_addr);
|
debug!(
|
||||||
|
"QUIC connection from {} blocked by route security",
|
||||||
|
real_addr
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -425,7 +482,8 @@ pub async fn quic_accept_loop(
|
|||||||
|
|
||||||
// Resolve per-route cancel token (child of global cancel)
|
// Resolve per-route cancel token (child of global cancel)
|
||||||
let route_cancel = match route_id.as_deref() {
|
let route_cancel = match route_id.as_deref() {
|
||||||
Some(id) => route_cancels.entry(id.to_string())
|
Some(id) => route_cancels
|
||||||
|
.entry(id.to_string())
|
||||||
.or_insert_with(|| cancel.child_token())
|
.or_insert_with(|| cancel.child_token())
|
||||||
.clone(),
|
.clone(),
|
||||||
None => cancel.child_token(),
|
None => cancel.child_token(),
|
||||||
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
|
|||||||
let metrics = Arc::clone(&metrics);
|
let metrics = Arc::clone(&metrics);
|
||||||
let conn_tracker = Arc::clone(&conn_tracker);
|
let conn_tracker = Arc::clone(&conn_tracker);
|
||||||
let h3_svc = h3_service.clone();
|
let h3_svc = h3_service.clone();
|
||||||
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
let real_client_addr = if real_addr != remote_addr {
|
||||||
|
Some(real_addr)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Register in connection registry (RAII guard removes on drop)
|
// Register in connection registry (RAII guard removes on drop)
|
||||||
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
|
|||||||
impl Drop for QuicConnGuard {
|
impl Drop for QuicConnGuard {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.tracker.connection_closed(&self.ip);
|
self.tracker.connection_closed(&self.ip);
|
||||||
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
self.metrics
|
||||||
|
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let _guard = QuicConnGuard {
|
let _guard = QuicConnGuard {
|
||||||
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
|
|||||||
route_id,
|
route_id,
|
||||||
};
|
};
|
||||||
|
|
||||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await {
|
match handle_quic_connection(
|
||||||
|
incoming,
|
||||||
|
route,
|
||||||
|
port,
|
||||||
|
Arc::clone(&metrics),
|
||||||
|
&conn_cancel,
|
||||||
|
h3_svc,
|
||||||
|
real_client_addr,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
||||||
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
||||||
}
|
}
|
||||||
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
|
|||||||
debug!("QUIC connection established from {}", effective_addr);
|
debug!("QUIC connection established from {}", effective_addr);
|
||||||
|
|
||||||
// Check if this route has HTTP/3 enabled
|
// Check if this route has HTTP/3 enabled
|
||||||
let enable_http3 = route.action.udp.as_ref()
|
let enable_http3 = route
|
||||||
|
.action
|
||||||
|
.udp
|
||||||
|
.as_ref()
|
||||||
.and_then(|u| u.quic.as_ref())
|
.and_then(|u| u.quic.as_ref())
|
||||||
.and_then(|q| q.enable_http3)
|
.and_then(|q| q.enable_http3)
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
if enable_http3 {
|
if enable_http3 {
|
||||||
if let Some(ref h3_svc) = h3_service {
|
if let Some(ref h3_svc) = h3_service {
|
||||||
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
|
debug!(
|
||||||
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
|
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
|
||||||
|
route.name
|
||||||
|
);
|
||||||
|
h3_svc
|
||||||
|
.handle_connection(connection, &route, port, real_client_addr, cancel)
|
||||||
|
.await
|
||||||
} else {
|
} else {
|
||||||
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
|
warn!(
|
||||||
|
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
|
||||||
|
route.name
|
||||||
|
);
|
||||||
// Keep connection alive until cancelled
|
// Keep connection alive until cancelled
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = cancel.cancelled() => {}
|
_ = cancel.cancelled() => {}
|
||||||
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
|
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -545,7 +630,10 @@ async fn handle_quic_stream_forwarding(
|
|||||||
let metrics_arc = metrics;
|
let metrics_arc = metrics;
|
||||||
|
|
||||||
// Resolve backend target
|
// Resolve backend target
|
||||||
let target = route.action.targets.as_ref()
|
let target = route
|
||||||
|
.action
|
||||||
|
.targets
|
||||||
|
.as_ref()
|
||||||
.and_then(|t| t.first())
|
.and_then(|t| t.first())
|
||||||
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
||||||
let backend_host = target.host.first();
|
let backend_host = target.host.first();
|
||||||
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
|
|||||||
|
|
||||||
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match forward_quic_stream_to_tcp(
|
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
|
||||||
send_stream,
|
.await
|
||||||
recv_stream,
|
{
|
||||||
&backend_addr,
|
|
||||||
stream_cancel,
|
|
||||||
).await {
|
|
||||||
Ok((bytes_in, bytes_out)) => {
|
Ok((bytes_in, bytes_out)) => {
|
||||||
stream_metrics.record_bytes(
|
stream_metrics.record_bytes(
|
||||||
bytes_in, bytes_out,
|
bytes_in,
|
||||||
|
bytes_out,
|
||||||
stream_route_id.as_deref(),
|
stream_route_id.as_deref(),
|
||||||
Some(&ip_str),
|
Some(&ip_str),
|
||||||
);
|
);
|
||||||
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
debug!(
|
||||||
|
"QUIC stream forwarded: {}B in, {}B out",
|
||||||
|
bytes_in, bytes_out
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("QUIC stream forwarding error: {}", e);
|
debug!("QUIC stream forwarding error: {}", e);
|
||||||
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
|
|||||||
total += n as u64;
|
total += n as u64;
|
||||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
let _ = tokio::time::timeout(
|
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
|
||||||
std::time::Duration::from_secs(2),
|
|
||||||
tcp_write.shutdown(),
|
|
||||||
).await;
|
|
||||||
total
|
total
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -721,8 +807,8 @@ mod tests {
|
|||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
|
||||||
// Generate a single self-signed cert and use its key pair
|
// Generate a single self-signed cert and use its key pair
|
||||||
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
let self_signed =
|
||||||
.unwrap();
|
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
|
||||||
let cert_der = self_signed.cert.der().clone();
|
let cert_der = self_signed.cert.der().clone();
|
||||||
let key_der = self_signed.key_pair.serialize_der();
|
let key_der = self_signed.key_pair.serialize_der();
|
||||||
|
|
||||||
@@ -737,6 +823,10 @@ mod tests {
|
|||||||
|
|
||||||
// Port 0 = OS assigns a free port
|
// Port 0 = OS assigns a free port
|
||||||
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
||||||
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"QUIC endpoint creation failed: {:?}",
|
||||||
|
result.err()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handshake length (3 bytes) - informational, we parse incrementally
|
// Handshake length (3 bytes) - informational, we parse incrementally
|
||||||
let _handshake_len = ((data[6] as usize) << 16)
|
let _handshake_len =
|
||||||
| ((data[7] as usize) << 8)
|
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
|
||||||
| (data[8] as usize);
|
|
||||||
|
|
||||||
let hello = &data[9..];
|
let hello = &data[9..];
|
||||||
|
|
||||||
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
|
|||||||
pub fn extract_http_host(data: &[u8]) -> Option<String> {
|
pub fn extract_http_host(data: &[u8]) -> Option<String> {
|
||||||
let text = std::str::from_utf8(data).ok()?;
|
let text = std::str::from_utf8(data).ok()?;
|
||||||
for line in text.split("\r\n") {
|
for line in text.split("\r\n") {
|
||||||
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
|
if let Some(value) = line
|
||||||
|
.strip_prefix("Host: ")
|
||||||
|
.or_else(|| line.strip_prefix("host: "))
|
||||||
|
{
|
||||||
// Strip port if present
|
// Strip port if present
|
||||||
let host = value.split(':').next().unwrap_or(value).trim();
|
let host = value.split(':').next().unwrap_or(value).trim();
|
||||||
if !host.is_empty() {
|
if !host.is_empty() {
|
||||||
@@ -213,7 +215,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_too_short() {
|
fn test_too_short() {
|
||||||
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
|
assert!(matches!(
|
||||||
|
extract_sni(&[0x16, 0x03]),
|
||||||
|
SniResult::NeedMoreData
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -263,7 +268,8 @@ mod tests {
|
|||||||
// Extension: type=0x0000 (SNI), length, data
|
// Extension: type=0x0000 (SNI), length, data
|
||||||
let sni_extension = {
|
let sni_extension = {
|
||||||
let mut e = Vec::new();
|
let mut e = Vec::new();
|
||||||
e.push(0x00); e.push(0x00); // SNI type
|
e.push(0x00);
|
||||||
|
e.push(0x00); // SNI type
|
||||||
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
||||||
e.push((sni_ext_data.len() & 0xFF) as u8);
|
e.push((sni_ext_data.len() & 0xFF) as u8);
|
||||||
e.extend_from_slice(&sni_ext_data);
|
e.extend_from_slice(&sni_ext_data);
|
||||||
@@ -283,16 +289,20 @@ mod tests {
|
|||||||
let hello_body = {
|
let hello_body = {
|
||||||
let mut h = Vec::new();
|
let mut h = Vec::new();
|
||||||
// Client version TLS 1.2
|
// Client version TLS 1.2
|
||||||
h.push(0x03); h.push(0x03);
|
h.push(0x03);
|
||||||
|
h.push(0x03);
|
||||||
// Random (32 bytes)
|
// Random (32 bytes)
|
||||||
h.extend_from_slice(&[0u8; 32]);
|
h.extend_from_slice(&[0u8; 32]);
|
||||||
// Session ID length = 0
|
// Session ID length = 0
|
||||||
h.push(0x00);
|
h.push(0x00);
|
||||||
// Cipher suites: length=2, one suite
|
// Cipher suites: length=2, one suite
|
||||||
h.push(0x00); h.push(0x02);
|
h.push(0x00);
|
||||||
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
h.push(0x02);
|
||||||
|
h.push(0x00);
|
||||||
|
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||||
// Compression methods: length=1, null
|
// Compression methods: length=1, null
|
||||||
h.push(0x01); h.push(0x00);
|
h.push(0x01);
|
||||||
|
h.push(0x00);
|
||||||
// Extensions
|
// Extensions
|
||||||
h.extend_from_slice(&extensions);
|
h.extend_from_slice(&extensions);
|
||||||
h
|
h
|
||||||
@@ -313,7 +323,8 @@ mod tests {
|
|||||||
// TLS record: type=0x16, version TLS 1.0, length
|
// TLS record: type=0x16, version TLS 1.0, length
|
||||||
let mut record = Vec::new();
|
let mut record = Vec::new();
|
||||||
record.push(0x16); // Handshake
|
record.push(0x16); // Handshake
|
||||||
record.push(0x03); record.push(0x01); // TLS 1.0
|
record.push(0x03);
|
||||||
|
record.push(0x01); // TLS 1.0
|
||||||
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
||||||
record.push((handshake.len() & 0xFF) as u8);
|
record.push((handshake.len() & 0xFF) as u8);
|
||||||
record.extend_from_slice(&handshake);
|
record.extend_from_slice(&handshake);
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
|
|||||||
use rustls::sign::CertifiedKey;
|
use rustls::sign::CertifiedKey;
|
||||||
use rustls::ServerConfig;
|
use rustls::ServerConfig;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::tcp_listener::TlsCertConfig;
|
use crate::tcp_listener::TlsCertConfig;
|
||||||
@@ -29,7 +29,9 @@ pub struct CertResolver {
|
|||||||
impl CertResolver {
|
impl CertResolver {
|
||||||
/// Build a resolver from PEM-encoded cert/key configs.
|
/// Build a resolver from PEM-encoded cert/key configs.
|
||||||
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
|
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
|
||||||
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
pub fn new(
|
||||||
|
configs: &HashMap<String, TlsCertConfig>,
|
||||||
|
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
ensure_crypto_provider();
|
ensure_crypto_provider();
|
||||||
let provider = rustls::crypto::ring::default_provider();
|
let provider = rustls::crypto::ring::default_provider();
|
||||||
let mut certs = HashMap::new();
|
let mut certs = HashMap::new();
|
||||||
@@ -38,8 +40,10 @@ impl CertResolver {
|
|||||||
for (domain, cfg) in configs {
|
for (domain, cfg) in configs {
|
||||||
let cert_chain = load_certs(&cfg.cert_pem)?;
|
let cert_chain = load_certs(&cfg.cert_pem)?;
|
||||||
let key = load_private_key(&cfg.key_pem)?;
|
let key = load_private_key(&cfg.key_pem)?;
|
||||||
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
|
let ck = Arc::new(
|
||||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
|
CertifiedKey::from_der(cert_chain, key, &provider)
|
||||||
|
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
|
||||||
|
);
|
||||||
if domain == "*" {
|
if domain == "*" {
|
||||||
fallback = Some(Arc::clone(&ck));
|
fallback = Some(Arc::clone(&ck));
|
||||||
}
|
}
|
||||||
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
|
|||||||
|
|
||||||
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
|
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
|
||||||
/// The returned acceptor can be reused across all connections (cheap Arc clone).
|
/// The returned acceptor can be reused across all connections (cheap Arc clone).
|
||||||
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
pub fn build_shared_tls_acceptor(
|
||||||
|
resolver: CertResolver,
|
||||||
|
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
ensure_crypto_provider();
|
ensure_crypto_provider();
|
||||||
let mut config = ServerConfig::builder()
|
let mut config = ServerConfig::builder()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
|
|||||||
// Shared session cache — enables session ID resumption across connections
|
// Shared session cache — enables session ID resumption across connections
|
||||||
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
|
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
|
||||||
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
|
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
|
||||||
config.ticketer = rustls::crypto::ring::Ticketer::new()
|
config.ticketer =
|
||||||
.map_err(|e| format!("Ticketer: {}", e))?;
|
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
|
||||||
|
|
||||||
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
|
info!(
|
||||||
|
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
|
||||||
|
);
|
||||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
||||||
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
|
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
|
||||||
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
pub fn build_tls_acceptor(
|
||||||
|
cert_pem: &str,
|
||||||
|
key_pem: &str,
|
||||||
|
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
|
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
|
||||||
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
|
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
|
||||||
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
pub fn build_tls_acceptor_h1_only(
|
||||||
|
cert_pem: &str,
|
||||||
|
key_pem: &str,
|
||||||
|
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
ensure_crypto_provider();
|
ensure_crypto_provider();
|
||||||
let certs = load_certs(cert_pem)?;
|
let certs = load_certs(cert_pem)?;
|
||||||
let key = load_private_key(key_pem)?;
|
let key = load_private_key(key_pem)?;
|
||||||
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
|
|||||||
// Apply TLS version restrictions
|
// Apply TLS version restrictions
|
||||||
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
||||||
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
||||||
builder
|
builder.with_no_client_auth().with_single_cert(certs, key)?
|
||||||
.with_no_client_auth()
|
|
||||||
.with_single_cert(certs, key)?
|
|
||||||
} else {
|
} else {
|
||||||
ServerConfig::builder()
|
ServerConfig::builder()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
||||||
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
fn resolve_tls_versions(
|
||||||
|
versions: Option<&[String]>,
|
||||||
|
) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||||
let versions = match versions {
|
let versions = match versions {
|
||||||
Some(v) if !v.is_empty() => v,
|
Some(v) if !v.is_empty() => v,
|
||||||
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
||||||
@@ -207,7 +221,8 @@ pub async fn accept_tls(
|
|||||||
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||||
|
|
||||||
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||||
SHARED_CLIENT_CONFIG.get_or_init(|| {
|
SHARED_CLIENT_CONFIG
|
||||||
|
.get_or_init(|| {
|
||||||
ensure_crypto_provider();
|
ensure_crypto_provider();
|
||||||
let config = rustls::ClientConfig::builder()
|
let config = rustls::ClientConfig::builder()
|
||||||
.dangerous()
|
.dangerous()
|
||||||
@@ -215,7 +230,8 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
|||||||
.with_no_client_auth();
|
.with_no_client_auth();
|
||||||
info!("Built shared backend TLS client config with session resumption");
|
info!("Built shared backend TLS client config with session resumption");
|
||||||
Arc::new(config)
|
Arc::new(config)
|
||||||
}).clone()
|
})
|
||||||
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
|
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
|
||||||
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
|||||||
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||||
|
|
||||||
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
|
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
|
||||||
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
|
SHARED_CLIENT_CONFIG_ALPN
|
||||||
|
.get_or_init(|| {
|
||||||
ensure_crypto_provider();
|
ensure_crypto_provider();
|
||||||
let mut config = rustls::ClientConfig::builder()
|
let mut config = rustls::ClientConfig::builder()
|
||||||
.dangerous()
|
.dangerous()
|
||||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||||
.with_no_client_auth();
|
.with_no_client_auth();
|
||||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||||
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
|
info!(
|
||||||
|
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
|
||||||
|
);
|
||||||
Arc::new(config)
|
Arc::new(config)
|
||||||
}).clone()
|
})
|
||||||
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
||||||
@@ -249,7 +269,8 @@ pub async fn connect_tls(
|
|||||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||||
stream.set_nodelay(true)?;
|
stream.set_nodelay(true)?;
|
||||||
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
|
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
|
||||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
|
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
|
||||||
|
{
|
||||||
debug!("Failed to set keepalive on backend TLS socket: {}", e);
|
debug!("Failed to set keepalive on backend TLS socket: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,10 +281,12 @@ pub async fn connect_tls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Load certificates from PEM string.
|
/// Load certificates from PEM string.
|
||||||
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
fn load_certs(
|
||||||
|
pem: &str,
|
||||||
|
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut reader = BufReader::new(pem.as_bytes());
|
let mut reader = BufReader::new(pem.as_bytes());
|
||||||
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
let certs: Vec<CertificateDer<'static>> =
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
|
||||||
if certs.is_empty() {
|
if certs.is_empty() {
|
||||||
return Err("No certificates found in PEM data".into());
|
return Err("No certificates found in PEM data".into());
|
||||||
}
|
}
|
||||||
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Load private key from PEM string.
|
/// Load private key from PEM string.
|
||||||
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
fn load_private_key(
|
||||||
|
pem: &str,
|
||||||
|
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut reader = BufReader::new(pem.as_bytes());
|
let mut reader = BufReader::new(pem.as_bytes());
|
||||||
// Try PKCS8 first, then RSA, then EC
|
// Try PKCS8 first, then RSA, then EC
|
||||||
let key = rustls_pemfile::private_key(&mut reader)?
|
let key =
|
||||||
.ok_or("No private key found in PEM data")?;
|
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
|
||||||
Ok(key)
|
Ok(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|||||||
|
|
||||||
use arc_swap::ArcSwap;
|
use arc_swap::ArcSwap;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
use tokio::task::JoinHandle;
|
|
||||||
use tokio::sync::{Mutex, RwLock};
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
use rustproxy_config::{RouteActionType, TransportProtocol};
|
use rustproxy_config::{RouteActionType, TransportProtocol};
|
||||||
use rustproxy_metrics::MetricsCollector;
|
use rustproxy_metrics::MetricsCollector;
|
||||||
use rustproxy_routing::{MatchContext, RouteManager};
|
use rustproxy_routing::{MatchContext, RouteManager};
|
||||||
|
use rustproxy_security::IpBlockList;
|
||||||
|
|
||||||
use rustproxy_http::h3_service::H3ProxyService;
|
use rustproxy_http::h3_service::H3ProxyService;
|
||||||
|
|
||||||
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
|
|||||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||||
/// Shared connection registry for selective recycling.
|
/// Shared connection registry for selective recycling.
|
||||||
connection_registry: Arc<ConnectionRegistry>,
|
connection_registry: Arc<ConnectionRegistry>,
|
||||||
|
/// Global ingress block policy, hot-reloadable without restarting listeners.
|
||||||
|
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for UdpListenerManager {
|
impl Drop for UdpListenerManager {
|
||||||
@@ -99,17 +102,26 @@ impl UdpListenerManager {
|
|||||||
proxy_ips: Arc::new(Vec::new()),
|
proxy_ips: Arc::new(Vec::new()),
|
||||||
route_cancels,
|
route_cancels,
|
||||||
connection_registry,
|
connection_registry,
|
||||||
|
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
|
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
|
||||||
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
|
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
|
||||||
if !ips.is_empty() {
|
if !ips.is_empty() {
|
||||||
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
|
info!(
|
||||||
|
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
|
||||||
|
ips.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
self.proxy_ips = Arc::new(ips);
|
self.proxy_ips = Arc::new(ips);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the shared global ingress security policy.
|
||||||
|
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
|
||||||
|
self.security_policy = policy;
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the H3 proxy service for HTTP/3 request handling.
|
/// Set the H3 proxy service for HTTP/3 request handling.
|
||||||
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
|
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
|
||||||
self.h3_service = Some(svc);
|
self.h3_service = Some(svc);
|
||||||
@@ -142,7 +154,9 @@ impl UdpListenerManager {
|
|||||||
// Check if any route on this port uses QUIC
|
// Check if any route on this port uses QUIC
|
||||||
let rm = self.route_manager.load();
|
let rm = self.route_manager.load();
|
||||||
let has_quic = rm.routes_for_port(port).iter().any(|r| {
|
let has_quic = rm.routes_for_port(port).iter().any(|r| {
|
||||||
r.action.udp.as_ref()
|
r.action
|
||||||
|
.udp
|
||||||
|
.as_ref()
|
||||||
.and_then(|u| u.quic.as_ref())
|
.and_then(|u| u.quic.as_ref())
|
||||||
.is_some()
|
.is_some()
|
||||||
});
|
});
|
||||||
@@ -164,8 +178,10 @@ impl UdpListenerManager {
|
|||||||
None,
|
None,
|
||||||
Arc::clone(&self.route_cancels),
|
Arc::clone(&self.route_cancels),
|
||||||
Arc::clone(&self.connection_registry),
|
Arc::clone(&self.connection_registry),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
));
|
));
|
||||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
self.listeners
|
||||||
|
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
info!("QUIC endpoint started on port {}", port);
|
info!("QUIC endpoint started on port {}", port);
|
||||||
} else {
|
} else {
|
||||||
// Proxy relay path: we own external socket, quinn on localhost
|
// Proxy relay path: we own external socket, quinn on localhost
|
||||||
@@ -173,6 +189,7 @@ impl UdpListenerManager {
|
|||||||
port,
|
port,
|
||||||
tls,
|
tls,
|
||||||
Arc::clone(&self.proxy_ips),
|
Arc::clone(&self.proxy_ips),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
self.cancel_token.child_token(),
|
self.cancel_token.child_token(),
|
||||||
)?;
|
)?;
|
||||||
let endpoint_for_updates = relay.endpoint.clone();
|
let endpoint_for_updates = relay.endpoint.clone();
|
||||||
@@ -187,13 +204,18 @@ impl UdpListenerManager {
|
|||||||
Some(relay.real_client_map),
|
Some(relay.real_client_map),
|
||||||
Arc::clone(&self.route_cancels),
|
Arc::clone(&self.route_cancels),
|
||||||
Arc::clone(&self.connection_registry),
|
Arc::clone(&self.connection_registry),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
));
|
));
|
||||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
self.listeners
|
||||||
|
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
||||||
}
|
}
|
||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
warn!(
|
||||||
|
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
|
||||||
|
port
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,6 +236,7 @@ impl UdpListenerManager {
|
|||||||
Arc::clone(&self.relay_writer),
|
Arc::clone(&self.relay_writer),
|
||||||
self.cancel_token.child_token(),
|
self.cancel_token.child_token(),
|
||||||
Arc::clone(&self.proxy_ips),
|
Arc::clone(&self.proxy_ips),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
));
|
));
|
||||||
|
|
||||||
self.listeners.insert(port, (handle, None));
|
self.listeners.insert(port, (handle, None));
|
||||||
@@ -254,8 +277,10 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
debug!("UDP listener stopped on port {}", port);
|
debug!("UDP listener stopped on port {}", port);
|
||||||
}
|
}
|
||||||
info!("All UDP listeners stopped, {} sessions remaining",
|
info!(
|
||||||
self.session_table.session_count());
|
"All UDP listeners stopped, {} sessions remaining",
|
||||||
|
self.session_table.session_count()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update TLS config on all active QUIC endpoints (cert refresh).
|
/// Update TLS config on all active QUIC endpoints (cert refresh).
|
||||||
@@ -288,11 +313,15 @@ impl UdpListenerManager {
|
|||||||
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
|
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
|
||||||
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
|
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
|
||||||
let rm = self.route_manager.load();
|
let rm = self.route_manager.load();
|
||||||
let upgrade_ports: Vec<u16> = self.listeners.iter()
|
let upgrade_ports: Vec<u16> = self
|
||||||
|
.listeners
|
||||||
|
.iter()
|
||||||
.filter(|(_, (_, endpoint))| endpoint.is_none())
|
.filter(|(_, (_, endpoint))| endpoint.is_none())
|
||||||
.filter(|(port, _)| {
|
.filter(|(port, _)| {
|
||||||
rm.routes_for_port(**port).iter().any(|r| {
|
rm.routes_for_port(**port).iter().any(|r| {
|
||||||
r.action.udp.as_ref()
|
r.action
|
||||||
|
.udp
|
||||||
|
.as_ref()
|
||||||
.and_then(|u| u.quic.as_ref())
|
.and_then(|u| u.quic.as_ref())
|
||||||
.is_some()
|
.is_some()
|
||||||
})
|
})
|
||||||
@@ -301,17 +330,23 @@ impl UdpListenerManager {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
for port in upgrade_ports {
|
for port in upgrade_ports {
|
||||||
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
|
info!(
|
||||||
|
"Upgrading raw UDP listener on port {} to QUIC endpoint",
|
||||||
|
port
|
||||||
|
);
|
||||||
|
|
||||||
// Stop the raw UDP listener task and drain sessions to release the socket
|
// Stop the raw UDP listener task and drain sessions to release the socket
|
||||||
if let Some((handle, _)) = self.listeners.remove(&port) {
|
if let Some((handle, _)) = self.listeners.remove(&port) {
|
||||||
handle.abort();
|
handle.abort();
|
||||||
}
|
}
|
||||||
let drained = self.session_table.drain_port(
|
let drained = self
|
||||||
port, &self.metrics, &self.conn_tracker,
|
.session_table
|
||||||
);
|
.drain_port(port, &self.metrics, &self.conn_tracker);
|
||||||
if drained > 0 {
|
if drained > 0 {
|
||||||
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
|
debug!(
|
||||||
|
"Drained {} UDP sessions on port {} for QUIC upgrade",
|
||||||
|
drained, port
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Brief yield to let aborted tasks drop their socket references
|
// Brief yield to let aborted tasks drop their socket references
|
||||||
@@ -326,11 +361,17 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
match create_result {
|
match create_result {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
|
info!(
|
||||||
|
"QUIC endpoint started on port {} (upgraded from raw UDP)",
|
||||||
|
port
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Port may still be held — retry once after a brief delay
|
// Port may still be held — retry once after a brief delay
|
||||||
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
|
warn!(
|
||||||
|
"QUIC endpoint creation failed on port {}, retrying: {}",
|
||||||
|
port, e
|
||||||
|
);
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
let retry_result = if self.proxy_ips.is_empty() {
|
let retry_result = if self.proxy_ips.is_empty() {
|
||||||
@@ -341,11 +382,17 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
match retry_result {
|
match retry_result {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
|
info!(
|
||||||
|
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
|
||||||
|
port
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e2) => {
|
Err(e2) => {
|
||||||
error!("Failed to upgrade port {} to QUIC after retry: {}. \
|
error!(
|
||||||
Rebinding as raw UDP.", port, e2);
|
"Failed to upgrade port {} to QUIC after retry: {}. \
|
||||||
|
Rebinding as raw UDP.",
|
||||||
|
port, e2
|
||||||
|
);
|
||||||
// Fallback: rebind as raw UDP so the port isn't dead
|
// Fallback: rebind as raw UDP so the port isn't dead
|
||||||
if let Ok(()) = self.rebind_raw_udp(port).await {
|
if let Ok(()) = self.rebind_raw_udp(port).await {
|
||||||
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
|
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
|
||||||
@@ -358,7 +405,11 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a direct QUIC endpoint (quinn owns the socket).
|
/// Create a direct QUIC endpoint (quinn owns the socket).
|
||||||
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
fn create_quic_direct(
|
||||||
|
&mut self,
|
||||||
|
port: u16,
|
||||||
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
|
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
|
||||||
let endpoint_for_updates = endpoint.clone();
|
let endpoint_for_updates = endpoint.clone();
|
||||||
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||||
@@ -372,17 +423,24 @@ impl UdpListenerManager {
|
|||||||
None,
|
None,
|
||||||
Arc::clone(&self.route_cancels),
|
Arc::clone(&self.route_cancels),
|
||||||
Arc::clone(&self.connection_registry),
|
Arc::clone(&self.connection_registry),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
));
|
));
|
||||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
self.listeners
|
||||||
|
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a QUIC endpoint with PROXY protocol relay.
|
/// Create a QUIC endpoint with PROXY protocol relay.
|
||||||
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
fn create_quic_with_relay(
|
||||||
|
&mut self,
|
||||||
|
port: u16,
|
||||||
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
||||||
port,
|
port,
|
||||||
tls_config,
|
tls_config,
|
||||||
Arc::clone(&self.proxy_ips),
|
Arc::clone(&self.proxy_ips),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
self.cancel_token.child_token(),
|
self.cancel_token.child_token(),
|
||||||
)?;
|
)?;
|
||||||
let endpoint_for_updates = relay.endpoint.clone();
|
let endpoint_for_updates = relay.endpoint.clone();
|
||||||
@@ -397,8 +455,10 @@ impl UdpListenerManager {
|
|||||||
Some(relay.real_client_map),
|
Some(relay.real_client_map),
|
||||||
Arc::clone(&self.route_cancels),
|
Arc::clone(&self.route_cancels),
|
||||||
Arc::clone(&self.connection_registry),
|
Arc::clone(&self.connection_registry),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
));
|
));
|
||||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
self.listeners
|
||||||
|
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -419,6 +479,7 @@ impl UdpListenerManager {
|
|||||||
Arc::clone(&self.relay_writer),
|
Arc::clone(&self.relay_writer),
|
||||||
self.cancel_token.child_token(),
|
self.cancel_token.child_token(),
|
||||||
Arc::clone(&self.proxy_ips),
|
Arc::clone(&self.proxy_ips),
|
||||||
|
Arc::clone(&self.security_policy),
|
||||||
));
|
));
|
||||||
|
|
||||||
self.listeners.insert(port, (handle, None));
|
self.listeners.insert(port, (handle, None));
|
||||||
@@ -458,7 +519,10 @@ impl UdpListenerManager {
|
|||||||
info!("Datagram handler relay connected to {}", path);
|
info!("Datagram handler relay connected to {}", path);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to connect datagram handler relay to {}: {}", path, e);
|
error!(
|
||||||
|
"Failed to connect datagram handler relay to {}: {}",
|
||||||
|
path, e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -514,6 +578,7 @@ impl UdpListenerManager {
|
|||||||
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
proxy_ips: Arc<Vec<IpAddr>>,
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
|
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||||
) {
|
) {
|
||||||
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
||||||
let mut buf = vec![0u8; 65535];
|
let mut buf = vec![0u8; 65535];
|
||||||
@@ -528,9 +593,11 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Periodic cleanup: remove proxy_addr_map entries with no active session
|
// Periodic cleanup: remove proxy_addr_map entries with no active session
|
||||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
|
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
|
||||||
|
{
|
||||||
last_proxy_cleanup = tokio::time::Instant::now();
|
last_proxy_cleanup = tokio::time::Instant::now();
|
||||||
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
|
let stale: Vec<SocketAddr> = proxy_addr_map
|
||||||
|
.iter()
|
||||||
.filter(|entry| {
|
.filter(|entry| {
|
||||||
let key: SessionKey = (*entry.key(), port);
|
let key: SessionKey = (*entry.key(), port);
|
||||||
session_table.get(&key).is_none()
|
session_table.get(&key).is_none()
|
||||||
@@ -538,7 +605,11 @@ impl UdpListenerManager {
|
|||||||
.map(|entry| *entry.key())
|
.map(|entry| *entry.key())
|
||||||
.collect();
|
.collect();
|
||||||
if !stale.is_empty() {
|
if !stale.is_empty() {
|
||||||
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
|
debug!(
|
||||||
|
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
|
||||||
|
stale.len(),
|
||||||
|
port
|
||||||
|
);
|
||||||
for addr in stale {
|
for addr in stale {
|
||||||
proxy_addr_map.remove(&addr);
|
proxy_addr_map.remove(&addr);
|
||||||
}
|
}
|
||||||
@@ -564,14 +635,20 @@ impl UdpListenerManager {
|
|||||||
let datagram = &buf[..len];
|
let datagram = &buf[..len];
|
||||||
|
|
||||||
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
|
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
|
||||||
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
let effective_client_ip =
|
||||||
|
if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||||
let session_key: SessionKey = (client_addr, port);
|
let session_key: SessionKey = (client_addr, port);
|
||||||
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
|
if session_table.get(&session_key).is_none()
|
||||||
|
&& !proxy_addr_map.contains_key(&client_addr)
|
||||||
|
{
|
||||||
// No session and no prior PROXY header — check for PROXY v2
|
// No session and no prior PROXY header — check for PROXY v2
|
||||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||||
match crate::proxy_protocol::parse_v2(datagram) {
|
match crate::proxy_protocol::parse_v2(datagram) {
|
||||||
Ok((header, _consumed)) => {
|
Ok((header, _consumed)) => {
|
||||||
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
|
debug!(
|
||||||
|
"UDP PROXY v2 from {}: real client {}",
|
||||||
|
client_addr, header.source_addr
|
||||||
|
);
|
||||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||||
continue; // discard the PROXY v2 datagram
|
continue; // discard the PROXY v2 datagram
|
||||||
}
|
}
|
||||||
@@ -585,7 +662,8 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Use real client IP if we've previously seen a PROXY v2 header
|
// Use real client IP if we've previously seen a PROXY v2 header
|
||||||
proxy_addr_map.get(&client_addr)
|
proxy_addr_map
|
||||||
|
.get(&client_addr)
|
||||||
.map(|r| r.ip())
|
.map(|r| r.ip())
|
||||||
.unwrap_or_else(|| client_addr.ip())
|
.unwrap_or_else(|| client_addr.ip())
|
||||||
}
|
}
|
||||||
@@ -593,6 +671,15 @@ impl UdpListenerManager {
|
|||||||
client_addr.ip()
|
client_addr.ip()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let block_list = security_policy.load();
|
||||||
|
if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
|
||||||
|
debug!(
|
||||||
|
"UDP datagram from {} blocked by global security policy",
|
||||||
|
effective_client_ip
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Route matching — use effective (real) client IP
|
// Route matching — use effective (real) client IP
|
||||||
let rm = route_manager.load();
|
let rm = route_manager.load();
|
||||||
let ip_str = effective_client_ip.to_string();
|
let ip_str = effective_client_ip.to_string();
|
||||||
@@ -611,7 +698,10 @@ impl UdpListenerManager {
|
|||||||
let route_match = match rm.find_route(&ctx) {
|
let route_match = match rm.find_route(&ctx) {
|
||||||
Some(m) => m,
|
Some(m) => m,
|
||||||
None => {
|
None => {
|
||||||
debug!("No UDP route matched for port {} from {}", port, client_addr);
|
debug!(
|
||||||
|
"No UDP route matched for port {} from {}",
|
||||||
|
port, client_addr
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -627,7 +717,9 @@ impl UdpListenerManager {
|
|||||||
&client_addr,
|
&client_addr,
|
||||||
port,
|
port,
|
||||||
datagram,
|
datagram,
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
debug!("Failed to relay UDP datagram to TS: {}", e);
|
debug!("Failed to relay UDP datagram to TS: {}", e);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
@@ -638,8 +730,10 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
// Check datagram size
|
// Check datagram size
|
||||||
if len as u32 > udp_config.max_datagram_size {
|
if len as u32 > udp_config.max_datagram_size {
|
||||||
debug!("UDP datagram too large ({} > {}) from {}, dropping",
|
debug!(
|
||||||
len, udp_config.max_datagram_size, client_addr);
|
"UDP datagram too large ({} > {}) from {}, dropping",
|
||||||
|
len, udp_config.max_datagram_size, client_addr
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -651,21 +745,27 @@ impl UdpListenerManager {
|
|||||||
None => {
|
None => {
|
||||||
// New session — check per-IP limits using the real client IP
|
// New session — check per-IP limits using the real client IP
|
||||||
if !conn_tracker.try_accept(&effective_client_ip) {
|
if !conn_tracker.try_accept(&effective_client_ip) {
|
||||||
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
|
debug!(
|
||||||
|
"UDP session rejected for {} (rate limit)",
|
||||||
|
effective_client_ip
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if !session_table.can_create_session(
|
if !session_table
|
||||||
&effective_client_ip,
|
.can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
|
||||||
udp_config.max_sessions_per_ip,
|
{
|
||||||
) {
|
debug!(
|
||||||
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
|
"UDP session rejected for {} (per-IP session limit)",
|
||||||
|
effective_client_ip
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve target
|
// Resolve target
|
||||||
let target = match route_match.target.or_else(|| {
|
let target = match route_match
|
||||||
route.action.targets.as_ref().and_then(|t| t.first())
|
.target
|
||||||
}) {
|
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
|
||||||
|
{
|
||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
None => {
|
None => {
|
||||||
warn!("No target for UDP route {:?}", route_id);
|
warn!("No target for UDP route {:?}", route_id);
|
||||||
@@ -686,13 +786,18 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
if let Err(e) = backend_socket.connect(&backend_addr).await {
|
if let Err(e) = backend_socket.connect(&backend_addr).await {
|
||||||
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
|
error!(
|
||||||
|
"Failed to connect backend UDP socket to {}: {}",
|
||||||
|
backend_addr, e
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let backend_socket = Arc::new(backend_socket);
|
let backend_socket = Arc::new(backend_socket);
|
||||||
|
|
||||||
debug!("New UDP session: {} -> {} (via port {}, real client {})",
|
debug!(
|
||||||
client_addr, backend_addr, port, effective_client_ip);
|
"New UDP session: {} -> {} (via port {}, real client {})",
|
||||||
|
client_addr, backend_addr, port, effective_client_ip
|
||||||
|
);
|
||||||
|
|
||||||
// Spawn return-path relay task
|
// Spawn return-path relay task
|
||||||
let session_cancel = CancellationToken::new();
|
let session_cancel = CancellationToken::new();
|
||||||
@@ -709,7 +814,9 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
let session = Arc::new(UdpSession {
|
let session = Arc::new(UdpSession {
|
||||||
backend_socket,
|
backend_socket,
|
||||||
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
last_activity: std::sync::atomic::AtomicU64::new(
|
||||||
|
session_table.elapsed_ms(),
|
||||||
|
),
|
||||||
created_at: std::time::Instant::now(),
|
created_at: std::time::Instant::now(),
|
||||||
route_id: route_id.map(|s| s.to_string()),
|
route_id: route_id.map(|s| s.to_string()),
|
||||||
source_ip: effective_client_ip,
|
source_ip: effective_client_ip,
|
||||||
@@ -718,7 +825,11 @@ impl UdpListenerManager {
|
|||||||
cancel: session_cancel,
|
cancel: session_cancel,
|
||||||
});
|
});
|
||||||
|
|
||||||
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
|
if !session_table.insert(
|
||||||
|
session_key,
|
||||||
|
Arc::clone(&session),
|
||||||
|
udp_config.max_sessions_per_ip,
|
||||||
|
) {
|
||||||
warn!("Failed to insert UDP session (race condition)");
|
warn!("Failed to insert UDP session (race condition)");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -735,7 +846,9 @@ impl UdpListenerManager {
|
|||||||
// Forward datagram to backend
|
// Forward datagram to backend
|
||||||
match session.backend_socket.send(datagram).await {
|
match session.backend_socket.send(datagram).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
session
|
||||||
|
.last_activity
|
||||||
|
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||||
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
|
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
|
||||||
metrics.record_datagram_in();
|
metrics.record_datagram_in();
|
||||||
}
|
}
|
||||||
@@ -779,7 +892,9 @@ impl UdpListenerManager {
|
|||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
// Update session activity
|
// Update session activity
|
||||||
if let Some(session) = session_table.get(&session_key) {
|
if let Some(session) = session_table.get(&session_key) {
|
||||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
session
|
||||||
|
.last_activity
|
||||||
|
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
|
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
|
||||||
metrics.record_datagram_out();
|
metrics.record_datagram_out();
|
||||||
@@ -814,7 +929,8 @@ impl UdpListenerManager {
|
|||||||
let json = serde_json::to_vec(&msg)?;
|
let json = serde_json::to_vec(&msg)?;
|
||||||
|
|
||||||
let mut guard = writer.lock().await;
|
let mut guard = writer.lock().await;
|
||||||
let stream = guard.as_mut()
|
let stream = guard
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
|
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
|
||||||
|
|
||||||
// Length-prefixed frame
|
// Length-prefixed frame
|
||||||
@@ -879,9 +995,15 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
|
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
let source_port = reply
|
||||||
|
.get("sourcePort")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(0) as u16;
|
||||||
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||||
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or("");
|
let payload_b64 = reply
|
||||||
|
.get("payloadBase64")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
|
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
|
|||||||
@@ -111,12 +111,15 @@ impl UdpSessionTable {
|
|||||||
|
|
||||||
/// Look up an existing session.
|
/// Look up an existing session.
|
||||||
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
||||||
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
|
self.sessions
|
||||||
|
.get(key)
|
||||||
|
.map(|entry| Arc::clone(entry.value()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if we can create a new session for this IP (under the per-IP limit).
|
/// Check if we can create a new session for this IP (under the per-IP limit).
|
||||||
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
||||||
let count = self.ip_session_counts
|
let count = self
|
||||||
|
.ip_session_counts
|
||||||
.get(ip)
|
.get(ip)
|
||||||
.map(|c| *c.value())
|
.map(|c| *c.value())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
@@ -124,12 +127,7 @@ impl UdpSessionTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Insert a new session. Returns false if per-IP limit exceeded.
|
/// Insert a new session. Returns false if per-IP limit exceeded.
|
||||||
pub fn insert(
|
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
|
||||||
&self,
|
|
||||||
key: SessionKey,
|
|
||||||
session: Arc<UdpSession>,
|
|
||||||
max_per_ip: u32,
|
|
||||||
) -> bool {
|
|
||||||
let ip = session.source_ip;
|
let ip = session.source_ip;
|
||||||
|
|
||||||
// Atomically check and increment per-IP count
|
// Atomically check and increment per-IP count
|
||||||
@@ -173,7 +171,9 @@ impl UdpSessionTable {
|
|||||||
let mut removed = 0;
|
let mut removed = 0;
|
||||||
|
|
||||||
// Collect keys to remove (avoid holding DashMap refs during removal)
|
// Collect keys to remove (avoid holding DashMap refs during removal)
|
||||||
let stale_keys: Vec<SessionKey> = self.sessions.iter()
|
let stale_keys: Vec<SessionKey> = self
|
||||||
|
.sessions
|
||||||
|
.iter()
|
||||||
.filter(|entry| {
|
.filter(|entry| {
|
||||||
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
||||||
now_ms.saturating_sub(last) >= timeout_ms
|
now_ms.saturating_sub(last) >= timeout_ms
|
||||||
@@ -185,7 +185,8 @@ impl UdpSessionTable {
|
|||||||
if let Some(session) = self.remove(&key) {
|
if let Some(session) = self.remove(&key) {
|
||||||
debug!(
|
debug!(
|
||||||
"UDP session expired: {} -> port {} (idle {}ms)",
|
"UDP session expired: {} -> port {} (idle {}ms)",
|
||||||
session.client_addr, key.1,
|
session.client_addr,
|
||||||
|
key.1,
|
||||||
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
||||||
);
|
);
|
||||||
conn_tracker.connection_closed(&session.source_ip);
|
conn_tracker.connection_closed(&session.source_ip);
|
||||||
@@ -210,7 +211,9 @@ impl UdpSessionTable {
|
|||||||
metrics: &MetricsCollector,
|
metrics: &MetricsCollector,
|
||||||
conn_tracker: &ConnectionTracker,
|
conn_tracker: &ConnectionTracker,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
let keys: Vec<SessionKey> = self.sessions.iter()
|
let keys: Vec<SessionKey> = self
|
||||||
|
.sessions
|
||||||
|
.iter()
|
||||||
.filter(|entry| entry.key().1 == port)
|
.filter(|entry| entry.key().1 == port)
|
||||||
.map(|entry| *entry.key())
|
.map(|entry| *entry.key())
|
||||||
.collect();
|
.collect();
|
||||||
@@ -257,9 +260,8 @@ mod tests {
|
|||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let backend_socket = rt.block_on(async {
|
let backend_socket =
|
||||||
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
|
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
|
||||||
});
|
|
||||||
|
|
||||||
let child_cancel = cancel.child_token();
|
let child_cancel = cancel.child_token();
|
||||||
let return_task = rt.spawn(async move {
|
let return_task = rt.spawn(async move {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! Route matching engine for RustProxy.
|
//! Route matching engine for RustProxy.
|
||||||
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
|
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
|
||||||
|
|
||||||
pub mod route_manager;
|
|
||||||
pub mod matchers;
|
pub mod matchers;
|
||||||
|
pub mod route_manager;
|
||||||
|
|
||||||
pub use route_manager::*;
|
pub use route_manager::*;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use ipnet::IpNet;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use ipnet::IpNet;
|
|
||||||
|
|
||||||
/// Match an IP address against a pattern.
|
/// Match an IP address against a pattern.
|
||||||
///
|
///
|
||||||
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
|
Some(format!(
|
||||||
|
"{}.{}.{}.{}/{}",
|
||||||
|
octets[0], octets[1], octets[2], octets[3], prefix_len
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
pub mod domain;
|
pub mod domain;
|
||||||
pub mod path;
|
|
||||||
pub mod ip;
|
|
||||||
pub mod header;
|
pub mod header;
|
||||||
|
pub mod ip;
|
||||||
|
pub mod path;
|
||||||
|
|
||||||
pub use domain::*;
|
pub use domain::*;
|
||||||
pub use path::*;
|
|
||||||
pub use ip::*;
|
|
||||||
pub use header::*;
|
pub use header::*;
|
||||||
|
pub use ip::*;
|
||||||
|
pub use path::*;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
|
|
||||||
use crate::matchers;
|
use crate::matchers;
|
||||||
|
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
|
||||||
|
|
||||||
/// Context for route matching (subset of connection info).
|
/// Context for route matching (subset of connection info).
|
||||||
pub struct MatchContext<'a> {
|
pub struct MatchContext<'a> {
|
||||||
@@ -42,19 +42,14 @@ impl RouteManager {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Filter enabled routes and sort by priority
|
// Filter enabled routes and sort by priority
|
||||||
let mut enabled_routes: Vec<RouteConfig> = routes
|
let mut enabled_routes: Vec<RouteConfig> =
|
||||||
.into_iter()
|
routes.into_iter().filter(|r| r.is_enabled()).collect();
|
||||||
.filter(|r| r.is_enabled())
|
|
||||||
.collect();
|
|
||||||
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
|
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
|
||||||
|
|
||||||
// Build port index
|
// Build port index
|
||||||
for (idx, route) in enabled_routes.iter().enumerate() {
|
for (idx, route) in enabled_routes.iter().enumerate() {
|
||||||
for port in route.listening_ports() {
|
for port in route.listening_ports() {
|
||||||
manager.port_index
|
manager.port_index.entry(port).or_default().push(idx);
|
||||||
.entry(port)
|
|
||||||
.or_default()
|
|
||||||
.push(idx);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +61,9 @@ impl RouteManager {
|
|||||||
/// Used to skip expensive header HashMap construction when no route needs it.
|
/// Used to skip expensive header HashMap construction when no route needs it.
|
||||||
pub fn any_route_has_headers(&self, port: u16) -> bool {
|
pub fn any_route_has_headers(&self, port: u16) -> bool {
|
||||||
if let Some(indices) = self.port_index.get(&port) {
|
if let Some(indices) = self.port_index.get(&port) {
|
||||||
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some())
|
indices
|
||||||
|
.iter()
|
||||||
|
.any(|&idx| self.routes[idx].route_match.headers.is_some())
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -99,8 +96,8 @@ impl RouteManager {
|
|||||||
let ctx_transport = ctx.transport.as_ref();
|
let ctx_transport = ctx.transport.as_ref();
|
||||||
match (route_transport, ctx_transport) {
|
match (route_transport, ctx_transport) {
|
||||||
// Route requires UDP only — reject non-UDP contexts
|
// Route requires UDP only — reject non-UDP contexts
|
||||||
(Some(TransportProtocol::Udp), None) |
|
(Some(TransportProtocol::Udp), None)
|
||||||
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
|
| (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
|
||||||
// Route requires TCP only — reject UDP contexts
|
// Route requires TCP only — reject UDP contexts
|
||||||
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
|
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
|
||||||
// Route has no transport (default = TCP) — reject UDP contexts
|
// Route has no transport (default = TCP) — reject UDP contexts
|
||||||
@@ -196,7 +193,11 @@ impl RouteManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Find the best matching target within a route.
|
/// Find the best matching target within a route.
|
||||||
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
|
fn find_target<'a>(
|
||||||
|
&self,
|
||||||
|
route: &'a RouteConfig,
|
||||||
|
ctx: &MatchContext<'_>,
|
||||||
|
) -> Option<&'a RouteTarget> {
|
||||||
let targets = route.action.targets.as_ref()?;
|
let targets = route.action.targets.as_ref()?;
|
||||||
|
|
||||||
if targets.len() == 1 && targets[0].target_match.is_none() {
|
if targets.len() == 1 && targets[0].target_match.is_none() {
|
||||||
@@ -223,17 +224,11 @@ impl RouteManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to first target without match criteria
|
// Fall back to first target without match criteria
|
||||||
best.or_else(|| {
|
best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
|
||||||
targets.iter().find(|t| t.target_match.is_none())
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a target match criteria matches the context.
|
/// Check if a target match criteria matches the context.
|
||||||
fn matches_target(
|
fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
|
||||||
&self,
|
|
||||||
tm: &rustproxy_config::TargetMatch,
|
|
||||||
ctx: &MatchContext<'_>,
|
|
||||||
) -> bool {
|
|
||||||
// Port matching
|
// Port matching
|
||||||
if let Some(ref ports) = tm.ports {
|
if let Some(ref ports) = tm.ports {
|
||||||
if !ports.contains(&ctx.port) {
|
if !ports.contains(&ctx.port) {
|
||||||
@@ -298,9 +293,7 @@ impl RouteManager {
|
|||||||
// If multiple passthrough routes on same port, SNI is needed
|
// If multiple passthrough routes on same port, SNI is needed
|
||||||
let passthrough_routes: Vec<_> = routes
|
let passthrough_routes: Vec<_> = routes
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|r| {
|
.filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
|
||||||
r.tls_mode() == Some(&TlsMode::Passthrough)
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if passthrough_routes.len() > 1 {
|
if passthrough_routes.len() > 1 {
|
||||||
@@ -419,7 +412,11 @@ mod tests {
|
|||||||
|
|
||||||
let result = manager.find_route(&ctx).unwrap();
|
let result = manager.find_route(&ctx).unwrap();
|
||||||
// Should match the higher-priority specific route
|
// Should match the higher-priority specific route
|
||||||
assert!(result.route.route_match.domains.as_ref()
|
assert!(result
|
||||||
|
.route
|
||||||
|
.route_match
|
||||||
|
.domains
|
||||||
|
.as_ref()
|
||||||
.map(|d| d.to_vec())
|
.map(|d| d.to_vec())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.contains(&"api.example.com"));
|
.contains(&"api.example.com"));
|
||||||
@@ -619,8 +616,14 @@ mod tests {
|
|||||||
|
|
||||||
let result = manager.find_route(&ctx);
|
let result = manager.find_route(&ctx);
|
||||||
assert!(result.is_some());
|
assert!(result.is_some());
|
||||||
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
|
let matched_domains = result
|
||||||
.map(|d| d.to_vec()).unwrap();
|
.unwrap()
|
||||||
|
.route
|
||||||
|
.route_match
|
||||||
|
.domains
|
||||||
|
.as_ref()
|
||||||
|
.map(|d| d.to_vec())
|
||||||
|
.unwrap();
|
||||||
assert!(matched_domains.contains(&"*"));
|
assert!(matched_domains.contains(&"*"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -735,7 +738,11 @@ mod tests {
|
|||||||
assert_eq!(result.target.unwrap().host.first(), "default-backend");
|
assert_eq!(result.target.unwrap().host.first(), "default-backend");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
|
fn make_route_with_protocol(
|
||||||
|
port: u16,
|
||||||
|
domain: Option<&str>,
|
||||||
|
protocol: Option<&str>,
|
||||||
|
) -> RouteConfig {
|
||||||
let mut route = make_route(port, domain, 0);
|
let mut route = make_route(port, domain, 0);
|
||||||
route.route_match.protocol = protocol.map(|s| s.to_string());
|
route.route_match.protocol = protocol.map(|s| s.to_string());
|
||||||
route
|
route
|
||||||
@@ -1029,8 +1036,10 @@ mod tests {
|
|||||||
transport: Some(TransportProtocol::Udp),
|
transport: Some(TransportProtocol::Udp),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(manager.find_route(&ctx).is_some(),
|
assert!(
|
||||||
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes");
|
manager.find_route(&ctx).is_some(),
|
||||||
|
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -1051,7 +1060,9 @@ mod tests {
|
|||||||
transport: None, // TCP (default)
|
transport: None, // TCP (default)
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(manager.find_route(&ctx).is_none(),
|
assert!(
|
||||||
"TCP TLS without SNI should NOT match domain-restricted routes");
|
manager.find_route(&ctx).is_none(),
|
||||||
|
"TCP TLS without SNI should NOT match domain-restricted routes"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use base64::Engine;
|
|
||||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||||
|
use base64::Engine;
|
||||||
|
|
||||||
/// Basic auth validator.
|
/// Basic auth validator.
|
||||||
pub struct BasicAuthValidator {
|
pub struct BasicAuthValidator {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ struct DomainScopedEntry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Represents an IP pattern for matching.
|
/// Represents an IP pattern for matching.
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
enum IpPattern {
|
enum IpPattern {
|
||||||
/// Exact IP match
|
/// Exact IP match
|
||||||
Exact(IpAddr),
|
Exact(IpAddr),
|
||||||
@@ -31,6 +31,37 @@ enum IpPattern {
|
|||||||
Wildcard,
|
Wildcard,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compiled block list for early ingress filtering.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IpBlockList {
|
||||||
|
block_list: Vec<IpPattern>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IpBlockList {
|
||||||
|
pub fn new(block_list: &[String]) -> Self {
|
||||||
|
Self {
|
||||||
|
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self {
|
||||||
|
block_list: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.block_list.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
||||||
|
let normalized = IpFilter::normalize_ip(ip);
|
||||||
|
self.block_list
|
||||||
|
.iter()
|
||||||
|
.any(|pattern| pattern.matches(&normalized))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl IpPattern {
|
impl IpPattern {
|
||||||
fn parse(s: &str) -> Self {
|
fn parse(s: &str) -> Self {
|
||||||
let s = s.trim();
|
let s = s.trim();
|
||||||
@@ -68,8 +99,7 @@ fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
|
|||||||
}
|
}
|
||||||
if p.starts_with("*.") {
|
if p.starts_with("*.") {
|
||||||
let suffix = &p[1..]; // e.g., ".abc.xyz"
|
let suffix = &p[1..]; // e.g., ".abc.xyz"
|
||||||
d.len() > suffix.len()
|
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
||||||
&& d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -127,7 +157,11 @@ impl IpFilter {
|
|||||||
if let Some(req_domain) = domain {
|
if let Some(req_domain) = domain {
|
||||||
for entry in &self.domain_scoped {
|
for entry in &self.domain_scoped {
|
||||||
if entry.pattern.matches(ip) {
|
if entry.pattern.matches(ip) {
|
||||||
if entry.domains.iter().any(|d| domain_matches_pattern(d, req_domain)) {
|
if entry
|
||||||
|
.domains
|
||||||
|
.iter()
|
||||||
|
.any(|d| domain_matches_pattern(d, req_domain))
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -212,10 +246,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_block_trumps_allow() {
|
fn test_block_trumps_allow() {
|
||||||
let filter = IpFilter::new(
|
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
|
||||||
&[plain("10.0.0.0/8")],
|
|
||||||
&["10.0.0.5".to_string()],
|
|
||||||
);
|
|
||||||
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
||||||
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
|
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
|
||||||
assert!(!filter.is_allowed(&blocked));
|
assert!(!filter.is_allowed(&blocked));
|
||||||
@@ -255,30 +286,21 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_domain_scoped_allows_matching_domain() {
|
fn test_domain_scoped_allows_matching_domain() {
|
||||||
let filter = IpFilter::new(
|
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
|
||||||
&[],
|
|
||||||
);
|
|
||||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_domain_scoped_denies_non_matching_domain() {
|
fn test_domain_scoped_denies_non_matching_domain() {
|
||||||
let filter = IpFilter::new(
|
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
|
||||||
&[],
|
|
||||||
);
|
|
||||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||||
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_domain_scoped_denies_without_domain() {
|
fn test_domain_scoped_denies_without_domain() {
|
||||||
let filter = IpFilter::new(
|
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
|
||||||
&[],
|
|
||||||
);
|
|
||||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||||
// Without domain context, domain-scoped entries cannot match
|
// Without domain context, domain-scoped entries cannot match
|
||||||
assert!(!filter.is_allowed_for_domain(&ip, None));
|
assert!(!filter.is_allowed_for_domain(&ip, None));
|
||||||
@@ -286,10 +308,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_domain_scoped_wildcard_domain() {
|
fn test_domain_scoped_wildcard_domain() {
|
||||||
let filter = IpFilter::new(
|
let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
|
||||||
&[scoped("10.8.0.2", &["*.abc.xyz"])],
|
|
||||||
&[],
|
|
||||||
);
|
|
||||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||||
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
|
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// JWT claims (minimal structure).
|
/// JWT claims (minimal structure).
|
||||||
@@ -160,10 +160,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_extract_token_bearer() {
|
fn test_extract_token_bearer() {
|
||||||
assert_eq!(
|
assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
|
||||||
JwtValidator::extract_token("Bearer abc123"),
|
|
||||||
Some("abc123")
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
//!
|
//!
|
||||||
//! IP filtering, rate limiting, and authentication for RustProxy.
|
//! IP filtering, rate limiting, and authentication for RustProxy.
|
||||||
|
|
||||||
pub mod ip_filter;
|
|
||||||
pub mod rate_limiter;
|
|
||||||
pub mod basic_auth;
|
pub mod basic_auth;
|
||||||
|
pub mod ip_filter;
|
||||||
pub mod jwt_auth;
|
pub mod jwt_auth;
|
||||||
|
pub mod rate_limiter;
|
||||||
|
|
||||||
pub use ip_filter::*;
|
|
||||||
pub use rate_limiter::*;
|
|
||||||
pub use basic_auth::*;
|
pub use basic_auth::*;
|
||||||
|
pub use ip_filter::*;
|
||||||
pub use jwt_auth::*;
|
pub use jwt_auth::*;
|
||||||
|
pub use rate_limiter::*;
|
||||||
|
|||||||
@@ -4,8 +4,7 @@
|
|||||||
//! Account credentials are ephemeral — the consumer owns all persistence.
|
//! Account credentials are ephemeral — the consumer owns all persistence.
|
||||||
|
|
||||||
use instant_acme::{
|
use instant_acme::{
|
||||||
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
|
Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
|
||||||
AccountCredentials,
|
|
||||||
};
|
};
|
||||||
use rcgen::{CertificateParams, KeyPair};
|
use rcgen::{CertificateParams, KeyPair};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@@ -89,7 +88,11 @@ impl AcmeClient {
|
|||||||
F: FnOnce(PendingChallenge) -> Fut,
|
F: FnOnce(PendingChallenge) -> Fut,
|
||||||
Fut: std::future::Future<Output = Result<(), AcmeError>>,
|
Fut: std::future::Future<Output = Result<(), AcmeError>>,
|
||||||
{
|
{
|
||||||
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
|
info!(
|
||||||
|
"Starting ACME provisioning for {} via {}",
|
||||||
|
domain,
|
||||||
|
self.directory_url()
|
||||||
|
);
|
||||||
|
|
||||||
// 1. Get or create ACME account
|
// 1. Get or create ACME account
|
||||||
let account = self.get_or_create_account().await?;
|
let account = self.get_or_create_account().await?;
|
||||||
@@ -170,14 +173,14 @@ impl AcmeClient {
|
|||||||
debug!("Order ready, finalizing...");
|
debug!("Order ready, finalizing...");
|
||||||
|
|
||||||
// 6. Generate CSR and finalize
|
// 6. Generate CSR and finalize
|
||||||
let key_pair = KeyPair::generate().map_err(|e| {
|
let key_pair = KeyPair::generate()
|
||||||
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
|
.map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
|
let mut params = CertificateParams::new(vec![domain.to_string()])
|
||||||
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
|
.map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
|
||||||
})?;
|
params
|
||||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
.distinguished_name
|
||||||
|
.push(rcgen::DnType::CommonName, domain);
|
||||||
|
|
||||||
let csr = params.serialize_request(&key_pair).map_err(|e| {
|
let csr = params.serialize_request(&key_pair).map_err(|e| {
|
||||||
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
|
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
|
||||||
@@ -219,9 +222,7 @@ impl AcmeClient {
|
|||||||
.certificate()
|
.certificate()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
|
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
|
||||||
AcmeError::FinalizationFailed("No certificate returned".to_string())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let private_key_pem = key_pair.serialize_pem();
|
let private_key_pem = key_pair.serialize_pem();
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
|
|
||||||
use crate::acme::AcmeClient;
|
use crate::acme::AcmeClient;
|
||||||
|
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum CertManagerError {
|
pub enum CertManagerError {
|
||||||
@@ -45,17 +45,13 @@ impl CertManager {
|
|||||||
/// Create an ACME client using this manager's configuration.
|
/// Create an ACME client using this manager's configuration.
|
||||||
/// Returns None if no ACME email is configured.
|
/// Returns None if no ACME email is configured.
|
||||||
pub fn acme_client(&self) -> Option<AcmeClient> {
|
pub fn acme_client(&self) -> Option<AcmeClient> {
|
||||||
self.acme_email.as_ref().map(|email| {
|
self.acme_email
|
||||||
AcmeClient::new(email.clone(), self.use_production)
|
.as_ref()
|
||||||
})
|
.map(|email| AcmeClient::new(email.clone(), self.use_production))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load a static certificate into the store (infallible — pure cache insert).
|
/// Load a static certificate into the store (infallible — pure cache insert).
|
||||||
pub fn load_static(
|
pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
|
||||||
&mut self,
|
|
||||||
domain: String,
|
|
||||||
bundle: CertBundle,
|
|
||||||
) {
|
|
||||||
self.store.store(domain, bundle);
|
self.store.store(domain, bundle);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,20 +104,22 @@ impl CertManager {
|
|||||||
F: FnOnce(String, String) -> Fut,
|
F: FnOnce(String, String) -> Fut,
|
||||||
Fut: std::future::Future<Output = ()>,
|
Fut: std::future::Future<Output = ()>,
|
||||||
{
|
{
|
||||||
let acme_client = self.acme_client()
|
let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
|
||||||
.ok_or(CertManagerError::NoEmail)?;
|
|
||||||
|
|
||||||
info!("Renewing certificate for {}", domain);
|
info!("Renewing certificate for {}", domain);
|
||||||
|
|
||||||
let domain_owned = domain.to_string();
|
let domain_owned = domain.to_string();
|
||||||
let result = acme_client.provision(&domain_owned, |pending| {
|
let result = acme_client
|
||||||
|
.provision(&domain_owned, |pending| {
|
||||||
let token = pending.token.clone();
|
let token = pending.token.clone();
|
||||||
let key_auth = pending.key_authorization.clone();
|
let key_auth = pending.key_authorization.clone();
|
||||||
async move {
|
async move {
|
||||||
challenge_setup(token, key_auth).await;
|
challenge_setup(token, key_auth).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}).await.map_err(|e| CertManagerError::AcmeFailure {
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| CertManagerError::AcmeFailure {
|
||||||
domain: domain.to_string(),
|
domain: domain.to_string(),
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
})?;
|
})?;
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
/// Certificate metadata stored alongside certs.
|
/// Certificate metadata stored alongside certs.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -90,8 +90,10 @@ mod tests {
|
|||||||
|
|
||||||
fn make_test_bundle(domain: &str) -> CertBundle {
|
fn make_test_bundle(domain: &str) -> CertBundle {
|
||||||
CertBundle {
|
CertBundle {
|
||||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
|
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
|
||||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
|
.to_string(),
|
||||||
|
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
|
||||||
|
.to_string(),
|
||||||
ca_pem: None,
|
ca_pem: None,
|
||||||
metadata: CertMetadata {
|
metadata: CertMetadata {
|
||||||
domain: domain.to_string(),
|
domain: domain.to_string(),
|
||||||
@@ -122,7 +124,8 @@ mod tests {
|
|||||||
let mut store = CertStore::new();
|
let mut store = CertStore::new();
|
||||||
|
|
||||||
let mut bundle = make_test_bundle("secure.com");
|
let mut bundle = make_test_bundle("secure.com");
|
||||||
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
bundle.ca_pem =
|
||||||
|
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||||
store.store("secure.com".to_string(), bundle);
|
store.store("secure.com".to_string(), bundle);
|
||||||
|
|
||||||
let loaded = store.get("secure.com").unwrap();
|
let loaded = store.get("secure.com").unwrap();
|
||||||
@@ -147,7 +150,10 @@ mod tests {
|
|||||||
fn test_remove_cert() {
|
fn test_remove_cert() {
|
||||||
let mut store = CertStore::new();
|
let mut store = CertStore::new();
|
||||||
|
|
||||||
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com"));
|
store.store(
|
||||||
|
"remove-me.com".to_string(),
|
||||||
|
make_test_bundle("remove-me.com"),
|
||||||
|
);
|
||||||
assert!(store.has("remove-me.com"));
|
assert!(store.has("remove-me.com"));
|
||||||
|
|
||||||
let removed = store.remove("remove-me.com");
|
let removed = store.remove("remove-me.com");
|
||||||
@@ -165,7 +171,10 @@ mod tests {
|
|||||||
fn test_wildcard_domain() {
|
fn test_wildcard_domain() {
|
||||||
let mut store = CertStore::new();
|
let mut store = CertStore::new();
|
||||||
|
|
||||||
store.store("*.example.com".to_string(), make_test_bundle("*.example.com"));
|
store.store(
|
||||||
|
"*.example.com".to_string(),
|
||||||
|
make_test_bundle("*.example.com"),
|
||||||
|
);
|
||||||
assert!(store.has("*.example.com"));
|
assert!(store.has("*.example.com"));
|
||||||
|
|
||||||
let loaded = store.get("*.example.com").unwrap();
|
let loaded = store.get("*.example.com").unwrap();
|
||||||
|
|||||||
@@ -3,11 +3,11 @@
|
|||||||
//! TLS certificate management for RustProxy.
|
//! TLS certificate management for RustProxy.
|
||||||
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
|
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
|
||||||
|
|
||||||
pub mod cert_store;
|
|
||||||
pub mod cert_manager;
|
|
||||||
pub mod acme;
|
pub mod acme;
|
||||||
|
pub mod cert_manager;
|
||||||
|
pub mod cert_store;
|
||||||
pub mod sni_resolver;
|
pub mod sni_resolver;
|
||||||
|
|
||||||
pub use cert_store::*;
|
|
||||||
pub use cert_manager::*;
|
pub use cert_manager::*;
|
||||||
|
pub use cert_store::*;
|
||||||
pub use sni_resolver::*;
|
pub use sni_resolver::*;
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
|
|||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{debug, info, error};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
/// ACME HTTP-01 challenge server.
|
/// ACME HTTP-01 challenge server.
|
||||||
pub struct ChallengeServer {
|
pub struct ChallengeServer {
|
||||||
@@ -47,7 +47,10 @@ impl ChallengeServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Start the challenge server on the given port.
|
/// Start the challenge server on the given port.
|
||||||
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
pub async fn start(
|
||||||
|
&mut self,
|
||||||
|
port: u16,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let addr = format!("0.0.0.0:{}", port);
|
let addr = format!("0.0.0.0:{}", port);
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
let listener = TcpListener::bind(&addr).await?;
|
||||||
info!("ACME challenge server listening on port {}", port);
|
info!("ACME challenge server listening on port {}", port);
|
||||||
@@ -101,10 +104,7 @@ impl ChallengeServer {
|
|||||||
pub async fn stop(&mut self) {
|
pub async fn stop(&mut self) {
|
||||||
self.cancel.cancel();
|
self.cancel.cancel();
|
||||||
if let Some(handle) = self.handle.take() {
|
if let Some(handle) = self.handle.take() {
|
||||||
let _ = tokio::time::timeout(
|
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
|
||||||
std::time::Duration::from_secs(5),
|
|
||||||
handle,
|
|
||||||
).await;
|
|
||||||
}
|
}
|
||||||
self.challenges.clear();
|
self.challenges.clear();
|
||||||
self.cancel = CancellationToken::new();
|
self.cancel = CancellationToken::new();
|
||||||
@@ -154,10 +154,14 @@ mod tests {
|
|||||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
// Fetch the challenge
|
// Fetch the challenge
|
||||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
|
let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let io = TokioIo::new(client);
|
let io = TokioIo::new(client);
|
||||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
|
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
|
||||||
tokio::spawn(async move { let _ = conn.await; });
|
tokio::spawn(async move {
|
||||||
|
let _ = conn.await;
|
||||||
|
});
|
||||||
|
|
||||||
let req = Request::get("/.well-known/acme-challenge/test-token")
|
let req = Request::get("/.well-known/acme-challenge/test-token")
|
||||||
.body(Full::new(Bytes::new()))
|
.body(Full::new(Bytes::new()))
|
||||||
|
|||||||
+262
-105
@@ -57,24 +57,27 @@ use std::collections::{HashMap, HashSet};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use arc_swap::ArcSwap;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tracing::{info, warn, debug, error};
|
use arc_swap::ArcSwap;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
// Re-export key types
|
// Re-export key types
|
||||||
pub use rustproxy_config;
|
pub use rustproxy_config;
|
||||||
pub use rustproxy_routing;
|
|
||||||
pub use rustproxy_passthrough;
|
|
||||||
pub use rustproxy_tls;
|
|
||||||
pub use rustproxy_http;
|
pub use rustproxy_http;
|
||||||
pub use rustproxy_metrics;
|
pub use rustproxy_metrics;
|
||||||
|
pub use rustproxy_passthrough;
|
||||||
|
pub use rustproxy_routing;
|
||||||
pub use rustproxy_security;
|
pub use rustproxy_security;
|
||||||
|
pub use rustproxy_tls;
|
||||||
|
|
||||||
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec};
|
use rustproxy_config::{CertificateSpec, RouteConfig, RustProxyOptions, TlsMode};
|
||||||
|
use rustproxy_metrics::{Metrics, MetricsCollector, Statistics};
|
||||||
|
use rustproxy_passthrough::{
|
||||||
|
ConnectionConfig, TcpListenerManager, TlsCertConfig, UdpListenerManager,
|
||||||
|
};
|
||||||
use rustproxy_routing::RouteManager;
|
use rustproxy_routing::RouteManager;
|
||||||
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig};
|
use rustproxy_security::IpBlockList;
|
||||||
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
|
use rustproxy_tls::{CertBundle, CertManager, CertMetadata, CertSource, CertStore};
|
||||||
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
|
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
/// Certificate status.
|
/// Certificate status.
|
||||||
@@ -106,6 +109,8 @@ pub struct RustProxy {
|
|||||||
loaded_certs: HashMap<String, TlsCertConfig>,
|
loaded_certs: HashMap<String, TlsCertConfig>,
|
||||||
/// Cancellation token for cooperative shutdown of background tasks.
|
/// Cancellation token for cooperative shutdown of background tasks.
|
||||||
cancel_token: CancellationToken,
|
cancel_token: CancellationToken,
|
||||||
|
/// Shared global ingress blocklist, hot-reloadable across TCP/UDP listeners.
|
||||||
|
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RustProxy {
|
impl RustProxy {
|
||||||
@@ -127,13 +132,19 @@ impl RustProxy {
|
|||||||
let route_manager = RouteManager::new(options.routes.clone());
|
let route_manager = RouteManager::new(options.routes.clone());
|
||||||
|
|
||||||
// Set up certificate manager if ACME is configured
|
// Set up certificate manager if ACME is configured
|
||||||
let cert_manager = Self::build_cert_manager(&options)
|
let cert_manager =
|
||||||
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
Self::build_cert_manager(&options).map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
||||||
|
|
||||||
let retention = options.metrics.as_ref()
|
let retention = options
|
||||||
|
.metrics
|
||||||
|
.as_ref()
|
||||||
.and_then(|m| m.retention_seconds)
|
.and_then(|m| m.retention_seconds)
|
||||||
.unwrap_or(3600) as usize;
|
.unwrap_or(3600) as usize;
|
||||||
|
|
||||||
|
let security_policy = Arc::new(ArcSwap::from(Arc::new(Self::build_ip_block_list(
|
||||||
|
options.security_policy.as_ref(),
|
||||||
|
))));
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
options,
|
options,
|
||||||
route_table: ArcSwap::from(Arc::new(route_manager)),
|
route_table: ArcSwap::from(Arc::new(route_manager)),
|
||||||
@@ -149,6 +160,7 @@ impl RustProxy {
|
|||||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||||
loaded_certs: HashMap::new(),
|
loaded_certs: HashMap::new(),
|
||||||
cancel_token: CancellationToken::new(),
|
cancel_token: CancellationToken::new(),
|
||||||
|
security_policy,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,11 +175,13 @@ impl RustProxy {
|
|||||||
// Apply default target if route has no targets
|
// Apply default target if route has no targets
|
||||||
if route.action.targets.is_none() {
|
if route.action.targets.is_none() {
|
||||||
if let Some(ref default_target) = defaults.target {
|
if let Some(ref default_target) = defaults.target {
|
||||||
debug!("Applying default target {}:{} to route {:?}",
|
debug!(
|
||||||
default_target.host, default_target.port,
|
"Applying default target {}:{} to route {:?}",
|
||||||
route.name.as_deref().unwrap_or("unnamed"));
|
default_target.host,
|
||||||
route.action.targets = Some(vec![
|
default_target.port,
|
||||||
rustproxy_config::RouteTarget {
|
route.name.as_deref().unwrap_or("unnamed")
|
||||||
|
);
|
||||||
|
route.action.targets = Some(vec![rustproxy_config::RouteTarget {
|
||||||
target_match: None,
|
target_match: None,
|
||||||
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
||||||
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
||||||
@@ -179,8 +193,7 @@ impl RustProxy {
|
|||||||
advanced: None,
|
advanced: None,
|
||||||
backend_transport: None,
|
backend_transport: None,
|
||||||
priority: None,
|
priority: None,
|
||||||
}
|
}]);
|
||||||
]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,7 +212,10 @@ impl RustProxy {
|
|||||||
|
|
||||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||||
security.ip_allow_list = Some(
|
security.ip_allow_list = Some(
|
||||||
allow_list.iter().map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone())).collect()
|
allow_list
|
||||||
|
.iter()
|
||||||
|
.map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone()))
|
||||||
|
.collect(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if let Some(ref block_list) = default_security.ip_block_list {
|
if let Some(ref block_list) = default_security.ip_block_list {
|
||||||
@@ -208,8 +224,10 @@ impl RustProxy {
|
|||||||
|
|
||||||
// Only apply if there's something meaningful
|
// Only apply if there's something meaningful
|
||||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||||
debug!("Applying default security to route {:?}",
|
debug!(
|
||||||
route.name.as_deref().unwrap_or("unnamed"));
|
"Applying default security to route {:?}",
|
||||||
|
route.name.as_deref().unwrap_or("unnamed")
|
||||||
|
);
|
||||||
route.security = Some(security);
|
route.security = Some(security);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -224,13 +242,17 @@ impl RustProxy {
|
|||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let email = acme.email.clone()
|
let email = acme.email.clone().or_else(|| acme.account_email.clone());
|
||||||
.or_else(|| acme.account_email.clone());
|
|
||||||
let use_production = acme.use_production.unwrap_or(false);
|
let use_production = acme.use_production.unwrap_or(false);
|
||||||
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
|
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
|
||||||
|
|
||||||
let store = CertStore::new();
|
let store = CertStore::new();
|
||||||
Some(CertManager::new(store, email, use_production, renew_before_days))
|
Some(CertManager::new(
|
||||||
|
store,
|
||||||
|
email,
|
||||||
|
use_production,
|
||||||
|
renew_before_days,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build ConnectionConfig from RustProxyOptions.
|
/// Build ConnectionConfig from RustProxyOptions.
|
||||||
@@ -248,7 +270,10 @@ impl RustProxy {
|
|||||||
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
|
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
|
||||||
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
|
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
|
||||||
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
|
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
|
||||||
proxy_ips: options.proxy_ips.as_deref().unwrap_or(&[])
|
proxy_ips: options
|
||||||
|
.proxy_ips
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&[])
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|s| s.parse::<std::net::IpAddr>().ok())
|
.filter_map(|s| s.parse::<std::net::IpAddr>().ok())
|
||||||
.collect(),
|
.collect(),
|
||||||
@@ -258,6 +283,22 @@ impl RustProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_ip_block_list(policy: Option<&rustproxy_config::SecurityPolicy>) -> IpBlockList {
|
||||||
|
let Some(policy) = policy else {
|
||||||
|
return IpBlockList::empty();
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
if let Some(blocked_ips) = &policy.blocked_ips {
|
||||||
|
entries.extend(blocked_ips.iter().cloned());
|
||||||
|
}
|
||||||
|
if let Some(blocked_cidrs) = &policy.blocked_cidrs {
|
||||||
|
entries.extend(blocked_cidrs.iter().cloned());
|
||||||
|
}
|
||||||
|
|
||||||
|
IpBlockList::new(&entries)
|
||||||
|
}
|
||||||
|
|
||||||
/// Start the proxy, binding to all configured ports.
|
/// Start the proxy, binding to all configured ports.
|
||||||
pub async fn start(&mut self) -> Result<()> {
|
pub async fn start(&mut self) -> Result<()> {
|
||||||
if self.started {
|
if self.started {
|
||||||
@@ -272,7 +313,11 @@ impl RustProxy {
|
|||||||
let route_manager = self.route_table.load();
|
let route_manager = self.route_table.load();
|
||||||
let ports = route_manager.listening_ports();
|
let ports = route_manager.listening_ports();
|
||||||
|
|
||||||
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
|
info!(
|
||||||
|
"Configured {} routes on {} ports",
|
||||||
|
route_manager.route_count(),
|
||||||
|
ports.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Create TCP listener manager with metrics
|
// Create TCP listener manager with metrics
|
||||||
let mut listener = TcpListenerManager::with_metrics(
|
let mut listener = TcpListenerManager::with_metrics(
|
||||||
@@ -282,7 +327,8 @@ impl RustProxy {
|
|||||||
|
|
||||||
// Apply connection config from options
|
// Apply connection config from options
|
||||||
let conn_config = Self::build_connection_config(&self.options);
|
let conn_config = Self::build_connection_config(&self.options);
|
||||||
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
debug!(
|
||||||
|
"Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
||||||
conn_config.connection_timeout_ms,
|
conn_config.connection_timeout_ms,
|
||||||
conn_config.initial_data_timeout_ms,
|
conn_config.initial_data_timeout_ms,
|
||||||
conn_config.socket_timeout_ms,
|
conn_config.socket_timeout_ms,
|
||||||
@@ -291,6 +337,7 @@ impl RustProxy {
|
|||||||
// Clone proxy_ips before conn_config is moved into the TCP listener
|
// Clone proxy_ips before conn_config is moved into the TCP listener
|
||||||
let udp_proxy_ips = conn_config.proxy_ips.clone();
|
let udp_proxy_ips = conn_config.proxy_ips.clone();
|
||||||
listener.set_connection_config(conn_config);
|
listener.set_connection_config(conn_config);
|
||||||
|
listener.set_security_policy(Arc::clone(&self.security_policy));
|
||||||
|
|
||||||
// Share the socket-handler relay path with the listener
|
// Share the socket-handler relay path with the listener
|
||||||
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
|
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
|
||||||
@@ -303,10 +350,13 @@ impl RustProxy {
|
|||||||
let cm = cm.lock().await;
|
let cm = cm.lock().await;
|
||||||
for (domain, bundle) in cm.store().iter() {
|
for (domain, bundle) in cm.store().iter() {
|
||||||
if !tls_configs.contains_key(domain) {
|
if !tls_configs.contains_key(domain) {
|
||||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
tls_configs.insert(
|
||||||
|
domain.clone(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: bundle.cert_pem.clone(),
|
cert_pem: bundle.cert_pem.clone(),
|
||||||
key_pem: bundle.key_pem.clone(),
|
key_pem: bundle.key_pem.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -330,7 +380,9 @@ impl RustProxy {
|
|||||||
let mut tcp_ports = std::collections::HashSet::new();
|
let mut tcp_ports = std::collections::HashSet::new();
|
||||||
let mut udp_ports = std::collections::HashSet::new();
|
let mut udp_ports = std::collections::HashSet::new();
|
||||||
for route in &self.options.routes {
|
for route in &self.options.routes {
|
||||||
if !route.is_enabled() { continue; }
|
if !route.is_enabled() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let transport = route.route_match.transport.as_ref();
|
let transport = route.route_match.transport.as_ref();
|
||||||
let route_ports = route.route_match.ports.to_ports();
|
let route_ports = route.route_match.ports.to_ports();
|
||||||
for port in route_ports {
|
for port in route_ports {
|
||||||
@@ -371,6 +423,7 @@ impl RustProxy {
|
|||||||
connection_registry,
|
connection_registry,
|
||||||
);
|
);
|
||||||
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
|
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
|
||||||
|
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
|
||||||
|
|
||||||
// Share HttpProxyService with H3 — same route matching, connection
|
// Share HttpProxyService with H3 — same route matching, connection
|
||||||
// pool, and ALPN protocol detection as the TCP/HTTP path.
|
// pool, and ALPN protocol detection as the TCP/HTTP path.
|
||||||
@@ -379,10 +432,15 @@ impl RustProxy {
|
|||||||
udp_mgr.set_h3_service(Arc::new(h3_svc));
|
udp_mgr.set_h3_service(Arc::new(h3_svc));
|
||||||
|
|
||||||
for port in &udp_ports {
|
for port in &udp_ports {
|
||||||
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
|
udp_mgr
|
||||||
|
.add_port_with_tls(*port, quic_tls_config.clone())
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
info!("UDP listeners started on {} ports: {:?}",
|
info!(
|
||||||
udp_ports.len(), udp_mgr.listening_ports());
|
"UDP listeners started on {} ports: {:?}",
|
||||||
|
udp_ports.len(),
|
||||||
|
udp_mgr.listening_ports()
|
||||||
|
);
|
||||||
self.udp_listener_manager = Some(udp_mgr);
|
self.udp_listener_manager = Some(udp_mgr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,16 +449,22 @@ impl RustProxy {
|
|||||||
|
|
||||||
// Start the throughput sampling task with cooperative cancellation
|
// Start the throughput sampling task with cooperative cancellation
|
||||||
let metrics = Arc::clone(&self.metrics);
|
let metrics = Arc::clone(&self.metrics);
|
||||||
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
|
let conn_tracker = self
|
||||||
|
.listener_manager
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.conn_tracker()
|
||||||
|
.clone();
|
||||||
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
|
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
|
||||||
let interval_ms = self.options.metrics.as_ref()
|
let interval_ms = self
|
||||||
|
.options
|
||||||
|
.metrics
|
||||||
|
.as_ref()
|
||||||
.and_then(|m| m.sample_interval_ms)
|
.and_then(|m| m.sample_interval_ms)
|
||||||
.unwrap_or(1000);
|
.unwrap_or(1000);
|
||||||
let sampling_cancel = self.cancel_token.clone();
|
let sampling_cancel = self.cancel_token.clone();
|
||||||
self.sampling_handle = Some(tokio::spawn(async move {
|
self.sampling_handle = Some(tokio::spawn(async move {
|
||||||
let mut interval = tokio::time::interval(
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
|
||||||
std::time::Duration::from_millis(interval_ms)
|
|
||||||
);
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = sampling_cancel.cancelled() => break,
|
_ = sampling_cancel.cancelled() => break,
|
||||||
@@ -442,7 +506,10 @@ impl RustProxy {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let cert_spec = route.action.tls.as_ref()
|
let cert_spec = route
|
||||||
|
.action
|
||||||
|
.tls
|
||||||
|
.as_ref()
|
||||||
.and_then(|tls| tls.certificate.as_ref());
|
.and_then(|tls| tls.certificate.as_ref());
|
||||||
|
|
||||||
if let Some(CertificateSpec::Auto(_)) = cert_spec {
|
if let Some(CertificateSpec::Auto(_)) = cert_spec {
|
||||||
@@ -466,16 +533,25 @@ impl RustProxy {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
|
info!(
|
||||||
|
"Auto-provisioning certificates for {} domains",
|
||||||
|
domains_to_provision.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Start challenge server
|
// Start challenge server
|
||||||
let acme_port = self.options.acme.as_ref()
|
let acme_port = self
|
||||||
|
.options
|
||||||
|
.acme
|
||||||
|
.as_ref()
|
||||||
.and_then(|a| a.port)
|
.and_then(|a| a.port)
|
||||||
.unwrap_or(80);
|
.unwrap_or(80);
|
||||||
|
|
||||||
let mut challenge_server = challenge_server::ChallengeServer::new();
|
let mut challenge_server = challenge_server::ChallengeServer::new();
|
||||||
if let Err(e) = challenge_server.start(acme_port).await {
|
if let Err(e) = challenge_server.start(acme_port).await {
|
||||||
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
|
error!(
|
||||||
|
"Failed to start ACME challenge server on port {}: {}",
|
||||||
|
acme_port, e
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -488,13 +564,15 @@ impl RustProxy {
|
|||||||
|
|
||||||
if let Some(acme_client) = acme_client {
|
if let Some(acme_client) = acme_client {
|
||||||
let challenge_server_ref = &challenge_server;
|
let challenge_server_ref = &challenge_server;
|
||||||
let result = acme_client.provision(domain, |pending| {
|
let result = acme_client
|
||||||
|
.provision(domain, |pending| {
|
||||||
challenge_server_ref.set_challenge(
|
challenge_server_ref.set_challenge(
|
||||||
pending.token.clone(),
|
pending.token.clone(),
|
||||||
pending.key_authorization.clone(),
|
pending.key_authorization.clone(),
|
||||||
);
|
);
|
||||||
async move { Ok(()) }
|
async move { Ok(()) }
|
||||||
}).await;
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok((cert_pem, key_pem)) => {
|
Ok((cert_pem, key_pem)) => {
|
||||||
@@ -539,7 +617,10 @@ impl RustProxy {
|
|||||||
None => return,
|
None => return,
|
||||||
};
|
};
|
||||||
|
|
||||||
let auto_renew = self.options.acme.as_ref()
|
let auto_renew = self
|
||||||
|
.options
|
||||||
|
.acme
|
||||||
|
.as_ref()
|
||||||
.and_then(|a| a.auto_renew)
|
.and_then(|a| a.auto_renew)
|
||||||
.unwrap_or(true);
|
.unwrap_or(true);
|
||||||
|
|
||||||
@@ -547,11 +628,17 @@ impl RustProxy {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let check_interval_hours = self.options.acme.as_ref()
|
let check_interval_hours = self
|
||||||
|
.options
|
||||||
|
.acme
|
||||||
|
.as_ref()
|
||||||
.and_then(|a| a.renew_check_interval_hours)
|
.and_then(|a| a.renew_check_interval_hours)
|
||||||
.unwrap_or(24);
|
.unwrap_or(24);
|
||||||
|
|
||||||
let acme_port = self.options.acme.as_ref()
|
let acme_port = self
|
||||||
|
.options
|
||||||
|
.acme
|
||||||
|
.as_ref()
|
||||||
.and_then(|a| a.port)
|
.and_then(|a| a.port)
|
||||||
.unwrap_or(80);
|
.unwrap_or(80);
|
||||||
|
|
||||||
@@ -664,8 +751,7 @@ impl RustProxy {
|
|||||||
/// Update routes atomically (hot-reload).
|
/// Update routes atomically (hot-reload).
|
||||||
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
|
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
|
||||||
// Validate new routes
|
// Validate new routes
|
||||||
rustproxy_config::validate_routes(&routes)
|
rustproxy_config::validate_routes(&routes).map_err(|errors| {
|
||||||
.map_err(|errors| {
|
|
||||||
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
||||||
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
||||||
})?;
|
})?;
|
||||||
@@ -673,8 +759,11 @@ impl RustProxy {
|
|||||||
let new_manager = RouteManager::new(routes.clone());
|
let new_manager = RouteManager::new(routes.clone());
|
||||||
let new_ports = new_manager.listening_ports();
|
let new_ports = new_manager.listening_ports();
|
||||||
|
|
||||||
info!("Updating routes: {} routes on {} ports",
|
info!(
|
||||||
new_manager.route_count(), new_ports.len());
|
"Updating routes: {} routes on {} ports",
|
||||||
|
new_manager.route_count(),
|
||||||
|
new_ports.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Get old ports
|
// Get old ports
|
||||||
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
|
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
|
||||||
@@ -684,28 +773,35 @@ impl RustProxy {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Prune per-route metrics for route IDs that no longer exist
|
// Prune per-route metrics for route IDs that no longer exist
|
||||||
let active_route_ids: HashSet<String> = routes.iter()
|
let active_route_ids: HashSet<String> =
|
||||||
.filter_map(|r| r.id.clone())
|
routes.iter().filter_map(|r| r.id.clone()).collect();
|
||||||
.collect();
|
|
||||||
self.metrics.retain_routes(&active_route_ids);
|
self.metrics.retain_routes(&active_route_ids);
|
||||||
|
|
||||||
// Prune per-backend metrics for backends no longer in any route target.
|
// Prune per-backend metrics for backends no longer in any route target.
|
||||||
// For PortSpec::Preserve routes, expand across all listening ports since
|
// For PortSpec::Preserve routes, expand across all listening ports since
|
||||||
// the actual runtime port depends on the incoming connection.
|
// the actual runtime port depends on the incoming connection.
|
||||||
let listening_ports = self.get_listening_ports();
|
let listening_ports = self.get_listening_ports();
|
||||||
let active_backends: HashSet<String> = routes.iter()
|
let active_backends: HashSet<String> = routes
|
||||||
|
.iter()
|
||||||
.filter_map(|r| r.action.targets.as_ref())
|
.filter_map(|r| r.action.targets.as_ref())
|
||||||
.flat_map(|targets| targets.iter())
|
.flat_map(|targets| targets.iter())
|
||||||
.flat_map(|target| {
|
.flat_map(|target| {
|
||||||
let hosts: Vec<String> = target.host.to_vec().into_iter().map(|s| s.to_string()).collect();
|
let hosts: Vec<String> = target
|
||||||
|
.host
|
||||||
|
.to_vec()
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
match &target.port {
|
match &target.port {
|
||||||
rustproxy_config::PortSpec::Fixed(p) => {
|
rustproxy_config::PortSpec::Fixed(p) => hosts
|
||||||
hosts.into_iter().map(|h| format!("{}:{}", h, p)).collect::<Vec<_>>()
|
.into_iter()
|
||||||
}
|
.map(|h| format!("{}:{}", h, p))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
_ => {
|
_ => {
|
||||||
// Preserve/special: expand across all listening ports
|
// Preserve/special: expand across all listening ports
|
||||||
let lp = &listening_ports;
|
let lp = &listening_ports;
|
||||||
hosts.into_iter()
|
hosts
|
||||||
|
.into_iter()
|
||||||
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
|
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
@@ -733,10 +829,13 @@ impl RustProxy {
|
|||||||
let cm = cm_arc.lock().await;
|
let cm = cm_arc.lock().await;
|
||||||
for (domain, bundle) in cm.store().iter() {
|
for (domain, bundle) in cm.store().iter() {
|
||||||
if !tls_configs.contains_key(domain) {
|
if !tls_configs.contains_key(domain) {
|
||||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
tls_configs.insert(
|
||||||
|
domain.clone(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: bundle.cert_pem.clone(),
|
cert_pem: bundle.cert_pem.clone(),
|
||||||
key_pem: bundle.key_pem.clone(),
|
key_pem: bundle.key_pem.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -753,7 +852,9 @@ impl RustProxy {
|
|||||||
// Cancel connections on routes that were removed or disabled
|
// Cancel connections on routes that were removed or disabled
|
||||||
listener.invalidate_removed_routes(&active_route_ids);
|
listener.invalidate_removed_routes(&active_route_ids);
|
||||||
// Clean up registry entries for removed routes
|
// Clean up registry entries for removed routes
|
||||||
listener.connection_registry().cleanup_removed_routes(&active_route_ids);
|
listener
|
||||||
|
.connection_registry()
|
||||||
|
.cleanup_removed_routes(&active_route_ids);
|
||||||
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
|
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
|
||||||
listener.prune_http_proxy_caches(&active_route_ids);
|
listener.prune_http_proxy_caches(&active_route_ids);
|
||||||
|
|
||||||
@@ -766,9 +867,10 @@ impl RustProxy {
|
|||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
// Find corresponding old route
|
// Find corresponding old route
|
||||||
let old_route = old_manager.routes().iter().find(|r| {
|
let old_route = old_manager
|
||||||
r.id.as_deref() == Some(new_id)
|
.routes()
|
||||||
});
|
.iter()
|
||||||
|
.find(|r| r.id.as_deref() == Some(new_id));
|
||||||
let old_route = match old_route {
|
let old_route = match old_route {
|
||||||
Some(r) => r,
|
Some(r) => r,
|
||||||
None => continue, // new route, no existing connections to recycle
|
None => continue, // new route, no existing connections to recycle
|
||||||
@@ -812,11 +914,13 @@ impl RustProxy {
|
|||||||
{
|
{
|
||||||
let mut new_udp_ports = HashSet::new();
|
let mut new_udp_ports = HashSet::new();
|
||||||
for route in &routes {
|
for route in &routes {
|
||||||
if !route.is_enabled() { continue; }
|
if !route.is_enabled() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let transport = route.route_match.transport.as_ref();
|
let transport = route.route_match.transport.as_ref();
|
||||||
match transport {
|
match transport {
|
||||||
Some(rustproxy_config::TransportProtocol::Udp) |
|
Some(rustproxy_config::TransportProtocol::Udp)
|
||||||
Some(rustproxy_config::TransportProtocol::All) => {
|
| Some(rustproxy_config::TransportProtocol::All) => {
|
||||||
for port in route.route_match.ports.to_ports() {
|
for port in route.route_match.ports.to_ports() {
|
||||||
new_udp_ports.insert(port);
|
new_udp_ports.insert(port);
|
||||||
}
|
}
|
||||||
@@ -825,7 +929,8 @@ impl RustProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let old_udp_ports: HashSet<u16> = self.udp_listener_manager
|
let old_udp_ports: HashSet<u16> = self
|
||||||
|
.udp_listener_manager
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|u| u.listening_ports().into_iter().collect())
|
.map(|u| u.listening_ports().into_iter().collect())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
@@ -847,6 +952,7 @@ impl RustProxy {
|
|||||||
connection_registry,
|
connection_registry,
|
||||||
);
|
);
|
||||||
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
|
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
|
||||||
|
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
|
||||||
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
|
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
|
||||||
let http_proxy = listener.http_proxy().clone();
|
let http_proxy = listener.http_proxy().clone();
|
||||||
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
|
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
|
||||||
@@ -898,56 +1004,77 @@ impl RustProxy {
|
|||||||
|
|
||||||
/// Provision a certificate for a named route.
|
/// Provision a certificate for a named route.
|
||||||
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
|
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||||
let cm_arc = self.cert_manager.as_ref()
|
let cm_arc = self.cert_manager.as_ref().ok_or_else(|| {
|
||||||
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
|
anyhow::anyhow!("No certificate manager configured (ACME not enabled)")
|
||||||
|
})?;
|
||||||
|
|
||||||
// Find the route by name
|
// Find the route by name
|
||||||
let route = self.options.routes.iter()
|
let route = self
|
||||||
|
.options
|
||||||
|
.routes
|
||||||
|
.iter()
|
||||||
.find(|r| r.name.as_deref() == Some(route_name))
|
.find(|r| r.name.as_deref() == Some(route_name))
|
||||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
|
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
|
||||||
|
|
||||||
let domain = route.route_match.domains.as_ref()
|
let domain = route
|
||||||
|
.route_match
|
||||||
|
.domains
|
||||||
|
.as_ref()
|
||||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
|
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
|
||||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
|
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
|
||||||
|
|
||||||
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
|
info!(
|
||||||
|
"Provisioning certificate for route '{}' (domain: {})",
|
||||||
|
route_name, domain
|
||||||
|
);
|
||||||
|
|
||||||
// Start challenge server
|
// Start challenge server
|
||||||
let acme_port = self.options.acme.as_ref()
|
let acme_port = self
|
||||||
|
.options
|
||||||
|
.acme
|
||||||
|
.as_ref()
|
||||||
.and_then(|a| a.port)
|
.and_then(|a| a.port)
|
||||||
.unwrap_or(80);
|
.unwrap_or(80);
|
||||||
|
|
||||||
let mut cs = challenge_server::ChallengeServer::new();
|
let mut cs = challenge_server::ChallengeServer::new();
|
||||||
cs.start(acme_port).await
|
cs.start(acme_port)
|
||||||
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
|
||||||
|
|
||||||
let cs_ref = &cs;
|
let cs_ref = &cs;
|
||||||
let mut cm = cm_arc.lock().await;
|
let mut cm = cm_arc.lock().await;
|
||||||
let result = cm.renew_domain(&domain, |token, key_auth| {
|
let result = cm
|
||||||
|
.renew_domain(&domain, |token, key_auth| {
|
||||||
cs_ref.set_challenge(token, key_auth);
|
cs_ref.set_challenge(token, key_auth);
|
||||||
async {}
|
async {}
|
||||||
}).await;
|
})
|
||||||
|
.await;
|
||||||
drop(cm);
|
drop(cm);
|
||||||
|
|
||||||
cs.stop().await;
|
cs.stop().await;
|
||||||
|
|
||||||
let bundle = result
|
let bundle = result.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||||
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
|
||||||
|
|
||||||
// Hot-swap into TLS configs
|
// Hot-swap into TLS configs
|
||||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
tls_configs.insert(
|
||||||
|
domain.clone(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: bundle.cert_pem.clone(),
|
cert_pem: bundle.cert_pem.clone(),
|
||||||
key_pem: bundle.key_pem.clone(),
|
key_pem: bundle.key_pem.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
{
|
{
|
||||||
let cm = cm_arc.lock().await;
|
let cm = cm_arc.lock().await;
|
||||||
for (d, b) in cm.store().iter() {
|
for (d, b) in cm.store().iter() {
|
||||||
if !tls_configs.contains_key(d) {
|
if !tls_configs.contains_key(d) {
|
||||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
tls_configs.insert(
|
||||||
|
d.clone(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: b.cert_pem.clone(),
|
cert_pem: b.cert_pem.clone(),
|
||||||
key_pem: b.key_pem.clone(),
|
key_pem: b.key_pem.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -966,7 +1093,10 @@ impl RustProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
info!(
|
||||||
|
"Certificate provisioned and loaded for route '{}'",
|
||||||
|
route_name
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -978,10 +1108,16 @@ impl RustProxy {
|
|||||||
|
|
||||||
/// Get the status of a certificate for a named route.
|
/// Get the status of a certificate for a named route.
|
||||||
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
|
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
|
||||||
let route = self.options.routes.iter()
|
let route = self
|
||||||
|
.options
|
||||||
|
.routes
|
||||||
|
.iter()
|
||||||
.find(|r| r.name.as_deref() == Some(route_name))?;
|
.find(|r| r.name.as_deref() == Some(route_name))?;
|
||||||
|
|
||||||
let domain = route.route_match.domains.as_ref()
|
let domain = route
|
||||||
|
.route_match
|
||||||
|
.domains
|
||||||
|
.as_ref()
|
||||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
|
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
|
||||||
|
|
||||||
if let Some(ref cm_arc) = self.cert_manager {
|
if let Some(ref cm_arc) = self.cert_manager {
|
||||||
@@ -1010,8 +1146,9 @@ impl RustProxy {
|
|||||||
let mut metrics = self.metrics.snapshot();
|
let mut metrics = self.metrics.snapshot();
|
||||||
if let Some(ref lm) = self.listener_manager {
|
if let Some(ref lm) = self.listener_manager {
|
||||||
let entries = lm.http_proxy().protocol_cache_snapshot();
|
let entries = lm.http_proxy().protocol_cache_snapshot();
|
||||||
metrics.detected_protocols = entries.into_iter().map(|e| {
|
metrics.detected_protocols = entries
|
||||||
rustproxy_metrics::ProtocolCacheEntryMetric {
|
.into_iter()
|
||||||
|
.map(|e| rustproxy_metrics::ProtocolCacheEntryMetric {
|
||||||
host: e.host,
|
host: e.host,
|
||||||
port: e.port,
|
port: e.port,
|
||||||
domain: e.domain,
|
domain: e.domain,
|
||||||
@@ -1026,8 +1163,8 @@ impl RustProxy {
|
|||||||
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
|
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
|
||||||
h2_consecutive_failures: e.h2_consecutive_failures,
|
h2_consecutive_failures: e.h2_consecutive_failures,
|
||||||
h3_consecutive_failures: e.h3_consecutive_failures,
|
h3_consecutive_failures: e.h3_consecutive_failures,
|
||||||
}
|
})
|
||||||
}).collect();
|
.collect();
|
||||||
}
|
}
|
||||||
metrics
|
metrics
|
||||||
}
|
}
|
||||||
@@ -1058,9 +1195,7 @@ impl RustProxy {
|
|||||||
|
|
||||||
/// Get statistics snapshot.
|
/// Get statistics snapshot.
|
||||||
pub fn get_statistics(&self) -> Statistics {
|
pub fn get_statistics(&self) -> Statistics {
|
||||||
let uptime = self.started_at
|
let uptime = self.started_at.map(|t| t.elapsed().as_secs()).unwrap_or(0);
|
||||||
.map(|t| t.elapsed().as_secs())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
Statistics {
|
Statistics {
|
||||||
active_connections: self.metrics.active_connections(),
|
active_connections: self.metrics.active_connections(),
|
||||||
@@ -1071,6 +1206,13 @@ impl RustProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update the global ingress security policy.
|
||||||
|
pub fn set_security_policy(&mut self, policy: rustproxy_config::SecurityPolicy) {
|
||||||
|
self.security_policy
|
||||||
|
.store(Arc::new(Self::build_ip_block_list(Some(&policy))));
|
||||||
|
self.options.security_policy = Some(policy);
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
|
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
|
||||||
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
|
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
|
||||||
/// take effect immediately for all new connections.
|
/// take effect immediately for all new connections.
|
||||||
@@ -1130,10 +1272,13 @@ impl RustProxy {
|
|||||||
let cm = cm_arc.lock().await;
|
let cm = cm_arc.lock().await;
|
||||||
for (d, b) in cm.store().iter() {
|
for (d, b) in cm.store().iter() {
|
||||||
if !configs.contains_key(d) {
|
if !configs.contains_key(d) {
|
||||||
configs.insert(d.clone(), TlsCertConfig {
|
configs.insert(
|
||||||
|
d.clone(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: b.cert_pem.clone(),
|
cert_pem: b.cert_pem.clone(),
|
||||||
key_pem: b.key_pem.clone(),
|
key_pem: b.key_pem.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1166,7 +1311,8 @@ impl RustProxy {
|
|||||||
info!("Loading certificate for domain: {}", domain);
|
info!("Loading certificate for domain: {}", domain);
|
||||||
|
|
||||||
// Check if the cert actually changed (for selective connection recycling)
|
// Check if the cert actually changed (for selective connection recycling)
|
||||||
let cert_changed = self.loaded_certs
|
let cert_changed = self
|
||||||
|
.loaded_certs
|
||||||
.get(domain)
|
.get(domain)
|
||||||
.map(|existing| existing.cert_pem != cert_pem)
|
.map(|existing| existing.cert_pem != cert_pem)
|
||||||
.unwrap_or(false); // new domain = no existing connections to recycle
|
.unwrap_or(false); // new domain = no existing connections to recycle
|
||||||
@@ -1196,10 +1342,13 @@ impl RustProxy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Persist in loaded_certs so future rebuild calls include this cert
|
// Persist in loaded_certs so future rebuild calls include this cert
|
||||||
self.loaded_certs.insert(domain.to_string(), TlsCertConfig {
|
self.loaded_certs.insert(
|
||||||
|
domain.to_string(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: cert_pem.clone(),
|
cert_pem: cert_pem.clone(),
|
||||||
key_pem: key_pem.clone(),
|
key_pem: key_pem.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
|
|
||||||
// Hot-swap TLS config on TCP and QUIC listeners
|
// Hot-swap TLS config on TCP and QUIC listeners
|
||||||
let tls_configs = self.current_tls_configs().await;
|
let tls_configs = self.current_tls_configs().await;
|
||||||
@@ -1222,7 +1371,9 @@ impl RustProxy {
|
|||||||
// Recycle existing connections if cert actually changed
|
// Recycle existing connections if cert actually changed
|
||||||
if cert_changed {
|
if cert_changed {
|
||||||
if let Some(ref listener) = self.listener_manager {
|
if let Some(ref listener) = self.listener_manager {
|
||||||
listener.connection_registry().recycle_for_cert_change(domain);
|
listener
|
||||||
|
.connection_registry()
|
||||||
|
.recycle_for_cert_change(domain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1244,16 +1395,22 @@ impl RustProxy {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let cert_spec = route.action.tls.as_ref()
|
let cert_spec = route
|
||||||
|
.action
|
||||||
|
.tls
|
||||||
|
.as_ref()
|
||||||
.and_then(|tls| tls.certificate.as_ref());
|
.and_then(|tls| tls.certificate.as_ref());
|
||||||
|
|
||||||
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
|
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
|
||||||
if let Some(ref domains) = route.route_match.domains {
|
if let Some(ref domains) = route.route_match.domains {
|
||||||
for domain in domains.to_vec() {
|
for domain in domains.to_vec() {
|
||||||
configs.insert(domain.to_string(), TlsCertConfig {
|
configs.insert(
|
||||||
|
domain.to_string(),
|
||||||
|
TlsCertConfig {
|
||||||
cert_pem: cert_config.cert.clone(),
|
cert_pem: cert_config.cert.clone(),
|
||||||
key_pem: cert_config.key.clone(),
|
key_pem: cert_config.key.clone(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use anyhow::Result;
|
|
||||||
|
|
||||||
use rustproxy::RustProxy;
|
|
||||||
use rustproxy::management;
|
use rustproxy::management;
|
||||||
|
use rustproxy::RustProxy;
|
||||||
use rustproxy_config::RustProxyOptions;
|
use rustproxy_config::RustProxyOptions;
|
||||||
|
|
||||||
/// RustProxy - High-performance multi-protocol proxy
|
/// RustProxy - High-performance multi-protocol proxy
|
||||||
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
|
|||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_writer(std::io::stderr)
|
.with_writer(std::io::stderr)
|
||||||
.with_env_filter(
|
.with_env_filter(
|
||||||
EnvFilter::try_from_default_env()
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
|
||||||
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
|
|
||||||
)
|
)
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
|
|||||||
let options = RustProxyOptions::from_file(&cli.config)
|
let options = RustProxyOptions::from_file(&cli.config)
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
|
||||||
"Loaded {} routes from {}",
|
|
||||||
options.routes.len(),
|
|
||||||
cli.config
|
|
||||||
);
|
|
||||||
|
|
||||||
// Validate-only mode
|
// Validate-only mode
|
||||||
if cli.validate {
|
if cli.validate {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
use tracing::{info, error};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use crate::RustProxy;
|
use crate::RustProxy;
|
||||||
use rustproxy_config::RustProxyOptions;
|
use rustproxy_config::RustProxyOptions;
|
||||||
@@ -141,14 +141,19 @@ async fn handle_request(
|
|||||||
"start" => handle_start(&id, &request.params, proxy).await,
|
"start" => handle_start(&id, &request.params, proxy).await,
|
||||||
"stop" => handle_stop(&id, proxy).await,
|
"stop" => handle_stop(&id, proxy).await,
|
||||||
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
|
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
|
||||||
|
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
|
||||||
"getMetrics" => handle_get_metrics(&id, proxy),
|
"getMetrics" => handle_get_metrics(&id, proxy),
|
||||||
"getStatistics" => handle_get_statistics(&id, proxy),
|
"getStatistics" => handle_get_statistics(&id, proxy),
|
||||||
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
|
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
|
||||||
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
|
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
|
||||||
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
|
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
|
||||||
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
|
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
|
||||||
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
|
"setSocketHandlerRelay" => {
|
||||||
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
|
handle_set_socket_handler_relay(&id, &request.params, proxy).await
|
||||||
|
}
|
||||||
|
"setDatagramHandlerRelay" => {
|
||||||
|
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
|
||||||
|
}
|
||||||
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
|
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
|
||||||
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
|
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
|
||||||
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
|
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
|
||||||
@@ -167,7 +172,12 @@ async fn handle_start(
|
|||||||
|
|
||||||
let config = match params.get("config") {
|
let config = match params.get("config") {
|
||||||
Some(config) => config,
|
Some(config) => config,
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'config' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
|
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
|
||||||
@@ -176,8 +186,7 @@ async fn handle_start(
|
|||||||
};
|
};
|
||||||
|
|
||||||
match RustProxy::new(options) {
|
match RustProxy::new(options) {
|
||||||
Ok(mut p) => {
|
Ok(mut p) => match p.start().await {
|
||||||
match p.start().await {
|
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
send_event("started", serde_json::json!({}));
|
send_event("started", serde_json::json!({}));
|
||||||
*proxy = Some(p);
|
*proxy = Some(p);
|
||||||
@@ -187,27 +196,21 @@ async fn handle_start(
|
|||||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
|
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_stop(
|
async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
|
||||||
id: &str,
|
|
||||||
proxy: &mut Option<RustProxy>,
|
|
||||||
) -> ManagementResponse {
|
|
||||||
match proxy.as_mut() {
|
match proxy.as_mut() {
|
||||||
Some(p) => {
|
Some(p) => match p.stop().await {
|
||||||
match p.stop().await {
|
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
*proxy = None;
|
*proxy = None;
|
||||||
send_event("stopped", serde_json::json!({}));
|
send_event("stopped", serde_json::json!({}));
|
||||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||||
}
|
}
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||||
}
|
},
|
||||||
}
|
|
||||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -224,7 +227,12 @@ async fn handle_update_routes(
|
|||||||
|
|
||||||
let routes = match params.get("routes") {
|
let routes = match params.get("routes") {
|
||||||
Some(routes) => routes,
|
Some(routes) => routes,
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'routes' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
|
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
|
||||||
@@ -234,36 +242,72 @@ async fn handle_update_routes(
|
|||||||
|
|
||||||
match p.update_routes(routes).await {
|
match p.update_routes(routes).await {
|
||||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
|
Err(e) => {
|
||||||
|
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_get_metrics(
|
fn handle_set_security_policy(
|
||||||
id: &str,
|
id: &str,
|
||||||
proxy: &Option<RustProxy>,
|
params: &serde_json::Value,
|
||||||
|
proxy: &mut Option<RustProxy>,
|
||||||
) -> ManagementResponse {
|
) -> ManagementResponse {
|
||||||
|
let p = match proxy.as_mut() {
|
||||||
|
Some(p) => p,
|
||||||
|
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let policy = match params.get("policy") {
|
||||||
|
Some(policy) => policy,
|
||||||
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'policy' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
|
||||||
|
Ok(policy) => policy,
|
||||||
|
Err(e) => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Invalid security policy: {}", e),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
p.set_security_policy(policy);
|
||||||
|
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||||
match proxy.as_ref() {
|
match proxy.as_ref() {
|
||||||
Some(p) => {
|
Some(p) => {
|
||||||
let metrics = p.get_metrics();
|
let metrics = p.get_metrics();
|
||||||
match serde_json::to_value(&metrics) {
|
match serde_json::to_value(&metrics) {
|
||||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to serialize metrics: {}", e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_get_statistics(
|
fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||||
id: &str,
|
|
||||||
proxy: &Option<RustProxy>,
|
|
||||||
) -> ManagementResponse {
|
|
||||||
match proxy.as_ref() {
|
match proxy.as_ref() {
|
||||||
Some(p) => {
|
Some(p) => {
|
||||||
let stats = p.get_statistics();
|
let stats = p.get_statistics();
|
||||||
match serde_json::to_value(&stats) {
|
match serde_json::to_value(&stats) {
|
||||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to serialize statistics: {}", e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||||
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
|
|||||||
|
|
||||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||||
Some(name) => name.to_string(),
|
Some(name) => name.to_string(),
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'routeName' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match p.provision_certificate(&route_name).await {
|
match p.provision_certificate(&route_name).await {
|
||||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to provision certificate: {}", e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
|
|||||||
|
|
||||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||||
Some(name) => name.to_string(),
|
Some(name) => name.to_string(),
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'routeName' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match p.renew_certificate(&route_name).await {
|
match p.renew_certificate(&route_name).await {
|
||||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to renew certificate: {}", e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
|
|||||||
|
|
||||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||||
Some(name) => name,
|
Some(name) => name,
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'routeName' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match p.get_certificate_status(route_name).await {
|
match p.get_certificate_status(route_name).await {
|
||||||
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
|
Some(status) => ManagementResponse::ok(
|
||||||
|
id.to_string(),
|
||||||
|
serde_json::json!({
|
||||||
"domain": status.domain,
|
"domain": status.domain,
|
||||||
"source": status.source,
|
"source": status.source,
|
||||||
"expiresAt": status.expires_at,
|
"expiresAt": status.expires_at,
|
||||||
"isValid": status.is_valid,
|
"isValid": status.is_valid,
|
||||||
})),
|
}),
|
||||||
|
),
|
||||||
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
|
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_get_listening_ports(
|
fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||||
id: &str,
|
|
||||||
proxy: &Option<RustProxy>,
|
|
||||||
) -> ManagementResponse {
|
|
||||||
match proxy.as_ref() {
|
match proxy.as_ref() {
|
||||||
Some(p) => {
|
Some(p) => {
|
||||||
let ports = p.get_listening_ports();
|
let ports = p.get_listening_ports();
|
||||||
@@ -361,7 +426,8 @@ async fn handle_set_socket_handler_relay(
|
|||||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let socket_path = params.get("socketPath")
|
let socket_path = params
|
||||||
|
.get("socketPath")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.map(|s| s.to_string());
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
@@ -381,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
|
|||||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let socket_path = params.get("socketPath")
|
let socket_path = params
|
||||||
|
.get("socketPath")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.map(|s| s.to_string());
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
|
|||||||
|
|
||||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||||
Some(port) => port as u16,
|
Some(port) => port as u16,
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match p.add_listening_port(port).await {
|
match p.add_listening_port(port).await {
|
||||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to add port {}: {}", port, e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
|
|||||||
|
|
||||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||||
Some(port) => port as u16,
|
Some(port) => port as u16,
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match p.remove_listening_port(port).await {
|
match p.remove_listening_port(port).await {
|
||||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to remove port {}: {}", port, e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
|
|||||||
|
|
||||||
let domain = match params.get("domain").and_then(|v| v.as_str()) {
|
let domain = match params.get("domain").and_then(|v| v.as_str()) {
|
||||||
Some(d) => d.to_string(),
|
Some(d) => d.to_string(),
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
"Missing 'domain' parameter".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let cert = match params.get("cert").and_then(|v| v.as_str()) {
|
let cert = match params.get("cert").and_then(|v| v.as_str()) {
|
||||||
Some(c) => c.to_string(),
|
Some(c) => c.to_string(),
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let key = match params.get("key").and_then(|v| v.as_str()) {
|
let key = match params.get("key").and_then(|v| v.as_str()) {
|
||||||
Some(k) => k.to_string(),
|
Some(k) => k.to_string(),
|
||||||
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
|
None => {
|
||||||
|
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
|
let ca = params
|
||||||
|
.get("ca")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
info!("loadCertificate: domain={}", domain);
|
info!("loadCertificate: domain={}", domain);
|
||||||
|
|
||||||
// Load cert into cert manager and hot-swap TLS config
|
// Load cert into cert manager and hot-swap TLS config
|
||||||
match p.load_certificate(&domain, cert, key, ca).await {
|
match p.load_certificate(&domain, cert, key, ca).await {
|
||||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
|
Err(e) => ManagementResponse::err(
|
||||||
|
id.to_string(),
|
||||||
|
format!("Failed to load certificate for {}: {}", domain, e),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
|
|||||||
let path = parts.get(1).copied().unwrap_or("/");
|
let path = parts.get(1).copied().unwrap_or("/");
|
||||||
|
|
||||||
// Extract Host header
|
// Extract Host header
|
||||||
let host = req_str.lines()
|
let host = req_str
|
||||||
|
.lines()
|
||||||
.find(|l| l.to_lowercase().starts_with("host:"))
|
.find(|l| l.to_lowercase().starts_with("host:"))
|
||||||
.map(|l| l[5..].trim())
|
.map(|l| l[5..].trim())
|
||||||
.unwrap_or("unknown");
|
.unwrap_or("unknown");
|
||||||
@@ -336,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
|
|||||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||||
|
|
||||||
// Extract Sec-WebSocket-Key for proper handshake
|
// Extract Sec-WebSocket-Key for proper handshake
|
||||||
let ws_key = req_str.lines()
|
let ws_key = req_str
|
||||||
|
.lines()
|
||||||
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
||||||
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
@@ -378,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
|
|||||||
use rcgen::{CertificateParams, KeyPair};
|
use rcgen::{CertificateParams, KeyPair};
|
||||||
|
|
||||||
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
||||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
params
|
||||||
|
.distinguished_name
|
||||||
|
.push(rcgen::DnType::CommonName, domain);
|
||||||
|
|
||||||
let key_pair = KeyPair::generate().unwrap();
|
let key_pair = KeyPair::generate().unwrap();
|
||||||
let cert = params.self_signed(&key_pair).unwrap();
|
let cert = params.self_signed(&key_pair).unwrap();
|
||||||
@@ -458,11 +462,7 @@ pub fn make_tls_terminate_route(
|
|||||||
|
|
||||||
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
|
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
|
||||||
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
|
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
|
||||||
pub async fn start_tls_ws_echo_backend(
|
pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
|
||||||
port: u16,
|
|
||||||
cert_pem: &str,
|
|
||||||
key_pem: &str,
|
|
||||||
) -> JoinHandle<()> {
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
use bytes::Buf;
|
||||||
use common::*;
|
use common::*;
|
||||||
use rustproxy::RustProxy;
|
use rustproxy::RustProxy;
|
||||||
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
|
use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
|
||||||
use bytes::Buf;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
|
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
|
||||||
@@ -14,7 +14,14 @@ fn make_h3_route(
|
|||||||
cert_pem: &str,
|
cert_pem: &str,
|
||||||
key_pem: &str,
|
key_pem: &str,
|
||||||
) -> rustproxy_config::RouteConfig {
|
) -> rustproxy_config::RouteConfig {
|
||||||
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
|
let mut route = make_tls_terminate_route(
|
||||||
|
port,
|
||||||
|
"localhost",
|
||||||
|
target_host,
|
||||||
|
target_port,
|
||||||
|
cert_pem,
|
||||||
|
key_pem,
|
||||||
|
);
|
||||||
route.route_match.transport = Some(TransportProtocol::All);
|
route.route_match.transport = Some(TransportProtocol::All);
|
||||||
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
|
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
|
||||||
route.action.udp = Some(RouteUdp {
|
route.action.udp = Some(RouteUdp {
|
||||||
@@ -89,9 +96,7 @@ async fn test_h3_response_stream_finishes() {
|
|||||||
.await
|
.await
|
||||||
.expect("QUIC handshake failed");
|
.expect("QUIC handshake failed");
|
||||||
|
|
||||||
let (mut driver, mut send_request) = h3::client::new(
|
let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
|
||||||
h3_quinn::Connection::new(connection),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.expect("H3 connection setup failed");
|
.expect("H3 connection setup failed");
|
||||||
|
|
||||||
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
|
|||||||
.body(())
|
.body(())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut stream = send_request.send_request(req).await
|
let mut stream = send_request
|
||||||
|
.send_request(req)
|
||||||
|
.await
|
||||||
.expect("Failed to send H3 request");
|
.expect("Failed to send H3 request");
|
||||||
stream.finish().await
|
stream
|
||||||
|
.finish()
|
||||||
|
.await
|
||||||
.expect("Failed to finish sending H3 request body");
|
.expect("Failed to finish sending H3 request body");
|
||||||
|
|
||||||
// 6. Read response headers
|
// 6. Read response headers
|
||||||
let resp = stream.recv_response().await
|
let resp = stream
|
||||||
|
.recv_response()
|
||||||
|
.await
|
||||||
.expect("Failed to receive H3 response");
|
.expect("Failed to receive H3 response");
|
||||||
assert_eq!(resp.status(), http::StatusCode::OK,
|
assert_eq!(
|
||||||
"Expected 200 OK, got {}", resp.status());
|
resp.status(),
|
||||||
|
http::StatusCode::OK,
|
||||||
|
"Expected 200 OK, got {}",
|
||||||
|
resp.status()
|
||||||
|
);
|
||||||
|
|
||||||
// 7. Read body and verify stream ends (FIN received)
|
// 7. Read body and verify stream ends (FIN received)
|
||||||
// This is the critical assertion: recv_data() must return None (stream ended)
|
// This is the critical assertion: recv_data() must return None (stream ended)
|
||||||
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
|
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut total = 0usize;
|
let mut total = 0usize;
|
||||||
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
||||||
total += chunk.remaining();
|
total += chunk.remaining();
|
||||||
}
|
}
|
||||||
// recv_data() returned None => stream ended (FIN received)
|
// recv_data() returned None => stream ended (FIN received)
|
||||||
total
|
total
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let bytes_received = result.expect(
|
let bytes_received = result.expect(
|
||||||
"TIMEOUT: H3 stream never ended (FIN not received by client). \
|
"TIMEOUT: H3 stream never ended (FIN not received by client). \
|
||||||
The proxy sent all response data but failed to send the QUIC stream FIN."
|
The proxy sent all response data but failed to send the QUIC stream FIN.",
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
bytes_received,
|
bytes_received,
|
||||||
|
|||||||
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||||
let body = extract_body(&response);
|
let body = extract_body(&response);
|
||||||
body.to_string()
|
body.to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
|
assert!(
|
||||||
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
|
result.contains(r#""method":"GET"#),
|
||||||
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
|
"Expected GET method, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
result.contains(r#""path":"/hello"#),
|
||||||
|
"Expected /hello path, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
result.contains(r#""backend":"main"#),
|
||||||
|
"Expected main backend, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![
|
routes: vec![
|
||||||
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
|
make_test_route(
|
||||||
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
|
proxy_port,
|
||||||
|
Some("alpha.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
backend1_port,
|
||||||
|
),
|
||||||
|
make_test_route(
|
||||||
|
proxy_port,
|
||||||
|
Some("beta.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
backend2_port,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Test alpha domain
|
// Test alpha domain
|
||||||
let alpha_result = with_timeout(async {
|
let alpha_result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||||
extract_body(&response).to_string()
|
extract_body(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
|
assert!(
|
||||||
|
alpha_result.contains(r#""backend":"alpha"#),
|
||||||
|
"Expected alpha backend, got: {}",
|
||||||
|
alpha_result
|
||||||
|
);
|
||||||
|
|
||||||
// Test beta domain
|
// Test beta domain
|
||||||
let beta_result = with_timeout(async {
|
let beta_result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||||
extract_body(&response).to_string()
|
extract_body(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
|
assert!(
|
||||||
|
beta_result.contains(r#""backend":"beta"#),
|
||||||
|
"Expected beta backend, got: {}",
|
||||||
|
beta_result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Test API path
|
// Test API path
|
||||||
let api_result = with_timeout(async {
|
let api_result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||||
extract_body(&response).to_string()
|
extract_body(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
|
assert!(
|
||||||
|
api_result.contains(r#""backend":"api"#),
|
||||||
|
"Expected api backend, got: {}",
|
||||||
|
api_result
|
||||||
|
);
|
||||||
|
|
||||||
// Test web path (no /api prefix)
|
// Test web path (no /api prefix)
|
||||||
let web_result = with_timeout(async {
|
let web_result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||||
extract_body(&response).to_string()
|
extract_body(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
|
assert!(
|
||||||
|
web_result.contains(r#""backend":"web"#),
|
||||||
|
"Expected web backend, got: {}",
|
||||||
|
web_result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Should get 204 No Content with CORS headers
|
// Should get 204 No Content with CORS headers
|
||||||
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
|
assert!(
|
||||||
assert!(result.to_lowercase().contains("access-control-allow-origin"),
|
result.contains("204"),
|
||||||
"Expected CORS header, got: {}", result);
|
"Expected 204 status, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.to_lowercase()
|
||||||
|
.contains("access-control-allow-origin"),
|
||||||
|
"Expected CORS header, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||||
response
|
response
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Proxy should relay the 500 from backend
|
// Proxy should relay the 500 from backend
|
||||||
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
|
assert!(
|
||||||
|
result.contains("500"),
|
||||||
|
"Expected 500 status, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
|
|||||||
|
|
||||||
// Create a route only for a specific domain
|
// Create a route only for a specific domain
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
|
routes: vec![make_test_route(
|
||||||
|
proxy_port,
|
||||||
|
Some("known.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
9999,
|
||||||
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||||
response
|
response
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Should get 502 Bad Gateway (no route matched)
|
// Should get 502 Bad Gateway (no route matched)
|
||||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
assert!(
|
||||||
|
result.contains("502"),
|
||||||
|
"Expected 502 status, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||||
response
|
response
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Should get 502 Bad Gateway (backend unavailable)
|
// Should get 502 Bad Gateway (backend unavailable)
|
||||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
assert!(
|
||||||
|
result.contains("502"),
|
||||||
|
"Expected 502 status, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![make_tls_terminate_route(
|
routes: vec![make_tls_terminate_route(
|
||||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
proxy_port,
|
||||||
|
domain,
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
&cert_pem,
|
||||||
|
&key_pem,
|
||||||
)],
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -295,7 +388,8 @@ async fn test_https_terminate_http_forward() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
let tls_config = rustls::ClientConfig::builder()
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
.dangerous()
|
.dangerous()
|
||||||
@@ -319,14 +413,28 @@ async fn test_https_terminate_http_forward() {
|
|||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||||
String::from_utf8_lossy(&response).to_string()
|
String::from_utf8_lossy(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let body = extract_body(&result);
|
let body = extract_body(&result);
|
||||||
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
|
assert!(
|
||||||
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
|
body.contains(r#""method":"GET"#),
|
||||||
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
|
"Expected GET, got: {}",
|
||||||
|
body
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
body.contains(r#""path":"/api/data"#),
|
||||||
|
"Expected /api/data, got: {}",
|
||||||
|
body
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
body.contains(r#""backend":"tls-backend"#),
|
||||||
|
"Expected tls-backend, got: {}",
|
||||||
|
body
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -347,7 +455,8 @@ async fn test_websocket_through_proxy() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -369,7 +478,9 @@ async fn test_websocket_through_proxy() {
|
|||||||
let mut temp = [0u8; 1];
|
let mut temp = [0u8; 1];
|
||||||
loop {
|
loop {
|
||||||
let n = stream.read(&mut temp).await.unwrap();
|
let n = stream.read(&mut temp).await.unwrap();
|
||||||
if n == 0 { break; }
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
response_buf.push(temp[0]);
|
response_buf.push(temp[0]);
|
||||||
if response_buf.len() >= 4 {
|
if response_buf.len() >= 4 {
|
||||||
let len = response_buf.len();
|
let len = response_buf.len();
|
||||||
@@ -380,7 +491,11 @@ async fn test_websocket_through_proxy() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||||
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
|
assert!(
|
||||||
|
response_str.contains("101"),
|
||||||
|
"Expected 101 Switching Protocols, got: {}",
|
||||||
|
response_str
|
||||||
|
);
|
||||||
assert!(
|
assert!(
|
||||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||||
"Expected Upgrade header, got: {}",
|
"Expected Upgrade header, got: {}",
|
||||||
@@ -399,7 +514,9 @@ async fn test_websocket_through_proxy() {
|
|||||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||||
|
|
||||||
"ok".to_string()
|
"ok".to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
|||||||
|
|
||||||
// Create terminate-and-reencrypt routes
|
// Create terminate-and-reencrypt routes
|
||||||
let mut route1 = make_tls_terminate_route(
|
let mut route1 = make_tls_terminate_route(
|
||||||
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
|
proxy_port,
|
||||||
|
"alpha.example.com",
|
||||||
|
"127.0.0.1",
|
||||||
|
backend1_port,
|
||||||
|
&cert1,
|
||||||
|
&key1,
|
||||||
);
|
);
|
||||||
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||||
|
|
||||||
let mut route2 = make_tls_terminate_route(
|
let mut route2 = make_tls_terminate_route(
|
||||||
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
|
proxy_port,
|
||||||
|
"beta.example.com",
|
||||||
|
"127.0.0.1",
|
||||||
|
backend2_port,
|
||||||
|
&cert2,
|
||||||
|
&key2,
|
||||||
);
|
);
|
||||||
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||||
|
|
||||||
@@ -450,7 +577,8 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
|
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
|
||||||
let alpha_result = with_timeout(async {
|
let alpha_result = with_timeout(
|
||||||
|
async {
|
||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
let tls_config = rustls::ClientConfig::builder()
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
.dangerous()
|
.dangerous()
|
||||||
@@ -461,16 +589,20 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
|||||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
let server_name =
|
||||||
|
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||||
|
|
||||||
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
|
let request =
|
||||||
|
"GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
|
||||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||||
|
|
||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||||
String::from_utf8_lossy(&response).to_string()
|
String::from_utf8_lossy(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -498,7 +630,8 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Test beta domain - different host goes to different backend
|
// Test beta domain - different host goes to different backend
|
||||||
let beta_result = with_timeout(async {
|
let beta_result = with_timeout(
|
||||||
|
async {
|
||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
let tls_config = rustls::ClientConfig::builder()
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
.dangerous()
|
.dangerous()
|
||||||
@@ -509,16 +642,20 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
|||||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
|
let server_name =
|
||||||
|
rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
|
||||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||||
|
|
||||||
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
|
let request =
|
||||||
|
"GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
|
||||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||||
|
|
||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||||
String::from_utf8_lossy(&response).to_string()
|
String::from_utf8_lossy(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
|
|||||||
.dangerous()
|
.dangerous()
|
||||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||||
.with_no_client_auth();
|
.with_no_client_auth();
|
||||||
let connector =
|
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||||
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
|
||||||
|
|
||||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let server_name =
|
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||||
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
|
||||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||||
|
|
||||||
// Send WebSocket upgrade request through TLS
|
// Send WebSocket upgrade request through TLS
|
||||||
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// HTTP request should match the route and get proxied
|
// HTTP request should match the route and get proxied
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
|
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
|
||||||
extract_body(&response).to_string()
|
extract_body(&response).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
|
|||||||
assert!(!wait_for_port(port, 200).await);
|
assert!(!wait_for_port(port, 200).await);
|
||||||
|
|
||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
|
assert!(
|
||||||
|
wait_for_port(port, 2000).await,
|
||||||
|
"Port should be listening after start"
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
|
|
||||||
// Give the OS a moment to release the port
|
// Give the OS a moment to release the port
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||||
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
|
assert!(
|
||||||
|
!wait_for_port(port, 200).await,
|
||||||
|
"Port should not be listening after stop"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
|
|||||||
let port = next_port();
|
let port = next_port();
|
||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
|
routes: vec![make_test_route(
|
||||||
|
port,
|
||||||
|
Some("old.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
8080,
|
||||||
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
|
|
||||||
// Update routes atomically
|
// Update routes atomically
|
||||||
let new_routes = vec![
|
let new_routes = vec![make_test_route(
|
||||||
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
|
port,
|
||||||
];
|
Some("new.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
9090,
|
||||||
|
)];
|
||||||
let result = proxy.update_routes(new_routes).await;
|
let result = proxy.update_routes(new_routes).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
|
|||||||
|
|
||||||
// Add a new port
|
// Add a new port
|
||||||
proxy.add_listening_port(port2).await.unwrap();
|
proxy.add_listening_port(port2).await.unwrap();
|
||||||
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
|
assert!(
|
||||||
|
wait_for_port(port2, 2000).await,
|
||||||
|
"New port should be listening"
|
||||||
|
);
|
||||||
|
|
||||||
// Remove the port
|
// Remove the port
|
||||||
proxy.remove_listening_port(port2).await.unwrap();
|
proxy.remove_listening_port(port2).await.unwrap();
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||||
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
|
assert!(
|
||||||
|
!wait_for_port(port2, 200).await,
|
||||||
|
"Removed port should not be listening"
|
||||||
|
);
|
||||||
|
|
||||||
// Original port should still be listening
|
// Original port should still be listening
|
||||||
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
|
assert!(
|
||||||
|
wait_for_port(port1, 200).await,
|
||||||
|
"Original port should still be listening"
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
|
|||||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
let stats = proxy.get_statistics();
|
let stats = proxy.get_statistics();
|
||||||
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
|
assert!(
|
||||||
|
stats.total_connections > 0,
|
||||||
|
"Expected total_connections > 0, got {}",
|
||||||
|
stats.total_connections
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
|
|||||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
let stats = proxy.get_statistics();
|
let stats = proxy.get_statistics();
|
||||||
assert!(stats.total_connections > 0,
|
assert!(
|
||||||
"Expected some connections tracked, got {}", stats.total_connections);
|
stats.total_connections > 0,
|
||||||
|
"Expected some connections tracked, got {}",
|
||||||
|
stats.total_connections
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
|
|||||||
let mut proxy = RustProxy::new(options).unwrap();
|
let mut proxy = RustProxy::new(options).unwrap();
|
||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(port1, 2000).await);
|
assert!(wait_for_port(port1, 2000).await);
|
||||||
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
|
assert!(
|
||||||
|
!wait_for_port(port2, 200).await,
|
||||||
|
"port2 should not be listening yet"
|
||||||
|
);
|
||||||
|
|
||||||
// Update routes to use port2 instead
|
// Update routes to use port2 instead
|
||||||
let new_routes = vec![
|
let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
|
||||||
make_test_route(port2, None, "127.0.0.1", backend_port),
|
|
||||||
];
|
|
||||||
proxy.update_routes(new_routes).await.unwrap();
|
proxy.update_routes(new_routes).await.unwrap();
|
||||||
|
|
||||||
// Port2 should now be listening, port1 should be closed
|
// Port2 should now be listening, port1 should be closed
|
||||||
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
|
assert!(
|
||||||
|
wait_for_port(port2, 2000).await,
|
||||||
|
"port2 should be listening after reload"
|
||||||
|
);
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||||
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
|
assert!(
|
||||||
|
!wait_for_port(port1, 200).await,
|
||||||
|
"port1 should be closed after reload"
|
||||||
|
);
|
||||||
|
|
||||||
// Verify port2 works
|
// Verify port2 works
|
||||||
let ports = proxy.get_listening_ports();
|
let ports = proxy.get_listening_ports();
|
||||||
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
|
assert!(
|
||||||
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
|
ports.contains(&port2),
|
||||||
|
"Expected port2 in listening ports: {:?}",
|
||||||
|
ports
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!ports.contains(&port1),
|
||||||
|
"port1 should not be in listening ports: {:?}",
|
||||||
|
ports
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,10 +24,14 @@ async fn test_tcp_forward_echo() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
|
|
||||||
// Wait for proxy to be ready
|
// Wait for proxy to be ready
|
||||||
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
|
assert!(
|
||||||
|
wait_for_port(proxy_port, 2000).await,
|
||||||
|
"Proxy port not ready"
|
||||||
|
);
|
||||||
|
|
||||||
// Connect and send data
|
// Connect and send data
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -36,7 +40,9 @@ async fn test_tcp_forward_echo() {
|
|||||||
let mut buf = vec![0u8; 1024];
|
let mut buf = vec![0u8; 1024];
|
||||||
let n = stream.read(&mut buf).await.unwrap();
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -61,7 +67,8 @@ async fn test_tcp_forward_large_payload() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -75,7 +82,9 @@ async fn test_tcp_forward_large_payload() {
|
|||||||
let mut received = Vec::new();
|
let mut received = Vec::new();
|
||||||
stream.read_to_end(&mut received).await.unwrap();
|
stream.read_to_end(&mut received).await.unwrap();
|
||||||
received.len()
|
received.len()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -100,7 +109,8 @@ async fn test_tcp_forward_multiple_connections() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut handles = Vec::new();
|
let mut handles = Vec::new();
|
||||||
for i in 0..10 {
|
for i in 0..10 {
|
||||||
let port = proxy_port;
|
let port = proxy_port;
|
||||||
@@ -122,7 +132,9 @@ async fn test_tcp_forward_multiple_connections() {
|
|||||||
results.push(handle.await.unwrap());
|
results.push(handle.await.unwrap());
|
||||||
}
|
}
|
||||||
results
|
results
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Connection should complete (proxy accepts it) but data should not flow
|
// Connection should complete (proxy accepts it) but data should not flow
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||||
stream.is_ok()
|
stream.is_ok()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result, "Should be able to connect to proxy even if backend is down");
|
assert!(
|
||||||
|
result,
|
||||||
|
"Should be able to connect to proxy even if backend is down"
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -178,7 +196,8 @@ async fn test_tcp_forward_bidirectional() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -187,7 +206,9 @@ async fn test_tcp_forward_bidirectional() {
|
|||||||
let mut buf = vec![0u8; 1024];
|
let mut buf = vec![0u8; 1024];
|
||||||
let n = stream.read(&mut buf).await.unwrap();
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![
|
routes: vec![
|
||||||
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
|
make_tls_passthrough_route(
|
||||||
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
|
proxy_port,
|
||||||
|
Some("one.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
backend1_port,
|
||||||
|
),
|
||||||
|
make_tls_passthrough_route(
|
||||||
|
proxy_port,
|
||||||
|
Some("two.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
backend2_port,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -76,7 +86,8 @@ async fn test_tls_passthrough_sni_routing() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Send a fake ClientHello with SNI "one.example.com"
|
// Send a fake ClientHello with SNI "one.example.com"
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -86,15 +97,22 @@ async fn test_tls_passthrough_sni_routing() {
|
|||||||
let mut buf = vec![0u8; 4096];
|
let mut buf = vec![0u8; 4096];
|
||||||
let n = stream.read(&mut buf).await.unwrap();
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Backend1 should have received the ClientHello and prefixed its response
|
// Backend1 should have received the ClientHello and prefixed its response
|
||||||
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
|
assert!(
|
||||||
|
result.starts_with("BACKEND1:"),
|
||||||
|
"Expected BACKEND1 prefix, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
// Now test routing to backend2
|
// Now test routing to backend2
|
||||||
let result2 = with_timeout(async {
|
let result2 = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -104,11 +122,17 @@ async fn test_tls_passthrough_sni_routing() {
|
|||||||
let mut buf = vec![0u8; 4096];
|
let mut buf = vec![0u8; 4096];
|
||||||
let n = stream.read(&mut buf).await.unwrap();
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
|
assert!(
|
||||||
|
result2.starts_with("BACKEND2:"),
|
||||||
|
"Expected BACKEND2 prefix, got: {}",
|
||||||
|
result2
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
|
|||||||
let _backend = start_echo_server(backend_port).await;
|
let _backend = start_echo_server(backend_port).await;
|
||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![
|
routes: vec![make_tls_passthrough_route(
|
||||||
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
|
proxy_port,
|
||||||
],
|
Some("known.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -132,7 +159,8 @@ async fn test_tls_passthrough_unknown_sni() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -146,7 +174,9 @@ async fn test_tls_passthrough_unknown_sni() {
|
|||||||
Ok(_) => false, // Got data = route shouldn't have matched
|
Ok(_) => false, // Got data = route shouldn't have matched
|
||||||
Err(_) => true, // Error = connection dropped
|
Err(_) => true, // Error = connection dropped
|
||||||
}
|
}
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
|
|||||||
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![
|
routes: vec![make_tls_passthrough_route(
|
||||||
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
|
proxy_port,
|
||||||
],
|
Some("*.example.com"),
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -174,7 +207,8 @@ async fn test_tls_passthrough_wildcard_domain() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Should match any subdomain of example.com
|
// Should match any subdomain of example.com
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -184,11 +218,17 @@ async fn test_tls_passthrough_wildcard_domain() {
|
|||||||
let mut buf = vec![0u8; 4096];
|
let mut buf = vec![0u8; 4096];
|
||||||
let n = stream.read(&mut buf).await.unwrap();
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
|
assert!(
|
||||||
|
result.starts_with("WILDCARD:"),
|
||||||
|
"Expected WILDCARD prefix, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -222,7 +262,8 @@ async fn test_tls_passthrough_multiple_domains() {
|
|||||||
("beta.example.com", "B2:"),
|
("beta.example.com", "B2:"),
|
||||||
("gamma.example.com", "B3:"),
|
("gamma.example.com", "B3:"),
|
||||||
] {
|
] {
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -232,14 +273,18 @@ async fn test_tls_passthrough_multiple_domains() {
|
|||||||
let mut buf = vec![0u8; 4096];
|
let mut buf = vec![0u8; 4096];
|
||||||
let n = stream.read(&mut buf).await.unwrap();
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 5)
|
},
|
||||||
|
5,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
result.starts_with(expected_prefix),
|
result.starts_with(expected_prefix),
|
||||||
"Domain {} should route to {}, got: {}",
|
"Domain {} should route to {}, got: {}",
|
||||||
domain, expected_prefix, result
|
domain,
|
||||||
|
expected_prefix,
|
||||||
|
result
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![make_tls_terminate_route(
|
routes: vec![make_tls_terminate_route(
|
||||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
proxy_port,
|
||||||
|
domain,
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
&cert_pem,
|
||||||
|
&key_pem,
|
||||||
)],
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -84,7 +89,8 @@ async fn test_tls_terminate_basic() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Connect with TLS client
|
// Connect with TLS client
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let tls_config = make_insecure_tls_client_config();
|
let tls_config = make_insecure_tls_client_config();
|
||||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||||
|
|
||||||
@@ -100,7 +106,9 @@ async fn test_tls_terminate_basic() {
|
|||||||
let mut buf = vec![0u8; 1024];
|
let mut buf = vec![0u8; 1024];
|
||||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
|
|||||||
|
|
||||||
// Create terminate-and-reencrypt route
|
// Create terminate-and-reencrypt route
|
||||||
let mut route = make_tls_terminate_route(
|
let mut route = make_tls_terminate_route(
|
||||||
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
|
proxy_port,
|
||||||
|
domain,
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
&proxy_cert,
|
||||||
|
&proxy_key,
|
||||||
);
|
);
|
||||||
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||||
|
|
||||||
@@ -138,7 +151,8 @@ async fn test_tls_terminate_and_reencrypt() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let tls_config = make_insecure_tls_client_config();
|
let tls_config = make_insecure_tls_client_config();
|
||||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||||
|
|
||||||
@@ -154,7 +168,9 @@ async fn test_tls_terminate_and_reencrypt() {
|
|||||||
let mut buf = vec![0u8; 1024];
|
let mut buf = vec![0u8; 1024];
|
||||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![
|
routes: vec![
|
||||||
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
|
make_tls_terminate_route(
|
||||||
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
|
proxy_port,
|
||||||
|
"alpha.example.com",
|
||||||
|
"127.0.0.1",
|
||||||
|
backend1_port,
|
||||||
|
&cert1,
|
||||||
|
&key1,
|
||||||
|
),
|
||||||
|
make_tls_terminate_route(
|
||||||
|
proxy_port,
|
||||||
|
"beta.example.com",
|
||||||
|
"127.0.0.1",
|
||||||
|
backend2_port,
|
||||||
|
&cert2,
|
||||||
|
&key2,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -188,7 +218,8 @@ async fn test_tls_terminate_sni_cert_selection() {
|
|||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
// Test alpha domain
|
// Test alpha domain
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let tls_config = make_insecure_tls_client_config();
|
let tls_config = make_insecure_tls_client_config();
|
||||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||||
|
|
||||||
@@ -196,7 +227,8 @@ async fn test_tls_terminate_sni_cert_selection() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
let server_name =
|
||||||
|
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||||
|
|
||||||
tls_stream.write_all(b"test").await.unwrap();
|
tls_stream.write_all(b"test").await.unwrap();
|
||||||
@@ -204,11 +236,17 @@ async fn test_tls_terminate_sni_cert_selection() {
|
|||||||
let mut buf = vec![0u8; 1024];
|
let mut buf = vec![0u8; 1024];
|
||||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||||
}, 10)
|
},
|
||||||
|
10,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
|
assert!(
|
||||||
|
result.starts_with("ALPHA:"),
|
||||||
|
"Expected ALPHA prefix, got: {}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
|
||||||
proxy.stop().await.unwrap();
|
proxy.stop().await.unwrap();
|
||||||
}
|
}
|
||||||
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![make_tls_terminate_route(
|
routes: vec![make_tls_terminate_route(
|
||||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
proxy_port,
|
||||||
|
domain,
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
&cert_pem,
|
||||||
|
&key_pem,
|
||||||
)],
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -233,7 +276,8 @@ async fn test_tls_terminate_large_payload() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let tls_config = make_insecure_tls_client_config();
|
let tls_config = make_insecure_tls_client_config();
|
||||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||||
|
|
||||||
@@ -252,7 +296,9 @@ async fn test_tls_terminate_large_payload() {
|
|||||||
let mut received = Vec::new();
|
let mut received = Vec::new();
|
||||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||||
received.len()
|
received.len()
|
||||||
}, 15)
|
},
|
||||||
|
15,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
|
|||||||
|
|
||||||
let options = RustProxyOptions {
|
let options = RustProxyOptions {
|
||||||
routes: vec![make_tls_terminate_route(
|
routes: vec![make_tls_terminate_route(
|
||||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
proxy_port,
|
||||||
|
domain,
|
||||||
|
"127.0.0.1",
|
||||||
|
backend_port,
|
||||||
|
&cert_pem,
|
||||||
|
&key_pem,
|
||||||
)],
|
)],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -281,7 +332,8 @@ async fn test_tls_terminate_concurrent() {
|
|||||||
proxy.start().await.unwrap();
|
proxy.start().await.unwrap();
|
||||||
assert!(wait_for_port(proxy_port, 2000).await);
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||||||
|
|
||||||
let result = with_timeout(async {
|
let result = with_timeout(
|
||||||
|
async {
|
||||||
let mut handles = Vec::new();
|
let mut handles = Vec::new();
|
||||||
for i in 0..10 {
|
for i in 0..10 {
|
||||||
let port = proxy_port;
|
let port = proxy_port;
|
||||||
@@ -311,7 +363,9 @@ async fn test_tls_terminate_concurrent() {
|
|||||||
results.push(handle.await.unwrap());
|
results.push(handle.await.unwrap());
|
||||||
}
|
}
|
||||||
results
|
results
|
||||||
}, 15)
|
},
|
||||||
|
15,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,6 @@
|
|||||||
*/
|
*/
|
||||||
export const commitinfo = {
|
export const commitinfo = {
|
||||||
name: '@push.rocks/smartproxy',
|
name: '@push.rocks/smartproxy',
|
||||||
version: '27.8.2',
|
version: '27.9.0',
|
||||||
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -7,7 +7,7 @@ export { SmartProxy } from './proxies/smart-proxy/index.js';
|
|||||||
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
|
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
|
||||||
|
|
||||||
// Export smart-proxy models
|
// Export smart-proxy models
|
||||||
export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
|
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
|
||||||
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
|
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
|
||||||
export * from './proxies/smart-proxy/utils/index.js';
|
export * from './proxies/smart-proxy/utils/index.js';
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,6 @@
|
|||||||
* SmartProxy models
|
* SmartProxy models
|
||||||
*/
|
*/
|
||||||
// Export everything except IAcmeOptions from interfaces
|
// Export everything except IAcmeOptions from interfaces
|
||||||
export type { ISmartProxyOptions, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
|
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
|
||||||
export * from './route-types.js';
|
export * from './route-types.js';
|
||||||
export * from './metrics-types.js';
|
export * from './metrics-types.js';
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ export interface ISmartProxyCertStore {
|
|||||||
}
|
}
|
||||||
import type { IRouteConfig } from './route-types.js';
|
import type { IRouteConfig } from './route-types.js';
|
||||||
|
|
||||||
|
export interface ISmartProxySecurityPolicy {
|
||||||
|
blockedIps?: string[];
|
||||||
|
blockedCidrs?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provision object for static or HTTP-01 certificate
|
* Provision object for static or HTTP-01 certificate
|
||||||
*/
|
*/
|
||||||
@@ -137,6 +142,7 @@ export interface ISmartProxyOptions {
|
|||||||
// Rate limiting and security
|
// Rate limiting and security
|
||||||
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
|
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
|
||||||
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
|
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
|
||||||
|
securityPolicy?: ISmartProxySecurityPolicy; // Global ingress block policy, enforced before routing
|
||||||
|
|
||||||
// Enhanced keep-alive settings
|
// Enhanced keep-alive settings
|
||||||
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
|
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
|
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
|
||||||
import type { IAcmeOptions, ISmartProxyOptions } from './interfaces.js';
|
import type { IAcmeOptions, ISmartProxyOptions, ISmartProxySecurityPolicy } from './interfaces.js';
|
||||||
import type {
|
import type {
|
||||||
IRouteAction,
|
IRouteAction,
|
||||||
IRouteConfig,
|
IRouteConfig,
|
||||||
@@ -75,6 +75,7 @@ export interface IRustProxyOptions {
|
|||||||
keepAliveInactivityMultiplier?: number;
|
keepAliveInactivityMultiplier?: number;
|
||||||
extendedKeepAliveLifetime?: number;
|
extendedKeepAliveLifetime?: number;
|
||||||
metrics?: ISmartProxyOptions['metrics'];
|
metrics?: ISmartProxyOptions['metrics'];
|
||||||
|
securityPolicy?: ISmartProxySecurityPolicy;
|
||||||
acme?: IRustAcmeOptions;
|
acme?: IRustAcmeOptions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import type {
|
|||||||
IRustRouteConfig,
|
IRustRouteConfig,
|
||||||
IRustStatistics,
|
IRustStatistics,
|
||||||
} from './models/rust-types.js';
|
} from './models/rust-types.js';
|
||||||
|
import type { ISmartProxySecurityPolicy } from './models/interfaces.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type-safe command definitions for the Rust proxy IPC protocol.
|
* Type-safe command definitions for the Rust proxy IPC protocol.
|
||||||
@@ -15,6 +16,7 @@ type TSmartProxyCommands = {
|
|||||||
start: { params: { config: IRustProxyOptions }; result: void };
|
start: { params: { config: IRustProxyOptions }; result: void };
|
||||||
stop: { params: Record<string, never>; result: void };
|
stop: { params: Record<string, never>; result: void };
|
||||||
updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
|
updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
|
||||||
|
setSecurityPolicy: { params: { policy: ISmartProxySecurityPolicy }; result: void };
|
||||||
getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
|
getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
|
||||||
getStatistics: { params: Record<string, never>; result: IRustStatistics };
|
getStatistics: { params: Record<string, never>; result: IRustStatistics };
|
||||||
provisionCertificate: { params: { routeName: string }; result: void };
|
provisionCertificate: { params: { routeName: string }; result: void };
|
||||||
@@ -139,6 +141,10 @@ export class RustProxyBridge extends plugins.EventEmitter {
|
|||||||
await this.bridge.sendCommand('updateRoutes', { routes });
|
await this.bridge.sendCommand('updateRoutes', { routes });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async setSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
|
||||||
|
await this.bridge.sendCommand('setSecurityPolicy', { policy });
|
||||||
|
}
|
||||||
|
|
||||||
public async getMetrics(): Promise<IRustMetricsSnapshot> {
|
public async getMetrics(): Promise<IRustMetricsSnapshot> {
|
||||||
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
|
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import { Mutex } from './utils/mutex.js';
|
|||||||
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
|
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
|
||||||
|
|
||||||
// Types
|
// Types
|
||||||
import type { ISmartProxyOptions, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
|
import type { ISmartProxyOptions, ISmartProxySecurityPolicy, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
|
||||||
import type { IRouteConfig } from './models/route-types.js';
|
import type { IRouteConfig } from './models/route-types.js';
|
||||||
import type { IMetrics } from './models/metrics-types.js';
|
import type { IMetrics } from './models/metrics-types.js';
|
||||||
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
|
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
|
||||||
@@ -350,6 +350,15 @@ export class SmartProxy extends plugins.EventEmitter {
|
|||||||
.catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' }));
|
.catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' }));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the global ingress security policy without changing routes.
|
||||||
|
* The Rust engine applies this before route selection and backend connection.
|
||||||
|
*/
|
||||||
|
public async updateSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
|
||||||
|
this.settings.securityPolicy = policy;
|
||||||
|
await this.bridge.setSecurityPolicy(policy);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provision a certificate for a named route.
|
* Provision a certificate for a named route.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ export function buildRustProxyOptions(
|
|||||||
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
|
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
|
||||||
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
|
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
|
||||||
metrics: settings.metrics,
|
metrics: settings.metrics,
|
||||||
|
securityPolicy: settings.securityPolicy,
|
||||||
acme: serializeAcmeForRust(acme),
|
acme: serializeAcmeForRust(acme),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user