use crate::name::{decode_name, encode_name}; use crate::types::{QClass, QType, FLAG_QR, FLAG_AA, FLAG_RD, FLAG_RA, FLAG_AD, EDNS_DO_BIT}; /// A parsed DNS question. #[derive(Debug, Clone)] pub struct DnsQuestion { pub name: String, pub qtype: QType, pub qclass: QClass, } /// A parsed DNS resource record. #[derive(Debug, Clone)] pub struct DnsRecord { pub name: String, pub rtype: QType, pub rclass: QClass, pub ttl: u32, pub rdata: Vec, // For OPT records, the flags are stored in the TTL field position pub opt_flags: Option, } /// A complete DNS packet (parsed). #[derive(Debug, Clone)] pub struct DnsPacket { pub id: u16, pub flags: u16, pub questions: Vec, pub answers: Vec, pub authorities: Vec, pub additionals: Vec, } impl DnsPacket { /// Create a new empty query packet. pub fn new_query(id: u16) -> Self { DnsPacket { id, flags: 0, questions: Vec::new(), answers: Vec::new(), authorities: Vec::new(), additionals: Vec::new(), } } /// Create a response packet for a given request. pub fn new_response(request: &DnsPacket) -> Self { let mut flags = FLAG_QR | FLAG_AA | FLAG_RA; if request.flags & FLAG_RD != 0 { flags |= FLAG_RD; } DnsPacket { id: request.id, flags, questions: request.questions.clone(), answers: Vec::new(), authorities: Vec::new(), additionals: Vec::new(), } } /// Extract the response code (lower 4 bits of flags). pub fn rcode(&self) -> u8 { (self.flags & 0x000F) as u8 } /// Check if the AD (Authenticated Data) flag is set. pub fn has_ad_flag(&self) -> bool { self.flags & FLAG_AD != 0 } /// Check if DNSSEC (DO bit) is requested in the OPT record. pub fn is_dnssec_requested(&self) -> bool { for additional in &self.additionals { if additional.rtype == QType::OPT { if let Some(flags) = additional.opt_flags { if flags & EDNS_DO_BIT != 0 { return true; } } } } false } /// Parse a DNS packet from wire format bytes. pub fn parse(data: &[u8]) -> Result { if data.len() < 12 { return Err("packet too short for DNS header".into()); } let id = u16::from_be_bytes([data[0], data[1]]); let flags = u16::from_be_bytes([data[2], data[3]]); let qdcount = u16::from_be_bytes([data[4], data[5]]) as usize; let ancount = u16::from_be_bytes([data[6], data[7]]) as usize; let nscount = u16::from_be_bytes([data[8], data[9]]) as usize; let arcount = u16::from_be_bytes([data[10], data[11]]) as usize; let mut offset = 12; // Parse questions let mut questions = Vec::with_capacity(qdcount); for _ in 0..qdcount { let (name, consumed) = decode_name(data, offset).map_err(|e| e.to_string())?; offset += consumed; if offset + 4 > data.len() { return Err("packet too short for question fields".into()); } let qtype = QType::from_u16(u16::from_be_bytes([data[offset], data[offset + 1]])); let qclass = QClass::from_u16(u16::from_be_bytes([data[offset + 2], data[offset + 3]])); offset += 4; questions.push(DnsQuestion { name, qtype, qclass }); } // Parse resource records fn parse_records(data: &[u8], offset: &mut usize, count: usize) -> Result, String> { let mut records = Vec::with_capacity(count); for _ in 0..count { let (name, consumed) = decode_name(data, *offset).map_err(|e| e.to_string())?; *offset += consumed; if *offset + 10 > data.len() { return Err("packet too short for RR fields".into()); } let rtype = QType::from_u16(u16::from_be_bytes([data[*offset], data[*offset + 1]])); let rclass_or_payload = u16::from_be_bytes([data[*offset + 2], data[*offset + 3]]); let ttl_bytes = u32::from_be_bytes([data[*offset + 4], data[*offset + 5], data[*offset + 6], data[*offset + 7]]); let rdlength = u16::from_be_bytes([data[*offset + 8], data[*offset + 9]]) as usize; *offset += 10; if *offset + rdlength > data.len() { return Err("packet too short for RDATA".into()); } let rdata = data[*offset..*offset + rdlength].to_vec(); *offset += rdlength; // For OPT records, extract flags from the TTL position let (rclass, ttl, opt_flags) = if rtype == QType::OPT { // OPT: class = UDP payload size, TTL upper 16 = extended RCODE + version, lower 16 = flags let flags = (ttl_bytes & 0xFFFF) as u16; (QClass::from_u16(rclass_or_payload), 0, Some(flags)) } else { (QClass::from_u16(rclass_or_payload), ttl_bytes, None) }; records.push(DnsRecord { name, rtype, rclass, ttl, rdata, opt_flags, }); } Ok(records) } let answers = parse_records(data, &mut offset, ancount)?; let authorities = parse_records(data, &mut offset, nscount)?; let additionals = parse_records(data, &mut offset, arcount)?; Ok(DnsPacket { id, flags, questions, answers, authorities, additionals, }) } /// Encode this DNS packet to wire format bytes. pub fn encode(&self) -> Vec { let mut buf = Vec::with_capacity(512); // Header buf.extend_from_slice(&self.id.to_be_bytes()); buf.extend_from_slice(&self.flags.to_be_bytes()); buf.extend_from_slice(&(self.questions.len() as u16).to_be_bytes()); buf.extend_from_slice(&(self.answers.len() as u16).to_be_bytes()); buf.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes()); buf.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes()); // Questions for q in &self.questions { buf.extend_from_slice(&encode_name(&q.name)); buf.extend_from_slice(&q.qtype.to_u16().to_be_bytes()); buf.extend_from_slice(&q.qclass.to_u16().to_be_bytes()); } // Resource records fn encode_records(buf: &mut Vec, records: &[DnsRecord]) { for rr in records { buf.extend_from_slice(&encode_name(&rr.name)); buf.extend_from_slice(&rr.rtype.to_u16().to_be_bytes()); if rr.rtype == QType::OPT { // OPT: class = UDP payload size (4096), TTL = ext rcode + flags buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes()); let flags = rr.opt_flags.unwrap_or(0) as u32; buf.extend_from_slice(&flags.to_be_bytes()); } else { buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes()); buf.extend_from_slice(&rr.ttl.to_be_bytes()); } buf.extend_from_slice(&(rr.rdata.len() as u16).to_be_bytes()); buf.extend_from_slice(&rr.rdata); } } encode_records(&mut buf, &self.answers); encode_records(&mut buf, &self.authorities); encode_records(&mut buf, &self.additionals); buf } } // ── RDATA encoding helpers ───────────────────────────────────────── /// Encode an A record (IPv4 address string -> 4 bytes). pub fn encode_a(ip: &str) -> Vec { ip.split('.') .filter_map(|s| s.parse::().ok()) .collect() } /// Encode an AAAA record (IPv6 address string -> 16 bytes). pub fn encode_aaaa(ip: &str) -> Vec { // Handle :: expansion let expanded = expand_ipv6(ip); expanded .split(':') .flat_map(|seg| { let val = u16::from_str_radix(seg, 16).unwrap_or(0); val.to_be_bytes().to_vec() }) .collect() } fn expand_ipv6(ip: &str) -> String { if !ip.contains("::") { return ip.to_string(); } let parts: Vec<&str> = ip.split("::").collect(); let left: Vec<&str> = if parts[0].is_empty() { vec![] } else { parts[0].split(':').collect() }; let right: Vec<&str> = if parts.len() > 1 && !parts[1].is_empty() { parts[1].split(':').collect() } else { vec![] }; let fill_count = 8 - left.len() - right.len(); let mut result: Vec = left.iter().map(|s| s.to_string()).collect(); for _ in 0..fill_count { result.push("0".to_string()); } result.extend(right.iter().map(|s| s.to_string())); result.join(":") } /// Encode a TXT record (array of strings -> length-prefixed chunks). pub fn encode_txt(strings: &[String]) -> Vec { let mut buf = Vec::new(); for s in strings { let bytes = s.as_bytes(); // TXT strings must be <= 255 bytes each let len = bytes.len().min(255); buf.push(len as u8); buf.extend_from_slice(&bytes[..len]); } buf } /// Encode a domain name for use in RDATA (NS, CNAME, PTR, etc.). pub fn encode_name_rdata(name: &str) -> Vec { encode_name(name) } /// Encode a SOA record RDATA. pub fn encode_soa(mname: &str, rname: &str, serial: u32, refresh: u32, retry: u32, expire: u32, minimum: u32) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&encode_name(mname)); buf.extend_from_slice(&encode_name(rname)); buf.extend_from_slice(&serial.to_be_bytes()); buf.extend_from_slice(&refresh.to_be_bytes()); buf.extend_from_slice(&retry.to_be_bytes()); buf.extend_from_slice(&expire.to_be_bytes()); buf.extend_from_slice(&minimum.to_be_bytes()); buf } /// Encode an MX record RDATA. pub fn encode_mx(preference: u16, exchange: &str) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&preference.to_be_bytes()); buf.extend_from_slice(&encode_name(exchange)); buf } /// Encode a SRV record RDATA. pub fn encode_srv(priority: u16, weight: u16, port: u16, target: &str) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&priority.to_be_bytes()); buf.extend_from_slice(&weight.to_be_bytes()); buf.extend_from_slice(&port.to_be_bytes()); buf.extend_from_slice(&encode_name(target)); buf } /// Encode a DNSKEY record RDATA. pub fn encode_dnskey(flags: u16, protocol: u8, algorithm: u8, public_key: &[u8]) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&flags.to_be_bytes()); buf.push(protocol); buf.push(algorithm); buf.extend_from_slice(public_key); buf } /// Encode an RRSIG record RDATA. pub fn encode_rrsig( type_covered: u16, algorithm: u8, labels: u8, original_ttl: u32, expiration: u32, inception: u32, key_tag: u16, signers_name: &str, signature: &[u8], ) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&type_covered.to_be_bytes()); buf.push(algorithm); buf.push(labels); buf.extend_from_slice(&original_ttl.to_be_bytes()); buf.extend_from_slice(&expiration.to_be_bytes()); buf.extend_from_slice(&inception.to_be_bytes()); buf.extend_from_slice(&key_tag.to_be_bytes()); buf.extend_from_slice(&encode_name(signers_name)); buf.extend_from_slice(signature); buf } // ── RDATA decoding helpers ───────────────────────────────────────── /// Decode an A record (4 bytes -> IPv4 string). pub fn decode_a(rdata: &[u8]) -> Result { if rdata.len() < 4 { return Err("A rdata too short"); } Ok(format!("{}.{}.{}.{}", rdata[0], rdata[1], rdata[2], rdata[3])) } /// Decode an AAAA record (16 bytes -> IPv6 string). pub fn decode_aaaa(rdata: &[u8]) -> Result { if rdata.len() < 16 { return Err("AAAA rdata too short"); } let groups: Vec = (0..8) .map(|i| { let val = u16::from_be_bytes([rdata[i * 2], rdata[i * 2 + 1]]); format!("{:x}", val) }) .collect(); // Build full form, then compress :: notation let full = groups.join(":"); compress_ipv6(&full) } /// Compress a full IPv6 address to shortest form. fn compress_ipv6(full: &str) -> Result { let groups: Vec<&str> = full.split(':').collect(); if groups.len() != 8 { return Ok(full.to_string()); } // Find longest run of consecutive "0" groups let mut best_start = None; let mut best_len = 0usize; let mut cur_start = None; let mut cur_len = 0usize; for (i, g) in groups.iter().enumerate() { if *g == "0" { if cur_start.is_none() { cur_start = Some(i); cur_len = 1; } else { cur_len += 1; } if cur_len > best_len { best_start = cur_start; best_len = cur_len; } } else { cur_start = None; cur_len = 0; } } if best_len >= 2 { let bs = best_start.unwrap(); let left: Vec<&str> = groups[..bs].to_vec(); let right: Vec<&str> = groups[bs + best_len..].to_vec(); let l = left.join(":"); let r = right.join(":"); if l.is_empty() && r.is_empty() { Ok("::".to_string()) } else if l.is_empty() { Ok(format!("::{}", r)) } else if r.is_empty() { Ok(format!("{}::", l)) } else { Ok(format!("{}::{}", l, r)) } } else { Ok(full.to_string()) } } /// Decode a TXT record (length-prefixed chunks -> strings). pub fn decode_txt(rdata: &[u8]) -> Result, &'static str> { let mut strings = Vec::new(); let mut pos = 0; while pos < rdata.len() { let len = rdata[pos] as usize; pos += 1; if pos + len > rdata.len() { return Err("TXT chunk extends beyond rdata"); } let s = std::str::from_utf8(&rdata[pos..pos + len]) .map_err(|_| "invalid UTF-8 in TXT")?; strings.push(s.to_string()); pos += len; } Ok(strings) } /// Decode an MX record (preference + exchange name with compression). pub fn decode_mx(rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result<(u16, String), String> { if rdata.len() < 3 { return Err("MX rdata too short".into()); } let preference = u16::from_be_bytes([rdata[0], rdata[1]]); let (name, _) = decode_name(packet, rdata_offset + 2).map_err(|e| e.to_string())?; Ok((preference, name)) } /// Decode a name from RDATA (for NS, CNAME, PTR records with compression). pub fn decode_name_rdata(_rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result { let (name, _) = decode_name(packet, rdata_offset).map_err(|e| e.to_string())?; Ok(name) } /// SOA record decoded fields. #[derive(Debug, Clone)] pub struct SoaData { pub mname: String, pub rname: String, pub serial: u32, pub refresh: u32, pub retry: u32, pub expire: u32, pub minimum: u32, } /// Decode a SOA record RDATA. pub fn decode_soa(rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result { let (mname, consumed1) = decode_name(packet, rdata_offset).map_err(|e| e.to_string())?; let (rname, consumed2) = decode_name(packet, rdata_offset + consumed1).map_err(|e| e.to_string())?; let nums_offset = consumed1 + consumed2; if rdata.len() < nums_offset + 20 { return Err("SOA rdata too short for numeric fields".into()); } let serial = u32::from_be_bytes([ rdata[nums_offset], rdata[nums_offset + 1], rdata[nums_offset + 2], rdata[nums_offset + 3], ]); let refresh = u32::from_be_bytes([ rdata[nums_offset + 4], rdata[nums_offset + 5], rdata[nums_offset + 6], rdata[nums_offset + 7], ]); let retry = u32::from_be_bytes([ rdata[nums_offset + 8], rdata[nums_offset + 9], rdata[nums_offset + 10], rdata[nums_offset + 11], ]); let expire = u32::from_be_bytes([ rdata[nums_offset + 12], rdata[nums_offset + 13], rdata[nums_offset + 14], rdata[nums_offset + 15], ]); let minimum = u32::from_be_bytes([ rdata[nums_offset + 16], rdata[nums_offset + 17], rdata[nums_offset + 18], rdata[nums_offset + 19], ]); Ok(SoaData { mname, rname, serial, refresh, retry, expire, minimum }) } /// SRV record decoded fields. #[derive(Debug, Clone)] pub struct SrvData { pub priority: u16, pub weight: u16, pub port: u16, pub target: String, } /// Decode a SRV record RDATA. pub fn decode_srv(rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result { if rdata.len() < 7 { return Err("SRV rdata too short".into()); } let priority = u16::from_be_bytes([rdata[0], rdata[1]]); let weight = u16::from_be_bytes([rdata[2], rdata[3]]); let port = u16::from_be_bytes([rdata[4], rdata[5]]); let (target, _) = decode_name(packet, rdata_offset + 6).map_err(|e| e.to_string())?; Ok(SrvData { priority, weight, port, target }) } /// Build a DnsRecord from high-level data. pub fn build_record(name: &str, rtype: QType, ttl: u32, rdata: Vec) -> DnsRecord { DnsRecord { name: name.to_string(), rtype, rclass: QClass::IN, ttl, rdata, opt_flags: None, } } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_encode_roundtrip() { // Build a simple query let mut query = DnsPacket::new_query(0x1234); query.flags = FLAG_RD; query.questions.push(DnsQuestion { name: "example.com".to_string(), qtype: QType::A, qclass: QClass::IN, }); let encoded = query.encode(); let parsed = DnsPacket::parse(&encoded).unwrap(); assert_eq!(parsed.id, 0x1234); assert_eq!(parsed.questions.len(), 1); assert_eq!(parsed.questions[0].name, "example.com"); assert_eq!(parsed.questions[0].qtype, QType::A); } #[test] fn test_response_with_answer() { let mut query = DnsPacket::new_query(0x5678); query.flags = FLAG_RD; query.questions.push(DnsQuestion { name: "test.example.com".to_string(), qtype: QType::A, qclass: QClass::IN, }); let mut response = DnsPacket::new_response(&query); response.answers.push(build_record( "test.example.com", QType::A, 300, encode_a("127.0.0.1"), )); let encoded = response.encode(); let parsed = DnsPacket::parse(&encoded).unwrap(); assert_eq!(parsed.id, 0x5678); assert!(parsed.flags & FLAG_QR != 0); // Is a response assert!(parsed.flags & FLAG_AA != 0); // Authoritative assert_eq!(parsed.answers.len(), 1); assert_eq!(parsed.answers[0].rdata, vec![127, 0, 0, 1]); } #[test] fn test_encode_aaaa() { let data = encode_aaaa("::1"); assert_eq!(data.len(), 16); assert_eq!(data[15], 1); assert!(data[..15].iter().all(|&b| b == 0)); } #[test] fn test_encode_txt() { let data = encode_txt(&["hello".to_string(), "world".to_string()]); assert_eq!(data[0], 5); // length of "hello" assert_eq!(&data[1..6], b"hello"); assert_eq!(data[6], 5); // length of "world" assert_eq!(&data[7..12], b"world"); } #[test] fn test_decode_a() { let rdata = encode_a("192.168.1.1"); let decoded = decode_a(&rdata).unwrap(); assert_eq!(decoded, "192.168.1.1"); } #[test] fn test_decode_aaaa() { let rdata = encode_aaaa("::1"); let decoded = decode_aaaa(&rdata).unwrap(); assert_eq!(decoded, "::1"); let rdata2 = encode_aaaa("2001:db8::1"); let decoded2 = decode_aaaa(&rdata2).unwrap(); assert_eq!(decoded2, "2001:db8::1"); } #[test] fn test_decode_txt() { let strings = vec!["hello".to_string(), "world".to_string()]; let rdata = encode_txt(&strings); let decoded = decode_txt(&rdata).unwrap(); assert_eq!(decoded, strings); } #[test] fn test_rcode_and_ad_flag() { let mut pkt = DnsPacket::new_query(1); assert_eq!(pkt.rcode(), 0); assert!(!pkt.has_ad_flag()); pkt.flags |= crate::types::FLAG_AD; assert!(pkt.has_ad_flag()); pkt.flags |= 0x0003; // NXDOMAIN assert_eq!(pkt.rcode(), 3); } #[test] fn test_dnssec_do_bit() { let mut query = DnsPacket::new_query(1); query.questions.push(DnsQuestion { name: "example.com".to_string(), qtype: QType::A, qclass: QClass::IN, }); // No OPT record = no DNSSEC assert!(!query.is_dnssec_requested()); // Add OPT with DO bit query.additionals.push(DnsRecord { name: ".".to_string(), rtype: QType::OPT, rclass: QClass::from_u16(4096), // UDP payload size ttl: 0, rdata: vec![], opt_flags: Some(EDNS_DO_BIT), }); assert!(query.is_dnssec_requested()); } }