feat(rustproxy): add authenticated VPN route security
This commit is contained in:
@@ -114,6 +114,43 @@ pub enum IpAllowEntry {
|
||||
DomainScoped { ip: String, domains: Vec<String> },
|
||||
}
|
||||
|
||||
/// Authenticated VPN metadata received from trusted PROXY protocol TLVs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct VpnConnectionInfo {
|
||||
pub client_id: String,
|
||||
pub assigned_ip: String,
|
||||
pub transport_type: Option<String>,
|
||||
pub remote_addr: Option<String>,
|
||||
}
|
||||
|
||||
/// A VPN client allow entry: full-route client ID or domain-scoped client ID.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum VpnClientAllowEntry {
|
||||
Plain(String),
|
||||
DomainScoped {
|
||||
#[serde(rename = "clientId")]
|
||||
client_id: String,
|
||||
domains: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// VPN-specific route access control.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteVpnSecurity {
|
||||
/// Require authenticated VPN metadata for this route.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
/// Allowed VPN client IDs. Empty/None means any authenticated VPN client when required=true.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_clients: Option<Vec<VpnClientAllowEntry>>,
|
||||
/// Allowed VPN assigned tunnel IPs. Mainly for compatibility; client IDs are preferred.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_assigned_ips: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Security options for routes.
|
||||
/// Matches TypeScript: `IRouteSecurity`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -142,4 +179,7 @@ pub struct RouteSecurity {
|
||||
/// JWT auth
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_auth: Option<JwtAuthConfig>,
|
||||
/// Authenticated VPN client requirement/allow list.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub vpn: Option<RouteVpnSecurity>,
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ async fn handle_h3_request(
|
||||
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
|
||||
let conn_activity = ConnActivity::new_standalone();
|
||||
let response = http_proxy
|
||||
.handle_request(req, peer_addr, port, cancel, conn_activity)
|
||||
.handle_request(req, peer_addr, port, cancel, conn_activity, None)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_config::VpnConnectionInfo;
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_security::RateLimiter;
|
||||
|
||||
@@ -461,6 +462,20 @@ impl HttpProxyService {
|
||||
cancel: CancellationToken,
|
||||
) where
|
||||
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
self.handle_io_with_vpn(stream, peer_addr, port, cancel, None).await;
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection with optional authenticated VPN metadata.
|
||||
pub async fn handle_io_with_vpn<I>(
|
||||
self: Arc<Self>,
|
||||
stream: I,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
cancel: CancellationToken,
|
||||
vpn_info: Option<VpnConnectionInfo>,
|
||||
) where
|
||||
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
@@ -484,6 +499,7 @@ impl HttpProxyService {
|
||||
let la_inner = Arc::clone(&last_activity);
|
||||
let ar_inner = Arc::clone(&active_requests);
|
||||
let cancel_inner = cancel.clone();
|
||||
let vpn_info = Arc::new(vpn_info);
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
// Detect frontend protocol from the first request on this connection.
|
||||
// OnceLock ensures only the first call opens the counter.
|
||||
@@ -499,6 +515,7 @@ impl HttpProxyService {
|
||||
let svc = Arc::clone(&self);
|
||||
let peer = peer_addr;
|
||||
let cn = cancel_inner.clone();
|
||||
let vpn = Arc::clone(&vpn_info);
|
||||
let la = Arc::clone(&la_inner);
|
||||
let st = start;
|
||||
let ca = ConnActivity {
|
||||
@@ -510,7 +527,7 @@ impl HttpProxyService {
|
||||
};
|
||||
async move {
|
||||
let req = req.map(|body| BoxBody::new(body));
|
||||
let result = svc.handle_request(req, peer, port, cn, ca).await;
|
||||
let result = svc.handle_request(req, peer, port, cn, ca, vpn.as_ref().as_ref()).await;
|
||||
// Mark request end — update activity timestamp before guard drops
|
||||
la.store(st.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
drop(req_guard); // Explicitly drop to decrement active_requests
|
||||
@@ -600,6 +617,7 @@ impl HttpProxyService {
|
||||
port: u16,
|
||||
cancel: CancellationToken,
|
||||
mut conn_activity: ConnActivity,
|
||||
vpn_info: Option<&VpnConnectionInfo>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let host = extract_request_host(&req).map(str::to_string);
|
||||
|
||||
@@ -679,11 +697,12 @@ impl HttpProxyService {
|
||||
.or_insert_with(|| Arc::new(RateLimiter::new(rl.max_requests, rl.window)))
|
||||
.clone()
|
||||
});
|
||||
if let Some(response) = RequestFilter::apply_with_rate_limiter(
|
||||
if let Some(response) = RequestFilter::apply_with_rate_limiter_and_vpn(
|
||||
security,
|
||||
&req,
|
||||
&peer_addr,
|
||||
rate_limiter.as_ref(),
|
||||
vpn_info,
|
||||
) {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_config::{RouteSecurity, VpnClientAllowEntry, VpnConnectionInfo};
|
||||
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
|
||||
|
||||
use crate::request_host::extract_request_host;
|
||||
@@ -33,9 +33,22 @@ impl RequestFilter {
|
||||
req: &Request<impl hyper::body::Body>,
|
||||
peer_addr: &SocketAddr,
|
||||
rate_limiter: Option<&Arc<RateLimiter>>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
Self::apply_with_rate_limiter_and_vpn(security, req, peer_addr, rate_limiter, None)
|
||||
}
|
||||
|
||||
/// Apply security filters with optional rate limiter and authenticated VPN metadata.
|
||||
/// Returns Some(response) if the request should be blocked.
|
||||
pub fn apply_with_rate_limiter_and_vpn(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<impl hyper::body::Body>,
|
||||
peer_addr: &SocketAddr,
|
||||
rate_limiter: Option<&Arc<RateLimiter>>,
|
||||
vpn_info: Option<&VpnConnectionInfo>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
let host = extract_request_host(req);
|
||||
|
||||
// IP filter (domain-aware: use the same host extraction as route matching)
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
@@ -43,12 +56,15 @@ impl RequestFilter {
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
let host = extract_request_host(req);
|
||||
if !filter.is_allowed_for_domain(&normalized, host) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
}
|
||||
|
||||
if !Self::check_vpn_security(security, vpn_info, host) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "VPN access denied"));
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(ref rate_limit_config) = security.rate_limit {
|
||||
if rate_limit_config.enabled {
|
||||
@@ -177,6 +193,49 @@ impl RequestFilter {
|
||||
None
|
||||
}
|
||||
|
||||
/// Check VPN-specific route access control.
|
||||
pub fn check_vpn_security(
|
||||
security: &RouteSecurity,
|
||||
vpn_info: Option<&VpnConnectionInfo>,
|
||||
domain: Option<&str>,
|
||||
) -> bool {
|
||||
let Some(vpn_security) = security.vpn.as_ref() else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let has_client_policy = vpn_security.allowed_clients.is_some()
|
||||
|| vpn_security.allowed_assigned_ips.is_some();
|
||||
let allowed_clients = vpn_security.allowed_clients.as_deref().unwrap_or(&[]);
|
||||
let allowed_assigned_ips = vpn_security.allowed_assigned_ips.as_deref().unwrap_or(&[]);
|
||||
let requires_vpn = vpn_security.required.unwrap_or(false);
|
||||
|
||||
let Some(vpn_info) = vpn_info else {
|
||||
return !requires_vpn;
|
||||
};
|
||||
|
||||
if !has_client_policy {
|
||||
return true;
|
||||
}
|
||||
|
||||
if allowed_clients.is_empty() && allowed_assigned_ips.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if allowed_assigned_ips.iter().any(|ip| ip == &vpn_info.assigned_ip) {
|
||||
return true;
|
||||
}
|
||||
|
||||
allowed_clients.iter().any(|entry| match entry {
|
||||
VpnClientAllowEntry::Plain(client_id) => client_id == &vpn_info.client_id,
|
||||
VpnClientAllowEntry::DomainScoped { client_id, domains } => {
|
||||
client_id == &vpn_info.client_id
|
||||
&& domain
|
||||
.map(|d| domains.iter().any(|pattern| domain_matches_pattern(pattern, d)))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a request path matches any pattern in the exclude list.
|
||||
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
|
||||
// No global exclude paths on RouteSecurity currently,
|
||||
@@ -286,6 +345,23 @@ impl RequestFilter {
|
||||
}
|
||||
}
|
||||
|
||||
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
|
||||
let p = pattern.trim();
|
||||
let d = domain.trim();
|
||||
if p == "*" {
|
||||
return true;
|
||||
}
|
||||
if p.eq_ignore_ascii_case(d) {
|
||||
return true;
|
||||
}
|
||||
if p.starts_with("*.") {
|
||||
let suffix = &p[1..];
|
||||
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
@@ -303,7 +379,7 @@ mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{Request, StatusCode, Version};
|
||||
use rustproxy_config::{IpAllowEntry, RouteSecurity};
|
||||
use rustproxy_config::{IpAllowEntry, RouteSecurity, RouteVpnSecurity, VpnClientAllowEntry, VpnConnectionInfo};
|
||||
|
||||
use super::RequestFilter;
|
||||
|
||||
@@ -319,6 +395,7 @@ mod tests {
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
vpn: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,4 +441,55 @@ mod tests {
|
||||
.expect("non-matching domain should be denied");
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vpn_policy_with_allow_list_preserves_direct_traffic() {
|
||||
let mut security = domain_scoped_security();
|
||||
security.ip_allow_list = None;
|
||||
security.vpn = Some(RouteVpnSecurity {
|
||||
required: Some(false),
|
||||
allowed_clients: Some(vec![VpnClientAllowEntry::Plain("client-1".to_string())]),
|
||||
allowed_assigned_ips: None,
|
||||
});
|
||||
|
||||
assert!(RequestFilter::check_vpn_security(&security, None, Some("app.example.com")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vpn_policy_denies_unlisted_vpn_client() {
|
||||
let mut security = domain_scoped_security();
|
||||
security.ip_allow_list = None;
|
||||
security.vpn = Some(RouteVpnSecurity {
|
||||
required: Some(false),
|
||||
allowed_clients: Some(vec![VpnClientAllowEntry::Plain("client-1".to_string())]),
|
||||
allowed_assigned_ips: None,
|
||||
});
|
||||
let vpn_info = VpnConnectionInfo {
|
||||
client_id: "client-2".to_string(),
|
||||
assigned_ip: "10.8.0.3".to_string(),
|
||||
transport_type: Some("wireguard".to_string()),
|
||||
remote_addr: Some("198.51.100.10:51820".to_string()),
|
||||
};
|
||||
|
||||
assert!(!RequestFilter::check_vpn_security(&security, Some(&vpn_info), Some("app.example.com")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vpn_required_with_empty_policy_denies_all_vpn_clients() {
|
||||
let mut security = domain_scoped_security();
|
||||
security.ip_allow_list = None;
|
||||
security.vpn = Some(RouteVpnSecurity {
|
||||
required: Some(true),
|
||||
allowed_clients: Some(vec![]),
|
||||
allowed_assigned_ips: None,
|
||||
});
|
||||
let vpn_info = VpnConnectionInfo {
|
||||
client_id: "client-1".to_string(),
|
||||
assigned_ip: "10.8.0.2".to_string(),
|
||||
transport_type: Some("wireguard".to_string()),
|
||||
remote_addr: Some("198.51.100.10:51820".to_string()),
|
||||
};
|
||||
|
||||
assert!(!RequestFilter::check_vpn_security(&security, Some(&vpn_info), Some("app.example.com")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,6 +275,7 @@ mod tests {
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
vpn: None,
|
||||
};
|
||||
|
||||
reg.recycle_for_security_change("r1", &security);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use thiserror::Error;
|
||||
|
||||
use rustproxy_config::VpnConnectionInfo;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProxyProtocolError {
|
||||
#[error("Invalid PROXY protocol header")]
|
||||
@@ -19,6 +21,7 @@ pub struct ProxyProtocolHeader {
|
||||
pub source_addr: SocketAddr,
|
||||
pub dest_addr: SocketAddr,
|
||||
pub protocol: ProxyProtocol,
|
||||
pub vpn: Option<VpnConnectionInfo>,
|
||||
}
|
||||
|
||||
/// Protocol in PROXY header.
|
||||
@@ -43,6 +46,9 @@ const PROXY_V2_SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
/// Custom SmartVPN metadata TLV. 0xEA sits in the PP2_TYPE_MIN_CUSTOM range.
|
||||
pub const PP2_TYPE_SMARTVPN_METADATA: u8 = 0xEA;
|
||||
|
||||
// ===== v1 (text format) =====
|
||||
|
||||
/// Parse a PROXY protocol v1 header from data.
|
||||
@@ -90,6 +96,7 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
source_addr: SocketAddr::new(src_ip, src_port),
|
||||
dest_addr: SocketAddr::new(dst_ip, dst_port),
|
||||
protocol,
|
||||
vpn: None,
|
||||
};
|
||||
|
||||
Ok((header, line_end + 2))
|
||||
@@ -173,6 +180,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
protocol: ProxyProtocol::Unknown,
|
||||
vpn: parse_v2_tlvs(&data[16..total_len], 0),
|
||||
},
|
||||
total_len,
|
||||
));
|
||||
@@ -193,11 +201,13 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]);
|
||||
let vpn = parse_v2_tlvs(addr_block, 12);
|
||||
Ok((
|
||||
ProxyProtocolHeader {
|
||||
source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
|
||||
dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port),
|
||||
protocol: ProxyProtocol::Tcp4,
|
||||
vpn,
|
||||
},
|
||||
total_len,
|
||||
))
|
||||
@@ -213,11 +223,13 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]);
|
||||
let vpn = parse_v2_tlvs(addr_block, 12);
|
||||
Ok((
|
||||
ProxyProtocolHeader {
|
||||
source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
|
||||
dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port),
|
||||
protocol: ProxyProtocol::Udp4,
|
||||
vpn,
|
||||
},
|
||||
total_len,
|
||||
))
|
||||
@@ -233,11 +245,13 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]);
|
||||
let vpn = parse_v2_tlvs(addr_block, 36);
|
||||
Ok((
|
||||
ProxyProtocolHeader {
|
||||
source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
|
||||
dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port),
|
||||
protocol: ProxyProtocol::Tcp6,
|
||||
vpn,
|
||||
},
|
||||
total_len,
|
||||
))
|
||||
@@ -253,11 +267,13 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]);
|
||||
let vpn = parse_v2_tlvs(addr_block, 36);
|
||||
Ok((
|
||||
ProxyProtocolHeader {
|
||||
source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
|
||||
dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port),
|
||||
protocol: ProxyProtocol::Udp6,
|
||||
vpn,
|
||||
},
|
||||
total_len,
|
||||
))
|
||||
@@ -268,6 +284,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
protocol: ProxyProtocol::Unknown,
|
||||
vpn: parse_v2_tlvs(addr_block, 0),
|
||||
},
|
||||
total_len,
|
||||
)),
|
||||
@@ -278,6 +295,32 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_v2_tlvs(addr_block: &[u8], offset: usize) -> Option<VpnConnectionInfo> {
|
||||
if addr_block.len() <= offset {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pos = offset;
|
||||
while pos + 3 <= addr_block.len() {
|
||||
let tlv_type = addr_block[pos];
|
||||
let len = u16::from_be_bytes([addr_block[pos + 1], addr_block[pos + 2]]) as usize;
|
||||
pos += 3;
|
||||
if pos + len > addr_block.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = &addr_block[pos..pos + len];
|
||||
if tlv_type == PP2_TYPE_SMARTVPN_METADATA {
|
||||
if let Ok(metadata) = serde_json::from_slice::<VpnConnectionInfo>(value) {
|
||||
return Some(metadata);
|
||||
}
|
||||
}
|
||||
pos += len;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Generate a PROXY protocol v2 binary header.
|
||||
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
|
||||
let transport_nibble: u8 = match transport {
|
||||
@@ -382,6 +425,27 @@ mod tests {
|
||||
assert_eq!(parsed.dest_addr, dest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_v2_smartvpn_metadata_tlv() {
|
||||
let source: SocketAddr = "198.51.100.10:54321".parse().unwrap();
|
||||
let dest: SocketAddr = "203.0.113.25:8443".parse().unwrap();
|
||||
let mut header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
|
||||
let metadata = br#"{"clientId":"alice","assignedIp":"10.8.0.2","transportType":"wireguard","remoteAddr":"198.51.100.10:51820"}"#;
|
||||
header.push(PP2_TYPE_SMARTVPN_METADATA);
|
||||
header.extend_from_slice(&(metadata.len() as u16).to_be_bytes());
|
||||
header.extend_from_slice(metadata);
|
||||
let addr_len = 12 + 3 + metadata.len();
|
||||
header[14..16].copy_from_slice(&(addr_len as u16).to_be_bytes());
|
||||
|
||||
let (parsed, consumed) = parse_v2(&header).unwrap();
|
||||
assert_eq!(consumed, header.len());
|
||||
assert_eq!(parsed.source_addr, source);
|
||||
let vpn = parsed.vpn.unwrap();
|
||||
assert_eq!(vpn.client_id, "alice");
|
||||
assert_eq!(vpn.assigned_ip, "10.8.0.2");
|
||||
assert_eq!(vpn.transport_type.as_deref(), Some("wireguard"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_v2_udp4() {
|
||||
let source: SocketAddr = "10.0.0.1:12345".parse().unwrap();
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::forwarder;
|
||||
use crate::sni_parser;
|
||||
use crate::socket_opts;
|
||||
use crate::tls_handler;
|
||||
use rustproxy_config::RouteActionType;
|
||||
use rustproxy_config::{RouteActionType, VpnConnectionInfo};
|
||||
use rustproxy_http::HttpProxyService;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::RouteManager;
|
||||
@@ -654,10 +654,11 @@ impl TcpListenerManager {
|
||||
// Only parse PROXY headers from trusted proxy IPs (security).
|
||||
// Non-proxy connections skip the peek entirely (no latency cost).
|
||||
let mut effective_peer_addr = peer_addr;
|
||||
let mut vpn_info: Option<VpnConnectionInfo> = None;
|
||||
if !conn_config.proxy_ips.is_empty() && conn_config.proxy_ips.contains(&peer_addr.ip()) {
|
||||
// Trusted proxy IP — peek for PROXY protocol header.
|
||||
// Use stack-allocated buffers (PROXY v1 headers are max ~108 bytes).
|
||||
let mut proxy_peek = [0u8; 256];
|
||||
let mut proxy_peek = [0u8; 4096];
|
||||
let pn = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut proxy_peek),
|
||||
@@ -693,7 +694,8 @@ impl TcpListenerManager {
|
||||
header.source_addr, header.dest_addr, header.protocol
|
||||
);
|
||||
effective_peer_addr = header.source_addr;
|
||||
let mut discard = [0u8; 256];
|
||||
vpn_info = header.vpn;
|
||||
let mut discard = vec![0u8; consumed];
|
||||
stream.read_exact(&mut discard[..consumed]).await?;
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -812,6 +814,14 @@ impl TcpListenerManager {
|
||||
warn!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_vpn_security(
|
||||
security,
|
||||
vpn_info.as_ref(),
|
||||
None,
|
||||
) {
|
||||
warn!("Connection from {} blocked by VPN route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
metrics.connection_opened(route_id, Some(&ip_str));
|
||||
@@ -1049,6 +1059,14 @@ impl TcpListenerManager {
|
||||
warn!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_vpn_security(
|
||||
security,
|
||||
vpn_info.as_ref(),
|
||||
domain.as_deref(),
|
||||
) {
|
||||
warn!("Connection from {} blocked by VPN route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Track connection in metrics — guard ensures connection_closed on all exit paths
|
||||
@@ -1079,6 +1097,7 @@ impl TcpListenerManager {
|
||||
route_id,
|
||||
&conn_config,
|
||||
cancel.clone(),
|
||||
vpn_info.clone(),
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
@@ -1264,7 +1283,7 @@ impl TcpListenerManager {
|
||||
// (e.g. H2 close, backend error, idle timeout drain).
|
||||
let wrapped = rustproxy_http::shutdown_on_drop::ShutdownOnDrop::new(buf_stream);
|
||||
http_proxy
|
||||
.handle_io(wrapped, peer_addr, port, cancel.clone())
|
||||
.handle_io_with_vpn(wrapped, peer_addr, port, cancel.clone(), vpn_info.clone())
|
||||
.await;
|
||||
} else {
|
||||
debug!(
|
||||
@@ -1375,7 +1394,7 @@ impl TcpListenerManager {
|
||||
// even if hyper drops the connection without calling shutdown.
|
||||
let wrapped = rustproxy_http::shutdown_on_drop::ShutdownOnDrop::new(buf_stream);
|
||||
http_proxy
|
||||
.handle_io(wrapped, peer_addr, port, cancel.clone())
|
||||
.handle_io_with_vpn(wrapped, peer_addr, port, cancel.clone(), vpn_info.clone())
|
||||
.await;
|
||||
} else {
|
||||
// Non-HTTP: TLS-to-TLS tunnel (existing behavior for raw TCP protocols)
|
||||
@@ -1404,7 +1423,7 @@ impl TcpListenerManager {
|
||||
// Plain HTTP - use HTTP proxy for request-level routing
|
||||
debug!("HTTP proxy: {} on port {}", peer_addr, port);
|
||||
http_proxy
|
||||
.handle_connection(stream, peer_addr, port, cancel.clone())
|
||||
.handle_io_with_vpn(stream, peer_addr, port, cancel.clone(), vpn_info.clone())
|
||||
.await;
|
||||
Ok(())
|
||||
} else {
|
||||
@@ -1485,6 +1504,7 @@ impl TcpListenerManager {
|
||||
route_id: Option<&str>,
|
||||
conn_config: &ConnectionConfig,
|
||||
cancel: CancellationToken,
|
||||
vpn_info: Option<VpnConnectionInfo>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::UnixStream;
|
||||
@@ -1511,6 +1531,7 @@ impl TcpListenerManager {
|
||||
"localPort": port,
|
||||
"isTLS": is_tls,
|
||||
"domain": domain,
|
||||
"vpn": vpn_info,
|
||||
});
|
||||
|
||||
// Send metadata line (JSON + newline)
|
||||
|
||||
@@ -208,6 +208,7 @@ impl RustProxy {
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
vpn: None,
|
||||
};
|
||||
|
||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||
|
||||
Reference in New Issue
Block a user