feat(server): add PROXY protocol v2 support for real client IP handling and connection ACLs

This commit is contained in:
2026-03-29 17:40:55 +00:00
parent e31086d0c2
commit 229db4be38
9 changed files with 592 additions and 404 deletions

View File

@@ -11,6 +11,30 @@ pub enum AclResult {
DenyDst,
}
/// Check whether a connection source IP is in a server-level block list.
/// Used for pre-handshake rejection of known-bad IPs.
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
ip_matches_any(ip, block_list)
}
/// Check whether a source IP is allowed by allow/block lists.
/// Returns true if the IP is permitted (not blocked and passes allow check).
pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool {
// Deny overrides allow
if let Some(bl) = block_list {
if ip_matches_any(ip, bl) {
return false;
}
}
// If allow list exists and is non-empty, IP must match
if let Some(al) = allow_list {
if !al.is_empty() && !ip_matches_any(ip, al) {
return false;
}
}
true
}
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
///
/// Evaluation order (deny overrides allow):

View File

@@ -20,3 +20,4 @@ pub mod mtu;
pub mod wireguard;
pub mod client_registry;
pub mod acl;
pub mod proxy_protocol;

261
rust/src/proxy_protocol.rs Normal file
View File

@@ -0,0 +1,261 @@
//! PROXY protocol v2 parser for extracting real client addresses
//! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.).
//!
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
use anyhow::Result;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
/// Timeout for reading the PROXY protocol header from a new connection.
const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5);
/// The 12-byte PP v2 signature.
const PP_V2_SIGNATURE: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
/// Parsed PROXY protocol v2 header.
#[derive(Debug, Clone)]
pub struct ProxyHeader {
/// Real client source address.
pub src_addr: SocketAddr,
/// Proxy-to-server destination address.
pub dst_addr: SocketAddr,
/// True if this is a LOCAL command (health check probe from proxy).
pub is_local: bool,
}
/// Read and parse a PROXY protocol v2 header from a TCP stream.
///
/// Reads exactly the header bytes — the stream is in a clean state for
/// WebSocket upgrade afterward. Returns an error on timeout, invalid
/// signature, or malformed header.
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader> {
tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream))
.await
.map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))?
}
async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader> {
// Read the 16-byte fixed prefix
let mut prefix = [0u8; 16];
stream.read_exact(&mut prefix).await?;
// Validate the 12-byte signature
if prefix[..12] != PP_V2_SIGNATURE {
anyhow::bail!("Invalid PROXY protocol v2 signature");
}
// Byte 12: version (high nibble) | command (low nibble)
let version = (prefix[12] & 0xF0) >> 4;
let command = prefix[12] & 0x0F;
if version != 2 {
anyhow::bail!("Unsupported PROXY protocol version: {}", version);
}
// Byte 13: address family (high nibble) | protocol (low nibble)
let addr_family = (prefix[13] & 0xF0) >> 4;
let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP)
// Bytes 14-15: address data length (big-endian)
let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize;
// Read the address data
let mut addr_data = vec![0u8; addr_len];
if addr_len > 0 {
stream.read_exact(&mut addr_data).await?;
}
// LOCAL command (0x00) = health check, no real address
if command == 0x00 {
return Ok(ProxyHeader {
src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
is_local: true,
});
}
// PROXY command (0x01) — parse address block
if command != 0x01 {
anyhow::bail!("Unknown PROXY protocol command: {}", command);
}
match addr_family {
// AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes
1 => {
if addr_data.len() < 12 {
anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len());
}
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
Ok(ProxyHeader {
src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
is_local: false,
})
}
// AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes
2 => {
if addr_data.len() < 36 {
anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len());
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
Ok(ProxyHeader {
src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)),
dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)),
is_local: false,
})
}
// AF_UNSPEC or unknown
_ => {
anyhow::bail!("Unsupported address family: {}", addr_family);
}
}
}
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&PP_V2_SIGNATURE);
match (src, dst) {
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
buf.push(0x21); // version 2 | PROXY command
buf.push(0x11); // AF_INET | STREAM
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
buf.extend_from_slice(&s.ip().octets());
buf.extend_from_slice(&d.ip().octets());
buf.extend_from_slice(&s.port().to_be_bytes());
buf.extend_from_slice(&d.port().to_be_bytes());
}
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
buf.push(0x21); // version 2 | PROXY command
buf.push(0x21); // AF_INET6 | STREAM
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
buf.extend_from_slice(&s.ip().octets());
buf.extend_from_slice(&d.ip().octets());
buf.extend_from_slice(&s.port().to_be_bytes());
buf.extend_from_slice(&d.port().to_be_bytes());
}
_ => panic!("Mismatched address families"),
}
buf
}
/// Build a PROXY protocol v2 LOCAL header (health check probe).
pub fn build_pp_v2_local() -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&PP_V2_SIGNATURE);
buf.push(0x20); // version 2 | LOCAL command
buf.push(0x00); // AF_UNSPEC
buf.extend_from_slice(&0u16.to_be_bytes()); // no address data
buf
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
/// Helper: create a TCP pair and write data to the client side, then parse from server side.
async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result<ProxyHeader> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let data = header_bytes.to_vec();
let client_task = tokio::spawn(async move {
let mut client = TcpStream::connect(addr).await.unwrap();
client.write_all(&data).await.unwrap();
client // keep alive
});
let (mut server_stream, _) = listener.accept().await.unwrap();
let result = read_proxy_header(&mut server_stream).await;
let _client = client_task.await.unwrap();
result
}
#[tokio::test]
async fn parse_valid_ipv4_header() {
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
let header = build_pp_v2_header(src, dst);
let parsed = parse_header_from_bytes(&header).await.unwrap();
assert!(!parsed.is_local);
assert_eq!(parsed.src_addr, src);
assert_eq!(parsed.dst_addr, dst);
}
#[tokio::test]
async fn parse_valid_ipv6_header() {
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
let dst = "[2001:db8::2]:443".parse::<SocketAddr>().unwrap();
let header = build_pp_v2_header(src, dst);
let parsed = parse_header_from_bytes(&header).await.unwrap();
assert!(!parsed.is_local);
assert_eq!(parsed.src_addr, src);
assert_eq!(parsed.dst_addr, dst);
}
#[tokio::test]
async fn parse_local_command() {
let header = build_pp_v2_local();
let parsed = parse_header_from_bytes(&header).await.unwrap();
assert!(parsed.is_local);
}
#[tokio::test]
async fn reject_invalid_signature() {
let mut header = build_pp_v2_local();
header[0] = 0xFF; // corrupt signature
let result = parse_header_from_bytes(&header).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("signature"));
}
#[tokio::test]
async fn reject_wrong_version() {
let mut header = build_pp_v2_local();
header[12] = 0x10; // version 1 instead of 2
let result = parse_header_from_bytes(&header).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("version"));
}
#[tokio::test]
async fn reject_truncated_header() {
// Only 10 bytes — not even the full signature
let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await;
assert!(result.is_err());
}
#[tokio::test]
async fn ipv4_header_is_exactly_28_bytes() {
let src = "1.2.3.4:80".parse::<SocketAddr>().unwrap();
let dst = "5.6.7.8:443".parse::<SocketAddr>().unwrap();
let header = build_pp_v2_header(src, dst);
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28
assert_eq!(header.len(), 28);
}
#[tokio::test]
async fn ipv6_header_is_exactly_52_bytes() {
let src = "[::1]:80".parse::<SocketAddr>().unwrap();
let dst = "[::2]:443".parse::<SocketAddr>().unwrap();
let header = build_pp_v2_header(src, dst);
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52
assert_eq!(header.len(), 52);
}
}

View File

@@ -49,6 +49,11 @@ pub struct ServerConfig {
pub quic_idle_timeout_secs: Option<u64>,
/// Pre-registered clients for IK authentication.
pub clients: Option<Vec<ClientEntry>>,
/// Enable PROXY protocol v2 parsing on incoming WebSocket connections.
/// SECURITY: Must be false when accepting direct client connections.
pub proxy_protocol: Option<bool>,
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
pub connection_ip_block_list: Option<Vec<String>>,
}
/// Information about a connected client.
@@ -70,6 +75,8 @@ pub struct ClientInfo {
pub authenticated_key: String,
/// Registered client ID from the client registry.
pub registered_client_id: String,
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
pub remote_addr: Option<String>,
}
/// Server statistics.
@@ -562,8 +569,8 @@ impl VpnServer {
}
}
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
/// to the transport-agnostic `handle_client_connection`.
/// WebSocket listener — accepts TCP connections, optionally parses PROXY protocol v2,
/// upgrades to WS, then hands off to `handle_client_connection`.
async fn run_ws_listener(
state: Arc<ServerState>,
listen_addr: String,
@@ -576,17 +583,51 @@ async fn run_ws_listener(
tokio::select! {
accept = listener.accept() => {
match accept {
Ok((stream, addr)) => {
info!("New connection from {}", addr);
Ok((mut tcp_stream, tcp_addr)) => {
info!("New connection from {}", tcp_addr);
let state = state.clone();
tokio::spawn(async move {
match transport::accept_connection(stream).await {
// Phase 0: Parse PROXY protocol v2 header if enabled
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
match crate::proxy_protocol::read_proxy_header(&mut tcp_stream).await {
Ok(header) if header.is_local => {
info!("PP v2 LOCAL probe from {}", tcp_addr);
return; // Health check — close gracefully
}
Ok(header) => {
info!("PP v2: real client {} (via {})", header.src_addr, tcp_addr);
Some(header.src_addr)
}
Err(e) => {
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
return; // Drop connection
}
}
} else {
Some(tcp_addr) // Direct connection — use TCP SocketAddr
};
// Phase 1: Server-level connection IP block list (pre-handshake)
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
if !block_list.is_empty() {
if let std::net::IpAddr::V4(v4) = addr.ip() {
if acl::is_connection_blocked(v4, block_list) {
warn!("Connection blocked by server IP block list: {}", addr);
return;
}
}
}
}
// Phase 2: WebSocket upgrade + VPN handshake
match transport::accept_connection(tcp_stream).await {
Ok(ws) => {
let (sink, stream) = transport_trait::split_ws(ws);
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
remote_addr,
).await {
warn!("Client connection error: {}", e);
}
@@ -662,6 +703,7 @@ async fn run_quic_listener(
state,
Box::new(sink),
Box::new(stream),
Some(remote),
).await {
warn!("QUIC client error: {}", e);
}
@@ -700,6 +742,7 @@ async fn handle_client_connection(
state: Arc<ServerState>,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
remote_addr: Option<std::net::SocketAddr>,
) -> Result<()> {
let server_private_key = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
@@ -779,6 +822,24 @@ async fn handle_client_connection(
let mut noise_transport = responder.into_transport_mode()?;
// Connection-level ACL: check real client IP against per-client ipAllowList/ipBlockList
if let (Some(ref sec), Some(ref addr)) = (&client_security, &remote_addr) {
if let std::net::IpAddr::V4(v4) = addr.ip() {
if !acl::is_source_allowed(
v4,
sec.ip_allow_list.as_deref(),
sec.ip_block_list.as_deref(),
) {
warn!("Connection-level ACL denied client {} from IP {}", registered_client_id, addr);
let disconnect_frame = Frame { packet_type: PacketType::Disconnect, payload: Vec::new() };
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
anyhow::bail!("Connection denied: source IP {} not allowed for client {}", addr, registered_client_id);
}
}
}
// Use the registered client ID as the connection ID
let client_id = registered_client_id.clone();
@@ -811,6 +872,7 @@ async fn handle_client_connection(
burst_bytes: burst,
authenticated_key: client_pub_key_b64.clone(),
registered_client_id: registered_client_id.clone(),
remote_addr: remote_addr.map(|a| a.to_string()),
};
state.clients.write().await.insert(client_id.clone(), client_info);
@@ -845,7 +907,9 @@ async fn handle_client_connection(
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
info!("Client {} ({}) connected with IP {}", registered_client_id, &client_pub_key_b64[..8], assigned_ip);
info!("Client {} ({}) connected with IP {} from {}",
registered_client_id, &client_pub_key_b64[..8], assigned_ip,
remote_addr.map(|a| a.to_string()).unwrap_or_else(|| "unknown".to_string()));
// Main packet loop with dead-peer detection
let mut last_activity = tokio::time::Instant::now();