194 lines
6.4 KiB
Rust
194 lines
6.4 KiB
Rust
use crate::ipc_types::{ClientDnsAnswer, ResolveParams, ResolveResult};
|
|
use rustdns_protocol::packet::{
|
|
decode_a, decode_aaaa, decode_mx, decode_name_rdata, decode_soa, decode_srv, decode_txt,
|
|
DnsPacket, DnsQuestion, DnsRecord,
|
|
};
|
|
use rustdns_protocol::types::{QClass, QType, EDNS_DO_BIT, FLAG_RD};
|
|
use std::net::SocketAddr;
|
|
use std::time::Duration;
|
|
use tokio::net::UdpSocket;
|
|
use tracing::debug;
|
|
|
|
/// Resolve a DNS query via UDP to an upstream server.
|
|
pub async fn resolve_udp(params: &ResolveParams) -> Result<ResolveResult, String> {
|
|
let server_addr: SocketAddr = params
|
|
.server_addr
|
|
.parse()
|
|
.map_err(|e| format!("Invalid server address '{}': {}", params.server_addr, e))?;
|
|
|
|
let qtype = QType::from_str(¶ms.record_type);
|
|
let id: u16 = rand::random();
|
|
|
|
// Build query packet with RD flag and EDNS0 DO bit
|
|
let mut query = DnsPacket::new_query(id);
|
|
query.flags = FLAG_RD;
|
|
query.questions.push(DnsQuestion {
|
|
name: params.name.clone(),
|
|
qtype,
|
|
qclass: QClass::IN,
|
|
});
|
|
|
|
// Add OPT record with DO bit for DNSSEC
|
|
query.additionals.push(rustdns_protocol::packet::DnsRecord {
|
|
name: ".".to_string(),
|
|
rtype: QType::OPT,
|
|
rclass: QClass::from_u16(4096), // UDP payload size
|
|
ttl: 0,
|
|
rdata: vec![],
|
|
opt_flags: Some(EDNS_DO_BIT),
|
|
});
|
|
|
|
let query_bytes = query.encode();
|
|
|
|
// Bind to an ephemeral port
|
|
let bind_addr = if server_addr.is_ipv6() {
|
|
"[::]:0"
|
|
} else {
|
|
"0.0.0.0:0"
|
|
};
|
|
let socket = UdpSocket::bind(bind_addr)
|
|
.await
|
|
.map_err(|e| format!("Failed to bind UDP socket: {}", e))?;
|
|
|
|
socket
|
|
.send_to(&query_bytes, server_addr)
|
|
.await
|
|
.map_err(|e| format!("Failed to send UDP query: {}", e))?;
|
|
|
|
let mut buf = vec![0u8; 4096];
|
|
let timeout = Duration::from_millis(params.timeout_ms);
|
|
|
|
let len = tokio::time::timeout(timeout, socket.recv_from(&mut buf))
|
|
.await
|
|
.map_err(|_| "UDP query timed out".to_string())?
|
|
.map_err(|e| format!("Failed to receive UDP response: {}", e))?
|
|
.0;
|
|
|
|
let response_bytes = &buf[..len];
|
|
let response = DnsPacket::parse(response_bytes)
|
|
.map_err(|e| format!("Failed to parse UDP response: {}", e))?;
|
|
|
|
debug!(
|
|
"UDP response: id={}, rcode={}, answers={}, ad={}",
|
|
response.id,
|
|
response.rcode(),
|
|
response.answers.len(),
|
|
response.has_ad_flag()
|
|
);
|
|
|
|
let answers = decode_answers(&response.answers, response_bytes);
|
|
|
|
Ok(ResolveResult {
|
|
answers,
|
|
ad_flag: response.has_ad_flag(),
|
|
rcode: response.rcode(),
|
|
})
|
|
}
|
|
|
|
/// Decode answer records into ClientDnsAnswer values.
|
|
pub fn decode_answers(records: &[DnsRecord], packet_bytes: &[u8]) -> Vec<ClientDnsAnswer> {
|
|
let mut answers = Vec::new();
|
|
|
|
for record in records {
|
|
// Skip OPT, RRSIG, DNSKEY records — they're metadata, not answer data
|
|
match record.rtype {
|
|
QType::OPT | QType::RRSIG | QType::DNSKEY => continue,
|
|
_ => {}
|
|
}
|
|
|
|
let value = decode_record_value(record, packet_bytes);
|
|
let value = match value {
|
|
Ok(v) => v,
|
|
Err(_) => continue, // skip records we can't decode
|
|
};
|
|
|
|
// Strip trailing dot from name
|
|
let name = record.name.strip_suffix('.').unwrap_or(&record.name).to_string();
|
|
|
|
answers.push(ClientDnsAnswer {
|
|
name,
|
|
rtype: record.rtype.as_str().to_string(),
|
|
ttl: record.ttl,
|
|
value,
|
|
});
|
|
}
|
|
|
|
answers
|
|
}
|
|
|
|
/// Decode a single record's RDATA to a string value.
|
|
fn decode_record_value(record: &DnsRecord, packet_bytes: &[u8]) -> Result<String, String> {
|
|
// We need the rdata offset within the packet for compression pointer resolution.
|
|
// Since we have the raw rdata and the full packet, we find the rdata position.
|
|
let rdata_offset = find_rdata_offset(packet_bytes, &record.rdata);
|
|
|
|
match record.rtype {
|
|
QType::A => decode_a(&record.rdata).map_err(|e| e.to_string()),
|
|
QType::AAAA => decode_aaaa(&record.rdata).map_err(|e| e.to_string()),
|
|
QType::TXT => {
|
|
let chunks = decode_txt(&record.rdata).map_err(|e| e.to_string())?;
|
|
Ok(chunks.join(""))
|
|
}
|
|
QType::MX => {
|
|
if let Some(offset) = rdata_offset {
|
|
let (pref, exchange) = decode_mx(&record.rdata, packet_bytes, offset)?;
|
|
Ok(format!("{} {}", pref, exchange))
|
|
} else {
|
|
Err("Cannot find MX rdata in packet".into())
|
|
}
|
|
}
|
|
QType::NS | QType::CNAME | QType::PTR => {
|
|
if let Some(offset) = rdata_offset {
|
|
decode_name_rdata(&record.rdata, packet_bytes, offset)
|
|
} else {
|
|
Err("Cannot find name rdata in packet".into())
|
|
}
|
|
}
|
|
QType::SOA => {
|
|
if let Some(offset) = rdata_offset {
|
|
let soa = decode_soa(&record.rdata, packet_bytes, offset)?;
|
|
Ok(format!(
|
|
"{} {} {} {} {} {} {}",
|
|
soa.mname, soa.rname, soa.serial, soa.refresh, soa.retry, soa.expire, soa.minimum
|
|
))
|
|
} else {
|
|
Err("Cannot find SOA rdata in packet".into())
|
|
}
|
|
}
|
|
QType::SRV => {
|
|
if let Some(offset) = rdata_offset {
|
|
let srv = decode_srv(&record.rdata, packet_bytes, offset)?;
|
|
Ok(format!(
|
|
"{} {} {} {}",
|
|
srv.priority, srv.weight, srv.port, srv.target
|
|
))
|
|
} else {
|
|
Err("Cannot find SRV rdata in packet".into())
|
|
}
|
|
}
|
|
_ => {
|
|
// Unknown type: return hex encoding
|
|
Ok(record.rdata.iter().map(|b| format!("{:02x}", b)).collect::<String>())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Find the offset of the rdata bytes within the full packet buffer.
|
|
/// This is needed because compression pointers in RDATA reference absolute positions.
|
|
fn find_rdata_offset(packet: &[u8], rdata: &[u8]) -> Option<usize> {
|
|
if rdata.is_empty() {
|
|
return None;
|
|
}
|
|
// Search for the rdata slice within the packet
|
|
let rdata_len = rdata.len();
|
|
if rdata_len > packet.len() {
|
|
return None;
|
|
}
|
|
for i in 0..=(packet.len() - rdata_len) {
|
|
if &packet[i..i + rdata_len] == rdata {
|
|
return Some(i);
|
|
}
|
|
}
|
|
None
|
|
}
|