feat(rustproxy): add protocol-based routing and backend TLS re-encryption support

This commit is contained in:
2026-02-16 12:02:36 +00:00
parent db932e8acc
commit f0b7c27996
11 changed files with 536 additions and 73 deletions

View File

@@ -18,6 +18,8 @@ http-body = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
tokio = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }

View File

@@ -18,6 +18,9 @@ use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use std::pin::Pin;
use std::task::{Context, Poll};
use rustproxy_routing::RouteManager;
use rustproxy_metrics::MetricsCollector;
@@ -35,6 +38,125 @@ const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::
/// Default WebSocket max lifetime (24 hours).
const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400);
/// Backend stream that can be either plain TCP or TLS-wrapped.
/// Used for `terminate-and-reencrypt` mode where the backend requires TLS.
pub(crate) enum BackendStream {
Plain(TcpStream),
Tls(tokio_rustls::client::TlsStream<TcpStream>),
}
impl tokio::io::AsyncRead for BackendStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
BackendStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl tokio::io::AsyncWrite for BackendStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
BackendStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_flush(cx),
BackendStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
BackendStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
BackendStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
/// Connect to a backend over TLS. Uses InsecureVerifier for internal backends
/// with self-signed certs (same pattern as tls_handler::connect_tls).
async fn connect_tls_backend(
host: &str,
port: u16,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
let _ = rustls::crypto::ring::default_provider().install_default();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
let tls_stream = connector.connect(server_name, stream).await?;
debug!("Backend TLS connection established to {}:{}", host, port);
Ok(tls_stream)
}
/// Insecure certificate verifier for backend TLS connections.
/// Internal backends may use self-signed certs.
#[derive(Debug)]
struct InsecureBackendVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureBackendVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
/// HTTP proxy service that processes HTTP traffic.
pub struct HttpProxyService {
route_manager: Arc<RouteManager>,
@@ -173,6 +295,7 @@ impl HttpProxyService {
tls_version: None,
headers: Some(&headers),
is_tls: false,
protocol: Some("http"),
};
let route_match = match self.route_manager.find_route(&ctx) {
@@ -273,28 +396,51 @@ impl HttpProxyService {
}
}
// Connect to upstream with timeout
let upstream_stream = match tokio::time::timeout(
self.connect_timeout,
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id, Some(&ip_str));
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
// Connect to upstream with timeout (TLS if upstream.use_tls is set)
let backend = if upstream.use_tls {
match tokio::time::timeout(
self.connect_timeout,
connect_tls_backend(&upstream.host, upstream.port),
).await {
Ok(Ok(tls)) => BackendStream::Tls(tls),
Ok(Err(e)) => {
error!("Failed TLS connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id, Some(&ip_str));
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
}
Err(_) => {
error!("Upstream TLS connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id, Some(&ip_str));
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
}
}
Err(_) => {
error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id, Some(&ip_str));
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
} else {
match tokio::time::timeout(
self.connect_timeout,
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
).await {
Ok(Ok(s)) => {
s.set_nodelay(true).ok();
BackendStream::Plain(s)
}
Ok(Err(e)) => {
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id, Some(&ip_str));
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
Err(_) => {
error!("Upstream connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.connection_closed(route_id, Some(&ip_str));
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
}
}
};
upstream_stream.set_nodelay(true).ok();
let io = TokioIo::new(upstream_stream);
let io = TokioIo::new(backend);
let result = if use_h2 {
// HTTP/2 backend
@@ -310,7 +456,7 @@ impl HttpProxyService {
/// Forward request to backend via HTTP/1.1 with body streaming.
async fn forward_h1(
&self,
io: TokioIo<TcpStream>,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
@@ -376,7 +522,7 @@ impl HttpProxyService {
/// Forward request to backend via HTTP/2 with body streaming.
async fn forward_h2(
&self,
io: TokioIo<TcpStream>,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
upstream_headers: hyper::HeaderMap,
@@ -516,26 +662,49 @@ impl HttpProxyService {
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
// Connect to upstream with timeout
let mut upstream_stream = match tokio::time::timeout(
self.connect_timeout,
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id, Some(source_ip));
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
// Connect to upstream with timeout (TLS if upstream.use_tls is set)
let mut upstream_stream: BackendStream = if upstream.use_tls {
match tokio::time::timeout(
self.connect_timeout,
connect_tls_backend(&upstream.host, upstream.port),
).await {
Ok(Ok(tls)) => BackendStream::Tls(tls),
Ok(Err(e)) => {
error!("WebSocket: failed TLS connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id, Some(source_ip));
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
}
Err(_) => {
error!("WebSocket: upstream TLS connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id, Some(source_ip));
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
}
}
Err(_) => {
error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id, Some(source_ip));
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
} else {
match tokio::time::timeout(
self.connect_timeout,
TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)),
).await {
Ok(Ok(s)) => {
s.set_nodelay(true).ok();
BackendStream::Plain(s)
}
Ok(Err(e)) => {
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id, Some(source_ip));
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
Err(_) => {
error!("WebSocket: upstream connect timeout for {}:{}", upstream.host, upstream.port);
self.upstream_selector.connection_ended(upstream_key);
self.metrics.connection_closed(route_id, Some(source_ip));
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
}
}
};
upstream_stream.set_nodelay(true).ok();
let path = req.uri().path().to_string();
let upstream_path = {