481 lines
17 KiB
Rust
481 lines
17 KiB
Rust
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
|
use thiserror::Error;
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ProxyProtocolError {
|
|
#[error("Invalid PROXY protocol header")]
|
|
InvalidHeader,
|
|
#[error("Unsupported PROXY protocol version")]
|
|
UnsupportedVersion,
|
|
#[error("Parse error: {0}")]
|
|
Parse(String),
|
|
#[error("Incomplete header: need {0} bytes, got {1}")]
|
|
Incomplete(usize, usize),
|
|
}
|
|
|
|
/// Parsed PROXY protocol header (v1 or v2).
|
|
#[derive(Debug, Clone)]
|
|
pub struct ProxyProtocolHeader {
|
|
pub source_addr: SocketAddr,
|
|
pub dest_addr: SocketAddr,
|
|
pub protocol: ProxyProtocol,
|
|
}
|
|
|
|
/// Protocol in PROXY header.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum ProxyProtocol {
|
|
Tcp4,
|
|
Tcp6,
|
|
Udp4,
|
|
Udp6,
|
|
Unknown,
|
|
}
|
|
|
|
/// Transport type for PROXY v2 header generation.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum ProxyV2Transport {
|
|
Stream, // TCP
|
|
Datagram, // UDP
|
|
}
|
|
|
|
/// PROXY protocol v2 signature (12 bytes).
|
|
const PROXY_V2_SIGNATURE: [u8; 12] = [
|
|
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
|
];
|
|
|
|
// ===== v1 (text format) =====
|
|
|
|
/// Parse a PROXY protocol v1 header from data.
|
|
///
|
|
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
|
|
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
|
|
let line_end = data
|
|
.windows(2)
|
|
.position(|w| w == b"\r\n")
|
|
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
|
|
|
let line = std::str::from_utf8(&data[..line_end])
|
|
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
|
|
|
if !line.starts_with("PROXY ") {
|
|
return Err(ProxyProtocolError::InvalidHeader);
|
|
}
|
|
|
|
let parts: Vec<&str> = line.split(' ').collect();
|
|
if parts.len() != 6 {
|
|
return Err(ProxyProtocolError::InvalidHeader);
|
|
}
|
|
|
|
let protocol = match parts[1] {
|
|
"TCP4" => ProxyProtocol::Tcp4,
|
|
"TCP6" => ProxyProtocol::Tcp6,
|
|
"UNKNOWN" => ProxyProtocol::Unknown,
|
|
_ => return Err(ProxyProtocolError::UnsupportedVersion),
|
|
};
|
|
|
|
let src_ip: IpAddr = parts[2]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
|
|
let dst_ip: IpAddr = parts[3]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
|
|
let src_port: u16 = parts[4]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
|
|
let dst_port: u16 = parts[5]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
|
|
|
|
let header = ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(src_ip, src_port),
|
|
dest_addr: SocketAddr::new(dst_ip, dst_port),
|
|
protocol,
|
|
};
|
|
|
|
Ok((header, line_end + 2))
|
|
}
|
|
|
|
/// Generate a PROXY protocol v1 header string.
|
|
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
|
|
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
|
|
format!(
|
|
"PROXY {} {} {} {} {}\r\n",
|
|
proto,
|
|
source.ip(),
|
|
dest.ip(),
|
|
source.port(),
|
|
dest.port()
|
|
)
|
|
}
|
|
|
|
/// Check if data starts with a PROXY protocol v1 header.
|
|
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
|
|
data.starts_with(b"PROXY ")
|
|
}
|
|
|
|
// ===== v2 (binary format) =====
|
|
|
|
/// Check if data starts with a PROXY protocol v2 header.
|
|
pub fn is_proxy_protocol_v2(data: &[u8]) -> bool {
|
|
data.len() >= 12 && data[..12] == PROXY_V2_SIGNATURE
|
|
}
|
|
|
|
/// Parse a PROXY protocol v2 binary header.
|
|
///
|
|
/// Binary format:
|
|
/// - [0..12] signature (12 bytes)
|
|
/// - [12] version (high nibble) + command (low nibble)
|
|
/// - [13] address family (high nibble) + transport (low nibble)
|
|
/// - [14..16] address block length (big-endian u16)
|
|
/// - [16..] address block (variable, depends on family)
|
|
pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
|
|
if data.len() < 16 {
|
|
return Err(ProxyProtocolError::Incomplete(16, data.len()));
|
|
}
|
|
|
|
// Validate signature
|
|
if data[..12] != PROXY_V2_SIGNATURE {
|
|
return Err(ProxyProtocolError::InvalidHeader);
|
|
}
|
|
|
|
// Version (high nibble of byte 12) must be 0x2
|
|
let version = (data[12] >> 4) & 0x0F;
|
|
if version != 2 {
|
|
return Err(ProxyProtocolError::UnsupportedVersion);
|
|
}
|
|
|
|
// Command (low nibble of byte 12)
|
|
let command = data[12] & 0x0F;
|
|
// 0x0 = LOCAL, 0x1 = PROXY
|
|
if command > 1 {
|
|
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
|
|
}
|
|
|
|
// Address family (high nibble) + transport (low nibble) of byte 13
|
|
let family = (data[13] >> 4) & 0x0F;
|
|
let transport = data[13] & 0x0F;
|
|
|
|
// Address block length
|
|
let addr_len = u16::from_be_bytes([data[14], data[15]]) as usize;
|
|
let total_len = 16 + addr_len;
|
|
|
|
if data.len() < total_len {
|
|
return Err(ProxyProtocolError::Incomplete(total_len, data.len()));
|
|
}
|
|
|
|
// LOCAL command: no real addresses, return unspecified
|
|
if command == 0 {
|
|
return Ok((
|
|
ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
|
dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
|
protocol: ProxyProtocol::Unknown,
|
|
},
|
|
total_len,
|
|
));
|
|
}
|
|
|
|
// PROXY command: parse addresses based on family + transport
|
|
let addr_block = &data[16..16 + addr_len];
|
|
|
|
match (family, transport) {
|
|
// AF_INET (0x1) + STREAM (0x1) = TCP4
|
|
(0x1, 0x1) => {
|
|
if addr_len < 12 {
|
|
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
|
}
|
|
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
|
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
|
let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]);
|
|
let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]);
|
|
Ok((
|
|
ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
|
|
dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port),
|
|
protocol: ProxyProtocol::Tcp4,
|
|
},
|
|
total_len,
|
|
))
|
|
}
|
|
// AF_INET (0x1) + DGRAM (0x2) = UDP4
|
|
(0x1, 0x2) => {
|
|
if addr_len < 12 {
|
|
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
|
}
|
|
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
|
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
|
let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]);
|
|
let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]);
|
|
Ok((
|
|
ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
|
|
dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port),
|
|
protocol: ProxyProtocol::Udp4,
|
|
},
|
|
total_len,
|
|
))
|
|
}
|
|
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
|
|
(0x2, 0x1) => {
|
|
if addr_len < 36 {
|
|
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
|
}
|
|
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
|
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
|
let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]);
|
|
let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]);
|
|
Ok((
|
|
ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
|
|
dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port),
|
|
protocol: ProxyProtocol::Tcp6,
|
|
},
|
|
total_len,
|
|
))
|
|
}
|
|
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
|
|
(0x2, 0x2) => {
|
|
if addr_len < 36 {
|
|
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
|
}
|
|
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
|
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
|
let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]);
|
|
let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]);
|
|
Ok((
|
|
ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
|
|
dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port),
|
|
protocol: ProxyProtocol::Udp6,
|
|
},
|
|
total_len,
|
|
))
|
|
}
|
|
// AF_UNSPEC or unknown
|
|
(0x0, _) => Ok((
|
|
ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
|
dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
|
protocol: ProxyProtocol::Unknown,
|
|
},
|
|
total_len,
|
|
)),
|
|
_ => Err(ProxyProtocolError::Parse(format!(
|
|
"Unsupported family/transport: 0x{:X}{:X}",
|
|
family, transport
|
|
))),
|
|
}
|
|
}
|
|
|
|
/// Generate a PROXY protocol v2 binary header.
|
|
pub fn generate_v2(
|
|
source: &SocketAddr,
|
|
dest: &SocketAddr,
|
|
transport: ProxyV2Transport,
|
|
) -> Vec<u8> {
|
|
let transport_nibble: u8 = match transport {
|
|
ProxyV2Transport::Stream => 0x1,
|
|
ProxyV2Transport::Datagram => 0x2,
|
|
};
|
|
|
|
match (source.ip(), dest.ip()) {
|
|
(IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => {
|
|
let mut buf = Vec::with_capacity(28);
|
|
buf.extend_from_slice(&PROXY_V2_SIGNATURE);
|
|
buf.push(0x21); // version 2, PROXY command
|
|
buf.push(0x10 | transport_nibble); // AF_INET + transport
|
|
buf.extend_from_slice(&12u16.to_be_bytes()); // addr block length
|
|
buf.extend_from_slice(&src_ip.octets());
|
|
buf.extend_from_slice(&dst_ip.octets());
|
|
buf.extend_from_slice(&source.port().to_be_bytes());
|
|
buf.extend_from_slice(&dest.port().to_be_bytes());
|
|
buf
|
|
}
|
|
(IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => {
|
|
let mut buf = Vec::with_capacity(52);
|
|
buf.extend_from_slice(&PROXY_V2_SIGNATURE);
|
|
buf.push(0x21); // version 2, PROXY command
|
|
buf.push(0x20 | transport_nibble); // AF_INET6 + transport
|
|
buf.extend_from_slice(&36u16.to_be_bytes()); // addr block length
|
|
buf.extend_from_slice(&src_ip.octets());
|
|
buf.extend_from_slice(&dst_ip.octets());
|
|
buf.extend_from_slice(&source.port().to_be_bytes());
|
|
buf.extend_from_slice(&dest.port().to_be_bytes());
|
|
buf
|
|
}
|
|
// Mixed IPv4/IPv6: map IPv4 to IPv6-mapped address
|
|
_ => {
|
|
let src_v6 = match source.ip() {
|
|
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
|
|
IpAddr::V6(v6) => v6,
|
|
};
|
|
let dst_v6 = match dest.ip() {
|
|
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
|
|
IpAddr::V6(v6) => v6,
|
|
};
|
|
let src6 = SocketAddr::new(IpAddr::V6(src_v6), source.port());
|
|
let dst6 = SocketAddr::new(IpAddr::V6(dst_v6), dest.port());
|
|
generate_v2(&src6, &dst6, transport)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// ===== v1 tests =====
|
|
|
|
#[test]
|
|
fn test_parse_v1_tcp4() {
|
|
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
|
|
let (parsed, consumed) = parse_v1(header).unwrap();
|
|
assert_eq!(consumed, header.len());
|
|
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
|
|
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
|
|
assert_eq!(parsed.source_addr.port(), 12345);
|
|
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
|
|
assert_eq!(parsed.dest_addr.port(), 443);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_v1() {
|
|
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
|
|
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
|
|
let header = generate_v1(&source, &dest);
|
|
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_proxy_protocol() {
|
|
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
|
|
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
|
|
}
|
|
|
|
// ===== v2 tests =====
|
|
|
|
#[test]
|
|
fn test_is_proxy_protocol_v2() {
|
|
assert!(is_proxy_protocol_v2(&PROXY_V2_SIGNATURE));
|
|
assert!(!is_proxy_protocol_v2(b"PROXY TCP4 ..."));
|
|
assert!(!is_proxy_protocol_v2(b"short"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_v2_tcp4() {
|
|
let source: SocketAddr = "198.51.100.10:54321".parse().unwrap();
|
|
let dest: SocketAddr = "203.0.113.25:8443".parse().unwrap();
|
|
let header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
|
|
|
|
assert_eq!(header.len(), 28);
|
|
let (parsed, consumed) = parse_v2(&header).unwrap();
|
|
assert_eq!(consumed, 28);
|
|
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
|
|
assert_eq!(parsed.source_addr, source);
|
|
assert_eq!(parsed.dest_addr, dest);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_v2_udp4() {
|
|
let source: SocketAddr = "10.0.0.1:12345".parse().unwrap();
|
|
let dest: SocketAddr = "10.0.0.2:53".parse().unwrap();
|
|
let header = generate_v2(&source, &dest, ProxyV2Transport::Datagram);
|
|
|
|
assert_eq!(header.len(), 28);
|
|
assert_eq!(header[13], 0x12); // AF_INET + DGRAM
|
|
|
|
let (parsed, consumed) = parse_v2(&header).unwrap();
|
|
assert_eq!(consumed, 28);
|
|
assert_eq!(parsed.protocol, ProxyProtocol::Udp4);
|
|
assert_eq!(parsed.source_addr, source);
|
|
assert_eq!(parsed.dest_addr, dest);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_v2_tcp6() {
|
|
let source: SocketAddr = "[2001:db8::1]:54321".parse().unwrap();
|
|
let dest: SocketAddr = "[2001:db8::2]:443".parse().unwrap();
|
|
let header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
|
|
|
|
assert_eq!(header.len(), 52);
|
|
assert_eq!(header[13], 0x21); // AF_INET6 + STREAM
|
|
|
|
let (parsed, consumed) = parse_v2(&header).unwrap();
|
|
assert_eq!(consumed, 52);
|
|
assert_eq!(parsed.protocol, ProxyProtocol::Tcp6);
|
|
assert_eq!(parsed.source_addr, source);
|
|
assert_eq!(parsed.dest_addr, dest);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_v2_tcp4_byte_layout() {
|
|
let source: SocketAddr = "1.2.3.4:1000".parse().unwrap();
|
|
let dest: SocketAddr = "5.6.7.8:443".parse().unwrap();
|
|
let header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
|
|
|
|
assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE);
|
|
assert_eq!(header[12], 0x21); // v2, PROXY
|
|
assert_eq!(header[13], 0x11); // AF_INET, STREAM
|
|
assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12); // addr len
|
|
assert_eq!(&header[16..20], &[1, 2, 3, 4]); // src ip
|
|
assert_eq!(&header[20..24], &[5, 6, 7, 8]); // dst ip
|
|
assert_eq!(u16::from_be_bytes([header[24], header[25]]), 1000); // src port
|
|
assert_eq!(u16::from_be_bytes([header[26], header[27]]), 443); // dst port
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_v2_udp4_byte_layout() {
|
|
let source: SocketAddr = "10.0.0.1:5000".parse().unwrap();
|
|
let dest: SocketAddr = "10.0.0.2:53".parse().unwrap();
|
|
let header = generate_v2(&source, &dest, ProxyV2Transport::Datagram);
|
|
|
|
assert_eq!(header[12], 0x21); // v2, PROXY
|
|
assert_eq!(header[13], 0x12); // AF_INET, DGRAM (UDP)
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_v2_local_command() {
|
|
// Build a LOCAL command header (no addresses)
|
|
let mut header = Vec::new();
|
|
header.extend_from_slice(&PROXY_V2_SIGNATURE);
|
|
header.push(0x20); // v2, LOCAL
|
|
header.push(0x00); // AF_UNSPEC
|
|
header.extend_from_slice(&0u16.to_be_bytes()); // 0-length address block
|
|
|
|
let (parsed, consumed) = parse_v2(&header).unwrap();
|
|
assert_eq!(consumed, 16);
|
|
assert_eq!(parsed.protocol, ProxyProtocol::Unknown);
|
|
assert_eq!(parsed.source_addr.port(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_v2_incomplete() {
|
|
let data = &PROXY_V2_SIGNATURE[..8]; // only 8 bytes
|
|
assert!(parse_v2(data).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_v2_wrong_version() {
|
|
let mut header = Vec::new();
|
|
header.extend_from_slice(&PROXY_V2_SIGNATURE);
|
|
header.push(0x11); // version 1, not 2
|
|
header.push(0x11);
|
|
header.extend_from_slice(&12u16.to_be_bytes());
|
|
header.extend_from_slice(&[0u8; 12]);
|
|
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_v2_roundtrip_with_trailing_data() {
|
|
let source: SocketAddr = "192.168.1.1:8080".parse().unwrap();
|
|
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
|
|
let mut data = generate_v2(&source, &dest, ProxyV2Transport::Stream);
|
|
data.extend_from_slice(b"GET / HTTP/1.1\r\n"); // trailing app data
|
|
|
|
let (parsed, consumed) = parse_v2(&data).unwrap();
|
|
assert_eq!(consumed, 28);
|
|
assert_eq!(parsed.source_addr, source);
|
|
assert_eq!(&data[consumed..], b"GET / HTTP/1.1\r\n");
|
|
}
|
|
}
|