feat(rust): add Rust-based DNS server backend with IPC management and TypeScript bridge
This commit is contained in:
125
rust/crates/rustdns/src/ipc_types.rs
Normal file
125
rust/crates/rustdns/src/ipc_types.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// IPC request from TypeScript to Rust (via stdin).
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IpcRequest {
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// IPC response from Rust to TypeScript (via stdout).
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IpcResponse {
|
||||
pub id: String,
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl IpcResponse {
|
||||
pub fn ok(id: String, result: serde_json::Value) -> Self {
|
||||
IpcResponse {
|
||||
id,
|
||||
success: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(id: String, error: String) -> Self {
|
||||
IpcResponse {
|
||||
id,
|
||||
success: false,
|
||||
result: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// IPC event from Rust to TypeScript (unsolicited, no id).
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IpcEvent {
|
||||
pub event: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Configuration sent via the "start" command.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RustDnsConfig {
|
||||
pub udp_port: u16,
|
||||
pub https_port: u16,
|
||||
#[serde(default = "default_bind")]
|
||||
pub udp_bind_interface: String,
|
||||
#[serde(default = "default_bind")]
|
||||
pub https_bind_interface: String,
|
||||
#[serde(default)]
|
||||
pub https_key: String,
|
||||
#[serde(default)]
|
||||
pub https_cert: String,
|
||||
pub dnssec_zone: String,
|
||||
#[serde(default = "default_algorithm")]
|
||||
pub dnssec_algorithm: String,
|
||||
#[serde(default)]
|
||||
pub primary_nameserver: String,
|
||||
#[serde(default = "default_true")]
|
||||
pub enable_localhost_handling: bool,
|
||||
#[serde(default)]
|
||||
pub manual_udp_mode: bool,
|
||||
#[serde(default)]
|
||||
pub manual_https_mode: bool,
|
||||
}
|
||||
|
||||
fn default_bind() -> String {
|
||||
"0.0.0.0".to_string()
|
||||
}
|
||||
|
||||
fn default_algorithm() -> String {
|
||||
"ECDSA".to_string()
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// A DNS question as sent over IPC.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct IpcDnsQuestion {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub qtype: String,
|
||||
pub class: String,
|
||||
}
|
||||
|
||||
/// A DNS answer as received from TypeScript over IPC.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct IpcDnsAnswer {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub rtype: String,
|
||||
pub class: String,
|
||||
pub ttl: u32,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// The dnsQuery event sent from Rust to TypeScript.
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DnsQueryEvent {
|
||||
pub correlation_id: String,
|
||||
pub questions: Vec<IpcDnsQuestion>,
|
||||
pub dnssec_requested: bool,
|
||||
}
|
||||
|
||||
/// The dnsQueryResult command from TypeScript to Rust.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DnsQueryResult {
|
||||
pub correlation_id: String,
|
||||
pub answers: Vec<IpcDnsAnswer>,
|
||||
pub answered: bool,
|
||||
}
|
||||
3
rust/crates/rustdns/src/lib.rs
Normal file
3
rust/crates/rustdns/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod management;
|
||||
pub mod ipc_types;
|
||||
pub mod resolver;
|
||||
36
rust/crates/rustdns/src/main.rs
Normal file
36
rust/crates/rustdns/src/main.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use clap::Parser;
|
||||
use tracing_subscriber;
|
||||
|
||||
mod management;
|
||||
mod ipc_types;
|
||||
mod resolver;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "rustdns", about = "Rust DNS server with IPC management")]
|
||||
struct Cli {
|
||||
/// Run in management mode (IPC via stdin/stdout)
|
||||
#[arg(long)]
|
||||
management: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Install the default rustls crypto provider (ring) before any TLS operations
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Tracing writes to stderr so stdout is reserved for IPC
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
if cli.management {
|
||||
management::management_loop().await?;
|
||||
} else {
|
||||
eprintln!("rustdns: use --management flag for IPC mode");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
402
rust/crates/rustdns/src/management.rs
Normal file
402
rust/crates/rustdns/src/management.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use crate::ipc_types::*;
|
||||
use crate::resolver::DnsResolver;
|
||||
use dashmap::DashMap;
|
||||
use rustdns_dnssec::keys::DnssecAlgorithm;
|
||||
use rustdns_protocol::packet::DnsPacket;
|
||||
use rustdns_server::https::{self, HttpsServer};
|
||||
use rustdns_server::udp::{UdpServer, UdpServerConfig};
|
||||
use std::io::{self, BufRead, Write};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Pending DNS query callbacks waiting for TypeScript response.
|
||||
type PendingCallbacks = Arc<DashMap<String, oneshot::Sender<DnsQueryResult>>>;
|
||||
|
||||
/// Active server state.
|
||||
struct ServerState {
|
||||
udp_server: Option<UdpServer>,
|
||||
https_server: Option<HttpsServer>,
|
||||
resolver: Arc<DnsResolver>,
|
||||
}
|
||||
|
||||
/// Emit a JSON event on stdout.
|
||||
fn send_event(event: &str, data: serde_json::Value) {
|
||||
let evt = IpcEvent {
|
||||
event: event.to_string(),
|
||||
data,
|
||||
};
|
||||
let json = serde_json::to_string(&evt).unwrap();
|
||||
let stdout = io::stdout();
|
||||
let mut lock = stdout.lock();
|
||||
let _ = writeln!(lock, "{}", json);
|
||||
let _ = lock.flush();
|
||||
}
|
||||
|
||||
/// Send a JSON response on stdout.
|
||||
fn send_response(response: &IpcResponse) {
|
||||
let json = serde_json::to_string(response).unwrap();
|
||||
let stdout = io::stdout();
|
||||
let mut lock = stdout.lock();
|
||||
let _ = writeln!(lock, "{}", json);
|
||||
let _ = lock.flush();
|
||||
}
|
||||
|
||||
/// Main management loop — reads JSON lines from stdin, dispatches commands.
|
||||
pub async fn management_loop() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Emit ready event
|
||||
send_event("ready", serde_json::json!({
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}));
|
||||
|
||||
let pending: PendingCallbacks = Arc::new(DashMap::new());
|
||||
let mut server_state: Option<ServerState> = None;
|
||||
|
||||
// Channel for stdin commands (read in blocking thread)
|
||||
let (cmd_tx, mut cmd_rx) = mpsc::channel::<String>(256);
|
||||
|
||||
// Channel for DNS query events from the server
|
||||
let (query_tx, mut query_rx) = mpsc::channel::<(String, DnsPacket)>(256);
|
||||
|
||||
// Spawn blocking stdin reader
|
||||
std::thread::spawn(move || {
|
||||
let stdin = io::stdin();
|
||||
let reader = stdin.lock();
|
||||
for line in reader.lines() {
|
||||
match line {
|
||||
Ok(l) => {
|
||||
if cmd_tx.blocking_send(l).is_err() {
|
||||
break; // channel closed
|
||||
}
|
||||
}
|
||||
Err(_) => break, // stdin closed
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
cmd = cmd_rx.recv() => {
|
||||
match cmd {
|
||||
Some(line) => {
|
||||
let request: IpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to parse IPC request: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_request(
|
||||
&request,
|
||||
&mut server_state,
|
||||
&pending,
|
||||
&query_tx,
|
||||
).await;
|
||||
send_response(&response);
|
||||
}
|
||||
None => {
|
||||
// stdin closed — parent process exited
|
||||
info!("stdin closed, shutting down");
|
||||
if let Some(ref state) = server_state {
|
||||
if let Some(ref udp) = state.udp_server {
|
||||
udp.stop();
|
||||
}
|
||||
if let Some(ref https) = state.https_server {
|
||||
https.stop();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
query = query_rx.recv() => {
|
||||
if let Some((correlation_id, packet)) = query {
|
||||
let dnssec = packet.is_dnssec_requested();
|
||||
let questions = DnsResolver::questions_to_ipc(&packet.questions);
|
||||
|
||||
send_event("dnsQuery", serde_json::to_value(&DnsQueryEvent {
|
||||
correlation_id,
|
||||
questions,
|
||||
dnssec_requested: dnssec,
|
||||
}).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
request: &IpcRequest,
|
||||
server_state: &mut Option<ServerState>,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> IpcResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"ping" => IpcResponse::ok(id, serde_json::json!({ "pong": true })),
|
||||
|
||||
"start" => {
|
||||
handle_start(id, &request.params, server_state, pending, query_tx).await
|
||||
}
|
||||
|
||||
"stop" => {
|
||||
handle_stop(id, server_state)
|
||||
}
|
||||
|
||||
"dnsQueryResult" => {
|
||||
handle_query_result(id, &request.params, pending)
|
||||
}
|
||||
|
||||
"updateCerts" => {
|
||||
// TODO: hot-swap TLS certs (requires rustls cert resolver)
|
||||
IpcResponse::ok(id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
"processPacket" => {
|
||||
handle_process_packet(id, &request.params, server_state, pending, query_tx).await
|
||||
}
|
||||
|
||||
_ => IpcResponse::err(id, format!("Unknown method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_start(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
server_state: &mut Option<ServerState>,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> IpcResponse {
|
||||
let config: RustDnsConfig = match serde_json::from_value(params.get("config").cloned().unwrap_or_default()) {
|
||||
Ok(c) => c,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid config: {}", e)),
|
||||
};
|
||||
|
||||
let algorithm = DnssecAlgorithm::from_str(&config.dnssec_algorithm)
|
||||
.unwrap_or(DnssecAlgorithm::EcdsaP256Sha256);
|
||||
|
||||
let resolver = Arc::new(DnsResolver::new(
|
||||
&config.dnssec_zone,
|
||||
algorithm,
|
||||
&config.primary_nameserver,
|
||||
config.enable_localhost_handling,
|
||||
));
|
||||
|
||||
// Start UDP server if not manual mode
|
||||
let udp_server = if !config.manual_udp_mode {
|
||||
let addr: SocketAddr = format!("{}:{}", config.udp_bind_interface, config.udp_port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], config.udp_port)));
|
||||
|
||||
let resolver_clone = resolver.clone();
|
||||
let pending_clone = pending.clone();
|
||||
let query_tx_clone = query_tx.clone();
|
||||
|
||||
match UdpServer::start(
|
||||
UdpServerConfig { bind_addr: addr },
|
||||
move |packet| {
|
||||
let resolver = resolver_clone.clone();
|
||||
let pending = pending_clone.clone();
|
||||
let query_tx = query_tx_clone.clone();
|
||||
async move {
|
||||
resolve_with_callback(packet, &resolver, &pending, &query_tx).await
|
||||
}
|
||||
},
|
||||
).await {
|
||||
Ok(server) => {
|
||||
info!("UDP DNS server started on {}", addr);
|
||||
Some(server)
|
||||
}
|
||||
Err(e) => {
|
||||
return IpcResponse::err(id, format!("Failed to start UDP server: {}", e));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Start HTTPS server if not manual mode and certs are provided
|
||||
let https_server = if !config.manual_https_mode && !config.https_cert.is_empty() && !config.https_key.is_empty() {
|
||||
let addr: SocketAddr = format!("{}:{}", config.https_bind_interface, config.https_port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], config.https_port)));
|
||||
|
||||
match https::create_tls_config(&config.https_cert, &config.https_key) {
|
||||
Ok(tls_config) => {
|
||||
let resolver_clone = resolver.clone();
|
||||
let pending_clone = pending.clone();
|
||||
let query_tx_clone = query_tx.clone();
|
||||
|
||||
match HttpsServer::start(
|
||||
https::HttpsServerConfig {
|
||||
bind_addr: addr,
|
||||
tls_config,
|
||||
},
|
||||
move |packet| {
|
||||
let resolver = resolver_clone.clone();
|
||||
let pending = pending_clone.clone();
|
||||
let query_tx = query_tx_clone.clone();
|
||||
async move {
|
||||
resolve_with_callback(packet, &resolver, &pending, &query_tx).await
|
||||
}
|
||||
},
|
||||
).await {
|
||||
Ok(server) => {
|
||||
info!("HTTPS DoH server started on {}", addr);
|
||||
Some(server)
|
||||
}
|
||||
Err(e) => {
|
||||
return IpcResponse::err(id, format!("Failed to start HTTPS server: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return IpcResponse::err(id, format!("Failed to configure TLS: {}", e));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
*server_state = Some(ServerState {
|
||||
udp_server,
|
||||
https_server,
|
||||
resolver,
|
||||
});
|
||||
|
||||
send_event("started", serde_json::json!({}));
|
||||
IpcResponse::ok(id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn handle_stop(id: String, server_state: &mut Option<ServerState>) -> IpcResponse {
|
||||
if let Some(ref state) = server_state {
|
||||
if let Some(ref udp) = state.udp_server {
|
||||
udp.stop();
|
||||
}
|
||||
if let Some(ref https) = state.https_server {
|
||||
https.stop();
|
||||
}
|
||||
}
|
||||
*server_state = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
IpcResponse::ok(id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn handle_query_result(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
pending: &PendingCallbacks,
|
||||
) -> IpcResponse {
|
||||
let result: DnsQueryResult = match serde_json::from_value(params.clone()) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid query result: {}", e)),
|
||||
};
|
||||
|
||||
let correlation_id = result.correlation_id.clone();
|
||||
if let Some((_, sender)) = pending.remove(&correlation_id) {
|
||||
let _ = sender.send(result);
|
||||
IpcResponse::ok(id, serde_json::json!({ "resolved": true }))
|
||||
} else {
|
||||
IpcResponse::err(id, format!("No pending query for correlationId: {}", correlation_id))
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_process_packet(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
server_state: &mut Option<ServerState>,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> IpcResponse {
|
||||
let packet_b64 = match params.get("packet").and_then(|v| v.as_str()) {
|
||||
Some(p) => p,
|
||||
None => return IpcResponse::err(id, "Missing packet parameter".to_string()),
|
||||
};
|
||||
|
||||
let packet_data = match base64_decode(packet_b64) {
|
||||
Ok(d) => d,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid base64: {}", e)),
|
||||
};
|
||||
|
||||
let state = match server_state {
|
||||
Some(ref s) => s,
|
||||
None => return IpcResponse::err(id, "Server not started".to_string()),
|
||||
};
|
||||
|
||||
let request = match DnsPacket::parse(&packet_data) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return IpcResponse::err(id, format!("Failed to parse packet: {}", e)),
|
||||
};
|
||||
|
||||
let response = resolve_with_callback(request, &state.resolver, pending, query_tx).await;
|
||||
let encoded = response.encode();
|
||||
|
||||
use base64::Engine;
|
||||
let response_b64 = base64::engine::general_purpose::STANDARD.encode(&encoded);
|
||||
IpcResponse::ok(id, serde_json::json!({ "packet": response_b64 }))
|
||||
}
|
||||
|
||||
/// Core resolution: try local first, then IPC callback to TypeScript.
|
||||
async fn resolve_with_callback(
|
||||
packet: DnsPacket,
|
||||
resolver: &DnsResolver,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> DnsPacket {
|
||||
// Try local resolution first (localhost, DNSKEY)
|
||||
if let Some(response) = resolver.try_local_resolution(&packet) {
|
||||
return response;
|
||||
}
|
||||
|
||||
// Need IPC callback to TypeScript
|
||||
let correlation_id = format!("dns_{}", uuid_v4());
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
pending.insert(correlation_id.clone(), tx);
|
||||
|
||||
// Send the query event to the management loop for emission
|
||||
if query_tx.send((correlation_id.clone(), packet.clone())).await.is_err() {
|
||||
pending.remove(&correlation_id);
|
||||
return DnsPacket::new_response(&packet);
|
||||
}
|
||||
|
||||
// Wait for the result with a timeout
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
|
||||
Ok(Ok(result)) => {
|
||||
resolver.build_response_from_answers(&packet, &result.answers, result.answered)
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
// Sender dropped
|
||||
pending.remove(&correlation_id);
|
||||
resolver.build_response_from_answers(&packet, &[], false)
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout
|
||||
pending.remove(&correlation_id);
|
||||
resolver.build_response_from_answers(&packet, &[], false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple UUID v4 generation (no external dep needed).
|
||||
fn uuid_v4() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let random: u64 = nanos as u64 ^ (std::process::id() as u64 * 0x517cc1b727220a95);
|
||||
format!("{:016x}{:016x}", nanos as u64, random)
|
||||
}
|
||||
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD
|
||||
.decode(input)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
258
rust/crates/rustdns/src/resolver.rs
Normal file
258
rust/crates/rustdns/src/resolver.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use crate::ipc_types::{IpcDnsAnswer, IpcDnsQuestion};
|
||||
use rustdns_protocol::packet::*;
|
||||
use rustdns_protocol::types::QType;
|
||||
use rustdns_dnssec::keys::{DnssecAlgorithm, DnssecKeyPair};
|
||||
use rustdns_dnssec::keytag::compute_key_tag;
|
||||
use rustdns_dnssec::signing::generate_rrsig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// DNS resolver that builds responses from IPC callback answers.
|
||||
pub struct DnsResolver {
|
||||
pub zone: String,
|
||||
pub primary_nameserver: String,
|
||||
pub enable_localhost: bool,
|
||||
pub key_pair: DnssecKeyPair,
|
||||
pub dnskey_rdata: Vec<u8>,
|
||||
pub key_tag: u16,
|
||||
}
|
||||
|
||||
impl DnsResolver {
|
||||
pub fn new(zone: &str, algorithm: DnssecAlgorithm, primary_nameserver: &str, enable_localhost: bool) -> Self {
|
||||
let key_pair = DnssecKeyPair::generate(algorithm);
|
||||
let dnskey_rdata = key_pair.dnskey_rdata();
|
||||
let key_tag = compute_key_tag(&dnskey_rdata);
|
||||
|
||||
let primary_ns = if primary_nameserver.is_empty() {
|
||||
format!("ns1.{}", zone)
|
||||
} else {
|
||||
primary_nameserver.to_string()
|
||||
};
|
||||
|
||||
DnsResolver {
|
||||
zone: zone.to_string(),
|
||||
primary_nameserver: primary_ns,
|
||||
enable_localhost,
|
||||
key_pair,
|
||||
dnskey_rdata,
|
||||
key_tag,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a query can be answered locally (localhost, DNSKEY).
|
||||
/// Returns Some(answers) if handled locally, None if it needs IPC callback.
|
||||
pub fn try_local_resolution(&self, packet: &DnsPacket) -> Option<DnsPacket> {
|
||||
let dnssec = packet.is_dnssec_requested();
|
||||
let mut response = DnsPacket::new_response(packet);
|
||||
let mut all_local = true;
|
||||
|
||||
for q in &packet.questions {
|
||||
if let Some(records) = self.try_local_question(q, dnssec) {
|
||||
for r in records {
|
||||
response.answers.push(r);
|
||||
}
|
||||
} else {
|
||||
all_local = false;
|
||||
}
|
||||
}
|
||||
|
||||
if all_local && !packet.questions.is_empty() {
|
||||
Some(response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn try_local_question(&self, q: &DnsQuestion, dnssec: bool) -> Option<Vec<DnsRecord>> {
|
||||
let name_lower = q.name.to_lowercase();
|
||||
let name_trimmed = name_lower.strip_suffix('.').unwrap_or(&name_lower);
|
||||
|
||||
// DNSKEY queries for our zone
|
||||
if dnssec && q.qtype == QType::DNSKEY && name_trimmed == self.zone.to_lowercase() {
|
||||
let record = build_record(&q.name, QType::DNSKEY, 3600, self.dnskey_rdata.clone());
|
||||
let mut records = vec![record.clone()];
|
||||
// Sign the DNSKEY record
|
||||
let rrsig = generate_rrsig(&self.key_pair, &self.zone, &[record], &q.name, QType::DNSKEY);
|
||||
records.push(rrsig);
|
||||
return Some(records);
|
||||
}
|
||||
|
||||
// Localhost handling (RFC 6761)
|
||||
if self.enable_localhost {
|
||||
if name_trimmed == "localhost" {
|
||||
match q.qtype {
|
||||
QType::A => {
|
||||
return Some(vec![build_record(&q.name, QType::A, 0, encode_a("127.0.0.1"))]);
|
||||
}
|
||||
QType::AAAA => {
|
||||
return Some(vec![build_record(&q.name, QType::AAAA, 0, encode_aaaa("::1"))]);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse localhost
|
||||
if name_trimmed == "1.0.0.127.in-addr.arpa" && q.qtype == QType::PTR {
|
||||
return Some(vec![build_record(&q.name, QType::PTR, 0, encode_name_rdata("localhost."))]);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Build a response from IPC callback answers.
|
||||
pub fn build_response_from_answers(
|
||||
&self,
|
||||
request: &DnsPacket,
|
||||
answers: &[IpcDnsAnswer],
|
||||
answered: bool,
|
||||
) -> DnsPacket {
|
||||
let dnssec = request.is_dnssec_requested();
|
||||
let mut response = DnsPacket::new_response(request);
|
||||
|
||||
if answered && !answers.is_empty() {
|
||||
// Group answers by (name, type) for DNSSEC RRset signing
|
||||
let mut rrset_map: HashMap<(String, QType), Vec<DnsRecord>> = HashMap::new();
|
||||
|
||||
for answer in answers {
|
||||
let rtype = QType::from_str(&answer.rtype);
|
||||
let rdata = self.encode_answer_rdata(rtype, &answer.data);
|
||||
let record = build_record(&answer.name, rtype, answer.ttl, rdata);
|
||||
response.answers.push(record.clone());
|
||||
|
||||
if dnssec {
|
||||
let key = (answer.name.clone(), rtype);
|
||||
rrset_map.entry(key).or_default().push(record);
|
||||
}
|
||||
}
|
||||
|
||||
// Sign RRsets
|
||||
if dnssec {
|
||||
for ((name, rtype), rrset) in &rrset_map {
|
||||
let rrsig = generate_rrsig(&self.key_pair, &self.zone, rrset, name, *rtype);
|
||||
response.answers.push(rrsig);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No handler matched — return SOA
|
||||
for q in &request.questions {
|
||||
let soa_rdata = encode_soa(
|
||||
&self.primary_nameserver,
|
||||
&format!("hostmaster.{}", self.zone),
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as u32,
|
||||
3600,
|
||||
600,
|
||||
604800,
|
||||
86400,
|
||||
);
|
||||
let soa_record = build_record(&q.name, QType::SOA, 3600, soa_rdata);
|
||||
response.answers.push(soa_record.clone());
|
||||
|
||||
if dnssec {
|
||||
let rrsig = generate_rrsig(&self.key_pair, &self.zone, &[soa_record], &q.name, QType::SOA);
|
||||
response.answers.push(rrsig);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Process a raw DNS packet (for manual/passthrough mode).
|
||||
/// Returns local answers or None if IPC callback is needed.
|
||||
pub fn process_packet_local(&self, data: &[u8]) -> Result<Option<Vec<u8>>, String> {
|
||||
let packet = DnsPacket::parse(data)?;
|
||||
if let Some(response) = self.try_local_resolution(&packet) {
|
||||
Ok(Some(response.encode()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_answer_rdata(&self, rtype: QType, data: &serde_json::Value) -> Vec<u8> {
|
||||
match rtype {
|
||||
QType::A => {
|
||||
if let Some(ip) = data.as_str() {
|
||||
encode_a(ip)
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::AAAA => {
|
||||
if let Some(ip) = data.as_str() {
|
||||
encode_aaaa(ip)
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::TXT => {
|
||||
if let Some(arr) = data.as_array() {
|
||||
let strings: Vec<String> = arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect();
|
||||
encode_txt(&strings)
|
||||
} else if let Some(s) = data.as_str() {
|
||||
encode_txt(&[s.to_string()])
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::NS | QType::CNAME | QType::PTR => {
|
||||
if let Some(name) = data.as_str() {
|
||||
encode_name_rdata(name)
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::MX => {
|
||||
let preference = data.get("preference").and_then(|v| v.as_u64()).unwrap_or(10) as u16;
|
||||
let exchange = data.get("exchange").and_then(|v| v.as_str()).unwrap_or("");
|
||||
encode_mx(preference, exchange)
|
||||
}
|
||||
QType::SRV => {
|
||||
let priority = data.get("priority").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let weight = data.get("weight").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let port = data.get("port").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let target = data.get("target").and_then(|v| v.as_str()).unwrap_or("");
|
||||
encode_srv(priority, weight, port, target)
|
||||
}
|
||||
QType::SOA => {
|
||||
let mname = data.get("mname").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let rname = data.get("rname").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let serial = data.get("serial").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let refresh = data.get("refresh").and_then(|v| v.as_u64()).unwrap_or(3600) as u32;
|
||||
let retry = data.get("retry").and_then(|v| v.as_u64()).unwrap_or(600) as u32;
|
||||
let expire = data.get("expire").and_then(|v| v.as_u64()).unwrap_or(604800) as u32;
|
||||
let minimum = data.get("minimum").and_then(|v| v.as_u64()).unwrap_or(86400) as u32;
|
||||
encode_soa(mname, rname, serial, refresh, retry, expire, minimum)
|
||||
}
|
||||
_ => {
|
||||
// For unknown types, try to interpret as raw base64
|
||||
if let Some(b64) = data.as_str() {
|
||||
base64_decode(b64).unwrap_or_default()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert questions to IPC format.
|
||||
pub fn questions_to_ipc(questions: &[DnsQuestion]) -> Vec<IpcDnsQuestion> {
|
||||
questions
|
||||
.iter()
|
||||
.map(|q| IpcDnsQuestion {
|
||||
name: q.name.clone(),
|
||||
qtype: q.qtype.as_str().to_string(),
|
||||
class: "IN".to_string(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD
|
||||
.decode(input)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
Reference in New Issue
Block a user