feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates

This commit is contained in:
2026-02-09 10:55:46 +00:00
parent a31fee41df
commit 1df3b7af4a
151 changed files with 16927 additions and 19432 deletions

View File

@@ -0,0 +1,25 @@
[package]
name = "rustproxy-passthrough"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Raw TCP/SNI passthrough engine for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
dashmap = { workspace = true }
arc-swap = { workspace = true }
rustproxy-http = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tokio-util = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@@ -0,0 +1,155 @@
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
/// Per-connection tracking record with atomics for lock-free updates.
///
/// Each field uses atomics so that the forwarding tasks can update
/// bytes_received / bytes_sent / last_activity without holding any lock,
/// while the zombie scanner reads them concurrently.
pub struct ConnectionRecord {
/// Unique connection ID assigned by the ConnectionTracker.
pub id: u64,
/// Wall-clock instant when this connection was created.
pub created_at: Instant,
/// Milliseconds since `created_at` when the last activity occurred.
/// Updated atomically by the forwarding loops.
pub last_activity: AtomicU64,
/// Total bytes received from the client (inbound).
pub bytes_received: AtomicU64,
/// Total bytes sent to the client (outbound / from backend).
pub bytes_sent: AtomicU64,
/// True once the client side of the connection has closed.
pub client_closed: AtomicBool,
/// True once the backend side of the connection has closed.
pub backend_closed: AtomicBool,
/// Whether this connection uses TLS (affects zombie thresholds).
pub is_tls: AtomicBool,
/// Whether this connection has keep-alive semantics.
pub has_keep_alive: AtomicBool,
}
impl ConnectionRecord {
/// Create a new connection record with the given ID.
/// All counters start at zero, all flags start as false.
pub fn new(id: u64) -> Self {
Self {
id,
created_at: Instant::now(),
last_activity: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
client_closed: AtomicBool::new(false),
backend_closed: AtomicBool::new(false),
is_tls: AtomicBool::new(false),
has_keep_alive: AtomicBool::new(false),
}
}
/// Update `last_activity` to reflect the current elapsed time.
pub fn touch(&self) {
let elapsed_ms = self.created_at.elapsed().as_millis() as u64;
self.last_activity.store(elapsed_ms, Ordering::Relaxed);
}
/// Record `n` bytes received from the client (inbound).
pub fn record_bytes_in(&self, n: u64) {
self.bytes_received.fetch_add(n, Ordering::Relaxed);
self.touch();
}
/// Record `n` bytes sent to the client (outbound / from backend).
pub fn record_bytes_out(&self, n: u64) {
self.bytes_sent.fetch_add(n, Ordering::Relaxed);
self.touch();
}
/// How long since the last activity on this connection.
pub fn idle_duration(&self) -> Duration {
let last_ms = self.last_activity.load(Ordering::Relaxed);
let age_ms = self.created_at.elapsed().as_millis() as u64;
Duration::from_millis(age_ms.saturating_sub(last_ms))
}
/// Total age of this connection (time since creation).
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_new_record() {
let record = ConnectionRecord::new(42);
assert_eq!(record.id, 42);
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 0);
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 0);
assert!(!record.client_closed.load(Ordering::Relaxed));
assert!(!record.backend_closed.load(Ordering::Relaxed));
assert!(!record.is_tls.load(Ordering::Relaxed));
assert!(!record.has_keep_alive.load(Ordering::Relaxed));
}
#[test]
fn test_record_bytes() {
let record = ConnectionRecord::new(1);
record.record_bytes_in(100);
record.record_bytes_in(200);
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 300);
record.record_bytes_out(50);
record.record_bytes_out(75);
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 125);
}
#[test]
fn test_touch_updates_activity() {
let record = ConnectionRecord::new(1);
assert_eq!(record.last_activity.load(Ordering::Relaxed), 0);
// Sleep briefly so elapsed time is nonzero
thread::sleep(Duration::from_millis(10));
record.touch();
let activity = record.last_activity.load(Ordering::Relaxed);
assert!(activity >= 10, "last_activity should be at least 10ms, got {}", activity);
}
#[test]
fn test_idle_duration() {
let record = ConnectionRecord::new(1);
// Initially idle_duration ~ age since last_activity is 0
thread::sleep(Duration::from_millis(20));
let idle = record.idle_duration();
assert!(idle >= Duration::from_millis(20));
// After touch, idle should be near zero
record.touch();
let idle = record.idle_duration();
assert!(idle < Duration::from_millis(10));
}
#[test]
fn test_age() {
let record = ConnectionRecord::new(1);
thread::sleep(Duration::from_millis(20));
let age = record.age();
assert!(age >= Duration::from_millis(20));
}
#[test]
fn test_flags() {
let record = ConnectionRecord::new(1);
record.client_closed.store(true, Ordering::Relaxed);
record.is_tls.store(true, Ordering::Relaxed);
record.has_keep_alive.store(true, Ordering::Relaxed);
assert!(record.client_closed.load(Ordering::Relaxed));
assert!(!record.backend_closed.load(Ordering::Relaxed));
assert!(record.is_tls.load(Ordering::Relaxed));
assert!(record.has_keep_alive.load(Ordering::Relaxed));
}
}

View File

@@ -0,0 +1,402 @@
use dashmap::DashMap;
use std::collections::VecDeque;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use super::connection_record::ConnectionRecord;
/// Thresholds for zombie detection (non-TLS connections).
const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30);
/// Thresholds for zombie detection (TLS connections).
const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300);
/// Stuck connection timeout (non-TLS): received data but never sent any.
const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60);
/// Stuck connection timeout (TLS): received data but never sent any.
const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300);
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
/// Also maintains per-connection records for zombie detection.
pub struct ConnectionTracker {
/// Active connection counts per IP
active: DashMap<IpAddr, AtomicU64>,
/// Connection timestamps per IP for rate limiting
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
/// Maximum concurrent connections per IP (None = unlimited)
max_per_ip: Option<u64>,
/// Maximum new connections per minute per IP (None = unlimited)
rate_limit_per_minute: Option<u64>,
/// Per-connection tracking records for zombie detection
connections: DashMap<u64, Arc<ConnectionRecord>>,
/// Monotonically increasing connection ID counter
next_id: AtomicU64,
}
impl ConnectionTracker {
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
Self {
active: DashMap::new(),
timestamps: DashMap::new(),
max_per_ip,
rate_limit_per_minute,
connections: DashMap::new(),
next_id: AtomicU64::new(1),
}
}
/// Try to accept a new connection from the given IP.
/// Returns true if allowed, false if over limit.
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
if count >= max {
return false;
}
}
// Check rate limit
if let Some(rate_limit) = self.rate_limit_per_minute {
let now = Instant::now();
let one_minute = std::time::Duration::from_secs(60);
let mut entry = self.timestamps.entry(*ip).or_default();
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
timestamps.pop_front();
}
if timestamps.len() as u64 >= rate_limit {
return false;
}
timestamps.push_back(now);
}
true
}
/// Record that a connection was opened from the given IP.
pub fn connection_opened(&self, ip: &IpAddr) {
self.active
.entry(*ip)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection was closed from the given IP.
pub fn connection_closed(&self, ip: &IpAddr) {
if let Some(counter) = self.active.get(ip) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Clean up zero entries
if prev <= 1 {
drop(counter);
self.active.remove(ip);
}
}
}
/// Get the current number of active connections for an IP.
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
}
/// Get the total number of tracked IPs.
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
/// Register a new connection and return its tracking record.
///
/// The returned `Arc<ConnectionRecord>` should be passed to the forwarding
/// loop so it can update bytes / activity atomics in real time.
pub fn register_connection(&self, is_tls: bool) -> Arc<ConnectionRecord> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let record = Arc::new(ConnectionRecord::new(id));
record.is_tls.store(is_tls, Ordering::Relaxed);
self.connections.insert(id, Arc::clone(&record));
record
}
/// Remove a connection record when the connection is fully closed.
pub fn unregister_connection(&self, id: u64) {
self.connections.remove(&id);
}
/// Scan all tracked connections and return IDs of zombie connections.
///
/// A connection is considered a zombie in any of these cases:
/// - **Full zombie**: both `client_closed` and `backend_closed` are true.
/// - **Half zombie**: one side closed for longer than the threshold
/// (5 min for TLS, 30s for non-TLS).
/// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer
/// than the stuck threshold (5 min for TLS, 60s for non-TLS).
pub fn scan_zombies(&self) -> Vec<u64> {
let mut zombies = Vec::new();
for entry in self.connections.iter() {
let record = entry.value();
let id = *entry.key();
let is_tls = record.is_tls.load(Ordering::Relaxed);
let client_closed = record.client_closed.load(Ordering::Relaxed);
let backend_closed = record.backend_closed.load(Ordering::Relaxed);
let idle = record.idle_duration();
let bytes_in = record.bytes_received.load(Ordering::Relaxed);
let bytes_out = record.bytes_sent.load(Ordering::Relaxed);
// Full zombie: both sides closed
if client_closed && backend_closed {
zombies.push(id);
continue;
}
// Half zombie: one side closed for too long
let half_timeout = if is_tls {
HALF_ZOMBIE_TIMEOUT_TLS
} else {
HALF_ZOMBIE_TIMEOUT_PLAIN
};
if (client_closed || backend_closed) && idle >= half_timeout {
zombies.push(id);
continue;
}
// Stuck: received data but never sent anything for too long
let stuck_timeout = if is_tls {
STUCK_TIMEOUT_TLS
} else {
STUCK_TIMEOUT_PLAIN
};
if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout {
zombies.push(id);
}
}
zombies
}
/// Start a background task that periodically scans for zombie connections.
///
/// The scanner runs every 10 seconds and logs any zombies it finds.
/// It stops when the provided `CancellationToken` is cancelled.
pub fn start_zombie_scanner(self: &Arc<Self>, cancel: CancellationToken) {
let tracker = Arc::clone(self);
tokio::spawn(async move {
let interval = Duration::from_secs(10);
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Zombie scanner shutting down");
break;
}
_ = tokio::time::sleep(interval) => {
let zombies = tracker.scan_zombies();
if !zombies.is_empty() {
warn!(
"Detected {} zombie connection(s): {:?}",
zombies.len(),
zombies
);
}
}
}
}
});
}
/// Get the total number of tracked connections (with records).
pub fn total_connections(&self) -> usize {
self.connections.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_tracking() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 2);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 0);
}
#[test]
fn test_per_ip_limit() {
let tracker = ConnectionTracker::new(Some(2), None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
// Third connection should be rejected
assert!(!tracker.try_accept(&ip));
// Different IP should still be allowed
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
assert!(tracker.try_accept(&ip2));
}
#[test]
fn test_rate_limit() {
let tracker = ConnectionTracker::new(None, Some(3));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
// 4th attempt within the minute should be rejected
assert!(!tracker.try_accept(&ip));
}
#[test]
fn test_no_limits() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..1000 {
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
}
assert_eq!(tracker.active_connections(&ip), 1000);
}
#[test]
fn test_tracked_ips() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.tracked_ips(), 0);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
tracker.connection_opened(&ip1);
tracker.connection_opened(&ip2);
assert_eq!(tracker.tracked_ips(), 2);
tracker.connection_closed(&ip1);
assert_eq!(tracker.tracked_ips(), 1);
}
#[test]
fn test_register_unregister_connection() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.total_connections(), 0);
let record1 = tracker.register_connection(false);
assert_eq!(tracker.total_connections(), 1);
assert!(!record1.is_tls.load(Ordering::Relaxed));
let record2 = tracker.register_connection(true);
assert_eq!(tracker.total_connections(), 2);
assert!(record2.is_tls.load(Ordering::Relaxed));
// IDs should be unique
assert_ne!(record1.id, record2.id);
tracker.unregister_connection(record1.id);
assert_eq!(tracker.total_connections(), 1);
tracker.unregister_connection(record2.id);
assert_eq!(tracker.total_connections(), 0);
}
#[test]
fn test_full_zombie_detection() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
// Not a zombie initially
assert!(tracker.scan_zombies().is_empty());
// Set both sides closed -> full zombie
record.client_closed.store(true, Ordering::Relaxed);
record.backend_closed.store(true, Ordering::Relaxed);
let zombies = tracker.scan_zombies();
assert_eq!(zombies.len(), 1);
assert_eq!(zombies[0], record.id);
}
#[test]
fn test_half_zombie_not_triggered_immediately() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
record.touch(); // mark activity now
// Only one side closed, but just now -> not a zombie yet
record.client_closed.store(true, Ordering::Relaxed);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_stuck_connection_not_triggered_immediately() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
record.touch(); // mark activity now
// Has received data but sent nothing -> but just started, not stuck yet
record.bytes_received.store(1000, Ordering::Relaxed);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_unregister_removes_from_zombie_scan() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
let id = record.id;
// Make it a full zombie
record.client_closed.store(true, Ordering::Relaxed);
record.backend_closed.store(true, Ordering::Relaxed);
assert_eq!(tracker.scan_zombies().len(), 1);
// Unregister should remove it
tracker.unregister_connection(id);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_total_connections() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.total_connections(), 0);
let r1 = tracker.register_connection(false);
let r2 = tracker.register_connection(true);
let r3 = tracker.register_connection(false);
assert_eq!(tracker.total_connections(), 3);
tracker.unregister_connection(r2.id);
assert_eq!(tracker.total_connections(), 2);
tracker.unregister_connection(r1.id);
tracker.unregister_connection(r3.id);
assert_eq!(tracker.total_connections(), 0);
}
}

View File

@@ -0,0 +1,325 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
use super::connection_record::ConnectionRecord;
/// Statistics for a forwarded connection.
#[derive(Debug, Default)]
pub struct ForwardStats {
pub bytes_in: AtomicU64,
pub bytes_out: AtomicU64,
}
/// Perform bidirectional TCP forwarding between client and backend.
///
/// This is the core data path for passthrough connections.
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
pub async fn forward_bidirectional(
mut client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
}
let (mut client_read, mut client_write) = client.split();
let (mut backend_read, mut backend_write) = backend.split();
let client_to_backend = async {
let mut buf = vec![0u8; 65536];
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
loop {
let n = client_read.read(&mut buf).await?;
if n == 0 {
break;
}
backend_write.write_all(&buf[..n]).await?;
total += n as u64;
}
backend_write.shutdown().await?;
Ok::<u64, std::io::Error>(total)
};
let backend_to_client = async {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = backend_read.read(&mut buf).await?;
if n == 0 {
break;
}
client_write.write_all(&buf[..n]).await?;
total += n as u64;
}
client_write.shutdown().await?;
Ok::<u64, std::io::Error>(total)
};
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
}
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
///
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
pub async fn forward_bidirectional_with_timeouts(
client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
cancel: CancellationToken,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
}
let (mut client_read, mut client_write) = client.into_split();
let (mut backend_read, mut backend_write) = backend.into_split();
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = initial_len;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = backend_write.shutdown().await;
total
});
let la2 = Arc::clone(&last_activity);
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = client_write.shutdown().await;
total
});
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Connection cancelled by shutdown");
c2b_handle.abort();
b2c_handle.abort();
break;
}
_ = tokio::time::sleep(check_interval) => {
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
}
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}
/// Forward bidirectional with a callback for byte counting.
pub async fn forward_bidirectional_with_stats(
client: TcpStream,
backend: TcpStream,
initial_data: Option<&[u8]>,
stats: Arc<ForwardStats>,
) -> std::io::Result<()> {
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
Ok(())
}
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
/// updating a `ConnectionRecord` with byte counts and activity timestamps
/// in real time for zombie detection.
///
/// When `record` is `None`, this behaves identically to
/// `forward_bidirectional_with_timeouts`.
///
/// The record's `client_closed` / `backend_closed` flags are set when the
/// respective copy loop terminates, giving the zombie scanner visibility
/// into half-open connections.
pub async fn forward_bidirectional_with_record(
client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
cancel: CancellationToken,
record: Option<Arc<ConnectionRecord>>,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
if let Some(ref r) = record {
r.record_bytes_in(data.len() as u64);
}
}
let (mut client_read, mut client_write) = client.into_split();
let (mut backend_read, mut backend_write) = backend.into_split();
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
let rec1 = record.clone();
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = initial_len;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
let now_ms = start.elapsed().as_millis() as u64;
la1.store(now_ms, Ordering::Relaxed);
if let Some(ref r) = rec1 {
r.record_bytes_in(n as u64);
}
}
let _ = backend_write.shutdown().await;
// Mark client side as closed
if let Some(ref r) = rec1 {
r.client_closed.store(true, Ordering::Relaxed);
}
total
});
let la2 = Arc::clone(&last_activity);
let rec2 = record.clone();
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
let now_ms = start.elapsed().as_millis() as u64;
la2.store(now_ms, Ordering::Relaxed);
if let Some(ref r) = rec2 {
r.record_bytes_out(n as u64);
}
}
let _ = client_write.shutdown().await;
// Mark backend side as closed
if let Some(ref r) = rec2 {
r.backend_closed.store(true, Ordering::Relaxed);
}
total
});
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Connection cancelled by shutdown");
c2b_handle.abort();
b2c_handle.abort();
break;
}
_ = tokio::time::sleep(check_interval) => {
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
}
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}

View File

@@ -0,0 +1,22 @@
//! # rustproxy-passthrough
//!
//! Raw TCP/SNI passthrough engine for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
pub mod tcp_listener;
pub mod sni_parser;
pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_record;
pub mod connection_tracker;
pub mod socket_relay;
pub use tcp_listener::*;
pub use sni_parser::*;
pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_record::*;
pub use connection_tracker::*;
pub use socket_relay::*;

View File

@@ -0,0 +1,129 @@
use std::net::SocketAddr;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProxyProtocolError {
#[error("Invalid PROXY protocol header")]
InvalidHeader,
#[error("Unsupported PROXY protocol version")]
UnsupportedVersion,
#[error("Parse error: {0}")]
Parse(String),
}
/// Parsed PROXY protocol v1 header.
#[derive(Debug, Clone)]
pub struct ProxyProtocolHeader {
pub source_addr: SocketAddr,
pub dest_addr: SocketAddr,
pub protocol: ProxyProtocol,
}
/// Protocol in PROXY header.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProxyProtocol {
Tcp4,
Tcp6,
Unknown,
}
/// Parse a PROXY protocol v1 header from data.
///
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
// Find the end of the header line
let line_end = data
.windows(2)
.position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end])
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader);
}
let parts: Vec<&str> = line.split(' ').collect();
if parts.len() != 6 {
return Err(ProxyProtocolError::InvalidHeader);
}
let protocol = match parts[1] {
"TCP4" => ProxyProtocol::Tcp4,
"TCP6" => ProxyProtocol::Tcp6,
"UNKNOWN" => ProxyProtocol::Unknown,
_ => return Err(ProxyProtocolError::UnsupportedVersion),
};
let src_ip: std::net::IpAddr = parts[2]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
let dst_ip: std::net::IpAddr = parts[3]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
let src_port: u16 = parts[4]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
let dst_port: u16 = parts[5]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
let header = ProxyProtocolHeader {
source_addr: SocketAddr::new(src_ip, src_port),
dest_addr: SocketAddr::new(dst_ip, dst_port),
protocol,
};
// Consumed bytes = line + \r\n
Ok((header, line_end + 2))
}
/// Generate a PROXY protocol v1 header string.
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
format!(
"PROXY {} {} {} {} {}\r\n",
proto,
source.ip(),
dest.ip(),
source.port(),
dest.port()
)
}
/// Check if data starts with a PROXY protocol v1 header.
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
data.starts_with(b"PROXY ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_v1_tcp4() {
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
let (parsed, consumed) = parse_v1(header).unwrap();
assert_eq!(consumed, header.len());
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
assert_eq!(parsed.source_addr.port(), 12345);
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
assert_eq!(parsed.dest_addr.port(), 443);
}
#[test]
fn test_generate_v1() {
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
let header = generate_v1(&source, &dest);
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
}
#[test]
fn test_is_proxy_protocol() {
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
}
}

View File

@@ -0,0 +1,287 @@
//! ClientHello SNI extraction via manual byte parsing.
//! No TLS stack needed - we just parse enough of the ClientHello to extract the SNI.
/// Result of SNI extraction.
#[derive(Debug)]
pub enum SniResult {
/// Successfully extracted SNI hostname.
Found(String),
/// TLS ClientHello detected but no SNI extension present.
NoSni,
/// Not a TLS ClientHello (plain HTTP or other protocol).
NotTls,
/// Need more data to determine.
NeedMoreData,
}
/// Extract the SNI hostname from a TLS ClientHello message.
///
/// This parses just enough of the TLS record to find the SNI extension,
/// without performing any actual TLS operations.
pub fn extract_sni(data: &[u8]) -> SniResult {
// Minimum TLS record header is 5 bytes
if data.len() < 5 {
return SniResult::NeedMoreData;
}
// Check for TLS record: content_type=22 (Handshake)
if data[0] != 0x16 {
return SniResult::NotTls;
}
// TLS version (major.minor) - accept any
// data[1..2] = version
// Record length
let record_len = ((data[3] as usize) << 8) | (data[4] as usize);
let _total_len = 5 + record_len;
// We need at least the handshake header (5 TLS + 4 handshake = 9)
if data.len() < 9 {
return SniResult::NeedMoreData;
}
// Handshake type = 1 (ClientHello)
if data[5] != 0x01 {
return SniResult::NotTls;
}
// Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16)
| ((data[7] as usize) << 8)
| (data[8] as usize);
let hello = &data[9..];
// ClientHello structure:
// 2 bytes: client version
// 32 bytes: random
// 1 byte: session_id length + session_id
let mut pos = 2 + 32; // skip version + random
if pos >= hello.len() {
return SniResult::NeedMoreData;
}
// Session ID
let session_id_len = hello[pos] as usize;
pos += 1 + session_id_len;
if pos + 2 > hello.len() {
return SniResult::NeedMoreData;
}
// Cipher suites
let cipher_suites_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
pos += 2 + cipher_suites_len;
if pos + 1 > hello.len() {
return SniResult::NeedMoreData;
}
// Compression methods
let compression_len = hello[pos] as usize;
pos += 1 + compression_len;
if pos + 2 > hello.len() {
// No extensions
return SniResult::NoSni;
}
// Extensions length
let extensions_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
pos += 2;
let extensions_end = pos + extensions_len;
if extensions_end > hello.len() {
// Partial extensions, try to parse what we have
}
// Parse extensions looking for SNI (type 0x0000)
while pos + 4 <= hello.len() && pos < extensions_end {
let ext_type = ((hello[pos] as u16) << 8) | (hello[pos + 1] as u16);
let ext_len = ((hello[pos + 2] as usize) << 8) | (hello[pos + 3] as usize);
pos += 4;
if ext_type == 0x0000 {
// SNI extension
return parse_sni_extension(&hello[pos..(pos + ext_len).min(hello.len())], ext_len);
}
pos += ext_len;
}
SniResult::NoSni
}
/// Parse the SNI extension data.
fn parse_sni_extension(data: &[u8], _ext_len: usize) -> SniResult {
if data.len() < 5 {
return SniResult::NeedMoreData;
}
// Server name list length
let _list_len = ((data[0] as usize) << 8) | (data[1] as usize);
// Server name type (0 = hostname)
if data[2] != 0x00 {
return SniResult::NoSni;
}
// Hostname length
let name_len = ((data[3] as usize) << 8) | (data[4] as usize);
if data.len() < 5 + name_len {
return SniResult::NeedMoreData;
}
match std::str::from_utf8(&data[5..5 + name_len]) {
Ok(hostname) => SniResult::Found(hostname.to_lowercase()),
Err(_) => SniResult::NoSni,
}
}
/// Check if the initial bytes look like a TLS ClientHello.
pub fn is_tls(data: &[u8]) -> bool {
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
}
/// Check if the initial bytes look like HTTP.
pub fn is_http(data: &[u8]) -> bool {
if data.len() < 4 {
return false;
}
// Check for common HTTP methods
let starts = [
b"GET " as &[u8],
b"POST",
b"PUT ",
b"HEAD",
b"DELE",
b"PATC",
b"OPTI",
b"CONN",
];
starts.iter().any(|s| data.starts_with(s))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_not_tls() {
let http_data = b"GET / HTTP/1.1\r\n";
assert!(matches!(extract_sni(http_data), SniResult::NotTls));
}
#[test]
fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
}
#[test]
fn test_is_tls() {
assert!(is_tls(&[0x16, 0x03, 0x01]));
assert!(!is_tls(&[0x47, 0x45, 0x54])); // "GET"
}
#[test]
fn test_is_http() {
assert!(is_http(b"GET /"));
assert!(is_http(b"POST /api"));
assert!(!is_http(&[0x16, 0x03, 0x01]));
}
#[test]
fn test_real_client_hello() {
// A minimal TLS 1.2 ClientHello with SNI "example.com"
let client_hello: Vec<u8> = build_test_client_hello("example.com");
match extract_sni(&client_hello) {
SniResult::Found(sni) => assert_eq!(sni, "example.com"),
other => panic!("Expected Found, got {:?}", other),
}
}
/// Build a minimal TLS ClientHello for testing.
fn build_test_client_hello(hostname: &str) -> Vec<u8> {
let hostname_bytes = hostname.as_bytes();
// SNI extension
let sni_ext_data = {
let mut d = Vec::new();
// Server name list length
let name_entry_len = 3 + hostname_bytes.len(); // type(1) + len(2) + name
d.push(((name_entry_len >> 8) & 0xFF) as u8);
d.push((name_entry_len & 0xFF) as u8);
// Host name type = 0
d.push(0x00);
// Host name length
d.push(((hostname_bytes.len() >> 8) & 0xFF) as u8);
d.push((hostname_bytes.len() & 0xFF) as u8);
// Host name
d.extend_from_slice(hostname_bytes);
d
};
// Extension: type=0x0000 (SNI), length, data
let sni_extension = {
let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data);
e
};
// Extensions block
let extensions = {
let mut ext = Vec::new();
ext.push(((sni_extension.len() >> 8) & 0xFF) as u8);
ext.push((sni_extension.len() & 0xFF) as u8);
ext.extend_from_slice(&sni_extension);
ext
};
// ClientHello body
let hello_body = {
let mut h = Vec::new();
// Client version TLS 1.2
h.push(0x03); h.push(0x03);
// Random (32 bytes)
h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0
h.push(0x00);
// Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01); h.push(0x00);
// Extensions
h.extend_from_slice(&extensions);
h
};
// Handshake: type=1 (ClientHello), length
let handshake = {
let mut hs = Vec::new();
hs.push(0x01); // ClientHello
// 3-byte length
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
hs.push((hello_body.len() & 0xFF) as u8);
hs.extend_from_slice(&hello_body);
hs
};
// TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new();
record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake);
record
}
}

View File

@@ -0,0 +1,126 @@
//! Socket handler relay for connecting client connections to a TypeScript handler
//! via a Unix domain socket.
//!
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
use tokio::net::UnixStream;
use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::net::TcpStream;
use serde::Serialize;
use tracing::debug;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct RelayMetadata {
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data_base64: Option<String>,
}
/// Relay a client connection to a TypeScript handler via Unix domain socket.
///
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
pub async fn relay_to_handler(
client: TcpStream,
relay_socket_path: &str,
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data: Option<&[u8]>,
) -> std::io::Result<()> {
debug!(
"Relaying connection {} to handler socket {}",
connection_id, relay_socket_path
);
// Connect to TypeScript handler Unix socket
let mut handler = UnixStream::connect(relay_socket_path).await?;
// Build and send metadata header
let initial_data_base64 = initial_data.map(base64_encode);
let metadata = RelayMetadata {
connection_id,
remote_ip,
remote_port,
local_port,
sni,
route_name,
initial_data_base64,
};
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
handler.write_all(metadata_json.as_bytes()).await?;
handler.write_all(b"\n").await?;
// Bidirectional relay between client and handler
let (mut client_read, mut client_write) = client.into_split();
let (mut handler_read, mut handler_write) = handler.into_split();
let c2h = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if handler_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = handler_write.shutdown().await;
});
let h2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match handler_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = client_write.shutdown().await;
});
let _ = tokio::join!(c2h, h2c);
debug!("Relay connection {} completed", connection_id);
Ok(())
}
/// Simple base64 encoding without external dependency.
fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in data.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(n & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}

View File

@@ -0,0 +1,874 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{info, error, debug, warn};
use thiserror::Error;
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
use rustproxy_http::HttpProxyService;
use crate::sni_parser;
use crate::forwarder;
use crate::tls_handler;
use crate::connection_tracker::ConnectionTracker;
#[derive(Debug, Error)]
pub enum ListenerError {
#[error("Failed to bind port {port}: {source}")]
BindFailed { port: u16, source: std::io::Error },
#[error("Port {0} already bound")]
AlreadyBound(u16),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
/// TLS configuration for a specific domain.
#[derive(Clone)]
pub struct TlsCertConfig {
pub cert_pem: String,
pub key_pem: String,
}
/// Timeout and connection management configuration.
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
/// Timeout for establishing connection to backend (ms)
pub connection_timeout_ms: u64,
/// Timeout for initial data/SNI peek (ms)
pub initial_data_timeout_ms: u64,
/// Socket inactivity timeout (ms)
pub socket_timeout_ms: u64,
/// Maximum connection lifetime (ms)
pub max_connection_lifetime_ms: u64,
/// Graceful shutdown timeout (ms)
pub graceful_shutdown_timeout_ms: u64,
/// Maximum connections per IP (None = unlimited)
pub max_connections_per_ip: Option<u64>,
/// Connection rate limit per minute per IP (None = unlimited)
pub connection_rate_limit_per_minute: Option<u64>,
/// Keep-alive treatment
pub keep_alive_treatment: Option<rustproxy_config::KeepAliveTreatment>,
/// Inactivity multiplier for keep-alive connections
pub keep_alive_inactivity_multiplier: Option<f64>,
/// Extended keep-alive lifetime (ms) for Extended treatment mode
pub extended_keep_alive_lifetime_ms: Option<u64>,
/// Whether to accept PROXY protocol
pub accept_proxy_protocol: bool,
/// Whether to send PROXY protocol
pub send_proxy_protocol: bool,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
connection_timeout_ms: 30_000,
initial_data_timeout_ms: 60_000,
socket_timeout_ms: 3_600_000,
max_connection_lifetime_ms: 86_400_000,
graceful_shutdown_timeout_ms: 30_000,
max_connections_per_ip: None,
connection_rate_limit_per_minute: None,
keep_alive_treatment: None,
keep_alive_inactivity_multiplier: None,
extended_keep_alive_lifetime_ms: None,
accept_proxy_protocol: false,
send_proxy_protocol: false,
}
}
}
/// Manages TCP listeners for all configured ports.
pub struct TcpListenerManager {
/// Active listeners indexed by port
listeners: HashMap<u16, tokio::task::JoinHandle<()>>,
/// Shared route manager
route_manager: Arc<RouteManager>,
/// Shared metrics collector
metrics: Arc<MetricsCollector>,
/// TLS acceptors indexed by domain
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
/// HTTP proxy service for HTTP-level forwarding
http_proxy: Arc<HttpProxyService>,
/// Connection configuration
conn_config: Arc<ConnectionConfig>,
/// Connection tracker for per-IP limits
conn_tracker: Arc<ConnectionTracker>,
/// Cancellation token for graceful shutdown
cancel_token: CancellationToken,
}
impl TcpListenerManager {
pub fn new(route_manager: Arc<RouteManager>) -> Self {
let metrics = Arc::new(MetricsCollector::new());
let http_proxy = Arc::new(HttpProxyService::new(
Arc::clone(&route_manager),
Arc::clone(&metrics),
));
let conn_config = ConnectionConfig::default();
let conn_tracker = Arc::new(ConnectionTracker::new(
conn_config.max_connections_per_ip,
conn_config.connection_rate_limit_per_minute,
));
Self {
listeners: HashMap::new(),
route_manager,
metrics,
tls_configs: Arc::new(HashMap::new()),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
cancel_token: CancellationToken::new(),
}
}
/// Create with a metrics collector.
pub fn with_metrics(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
let http_proxy = Arc::new(HttpProxyService::new(
Arc::clone(&route_manager),
Arc::clone(&metrics),
));
let conn_config = ConnectionConfig::default();
let conn_tracker = Arc::new(ConnectionTracker::new(
conn_config.max_connections_per_ip,
conn_config.connection_rate_limit_per_minute,
));
Self {
listeners: HashMap::new(),
route_manager,
metrics,
tls_configs: Arc::new(HashMap::new()),
http_proxy,
conn_config: Arc::new(conn_config),
conn_tracker,
cancel_token: CancellationToken::new(),
}
}
/// Set connection configuration.
pub fn set_connection_config(&mut self, config: ConnectionConfig) {
self.conn_tracker = Arc::new(ConnectionTracker::new(
config.max_connections_per_ip,
config.connection_rate_limit_per_minute,
));
self.conn_config = Arc::new(config);
}
/// Set TLS certificate configurations.
pub fn set_tls_configs(&mut self, configs: HashMap<String, TlsCertConfig>) {
self.tls_configs = Arc::new(configs);
}
/// Start listening on a port.
pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> {
if self.listeners.contains_key(&port) {
return Err(ListenerError::AlreadyBound(port));
}
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await.map_err(|e| {
ListenerError::BindFailed { port, source: e }
})?;
info!("Listening on port {}", port);
let route_manager = Arc::clone(&self.route_manager);
let metrics = Arc::clone(&self.metrics);
let tls_configs = Arc::clone(&self.tls_configs);
let http_proxy = Arc::clone(&self.http_proxy);
let conn_config = Arc::clone(&self.conn_config);
let conn_tracker = Arc::clone(&self.conn_tracker);
let cancel = self.cancel_token.clone();
let handle = tokio::spawn(async move {
Self::accept_loop(
listener, port, route_manager, metrics, tls_configs,
http_proxy, conn_config, conn_tracker, cancel,
).await;
});
self.listeners.insert(port, handle);
Ok(())
}
/// Stop listening on a port.
pub fn remove_port(&mut self, port: u16) -> bool {
if let Some(handle) = self.listeners.remove(&port) {
handle.abort();
info!("Stopped listening on port {}", port);
true
} else {
false
}
}
/// Get all currently listening ports.
pub fn listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.listeners.keys().copied().collect();
ports.sort();
ports
}
/// Stop all listeners gracefully.
///
/// Signals cancellation and waits up to `graceful_shutdown_timeout_ms` for
/// connections to drain, then aborts remaining tasks.
pub async fn graceful_stop(&mut self) {
let timeout_ms = self.conn_config.graceful_shutdown_timeout_ms;
info!("Initiating graceful shutdown (timeout: {}ms)", timeout_ms);
// Signal all accept loops to stop accepting new connections
self.cancel_token.cancel();
// Wait for existing connections to drain
let timeout = std::time::Duration::from_millis(timeout_ms);
let deadline = tokio::time::Instant::now() + timeout;
for (port, handle) in self.listeners.drain() {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
handle.abort();
warn!("Force-stopped listener on port {} (timeout exceeded)", port);
} else {
match tokio::time::timeout(remaining, handle).await {
Ok(_) => info!("Listener on port {} stopped gracefully", port),
Err(_) => {
warn!("Listener on port {} did not stop in time, aborting", port);
}
}
}
}
// Reset cancellation token for potential restart
self.cancel_token = CancellationToken::new();
info!("Graceful shutdown complete");
}
/// Stop all listeners immediately (backward compatibility).
pub fn stop_all(&mut self) {
self.cancel_token.cancel();
for (port, handle) in self.listeners.drain() {
handle.abort();
info!("Stopped listening on port {}", port);
}
self.cancel_token = CancellationToken::new();
}
/// Update the route manager (for hot-reload).
pub fn update_route_manager(&mut self, route_manager: Arc<RouteManager>) {
self.route_manager = route_manager;
}
/// Get a reference to the metrics collector.
pub fn metrics(&self) -> &Arc<MetricsCollector> {
&self.metrics
}
/// Accept loop for a single port.
async fn accept_loop(
listener: TcpListener,
port: u16,
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
conn_tracker: Arc<ConnectionTracker>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("Accept loop on port {} shutting down", port);
break;
}
result = listener.accept() => {
match result {
Ok((stream, peer_addr)) => {
let ip = peer_addr.ip();
// Check per-IP limits and rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("Rejected connection from {} (per-IP limit or rate limit)", peer_addr);
drop(stream);
continue;
}
conn_tracker.connection_opened(&ip);
let rm = Arc::clone(&route_manager);
let m = Arc::clone(&metrics);
let tc = Arc::clone(&tls_configs);
let hp = Arc::clone(&http_proxy);
let cc = Arc::clone(&conn_config);
let ct = Arc::clone(&conn_tracker);
let cn = cancel.clone();
debug!("Accepted connection from {} on port {}", peer_addr, port);
tokio::spawn(async move {
let result = Self::handle_connection(
stream, port, peer_addr, rm, m, tc, hp, cc, cn,
).await;
if let Err(e) = result {
debug!("Connection error from {}: {}", peer_addr, e);
}
ct.connection_closed(&ip);
});
}
Err(e) => {
error!("Accept error on port {}: {}", port, e);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
}
}
}
/// Handle a single incoming connection.
async fn handle_connection(
mut stream: tokio::net::TcpStream,
port: u16,
peer_addr: std::net::SocketAddr,
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
http_proxy: Arc<HttpProxyService>,
conn_config: Arc<ConnectionConfig>,
cancel: CancellationToken,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use tokio::io::AsyncReadExt;
stream.set_nodelay(true)?;
// Handle PROXY protocol if configured
let mut effective_peer_addr = peer_addr;
if conn_config.accept_proxy_protocol {
let mut proxy_peek = vec![0u8; 256];
let pn = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
stream.peek(&mut proxy_peek),
).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Initial data timeout (proxy protocol peek)".into()),
};
if pn > 0 && crate::proxy_protocol::is_proxy_protocol_v1(&proxy_peek[..pn]) {
match crate::proxy_protocol::parse_v1(&proxy_peek[..pn]) {
Ok((header, consumed)) => {
debug!("PROXY protocol: real client {} -> {}", header.source_addr, header.dest_addr);
effective_peer_addr = header.source_addr;
// Consume the proxy protocol header bytes
let mut discard = vec![0u8; consumed];
stream.read_exact(&mut discard).await?;
}
Err(e) => {
debug!("Failed to parse PROXY protocol header: {}", e);
// Not a PROXY protocol header, continue normally
}
}
}
}
let peer_addr = effective_peer_addr;
// Peek at initial bytes with timeout
let mut peek_buf = vec![0u8; 4096];
let n = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
stream.peek(&mut peek_buf),
).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Initial data timeout".into()),
};
let initial_data = &peek_buf[..n];
// Determine connection type and extract SNI if TLS
let is_tls = sni_parser::is_tls(initial_data);
let is_http = sni_parser::is_http(initial_data);
let domain = if is_tls {
match sni_parser::extract_sni(initial_data) {
sni_parser::SniResult::Found(sni) => Some(sni),
sni_parser::SniResult::NoSni => None,
sni_parser::SniResult::NeedMoreData => {
let mut bigger_buf = vec![0u8; 16384];
let n = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
stream.peek(&mut bigger_buf),
).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("SNI data timeout".into()),
};
match sni_parser::extract_sni(&bigger_buf[..n]) {
sni_parser::SniResult::Found(sni) => Some(sni),
_ => None,
}
}
sni_parser::SniResult::NotTls => None,
}
} else {
None
};
// Match route
let ctx = rustproxy_routing::MatchContext {
port,
domain: domain.as_deref(),
path: None,
client_ip: Some(&peer_addr.ip().to_string()),
tls_version: None,
headers: None,
is_tls,
};
let route_match = route_manager.find_route(&ctx);
let route_match = match route_match {
Some(rm) => rm,
None => {
debug!("No route matched for port {} domain {:?}", port, domain);
return Ok(());
}
};
let route_id = route_match.route.id.as_deref();
// Check route-level IP security for passthrough connections
if let Some(ref security) = route_match.route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security,
&peer_addr.ip(),
) {
debug!("Connection from {} blocked by route security", peer_addr);
return Ok(());
}
}
// Track connection in metrics
metrics.connection_opened(route_id);
let target = match route_match.target {
Some(t) => t,
None => {
debug!("Route matched but no target available");
metrics.connection_closed(route_id);
return Ok(());
}
};
let target_host = target.host.first().to_string();
let target_port = target.port.resolve(port);
let tls_mode = route_match.route.tls_mode();
// Connection timeout for backend connections
let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms);
let base_inactivity_ms = conn_config.socket_timeout_ms;
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
(
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
std::time::Duration::from_millis(extended_lifetime),
)
}
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_secs(u64::MAX / 2),
)
}
_ => {
// Standard
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
)
}
};
// Determine if we should send PROXY protocol to backend
let should_send_proxy = conn_config.send_proxy_protocol
|| route_match.route.action.send_proxy_protocol.unwrap_or(false)
|| target.send_proxy_protocol.unwrap_or(false);
// Generate PROXY protocol header if needed
let proxy_header = if should_send_proxy {
let dest = std::net::SocketAddr::new(
target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)),
target_port,
);
Some(crate::proxy_protocol::generate_v1(&peer_addr, &dest))
} else {
None
};
let result = match tls_mode {
Some(rustproxy_config::TlsMode::Passthrough) => {
// Raw TCP passthrough - connect to backend and forward
let mut backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Backend connection timeout".into()),
};
backend.set_nodelay(true)?;
// Send PROXY protocol header if configured
if let Some(ref header) = proxy_header {
use tokio::io::AsyncWriteExt;
backend.write_all(header.as_bytes()).await?;
}
debug!(
"Passthrough: {} -> {}:{} (SNI: {:?})",
peer_addr, target_host, target_port, domain
);
let mut actual_buf = vec![0u8; n];
stream.read_exact(&mut actual_buf).await?;
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend, Some(&actual_buf),
inactivity_timeout, max_lifetime, cancel,
).await?;
metrics.record_bytes(bytes_in, bytes_out, route_id);
Ok(())
}
Some(rustproxy_config::TlsMode::Terminate) => {
let tls_config = Self::find_tls_config(&domain, &tls_configs)?;
// TLS accept with timeout, applying route-level TLS settings
let route_tls = route_match.route.action.tls.as_ref();
let acceptor = tls_handler::build_tls_acceptor_with_config(
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
)?;
let tls_stream = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
tls_handler::accept_tls(stream, &acceptor),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(_) => return Err("TLS handshake timeout".into()),
};
// Peek at decrypted data to determine if HTTP
let mut buf_stream = tokio::io::BufReader::new(tls_stream);
let peeked = {
use tokio::io::AsyncBufReadExt;
match buf_stream.fill_buf().await {
Ok(data) => sni_parser::is_http(data),
Err(_) => false,
}
};
if peeked {
debug!(
"TLS Terminate + HTTP: {} -> {}:{} (domain: {:?})",
peer_addr, target_host, target_port, domain
);
http_proxy.handle_io(buf_stream, peer_addr, port).await;
} else {
debug!(
"TLS Terminate + TCP: {} -> {}:{} (domain: {:?})",
peer_addr, target_host, target_port, domain
);
// Raw TCP forwarding of decrypted stream
let backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Backend connection timeout".into()),
};
backend.set_nodelay(true)?;
let (tls_read, tls_write) = tokio::io::split(buf_stream);
let (backend_read, backend_write) = tokio::io::split(backend);
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
tls_read, tls_write, backend_read, backend_write,
inactivity_timeout, max_lifetime,
).await;
metrics.record_bytes(bytes_in, bytes_out, route_id);
}
Ok(())
}
Some(rustproxy_config::TlsMode::TerminateAndReencrypt) => {
let route_tls = route_match.route.action.tls.as_ref();
Self::handle_tls_terminate_reencrypt(
stream, n, &domain, &target_host, target_port,
peer_addr, &tls_configs, &metrics, route_id, &conn_config, route_tls,
).await
}
None => {
if is_http {
// Plain HTTP - use HTTP proxy for request-level routing
debug!("HTTP proxy: {} on port {}", peer_addr, port);
http_proxy.handle_connection(stream, peer_addr, port).await;
Ok(())
} else {
// Plain TCP forwarding (non-HTTP)
let mut backend = match tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err("Backend connection timeout".into()),
};
backend.set_nodelay(true)?;
// Send PROXY protocol header if configured
if let Some(ref header) = proxy_header {
use tokio::io::AsyncWriteExt;
backend.write_all(header.as_bytes()).await?;
}
debug!(
"Forward: {} -> {}:{}",
peer_addr, target_host, target_port
);
let mut actual_buf = vec![0u8; n];
stream.read_exact(&mut actual_buf).await?;
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
stream, backend, Some(&actual_buf),
inactivity_timeout, max_lifetime, cancel,
).await?;
metrics.record_bytes(bytes_in, bytes_out, route_id);
Ok(())
}
}
};
metrics.connection_closed(route_id);
result
}
/// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend.
async fn handle_tls_terminate_reencrypt(
stream: tokio::net::TcpStream,
_peek_len: usize,
domain: &Option<String>,
target_host: &str,
target_port: u16,
peer_addr: std::net::SocketAddr,
tls_configs: &HashMap<String, TlsCertConfig>,
metrics: &MetricsCollector,
route_id: Option<&str>,
conn_config: &ConnectionConfig,
route_tls: Option<&rustproxy_config::RouteTls>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let tls_config = Self::find_tls_config(domain, tls_configs)?;
let acceptor = tls_handler::build_tls_acceptor_with_config(
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
)?;
// Accept TLS from client with timeout
let client_tls = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
tls_handler::accept_tls(stream, &acceptor),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(_) => return Err("TLS handshake timeout".into()),
};
debug!(
"TLS Terminate+Reencrypt: {} -> {}:{} (domain: {:?})",
peer_addr, target_host, target_port, domain
);
// Connect to backend over TLS with timeout
let backend_tls = match tokio::time::timeout(
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
tls_handler::connect_tls(target_host, target_port),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(_) => return Err("Backend TLS connection timeout".into()),
};
// Forward between two TLS streams
let (client_read, client_write) = tokio::io::split(client_tls);
let (backend_read, backend_write) = tokio::io::split(backend_tls);
let base_inactivity_ms = conn_config.socket_timeout_ms;
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
(
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
std::time::Duration::from_millis(extended_lifetime),
)
}
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_secs(u64::MAX / 2),
)
}
_ => {
// Standard
(
std::time::Duration::from_millis(base_inactivity_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
)
}
};
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
client_read, client_write, backend_read, backend_write,
inactivity_timeout, max_lifetime,
).await;
metrics.record_bytes(bytes_in, bytes_out, route_id);
Ok(())
}
/// Find the TLS config for a given domain.
fn find_tls_config<'a>(
domain: &Option<String>,
tls_configs: &'a HashMap<String, TlsCertConfig>,
) -> Result<&'a TlsCertConfig, Box<dyn std::error::Error + Send + Sync>> {
if let Some(domain) = domain {
// Try exact match
if let Some(config) = tls_configs.get(domain) {
return Ok(config);
}
// Try wildcard
if let Some(dot_pos) = domain.find('.') {
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
if let Some(config) = tls_configs.get(&wildcard) {
return Ok(config);
}
}
}
// Try default/fallback cert
if let Some(config) = tls_configs.get("*") {
return Ok(config);
}
// Try first available cert
if let Some((_key, config)) = tls_configs.iter().next() {
return Ok(config);
}
Err("No TLS certificate available for this domain".into())
}
/// Forward bidirectional between two split streams with inactivity and lifetime timeouts.
async fn forward_bidirectional_split_with_timeouts<R1, W1, R2, W2>(
mut client_read: R1,
mut client_write: W1,
mut backend_read: R2,
mut backend_write: W2,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
) -> (u64, u64)
where
R1: tokio::io::AsyncRead + Unpin + Send + 'static,
W1: tokio::io::AsyncWrite + Unpin + Send + 'static,
R2: tokio::io::AsyncRead + Unpin + Send + 'static,
W2: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(
start.elapsed().as_millis() as u64,
Ordering::Relaxed,
);
}
let _ = backend_write.shutdown().await;
total
});
let la2 = Arc::clone(&last_activity);
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(
start.elapsed().as_millis() as u64,
Ordering::Relaxed,
);
}
let _ = client_write.shutdown().await;
total
});
// Watchdog task: check for inactivity and max lifetime
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::time::sleep(check_interval).await;
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
// No activity since last check
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
(bytes_in, bytes_out)
}
}

View File

@@ -0,0 +1,190 @@
use std::io::BufReader;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tracing::debug;
/// Ensure the default crypto provider is installed.
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
pub fn build_tls_acceptor_with_config(
cert_pem: &str,
key_pem: &str,
tls_config: Option<&rustproxy_config::RouteTls>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
let mut config = if let Some(route_tls) = tls_config {
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?
};
// Apply session timeout if configured
if let Some(route_tls) = tls_config {
if let Some(timeout_secs) = route_tls.session_timeout {
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
256, // max sessions
);
debug!("TLS session timeout configured: {}s", timeout_secs);
}
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
};
let mut result = Vec::new();
for v in versions {
match v.as_str() {
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
if !result.contains(&&rustls::version::TLS12) {
result.push(&rustls::version::TLS12);
}
}
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
if !result.contains(&&rustls::version::TLS13) {
result.push(&rustls::version::TLS13);
}
}
other => {
debug!("Unknown TLS version '{}', ignoring", other);
}
}
}
if result.is_empty() {
// Fallback to both if no valid versions specified
vec![&rustls::version::TLS12, &rustls::version::TLS13]
} else {
result
}
}
/// Accept a TLS connection from a client stream.
pub async fn accept_tls(
stream: TcpStream,
acceptor: &TlsAcceptor,
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
let tls_stream = acceptor.accept(stream).await?;
debug!("TLS handshake completed");
Ok(tls_stream)
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
pub async fn connect_tls(
host: &str,
port: u16,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
let tls_stream = connector.connect(server_name, stream).await?;
debug!("Backend TLS connection established to {}:{}", host, port);
Ok(tls_stream)
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
Ok(certs)
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
Ok(key)
}
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
/// In internal networks, backends may use self-signed certs.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}