use crate::ipc_types::{ClientDnsAnswer, ResolveParams, ResolveResult}; use rustdns_protocol::packet::{ decode_a, decode_aaaa, decode_mx, decode_name_rdata, decode_soa, decode_srv, decode_txt, DnsPacket, DnsQuestion, DnsRecord, }; use rustdns_protocol::types::{QClass, QType, EDNS_DO_BIT, FLAG_RD}; use std::net::SocketAddr; use std::time::Duration; use tokio::net::UdpSocket; use tracing::debug; /// Resolve a DNS query via UDP to an upstream server. pub async fn resolve_udp(params: &ResolveParams) -> Result { let server_addr: SocketAddr = params .server_addr .parse() .map_err(|e| format!("Invalid server address '{}': {}", params.server_addr, e))?; let qtype = QType::from_str(¶ms.record_type); let id: u16 = rand::random(); // Build query packet with RD flag and EDNS0 DO bit let mut query = DnsPacket::new_query(id); query.flags = FLAG_RD; query.questions.push(DnsQuestion { name: params.name.clone(), qtype, qclass: QClass::IN, }); // Add OPT record with DO bit for DNSSEC query.additionals.push(rustdns_protocol::packet::DnsRecord { name: ".".to_string(), rtype: QType::OPT, rclass: QClass::from_u16(4096), // UDP payload size ttl: 0, rdata: vec![], opt_flags: Some(EDNS_DO_BIT), }); let query_bytes = query.encode(); // Bind to an ephemeral port let bind_addr = if server_addr.is_ipv6() { "[::]:0" } else { "0.0.0.0:0" }; let socket = UdpSocket::bind(bind_addr) .await .map_err(|e| format!("Failed to bind UDP socket: {}", e))?; socket .send_to(&query_bytes, server_addr) .await .map_err(|e| format!("Failed to send UDP query: {}", e))?; let mut buf = vec![0u8; 4096]; let timeout = Duration::from_millis(params.timeout_ms); let len = tokio::time::timeout(timeout, socket.recv_from(&mut buf)) .await .map_err(|_| "UDP query timed out".to_string())? .map_err(|e| format!("Failed to receive UDP response: {}", e))? .0; let response_bytes = &buf[..len]; let response = DnsPacket::parse(response_bytes) .map_err(|e| format!("Failed to parse UDP response: {}", e))?; debug!( "UDP response: id={}, rcode={}, answers={}, ad={}", response.id, response.rcode(), response.answers.len(), response.has_ad_flag() ); let answers = decode_answers(&response.answers, response_bytes); Ok(ResolveResult { answers, ad_flag: response.has_ad_flag(), rcode: response.rcode(), }) } /// Decode answer records into ClientDnsAnswer values. pub fn decode_answers(records: &[DnsRecord], packet_bytes: &[u8]) -> Vec { let mut answers = Vec::new(); for record in records { // Skip OPT, RRSIG, DNSKEY records — they're metadata, not answer data match record.rtype { QType::OPT | QType::RRSIG | QType::DNSKEY => continue, _ => {} } let value = decode_record_value(record, packet_bytes); let value = match value { Ok(v) => v, Err(_) => continue, // skip records we can't decode }; // Strip trailing dot from name let name = record.name.strip_suffix('.').unwrap_or(&record.name).to_string(); answers.push(ClientDnsAnswer { name, rtype: record.rtype.as_str().to_string(), ttl: record.ttl, value, }); } answers } /// Decode a single record's RDATA to a string value. fn decode_record_value(record: &DnsRecord, packet_bytes: &[u8]) -> Result { // We need the rdata offset within the packet for compression pointer resolution. // Since we have the raw rdata and the full packet, we find the rdata position. let rdata_offset = find_rdata_offset(packet_bytes, &record.rdata); match record.rtype { QType::A => decode_a(&record.rdata).map_err(|e| e.to_string()), QType::AAAA => decode_aaaa(&record.rdata).map_err(|e| e.to_string()), QType::TXT => { let chunks = decode_txt(&record.rdata).map_err(|e| e.to_string())?; Ok(chunks.join("")) } QType::MX => { if let Some(offset) = rdata_offset { let (pref, exchange) = decode_mx(&record.rdata, packet_bytes, offset)?; Ok(format!("{} {}", pref, exchange)) } else { Err("Cannot find MX rdata in packet".into()) } } QType::NS | QType::CNAME | QType::PTR => { if let Some(offset) = rdata_offset { decode_name_rdata(&record.rdata, packet_bytes, offset) } else { Err("Cannot find name rdata in packet".into()) } } QType::SOA => { if let Some(offset) = rdata_offset { let soa = decode_soa(&record.rdata, packet_bytes, offset)?; Ok(format!( "{} {} {} {} {} {} {}", soa.mname, soa.rname, soa.serial, soa.refresh, soa.retry, soa.expire, soa.minimum )) } else { Err("Cannot find SOA rdata in packet".into()) } } QType::SRV => { if let Some(offset) = rdata_offset { let srv = decode_srv(&record.rdata, packet_bytes, offset)?; Ok(format!( "{} {} {} {}", srv.priority, srv.weight, srv.port, srv.target )) } else { Err("Cannot find SRV rdata in packet".into()) } } _ => { // Unknown type: return hex encoding Ok(record.rdata.iter().map(|b| format!("{:02x}", b)).collect::()) } } } /// Find the offset of the rdata bytes within the full packet buffer. /// This is needed because compression pointers in RDATA reference absolute positions. fn find_rdata_offset(packet: &[u8], rdata: &[u8]) -> Option { if rdata.is_empty() { return None; } // Search for the rdata slice within the packet let rdata_len = rdata.len(); if rdata_len > packet.len() { return None; } for i in 0..=(packet.len() - rdata_len) { if &packet[i..i + rdata_len] == rdata { return Some(i); } } None }