feat(server): add PROXY protocol v2 support for real client IP handling and connection ACLs
This commit is contained in:
@@ -11,6 +11,30 @@ pub enum AclResult {
|
||||
DenyDst,
|
||||
}
|
||||
|
||||
/// Check whether a connection source IP is in a server-level block list.
|
||||
/// Used for pre-handshake rejection of known-bad IPs.
|
||||
pub fn is_connection_blocked(ip: Ipv4Addr, block_list: &[String]) -> bool {
|
||||
ip_matches_any(ip, block_list)
|
||||
}
|
||||
|
||||
/// Check whether a source IP is allowed by allow/block lists.
|
||||
/// Returns true if the IP is permitted (not blocked and passes allow check).
|
||||
pub fn is_source_allowed(ip: Ipv4Addr, allow_list: Option<&[String]>, block_list: Option<&[String]>) -> bool {
|
||||
// Deny overrides allow
|
||||
if let Some(bl) = block_list {
|
||||
if ip_matches_any(ip, bl) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If allow list exists and is non-empty, IP must match
|
||||
if let Some(al) = allow_list {
|
||||
if !al.is_empty() && !ip_matches_any(ip, al) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Check whether a packet from `src_ip` to `dst_ip` is allowed by the client's security policy.
|
||||
///
|
||||
/// Evaluation order (deny overrides allow):
|
||||
|
||||
@@ -20,3 +20,4 @@ pub mod mtu;
|
||||
pub mod wireguard;
|
||||
pub mod client_registry;
|
||||
pub mod acl;
|
||||
pub mod proxy_protocol;
|
||||
|
||||
261
rust/src/proxy_protocol.rs
Normal file
261
rust/src/proxy_protocol.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
//! PROXY protocol v2 parser for extracting real client addresses
|
||||
//! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.).
|
||||
//!
|
||||
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
|
||||
|
||||
use anyhow::Result;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Timeout for reading the PROXY protocol header from a new connection.
|
||||
const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// The 12-byte PP v2 signature.
|
||||
const PP_V2_SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
/// Parsed PROXY protocol v2 header.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyHeader {
|
||||
/// Real client source address.
|
||||
pub src_addr: SocketAddr,
|
||||
/// Proxy-to-server destination address.
|
||||
pub dst_addr: SocketAddr,
|
||||
/// True if this is a LOCAL command (health check probe from proxy).
|
||||
pub is_local: bool,
|
||||
}
|
||||
|
||||
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
||||
///
|
||||
/// Reads exactly the header bytes — the stream is in a clean state for
|
||||
/// WebSocket upgrade afterward. Returns an error on timeout, invalid
|
||||
/// signature, or malformed header.
|
||||
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream))
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))?
|
||||
}
|
||||
|
||||
async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
||||
// Read the 16-byte fixed prefix
|
||||
let mut prefix = [0u8; 16];
|
||||
stream.read_exact(&mut prefix).await?;
|
||||
|
||||
// Validate the 12-byte signature
|
||||
if prefix[..12] != PP_V2_SIGNATURE {
|
||||
anyhow::bail!("Invalid PROXY protocol v2 signature");
|
||||
}
|
||||
|
||||
// Byte 12: version (high nibble) | command (low nibble)
|
||||
let version = (prefix[12] & 0xF0) >> 4;
|
||||
let command = prefix[12] & 0x0F;
|
||||
|
||||
if version != 2 {
|
||||
anyhow::bail!("Unsupported PROXY protocol version: {}", version);
|
||||
}
|
||||
|
||||
// Byte 13: address family (high nibble) | protocol (low nibble)
|
||||
let addr_family = (prefix[13] & 0xF0) >> 4;
|
||||
let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP)
|
||||
|
||||
// Bytes 14-15: address data length (big-endian)
|
||||
let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize;
|
||||
|
||||
// Read the address data
|
||||
let mut addr_data = vec![0u8; addr_len];
|
||||
if addr_len > 0 {
|
||||
stream.read_exact(&mut addr_data).await?;
|
||||
}
|
||||
|
||||
// LOCAL command (0x00) = health check, no real address
|
||||
if command == 0x00 {
|
||||
return Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
is_local: true,
|
||||
});
|
||||
}
|
||||
|
||||
// PROXY command (0x01) — parse address block
|
||||
if command != 0x01 {
|
||||
anyhow::bail!("Unknown PROXY protocol command: {}", command);
|
||||
}
|
||||
|
||||
match addr_family {
|
||||
// AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes
|
||||
1 => {
|
||||
if addr_data.len() < 12 {
|
||||
anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
|
||||
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
|
||||
dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes
|
||||
2 => {
|
||||
if addr_data.len() < 36 {
|
||||
anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len());
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
|
||||
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
|
||||
Ok(ProxyHeader {
|
||||
src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)),
|
||||
dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)),
|
||||
is_local: false,
|
||||
})
|
||||
}
|
||||
// AF_UNSPEC or unknown
|
||||
_ => {
|
||||
anyhow::bail!("Unsupported address family: {}", addr_family);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
|
||||
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
|
||||
match (src, dst) {
|
||||
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x11); // AF_INET | STREAM
|
||||
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
|
||||
buf.push(0x21); // version 2 | PROXY command
|
||||
buf.push(0x21); // AF_INET6 | STREAM
|
||||
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
|
||||
buf.extend_from_slice(&s.ip().octets());
|
||||
buf.extend_from_slice(&d.ip().octets());
|
||||
buf.extend_from_slice(&s.port().to_be_bytes());
|
||||
buf.extend_from_slice(&d.port().to_be_bytes());
|
||||
}
|
||||
_ => panic!("Mismatched address families"),
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a PROXY protocol v2 LOCAL header (health check probe).
|
||||
pub fn build_pp_v2_local() -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
||||
buf.push(0x20); // version 2 | LOCAL command
|
||||
buf.push(0x00); // AF_UNSPEC
|
||||
buf.extend_from_slice(&0u16.to_be_bytes()); // no address data
|
||||
buf
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Helper: create a TCP pair and write data to the client side, then parse from server side.
|
||||
async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result<ProxyHeader> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let data = header_bytes.to_vec();
|
||||
let client_task = tokio::spawn(async move {
|
||||
let mut client = TcpStream::connect(addr).await.unwrap();
|
||||
client.write_all(&data).await.unwrap();
|
||||
client // keep alive
|
||||
});
|
||||
|
||||
let (mut server_stream, _) = listener.accept().await.unwrap();
|
||||
let result = read_proxy_header(&mut server_stream).await;
|
||||
let _client = client_task.await.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv4_header() {
|
||||
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
|
||||
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_valid_ipv6_header() {
|
||||
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[2001:db8::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(!parsed.is_local);
|
||||
assert_eq!(parsed.src_addr, src);
|
||||
assert_eq!(parsed.dst_addr, dst);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_local_command() {
|
||||
let header = build_pp_v2_local();
|
||||
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
||||
assert!(parsed.is_local);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_invalid_signature() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[0] = 0xFF; // corrupt signature
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("signature"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_wrong_version() {
|
||||
let mut header = build_pp_v2_local();
|
||||
header[12] = 0x10; // version 1 instead of 2
|
||||
let result = parse_header_from_bytes(&header).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("version"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_truncated_header() {
|
||||
// Only 10 bytes — not even the full signature
|
||||
let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv4_header_is_exactly_28_bytes() {
|
||||
let src = "1.2.3.4:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "5.6.7.8:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28
|
||||
assert_eq!(header.len(), 28);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6_header_is_exactly_52_bytes() {
|
||||
let src = "[::1]:80".parse::<SocketAddr>().unwrap();
|
||||
let dst = "[::2]:443".parse::<SocketAddr>().unwrap();
|
||||
let header = build_pp_v2_header(src, dst);
|
||||
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52
|
||||
assert_eq!(header.len(), 52);
|
||||
}
|
||||
}
|
||||
@@ -49,6 +49,11 @@ pub struct ServerConfig {
|
||||
pub quic_idle_timeout_secs: Option<u64>,
|
||||
/// Pre-registered clients for IK authentication.
|
||||
pub clients: Option<Vec<ClientEntry>>,
|
||||
/// Enable PROXY protocol v2 parsing on incoming WebSocket connections.
|
||||
/// SECURITY: Must be false when accepting direct client connections.
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -70,6 +75,8 @@ pub struct ClientInfo {
|
||||
pub authenticated_key: String,
|
||||
/// Registered client ID from the client registry.
|
||||
pub registered_client_id: String,
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -562,8 +569,8 @@ impl VpnServer {
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
|
||||
/// to the transport-agnostic `handle_client_connection`.
|
||||
/// WebSocket listener — accepts TCP connections, optionally parses PROXY protocol v2,
|
||||
/// upgrades to WS, then hands off to `handle_client_connection`.
|
||||
async fn run_ws_listener(
|
||||
state: Arc<ServerState>,
|
||||
listen_addr: String,
|
||||
@@ -576,17 +583,51 @@ async fn run_ws_listener(
|
||||
tokio::select! {
|
||||
accept = listener.accept() => {
|
||||
match accept {
|
||||
Ok((stream, addr)) => {
|
||||
info!("New connection from {}", addr);
|
||||
Ok((mut tcp_stream, tcp_addr)) => {
|
||||
info!("New connection from {}", tcp_addr);
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
match transport::accept_connection(stream).await {
|
||||
// Phase 0: Parse PROXY protocol v2 header if enabled
|
||||
let remote_addr = if state.config.proxy_protocol.unwrap_or(false) {
|
||||
match crate::proxy_protocol::read_proxy_header(&mut tcp_stream).await {
|
||||
Ok(header) if header.is_local => {
|
||||
info!("PP v2 LOCAL probe from {}", tcp_addr);
|
||||
return; // Health check — close gracefully
|
||||
}
|
||||
Ok(header) => {
|
||||
info!("PP v2: real client {} (via {})", header.src_addr, tcp_addr);
|
||||
Some(header.src_addr)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PP v2 parse failed from {}: {}", tcp_addr, e);
|
||||
return; // Drop connection
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(tcp_addr) // Direct connection — use TCP SocketAddr
|
||||
};
|
||||
|
||||
// Phase 1: Server-level connection IP block list (pre-handshake)
|
||||
if let (Some(ref block_list), Some(ref addr)) = (&state.config.connection_ip_block_list, &remote_addr) {
|
||||
if !block_list.is_empty() {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if acl::is_connection_blocked(v4, block_list) {
|
||||
warn!("Connection blocked by server IP block list: {}", addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: WebSocket upgrade + VPN handshake
|
||||
match transport::accept_connection(tcp_stream).await {
|
||||
Ok(ws) => {
|
||||
let (sink, stream) = transport_trait::split_ws(ws);
|
||||
if let Err(e) = handle_client_connection(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
remote_addr,
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
@@ -662,6 +703,7 @@ async fn run_quic_listener(
|
||||
state,
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
Some(remote),
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
@@ -700,6 +742,7 @@ async fn handle_client_connection(
|
||||
state: Arc<ServerState>,
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>,
|
||||
) -> Result<()> {
|
||||
let server_private_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
@@ -779,6 +822,24 @@ async fn handle_client_connection(
|
||||
|
||||
let mut noise_transport = responder.into_transport_mode()?;
|
||||
|
||||
// Connection-level ACL: check real client IP against per-client ipAllowList/ipBlockList
|
||||
if let (Some(ref sec), Some(ref addr)) = (&client_security, &remote_addr) {
|
||||
if let std::net::IpAddr::V4(v4) = addr.ip() {
|
||||
if !acl::is_source_allowed(
|
||||
v4,
|
||||
sec.ip_allow_list.as_deref(),
|
||||
sec.ip_block_list.as_deref(),
|
||||
) {
|
||||
warn!("Connection-level ACL denied client {} from IP {}", registered_client_id, addr);
|
||||
let disconnect_frame = Frame { packet_type: PacketType::Disconnect, payload: Vec::new() };
|
||||
let mut frame_bytes = BytesMut::new();
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
|
||||
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
|
||||
anyhow::bail!("Connection denied: source IP {} not allowed for client {}", addr, registered_client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the registered client ID as the connection ID
|
||||
let client_id = registered_client_id.clone();
|
||||
|
||||
@@ -811,6 +872,7 @@ async fn handle_client_connection(
|
||||
burst_bytes: burst,
|
||||
authenticated_key: client_pub_key_b64.clone(),
|
||||
registered_client_id: registered_client_id.clone(),
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
@@ -845,7 +907,9 @@ async fn handle_client_connection(
|
||||
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
|
||||
sink.send_reliable(frame_bytes.to_vec()).await?;
|
||||
|
||||
info!("Client {} ({}) connected with IP {}", registered_client_id, &client_pub_key_b64[..8], assigned_ip);
|
||||
info!("Client {} ({}) connected with IP {} from {}",
|
||||
registered_client_id, &client_pub_key_b64[..8], assigned_ip,
|
||||
remote_addr.map(|a| a.to_string()).unwrap_or_else(|| "unknown".to_string()));
|
||||
|
||||
// Main packet loop with dead-peer detection
|
||||
let mut last_activity = tokio::time::Instant::now();
|
||||
|
||||
Reference in New Issue
Block a user