Files
smartdns/rust/crates/rustdns-protocol/src/name.rs

109 lines
3.1 KiB
Rust
Raw Normal View History

/// Encode a domain name into DNS wire format.
/// e.g. "example.com" -> [7, 'e','x','a','m','p','l','e', 3, 'c','o','m', 0]
pub fn encode_name(name: &str) -> Vec<u8> {
let mut buf = Vec::new();
let trimmed = name.strip_suffix('.').unwrap_or(name);
if trimmed.is_empty() {
buf.push(0);
return buf;
}
for label in trimmed.split('.') {
let len = label.len();
if len > 63 {
// Truncate to 63 per DNS spec
buf.push(63);
buf.extend_from_slice(&label.as_bytes()[..63]);
} else {
buf.push(len as u8);
buf.extend_from_slice(label.as_bytes());
}
}
buf.push(0); // root label
buf
}
/// Decode a domain name from DNS wire format at the given offset.
/// Returns (name, bytes_consumed).
/// Handles compression pointers (0xC0 prefix).
pub fn decode_name(data: &[u8], offset: usize) -> Result<(String, usize), &'static str> {
let mut labels: Vec<String> = Vec::new();
let mut pos = offset;
let mut bytes_consumed = 0;
let mut jumped = false;
loop {
if pos >= data.len() {
return Err("unexpected end of data in name");
}
let len = data[pos] as usize;
if len == 0 {
// Root label
if !jumped {
bytes_consumed = pos - offset + 1;
}
break;
}
// Check for compression pointer
if len & 0xC0 == 0xC0 {
if pos + 1 >= data.len() {
return Err("unexpected end of data in compression pointer");
}
let pointer = ((len & 0x3F) << 8) | (data[pos + 1] as usize);
if !jumped {
bytes_consumed = pos - offset + 2;
jumped = true;
}
pos = pointer;
continue;
}
// Regular label
pos += 1;
if pos + len > data.len() {
return Err("label extends beyond data");
}
let label = std::str::from_utf8(&data[pos..pos + len]).map_err(|_| "invalid UTF-8 in label")?;
labels.push(label.to_string());
pos += len;
}
if bytes_consumed == 0 && !jumped {
bytes_consumed = 1; // just the root label
}
Ok((labels.join("."), bytes_consumed))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_roundtrip() {
let names = vec!["example.com", "sub.domain.example.com", "localhost", "a.b.c.d.e"];
for name in names {
let encoded = encode_name(name);
let (decoded, consumed) = decode_name(&encoded, 0).unwrap();
assert_eq!(decoded, name);
assert_eq!(consumed, encoded.len());
}
}
#[test]
fn test_encode_trailing_dot() {
let a = encode_name("example.com.");
let b = encode_name("example.com");
assert_eq!(a, b);
}
#[test]
fn test_root_name() {
let encoded = encode_name("");
assert_eq!(encoded, vec![0]);
let (decoded, _) = decode_name(&encoded, 0).unwrap();
assert_eq!(decoded, "");
}
}