//! PROXY protocol v2 parser for extracting real client addresses //! when SmartVPN sits behind a reverse proxy (HAProxy, SmartProxy, etc.). //! //! Spec: use anyhow::Result; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; use tokio::io::AsyncReadExt; use tokio::net::TcpStream; /// Timeout for reading the PROXY protocol header from a new connection. const PROXY_HEADER_TIMEOUT: Duration = Duration::from_secs(5); /// The 12-byte PP v2 signature. const PP_V2_SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; /// Parsed PROXY protocol v2 header. #[derive(Debug, Clone)] pub struct ProxyHeader { /// Real client source address. pub src_addr: SocketAddr, /// Proxy-to-server destination address. pub dst_addr: SocketAddr, /// True if this is a LOCAL command (health check probe from proxy). pub is_local: bool, } /// Read and parse a PROXY protocol v2 header from a TCP stream. /// /// Reads exactly the header bytes — the stream is in a clean state for /// WebSocket upgrade afterward. Returns an error on timeout, invalid /// signature, or malformed header. pub async fn read_proxy_header(stream: &mut TcpStream) -> Result { tokio::time::timeout(PROXY_HEADER_TIMEOUT, read_proxy_header_inner(stream)) .await .map_err(|_| anyhow::anyhow!("PROXY protocol header read timed out ({}s)", PROXY_HEADER_TIMEOUT.as_secs()))? } async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result { // Read the 16-byte fixed prefix let mut prefix = [0u8; 16]; stream.read_exact(&mut prefix).await?; // Validate the 12-byte signature if prefix[..12] != PP_V2_SIGNATURE { anyhow::bail!("Invalid PROXY protocol v2 signature"); } // Byte 12: version (high nibble) | command (low nibble) let version = (prefix[12] & 0xF0) >> 4; let command = prefix[12] & 0x0F; if version != 2 { anyhow::bail!("Unsupported PROXY protocol version: {}", version); } // Byte 13: address family (high nibble) | protocol (low nibble) let addr_family = (prefix[13] & 0xF0) >> 4; let _protocol = prefix[13] & 0x0F; // 1 = STREAM (TCP) // Bytes 14-15: address data length (big-endian) let addr_len = u16::from_be_bytes([prefix[14], prefix[15]]) as usize; // Read the address data let mut addr_data = vec![0u8; addr_len]; if addr_len > 0 { stream.read_exact(&mut addr_data).await?; } // LOCAL command (0x00) = health check, no real address if command == 0x00 { return Ok(ProxyHeader { src_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), dst_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), is_local: true, }); } // PROXY command (0x01) — parse address block if command != 0x01 { anyhow::bail!("Unknown PROXY protocol command: {}", command); } match addr_family { // AF_INET (IPv4): 4 src + 4 dst + 2 src_port + 2 dst_port = 12 bytes 1 => { if addr_data.len() < 12 { anyhow::bail!("IPv4 address block too short: {} bytes", addr_data.len()); } let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]); let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]); let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]); let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]); Ok(ProxyHeader { src_addr: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)), dst_addr: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)), is_local: false, }) } // AF_INET6 (IPv6): 16 src + 16 dst + 2 src_port + 2 dst_port = 36 bytes 2 => { if addr_data.len() < 36 { anyhow::bail!("IPv6 address block too short: {} bytes", addr_data.len()); } let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap()); let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]); let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]); Ok(ProxyHeader { src_addr: SocketAddr::V6(SocketAddrV6::new(src_ip, src_port, 0, 0)), dst_addr: SocketAddr::V6(SocketAddrV6::new(dst_ip, dst_port, 0, 0)), is_local: false, }) } // AF_UNSPEC or unknown _ => { anyhow::bail!("Unsupported address family: {}", addr_family); } } } /// Build a PROXY protocol v2 header (for testing / proxy implementations). pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&PP_V2_SIGNATURE); match (src, dst) { (SocketAddr::V4(s), SocketAddr::V4(d)) => { buf.push(0x21); // version 2 | PROXY command buf.push(0x11); // AF_INET | STREAM buf.extend_from_slice(&12u16.to_be_bytes()); // addr length buf.extend_from_slice(&s.ip().octets()); buf.extend_from_slice(&d.ip().octets()); buf.extend_from_slice(&s.port().to_be_bytes()); buf.extend_from_slice(&d.port().to_be_bytes()); } (SocketAddr::V6(s), SocketAddr::V6(d)) => { buf.push(0x21); // version 2 | PROXY command buf.push(0x21); // AF_INET6 | STREAM buf.extend_from_slice(&36u16.to_be_bytes()); // addr length buf.extend_from_slice(&s.ip().octets()); buf.extend_from_slice(&d.ip().octets()); buf.extend_from_slice(&s.port().to_be_bytes()); buf.extend_from_slice(&d.port().to_be_bytes()); } _ => panic!("Mismatched address families"), } buf } /// Build a PROXY protocol v2 LOCAL header (health check probe). pub fn build_pp_v2_local() -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&PP_V2_SIGNATURE); buf.push(0x20); // version 2 | LOCAL command buf.push(0x00); // AF_UNSPEC buf.extend_from_slice(&0u16.to_be_bytes()); // no address data buf } #[cfg(test)] mod tests { use super::*; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; /// Helper: create a TCP pair and write data to the client side, then parse from server side. async fn parse_header_from_bytes(header_bytes: &[u8]) -> Result { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let data = header_bytes.to_vec(); let client_task = tokio::spawn(async move { let mut client = TcpStream::connect(addr).await.unwrap(); client.write_all(&data).await.unwrap(); client // keep alive }); let (mut server_stream, _) = listener.accept().await.unwrap(); let result = read_proxy_header(&mut server_stream).await; let _client = client_task.await.unwrap(); result } #[tokio::test] async fn parse_valid_ipv4_header() { let src = "203.0.113.50:12345".parse::().unwrap(); let dst = "10.0.0.1:443".parse::().unwrap(); let header = build_pp_v2_header(src, dst); let parsed = parse_header_from_bytes(&header).await.unwrap(); assert!(!parsed.is_local); assert_eq!(parsed.src_addr, src); assert_eq!(parsed.dst_addr, dst); } #[tokio::test] async fn parse_valid_ipv6_header() { let src = "[2001:db8::1]:54321".parse::().unwrap(); let dst = "[2001:db8::2]:443".parse::().unwrap(); let header = build_pp_v2_header(src, dst); let parsed = parse_header_from_bytes(&header).await.unwrap(); assert!(!parsed.is_local); assert_eq!(parsed.src_addr, src); assert_eq!(parsed.dst_addr, dst); } #[tokio::test] async fn parse_local_command() { let header = build_pp_v2_local(); let parsed = parse_header_from_bytes(&header).await.unwrap(); assert!(parsed.is_local); } #[tokio::test] async fn reject_invalid_signature() { let mut header = build_pp_v2_local(); header[0] = 0xFF; // corrupt signature let result = parse_header_from_bytes(&header).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("signature")); } #[tokio::test] async fn reject_wrong_version() { let mut header = build_pp_v2_local(); header[12] = 0x10; // version 1 instead of 2 let result = parse_header_from_bytes(&header).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("version")); } #[tokio::test] async fn reject_truncated_header() { // Only 10 bytes — not even the full signature let result = parse_header_from_bytes(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49]).await; assert!(result.is_err()); } #[tokio::test] async fn ipv4_header_is_exactly_28_bytes() { let src = "1.2.3.4:80".parse::().unwrap(); let dst = "5.6.7.8:443".parse::().unwrap(); let header = build_pp_v2_header(src, dst); // 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 12 addrs = 28 assert_eq!(header.len(), 28); } #[tokio::test] async fn ipv6_header_is_exactly_52_bytes() { let src = "[::1]:80".parse::().unwrap(); let dst = "[::2]:443".parse::().unwrap(); let header = build_pp_v2_header(src, dst); // 12 sig + 1 ver/cmd + 1 fam/proto + 2 len + 36 addrs = 52 assert_eq!(header.len(), 52); } }