use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use thiserror::Error; #[derive(Debug, Error)] pub enum ProxyProtocolError { #[error("Invalid PROXY protocol header")] InvalidHeader, #[error("Unsupported PROXY protocol version")] UnsupportedVersion, #[error("Parse error: {0}")] Parse(String), #[error("Incomplete header: need {0} bytes, got {1}")] Incomplete(usize, usize), } /// Parsed PROXY protocol header (v1 or v2). #[derive(Debug, Clone)] pub struct ProxyProtocolHeader { pub source_addr: SocketAddr, pub dest_addr: SocketAddr, pub protocol: ProxyProtocol, } /// Protocol in PROXY header. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ProxyProtocol { Tcp4, Tcp6, Udp4, Udp6, Unknown, } /// Transport type for PROXY v2 header generation. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProxyV2Transport { Stream, // TCP Datagram, // UDP } /// PROXY protocol v2 signature (12 bytes). const PROXY_V2_SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; // ===== v1 (text format) ===== /// Parse a PROXY protocol v1 header from data. /// /// Format: `PROXY TCP4 \r\n` pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> { let line_end = data .windows(2) .position(|w| w == b"\r\n") .ok_or(ProxyProtocolError::InvalidHeader)?; let line = std::str::from_utf8(&data[..line_end]) .map_err(|_| ProxyProtocolError::InvalidHeader)?; if !line.starts_with("PROXY ") { return Err(ProxyProtocolError::InvalidHeader); } let parts: Vec<&str> = line.split(' ').collect(); if parts.len() != 6 { return Err(ProxyProtocolError::InvalidHeader); } let protocol = match parts[1] { "TCP4" => ProxyProtocol::Tcp4, "TCP6" => ProxyProtocol::Tcp6, "UNKNOWN" => ProxyProtocol::Unknown, _ => return Err(ProxyProtocolError::UnsupportedVersion), }; let src_ip: IpAddr = parts[2] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?; let dst_ip: IpAddr = parts[3] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?; let src_port: u16 = parts[4] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?; let dst_port: u16 = parts[5] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?; let header = ProxyProtocolHeader { source_addr: SocketAddr::new(src_ip, src_port), dest_addr: SocketAddr::new(dst_ip, dst_port), protocol, }; Ok((header, line_end + 2)) } /// Generate a PROXY protocol v1 header string. pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String { let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" }; format!( "PROXY {} {} {} {} {}\r\n", proto, source.ip(), dest.ip(), source.port(), dest.port() ) } /// Check if data starts with a PROXY protocol v1 header. pub fn is_proxy_protocol_v1(data: &[u8]) -> bool { data.starts_with(b"PROXY ") } // ===== v2 (binary format) ===== /// Check if data starts with a PROXY protocol v2 header. pub fn is_proxy_protocol_v2(data: &[u8]) -> bool { data.len() >= 12 && data[..12] == PROXY_V2_SIGNATURE } /// Parse a PROXY protocol v2 binary header. /// /// Binary format: /// - [0..12] signature (12 bytes) /// - [12] version (high nibble) + command (low nibble) /// - [13] address family (high nibble) + transport (low nibble) /// - [14..16] address block length (big-endian u16) /// - [16..] address block (variable, depends on family) pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> { if data.len() < 16 { return Err(ProxyProtocolError::Incomplete(16, data.len())); } // Validate signature if data[..12] != PROXY_V2_SIGNATURE { return Err(ProxyProtocolError::InvalidHeader); } // Version (high nibble of byte 12) must be 0x2 let version = (data[12] >> 4) & 0x0F; if version != 2 { return Err(ProxyProtocolError::UnsupportedVersion); } // Command (low nibble of byte 12) let command = data[12] & 0x0F; // 0x0 = LOCAL, 0x1 = PROXY if command > 1 { return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command))); } // Address family (high nibble) + transport (low nibble) of byte 13 let family = (data[13] >> 4) & 0x0F; let transport = data[13] & 0x0F; // Address block length let addr_len = u16::from_be_bytes([data[14], data[15]]) as usize; let total_len = 16 + addr_len; if data.len() < total_len { return Err(ProxyProtocolError::Incomplete(total_len, data.len())); } // LOCAL command: no real addresses, return unspecified if command == 0 { return Ok(( ProxyProtocolHeader { source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), protocol: ProxyProtocol::Unknown, }, total_len, )); } // PROXY command: parse addresses based on family + transport let addr_block = &data[16..16 + addr_len]; match (family, transport) { // AF_INET (0x1) + STREAM (0x1) = TCP4 (0x1, 0x1) => { if addr_len < 12 { return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string())); } let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]); let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]); let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]); let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]); Ok(( ProxyProtocolHeader { source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port), dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port), protocol: ProxyProtocol::Tcp4, }, total_len, )) } // AF_INET (0x1) + DGRAM (0x2) = UDP4 (0x1, 0x2) => { if addr_len < 12 { return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string())); } let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]); let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]); let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]); let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]); Ok(( ProxyProtocolHeader { source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port), dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port), protocol: ProxyProtocol::Udp4, }, total_len, )) } // AF_INET6 (0x2) + STREAM (0x1) = TCP6 (0x2, 0x1) => { if addr_len < 36 { return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string())); } let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap()); let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]); let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]); Ok(( ProxyProtocolHeader { source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port), dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port), protocol: ProxyProtocol::Tcp6, }, total_len, )) } // AF_INET6 (0x2) + DGRAM (0x2) = UDP6 (0x2, 0x2) => { if addr_len < 36 { return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string())); } let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap()); let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]); let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]); Ok(( ProxyProtocolHeader { source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port), dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port), protocol: ProxyProtocol::Udp6, }, total_len, )) } // AF_UNSPEC or unknown (0x0, _) => Ok(( ProxyProtocolHeader { source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), protocol: ProxyProtocol::Unknown, }, total_len, )), _ => Err(ProxyProtocolError::Parse(format!( "Unsupported family/transport: 0x{:X}{:X}", family, transport ))), } } /// Generate a PROXY protocol v2 binary header. pub fn generate_v2( source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport, ) -> Vec { let transport_nibble: u8 = match transport { ProxyV2Transport::Stream => 0x1, ProxyV2Transport::Datagram => 0x2, }; match (source.ip(), dest.ip()) { (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => { let mut buf = Vec::with_capacity(28); buf.extend_from_slice(&PROXY_V2_SIGNATURE); buf.push(0x21); // version 2, PROXY command buf.push(0x10 | transport_nibble); // AF_INET + transport buf.extend_from_slice(&12u16.to_be_bytes()); // addr block length buf.extend_from_slice(&src_ip.octets()); buf.extend_from_slice(&dst_ip.octets()); buf.extend_from_slice(&source.port().to_be_bytes()); buf.extend_from_slice(&dest.port().to_be_bytes()); buf } (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => { let mut buf = Vec::with_capacity(52); buf.extend_from_slice(&PROXY_V2_SIGNATURE); buf.push(0x21); // version 2, PROXY command buf.push(0x20 | transport_nibble); // AF_INET6 + transport buf.extend_from_slice(&36u16.to_be_bytes()); // addr block length buf.extend_from_slice(&src_ip.octets()); buf.extend_from_slice(&dst_ip.octets()); buf.extend_from_slice(&source.port().to_be_bytes()); buf.extend_from_slice(&dest.port().to_be_bytes()); buf } // Mixed IPv4/IPv6: map IPv4 to IPv6-mapped address _ => { let src_v6 = match source.ip() { IpAddr::V4(v4) => v4.to_ipv6_mapped(), IpAddr::V6(v6) => v6, }; let dst_v6 = match dest.ip() { IpAddr::V4(v4) => v4.to_ipv6_mapped(), IpAddr::V6(v6) => v6, }; let src6 = SocketAddr::new(IpAddr::V6(src_v6), source.port()); let dst6 = SocketAddr::new(IpAddr::V6(dst_v6), dest.port()); generate_v2(&src6, &dst6, transport) } } } #[cfg(test)] mod tests { use super::*; // ===== v1 tests ===== #[test] fn test_parse_v1_tcp4() { let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n"; let (parsed, consumed) = parse_v1(header).unwrap(); assert_eq!(consumed, header.len()); assert_eq!(parsed.protocol, ProxyProtocol::Tcp4); assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100"); assert_eq!(parsed.source_addr.port(), 12345); assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1"); assert_eq!(parsed.dest_addr.port(), 443); } #[test] fn test_generate_v1() { let source: SocketAddr = "192.168.1.100:12345".parse().unwrap(); let dest: SocketAddr = "10.0.0.1:443".parse().unwrap(); let header = generate_v1(&source, &dest); assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n"); } #[test] fn test_is_proxy_protocol() { assert!(is_proxy_protocol_v1(b"PROXY TCP4 ...")); assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1")); } // ===== v2 tests ===== #[test] fn test_is_proxy_protocol_v2() { assert!(is_proxy_protocol_v2(&PROXY_V2_SIGNATURE)); assert!(!is_proxy_protocol_v2(b"PROXY TCP4 ...")); assert!(!is_proxy_protocol_v2(b"short")); } #[test] fn test_parse_v2_tcp4() { let source: SocketAddr = "198.51.100.10:54321".parse().unwrap(); let dest: SocketAddr = "203.0.113.25:8443".parse().unwrap(); let header = generate_v2(&source, &dest, ProxyV2Transport::Stream); assert_eq!(header.len(), 28); let (parsed, consumed) = parse_v2(&header).unwrap(); assert_eq!(consumed, 28); assert_eq!(parsed.protocol, ProxyProtocol::Tcp4); assert_eq!(parsed.source_addr, source); assert_eq!(parsed.dest_addr, dest); } #[test] fn test_parse_v2_udp4() { let source: SocketAddr = "10.0.0.1:12345".parse().unwrap(); let dest: SocketAddr = "10.0.0.2:53".parse().unwrap(); let header = generate_v2(&source, &dest, ProxyV2Transport::Datagram); assert_eq!(header.len(), 28); assert_eq!(header[13], 0x12); // AF_INET + DGRAM let (parsed, consumed) = parse_v2(&header).unwrap(); assert_eq!(consumed, 28); assert_eq!(parsed.protocol, ProxyProtocol::Udp4); assert_eq!(parsed.source_addr, source); assert_eq!(parsed.dest_addr, dest); } #[test] fn test_parse_v2_tcp6() { let source: SocketAddr = "[2001:db8::1]:54321".parse().unwrap(); let dest: SocketAddr = "[2001:db8::2]:443".parse().unwrap(); let header = generate_v2(&source, &dest, ProxyV2Transport::Stream); assert_eq!(header.len(), 52); assert_eq!(header[13], 0x21); // AF_INET6 + STREAM let (parsed, consumed) = parse_v2(&header).unwrap(); assert_eq!(consumed, 52); assert_eq!(parsed.protocol, ProxyProtocol::Tcp6); assert_eq!(parsed.source_addr, source); assert_eq!(parsed.dest_addr, dest); } #[test] fn test_generate_v2_tcp4_byte_layout() { let source: SocketAddr = "1.2.3.4:1000".parse().unwrap(); let dest: SocketAddr = "5.6.7.8:443".parse().unwrap(); let header = generate_v2(&source, &dest, ProxyV2Transport::Stream); assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE); assert_eq!(header[12], 0x21); // v2, PROXY assert_eq!(header[13], 0x11); // AF_INET, STREAM assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12); // addr len assert_eq!(&header[16..20], &[1, 2, 3, 4]); // src ip assert_eq!(&header[20..24], &[5, 6, 7, 8]); // dst ip assert_eq!(u16::from_be_bytes([header[24], header[25]]), 1000); // src port assert_eq!(u16::from_be_bytes([header[26], header[27]]), 443); // dst port } #[test] fn test_generate_v2_udp4_byte_layout() { let source: SocketAddr = "10.0.0.1:5000".parse().unwrap(); let dest: SocketAddr = "10.0.0.2:53".parse().unwrap(); let header = generate_v2(&source, &dest, ProxyV2Transport::Datagram); assert_eq!(header[12], 0x21); // v2, PROXY assert_eq!(header[13], 0x12); // AF_INET, DGRAM (UDP) } #[test] fn test_parse_v2_local_command() { // Build a LOCAL command header (no addresses) let mut header = Vec::new(); header.extend_from_slice(&PROXY_V2_SIGNATURE); header.push(0x20); // v2, LOCAL header.push(0x00); // AF_UNSPEC header.extend_from_slice(&0u16.to_be_bytes()); // 0-length address block let (parsed, consumed) = parse_v2(&header).unwrap(); assert_eq!(consumed, 16); assert_eq!(parsed.protocol, ProxyProtocol::Unknown); assert_eq!(parsed.source_addr.port(), 0); } #[test] fn test_parse_v2_incomplete() { let data = &PROXY_V2_SIGNATURE[..8]; // only 8 bytes assert!(parse_v2(data).is_err()); } #[test] fn test_parse_v2_wrong_version() { let mut header = Vec::new(); header.extend_from_slice(&PROXY_V2_SIGNATURE); header.push(0x11); // version 1, not 2 header.push(0x11); header.extend_from_slice(&12u16.to_be_bytes()); header.extend_from_slice(&[0u8; 12]); assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion))); } #[test] fn test_v2_roundtrip_with_trailing_data() { let source: SocketAddr = "192.168.1.1:8080".parse().unwrap(); let dest: SocketAddr = "10.0.0.1:443".parse().unwrap(); let mut data = generate_v2(&source, &dest, ProxyV2Transport::Stream); data.extend_from_slice(b"GET / HTTP/1.1\r\n"); // trailing app data let (parsed, consumed) = parse_v2(&data).unwrap(); assert_eq!(consumed, 28); assert_eq!(parsed.source_addr, source); assert_eq!(&data[consumed..], b"GET / HTTP/1.1\r\n"); } }