feat(rust): add Rust-based DNS server backend with IPC management and TypeScript bridge

This commit is contained in:
2026-02-11 11:24:10 +00:00
parent abbb971d6a
commit 60371e1ad5
37 changed files with 4509 additions and 1272 deletions

2
rust/.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"

1446
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

8
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,8 @@
[workspace]
resolver = "2"
members = [
"crates/rustdns",
"crates/rustdns-protocol",
"crates/rustdns-server",
"crates/rustdns-dnssec",
]

View File

@@ -0,0 +1,11 @@
[package]
name = "rustdns-dnssec"
version = "0.1.0"
edition = "2021"
[dependencies]
rustdns-protocol = { path = "../rustdns-protocol" }
p256 = { version = "0.13", features = ["ecdsa", "ecdsa-core"] }
ed25519-dalek = { version = "2", features = ["rand_core"] }
sha2 = "0.10"
rand = "0.8"

View File

@@ -0,0 +1,157 @@
use p256::ecdsa::SigningKey as EcdsaSigningKey;
use ed25519_dalek::SigningKey as Ed25519SigningKey;
use rand::rngs::OsRng;
/// Supported DNSSEC algorithms.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnssecAlgorithm {
/// ECDSA P-256 with SHA-256 (algorithm 13)
EcdsaP256Sha256,
/// ED25519 (algorithm 15)
Ed25519,
}
impl DnssecAlgorithm {
pub fn number(&self) -> u8 {
match self {
DnssecAlgorithm::EcdsaP256Sha256 => 13,
DnssecAlgorithm::Ed25519 => 15,
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_uppercase().as_str() {
"ECDSA" | "ECDSAP256SHA256" => Some(DnssecAlgorithm::EcdsaP256Sha256),
"ED25519" => Some(DnssecAlgorithm::Ed25519),
_ => None,
}
}
}
/// A DNSSEC key pair with material for signing and DNSKEY generation.
pub enum DnssecKeyPair {
EcdsaP256 {
signing_key: EcdsaSigningKey,
},
Ed25519 {
signing_key: Ed25519SigningKey,
},
}
impl DnssecKeyPair {
/// Generate a new key pair for the given algorithm.
pub fn generate(algorithm: DnssecAlgorithm) -> Self {
match algorithm {
DnssecAlgorithm::EcdsaP256Sha256 => {
let signing_key = EcdsaSigningKey::random(&mut OsRng);
DnssecKeyPair::EcdsaP256 { signing_key }
}
DnssecAlgorithm::Ed25519 => {
let signing_key = Ed25519SigningKey::generate(&mut OsRng);
DnssecKeyPair::Ed25519 { signing_key }
}
}
}
/// Get the algorithm.
pub fn algorithm(&self) -> DnssecAlgorithm {
match self {
DnssecKeyPair::EcdsaP256 { .. } => DnssecAlgorithm::EcdsaP256Sha256,
DnssecKeyPair::Ed25519 { .. } => DnssecAlgorithm::Ed25519,
}
}
/// Get the public key bytes for the DNSKEY record.
/// For ECDSA P-256: 64 bytes (uncompressed x || y, without 0x04 prefix).
/// For ED25519: 32 bytes.
pub fn public_key_bytes(&self) -> Vec<u8> {
match self {
DnssecKeyPair::EcdsaP256 { signing_key } => {
use p256::ecdsa::VerifyingKey;
let verifying_key = VerifyingKey::from(signing_key);
let point = verifying_key.to_encoded_point(false); // uncompressed
let bytes = point.as_bytes();
// Remove 0x04 prefix for DNS format
bytes[1..].to_vec()
}
DnssecKeyPair::Ed25519 { signing_key } => {
let verifying_key = signing_key.verifying_key();
verifying_key.as_bytes().to_vec()
}
}
}
/// Get the DNSKEY RDATA (flags=256/ZSK, protocol=3, algorithm, public key).
pub fn dnskey_rdata(&self) -> Vec<u8> {
let flags: u16 = 256; // Zone Signing Key
let protocol: u8 = 3;
let algorithm = self.algorithm().number();
let pubkey = self.public_key_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&flags.to_be_bytes());
buf.push(protocol);
buf.push(algorithm);
buf.extend_from_slice(&pubkey);
buf
}
/// Sign data with this key pair.
pub fn sign(&self, data: &[u8]) -> Vec<u8> {
match self {
DnssecKeyPair::EcdsaP256 { signing_key } => {
use p256::ecdsa::{signature::Signer, Signature};
let sig: Signature = signing_key.sign(data);
sig.to_der().as_bytes().to_vec()
}
DnssecKeyPair::Ed25519 { signing_key } => {
use ed25519_dalek::Signer;
let sig = signing_key.sign(data);
sig.to_bytes().to_vec()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ecdsa_key_generation() {
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
assert_eq!(kp.algorithm(), DnssecAlgorithm::EcdsaP256Sha256);
assert_eq!(kp.public_key_bytes().len(), 64); // x(32) + y(32)
}
#[test]
fn test_ed25519_key_generation() {
let kp = DnssecKeyPair::generate(DnssecAlgorithm::Ed25519);
assert_eq!(kp.algorithm(), DnssecAlgorithm::Ed25519);
assert_eq!(kp.public_key_bytes().len(), 32);
}
#[test]
fn test_dnskey_rdata() {
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
let rdata = kp.dnskey_rdata();
// flags(2) + protocol(1) + algorithm(1) + pubkey(64) = 68
assert_eq!(rdata.len(), 68);
assert_eq!(rdata[0], 1); // flags high byte (256 >> 8)
assert_eq!(rdata[1], 0); // flags low byte
assert_eq!(rdata[2], 3); // protocol
assert_eq!(rdata[3], 13); // algorithm 13 = ECDSA P-256
}
#[test]
fn test_sign_and_verify() {
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
let data = b"test data to sign";
let sig = kp.sign(data);
assert!(!sig.is_empty());
let kp2 = DnssecKeyPair::generate(DnssecAlgorithm::Ed25519);
let sig2 = kp2.sign(data);
assert!(!sig2.is_empty());
}
}

View File

@@ -0,0 +1,38 @@
/// Compute the DNSSEC key tag as per RFC 4034 Appendix B.
/// Input is the full DNSKEY RDATA (flags + protocol + algorithm + public key).
pub fn compute_key_tag(dnskey_rdata: &[u8]) -> u16 {
let mut acc: u32 = 0;
for (i, &byte) in dnskey_rdata.iter().enumerate() {
if i & 1 == 0 {
acc += (byte as u32) << 8;
} else {
acc += byte as u32;
}
}
acc += (acc >> 16) & 0xFFFF;
(acc & 0xFFFF) as u16
}
/// Compute a DS record digest (SHA-256) from owner name + DNSKEY RDATA.
pub fn compute_ds_digest(owner_name_wire: &[u8], dnskey_rdata: &[u8]) -> Vec<u8> {
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(owner_name_wire);
hasher.update(dnskey_rdata);
hasher.finalize().to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_tag_computation() {
// A known DNSKEY RDATA: flags=256, protocol=3, algorithm=13, plus some key bytes
let mut rdata = vec![1u8, 0, 3, 13]; // flags=256, protocol=3, algorithm=13
rdata.extend_from_slice(&[0u8; 64]); // dummy 64-byte key
let tag = compute_key_tag(&rdata);
// Just verify it produces a reasonable value
assert!(tag > 0);
}
}

View File

@@ -0,0 +1,3 @@
pub mod keys;
pub mod signing;
pub mod keytag;

View File

@@ -0,0 +1,147 @@
use crate::keys::DnssecKeyPair;
use crate::keytag::compute_key_tag;
use rustdns_protocol::name::encode_name;
use rustdns_protocol::packet::{encode_rrsig, DnsRecord};
use rustdns_protocol::types::QType;
use sha2::{Sha256, Digest};
/// Canonical RRset serialization for DNSSEC signing (RFC 4034 Section 6).
/// Each record: name(wire) + type(2) + class(2) + ttl(4) + rdlength(2) + rdata
pub fn serialize_rrset_canonical(records: &[DnsRecord]) -> Vec<u8> {
let mut buf = Vec::new();
for rr in records {
if rr.rtype == QType::OPT {
continue;
}
let name = if rr.name.ends_with('.') {
rr.name.to_lowercase()
} else {
format!("{}.", rr.name).to_lowercase()
};
buf.extend_from_slice(&encode_name(&name));
buf.extend_from_slice(&rr.rtype.to_u16().to_be_bytes());
buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes());
buf.extend_from_slice(&rr.ttl.to_be_bytes());
buf.extend_from_slice(&(rr.rdata.len() as u16).to_be_bytes());
buf.extend_from_slice(&rr.rdata);
}
buf
}
/// Generate an RRSIG record for a given RRset.
pub fn generate_rrsig(
key_pair: &DnssecKeyPair,
zone: &str,
rrset: &[DnsRecord],
name: &str,
rtype: QType,
) -> DnsRecord {
let algorithm = key_pair.algorithm().number();
let dnskey_rdata = key_pair.dnskey_rdata();
let key_tag = compute_key_tag(&dnskey_rdata);
let signers_name = if zone.ends_with('.') {
zone.to_string()
} else {
format!("{}.", zone)
};
let ttl = if rrset.is_empty() { 300 } else { rrset[0].ttl };
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
let inception = now.wrapping_sub(3600); // 1 hour ago
let expiration = inception.wrapping_add(86400); // +1 day
let labels = name
.strip_suffix('.')
.unwrap_or(name)
.split('.')
.filter(|l| !l.is_empty())
.count() as u8;
// Build the RRSIG RDATA preamble (everything before the signature)
let type_covered = rtype.to_u16();
let mut sig_data = Vec::new();
sig_data.extend_from_slice(&type_covered.to_be_bytes());
sig_data.push(algorithm);
sig_data.push(labels);
sig_data.extend_from_slice(&ttl.to_be_bytes());
sig_data.extend_from_slice(&expiration.to_be_bytes());
sig_data.extend_from_slice(&inception.to_be_bytes());
sig_data.extend_from_slice(&key_tag.to_be_bytes());
sig_data.extend_from_slice(&encode_name(&signers_name));
// Append the canonical RRset
sig_data.extend_from_slice(&serialize_rrset_canonical(rrset));
// Sign: ECDSA uses SHA-256 internally via the p256 crate, ED25519 does its own hashing
let signature = match key_pair {
DnssecKeyPair::EcdsaP256 { .. } => {
// For ECDSA, we hash first then sign
let hash = Sha256::digest(&sig_data);
key_pair.sign(&hash)
}
DnssecKeyPair::Ed25519 { .. } => {
// ED25519 includes hashing internally
key_pair.sign(&sig_data)
}
};
let rrsig_rdata = encode_rrsig(
type_covered,
algorithm,
labels,
ttl,
expiration,
inception,
key_tag,
&signers_name,
&signature,
);
DnsRecord {
name: name.to_string(),
rtype: QType::RRSIG,
rclass: rustdns_protocol::types::QClass::IN,
ttl,
rdata: rrsig_rdata,
opt_flags: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::keys::{DnssecAlgorithm, DnssecKeyPair};
use rustdns_protocol::packet::{build_record, encode_a};
#[test]
fn test_generate_rrsig_ecdsa() {
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
let record = build_record("test.example.com", QType::A, 300, encode_a("127.0.0.1"));
let rrsig = generate_rrsig(&kp, "example.com", &[record], "test.example.com", QType::A);
assert_eq!(rrsig.rtype, QType::RRSIG);
assert!(!rrsig.rdata.is_empty());
}
#[test]
fn test_generate_rrsig_ed25519() {
let kp = DnssecKeyPair::generate(DnssecAlgorithm::Ed25519);
let record = build_record("test.example.com", QType::A, 300, encode_a("10.0.0.1"));
let rrsig = generate_rrsig(&kp, "example.com", &[record], "test.example.com", QType::A);
assert_eq!(rrsig.rtype, QType::RRSIG);
assert!(!rrsig.rdata.is_empty());
}
#[test]
fn test_serialize_rrset_canonical() {
let r1 = build_record("example.com", QType::A, 300, encode_a("1.2.3.4"));
let r2 = build_record("example.com", QType::A, 300, encode_a("5.6.7.8"));
let serialized = serialize_rrset_canonical(&[r1, r2]);
assert!(!serialized.is_empty());
}
}

View File

@@ -0,0 +1,6 @@
[package]
name = "rustdns-protocol"
version = "0.1.0"
edition = "2021"
[dependencies]

View File

@@ -0,0 +1,3 @@
pub mod types;
pub mod name;
pub mod packet;

View File

@@ -0,0 +1,108 @@
/// Encode a domain name into DNS wire format.
/// e.g. "example.com" -> [7, 'e','x','a','m','p','l','e', 3, 'c','o','m', 0]
pub fn encode_name(name: &str) -> Vec<u8> {
let mut buf = Vec::new();
let trimmed = name.strip_suffix('.').unwrap_or(name);
if trimmed.is_empty() {
buf.push(0);
return buf;
}
for label in trimmed.split('.') {
let len = label.len();
if len > 63 {
// Truncate to 63 per DNS spec
buf.push(63);
buf.extend_from_slice(&label.as_bytes()[..63]);
} else {
buf.push(len as u8);
buf.extend_from_slice(label.as_bytes());
}
}
buf.push(0); // root label
buf
}
/// Decode a domain name from DNS wire format at the given offset.
/// Returns (name, bytes_consumed).
/// Handles compression pointers (0xC0 prefix).
pub fn decode_name(data: &[u8], offset: usize) -> Result<(String, usize), &'static str> {
let mut labels: Vec<String> = Vec::new();
let mut pos = offset;
let mut bytes_consumed = 0;
let mut jumped = false;
loop {
if pos >= data.len() {
return Err("unexpected end of data in name");
}
let len = data[pos] as usize;
if len == 0 {
// Root label
if !jumped {
bytes_consumed = pos - offset + 1;
}
break;
}
// Check for compression pointer
if len & 0xC0 == 0xC0 {
if pos + 1 >= data.len() {
return Err("unexpected end of data in compression pointer");
}
let pointer = ((len & 0x3F) << 8) | (data[pos + 1] as usize);
if !jumped {
bytes_consumed = pos - offset + 2;
jumped = true;
}
pos = pointer;
continue;
}
// Regular label
pos += 1;
if pos + len > data.len() {
return Err("label extends beyond data");
}
let label = std::str::from_utf8(&data[pos..pos + len]).map_err(|_| "invalid UTF-8 in label")?;
labels.push(label.to_string());
pos += len;
}
if bytes_consumed == 0 && !jumped {
bytes_consumed = 1; // just the root label
}
Ok((labels.join("."), bytes_consumed))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_roundtrip() {
let names = vec!["example.com", "sub.domain.example.com", "localhost", "a.b.c.d.e"];
for name in names {
let encoded = encode_name(name);
let (decoded, consumed) = decode_name(&encoded, 0).unwrap();
assert_eq!(decoded, name);
assert_eq!(consumed, encoded.len());
}
}
#[test]
fn test_encode_trailing_dot() {
let a = encode_name("example.com.");
let b = encode_name("example.com");
assert_eq!(a, b);
}
#[test]
fn test_root_name() {
let encoded = encode_name("");
assert_eq!(encoded, vec![0]);
let (decoded, _) = decode_name(&encoded, 0).unwrap();
assert_eq!(decoded, "");
}
}

View File

@@ -0,0 +1,442 @@
use crate::name::{decode_name, encode_name};
use crate::types::{QClass, QType, FLAG_QR, FLAG_AA, FLAG_RD, FLAG_RA, EDNS_DO_BIT};
/// A parsed DNS question.
#[derive(Debug, Clone)]
pub struct DnsQuestion {
pub name: String,
pub qtype: QType,
pub qclass: QClass,
}
/// A parsed DNS resource record.
#[derive(Debug, Clone)]
pub struct DnsRecord {
pub name: String,
pub rtype: QType,
pub rclass: QClass,
pub ttl: u32,
pub rdata: Vec<u8>,
// For OPT records, the flags are stored in the TTL field position
pub opt_flags: Option<u16>,
}
/// A complete DNS packet (parsed).
#[derive(Debug, Clone)]
pub struct DnsPacket {
pub id: u16,
pub flags: u16,
pub questions: Vec<DnsQuestion>,
pub answers: Vec<DnsRecord>,
pub authorities: Vec<DnsRecord>,
pub additionals: Vec<DnsRecord>,
}
impl DnsPacket {
/// Create a new empty query packet.
pub fn new_query(id: u16) -> Self {
DnsPacket {
id,
flags: 0,
questions: Vec::new(),
answers: Vec::new(),
authorities: Vec::new(),
additionals: Vec::new(),
}
}
/// Create a response packet for a given request.
pub fn new_response(request: &DnsPacket) -> Self {
let mut flags = FLAG_QR | FLAG_AA | FLAG_RA;
if request.flags & FLAG_RD != 0 {
flags |= FLAG_RD;
}
DnsPacket {
id: request.id,
flags,
questions: request.questions.clone(),
answers: Vec::new(),
authorities: Vec::new(),
additionals: Vec::new(),
}
}
/// Check if DNSSEC (DO bit) is requested in the OPT record.
pub fn is_dnssec_requested(&self) -> bool {
for additional in &self.additionals {
if additional.rtype == QType::OPT {
if let Some(flags) = additional.opt_flags {
if flags & EDNS_DO_BIT != 0 {
return true;
}
}
}
}
false
}
/// Parse a DNS packet from wire format bytes.
pub fn parse(data: &[u8]) -> Result<Self, String> {
if data.len() < 12 {
return Err("packet too short for DNS header".into());
}
let id = u16::from_be_bytes([data[0], data[1]]);
let flags = u16::from_be_bytes([data[2], data[3]]);
let qdcount = u16::from_be_bytes([data[4], data[5]]) as usize;
let ancount = u16::from_be_bytes([data[6], data[7]]) as usize;
let nscount = u16::from_be_bytes([data[8], data[9]]) as usize;
let arcount = u16::from_be_bytes([data[10], data[11]]) as usize;
let mut offset = 12;
// Parse questions
let mut questions = Vec::with_capacity(qdcount);
for _ in 0..qdcount {
let (name, consumed) = decode_name(data, offset).map_err(|e| e.to_string())?;
offset += consumed;
if offset + 4 > data.len() {
return Err("packet too short for question fields".into());
}
let qtype = QType::from_u16(u16::from_be_bytes([data[offset], data[offset + 1]]));
let qclass = QClass::from_u16(u16::from_be_bytes([data[offset + 2], data[offset + 3]]));
offset += 4;
questions.push(DnsQuestion { name, qtype, qclass });
}
// Parse resource records
fn parse_records(data: &[u8], offset: &mut usize, count: usize) -> Result<Vec<DnsRecord>, String> {
let mut records = Vec::with_capacity(count);
for _ in 0..count {
let (name, consumed) = decode_name(data, *offset).map_err(|e| e.to_string())?;
*offset += consumed;
if *offset + 10 > data.len() {
return Err("packet too short for RR fields".into());
}
let rtype = QType::from_u16(u16::from_be_bytes([data[*offset], data[*offset + 1]]));
let rclass_or_payload = u16::from_be_bytes([data[*offset + 2], data[*offset + 3]]);
let ttl_bytes = u32::from_be_bytes([data[*offset + 4], data[*offset + 5], data[*offset + 6], data[*offset + 7]]);
let rdlength = u16::from_be_bytes([data[*offset + 8], data[*offset + 9]]) as usize;
*offset += 10;
if *offset + rdlength > data.len() {
return Err("packet too short for RDATA".into());
}
let rdata = data[*offset..*offset + rdlength].to_vec();
*offset += rdlength;
// For OPT records, extract flags from the TTL position
let (rclass, ttl, opt_flags) = if rtype == QType::OPT {
// OPT: class = UDP payload size, TTL upper 16 = extended RCODE + version, lower 16 = flags
let flags = (ttl_bytes & 0xFFFF) as u16;
(QClass::from_u16(rclass_or_payload), 0, Some(flags))
} else {
(QClass::from_u16(rclass_or_payload), ttl_bytes, None)
};
records.push(DnsRecord {
name,
rtype,
rclass,
ttl,
rdata,
opt_flags,
});
}
Ok(records)
}
let answers = parse_records(data, &mut offset, ancount)?;
let authorities = parse_records(data, &mut offset, nscount)?;
let additionals = parse_records(data, &mut offset, arcount)?;
Ok(DnsPacket {
id,
flags,
questions,
answers,
authorities,
additionals,
})
}
/// Encode this DNS packet to wire format bytes.
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(512);
// Header
buf.extend_from_slice(&self.id.to_be_bytes());
buf.extend_from_slice(&self.flags.to_be_bytes());
buf.extend_from_slice(&(self.questions.len() as u16).to_be_bytes());
buf.extend_from_slice(&(self.answers.len() as u16).to_be_bytes());
buf.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes());
buf.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes());
// Questions
for q in &self.questions {
buf.extend_from_slice(&encode_name(&q.name));
buf.extend_from_slice(&q.qtype.to_u16().to_be_bytes());
buf.extend_from_slice(&q.qclass.to_u16().to_be_bytes());
}
// Resource records
fn encode_records(buf: &mut Vec<u8>, records: &[DnsRecord]) {
for rr in records {
buf.extend_from_slice(&encode_name(&rr.name));
buf.extend_from_slice(&rr.rtype.to_u16().to_be_bytes());
if rr.rtype == QType::OPT {
// OPT: class = UDP payload size (4096), TTL = ext rcode + flags
buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes());
let flags = rr.opt_flags.unwrap_or(0) as u32;
buf.extend_from_slice(&flags.to_be_bytes());
} else {
buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes());
buf.extend_from_slice(&rr.ttl.to_be_bytes());
}
buf.extend_from_slice(&(rr.rdata.len() as u16).to_be_bytes());
buf.extend_from_slice(&rr.rdata);
}
}
encode_records(&mut buf, &self.answers);
encode_records(&mut buf, &self.authorities);
encode_records(&mut buf, &self.additionals);
buf
}
}
// ── RDATA encoding helpers ─────────────────────────────────────────
/// Encode an A record (IPv4 address string -> 4 bytes).
pub fn encode_a(ip: &str) -> Vec<u8> {
ip.split('.')
.filter_map(|s| s.parse::<u8>().ok())
.collect()
}
/// Encode an AAAA record (IPv6 address string -> 16 bytes).
pub fn encode_aaaa(ip: &str) -> Vec<u8> {
// Handle :: expansion
let expanded = expand_ipv6(ip);
expanded
.split(':')
.flat_map(|seg| {
let val = u16::from_str_radix(seg, 16).unwrap_or(0);
val.to_be_bytes().to_vec()
})
.collect()
}
fn expand_ipv6(ip: &str) -> String {
if !ip.contains("::") {
return ip.to_string();
}
let parts: Vec<&str> = ip.split("::").collect();
let left: Vec<&str> = if parts[0].is_empty() {
vec![]
} else {
parts[0].split(':').collect()
};
let right: Vec<&str> = if parts.len() > 1 && !parts[1].is_empty() {
parts[1].split(':').collect()
} else {
vec![]
};
let fill_count = 8 - left.len() - right.len();
let mut result: Vec<String> = left.iter().map(|s| s.to_string()).collect();
for _ in 0..fill_count {
result.push("0".to_string());
}
result.extend(right.iter().map(|s| s.to_string()));
result.join(":")
}
/// Encode a TXT record (array of strings -> length-prefixed chunks).
pub fn encode_txt(strings: &[String]) -> Vec<u8> {
let mut buf = Vec::new();
for s in strings {
let bytes = s.as_bytes();
// TXT strings must be <= 255 bytes each
let len = bytes.len().min(255);
buf.push(len as u8);
buf.extend_from_slice(&bytes[..len]);
}
buf
}
/// Encode a domain name for use in RDATA (NS, CNAME, PTR, etc.).
pub fn encode_name_rdata(name: &str) -> Vec<u8> {
encode_name(name)
}
/// Encode a SOA record RDATA.
pub fn encode_soa(mname: &str, rname: &str, serial: u32, refresh: u32, retry: u32, expire: u32, minimum: u32) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&encode_name(mname));
buf.extend_from_slice(&encode_name(rname));
buf.extend_from_slice(&serial.to_be_bytes());
buf.extend_from_slice(&refresh.to_be_bytes());
buf.extend_from_slice(&retry.to_be_bytes());
buf.extend_from_slice(&expire.to_be_bytes());
buf.extend_from_slice(&minimum.to_be_bytes());
buf
}
/// Encode an MX record RDATA.
pub fn encode_mx(preference: u16, exchange: &str) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&preference.to_be_bytes());
buf.extend_from_slice(&encode_name(exchange));
buf
}
/// Encode a SRV record RDATA.
pub fn encode_srv(priority: u16, weight: u16, port: u16, target: &str) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&priority.to_be_bytes());
buf.extend_from_slice(&weight.to_be_bytes());
buf.extend_from_slice(&port.to_be_bytes());
buf.extend_from_slice(&encode_name(target));
buf
}
/// Encode a DNSKEY record RDATA.
pub fn encode_dnskey(flags: u16, protocol: u8, algorithm: u8, public_key: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&flags.to_be_bytes());
buf.push(protocol);
buf.push(algorithm);
buf.extend_from_slice(public_key);
buf
}
/// Encode an RRSIG record RDATA.
pub fn encode_rrsig(
type_covered: u16,
algorithm: u8,
labels: u8,
original_ttl: u32,
expiration: u32,
inception: u32,
key_tag: u16,
signers_name: &str,
signature: &[u8],
) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&type_covered.to_be_bytes());
buf.push(algorithm);
buf.push(labels);
buf.extend_from_slice(&original_ttl.to_be_bytes());
buf.extend_from_slice(&expiration.to_be_bytes());
buf.extend_from_slice(&inception.to_be_bytes());
buf.extend_from_slice(&key_tag.to_be_bytes());
buf.extend_from_slice(&encode_name(signers_name));
buf.extend_from_slice(signature);
buf
}
/// Build a DnsRecord from high-level data.
pub fn build_record(name: &str, rtype: QType, ttl: u32, rdata: Vec<u8>) -> DnsRecord {
DnsRecord {
name: name.to_string(),
rtype,
rclass: QClass::IN,
ttl,
rdata,
opt_flags: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_encode_roundtrip() {
// Build a simple query
let mut query = DnsPacket::new_query(0x1234);
query.flags = FLAG_RD;
query.questions.push(DnsQuestion {
name: "example.com".to_string(),
qtype: QType::A,
qclass: QClass::IN,
});
let encoded = query.encode();
let parsed = DnsPacket::parse(&encoded).unwrap();
assert_eq!(parsed.id, 0x1234);
assert_eq!(parsed.questions.len(), 1);
assert_eq!(parsed.questions[0].name, "example.com");
assert_eq!(parsed.questions[0].qtype, QType::A);
}
#[test]
fn test_response_with_answer() {
let mut query = DnsPacket::new_query(0x5678);
query.flags = FLAG_RD;
query.questions.push(DnsQuestion {
name: "test.example.com".to_string(),
qtype: QType::A,
qclass: QClass::IN,
});
let mut response = DnsPacket::new_response(&query);
response.answers.push(build_record(
"test.example.com",
QType::A,
300,
encode_a("127.0.0.1"),
));
let encoded = response.encode();
let parsed = DnsPacket::parse(&encoded).unwrap();
assert_eq!(parsed.id, 0x5678);
assert!(parsed.flags & FLAG_QR != 0); // Is a response
assert!(parsed.flags & FLAG_AA != 0); // Authoritative
assert_eq!(parsed.answers.len(), 1);
assert_eq!(parsed.answers[0].rdata, vec![127, 0, 0, 1]);
}
#[test]
fn test_encode_aaaa() {
let data = encode_aaaa("::1");
assert_eq!(data.len(), 16);
assert_eq!(data[15], 1);
assert!(data[..15].iter().all(|&b| b == 0));
}
#[test]
fn test_encode_txt() {
let data = encode_txt(&["hello".to_string(), "world".to_string()]);
assert_eq!(data[0], 5); // length of "hello"
assert_eq!(&data[1..6], b"hello");
assert_eq!(data[6], 5); // length of "world"
assert_eq!(&data[7..12], b"world");
}
#[test]
fn test_dnssec_do_bit() {
let mut query = DnsPacket::new_query(1);
query.questions.push(DnsQuestion {
name: "example.com".to_string(),
qtype: QType::A,
qclass: QClass::IN,
});
// No OPT record = no DNSSEC
assert!(!query.is_dnssec_requested());
// Add OPT with DO bit
query.additionals.push(DnsRecord {
name: ".".to_string(),
rtype: QType::OPT,
rclass: QClass::from_u16(4096), // UDP payload size
ttl: 0,
rdata: vec![],
opt_flags: Some(EDNS_DO_BIT),
});
assert!(query.is_dnssec_requested());
}
}

View File

@@ -0,0 +1,131 @@
/// DNS record types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum QType {
A = 1,
NS = 2,
CNAME = 5,
SOA = 6,
PTR = 12,
MX = 15,
TXT = 16,
AAAA = 28,
SRV = 33,
OPT = 41,
RRSIG = 46,
DNSKEY = 48,
Unknown(u16),
}
impl QType {
pub fn from_u16(val: u16) -> Self {
match val {
1 => QType::A,
2 => QType::NS,
5 => QType::CNAME,
6 => QType::SOA,
12 => QType::PTR,
15 => QType::MX,
16 => QType::TXT,
28 => QType::AAAA,
33 => QType::SRV,
41 => QType::OPT,
46 => QType::RRSIG,
48 => QType::DNSKEY,
v => QType::Unknown(v),
}
}
pub fn to_u16(self) -> u16 {
match self {
QType::A => 1,
QType::NS => 2,
QType::CNAME => 5,
QType::SOA => 6,
QType::PTR => 12,
QType::MX => 15,
QType::TXT => 16,
QType::AAAA => 28,
QType::SRV => 33,
QType::OPT => 41,
QType::RRSIG => 46,
QType::DNSKEY => 48,
QType::Unknown(v) => v,
}
}
pub fn from_str(s: &str) -> Self {
match s.to_uppercase().as_str() {
"A" => QType::A,
"NS" => QType::NS,
"CNAME" => QType::CNAME,
"SOA" => QType::SOA,
"PTR" => QType::PTR,
"MX" => QType::MX,
"TXT" => QType::TXT,
"AAAA" => QType::AAAA,
"SRV" => QType::SRV,
"OPT" => QType::OPT,
"RRSIG" => QType::RRSIG,
"DNSKEY" => QType::DNSKEY,
_ => QType::Unknown(0),
}
}
pub fn as_str(&self) -> &'static str {
match self {
QType::A => "A",
QType::NS => "NS",
QType::CNAME => "CNAME",
QType::SOA => "SOA",
QType::PTR => "PTR",
QType::MX => "MX",
QType::TXT => "TXT",
QType::AAAA => "AAAA",
QType::SRV => "SRV",
QType::OPT => "OPT",
QType::RRSIG => "RRSIG",
QType::DNSKEY => "DNSKEY",
QType::Unknown(_) => "UNKNOWN",
}
}
}
/// DNS record classes
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum QClass {
IN = 1,
CH = 3,
HS = 4,
Unknown(u16),
}
impl QClass {
pub fn from_u16(val: u16) -> Self {
match val {
1 => QClass::IN,
3 => QClass::CH,
4 => QClass::HS,
v => QClass::Unknown(v),
}
}
pub fn to_u16(self) -> u16 {
match self {
QClass::IN => 1,
QClass::CH => 3,
QClass::HS => 4,
QClass::Unknown(v) => v,
}
}
}
/// DNS header flags
pub const FLAG_QR: u16 = 0x8000;
pub const FLAG_AA: u16 = 0x0400;
pub const FLAG_RD: u16 = 0x0100;
pub const FLAG_RA: u16 = 0x0080;
/// OPT record DO bit (DNSSEC OK)
pub const EDNS_DO_BIT: u16 = 0x8000;

View File

@@ -0,0 +1,17 @@
[package]
name = "rustdns-server"
version = "0.1.0"
edition = "2021"
[dependencies]
rustdns-protocol = { path = "../rustdns-protocol" }
rustdns-dnssec = { path = "../rustdns-dnssec" }
tokio = { version = "1", features = ["full"] }
hyper = { version = "1", features = ["http1", "server"] }
hyper-util = { version = "0.1", features = ["tokio"] }
http-body-util = "0.1"
rustls = { version = "0.23", features = ["ring"] }
tokio-rustls = "0.26"
rustls-pemfile = "2"
tracing = "0.1"
bytes = "1"

View File

@@ -0,0 +1,164 @@
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use http_body_util::{BodyExt, Full};
use rustdns_protocol::packet::DnsPacket;
use rustls::ServerConfig;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tracing::{error, info};
/// Configuration for the HTTPS DoH server.
pub struct HttpsServerConfig {
pub bind_addr: SocketAddr,
pub tls_config: Arc<ServerConfig>,
}
/// An HTTPS DNS-over-HTTPS server.
pub struct HttpsServer {
shutdown: tokio::sync::watch::Sender<bool>,
local_addr: SocketAddr,
}
impl HttpsServer {
/// Start the HTTPS DoH server.
pub async fn start<F, Fut>(
config: HttpsServerConfig,
resolver: F,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
where
F: Fn(DnsPacket) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = DnsPacket> + Send + 'static,
{
let listener = TcpListener::bind(config.bind_addr).await?;
let local_addr = listener.local_addr()?;
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let tls_acceptor = TlsAcceptor::from(config.tls_config);
let resolver = Arc::new(resolver);
info!("HTTPS DoH server listening on {}", local_addr);
tokio::spawn(async move {
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, _peer_addr)) => {
let acceptor = tls_acceptor.clone();
let resolver = resolver.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
let resolver = resolver.clone();
let service = service_fn(move |req: Request<Incoming>| {
let resolver = resolver.clone();
async move {
handle_doh_request(req, resolver).await
}
});
if let Err(e) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
error!("HTTPS connection error: {}", e);
}
}
Err(e) => {
error!("TLS accept error: {}", e);
}
}
});
}
Err(e) => {
error!("TCP accept error: {}", e);
}
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("HTTPS DoH server shutting down");
break;
}
}
}
}
});
Ok(HttpsServer {
shutdown: shutdown_tx,
local_addr,
})
}
/// Stop the HTTPS server.
pub fn stop(&self) {
let _ = self.shutdown.send(true);
}
/// Get the bound local address.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
async fn handle_doh_request<F, Fut>(
req: Request<Incoming>,
resolver: Arc<F>,
) -> Result<Response<Full<bytes::Bytes>>, hyper::Error>
where
F: Fn(DnsPacket) -> Fut + Send + Sync,
Fut: std::future::Future<Output = DnsPacket> + Send,
{
if req.method() == hyper::Method::POST && req.uri().path() == "/dns-query" {
let body = req.collect().await?.to_bytes();
match DnsPacket::parse(&body) {
Ok(request) => {
let response = resolver(request).await;
let encoded = response.encode();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/dns-message")
.body(Full::new(bytes::Bytes::from(encoded)))
.unwrap())
}
Err(e) => {
error!("Failed to parse DoH request: {}", e);
Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(bytes::Bytes::from(format!("Invalid DNS message: {}", e))))
.unwrap())
}
}
} else {
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(bytes::Bytes::new()))
.unwrap())
}
}
/// Create a rustls ServerConfig from PEM-encoded certificate and key.
pub fn create_tls_config(cert_pem: &str, key_pem: &str) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error + Send + Sync>> {
let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
.ok_or("no private key found in PEM data")?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
Ok(Arc::new(config))
}

View File

@@ -0,0 +1,12 @@
pub mod udp;
pub mod https;
use rustdns_protocol::packet::DnsPacket;
use std::future::Future;
use std::pin::Pin;
/// Trait for DNS query resolution.
/// The resolver receives a parsed DNS packet and returns a response packet.
pub type DnsResolverFn = Box<
dyn Fn(DnsPacket) -> Pin<Box<dyn Future<Output = DnsPacket> + Send>> + Send + Sync,
>;

View File

@@ -0,0 +1,95 @@
use rustdns_protocol::packet::DnsPacket;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tracing::{error, info};
/// Configuration for the UDP DNS server.
pub struct UdpServerConfig {
pub bind_addr: SocketAddr,
}
/// A UDP DNS server that delegates resolution to a callback.
pub struct UdpServer {
socket: Arc<UdpSocket>,
shutdown: tokio::sync::watch::Sender<bool>,
}
impl UdpServer {
/// Bind and start the UDP server. The resolver function is called for each query.
pub async fn start<F, Fut>(
config: UdpServerConfig,
resolver: F,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
where
F: Fn(DnsPacket) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = DnsPacket> + Send + 'static,
{
let socket = UdpSocket::bind(config.bind_addr).await?;
let socket = Arc::new(socket);
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
info!("UDP DNS server listening on {}", config.bind_addr);
let recv_socket = socket.clone();
let resolver = Arc::new(resolver);
tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = recv_socket.recv_from(&mut buf) => {
match result {
Ok((len, src)) => {
let data = buf[..len].to_vec();
let sock = recv_socket.clone();
let resolver = resolver.clone();
tokio::spawn(async move {
match DnsPacket::parse(&data) {
Ok(request) => {
let response = resolver(request).await;
let encoded = response.encode();
if let Err(e) = sock.send_to(&encoded, src).await {
error!("Failed to send UDP response: {}", e);
}
}
Err(e) => {
error!("Failed to parse DNS packet from {}: {}", src, e);
}
}
});
}
Err(e) => {
error!("UDP recv error: {}", e);
}
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("UDP DNS server shutting down");
break;
}
}
}
}
});
Ok(UdpServer {
socket,
shutdown: shutdown_tx,
})
}
/// Stop the UDP server.
pub fn stop(&self) {
let _ = self.shutdown.send(true);
}
/// Get the bound local address.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.socket.local_addr()
}
}

View File

@@ -0,0 +1,26 @@
[package]
name = "rustdns"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "rustdns"
path = "src/main.rs"
[lib]
name = "rustdns"
path = "src/lib.rs"
[dependencies]
rustdns-protocol = { path = "../rustdns-protocol" }
rustdns-dnssec = { path = "../rustdns-dnssec" }
rustdns-server = { path = "../rustdns-server" }
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
clap = { version = "4", features = ["derive"] }
tracing = "0.1"
tracing-subscriber = "0.3"
dashmap = "6"
base64 = "0.22"
rustls = { version = "0.23", features = ["ring"] }

View File

@@ -0,0 +1,125 @@
use serde::{Deserialize, Serialize};
/// IPC request from TypeScript to Rust (via stdin).
#[derive(Debug, Deserialize)]
pub struct IpcRequest {
pub id: String,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// IPC response from Rust to TypeScript (via stdout).
#[derive(Debug, Serialize)]
pub struct IpcResponse {
pub id: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl IpcResponse {
pub fn ok(id: String, result: serde_json::Value) -> Self {
IpcResponse {
id,
success: true,
result: Some(result),
error: None,
}
}
pub fn err(id: String, error: String) -> Self {
IpcResponse {
id,
success: false,
result: None,
error: Some(error),
}
}
}
/// IPC event from Rust to TypeScript (unsolicited, no id).
#[derive(Debug, Serialize)]
pub struct IpcEvent {
pub event: String,
pub data: serde_json::Value,
}
/// Configuration sent via the "start" command.
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RustDnsConfig {
pub udp_port: u16,
pub https_port: u16,
#[serde(default = "default_bind")]
pub udp_bind_interface: String,
#[serde(default = "default_bind")]
pub https_bind_interface: String,
#[serde(default)]
pub https_key: String,
#[serde(default)]
pub https_cert: String,
pub dnssec_zone: String,
#[serde(default = "default_algorithm")]
pub dnssec_algorithm: String,
#[serde(default)]
pub primary_nameserver: String,
#[serde(default = "default_true")]
pub enable_localhost_handling: bool,
#[serde(default)]
pub manual_udp_mode: bool,
#[serde(default)]
pub manual_https_mode: bool,
}
fn default_bind() -> String {
"0.0.0.0".to_string()
}
fn default_algorithm() -> String {
"ECDSA".to_string()
}
fn default_true() -> bool {
true
}
/// A DNS question as sent over IPC.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct IpcDnsQuestion {
pub name: String,
#[serde(rename = "type")]
pub qtype: String,
pub class: String,
}
/// A DNS answer as received from TypeScript over IPC.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct IpcDnsAnswer {
pub name: String,
#[serde(rename = "type")]
pub rtype: String,
pub class: String,
pub ttl: u32,
pub data: serde_json::Value,
}
/// The dnsQuery event sent from Rust to TypeScript.
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DnsQueryEvent {
pub correlation_id: String,
pub questions: Vec<IpcDnsQuestion>,
pub dnssec_requested: bool,
}
/// The dnsQueryResult command from TypeScript to Rust.
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DnsQueryResult {
pub correlation_id: String,
pub answers: Vec<IpcDnsAnswer>,
pub answered: bool,
}

View File

@@ -0,0 +1,3 @@
pub mod management;
pub mod ipc_types;
pub mod resolver;

View File

@@ -0,0 +1,36 @@
use clap::Parser;
use tracing_subscriber;
mod management;
mod ipc_types;
mod resolver;
#[derive(Parser, Debug)]
#[command(name = "rustdns", about = "Rust DNS server with IPC management")]
struct Cli {
/// Run in management mode (IPC via stdin/stdout)
#[arg(long)]
management: bool,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Install the default rustls crypto provider (ring) before any TLS operations
let _ = rustls::crypto::ring::default_provider().install_default();
let cli = Cli::parse();
// Tracing writes to stderr so stdout is reserved for IPC
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.init();
if cli.management {
management::management_loop().await?;
} else {
eprintln!("rustdns: use --management flag for IPC mode");
std::process::exit(1);
}
Ok(())
}

View File

@@ -0,0 +1,402 @@
use crate::ipc_types::*;
use crate::resolver::DnsResolver;
use dashmap::DashMap;
use rustdns_dnssec::keys::DnssecAlgorithm;
use rustdns_protocol::packet::DnsPacket;
use rustdns_server::https::{self, HttpsServer};
use rustdns_server::udp::{UdpServer, UdpServerConfig};
use std::io::{self, BufRead, Write};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tracing::{error, info};
/// Pending DNS query callbacks waiting for TypeScript response.
type PendingCallbacks = Arc<DashMap<String, oneshot::Sender<DnsQueryResult>>>;
/// Active server state.
struct ServerState {
udp_server: Option<UdpServer>,
https_server: Option<HttpsServer>,
resolver: Arc<DnsResolver>,
}
/// Emit a JSON event on stdout.
fn send_event(event: &str, data: serde_json::Value) {
let evt = IpcEvent {
event: event.to_string(),
data,
};
let json = serde_json::to_string(&evt).unwrap();
let stdout = io::stdout();
let mut lock = stdout.lock();
let _ = writeln!(lock, "{}", json);
let _ = lock.flush();
}
/// Send a JSON response on stdout.
fn send_response(response: &IpcResponse) {
let json = serde_json::to_string(response).unwrap();
let stdout = io::stdout();
let mut lock = stdout.lock();
let _ = writeln!(lock, "{}", json);
let _ = lock.flush();
}
/// Main management loop — reads JSON lines from stdin, dispatches commands.
pub async fn management_loop() -> Result<(), Box<dyn std::error::Error>> {
// Emit ready event
send_event("ready", serde_json::json!({
"version": env!("CARGO_PKG_VERSION")
}));
let pending: PendingCallbacks = Arc::new(DashMap::new());
let mut server_state: Option<ServerState> = None;
// Channel for stdin commands (read in blocking thread)
let (cmd_tx, mut cmd_rx) = mpsc::channel::<String>(256);
// Channel for DNS query events from the server
let (query_tx, mut query_rx) = mpsc::channel::<(String, DnsPacket)>(256);
// Spawn blocking stdin reader
std::thread::spawn(move || {
let stdin = io::stdin();
let reader = stdin.lock();
for line in reader.lines() {
match line {
Ok(l) => {
if cmd_tx.blocking_send(l).is_err() {
break; // channel closed
}
}
Err(_) => break, // stdin closed
}
}
});
loop {
tokio::select! {
cmd = cmd_rx.recv() => {
match cmd {
Some(line) => {
let request: IpcRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
error!("Failed to parse IPC request: {}", e);
continue;
}
};
let response = handle_request(
&request,
&mut server_state,
&pending,
&query_tx,
).await;
send_response(&response);
}
None => {
// stdin closed — parent process exited
info!("stdin closed, shutting down");
if let Some(ref state) = server_state {
if let Some(ref udp) = state.udp_server {
udp.stop();
}
if let Some(ref https) = state.https_server {
https.stop();
}
}
break;
}
}
}
query = query_rx.recv() => {
if let Some((correlation_id, packet)) = query {
let dnssec = packet.is_dnssec_requested();
let questions = DnsResolver::questions_to_ipc(&packet.questions);
send_event("dnsQuery", serde_json::to_value(&DnsQueryEvent {
correlation_id,
questions,
dnssec_requested: dnssec,
}).unwrap());
}
}
}
}
Ok(())
}
async fn handle_request(
request: &IpcRequest,
server_state: &mut Option<ServerState>,
pending: &PendingCallbacks,
query_tx: &mpsc::Sender<(String, DnsPacket)>,
) -> IpcResponse {
let id = request.id.clone();
match request.method.as_str() {
"ping" => IpcResponse::ok(id, serde_json::json!({ "pong": true })),
"start" => {
handle_start(id, &request.params, server_state, pending, query_tx).await
}
"stop" => {
handle_stop(id, server_state)
}
"dnsQueryResult" => {
handle_query_result(id, &request.params, pending)
}
"updateCerts" => {
// TODO: hot-swap TLS certs (requires rustls cert resolver)
IpcResponse::ok(id, serde_json::json!({}))
}
"processPacket" => {
handle_process_packet(id, &request.params, server_state, pending, query_tx).await
}
_ => IpcResponse::err(id, format!("Unknown method: {}", request.method)),
}
}
async fn handle_start(
id: String,
params: &serde_json::Value,
server_state: &mut Option<ServerState>,
pending: &PendingCallbacks,
query_tx: &mpsc::Sender<(String, DnsPacket)>,
) -> IpcResponse {
let config: RustDnsConfig = match serde_json::from_value(params.get("config").cloned().unwrap_or_default()) {
Ok(c) => c,
Err(e) => return IpcResponse::err(id, format!("Invalid config: {}", e)),
};
let algorithm = DnssecAlgorithm::from_str(&config.dnssec_algorithm)
.unwrap_or(DnssecAlgorithm::EcdsaP256Sha256);
let resolver = Arc::new(DnsResolver::new(
&config.dnssec_zone,
algorithm,
&config.primary_nameserver,
config.enable_localhost_handling,
));
// Start UDP server if not manual mode
let udp_server = if !config.manual_udp_mode {
let addr: SocketAddr = format!("{}:{}", config.udp_bind_interface, config.udp_port)
.parse()
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], config.udp_port)));
let resolver_clone = resolver.clone();
let pending_clone = pending.clone();
let query_tx_clone = query_tx.clone();
match UdpServer::start(
UdpServerConfig { bind_addr: addr },
move |packet| {
let resolver = resolver_clone.clone();
let pending = pending_clone.clone();
let query_tx = query_tx_clone.clone();
async move {
resolve_with_callback(packet, &resolver, &pending, &query_tx).await
}
},
).await {
Ok(server) => {
info!("UDP DNS server started on {}", addr);
Some(server)
}
Err(e) => {
return IpcResponse::err(id, format!("Failed to start UDP server: {}", e));
}
}
} else {
None
};
// Start HTTPS server if not manual mode and certs are provided
let https_server = if !config.manual_https_mode && !config.https_cert.is_empty() && !config.https_key.is_empty() {
let addr: SocketAddr = format!("{}:{}", config.https_bind_interface, config.https_port)
.parse()
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], config.https_port)));
match https::create_tls_config(&config.https_cert, &config.https_key) {
Ok(tls_config) => {
let resolver_clone = resolver.clone();
let pending_clone = pending.clone();
let query_tx_clone = query_tx.clone();
match HttpsServer::start(
https::HttpsServerConfig {
bind_addr: addr,
tls_config,
},
move |packet| {
let resolver = resolver_clone.clone();
let pending = pending_clone.clone();
let query_tx = query_tx_clone.clone();
async move {
resolve_with_callback(packet, &resolver, &pending, &query_tx).await
}
},
).await {
Ok(server) => {
info!("HTTPS DoH server started on {}", addr);
Some(server)
}
Err(e) => {
return IpcResponse::err(id, format!("Failed to start HTTPS server: {}", e));
}
}
}
Err(e) => {
return IpcResponse::err(id, format!("Failed to configure TLS: {}", e));
}
}
} else {
None
};
*server_state = Some(ServerState {
udp_server,
https_server,
resolver,
});
send_event("started", serde_json::json!({}));
IpcResponse::ok(id, serde_json::json!({}))
}
fn handle_stop(id: String, server_state: &mut Option<ServerState>) -> IpcResponse {
if let Some(ref state) = server_state {
if let Some(ref udp) = state.udp_server {
udp.stop();
}
if let Some(ref https) = state.https_server {
https.stop();
}
}
*server_state = None;
send_event("stopped", serde_json::json!({}));
IpcResponse::ok(id, serde_json::json!({}))
}
fn handle_query_result(
id: String,
params: &serde_json::Value,
pending: &PendingCallbacks,
) -> IpcResponse {
let result: DnsQueryResult = match serde_json::from_value(params.clone()) {
Ok(r) => r,
Err(e) => return IpcResponse::err(id, format!("Invalid query result: {}", e)),
};
let correlation_id = result.correlation_id.clone();
if let Some((_, sender)) = pending.remove(&correlation_id) {
let _ = sender.send(result);
IpcResponse::ok(id, serde_json::json!({ "resolved": true }))
} else {
IpcResponse::err(id, format!("No pending query for correlationId: {}", correlation_id))
}
}
async fn handle_process_packet(
id: String,
params: &serde_json::Value,
server_state: &mut Option<ServerState>,
pending: &PendingCallbacks,
query_tx: &mpsc::Sender<(String, DnsPacket)>,
) -> IpcResponse {
let packet_b64 = match params.get("packet").and_then(|v| v.as_str()) {
Some(p) => p,
None => return IpcResponse::err(id, "Missing packet parameter".to_string()),
};
let packet_data = match base64_decode(packet_b64) {
Ok(d) => d,
Err(e) => return IpcResponse::err(id, format!("Invalid base64: {}", e)),
};
let state = match server_state {
Some(ref s) => s,
None => return IpcResponse::err(id, "Server not started".to_string()),
};
let request = match DnsPacket::parse(&packet_data) {
Ok(p) => p,
Err(e) => return IpcResponse::err(id, format!("Failed to parse packet: {}", e)),
};
let response = resolve_with_callback(request, &state.resolver, pending, query_tx).await;
let encoded = response.encode();
use base64::Engine;
let response_b64 = base64::engine::general_purpose::STANDARD.encode(&encoded);
IpcResponse::ok(id, serde_json::json!({ "packet": response_b64 }))
}
/// Core resolution: try local first, then IPC callback to TypeScript.
async fn resolve_with_callback(
packet: DnsPacket,
resolver: &DnsResolver,
pending: &PendingCallbacks,
query_tx: &mpsc::Sender<(String, DnsPacket)>,
) -> DnsPacket {
// Try local resolution first (localhost, DNSKEY)
if let Some(response) = resolver.try_local_resolution(&packet) {
return response;
}
// Need IPC callback to TypeScript
let correlation_id = format!("dns_{}", uuid_v4());
let (tx, rx) = oneshot::channel();
pending.insert(correlation_id.clone(), tx);
// Send the query event to the management loop for emission
if query_tx.send((correlation_id.clone(), packet.clone())).await.is_err() {
pending.remove(&correlation_id);
return DnsPacket::new_response(&packet);
}
// Wait for the result with a timeout
match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
Ok(Ok(result)) => {
resolver.build_response_from_answers(&packet, &result.answers, result.answered)
}
Ok(Err(_)) => {
// Sender dropped
pending.remove(&correlation_id);
resolver.build_response_from_answers(&packet, &[], false)
}
Err(_) => {
// Timeout
pending.remove(&correlation_id);
resolver.build_response_from_answers(&packet, &[], false)
}
}
}
/// Simple UUID v4 generation (no external dep needed).
fn uuid_v4() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let random: u64 = nanos as u64 ^ (std::process::id() as u64 * 0x517cc1b727220a95);
format!("{:016x}{:016x}", nanos as u64, random)
}
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(input)
.map_err(|e| e.to_string())
}

View File

@@ -0,0 +1,258 @@
use crate::ipc_types::{IpcDnsAnswer, IpcDnsQuestion};
use rustdns_protocol::packet::*;
use rustdns_protocol::types::QType;
use rustdns_dnssec::keys::{DnssecAlgorithm, DnssecKeyPair};
use rustdns_dnssec::keytag::compute_key_tag;
use rustdns_dnssec::signing::generate_rrsig;
use std::collections::HashMap;
/// DNS resolver that builds responses from IPC callback answers.
pub struct DnsResolver {
pub zone: String,
pub primary_nameserver: String,
pub enable_localhost: bool,
pub key_pair: DnssecKeyPair,
pub dnskey_rdata: Vec<u8>,
pub key_tag: u16,
}
impl DnsResolver {
pub fn new(zone: &str, algorithm: DnssecAlgorithm, primary_nameserver: &str, enable_localhost: bool) -> Self {
let key_pair = DnssecKeyPair::generate(algorithm);
let dnskey_rdata = key_pair.dnskey_rdata();
let key_tag = compute_key_tag(&dnskey_rdata);
let primary_ns = if primary_nameserver.is_empty() {
format!("ns1.{}", zone)
} else {
primary_nameserver.to_string()
};
DnsResolver {
zone: zone.to_string(),
primary_nameserver: primary_ns,
enable_localhost,
key_pair,
dnskey_rdata,
key_tag,
}
}
/// Check if a query can be answered locally (localhost, DNSKEY).
/// Returns Some(answers) if handled locally, None if it needs IPC callback.
pub fn try_local_resolution(&self, packet: &DnsPacket) -> Option<DnsPacket> {
let dnssec = packet.is_dnssec_requested();
let mut response = DnsPacket::new_response(packet);
let mut all_local = true;
for q in &packet.questions {
if let Some(records) = self.try_local_question(q, dnssec) {
for r in records {
response.answers.push(r);
}
} else {
all_local = false;
}
}
if all_local && !packet.questions.is_empty() {
Some(response)
} else {
None
}
}
fn try_local_question(&self, q: &DnsQuestion, dnssec: bool) -> Option<Vec<DnsRecord>> {
let name_lower = q.name.to_lowercase();
let name_trimmed = name_lower.strip_suffix('.').unwrap_or(&name_lower);
// DNSKEY queries for our zone
if dnssec && q.qtype == QType::DNSKEY && name_trimmed == self.zone.to_lowercase() {
let record = build_record(&q.name, QType::DNSKEY, 3600, self.dnskey_rdata.clone());
let mut records = vec![record.clone()];
// Sign the DNSKEY record
let rrsig = generate_rrsig(&self.key_pair, &self.zone, &[record], &q.name, QType::DNSKEY);
records.push(rrsig);
return Some(records);
}
// Localhost handling (RFC 6761)
if self.enable_localhost {
if name_trimmed == "localhost" {
match q.qtype {
QType::A => {
return Some(vec![build_record(&q.name, QType::A, 0, encode_a("127.0.0.1"))]);
}
QType::AAAA => {
return Some(vec![build_record(&q.name, QType::AAAA, 0, encode_aaaa("::1"))]);
}
_ => {}
}
}
// Reverse localhost
if name_trimmed == "1.0.0.127.in-addr.arpa" && q.qtype == QType::PTR {
return Some(vec![build_record(&q.name, QType::PTR, 0, encode_name_rdata("localhost."))]);
}
}
None
}
/// Build a response from IPC callback answers.
pub fn build_response_from_answers(
&self,
request: &DnsPacket,
answers: &[IpcDnsAnswer],
answered: bool,
) -> DnsPacket {
let dnssec = request.is_dnssec_requested();
let mut response = DnsPacket::new_response(request);
if answered && !answers.is_empty() {
// Group answers by (name, type) for DNSSEC RRset signing
let mut rrset_map: HashMap<(String, QType), Vec<DnsRecord>> = HashMap::new();
for answer in answers {
let rtype = QType::from_str(&answer.rtype);
let rdata = self.encode_answer_rdata(rtype, &answer.data);
let record = build_record(&answer.name, rtype, answer.ttl, rdata);
response.answers.push(record.clone());
if dnssec {
let key = (answer.name.clone(), rtype);
rrset_map.entry(key).or_default().push(record);
}
}
// Sign RRsets
if dnssec {
for ((name, rtype), rrset) in &rrset_map {
let rrsig = generate_rrsig(&self.key_pair, &self.zone, rrset, name, *rtype);
response.answers.push(rrsig);
}
}
} else {
// No handler matched — return SOA
for q in &request.questions {
let soa_rdata = encode_soa(
&self.primary_nameserver,
&format!("hostmaster.{}", self.zone),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as u32,
3600,
600,
604800,
86400,
);
let soa_record = build_record(&q.name, QType::SOA, 3600, soa_rdata);
response.answers.push(soa_record.clone());
if dnssec {
let rrsig = generate_rrsig(&self.key_pair, &self.zone, &[soa_record], &q.name, QType::SOA);
response.answers.push(rrsig);
}
}
}
response
}
/// Process a raw DNS packet (for manual/passthrough mode).
/// Returns local answers or None if IPC callback is needed.
pub fn process_packet_local(&self, data: &[u8]) -> Result<Option<Vec<u8>>, String> {
let packet = DnsPacket::parse(data)?;
if let Some(response) = self.try_local_resolution(&packet) {
Ok(Some(response.encode()))
} else {
Ok(None)
}
}
fn encode_answer_rdata(&self, rtype: QType, data: &serde_json::Value) -> Vec<u8> {
match rtype {
QType::A => {
if let Some(ip) = data.as_str() {
encode_a(ip)
} else {
vec![]
}
}
QType::AAAA => {
if let Some(ip) = data.as_str() {
encode_aaaa(ip)
} else {
vec![]
}
}
QType::TXT => {
if let Some(arr) = data.as_array() {
let strings: Vec<String> = arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect();
encode_txt(&strings)
} else if let Some(s) = data.as_str() {
encode_txt(&[s.to_string()])
} else {
vec![]
}
}
QType::NS | QType::CNAME | QType::PTR => {
if let Some(name) = data.as_str() {
encode_name_rdata(name)
} else {
vec![]
}
}
QType::MX => {
let preference = data.get("preference").and_then(|v| v.as_u64()).unwrap_or(10) as u16;
let exchange = data.get("exchange").and_then(|v| v.as_str()).unwrap_or("");
encode_mx(preference, exchange)
}
QType::SRV => {
let priority = data.get("priority").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let weight = data.get("weight").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let port = data.get("port").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let target = data.get("target").and_then(|v| v.as_str()).unwrap_or("");
encode_srv(priority, weight, port, target)
}
QType::SOA => {
let mname = data.get("mname").and_then(|v| v.as_str()).unwrap_or("");
let rname = data.get("rname").and_then(|v| v.as_str()).unwrap_or("");
let serial = data.get("serial").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let refresh = data.get("refresh").and_then(|v| v.as_u64()).unwrap_or(3600) as u32;
let retry = data.get("retry").and_then(|v| v.as_u64()).unwrap_or(600) as u32;
let expire = data.get("expire").and_then(|v| v.as_u64()).unwrap_or(604800) as u32;
let minimum = data.get("minimum").and_then(|v| v.as_u64()).unwrap_or(86400) as u32;
encode_soa(mname, rname, serial, refresh, retry, expire, minimum)
}
_ => {
// For unknown types, try to interpret as raw base64
if let Some(b64) = data.as_str() {
base64_decode(b64).unwrap_or_default()
} else {
vec![]
}
}
}
}
/// Convert questions to IPC format.
pub fn questions_to_ipc(questions: &[DnsQuestion]) -> Vec<IpcDnsQuestion> {
questions
.iter()
.map(|q| IpcDnsQuestion {
name: q.name.clone(),
qtype: q.qtype.as_str().to_string(),
class: "IN".to_string(),
})
.collect()
}
}
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(input)
.map_err(|e| e.to_string())
}