feat(rustdns-client): add Rust DNS client binary and TypeScript IPC bridge to enable UDP and DoH resolution, RDATA decoding, and DNSSEC AD/rcode support
This commit is contained in:
752
rust/Cargo.lock
generated
752
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -5,4 +5,5 @@ members = [
|
||||
"crates/rustdns-protocol",
|
||||
"crates/rustdns-server",
|
||||
"crates/rustdns-dnssec",
|
||||
"crates/rustdns-client",
|
||||
]
|
||||
|
||||
20
rust/crates/rustdns-client/Cargo.toml
Normal file
20
rust/crates/rustdns-client/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "rustdns-client"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "rustdns-client"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
rustdns-protocol = { path = "../rustdns-protocol" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
rand = "0.9"
|
||||
94
rust/crates/rustdns-client/src/ipc_types.rs
Normal file
94
rust/crates/rustdns-client/src/ipc_types.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// IPC request from TypeScript to Rust (via stdin).
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IpcRequest {
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// IPC response from Rust to TypeScript (via stdout).
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IpcResponse {
|
||||
pub id: String,
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl IpcResponse {
|
||||
pub fn ok(id: String, result: serde_json::Value) -> Self {
|
||||
IpcResponse {
|
||||
id,
|
||||
success: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(id: String, error: String) -> Self {
|
||||
IpcResponse {
|
||||
id,
|
||||
success: false,
|
||||
result: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// IPC event from Rust to TypeScript (unsolicited, no id).
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IpcEvent {
|
||||
pub event: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Parameters for a DNS resolve request.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResolveParams {
|
||||
pub name: String,
|
||||
pub record_type: String,
|
||||
pub protocol: String,
|
||||
#[serde(default = "default_server_addr")]
|
||||
pub server_addr: String,
|
||||
#[serde(default = "default_doh_url")]
|
||||
pub doh_url: String,
|
||||
#[serde(default = "default_timeout_ms")]
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
fn default_server_addr() -> String {
|
||||
"1.1.1.1:53".to_string()
|
||||
}
|
||||
|
||||
fn default_doh_url() -> String {
|
||||
"https://cloudflare-dns.com/dns-query".to_string()
|
||||
}
|
||||
|
||||
fn default_timeout_ms() -> u64 {
|
||||
5000
|
||||
}
|
||||
|
||||
/// A single DNS answer record sent back to TypeScript.
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct ClientDnsAnswer {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub rtype: String,
|
||||
pub ttl: u32,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
/// Result of a DNS resolve request.
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResolveResult {
|
||||
pub answers: Vec<ClientDnsAnswer>,
|
||||
pub ad_flag: bool,
|
||||
pub rcode: u8,
|
||||
}
|
||||
36
rust/crates/rustdns-client/src/main.rs
Normal file
36
rust/crates/rustdns-client/src/main.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use clap::Parser;
|
||||
|
||||
mod ipc_types;
|
||||
mod management;
|
||||
mod resolver_doh;
|
||||
mod resolver_udp;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "rustdns-client", about = "Rust DNS client with IPC management")]
|
||||
struct Cli {
|
||||
/// Run in management mode (IPC via stdin/stdout)
|
||||
#[arg(long)]
|
||||
management: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Install the default rustls crypto provider (ring) before any TLS operations
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Tracing writes to stderr so stdout is reserved for IPC
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
if cli.management {
|
||||
management::management_loop().await?;
|
||||
} else {
|
||||
eprintln!("rustdns-client: use --management flag for IPC mode");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
130
rust/crates/rustdns-client/src/management.rs
Normal file
130
rust/crates/rustdns-client/src/management.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use crate::ipc_types::*;
|
||||
use crate::resolver_doh;
|
||||
use crate::resolver_udp;
|
||||
use std::io::{self, BufRead, Write};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Emit a JSON event on stdout.
|
||||
fn send_event(event: &str, data: serde_json::Value) {
|
||||
let evt = IpcEvent {
|
||||
event: event.to_string(),
|
||||
data,
|
||||
};
|
||||
let json = serde_json::to_string(&evt).unwrap();
|
||||
let stdout = io::stdout();
|
||||
let mut lock = stdout.lock();
|
||||
let _ = writeln!(lock, "{}", json);
|
||||
let _ = lock.flush();
|
||||
}
|
||||
|
||||
/// Send a JSON response on stdout.
|
||||
fn send_response(response: &IpcResponse) {
|
||||
let json = serde_json::to_string(response).unwrap();
|
||||
let stdout = io::stdout();
|
||||
let mut lock = stdout.lock();
|
||||
let _ = writeln!(lock, "{}", json);
|
||||
let _ = lock.flush();
|
||||
}
|
||||
|
||||
/// Main management loop — reads JSON lines from stdin, dispatches commands.
|
||||
pub async fn management_loop() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Emit ready event
|
||||
send_event(
|
||||
"ready",
|
||||
serde_json::json!({
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}),
|
||||
);
|
||||
|
||||
// Create a shared HTTP client for DoH connection pooling
|
||||
let http_client = reqwest::Client::builder()
|
||||
.use_rustls_tls()
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
// Channel for stdin commands (read in blocking thread)
|
||||
let (cmd_tx, mut cmd_rx) = mpsc::channel::<String>(256);
|
||||
|
||||
// Spawn blocking stdin reader
|
||||
std::thread::spawn(move || {
|
||||
let stdin = io::stdin();
|
||||
let reader = stdin.lock();
|
||||
for line in reader.lines() {
|
||||
match line {
|
||||
Ok(l) => {
|
||||
if cmd_tx.blocking_send(l).is_err() {
|
||||
break; // channel closed
|
||||
}
|
||||
}
|
||||
Err(_) => break, // stdin closed
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
match cmd_rx.recv().await {
|
||||
Some(line) => {
|
||||
let request: IpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to parse IPC request: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_request(&request, &http_client).await;
|
||||
send_response(&response);
|
||||
}
|
||||
None => {
|
||||
// stdin closed — parent process exited
|
||||
info!("stdin closed, shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(request: &IpcRequest, http_client: &reqwest::Client) -> IpcResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"ping" => IpcResponse::ok(id, serde_json::json!({ "pong": true })),
|
||||
|
||||
"resolve" => handle_resolve(id, &request.params, http_client).await,
|
||||
|
||||
_ => IpcResponse::err(id, format!("Unknown method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_resolve(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
http_client: &reqwest::Client,
|
||||
) -> IpcResponse {
|
||||
let resolve_params: ResolveParams = match serde_json::from_value(params.clone()) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid resolve params: {}", e)),
|
||||
};
|
||||
|
||||
let result = match resolve_params.protocol.as_str() {
|
||||
"udp" => resolver_udp::resolve_udp(&resolve_params).await,
|
||||
"doh" => resolver_doh::resolve_doh(&resolve_params, http_client).await,
|
||||
other => {
|
||||
return IpcResponse::err(
|
||||
id,
|
||||
format!("Unknown protocol '{}'. Use 'udp' or 'doh'.", other),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(resolve_result) => {
|
||||
let result_json = serde_json::to_value(&resolve_result).unwrap();
|
||||
IpcResponse::ok(id, result_json)
|
||||
}
|
||||
Err(e) => IpcResponse::err(id, e),
|
||||
}
|
||||
}
|
||||
75
rust/crates/rustdns-client/src/resolver_doh.rs
Normal file
75
rust/crates/rustdns-client/src/resolver_doh.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use crate::ipc_types::{ResolveParams, ResolveResult};
|
||||
use crate::resolver_udp::decode_answers;
|
||||
use rustdns_protocol::packet::{DnsPacket, DnsQuestion};
|
||||
use rustdns_protocol::types::{QClass, QType, EDNS_DO_BIT, FLAG_RD};
|
||||
use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
/// Resolve a DNS query via DNS-over-HTTPS (RFC 8484 wire format).
|
||||
pub async fn resolve_doh(
|
||||
params: &ResolveParams,
|
||||
http_client: &reqwest::Client,
|
||||
) -> Result<ResolveResult, String> {
|
||||
let qtype = QType::from_str(¶ms.record_type);
|
||||
let id: u16 = rand::random();
|
||||
|
||||
// Build query packet (same as UDP)
|
||||
let mut query = DnsPacket::new_query(id);
|
||||
query.flags = FLAG_RD;
|
||||
query.questions.push(DnsQuestion {
|
||||
name: params.name.clone(),
|
||||
qtype,
|
||||
qclass: QClass::IN,
|
||||
});
|
||||
|
||||
// Add OPT record with DO bit for DNSSEC
|
||||
query.additionals.push(rustdns_protocol::packet::DnsRecord {
|
||||
name: ".".to_string(),
|
||||
rtype: QType::OPT,
|
||||
rclass: QClass::from_u16(4096),
|
||||
ttl: 0,
|
||||
rdata: vec![],
|
||||
opt_flags: Some(EDNS_DO_BIT),
|
||||
});
|
||||
|
||||
let query_bytes = query.encode();
|
||||
let timeout = Duration::from_millis(params.timeout_ms);
|
||||
|
||||
let response = http_client
|
||||
.post(¶ms.doh_url)
|
||||
.header("Content-Type", "application/dns-message")
|
||||
.header("Accept", "application/dns-message")
|
||||
.body(query_bytes)
|
||||
.timeout(timeout)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("DoH request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("DoH server returned status {}", response.status()));
|
||||
}
|
||||
|
||||
let response_bytes = response
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read DoH response body: {}", e))?;
|
||||
|
||||
let dns_response = DnsPacket::parse(&response_bytes)
|
||||
.map_err(|e| format!("Failed to parse DoH response: {}", e))?;
|
||||
|
||||
debug!(
|
||||
"DoH response: id={}, rcode={}, answers={}, ad={}",
|
||||
dns_response.id,
|
||||
dns_response.rcode(),
|
||||
dns_response.answers.len(),
|
||||
dns_response.has_ad_flag()
|
||||
);
|
||||
|
||||
let answers = decode_answers(&dns_response.answers, &response_bytes);
|
||||
|
||||
Ok(ResolveResult {
|
||||
answers,
|
||||
ad_flag: dns_response.has_ad_flag(),
|
||||
rcode: dns_response.rcode(),
|
||||
})
|
||||
}
|
||||
193
rust/crates/rustdns-client/src/resolver_udp.rs
Normal file
193
rust/crates/rustdns-client/src/resolver_udp.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use crate::ipc_types::{ClientDnsAnswer, ResolveParams, ResolveResult};
|
||||
use rustdns_protocol::packet::{
|
||||
decode_a, decode_aaaa, decode_mx, decode_name_rdata, decode_soa, decode_srv, decode_txt,
|
||||
DnsPacket, DnsQuestion, DnsRecord,
|
||||
};
|
||||
use rustdns_protocol::types::{QClass, QType, EDNS_DO_BIT, FLAG_RD};
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use tokio::net::UdpSocket;
|
||||
use tracing::debug;
|
||||
|
||||
/// Resolve a DNS query via UDP to an upstream server.
|
||||
pub async fn resolve_udp(params: &ResolveParams) -> Result<ResolveResult, String> {
|
||||
let server_addr: SocketAddr = params
|
||||
.server_addr
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid server address '{}': {}", params.server_addr, e))?;
|
||||
|
||||
let qtype = QType::from_str(¶ms.record_type);
|
||||
let id: u16 = rand::random();
|
||||
|
||||
// Build query packet with RD flag and EDNS0 DO bit
|
||||
let mut query = DnsPacket::new_query(id);
|
||||
query.flags = FLAG_RD;
|
||||
query.questions.push(DnsQuestion {
|
||||
name: params.name.clone(),
|
||||
qtype,
|
||||
qclass: QClass::IN,
|
||||
});
|
||||
|
||||
// Add OPT record with DO bit for DNSSEC
|
||||
query.additionals.push(rustdns_protocol::packet::DnsRecord {
|
||||
name: ".".to_string(),
|
||||
rtype: QType::OPT,
|
||||
rclass: QClass::from_u16(4096), // UDP payload size
|
||||
ttl: 0,
|
||||
rdata: vec![],
|
||||
opt_flags: Some(EDNS_DO_BIT),
|
||||
});
|
||||
|
||||
let query_bytes = query.encode();
|
||||
|
||||
// Bind to an ephemeral port
|
||||
let bind_addr = if server_addr.is_ipv6() {
|
||||
"[::]:0"
|
||||
} else {
|
||||
"0.0.0.0:0"
|
||||
};
|
||||
let socket = UdpSocket::bind(bind_addr)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to bind UDP socket: {}", e))?;
|
||||
|
||||
socket
|
||||
.send_to(&query_bytes, server_addr)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to send UDP query: {}", e))?;
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let timeout = Duration::from_millis(params.timeout_ms);
|
||||
|
||||
let len = tokio::time::timeout(timeout, socket.recv_from(&mut buf))
|
||||
.await
|
||||
.map_err(|_| "UDP query timed out".to_string())?
|
||||
.map_err(|e| format!("Failed to receive UDP response: {}", e))?
|
||||
.0;
|
||||
|
||||
let response_bytes = &buf[..len];
|
||||
let response = DnsPacket::parse(response_bytes)
|
||||
.map_err(|e| format!("Failed to parse UDP response: {}", e))?;
|
||||
|
||||
debug!(
|
||||
"UDP response: id={}, rcode={}, answers={}, ad={}",
|
||||
response.id,
|
||||
response.rcode(),
|
||||
response.answers.len(),
|
||||
response.has_ad_flag()
|
||||
);
|
||||
|
||||
let answers = decode_answers(&response.answers, response_bytes);
|
||||
|
||||
Ok(ResolveResult {
|
||||
answers,
|
||||
ad_flag: response.has_ad_flag(),
|
||||
rcode: response.rcode(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode answer records into ClientDnsAnswer values.
|
||||
pub fn decode_answers(records: &[DnsRecord], packet_bytes: &[u8]) -> Vec<ClientDnsAnswer> {
|
||||
let mut answers = Vec::new();
|
||||
|
||||
for record in records {
|
||||
// Skip OPT, RRSIG, DNSKEY records — they're metadata, not answer data
|
||||
match record.rtype {
|
||||
QType::OPT | QType::RRSIG | QType::DNSKEY => continue,
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let value = decode_record_value(record, packet_bytes);
|
||||
let value = match value {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue, // skip records we can't decode
|
||||
};
|
||||
|
||||
// Strip trailing dot from name
|
||||
let name = record.name.strip_suffix('.').unwrap_or(&record.name).to_string();
|
||||
|
||||
answers.push(ClientDnsAnswer {
|
||||
name,
|
||||
rtype: record.rtype.as_str().to_string(),
|
||||
ttl: record.ttl,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
answers
|
||||
}
|
||||
|
||||
/// Decode a single record's RDATA to a string value.
|
||||
fn decode_record_value(record: &DnsRecord, packet_bytes: &[u8]) -> Result<String, String> {
|
||||
// We need the rdata offset within the packet for compression pointer resolution.
|
||||
// Since we have the raw rdata and the full packet, we find the rdata position.
|
||||
let rdata_offset = find_rdata_offset(packet_bytes, &record.rdata);
|
||||
|
||||
match record.rtype {
|
||||
QType::A => decode_a(&record.rdata).map_err(|e| e.to_string()),
|
||||
QType::AAAA => decode_aaaa(&record.rdata).map_err(|e| e.to_string()),
|
||||
QType::TXT => {
|
||||
let chunks = decode_txt(&record.rdata).map_err(|e| e.to_string())?;
|
||||
Ok(chunks.join(""))
|
||||
}
|
||||
QType::MX => {
|
||||
if let Some(offset) = rdata_offset {
|
||||
let (pref, exchange) = decode_mx(&record.rdata, packet_bytes, offset)?;
|
||||
Ok(format!("{} {}", pref, exchange))
|
||||
} else {
|
||||
Err("Cannot find MX rdata in packet".into())
|
||||
}
|
||||
}
|
||||
QType::NS | QType::CNAME | QType::PTR => {
|
||||
if let Some(offset) = rdata_offset {
|
||||
decode_name_rdata(&record.rdata, packet_bytes, offset)
|
||||
} else {
|
||||
Err("Cannot find name rdata in packet".into())
|
||||
}
|
||||
}
|
||||
QType::SOA => {
|
||||
if let Some(offset) = rdata_offset {
|
||||
let soa = decode_soa(&record.rdata, packet_bytes, offset)?;
|
||||
Ok(format!(
|
||||
"{} {} {} {} {} {} {}",
|
||||
soa.mname, soa.rname, soa.serial, soa.refresh, soa.retry, soa.expire, soa.minimum
|
||||
))
|
||||
} else {
|
||||
Err("Cannot find SOA rdata in packet".into())
|
||||
}
|
||||
}
|
||||
QType::SRV => {
|
||||
if let Some(offset) = rdata_offset {
|
||||
let srv = decode_srv(&record.rdata, packet_bytes, offset)?;
|
||||
Ok(format!(
|
||||
"{} {} {} {}",
|
||||
srv.priority, srv.weight, srv.port, srv.target
|
||||
))
|
||||
} else {
|
||||
Err("Cannot find SRV rdata in packet".into())
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown type: return hex encoding
|
||||
Ok(record.rdata.iter().map(|b| format!("{:02x}", b)).collect::<String>())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the offset of the rdata bytes within the full packet buffer.
|
||||
/// This is needed because compression pointers in RDATA reference absolute positions.
|
||||
fn find_rdata_offset(packet: &[u8], rdata: &[u8]) -> Option<usize> {
|
||||
if rdata.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Search for the rdata slice within the packet
|
||||
let rdata_len = rdata.len();
|
||||
if rdata_len > packet.len() {
|
||||
return None;
|
||||
}
|
||||
for i in 0..=(packet.len() - rdata_len) {
|
||||
if &packet[i..i + rdata_len] == rdata {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::name::{decode_name, encode_name};
|
||||
use crate::types::{QClass, QType, FLAG_QR, FLAG_AA, FLAG_RD, FLAG_RA, EDNS_DO_BIT};
|
||||
use crate::types::{QClass, QType, FLAG_QR, FLAG_AA, FLAG_RD, FLAG_RA, FLAG_AD, EDNS_DO_BIT};
|
||||
|
||||
/// A parsed DNS question.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -61,6 +61,16 @@ impl DnsPacket {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the response code (lower 4 bits of flags).
|
||||
pub fn rcode(&self) -> u8 {
|
||||
(self.flags & 0x000F) as u8
|
||||
}
|
||||
|
||||
/// Check if the AD (Authenticated Data) flag is set.
|
||||
pub fn has_ad_flag(&self) -> bool {
|
||||
self.flags & FLAG_AD != 0
|
||||
}
|
||||
|
||||
/// Check if DNSSEC (DO bit) is requested in the OPT record.
|
||||
pub fn is_dnssec_requested(&self) -> bool {
|
||||
for additional in &self.additionals {
|
||||
@@ -335,6 +345,181 @@ pub fn encode_rrsig(
|
||||
buf
|
||||
}
|
||||
|
||||
// ── RDATA decoding helpers ─────────────────────────────────────────
|
||||
|
||||
/// Decode an A record (4 bytes -> IPv4 string).
|
||||
pub fn decode_a(rdata: &[u8]) -> Result<String, &'static str> {
|
||||
if rdata.len() < 4 {
|
||||
return Err("A rdata too short");
|
||||
}
|
||||
Ok(format!("{}.{}.{}.{}", rdata[0], rdata[1], rdata[2], rdata[3]))
|
||||
}
|
||||
|
||||
/// Decode an AAAA record (16 bytes -> IPv6 string).
|
||||
pub fn decode_aaaa(rdata: &[u8]) -> Result<String, &'static str> {
|
||||
if rdata.len() < 16 {
|
||||
return Err("AAAA rdata too short");
|
||||
}
|
||||
let groups: Vec<String> = (0..8)
|
||||
.map(|i| {
|
||||
let val = u16::from_be_bytes([rdata[i * 2], rdata[i * 2 + 1]]);
|
||||
format!("{:x}", val)
|
||||
})
|
||||
.collect();
|
||||
// Build full form, then compress :: notation
|
||||
let full = groups.join(":");
|
||||
compress_ipv6(&full)
|
||||
}
|
||||
|
||||
/// Compress a full IPv6 address to shortest form.
|
||||
fn compress_ipv6(full: &str) -> Result<String, &'static str> {
|
||||
let groups: Vec<&str> = full.split(':').collect();
|
||||
if groups.len() != 8 {
|
||||
return Ok(full.to_string());
|
||||
}
|
||||
|
||||
// Find longest run of consecutive "0" groups
|
||||
let mut best_start = None;
|
||||
let mut best_len = 0usize;
|
||||
let mut cur_start = None;
|
||||
let mut cur_len = 0usize;
|
||||
|
||||
for (i, g) in groups.iter().enumerate() {
|
||||
if *g == "0" {
|
||||
if cur_start.is_none() {
|
||||
cur_start = Some(i);
|
||||
cur_len = 1;
|
||||
} else {
|
||||
cur_len += 1;
|
||||
}
|
||||
if cur_len > best_len {
|
||||
best_start = cur_start;
|
||||
best_len = cur_len;
|
||||
}
|
||||
} else {
|
||||
cur_start = None;
|
||||
cur_len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if best_len >= 2 {
|
||||
let bs = best_start.unwrap();
|
||||
let left: Vec<&str> = groups[..bs].to_vec();
|
||||
let right: Vec<&str> = groups[bs + best_len..].to_vec();
|
||||
let l = left.join(":");
|
||||
let r = right.join(":");
|
||||
if l.is_empty() && r.is_empty() {
|
||||
Ok("::".to_string())
|
||||
} else if l.is_empty() {
|
||||
Ok(format!("::{}", r))
|
||||
} else if r.is_empty() {
|
||||
Ok(format!("{}::", l))
|
||||
} else {
|
||||
Ok(format!("{}::{}", l, r))
|
||||
}
|
||||
} else {
|
||||
Ok(full.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a TXT record (length-prefixed chunks -> strings).
|
||||
pub fn decode_txt(rdata: &[u8]) -> Result<Vec<String>, &'static str> {
|
||||
let mut strings = Vec::new();
|
||||
let mut pos = 0;
|
||||
while pos < rdata.len() {
|
||||
let len = rdata[pos] as usize;
|
||||
pos += 1;
|
||||
if pos + len > rdata.len() {
|
||||
return Err("TXT chunk extends beyond rdata");
|
||||
}
|
||||
let s = std::str::from_utf8(&rdata[pos..pos + len])
|
||||
.map_err(|_| "invalid UTF-8 in TXT")?;
|
||||
strings.push(s.to_string());
|
||||
pos += len;
|
||||
}
|
||||
Ok(strings)
|
||||
}
|
||||
|
||||
/// Decode an MX record (preference + exchange name with compression).
|
||||
pub fn decode_mx(rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result<(u16, String), String> {
|
||||
if rdata.len() < 3 {
|
||||
return Err("MX rdata too short".into());
|
||||
}
|
||||
let preference = u16::from_be_bytes([rdata[0], rdata[1]]);
|
||||
let (name, _) = decode_name(packet, rdata_offset + 2).map_err(|e| e.to_string())?;
|
||||
Ok((preference, name))
|
||||
}
|
||||
|
||||
/// Decode a name from RDATA (for NS, CNAME, PTR records with compression).
|
||||
pub fn decode_name_rdata(_rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result<String, String> {
|
||||
let (name, _) = decode_name(packet, rdata_offset).map_err(|e| e.to_string())?;
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
/// SOA record decoded fields.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SoaData {
|
||||
pub mname: String,
|
||||
pub rname: String,
|
||||
pub serial: u32,
|
||||
pub refresh: u32,
|
||||
pub retry: u32,
|
||||
pub expire: u32,
|
||||
pub minimum: u32,
|
||||
}
|
||||
|
||||
/// Decode a SOA record RDATA.
|
||||
pub fn decode_soa(rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result<SoaData, String> {
|
||||
let (mname, consumed1) = decode_name(packet, rdata_offset).map_err(|e| e.to_string())?;
|
||||
let (rname, consumed2) = decode_name(packet, rdata_offset + consumed1).map_err(|e| e.to_string())?;
|
||||
let nums_offset = consumed1 + consumed2;
|
||||
if rdata.len() < nums_offset + 20 {
|
||||
return Err("SOA rdata too short for numeric fields".into());
|
||||
}
|
||||
let serial = u32::from_be_bytes([
|
||||
rdata[nums_offset], rdata[nums_offset + 1],
|
||||
rdata[nums_offset + 2], rdata[nums_offset + 3],
|
||||
]);
|
||||
let refresh = u32::from_be_bytes([
|
||||
rdata[nums_offset + 4], rdata[nums_offset + 5],
|
||||
rdata[nums_offset + 6], rdata[nums_offset + 7],
|
||||
]);
|
||||
let retry = u32::from_be_bytes([
|
||||
rdata[nums_offset + 8], rdata[nums_offset + 9],
|
||||
rdata[nums_offset + 10], rdata[nums_offset + 11],
|
||||
]);
|
||||
let expire = u32::from_be_bytes([
|
||||
rdata[nums_offset + 12], rdata[nums_offset + 13],
|
||||
rdata[nums_offset + 14], rdata[nums_offset + 15],
|
||||
]);
|
||||
let minimum = u32::from_be_bytes([
|
||||
rdata[nums_offset + 16], rdata[nums_offset + 17],
|
||||
rdata[nums_offset + 18], rdata[nums_offset + 19],
|
||||
]);
|
||||
Ok(SoaData { mname, rname, serial, refresh, retry, expire, minimum })
|
||||
}
|
||||
|
||||
/// SRV record decoded fields.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SrvData {
|
||||
pub priority: u16,
|
||||
pub weight: u16,
|
||||
pub port: u16,
|
||||
pub target: String,
|
||||
}
|
||||
|
||||
/// Decode a SRV record RDATA.
|
||||
pub fn decode_srv(rdata: &[u8], packet: &[u8], rdata_offset: usize) -> Result<SrvData, String> {
|
||||
if rdata.len() < 7 {
|
||||
return Err("SRV rdata too short".into());
|
||||
}
|
||||
let priority = u16::from_be_bytes([rdata[0], rdata[1]]);
|
||||
let weight = u16::from_be_bytes([rdata[2], rdata[3]]);
|
||||
let port = u16::from_be_bytes([rdata[4], rdata[5]]);
|
||||
let (target, _) = decode_name(packet, rdata_offset + 6).map_err(|e| e.to_string())?;
|
||||
Ok(SrvData { priority, weight, port, target })
|
||||
}
|
||||
|
||||
/// Build a DnsRecord from high-level data.
|
||||
pub fn build_record(name: &str, rtype: QType, ttl: u32, rdata: Vec<u8>) -> DnsRecord {
|
||||
DnsRecord {
|
||||
@@ -416,6 +601,45 @@ mod tests {
|
||||
assert_eq!(&data[7..12], b"world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_a() {
|
||||
let rdata = encode_a("192.168.1.1");
|
||||
let decoded = decode_a(&rdata).unwrap();
|
||||
assert_eq!(decoded, "192.168.1.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_aaaa() {
|
||||
let rdata = encode_aaaa("::1");
|
||||
let decoded = decode_aaaa(&rdata).unwrap();
|
||||
assert_eq!(decoded, "::1");
|
||||
|
||||
let rdata2 = encode_aaaa("2001:db8::1");
|
||||
let decoded2 = decode_aaaa(&rdata2).unwrap();
|
||||
assert_eq!(decoded2, "2001:db8::1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_txt() {
|
||||
let strings = vec!["hello".to_string(), "world".to_string()];
|
||||
let rdata = encode_txt(&strings);
|
||||
let decoded = decode_txt(&rdata).unwrap();
|
||||
assert_eq!(decoded, strings);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rcode_and_ad_flag() {
|
||||
let mut pkt = DnsPacket::new_query(1);
|
||||
assert_eq!(pkt.rcode(), 0);
|
||||
assert!(!pkt.has_ad_flag());
|
||||
|
||||
pkt.flags |= crate::types::FLAG_AD;
|
||||
assert!(pkt.has_ad_flag());
|
||||
|
||||
pkt.flags |= 0x0003; // NXDOMAIN
|
||||
assert_eq!(pkt.rcode(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dnssec_do_bit() {
|
||||
let mut query = DnsPacket::new_query(1);
|
||||
|
||||
@@ -127,5 +127,8 @@ pub const FLAG_AA: u16 = 0x0400;
|
||||
pub const FLAG_RD: u16 = 0x0100;
|
||||
pub const FLAG_RA: u16 = 0x0080;
|
||||
|
||||
/// Authenticated Data flag
|
||||
pub const FLAG_AD: u16 = 0x0020;
|
||||
|
||||
/// OPT record DO bit (DNSSEC OK)
|
||||
pub const EDNS_DO_BIT: u16 = 0x8000;
|
||||
|
||||
Reference in New Issue
Block a user