130 lines
3.9 KiB
Rust
130 lines
3.9 KiB
Rust
use std::net::SocketAddr;
|
|
use thiserror::Error;
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ProxyProtocolError {
|
|
#[error("Invalid PROXY protocol header")]
|
|
InvalidHeader,
|
|
#[error("Unsupported PROXY protocol version")]
|
|
UnsupportedVersion,
|
|
#[error("Parse error: {0}")]
|
|
Parse(String),
|
|
}
|
|
|
|
/// Parsed PROXY protocol v1 header.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ProxyProtocolHeader {
|
|
pub source_addr: SocketAddr,
|
|
pub dest_addr: SocketAddr,
|
|
pub protocol: ProxyProtocol,
|
|
}
|
|
|
|
/// Protocol in PROXY header.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum ProxyProtocol {
|
|
Tcp4,
|
|
Tcp6,
|
|
Unknown,
|
|
}
|
|
|
|
/// Parse a PROXY protocol v1 header from data.
|
|
///
|
|
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
|
|
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
|
|
// Find the end of the header line
|
|
let line_end = data
|
|
.windows(2)
|
|
.position(|w| w == b"\r\n")
|
|
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
|
|
|
let line = std::str::from_utf8(&data[..line_end])
|
|
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
|
|
|
if !line.starts_with("PROXY ") {
|
|
return Err(ProxyProtocolError::InvalidHeader);
|
|
}
|
|
|
|
let parts: Vec<&str> = line.split(' ').collect();
|
|
if parts.len() != 6 {
|
|
return Err(ProxyProtocolError::InvalidHeader);
|
|
}
|
|
|
|
let protocol = match parts[1] {
|
|
"TCP4" => ProxyProtocol::Tcp4,
|
|
"TCP6" => ProxyProtocol::Tcp6,
|
|
"UNKNOWN" => ProxyProtocol::Unknown,
|
|
_ => return Err(ProxyProtocolError::UnsupportedVersion),
|
|
};
|
|
|
|
let src_ip: std::net::IpAddr = parts[2]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
|
|
let dst_ip: std::net::IpAddr = parts[3]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
|
|
let src_port: u16 = parts[4]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
|
|
let dst_port: u16 = parts[5]
|
|
.parse()
|
|
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
|
|
|
|
let header = ProxyProtocolHeader {
|
|
source_addr: SocketAddr::new(src_ip, src_port),
|
|
dest_addr: SocketAddr::new(dst_ip, dst_port),
|
|
protocol,
|
|
};
|
|
|
|
// Consumed bytes = line + \r\n
|
|
Ok((header, line_end + 2))
|
|
}
|
|
|
|
/// Generate a PROXY protocol v1 header string.
|
|
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
|
|
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
|
|
format!(
|
|
"PROXY {} {} {} {} {}\r\n",
|
|
proto,
|
|
source.ip(),
|
|
dest.ip(),
|
|
source.port(),
|
|
dest.port()
|
|
)
|
|
}
|
|
|
|
/// Check if data starts with a PROXY protocol v1 header.
|
|
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
|
|
data.starts_with(b"PROXY ")
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_parse_v1_tcp4() {
|
|
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
|
|
let (parsed, consumed) = parse_v1(header).unwrap();
|
|
assert_eq!(consumed, header.len());
|
|
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
|
|
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
|
|
assert_eq!(parsed.source_addr.port(), 12345);
|
|
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
|
|
assert_eq!(parsed.dest_addr.port(), 443);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_v1() {
|
|
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
|
|
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
|
|
let header = generate_v1(&source, &dest);
|
|
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_proxy_protocol() {
|
|
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
|
|
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
|
|
}
|
|
}
|