109 lines
3.1 KiB
Rust
109 lines
3.1 KiB
Rust
|
|
/// Encode a domain name into DNS wire format.
|
||
|
|
/// e.g. "example.com" -> [7, 'e','x','a','m','p','l','e', 3, 'c','o','m', 0]
|
||
|
|
pub fn encode_name(name: &str) -> Vec<u8> {
|
||
|
|
let mut buf = Vec::new();
|
||
|
|
let trimmed = name.strip_suffix('.').unwrap_or(name);
|
||
|
|
if trimmed.is_empty() {
|
||
|
|
buf.push(0);
|
||
|
|
return buf;
|
||
|
|
}
|
||
|
|
for label in trimmed.split('.') {
|
||
|
|
let len = label.len();
|
||
|
|
if len > 63 {
|
||
|
|
// Truncate to 63 per DNS spec
|
||
|
|
buf.push(63);
|
||
|
|
buf.extend_from_slice(&label.as_bytes()[..63]);
|
||
|
|
} else {
|
||
|
|
buf.push(len as u8);
|
||
|
|
buf.extend_from_slice(label.as_bytes());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
buf.push(0); // root label
|
||
|
|
buf
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Decode a domain name from DNS wire format at the given offset.
|
||
|
|
/// Returns (name, bytes_consumed).
|
||
|
|
/// Handles compression pointers (0xC0 prefix).
|
||
|
|
pub fn decode_name(data: &[u8], offset: usize) -> Result<(String, usize), &'static str> {
|
||
|
|
let mut labels: Vec<String> = Vec::new();
|
||
|
|
let mut pos = offset;
|
||
|
|
let mut bytes_consumed = 0;
|
||
|
|
let mut jumped = false;
|
||
|
|
|
||
|
|
loop {
|
||
|
|
if pos >= data.len() {
|
||
|
|
return Err("unexpected end of data in name");
|
||
|
|
}
|
||
|
|
let len = data[pos] as usize;
|
||
|
|
|
||
|
|
if len == 0 {
|
||
|
|
// Root label
|
||
|
|
if !jumped {
|
||
|
|
bytes_consumed = pos - offset + 1;
|
||
|
|
}
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check for compression pointer
|
||
|
|
if len & 0xC0 == 0xC0 {
|
||
|
|
if pos + 1 >= data.len() {
|
||
|
|
return Err("unexpected end of data in compression pointer");
|
||
|
|
}
|
||
|
|
let pointer = ((len & 0x3F) << 8) | (data[pos + 1] as usize);
|
||
|
|
if !jumped {
|
||
|
|
bytes_consumed = pos - offset + 2;
|
||
|
|
jumped = true;
|
||
|
|
}
|
||
|
|
pos = pointer;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Regular label
|
||
|
|
pos += 1;
|
||
|
|
if pos + len > data.len() {
|
||
|
|
return Err("label extends beyond data");
|
||
|
|
}
|
||
|
|
let label = std::str::from_utf8(&data[pos..pos + len]).map_err(|_| "invalid UTF-8 in label")?;
|
||
|
|
labels.push(label.to_string());
|
||
|
|
pos += len;
|
||
|
|
}
|
||
|
|
|
||
|
|
if bytes_consumed == 0 && !jumped {
|
||
|
|
bytes_consumed = 1; // just the root label
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok((labels.join("."), bytes_consumed))
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_encode_decode_roundtrip() {
|
||
|
|
let names = vec!["example.com", "sub.domain.example.com", "localhost", "a.b.c.d.e"];
|
||
|
|
for name in names {
|
||
|
|
let encoded = encode_name(name);
|
||
|
|
let (decoded, consumed) = decode_name(&encoded, 0).unwrap();
|
||
|
|
assert_eq!(decoded, name);
|
||
|
|
assert_eq!(consumed, encoded.len());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_encode_trailing_dot() {
|
||
|
|
let a = encode_name("example.com.");
|
||
|
|
let b = encode_name("example.com");
|
||
|
|
assert_eq!(a, b);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_root_name() {
|
||
|
|
let encoded = encode_name("");
|
||
|
|
assert_eq!(encoded, vec![0]);
|
||
|
|
let (decoded, _) = decode_name(&encoded, 0).unwrap();
|
||
|
|
assert_eq!(decoded, "");
|
||
|
|
}
|
||
|
|
}
|