feat(rust): add Rust-based DNS server backend with IPC management and TypeScript bridge
This commit is contained in:
2
rust/.cargo/config.toml
Normal file
2
rust/.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
1446
rust/Cargo.lock
generated
Normal file
1446
rust/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
8
rust/Cargo.toml
Normal file
8
rust/Cargo.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/rustdns",
|
||||
"crates/rustdns-protocol",
|
||||
"crates/rustdns-server",
|
||||
"crates/rustdns-dnssec",
|
||||
]
|
||||
11
rust/crates/rustdns-dnssec/Cargo.toml
Normal file
11
rust/crates/rustdns-dnssec/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "rustdns-dnssec"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
rustdns-protocol = { path = "../rustdns-protocol" }
|
||||
p256 = { version = "0.13", features = ["ecdsa", "ecdsa-core"] }
|
||||
ed25519-dalek = { version = "2", features = ["rand_core"] }
|
||||
sha2 = "0.10"
|
||||
rand = "0.8"
|
||||
157
rust/crates/rustdns-dnssec/src/keys.rs
Normal file
157
rust/crates/rustdns-dnssec/src/keys.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use p256::ecdsa::SigningKey as EcdsaSigningKey;
|
||||
use ed25519_dalek::SigningKey as Ed25519SigningKey;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
/// Supported DNSSEC algorithms.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DnssecAlgorithm {
|
||||
/// ECDSA P-256 with SHA-256 (algorithm 13)
|
||||
EcdsaP256Sha256,
|
||||
/// ED25519 (algorithm 15)
|
||||
Ed25519,
|
||||
}
|
||||
|
||||
impl DnssecAlgorithm {
|
||||
pub fn number(&self) -> u8 {
|
||||
match self {
|
||||
DnssecAlgorithm::EcdsaP256Sha256 => 13,
|
||||
DnssecAlgorithm::Ed25519 => 15,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_uppercase().as_str() {
|
||||
"ECDSA" | "ECDSAP256SHA256" => Some(DnssecAlgorithm::EcdsaP256Sha256),
|
||||
"ED25519" => Some(DnssecAlgorithm::Ed25519),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A DNSSEC key pair with material for signing and DNSKEY generation.
|
||||
pub enum DnssecKeyPair {
|
||||
EcdsaP256 {
|
||||
signing_key: EcdsaSigningKey,
|
||||
},
|
||||
Ed25519 {
|
||||
signing_key: Ed25519SigningKey,
|
||||
},
|
||||
}
|
||||
|
||||
impl DnssecKeyPair {
|
||||
/// Generate a new key pair for the given algorithm.
|
||||
pub fn generate(algorithm: DnssecAlgorithm) -> Self {
|
||||
match algorithm {
|
||||
DnssecAlgorithm::EcdsaP256Sha256 => {
|
||||
let signing_key = EcdsaSigningKey::random(&mut OsRng);
|
||||
DnssecKeyPair::EcdsaP256 { signing_key }
|
||||
}
|
||||
DnssecAlgorithm::Ed25519 => {
|
||||
let signing_key = Ed25519SigningKey::generate(&mut OsRng);
|
||||
DnssecKeyPair::Ed25519 { signing_key }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the algorithm.
|
||||
pub fn algorithm(&self) -> DnssecAlgorithm {
|
||||
match self {
|
||||
DnssecKeyPair::EcdsaP256 { .. } => DnssecAlgorithm::EcdsaP256Sha256,
|
||||
DnssecKeyPair::Ed25519 { .. } => DnssecAlgorithm::Ed25519,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the public key bytes for the DNSKEY record.
|
||||
/// For ECDSA P-256: 64 bytes (uncompressed x || y, without 0x04 prefix).
|
||||
/// For ED25519: 32 bytes.
|
||||
pub fn public_key_bytes(&self) -> Vec<u8> {
|
||||
match self {
|
||||
DnssecKeyPair::EcdsaP256 { signing_key } => {
|
||||
use p256::ecdsa::VerifyingKey;
|
||||
let verifying_key = VerifyingKey::from(signing_key);
|
||||
let point = verifying_key.to_encoded_point(false); // uncompressed
|
||||
let bytes = point.as_bytes();
|
||||
// Remove 0x04 prefix for DNS format
|
||||
bytes[1..].to_vec()
|
||||
}
|
||||
DnssecKeyPair::Ed25519 { signing_key } => {
|
||||
let verifying_key = signing_key.verifying_key();
|
||||
verifying_key.as_bytes().to_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the DNSKEY RDATA (flags=256/ZSK, protocol=3, algorithm, public key).
|
||||
pub fn dnskey_rdata(&self) -> Vec<u8> {
|
||||
let flags: u16 = 256; // Zone Signing Key
|
||||
let protocol: u8 = 3;
|
||||
let algorithm = self.algorithm().number();
|
||||
let pubkey = self.public_key_bytes();
|
||||
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&flags.to_be_bytes());
|
||||
buf.push(protocol);
|
||||
buf.push(algorithm);
|
||||
buf.extend_from_slice(&pubkey);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Sign data with this key pair.
|
||||
pub fn sign(&self, data: &[u8]) -> Vec<u8> {
|
||||
match self {
|
||||
DnssecKeyPair::EcdsaP256 { signing_key } => {
|
||||
use p256::ecdsa::{signature::Signer, Signature};
|
||||
let sig: Signature = signing_key.sign(data);
|
||||
sig.to_der().as_bytes().to_vec()
|
||||
}
|
||||
DnssecKeyPair::Ed25519 { signing_key } => {
|
||||
use ed25519_dalek::Signer;
|
||||
let sig = signing_key.sign(data);
|
||||
sig.to_bytes().to_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ecdsa_key_generation() {
|
||||
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
|
||||
assert_eq!(kp.algorithm(), DnssecAlgorithm::EcdsaP256Sha256);
|
||||
assert_eq!(kp.public_key_bytes().len(), 64); // x(32) + y(32)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ed25519_key_generation() {
|
||||
let kp = DnssecKeyPair::generate(DnssecAlgorithm::Ed25519);
|
||||
assert_eq!(kp.algorithm(), DnssecAlgorithm::Ed25519);
|
||||
assert_eq!(kp.public_key_bytes().len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dnskey_rdata() {
|
||||
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
|
||||
let rdata = kp.dnskey_rdata();
|
||||
// flags(2) + protocol(1) + algorithm(1) + pubkey(64) = 68
|
||||
assert_eq!(rdata.len(), 68);
|
||||
assert_eq!(rdata[0], 1); // flags high byte (256 >> 8)
|
||||
assert_eq!(rdata[1], 0); // flags low byte
|
||||
assert_eq!(rdata[2], 3); // protocol
|
||||
assert_eq!(rdata[3], 13); // algorithm 13 = ECDSA P-256
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sign_and_verify() {
|
||||
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
|
||||
let data = b"test data to sign";
|
||||
let sig = kp.sign(data);
|
||||
assert!(!sig.is_empty());
|
||||
|
||||
let kp2 = DnssecKeyPair::generate(DnssecAlgorithm::Ed25519);
|
||||
let sig2 = kp2.sign(data);
|
||||
assert!(!sig2.is_empty());
|
||||
}
|
||||
}
|
||||
38
rust/crates/rustdns-dnssec/src/keytag.rs
Normal file
38
rust/crates/rustdns-dnssec/src/keytag.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
/// Compute the DNSSEC key tag as per RFC 4034 Appendix B.
|
||||
/// Input is the full DNSKEY RDATA (flags + protocol + algorithm + public key).
|
||||
pub fn compute_key_tag(dnskey_rdata: &[u8]) -> u16 {
|
||||
let mut acc: u32 = 0;
|
||||
for (i, &byte) in dnskey_rdata.iter().enumerate() {
|
||||
if i & 1 == 0 {
|
||||
acc += (byte as u32) << 8;
|
||||
} else {
|
||||
acc += byte as u32;
|
||||
}
|
||||
}
|
||||
acc += (acc >> 16) & 0xFFFF;
|
||||
(acc & 0xFFFF) as u16
|
||||
}
|
||||
|
||||
/// Compute a DS record digest (SHA-256) from owner name + DNSKEY RDATA.
|
||||
pub fn compute_ds_digest(owner_name_wire: &[u8], dnskey_rdata: &[u8]) -> Vec<u8> {
|
||||
use sha2::{Sha256, Digest};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(owner_name_wire);
|
||||
hasher.update(dnskey_rdata);
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_key_tag_computation() {
|
||||
// A known DNSKEY RDATA: flags=256, protocol=3, algorithm=13, plus some key bytes
|
||||
let mut rdata = vec![1u8, 0, 3, 13]; // flags=256, protocol=3, algorithm=13
|
||||
rdata.extend_from_slice(&[0u8; 64]); // dummy 64-byte key
|
||||
let tag = compute_key_tag(&rdata);
|
||||
// Just verify it produces a reasonable value
|
||||
assert!(tag > 0);
|
||||
}
|
||||
}
|
||||
3
rust/crates/rustdns-dnssec/src/lib.rs
Normal file
3
rust/crates/rustdns-dnssec/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod keys;
|
||||
pub mod signing;
|
||||
pub mod keytag;
|
||||
147
rust/crates/rustdns-dnssec/src/signing.rs
Normal file
147
rust/crates/rustdns-dnssec/src/signing.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use crate::keys::DnssecKeyPair;
|
||||
use crate::keytag::compute_key_tag;
|
||||
use rustdns_protocol::name::encode_name;
|
||||
use rustdns_protocol::packet::{encode_rrsig, DnsRecord};
|
||||
use rustdns_protocol::types::QType;
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
/// Canonical RRset serialization for DNSSEC signing (RFC 4034 Section 6).
|
||||
/// Each record: name(wire) + type(2) + class(2) + ttl(4) + rdlength(2) + rdata
|
||||
pub fn serialize_rrset_canonical(records: &[DnsRecord]) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
for rr in records {
|
||||
if rr.rtype == QType::OPT {
|
||||
continue;
|
||||
}
|
||||
let name = if rr.name.ends_with('.') {
|
||||
rr.name.to_lowercase()
|
||||
} else {
|
||||
format!("{}.", rr.name).to_lowercase()
|
||||
};
|
||||
buf.extend_from_slice(&encode_name(&name));
|
||||
buf.extend_from_slice(&rr.rtype.to_u16().to_be_bytes());
|
||||
buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes());
|
||||
buf.extend_from_slice(&rr.ttl.to_be_bytes());
|
||||
buf.extend_from_slice(&(rr.rdata.len() as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&rr.rdata);
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Generate an RRSIG record for a given RRset.
|
||||
pub fn generate_rrsig(
|
||||
key_pair: &DnssecKeyPair,
|
||||
zone: &str,
|
||||
rrset: &[DnsRecord],
|
||||
name: &str,
|
||||
rtype: QType,
|
||||
) -> DnsRecord {
|
||||
let algorithm = key_pair.algorithm().number();
|
||||
let dnskey_rdata = key_pair.dnskey_rdata();
|
||||
let key_tag = compute_key_tag(&dnskey_rdata);
|
||||
|
||||
let signers_name = if zone.ends_with('.') {
|
||||
zone.to_string()
|
||||
} else {
|
||||
format!("{}.", zone)
|
||||
};
|
||||
|
||||
let ttl = if rrset.is_empty() { 300 } else { rrset[0].ttl };
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as u32;
|
||||
let inception = now.wrapping_sub(3600); // 1 hour ago
|
||||
let expiration = inception.wrapping_add(86400); // +1 day
|
||||
|
||||
let labels = name
|
||||
.strip_suffix('.')
|
||||
.unwrap_or(name)
|
||||
.split('.')
|
||||
.filter(|l| !l.is_empty())
|
||||
.count() as u8;
|
||||
|
||||
// Build the RRSIG RDATA preamble (everything before the signature)
|
||||
let type_covered = rtype.to_u16();
|
||||
let mut sig_data = Vec::new();
|
||||
sig_data.extend_from_slice(&type_covered.to_be_bytes());
|
||||
sig_data.push(algorithm);
|
||||
sig_data.push(labels);
|
||||
sig_data.extend_from_slice(&ttl.to_be_bytes());
|
||||
sig_data.extend_from_slice(&expiration.to_be_bytes());
|
||||
sig_data.extend_from_slice(&inception.to_be_bytes());
|
||||
sig_data.extend_from_slice(&key_tag.to_be_bytes());
|
||||
sig_data.extend_from_slice(&encode_name(&signers_name));
|
||||
|
||||
// Append the canonical RRset
|
||||
sig_data.extend_from_slice(&serialize_rrset_canonical(rrset));
|
||||
|
||||
// Sign: ECDSA uses SHA-256 internally via the p256 crate, ED25519 does its own hashing
|
||||
let signature = match key_pair {
|
||||
DnssecKeyPair::EcdsaP256 { .. } => {
|
||||
// For ECDSA, we hash first then sign
|
||||
let hash = Sha256::digest(&sig_data);
|
||||
key_pair.sign(&hash)
|
||||
}
|
||||
DnssecKeyPair::Ed25519 { .. } => {
|
||||
// ED25519 includes hashing internally
|
||||
key_pair.sign(&sig_data)
|
||||
}
|
||||
};
|
||||
|
||||
let rrsig_rdata = encode_rrsig(
|
||||
type_covered,
|
||||
algorithm,
|
||||
labels,
|
||||
ttl,
|
||||
expiration,
|
||||
inception,
|
||||
key_tag,
|
||||
&signers_name,
|
||||
&signature,
|
||||
);
|
||||
|
||||
DnsRecord {
|
||||
name: name.to_string(),
|
||||
rtype: QType::RRSIG,
|
||||
rclass: rustdns_protocol::types::QClass::IN,
|
||||
ttl,
|
||||
rdata: rrsig_rdata,
|
||||
opt_flags: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::keys::{DnssecAlgorithm, DnssecKeyPair};
|
||||
use rustdns_protocol::packet::{build_record, encode_a};
|
||||
|
||||
#[test]
|
||||
fn test_generate_rrsig_ecdsa() {
|
||||
let kp = DnssecKeyPair::generate(DnssecAlgorithm::EcdsaP256Sha256);
|
||||
let record = build_record("test.example.com", QType::A, 300, encode_a("127.0.0.1"));
|
||||
let rrsig = generate_rrsig(&kp, "example.com", &[record], "test.example.com", QType::A);
|
||||
|
||||
assert_eq!(rrsig.rtype, QType::RRSIG);
|
||||
assert!(!rrsig.rdata.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_rrsig_ed25519() {
|
||||
let kp = DnssecKeyPair::generate(DnssecAlgorithm::Ed25519);
|
||||
let record = build_record("test.example.com", QType::A, 300, encode_a("10.0.0.1"));
|
||||
let rrsig = generate_rrsig(&kp, "example.com", &[record], "test.example.com", QType::A);
|
||||
|
||||
assert_eq!(rrsig.rtype, QType::RRSIG);
|
||||
assert!(!rrsig.rdata.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_rrset_canonical() {
|
||||
let r1 = build_record("example.com", QType::A, 300, encode_a("1.2.3.4"));
|
||||
let r2 = build_record("example.com", QType::A, 300, encode_a("5.6.7.8"));
|
||||
let serialized = serialize_rrset_canonical(&[r1, r2]);
|
||||
assert!(!serialized.is_empty());
|
||||
}
|
||||
}
|
||||
6
rust/crates/rustdns-protocol/Cargo.toml
Normal file
6
rust/crates/rustdns-protocol/Cargo.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[package]
|
||||
name = "rustdns-protocol"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
3
rust/crates/rustdns-protocol/src/lib.rs
Normal file
3
rust/crates/rustdns-protocol/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod types;
|
||||
pub mod name;
|
||||
pub mod packet;
|
||||
108
rust/crates/rustdns-protocol/src/name.rs
Normal file
108
rust/crates/rustdns-protocol/src/name.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
/// Encode a domain name into DNS wire format.
|
||||
/// e.g. "example.com" -> [7, 'e','x','a','m','p','l','e', 3, 'c','o','m', 0]
|
||||
pub fn encode_name(name: &str) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
let trimmed = name.strip_suffix('.').unwrap_or(name);
|
||||
if trimmed.is_empty() {
|
||||
buf.push(0);
|
||||
return buf;
|
||||
}
|
||||
for label in trimmed.split('.') {
|
||||
let len = label.len();
|
||||
if len > 63 {
|
||||
// Truncate to 63 per DNS spec
|
||||
buf.push(63);
|
||||
buf.extend_from_slice(&label.as_bytes()[..63]);
|
||||
} else {
|
||||
buf.push(len as u8);
|
||||
buf.extend_from_slice(label.as_bytes());
|
||||
}
|
||||
}
|
||||
buf.push(0); // root label
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode a domain name from DNS wire format at the given offset.
|
||||
/// Returns (name, bytes_consumed).
|
||||
/// Handles compression pointers (0xC0 prefix).
|
||||
pub fn decode_name(data: &[u8], offset: usize) -> Result<(String, usize), &'static str> {
|
||||
let mut labels: Vec<String> = Vec::new();
|
||||
let mut pos = offset;
|
||||
let mut bytes_consumed = 0;
|
||||
let mut jumped = false;
|
||||
|
||||
loop {
|
||||
if pos >= data.len() {
|
||||
return Err("unexpected end of data in name");
|
||||
}
|
||||
let len = data[pos] as usize;
|
||||
|
||||
if len == 0 {
|
||||
// Root label
|
||||
if !jumped {
|
||||
bytes_consumed = pos - offset + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Check for compression pointer
|
||||
if len & 0xC0 == 0xC0 {
|
||||
if pos + 1 >= data.len() {
|
||||
return Err("unexpected end of data in compression pointer");
|
||||
}
|
||||
let pointer = ((len & 0x3F) << 8) | (data[pos + 1] as usize);
|
||||
if !jumped {
|
||||
bytes_consumed = pos - offset + 2;
|
||||
jumped = true;
|
||||
}
|
||||
pos = pointer;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Regular label
|
||||
pos += 1;
|
||||
if pos + len > data.len() {
|
||||
return Err("label extends beyond data");
|
||||
}
|
||||
let label = std::str::from_utf8(&data[pos..pos + len]).map_err(|_| "invalid UTF-8 in label")?;
|
||||
labels.push(label.to_string());
|
||||
pos += len;
|
||||
}
|
||||
|
||||
if bytes_consumed == 0 && !jumped {
|
||||
bytes_consumed = 1; // just the root label
|
||||
}
|
||||
|
||||
Ok((labels.join("."), bytes_consumed))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_roundtrip() {
|
||||
let names = vec!["example.com", "sub.domain.example.com", "localhost", "a.b.c.d.e"];
|
||||
for name in names {
|
||||
let encoded = encode_name(name);
|
||||
let (decoded, consumed) = decode_name(&encoded, 0).unwrap();
|
||||
assert_eq!(decoded, name);
|
||||
assert_eq!(consumed, encoded.len());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_trailing_dot() {
|
||||
let a = encode_name("example.com.");
|
||||
let b = encode_name("example.com");
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_root_name() {
|
||||
let encoded = encode_name("");
|
||||
assert_eq!(encoded, vec![0]);
|
||||
let (decoded, _) = decode_name(&encoded, 0).unwrap();
|
||||
assert_eq!(decoded, "");
|
||||
}
|
||||
}
|
||||
442
rust/crates/rustdns-protocol/src/packet.rs
Normal file
442
rust/crates/rustdns-protocol/src/packet.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use crate::name::{decode_name, encode_name};
|
||||
use crate::types::{QClass, QType, FLAG_QR, FLAG_AA, FLAG_RD, FLAG_RA, EDNS_DO_BIT};
|
||||
|
||||
/// A parsed DNS question.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DnsQuestion {
|
||||
pub name: String,
|
||||
pub qtype: QType,
|
||||
pub qclass: QClass,
|
||||
}
|
||||
|
||||
/// A parsed DNS resource record.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DnsRecord {
|
||||
pub name: String,
|
||||
pub rtype: QType,
|
||||
pub rclass: QClass,
|
||||
pub ttl: u32,
|
||||
pub rdata: Vec<u8>,
|
||||
// For OPT records, the flags are stored in the TTL field position
|
||||
pub opt_flags: Option<u16>,
|
||||
}
|
||||
|
||||
/// A complete DNS packet (parsed).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DnsPacket {
|
||||
pub id: u16,
|
||||
pub flags: u16,
|
||||
pub questions: Vec<DnsQuestion>,
|
||||
pub answers: Vec<DnsRecord>,
|
||||
pub authorities: Vec<DnsRecord>,
|
||||
pub additionals: Vec<DnsRecord>,
|
||||
}
|
||||
|
||||
impl DnsPacket {
|
||||
/// Create a new empty query packet.
|
||||
pub fn new_query(id: u16) -> Self {
|
||||
DnsPacket {
|
||||
id,
|
||||
flags: 0,
|
||||
questions: Vec::new(),
|
||||
answers: Vec::new(),
|
||||
authorities: Vec::new(),
|
||||
additionals: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a response packet for a given request.
|
||||
pub fn new_response(request: &DnsPacket) -> Self {
|
||||
let mut flags = FLAG_QR | FLAG_AA | FLAG_RA;
|
||||
if request.flags & FLAG_RD != 0 {
|
||||
flags |= FLAG_RD;
|
||||
}
|
||||
DnsPacket {
|
||||
id: request.id,
|
||||
flags,
|
||||
questions: request.questions.clone(),
|
||||
answers: Vec::new(),
|
||||
authorities: Vec::new(),
|
||||
additionals: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if DNSSEC (DO bit) is requested in the OPT record.
|
||||
pub fn is_dnssec_requested(&self) -> bool {
|
||||
for additional in &self.additionals {
|
||||
if additional.rtype == QType::OPT {
|
||||
if let Some(flags) = additional.opt_flags {
|
||||
if flags & EDNS_DO_BIT != 0 {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Parse a DNS packet from wire format bytes.
|
||||
pub fn parse(data: &[u8]) -> Result<Self, String> {
|
||||
if data.len() < 12 {
|
||||
return Err("packet too short for DNS header".into());
|
||||
}
|
||||
|
||||
let id = u16::from_be_bytes([data[0], data[1]]);
|
||||
let flags = u16::from_be_bytes([data[2], data[3]]);
|
||||
let qdcount = u16::from_be_bytes([data[4], data[5]]) as usize;
|
||||
let ancount = u16::from_be_bytes([data[6], data[7]]) as usize;
|
||||
let nscount = u16::from_be_bytes([data[8], data[9]]) as usize;
|
||||
let arcount = u16::from_be_bytes([data[10], data[11]]) as usize;
|
||||
|
||||
let mut offset = 12;
|
||||
|
||||
// Parse questions
|
||||
let mut questions = Vec::with_capacity(qdcount);
|
||||
for _ in 0..qdcount {
|
||||
let (name, consumed) = decode_name(data, offset).map_err(|e| e.to_string())?;
|
||||
offset += consumed;
|
||||
if offset + 4 > data.len() {
|
||||
return Err("packet too short for question fields".into());
|
||||
}
|
||||
let qtype = QType::from_u16(u16::from_be_bytes([data[offset], data[offset + 1]]));
|
||||
let qclass = QClass::from_u16(u16::from_be_bytes([data[offset + 2], data[offset + 3]]));
|
||||
offset += 4;
|
||||
questions.push(DnsQuestion { name, qtype, qclass });
|
||||
}
|
||||
|
||||
// Parse resource records
|
||||
fn parse_records(data: &[u8], offset: &mut usize, count: usize) -> Result<Vec<DnsRecord>, String> {
|
||||
let mut records = Vec::with_capacity(count);
|
||||
for _ in 0..count {
|
||||
let (name, consumed) = decode_name(data, *offset).map_err(|e| e.to_string())?;
|
||||
*offset += consumed;
|
||||
if *offset + 10 > data.len() {
|
||||
return Err("packet too short for RR fields".into());
|
||||
}
|
||||
let rtype = QType::from_u16(u16::from_be_bytes([data[*offset], data[*offset + 1]]));
|
||||
let rclass_or_payload = u16::from_be_bytes([data[*offset + 2], data[*offset + 3]]);
|
||||
let ttl_bytes = u32::from_be_bytes([data[*offset + 4], data[*offset + 5], data[*offset + 6], data[*offset + 7]]);
|
||||
let rdlength = u16::from_be_bytes([data[*offset + 8], data[*offset + 9]]) as usize;
|
||||
*offset += 10;
|
||||
if *offset + rdlength > data.len() {
|
||||
return Err("packet too short for RDATA".into());
|
||||
}
|
||||
let rdata = data[*offset..*offset + rdlength].to_vec();
|
||||
*offset += rdlength;
|
||||
|
||||
// For OPT records, extract flags from the TTL position
|
||||
let (rclass, ttl, opt_flags) = if rtype == QType::OPT {
|
||||
// OPT: class = UDP payload size, TTL upper 16 = extended RCODE + version, lower 16 = flags
|
||||
let flags = (ttl_bytes & 0xFFFF) as u16;
|
||||
(QClass::from_u16(rclass_or_payload), 0, Some(flags))
|
||||
} else {
|
||||
(QClass::from_u16(rclass_or_payload), ttl_bytes, None)
|
||||
};
|
||||
|
||||
records.push(DnsRecord {
|
||||
name,
|
||||
rtype,
|
||||
rclass,
|
||||
ttl,
|
||||
rdata,
|
||||
opt_flags,
|
||||
});
|
||||
}
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
let answers = parse_records(data, &mut offset, ancount)?;
|
||||
let authorities = parse_records(data, &mut offset, nscount)?;
|
||||
let additionals = parse_records(data, &mut offset, arcount)?;
|
||||
|
||||
Ok(DnsPacket {
|
||||
id,
|
||||
flags,
|
||||
questions,
|
||||
answers,
|
||||
authorities,
|
||||
additionals,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode this DNS packet to wire format bytes.
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
let mut buf = Vec::with_capacity(512);
|
||||
|
||||
// Header
|
||||
buf.extend_from_slice(&self.id.to_be_bytes());
|
||||
buf.extend_from_slice(&self.flags.to_be_bytes());
|
||||
buf.extend_from_slice(&(self.questions.len() as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&(self.answers.len() as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes());
|
||||
|
||||
// Questions
|
||||
for q in &self.questions {
|
||||
buf.extend_from_slice(&encode_name(&q.name));
|
||||
buf.extend_from_slice(&q.qtype.to_u16().to_be_bytes());
|
||||
buf.extend_from_slice(&q.qclass.to_u16().to_be_bytes());
|
||||
}
|
||||
|
||||
// Resource records
|
||||
fn encode_records(buf: &mut Vec<u8>, records: &[DnsRecord]) {
|
||||
for rr in records {
|
||||
buf.extend_from_slice(&encode_name(&rr.name));
|
||||
buf.extend_from_slice(&rr.rtype.to_u16().to_be_bytes());
|
||||
if rr.rtype == QType::OPT {
|
||||
// OPT: class = UDP payload size (4096), TTL = ext rcode + flags
|
||||
buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes());
|
||||
let flags = rr.opt_flags.unwrap_or(0) as u32;
|
||||
buf.extend_from_slice(&flags.to_be_bytes());
|
||||
} else {
|
||||
buf.extend_from_slice(&rr.rclass.to_u16().to_be_bytes());
|
||||
buf.extend_from_slice(&rr.ttl.to_be_bytes());
|
||||
}
|
||||
buf.extend_from_slice(&(rr.rdata.len() as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&rr.rdata);
|
||||
}
|
||||
}
|
||||
|
||||
encode_records(&mut buf, &self.answers);
|
||||
encode_records(&mut buf, &self.authorities);
|
||||
encode_records(&mut buf, &self.additionals);
|
||||
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
// ── RDATA encoding helpers ─────────────────────────────────────────
|
||||
|
||||
/// Encode an A record (IPv4 address string -> 4 bytes).
|
||||
pub fn encode_a(ip: &str) -> Vec<u8> {
|
||||
ip.split('.')
|
||||
.filter_map(|s| s.parse::<u8>().ok())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Encode an AAAA record (IPv6 address string -> 16 bytes).
|
||||
pub fn encode_aaaa(ip: &str) -> Vec<u8> {
|
||||
// Handle :: expansion
|
||||
let expanded = expand_ipv6(ip);
|
||||
expanded
|
||||
.split(':')
|
||||
.flat_map(|seg| {
|
||||
let val = u16::from_str_radix(seg, 16).unwrap_or(0);
|
||||
val.to_be_bytes().to_vec()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn expand_ipv6(ip: &str) -> String {
|
||||
if !ip.contains("::") {
|
||||
return ip.to_string();
|
||||
}
|
||||
let parts: Vec<&str> = ip.split("::").collect();
|
||||
let left: Vec<&str> = if parts[0].is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
parts[0].split(':').collect()
|
||||
};
|
||||
let right: Vec<&str> = if parts.len() > 1 && !parts[1].is_empty() {
|
||||
parts[1].split(':').collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let fill_count = 8 - left.len() - right.len();
|
||||
let mut result: Vec<String> = left.iter().map(|s| s.to_string()).collect();
|
||||
for _ in 0..fill_count {
|
||||
result.push("0".to_string());
|
||||
}
|
||||
result.extend(right.iter().map(|s| s.to_string()));
|
||||
result.join(":")
|
||||
}
|
||||
|
||||
/// Encode a TXT record (array of strings -> length-prefixed chunks).
|
||||
pub fn encode_txt(strings: &[String]) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
for s in strings {
|
||||
let bytes = s.as_bytes();
|
||||
// TXT strings must be <= 255 bytes each
|
||||
let len = bytes.len().min(255);
|
||||
buf.push(len as u8);
|
||||
buf.extend_from_slice(&bytes[..len]);
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode a domain name for use in RDATA (NS, CNAME, PTR, etc.).
|
||||
pub fn encode_name_rdata(name: &str) -> Vec<u8> {
|
||||
encode_name(name)
|
||||
}
|
||||
|
||||
/// Encode a SOA record RDATA.
|
||||
pub fn encode_soa(mname: &str, rname: &str, serial: u32, refresh: u32, retry: u32, expire: u32, minimum: u32) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&encode_name(mname));
|
||||
buf.extend_from_slice(&encode_name(rname));
|
||||
buf.extend_from_slice(&serial.to_be_bytes());
|
||||
buf.extend_from_slice(&refresh.to_be_bytes());
|
||||
buf.extend_from_slice(&retry.to_be_bytes());
|
||||
buf.extend_from_slice(&expire.to_be_bytes());
|
||||
buf.extend_from_slice(&minimum.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode an MX record RDATA.
|
||||
pub fn encode_mx(preference: u16, exchange: &str) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&preference.to_be_bytes());
|
||||
buf.extend_from_slice(&encode_name(exchange));
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode a SRV record RDATA.
|
||||
pub fn encode_srv(priority: u16, weight: u16, port: u16, target: &str) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&priority.to_be_bytes());
|
||||
buf.extend_from_slice(&weight.to_be_bytes());
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf.extend_from_slice(&encode_name(target));
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode a DNSKEY record RDATA.
|
||||
pub fn encode_dnskey(flags: u16, protocol: u8, algorithm: u8, public_key: &[u8]) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&flags.to_be_bytes());
|
||||
buf.push(protocol);
|
||||
buf.push(algorithm);
|
||||
buf.extend_from_slice(public_key);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode an RRSIG record RDATA.
|
||||
pub fn encode_rrsig(
|
||||
type_covered: u16,
|
||||
algorithm: u8,
|
||||
labels: u8,
|
||||
original_ttl: u32,
|
||||
expiration: u32,
|
||||
inception: u32,
|
||||
key_tag: u16,
|
||||
signers_name: &str,
|
||||
signature: &[u8],
|
||||
) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&type_covered.to_be_bytes());
|
||||
buf.push(algorithm);
|
||||
buf.push(labels);
|
||||
buf.extend_from_slice(&original_ttl.to_be_bytes());
|
||||
buf.extend_from_slice(&expiration.to_be_bytes());
|
||||
buf.extend_from_slice(&inception.to_be_bytes());
|
||||
buf.extend_from_slice(&key_tag.to_be_bytes());
|
||||
buf.extend_from_slice(&encode_name(signers_name));
|
||||
buf.extend_from_slice(signature);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a DnsRecord from high-level data.
|
||||
pub fn build_record(name: &str, rtype: QType, ttl: u32, rdata: Vec<u8>) -> DnsRecord {
|
||||
DnsRecord {
|
||||
name: name.to_string(),
|
||||
rtype,
|
||||
rclass: QClass::IN,
|
||||
ttl,
|
||||
rdata,
|
||||
opt_flags: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_encode_roundtrip() {
|
||||
// Build a simple query
|
||||
let mut query = DnsPacket::new_query(0x1234);
|
||||
query.flags = FLAG_RD;
|
||||
query.questions.push(DnsQuestion {
|
||||
name: "example.com".to_string(),
|
||||
qtype: QType::A,
|
||||
qclass: QClass::IN,
|
||||
});
|
||||
|
||||
let encoded = query.encode();
|
||||
let parsed = DnsPacket::parse(&encoded).unwrap();
|
||||
|
||||
assert_eq!(parsed.id, 0x1234);
|
||||
assert_eq!(parsed.questions.len(), 1);
|
||||
assert_eq!(parsed.questions[0].name, "example.com");
|
||||
assert_eq!(parsed.questions[0].qtype, QType::A);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_with_answer() {
|
||||
let mut query = DnsPacket::new_query(0x5678);
|
||||
query.flags = FLAG_RD;
|
||||
query.questions.push(DnsQuestion {
|
||||
name: "test.example.com".to_string(),
|
||||
qtype: QType::A,
|
||||
qclass: QClass::IN,
|
||||
});
|
||||
|
||||
let mut response = DnsPacket::new_response(&query);
|
||||
response.answers.push(build_record(
|
||||
"test.example.com",
|
||||
QType::A,
|
||||
300,
|
||||
encode_a("127.0.0.1"),
|
||||
));
|
||||
|
||||
let encoded = response.encode();
|
||||
let parsed = DnsPacket::parse(&encoded).unwrap();
|
||||
|
||||
assert_eq!(parsed.id, 0x5678);
|
||||
assert!(parsed.flags & FLAG_QR != 0); // Is a response
|
||||
assert!(parsed.flags & FLAG_AA != 0); // Authoritative
|
||||
assert_eq!(parsed.answers.len(), 1);
|
||||
assert_eq!(parsed.answers[0].rdata, vec![127, 0, 0, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_aaaa() {
|
||||
let data = encode_aaaa("::1");
|
||||
assert_eq!(data.len(), 16);
|
||||
assert_eq!(data[15], 1);
|
||||
assert!(data[..15].iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_txt() {
|
||||
let data = encode_txt(&["hello".to_string(), "world".to_string()]);
|
||||
assert_eq!(data[0], 5); // length of "hello"
|
||||
assert_eq!(&data[1..6], b"hello");
|
||||
assert_eq!(data[6], 5); // length of "world"
|
||||
assert_eq!(&data[7..12], b"world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dnssec_do_bit() {
|
||||
let mut query = DnsPacket::new_query(1);
|
||||
query.questions.push(DnsQuestion {
|
||||
name: "example.com".to_string(),
|
||||
qtype: QType::A,
|
||||
qclass: QClass::IN,
|
||||
});
|
||||
|
||||
// No OPT record = no DNSSEC
|
||||
assert!(!query.is_dnssec_requested());
|
||||
|
||||
// Add OPT with DO bit
|
||||
query.additionals.push(DnsRecord {
|
||||
name: ".".to_string(),
|
||||
rtype: QType::OPT,
|
||||
rclass: QClass::from_u16(4096), // UDP payload size
|
||||
ttl: 0,
|
||||
rdata: vec![],
|
||||
opt_flags: Some(EDNS_DO_BIT),
|
||||
});
|
||||
assert!(query.is_dnssec_requested());
|
||||
}
|
||||
}
|
||||
131
rust/crates/rustdns-protocol/src/types.rs
Normal file
131
rust/crates/rustdns-protocol/src/types.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
/// DNS record types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(u16)]
|
||||
pub enum QType {
|
||||
A = 1,
|
||||
NS = 2,
|
||||
CNAME = 5,
|
||||
SOA = 6,
|
||||
PTR = 12,
|
||||
MX = 15,
|
||||
TXT = 16,
|
||||
AAAA = 28,
|
||||
SRV = 33,
|
||||
OPT = 41,
|
||||
RRSIG = 46,
|
||||
DNSKEY = 48,
|
||||
Unknown(u16),
|
||||
}
|
||||
|
||||
impl QType {
|
||||
pub fn from_u16(val: u16) -> Self {
|
||||
match val {
|
||||
1 => QType::A,
|
||||
2 => QType::NS,
|
||||
5 => QType::CNAME,
|
||||
6 => QType::SOA,
|
||||
12 => QType::PTR,
|
||||
15 => QType::MX,
|
||||
16 => QType::TXT,
|
||||
28 => QType::AAAA,
|
||||
33 => QType::SRV,
|
||||
41 => QType::OPT,
|
||||
46 => QType::RRSIG,
|
||||
48 => QType::DNSKEY,
|
||||
v => QType::Unknown(v),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u16(self) -> u16 {
|
||||
match self {
|
||||
QType::A => 1,
|
||||
QType::NS => 2,
|
||||
QType::CNAME => 5,
|
||||
QType::SOA => 6,
|
||||
QType::PTR => 12,
|
||||
QType::MX => 15,
|
||||
QType::TXT => 16,
|
||||
QType::AAAA => 28,
|
||||
QType::SRV => 33,
|
||||
QType::OPT => 41,
|
||||
QType::RRSIG => 46,
|
||||
QType::DNSKEY => 48,
|
||||
QType::Unknown(v) => v,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_uppercase().as_str() {
|
||||
"A" => QType::A,
|
||||
"NS" => QType::NS,
|
||||
"CNAME" => QType::CNAME,
|
||||
"SOA" => QType::SOA,
|
||||
"PTR" => QType::PTR,
|
||||
"MX" => QType::MX,
|
||||
"TXT" => QType::TXT,
|
||||
"AAAA" => QType::AAAA,
|
||||
"SRV" => QType::SRV,
|
||||
"OPT" => QType::OPT,
|
||||
"RRSIG" => QType::RRSIG,
|
||||
"DNSKEY" => QType::DNSKEY,
|
||||
_ => QType::Unknown(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
QType::A => "A",
|
||||
QType::NS => "NS",
|
||||
QType::CNAME => "CNAME",
|
||||
QType::SOA => "SOA",
|
||||
QType::PTR => "PTR",
|
||||
QType::MX => "MX",
|
||||
QType::TXT => "TXT",
|
||||
QType::AAAA => "AAAA",
|
||||
QType::SRV => "SRV",
|
||||
QType::OPT => "OPT",
|
||||
QType::RRSIG => "RRSIG",
|
||||
QType::DNSKEY => "DNSKEY",
|
||||
QType::Unknown(_) => "UNKNOWN",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DNS record classes
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u16)]
|
||||
pub enum QClass {
|
||||
IN = 1,
|
||||
CH = 3,
|
||||
HS = 4,
|
||||
Unknown(u16),
|
||||
}
|
||||
|
||||
impl QClass {
|
||||
pub fn from_u16(val: u16) -> Self {
|
||||
match val {
|
||||
1 => QClass::IN,
|
||||
3 => QClass::CH,
|
||||
4 => QClass::HS,
|
||||
v => QClass::Unknown(v),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u16(self) -> u16 {
|
||||
match self {
|
||||
QClass::IN => 1,
|
||||
QClass::CH => 3,
|
||||
QClass::HS => 4,
|
||||
QClass::Unknown(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DNS header flags
|
||||
pub const FLAG_QR: u16 = 0x8000;
|
||||
pub const FLAG_AA: u16 = 0x0400;
|
||||
pub const FLAG_RD: u16 = 0x0100;
|
||||
pub const FLAG_RA: u16 = 0x0080;
|
||||
|
||||
/// OPT record DO bit (DNSSEC OK)
|
||||
pub const EDNS_DO_BIT: u16 = 0x8000;
|
||||
17
rust/crates/rustdns-server/Cargo.toml
Normal file
17
rust/crates/rustdns-server/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "rustdns-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
rustdns-protocol = { path = "../rustdns-protocol" }
|
||||
rustdns-dnssec = { path = "../rustdns-dnssec" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
hyper = { version = "1", features = ["http1", "server"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio"] }
|
||||
http-body-util = "0.1"
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
tokio-rustls = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
tracing = "0.1"
|
||||
bytes = "1"
|
||||
164
rust/crates/rustdns-server/src/https.rs
Normal file
164
rust/crates/rustdns-server/src/https.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper::service::service_fn;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use rustdns_protocol::packet::DnsPacket;
|
||||
use rustls::ServerConfig;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Configuration for the HTTPS DoH server.
|
||||
pub struct HttpsServerConfig {
|
||||
pub bind_addr: SocketAddr,
|
||||
pub tls_config: Arc<ServerConfig>,
|
||||
}
|
||||
|
||||
/// An HTTPS DNS-over-HTTPS server.
|
||||
pub struct HttpsServer {
|
||||
shutdown: tokio::sync::watch::Sender<bool>,
|
||||
local_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl HttpsServer {
|
||||
/// Start the HTTPS DoH server.
|
||||
pub async fn start<F, Fut>(
|
||||
config: HttpsServerConfig,
|
||||
resolver: F,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
|
||||
where
|
||||
F: Fn(DnsPacket) -> Fut + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = DnsPacket> + Send + 'static,
|
||||
{
|
||||
let listener = TcpListener::bind(config.bind_addr).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||
|
||||
let tls_acceptor = TlsAcceptor::from(config.tls_config);
|
||||
let resolver = Arc::new(resolver);
|
||||
|
||||
info!("HTTPS DoH server listening on {}", local_addr);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut shutdown_rx = shutdown_rx;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, _peer_addr)) => {
|
||||
let acceptor = tls_acceptor.clone();
|
||||
let resolver = resolver.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
let io = TokioIo::new(tls_stream);
|
||||
let resolver = resolver.clone();
|
||||
|
||||
let service = service_fn(move |req: Request<Incoming>| {
|
||||
let resolver = resolver.clone();
|
||||
async move {
|
||||
handle_doh_request(req, resolver).await
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(e) = hyper::server::conn::http1::Builder::new()
|
||||
.serve_connection(io, service)
|
||||
.await
|
||||
{
|
||||
error!("HTTPS connection error: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("TLS accept error: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("TCP accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("HTTPS DoH server shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(HttpsServer {
|
||||
shutdown: shutdown_tx,
|
||||
local_addr,
|
||||
})
|
||||
}
|
||||
|
||||
/// Stop the HTTPS server.
|
||||
pub fn stop(&self) {
|
||||
let _ = self.shutdown.send(true);
|
||||
}
|
||||
|
||||
/// Get the bound local address.
|
||||
pub fn local_addr(&self) -> SocketAddr {
|
||||
self.local_addr
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_doh_request<F, Fut>(
|
||||
req: Request<Incoming>,
|
||||
resolver: Arc<F>,
|
||||
) -> Result<Response<Full<bytes::Bytes>>, hyper::Error>
|
||||
where
|
||||
F: Fn(DnsPacket) -> Fut + Send + Sync,
|
||||
Fut: std::future::Future<Output = DnsPacket> + Send,
|
||||
{
|
||||
if req.method() == hyper::Method::POST && req.uri().path() == "/dns-query" {
|
||||
let body = req.collect().await?.to_bytes();
|
||||
|
||||
match DnsPacket::parse(&body) {
|
||||
Ok(request) => {
|
||||
let response = resolver(request).await;
|
||||
let encoded = response.encode();
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/dns-message")
|
||||
.body(Full::new(bytes::Bytes::from(encoded)))
|
||||
.unwrap())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse DoH request: {}", e);
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Full::new(bytes::Bytes::from(format!("Invalid DNS message: {}", e))))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Full::new(bytes::Bytes::new()))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a rustls ServerConfig from PEM-encoded certificate and key.
|
||||
pub fn create_tls_config(cert_pem: &str, key_pem: &str) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
|
||||
.ok_or("no private key found in PEM data")?;
|
||||
|
||||
let config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?;
|
||||
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
12
rust/crates/rustdns-server/src/lib.rs
Normal file
12
rust/crates/rustdns-server/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
pub mod udp;
|
||||
pub mod https;
|
||||
|
||||
use rustdns_protocol::packet::DnsPacket;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Trait for DNS query resolution.
|
||||
/// The resolver receives a parsed DNS packet and returns a response packet.
|
||||
pub type DnsResolverFn = Box<
|
||||
dyn Fn(DnsPacket) -> Pin<Box<dyn Future<Output = DnsPacket> + Send>> + Send + Sync,
|
||||
>;
|
||||
95
rust/crates/rustdns-server/src/udp.rs
Normal file
95
rust/crates/rustdns-server/src/udp.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use rustdns_protocol::packet::DnsPacket;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::UdpSocket;
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Configuration for the UDP DNS server.
|
||||
pub struct UdpServerConfig {
|
||||
pub bind_addr: SocketAddr,
|
||||
}
|
||||
|
||||
/// A UDP DNS server that delegates resolution to a callback.
|
||||
pub struct UdpServer {
|
||||
socket: Arc<UdpSocket>,
|
||||
shutdown: tokio::sync::watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl UdpServer {
|
||||
/// Bind and start the UDP server. The resolver function is called for each query.
|
||||
pub async fn start<F, Fut>(
|
||||
config: UdpServerConfig,
|
||||
resolver: F,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
|
||||
where
|
||||
F: Fn(DnsPacket) -> Fut + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = DnsPacket> + Send + 'static,
|
||||
{
|
||||
let socket = UdpSocket::bind(config.bind_addr).await?;
|
||||
let socket = Arc::new(socket);
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||
|
||||
info!("UDP DNS server listening on {}", config.bind_addr);
|
||||
|
||||
let recv_socket = socket.clone();
|
||||
let resolver = Arc::new(resolver);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let mut shutdown_rx = shutdown_rx;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = recv_socket.recv_from(&mut buf) => {
|
||||
match result {
|
||||
Ok((len, src)) => {
|
||||
let data = buf[..len].to_vec();
|
||||
let sock = recv_socket.clone();
|
||||
let resolver = resolver.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match DnsPacket::parse(&data) {
|
||||
Ok(request) => {
|
||||
let response = resolver(request).await;
|
||||
let encoded = response.encode();
|
||||
if let Err(e) = sock.send_to(&encoded, src).await {
|
||||
error!("Failed to send UDP response: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse DNS packet from {}: {}", src, e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("UDP recv error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("UDP DNS server shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(UdpServer {
|
||||
socket,
|
||||
shutdown: shutdown_tx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Stop the UDP server.
|
||||
pub fn stop(&self) {
|
||||
let _ = self.shutdown.send(true);
|
||||
}
|
||||
|
||||
/// Get the bound local address.
|
||||
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
|
||||
self.socket.local_addr()
|
||||
}
|
||||
}
|
||||
26
rust/crates/rustdns/Cargo.toml
Normal file
26
rust/crates/rustdns/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "rustdns"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "rustdns"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "rustdns"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
rustdns-protocol = { path = "../rustdns-protocol" }
|
||||
rustdns-dnssec = { path = "../rustdns-dnssec" }
|
||||
rustdns-server = { path = "../rustdns-server" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
dashmap = "6"
|
||||
base64 = "0.22"
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
125
rust/crates/rustdns/src/ipc_types.rs
Normal file
125
rust/crates/rustdns/src/ipc_types.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// IPC request from TypeScript to Rust (via stdin).
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IpcRequest {
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// IPC response from Rust to TypeScript (via stdout).
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IpcResponse {
|
||||
pub id: String,
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl IpcResponse {
|
||||
pub fn ok(id: String, result: serde_json::Value) -> Self {
|
||||
IpcResponse {
|
||||
id,
|
||||
success: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(id: String, error: String) -> Self {
|
||||
IpcResponse {
|
||||
id,
|
||||
success: false,
|
||||
result: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// IPC event from Rust to TypeScript (unsolicited, no id).
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IpcEvent {
|
||||
pub event: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Configuration sent via the "start" command.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RustDnsConfig {
|
||||
pub udp_port: u16,
|
||||
pub https_port: u16,
|
||||
#[serde(default = "default_bind")]
|
||||
pub udp_bind_interface: String,
|
||||
#[serde(default = "default_bind")]
|
||||
pub https_bind_interface: String,
|
||||
#[serde(default)]
|
||||
pub https_key: String,
|
||||
#[serde(default)]
|
||||
pub https_cert: String,
|
||||
pub dnssec_zone: String,
|
||||
#[serde(default = "default_algorithm")]
|
||||
pub dnssec_algorithm: String,
|
||||
#[serde(default)]
|
||||
pub primary_nameserver: String,
|
||||
#[serde(default = "default_true")]
|
||||
pub enable_localhost_handling: bool,
|
||||
#[serde(default)]
|
||||
pub manual_udp_mode: bool,
|
||||
#[serde(default)]
|
||||
pub manual_https_mode: bool,
|
||||
}
|
||||
|
||||
fn default_bind() -> String {
|
||||
"0.0.0.0".to_string()
|
||||
}
|
||||
|
||||
fn default_algorithm() -> String {
|
||||
"ECDSA".to_string()
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// A DNS question as sent over IPC.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct IpcDnsQuestion {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub qtype: String,
|
||||
pub class: String,
|
||||
}
|
||||
|
||||
/// A DNS answer as received from TypeScript over IPC.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct IpcDnsAnswer {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub rtype: String,
|
||||
pub class: String,
|
||||
pub ttl: u32,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// The dnsQuery event sent from Rust to TypeScript.
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DnsQueryEvent {
|
||||
pub correlation_id: String,
|
||||
pub questions: Vec<IpcDnsQuestion>,
|
||||
pub dnssec_requested: bool,
|
||||
}
|
||||
|
||||
/// The dnsQueryResult command from TypeScript to Rust.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DnsQueryResult {
|
||||
pub correlation_id: String,
|
||||
pub answers: Vec<IpcDnsAnswer>,
|
||||
pub answered: bool,
|
||||
}
|
||||
3
rust/crates/rustdns/src/lib.rs
Normal file
3
rust/crates/rustdns/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod management;
|
||||
pub mod ipc_types;
|
||||
pub mod resolver;
|
||||
36
rust/crates/rustdns/src/main.rs
Normal file
36
rust/crates/rustdns/src/main.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use clap::Parser;
|
||||
use tracing_subscriber;
|
||||
|
||||
mod management;
|
||||
mod ipc_types;
|
||||
mod resolver;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "rustdns", about = "Rust DNS server with IPC management")]
|
||||
struct Cli {
|
||||
/// Run in management mode (IPC via stdin/stdout)
|
||||
#[arg(long)]
|
||||
management: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Install the default rustls crypto provider (ring) before any TLS operations
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Tracing writes to stderr so stdout is reserved for IPC
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
if cli.management {
|
||||
management::management_loop().await?;
|
||||
} else {
|
||||
eprintln!("rustdns: use --management flag for IPC mode");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
402
rust/crates/rustdns/src/management.rs
Normal file
402
rust/crates/rustdns/src/management.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use crate::ipc_types::*;
|
||||
use crate::resolver::DnsResolver;
|
||||
use dashmap::DashMap;
|
||||
use rustdns_dnssec::keys::DnssecAlgorithm;
|
||||
use rustdns_protocol::packet::DnsPacket;
|
||||
use rustdns_server::https::{self, HttpsServer};
|
||||
use rustdns_server::udp::{UdpServer, UdpServerConfig};
|
||||
use std::io::{self, BufRead, Write};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Pending DNS query callbacks waiting for TypeScript response.
|
||||
type PendingCallbacks = Arc<DashMap<String, oneshot::Sender<DnsQueryResult>>>;
|
||||
|
||||
/// Active server state.
|
||||
struct ServerState {
|
||||
udp_server: Option<UdpServer>,
|
||||
https_server: Option<HttpsServer>,
|
||||
resolver: Arc<DnsResolver>,
|
||||
}
|
||||
|
||||
/// Emit a JSON event on stdout.
|
||||
fn send_event(event: &str, data: serde_json::Value) {
|
||||
let evt = IpcEvent {
|
||||
event: event.to_string(),
|
||||
data,
|
||||
};
|
||||
let json = serde_json::to_string(&evt).unwrap();
|
||||
let stdout = io::stdout();
|
||||
let mut lock = stdout.lock();
|
||||
let _ = writeln!(lock, "{}", json);
|
||||
let _ = lock.flush();
|
||||
}
|
||||
|
||||
/// Send a JSON response on stdout.
|
||||
fn send_response(response: &IpcResponse) {
|
||||
let json = serde_json::to_string(response).unwrap();
|
||||
let stdout = io::stdout();
|
||||
let mut lock = stdout.lock();
|
||||
let _ = writeln!(lock, "{}", json);
|
||||
let _ = lock.flush();
|
||||
}
|
||||
|
||||
/// Main management loop — reads JSON lines from stdin, dispatches commands.
|
||||
pub async fn management_loop() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Emit ready event
|
||||
send_event("ready", serde_json::json!({
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}));
|
||||
|
||||
let pending: PendingCallbacks = Arc::new(DashMap::new());
|
||||
let mut server_state: Option<ServerState> = None;
|
||||
|
||||
// Channel for stdin commands (read in blocking thread)
|
||||
let (cmd_tx, mut cmd_rx) = mpsc::channel::<String>(256);
|
||||
|
||||
// Channel for DNS query events from the server
|
||||
let (query_tx, mut query_rx) = mpsc::channel::<(String, DnsPacket)>(256);
|
||||
|
||||
// Spawn blocking stdin reader
|
||||
std::thread::spawn(move || {
|
||||
let stdin = io::stdin();
|
||||
let reader = stdin.lock();
|
||||
for line in reader.lines() {
|
||||
match line {
|
||||
Ok(l) => {
|
||||
if cmd_tx.blocking_send(l).is_err() {
|
||||
break; // channel closed
|
||||
}
|
||||
}
|
||||
Err(_) => break, // stdin closed
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
cmd = cmd_rx.recv() => {
|
||||
match cmd {
|
||||
Some(line) => {
|
||||
let request: IpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to parse IPC request: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_request(
|
||||
&request,
|
||||
&mut server_state,
|
||||
&pending,
|
||||
&query_tx,
|
||||
).await;
|
||||
send_response(&response);
|
||||
}
|
||||
None => {
|
||||
// stdin closed — parent process exited
|
||||
info!("stdin closed, shutting down");
|
||||
if let Some(ref state) = server_state {
|
||||
if let Some(ref udp) = state.udp_server {
|
||||
udp.stop();
|
||||
}
|
||||
if let Some(ref https) = state.https_server {
|
||||
https.stop();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
query = query_rx.recv() => {
|
||||
if let Some((correlation_id, packet)) = query {
|
||||
let dnssec = packet.is_dnssec_requested();
|
||||
let questions = DnsResolver::questions_to_ipc(&packet.questions);
|
||||
|
||||
send_event("dnsQuery", serde_json::to_value(&DnsQueryEvent {
|
||||
correlation_id,
|
||||
questions,
|
||||
dnssec_requested: dnssec,
|
||||
}).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
request: &IpcRequest,
|
||||
server_state: &mut Option<ServerState>,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> IpcResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"ping" => IpcResponse::ok(id, serde_json::json!({ "pong": true })),
|
||||
|
||||
"start" => {
|
||||
handle_start(id, &request.params, server_state, pending, query_tx).await
|
||||
}
|
||||
|
||||
"stop" => {
|
||||
handle_stop(id, server_state)
|
||||
}
|
||||
|
||||
"dnsQueryResult" => {
|
||||
handle_query_result(id, &request.params, pending)
|
||||
}
|
||||
|
||||
"updateCerts" => {
|
||||
// TODO: hot-swap TLS certs (requires rustls cert resolver)
|
||||
IpcResponse::ok(id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
"processPacket" => {
|
||||
handle_process_packet(id, &request.params, server_state, pending, query_tx).await
|
||||
}
|
||||
|
||||
_ => IpcResponse::err(id, format!("Unknown method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_start(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
server_state: &mut Option<ServerState>,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> IpcResponse {
|
||||
let config: RustDnsConfig = match serde_json::from_value(params.get("config").cloned().unwrap_or_default()) {
|
||||
Ok(c) => c,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid config: {}", e)),
|
||||
};
|
||||
|
||||
let algorithm = DnssecAlgorithm::from_str(&config.dnssec_algorithm)
|
||||
.unwrap_or(DnssecAlgorithm::EcdsaP256Sha256);
|
||||
|
||||
let resolver = Arc::new(DnsResolver::new(
|
||||
&config.dnssec_zone,
|
||||
algorithm,
|
||||
&config.primary_nameserver,
|
||||
config.enable_localhost_handling,
|
||||
));
|
||||
|
||||
// Start UDP server if not manual mode
|
||||
let udp_server = if !config.manual_udp_mode {
|
||||
let addr: SocketAddr = format!("{}:{}", config.udp_bind_interface, config.udp_port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], config.udp_port)));
|
||||
|
||||
let resolver_clone = resolver.clone();
|
||||
let pending_clone = pending.clone();
|
||||
let query_tx_clone = query_tx.clone();
|
||||
|
||||
match UdpServer::start(
|
||||
UdpServerConfig { bind_addr: addr },
|
||||
move |packet| {
|
||||
let resolver = resolver_clone.clone();
|
||||
let pending = pending_clone.clone();
|
||||
let query_tx = query_tx_clone.clone();
|
||||
async move {
|
||||
resolve_with_callback(packet, &resolver, &pending, &query_tx).await
|
||||
}
|
||||
},
|
||||
).await {
|
||||
Ok(server) => {
|
||||
info!("UDP DNS server started on {}", addr);
|
||||
Some(server)
|
||||
}
|
||||
Err(e) => {
|
||||
return IpcResponse::err(id, format!("Failed to start UDP server: {}", e));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Start HTTPS server if not manual mode and certs are provided
|
||||
let https_server = if !config.manual_https_mode && !config.https_cert.is_empty() && !config.https_key.is_empty() {
|
||||
let addr: SocketAddr = format!("{}:{}", config.https_bind_interface, config.https_port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], config.https_port)));
|
||||
|
||||
match https::create_tls_config(&config.https_cert, &config.https_key) {
|
||||
Ok(tls_config) => {
|
||||
let resolver_clone = resolver.clone();
|
||||
let pending_clone = pending.clone();
|
||||
let query_tx_clone = query_tx.clone();
|
||||
|
||||
match HttpsServer::start(
|
||||
https::HttpsServerConfig {
|
||||
bind_addr: addr,
|
||||
tls_config,
|
||||
},
|
||||
move |packet| {
|
||||
let resolver = resolver_clone.clone();
|
||||
let pending = pending_clone.clone();
|
||||
let query_tx = query_tx_clone.clone();
|
||||
async move {
|
||||
resolve_with_callback(packet, &resolver, &pending, &query_tx).await
|
||||
}
|
||||
},
|
||||
).await {
|
||||
Ok(server) => {
|
||||
info!("HTTPS DoH server started on {}", addr);
|
||||
Some(server)
|
||||
}
|
||||
Err(e) => {
|
||||
return IpcResponse::err(id, format!("Failed to start HTTPS server: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return IpcResponse::err(id, format!("Failed to configure TLS: {}", e));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
*server_state = Some(ServerState {
|
||||
udp_server,
|
||||
https_server,
|
||||
resolver,
|
||||
});
|
||||
|
||||
send_event("started", serde_json::json!({}));
|
||||
IpcResponse::ok(id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn handle_stop(id: String, server_state: &mut Option<ServerState>) -> IpcResponse {
|
||||
if let Some(ref state) = server_state {
|
||||
if let Some(ref udp) = state.udp_server {
|
||||
udp.stop();
|
||||
}
|
||||
if let Some(ref https) = state.https_server {
|
||||
https.stop();
|
||||
}
|
||||
}
|
||||
*server_state = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
IpcResponse::ok(id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn handle_query_result(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
pending: &PendingCallbacks,
|
||||
) -> IpcResponse {
|
||||
let result: DnsQueryResult = match serde_json::from_value(params.clone()) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid query result: {}", e)),
|
||||
};
|
||||
|
||||
let correlation_id = result.correlation_id.clone();
|
||||
if let Some((_, sender)) = pending.remove(&correlation_id) {
|
||||
let _ = sender.send(result);
|
||||
IpcResponse::ok(id, serde_json::json!({ "resolved": true }))
|
||||
} else {
|
||||
IpcResponse::err(id, format!("No pending query for correlationId: {}", correlation_id))
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_process_packet(
|
||||
id: String,
|
||||
params: &serde_json::Value,
|
||||
server_state: &mut Option<ServerState>,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> IpcResponse {
|
||||
let packet_b64 = match params.get("packet").and_then(|v| v.as_str()) {
|
||||
Some(p) => p,
|
||||
None => return IpcResponse::err(id, "Missing packet parameter".to_string()),
|
||||
};
|
||||
|
||||
let packet_data = match base64_decode(packet_b64) {
|
||||
Ok(d) => d,
|
||||
Err(e) => return IpcResponse::err(id, format!("Invalid base64: {}", e)),
|
||||
};
|
||||
|
||||
let state = match server_state {
|
||||
Some(ref s) => s,
|
||||
None => return IpcResponse::err(id, "Server not started".to_string()),
|
||||
};
|
||||
|
||||
let request = match DnsPacket::parse(&packet_data) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return IpcResponse::err(id, format!("Failed to parse packet: {}", e)),
|
||||
};
|
||||
|
||||
let response = resolve_with_callback(request, &state.resolver, pending, query_tx).await;
|
||||
let encoded = response.encode();
|
||||
|
||||
use base64::Engine;
|
||||
let response_b64 = base64::engine::general_purpose::STANDARD.encode(&encoded);
|
||||
IpcResponse::ok(id, serde_json::json!({ "packet": response_b64 }))
|
||||
}
|
||||
|
||||
/// Core resolution: try local first, then IPC callback to TypeScript.
|
||||
async fn resolve_with_callback(
|
||||
packet: DnsPacket,
|
||||
resolver: &DnsResolver,
|
||||
pending: &PendingCallbacks,
|
||||
query_tx: &mpsc::Sender<(String, DnsPacket)>,
|
||||
) -> DnsPacket {
|
||||
// Try local resolution first (localhost, DNSKEY)
|
||||
if let Some(response) = resolver.try_local_resolution(&packet) {
|
||||
return response;
|
||||
}
|
||||
|
||||
// Need IPC callback to TypeScript
|
||||
let correlation_id = format!("dns_{}", uuid_v4());
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
pending.insert(correlation_id.clone(), tx);
|
||||
|
||||
// Send the query event to the management loop for emission
|
||||
if query_tx.send((correlation_id.clone(), packet.clone())).await.is_err() {
|
||||
pending.remove(&correlation_id);
|
||||
return DnsPacket::new_response(&packet);
|
||||
}
|
||||
|
||||
// Wait for the result with a timeout
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
|
||||
Ok(Ok(result)) => {
|
||||
resolver.build_response_from_answers(&packet, &result.answers, result.answered)
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
// Sender dropped
|
||||
pending.remove(&correlation_id);
|
||||
resolver.build_response_from_answers(&packet, &[], false)
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout
|
||||
pending.remove(&correlation_id);
|
||||
resolver.build_response_from_answers(&packet, &[], false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple UUID v4 generation (no external dep needed).
|
||||
fn uuid_v4() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let random: u64 = nanos as u64 ^ (std::process::id() as u64 * 0x517cc1b727220a95);
|
||||
format!("{:016x}{:016x}", nanos as u64, random)
|
||||
}
|
||||
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD
|
||||
.decode(input)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
258
rust/crates/rustdns/src/resolver.rs
Normal file
258
rust/crates/rustdns/src/resolver.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use crate::ipc_types::{IpcDnsAnswer, IpcDnsQuestion};
|
||||
use rustdns_protocol::packet::*;
|
||||
use rustdns_protocol::types::QType;
|
||||
use rustdns_dnssec::keys::{DnssecAlgorithm, DnssecKeyPair};
|
||||
use rustdns_dnssec::keytag::compute_key_tag;
|
||||
use rustdns_dnssec::signing::generate_rrsig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// DNS resolver that builds responses from IPC callback answers.
|
||||
pub struct DnsResolver {
|
||||
pub zone: String,
|
||||
pub primary_nameserver: String,
|
||||
pub enable_localhost: bool,
|
||||
pub key_pair: DnssecKeyPair,
|
||||
pub dnskey_rdata: Vec<u8>,
|
||||
pub key_tag: u16,
|
||||
}
|
||||
|
||||
impl DnsResolver {
|
||||
pub fn new(zone: &str, algorithm: DnssecAlgorithm, primary_nameserver: &str, enable_localhost: bool) -> Self {
|
||||
let key_pair = DnssecKeyPair::generate(algorithm);
|
||||
let dnskey_rdata = key_pair.dnskey_rdata();
|
||||
let key_tag = compute_key_tag(&dnskey_rdata);
|
||||
|
||||
let primary_ns = if primary_nameserver.is_empty() {
|
||||
format!("ns1.{}", zone)
|
||||
} else {
|
||||
primary_nameserver.to_string()
|
||||
};
|
||||
|
||||
DnsResolver {
|
||||
zone: zone.to_string(),
|
||||
primary_nameserver: primary_ns,
|
||||
enable_localhost,
|
||||
key_pair,
|
||||
dnskey_rdata,
|
||||
key_tag,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a query can be answered locally (localhost, DNSKEY).
|
||||
/// Returns Some(answers) if handled locally, None if it needs IPC callback.
|
||||
pub fn try_local_resolution(&self, packet: &DnsPacket) -> Option<DnsPacket> {
|
||||
let dnssec = packet.is_dnssec_requested();
|
||||
let mut response = DnsPacket::new_response(packet);
|
||||
let mut all_local = true;
|
||||
|
||||
for q in &packet.questions {
|
||||
if let Some(records) = self.try_local_question(q, dnssec) {
|
||||
for r in records {
|
||||
response.answers.push(r);
|
||||
}
|
||||
} else {
|
||||
all_local = false;
|
||||
}
|
||||
}
|
||||
|
||||
if all_local && !packet.questions.is_empty() {
|
||||
Some(response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn try_local_question(&self, q: &DnsQuestion, dnssec: bool) -> Option<Vec<DnsRecord>> {
|
||||
let name_lower = q.name.to_lowercase();
|
||||
let name_trimmed = name_lower.strip_suffix('.').unwrap_or(&name_lower);
|
||||
|
||||
// DNSKEY queries for our zone
|
||||
if dnssec && q.qtype == QType::DNSKEY && name_trimmed == self.zone.to_lowercase() {
|
||||
let record = build_record(&q.name, QType::DNSKEY, 3600, self.dnskey_rdata.clone());
|
||||
let mut records = vec![record.clone()];
|
||||
// Sign the DNSKEY record
|
||||
let rrsig = generate_rrsig(&self.key_pair, &self.zone, &[record], &q.name, QType::DNSKEY);
|
||||
records.push(rrsig);
|
||||
return Some(records);
|
||||
}
|
||||
|
||||
// Localhost handling (RFC 6761)
|
||||
if self.enable_localhost {
|
||||
if name_trimmed == "localhost" {
|
||||
match q.qtype {
|
||||
QType::A => {
|
||||
return Some(vec![build_record(&q.name, QType::A, 0, encode_a("127.0.0.1"))]);
|
||||
}
|
||||
QType::AAAA => {
|
||||
return Some(vec![build_record(&q.name, QType::AAAA, 0, encode_aaaa("::1"))]);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse localhost
|
||||
if name_trimmed == "1.0.0.127.in-addr.arpa" && q.qtype == QType::PTR {
|
||||
return Some(vec![build_record(&q.name, QType::PTR, 0, encode_name_rdata("localhost."))]);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Build a response from IPC callback answers.
|
||||
pub fn build_response_from_answers(
|
||||
&self,
|
||||
request: &DnsPacket,
|
||||
answers: &[IpcDnsAnswer],
|
||||
answered: bool,
|
||||
) -> DnsPacket {
|
||||
let dnssec = request.is_dnssec_requested();
|
||||
let mut response = DnsPacket::new_response(request);
|
||||
|
||||
if answered && !answers.is_empty() {
|
||||
// Group answers by (name, type) for DNSSEC RRset signing
|
||||
let mut rrset_map: HashMap<(String, QType), Vec<DnsRecord>> = HashMap::new();
|
||||
|
||||
for answer in answers {
|
||||
let rtype = QType::from_str(&answer.rtype);
|
||||
let rdata = self.encode_answer_rdata(rtype, &answer.data);
|
||||
let record = build_record(&answer.name, rtype, answer.ttl, rdata);
|
||||
response.answers.push(record.clone());
|
||||
|
||||
if dnssec {
|
||||
let key = (answer.name.clone(), rtype);
|
||||
rrset_map.entry(key).or_default().push(record);
|
||||
}
|
||||
}
|
||||
|
||||
// Sign RRsets
|
||||
if dnssec {
|
||||
for ((name, rtype), rrset) in &rrset_map {
|
||||
let rrsig = generate_rrsig(&self.key_pair, &self.zone, rrset, name, *rtype);
|
||||
response.answers.push(rrsig);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No handler matched — return SOA
|
||||
for q in &request.questions {
|
||||
let soa_rdata = encode_soa(
|
||||
&self.primary_nameserver,
|
||||
&format!("hostmaster.{}", self.zone),
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as u32,
|
||||
3600,
|
||||
600,
|
||||
604800,
|
||||
86400,
|
||||
);
|
||||
let soa_record = build_record(&q.name, QType::SOA, 3600, soa_rdata);
|
||||
response.answers.push(soa_record.clone());
|
||||
|
||||
if dnssec {
|
||||
let rrsig = generate_rrsig(&self.key_pair, &self.zone, &[soa_record], &q.name, QType::SOA);
|
||||
response.answers.push(rrsig);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Process a raw DNS packet (for manual/passthrough mode).
|
||||
/// Returns local answers or None if IPC callback is needed.
|
||||
pub fn process_packet_local(&self, data: &[u8]) -> Result<Option<Vec<u8>>, String> {
|
||||
let packet = DnsPacket::parse(data)?;
|
||||
if let Some(response) = self.try_local_resolution(&packet) {
|
||||
Ok(Some(response.encode()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_answer_rdata(&self, rtype: QType, data: &serde_json::Value) -> Vec<u8> {
|
||||
match rtype {
|
||||
QType::A => {
|
||||
if let Some(ip) = data.as_str() {
|
||||
encode_a(ip)
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::AAAA => {
|
||||
if let Some(ip) = data.as_str() {
|
||||
encode_aaaa(ip)
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::TXT => {
|
||||
if let Some(arr) = data.as_array() {
|
||||
let strings: Vec<String> = arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect();
|
||||
encode_txt(&strings)
|
||||
} else if let Some(s) = data.as_str() {
|
||||
encode_txt(&[s.to_string()])
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::NS | QType::CNAME | QType::PTR => {
|
||||
if let Some(name) = data.as_str() {
|
||||
encode_name_rdata(name)
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
QType::MX => {
|
||||
let preference = data.get("preference").and_then(|v| v.as_u64()).unwrap_or(10) as u16;
|
||||
let exchange = data.get("exchange").and_then(|v| v.as_str()).unwrap_or("");
|
||||
encode_mx(preference, exchange)
|
||||
}
|
||||
QType::SRV => {
|
||||
let priority = data.get("priority").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let weight = data.get("weight").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let port = data.get("port").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let target = data.get("target").and_then(|v| v.as_str()).unwrap_or("");
|
||||
encode_srv(priority, weight, port, target)
|
||||
}
|
||||
QType::SOA => {
|
||||
let mname = data.get("mname").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let rname = data.get("rname").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let serial = data.get("serial").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let refresh = data.get("refresh").and_then(|v| v.as_u64()).unwrap_or(3600) as u32;
|
||||
let retry = data.get("retry").and_then(|v| v.as_u64()).unwrap_or(600) as u32;
|
||||
let expire = data.get("expire").and_then(|v| v.as_u64()).unwrap_or(604800) as u32;
|
||||
let minimum = data.get("minimum").and_then(|v| v.as_u64()).unwrap_or(86400) as u32;
|
||||
encode_soa(mname, rname, serial, refresh, retry, expire, minimum)
|
||||
}
|
||||
_ => {
|
||||
// For unknown types, try to interpret as raw base64
|
||||
if let Some(b64) = data.as_str() {
|
||||
base64_decode(b64).unwrap_or_default()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert questions to IPC format.
|
||||
pub fn questions_to_ipc(questions: &[DnsQuestion]) -> Vec<IpcDnsQuestion> {
|
||||
questions
|
||||
.iter()
|
||||
.map(|q| IpcDnsQuestion {
|
||||
name: q.name.clone(),
|
||||
qtype: q.qtype.as_str().to_string(),
|
||||
class: "IN".to_string(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD
|
||||
.decode(input)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
Reference in New Issue
Block a user