use std::net::SocketAddr; use thiserror::Error; #[derive(Debug, Error)] pub enum ProxyProtocolError { #[error("Invalid PROXY protocol header")] InvalidHeader, #[error("Unsupported PROXY protocol version")] UnsupportedVersion, #[error("Parse error: {0}")] Parse(String), } /// Parsed PROXY protocol v1 header. #[derive(Debug, Clone)] pub struct ProxyProtocolHeader { pub source_addr: SocketAddr, pub dest_addr: SocketAddr, pub protocol: ProxyProtocol, } /// Protocol in PROXY header. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ProxyProtocol { Tcp4, Tcp6, Unknown, } /// Parse a PROXY protocol v1 header from data. /// /// Format: `PROXY TCP4 \r\n` pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> { // Find the end of the header line let line_end = data .windows(2) .position(|w| w == b"\r\n") .ok_or(ProxyProtocolError::InvalidHeader)?; let line = std::str::from_utf8(&data[..line_end]) .map_err(|_| ProxyProtocolError::InvalidHeader)?; if !line.starts_with("PROXY ") { return Err(ProxyProtocolError::InvalidHeader); } let parts: Vec<&str> = line.split(' ').collect(); if parts.len() != 6 { return Err(ProxyProtocolError::InvalidHeader); } let protocol = match parts[1] { "TCP4" => ProxyProtocol::Tcp4, "TCP6" => ProxyProtocol::Tcp6, "UNKNOWN" => ProxyProtocol::Unknown, _ => return Err(ProxyProtocolError::UnsupportedVersion), }; let src_ip: std::net::IpAddr = parts[2] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?; let dst_ip: std::net::IpAddr = parts[3] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?; let src_port: u16 = parts[4] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?; let dst_port: u16 = parts[5] .parse() .map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?; let header = ProxyProtocolHeader { source_addr: SocketAddr::new(src_ip, src_port), dest_addr: SocketAddr::new(dst_ip, dst_port), protocol, }; // Consumed bytes = line + \r\n Ok((header, line_end + 2)) } /// Generate a PROXY protocol v1 header string. pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String { let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" }; format!( "PROXY {} {} {} {} {}\r\n", proto, source.ip(), dest.ip(), source.port(), dest.port() ) } /// Check if data starts with a PROXY protocol v1 header. pub fn is_proxy_protocol_v1(data: &[u8]) -> bool { data.starts_with(b"PROXY ") } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_v1_tcp4() { let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n"; let (parsed, consumed) = parse_v1(header).unwrap(); assert_eq!(consumed, header.len()); assert_eq!(parsed.protocol, ProxyProtocol::Tcp4); assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100"); assert_eq!(parsed.source_addr.port(), 12345); assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1"); assert_eq!(parsed.dest_addr.port(), 443); } #[test] fn test_generate_v1() { let source: SocketAddr = "192.168.1.100:12345".parse().unwrap(); let dest: SocketAddr = "10.0.0.1:443".parse().unwrap(); let header = generate_v1(&source, &dest); assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n"); } #[test] fn test_is_proxy_protocol() { assert!(is_proxy_protocol_v1(b"PROXY TCP4 ...")); assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1")); } }