262 lines
9.9 KiB
Rust
262 lines
9.9 KiB
Rust
//! PROXY protocol v2 parser for extracting real client addresses
|
|
//! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.).
|
|
//!
|
|
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
|
|
|
|
use anyhow::Result;
|
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
|
use std::time::Duration;
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::net::TcpStream;
|
|
|
|
/// Timeout for reading the PROXY protocol header from a new connection.
|
|
const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5);
|
|
|
|
/// The 12-byte PP v2 signature.
|
|
const PP_V2_SIGNATURE: [u8; 12] = [
|
|
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
|
];
|
|
|
|
/// Parsed PROXY protocol v2 header.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ProxyHeader {
|
|
/// Real client source address.
|
|
pub src_addr: SocketAddr,
|
|
/// Proxy-to-server destination address.
|
|
pub dst_addr: SocketAddr,
|
|
/// True if this is a LOCAL command (health check probe from proxy).
|
|
pub is_local: bool,
|
|
}
|
|
|
|
/// Read and parse a PROXY protocol v2 header from a TCP stream.
|
|
///
|
|
/// Reads exactly the header bytes — the stream is in a clean state for
|
|
/// WebSocket upgrade afterward. Returns an error on timeout, invalid
|
|
/// signature, or malformed header.
|
|
pub async fn read_proxy_header(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
|
tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream))
|
|
.await
|
|
.map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))?
|
|
}
|
|
|
|
async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader> {
|
|
// Read the 16-byte fixed prefix
|
|
let mut prefix = [0u8; 16];
|
|
stream.read_exact(&mut prefix).await?;
|
|
|
|
// Validate the 12-byte signature
|
|
if prefix[..12] != PP_V2_SIGNATURE {
|
|
anyhow::bail!("Invalid PROXY protocol v2 signature");
|
|
}
|
|
|
|
// Byte 12: version (high nibble) | command (low nibble)
|
|
let version = (prefix[12] & 0xF0) >> 4;
|
|
let command = prefix[12] & 0x0F;
|
|
|
|
if version != 2 {
|
|
anyhow::bail!("Unsupported PROXY protocol version: {}", version);
|
|
}
|
|
|
|
// Byte 13: address family (high nibble) | protocol (low nibble)
|
|
let addr_family = (prefix[13] & 0xF0) >> 4;
|
|
let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP)
|
|
|
|
// Bytes 14-15: address data length (big-endian)
|
|
let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize;
|
|
|
|
// Read the address data
|
|
let mut addr_data = vec![0u8; addr_len];
|
|
if addr_len > 0 {
|
|
stream.read_exact(&mut addr_data).await?;
|
|
}
|
|
|
|
// LOCAL command (0x00) = health check, no real address
|
|
if command == 0x00 {
|
|
return Ok(ProxyHeader {
|
|
src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
|
dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
|
is_local: true,
|
|
});
|
|
}
|
|
|
|
// PROXY command (0x01) — parse address block
|
|
if command != 0x01 {
|
|
anyhow::bail!("Unknown PROXY protocol command: {}", command);
|
|
}
|
|
|
|
match addr_family {
|
|
// AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes
|
|
1 => {
|
|
if addr_data.len() < 12 {
|
|
anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len());
|
|
}
|
|
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
|
|
let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
|
|
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
|
|
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
|
|
Ok(ProxyHeader {
|
|
src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
|
|
dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
|
|
is_local: false,
|
|
})
|
|
}
|
|
// AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes
|
|
2 => {
|
|
if addr_data.len() < 36 {
|
|
anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len());
|
|
}
|
|
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
|
|
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
|
|
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
|
|
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
|
|
Ok(ProxyHeader {
|
|
src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)),
|
|
dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)),
|
|
is_local: false,
|
|
})
|
|
}
|
|
// AF_UNSPEC or unknown
|
|
_ => {
|
|
anyhow::bail!("Unsupported address family: {}", addr_family);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
|
|
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
|
|
let mut buf = Vec::new();
|
|
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
|
|
|
match (src, dst) {
|
|
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
|
|
buf.push(0x21); // version 2 | PROXY command
|
|
buf.push(0x11); // AF_INET | STREAM
|
|
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
|
|
buf.extend_from_slice(&s.ip().octets());
|
|
buf.extend_from_slice(&d.ip().octets());
|
|
buf.extend_from_slice(&s.port().to_be_bytes());
|
|
buf.extend_from_slice(&d.port().to_be_bytes());
|
|
}
|
|
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
|
|
buf.push(0x21); // version 2 | PROXY command
|
|
buf.push(0x21); // AF_INET6 | STREAM
|
|
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
|
|
buf.extend_from_slice(&s.ip().octets());
|
|
buf.extend_from_slice(&d.ip().octets());
|
|
buf.extend_from_slice(&s.port().to_be_bytes());
|
|
buf.extend_from_slice(&d.port().to_be_bytes());
|
|
}
|
|
_ => panic!("Mismatched address families"),
|
|
}
|
|
buf
|
|
}
|
|
|
|
/// Build a PROXY protocol v2 LOCAL header (health check probe).
|
|
pub fn build_pp_v2_local() -> Vec<u8> {
|
|
let mut buf = Vec::new();
|
|
buf.extend_from_slice(&PP_V2_SIGNATURE);
|
|
buf.push(0x20); // version 2 | LOCAL command
|
|
buf.push(0x00); // AF_UNSPEC
|
|
buf.extend_from_slice(&0u16.to_be_bytes()); // no address data
|
|
buf
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::net::TcpListener;
|
|
|
|
/// Helper: create a TCP pair and write data to the client side, then parse from server side.
|
|
async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result<ProxyHeader> {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
let data = header_bytes.to_vec();
|
|
let client_task = tokio::spawn(async move {
|
|
let mut client = TcpStream::connect(addr).await.unwrap();
|
|
client.write_all(&data).await.unwrap();
|
|
client // keep alive
|
|
});
|
|
|
|
let (mut server_stream, _) = listener.accept().await.unwrap();
|
|
let result = read_proxy_header(&mut server_stream).await;
|
|
let _client = client_task.await.unwrap();
|
|
result
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn parse_valid_ipv4_header() {
|
|
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
|
|
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
|
|
let header = build_pp_v2_header(src, dst);
|
|
|
|
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
|
assert!(!parsed.is_local);
|
|
assert_eq!(parsed.src_addr, src);
|
|
assert_eq!(parsed.dst_addr, dst);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn parse_valid_ipv6_header() {
|
|
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
|
|
let dst = "[2001:db8::2]:443".parse::<SocketAddr>().unwrap();
|
|
let header = build_pp_v2_header(src, dst);
|
|
|
|
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
|
assert!(!parsed.is_local);
|
|
assert_eq!(parsed.src_addr, src);
|
|
assert_eq!(parsed.dst_addr, dst);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn parse_local_command() {
|
|
let header = build_pp_v2_local();
|
|
let parsed = parse_header_from_bytes(&header).await.unwrap();
|
|
assert!(parsed.is_local);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn reject_invalid_signature() {
|
|
let mut header = build_pp_v2_local();
|
|
header[0] = 0xFF; // corrupt signature
|
|
let result = parse_header_from_bytes(&header).await;
|
|
assert!(result.is_err());
|
|
assert!(result.unwrap_err().to_string().contains("signature"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn reject_wrong_version() {
|
|
let mut header = build_pp_v2_local();
|
|
header[12] = 0x10; // version 1 instead of 2
|
|
let result = parse_header_from_bytes(&header).await;
|
|
assert!(result.is_err());
|
|
assert!(result.unwrap_err().to_string().contains("version"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn reject_truncated_header() {
|
|
// Only 10 bytes — not even the full signature
|
|
let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn ipv4_header_is_exactly_28_bytes() {
|
|
let src = "1.2.3.4:80".parse::<SocketAddr>().unwrap();
|
|
let dst = "5.6.7.8:443".parse::<SocketAddr>().unwrap();
|
|
let header = build_pp_v2_header(src, dst);
|
|
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28
|
|
assert_eq!(header.len(), 28);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn ipv6_header_is_exactly_52_bytes() {
|
|
let src = "[::1]:80".parse::<SocketAddr>().unwrap();
|
|
let dst = "[::2]:443".parse::<SocketAddr>().unwrap();
|
|
let header = build_pp_v2_header(src, dst);
|
|
// 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52
|
|
assert_eq!(header.len(), 52);
|
|
}
|
|
}
|