use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; use hyper::service::service_fn; use hyper_util::rt::TokioIo; use http_body_util::{BodyExt, Full}; use rustdns_protocol::packet::DnsPacket; use rustls::ServerConfig; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; use tracing::{error, info}; /// Configuration for the HTTPS DoH server. pub struct HttpsServerConfig { pub bind_addr: SocketAddr, pub tls_config: Arc, } /// An HTTPS DNS-over-HTTPS server. pub struct HttpsServer { shutdown: tokio::sync::watch::Sender, local_addr: SocketAddr, } impl HttpsServer { /// Start the HTTPS DoH server. pub async fn start( config: HttpsServerConfig, resolver: F, ) -> Result> where F: Fn(DnsPacket) -> Fut + Send + Sync + 'static, Fut: std::future::Future + Send + 'static, { let listener = TcpListener::bind(config.bind_addr).await?; let local_addr = listener.local_addr()?; let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); let tls_acceptor = TlsAcceptor::from(config.tls_config); let resolver = Arc::new(resolver); info!("HTTPS DoH server listening on {}", local_addr); tokio::spawn(async move { let mut shutdown_rx = shutdown_rx; loop { tokio::select! { result = listener.accept() => { match result { Ok((stream, _peer_addr)) => { let acceptor = tls_acceptor.clone(); let resolver = resolver.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { Ok(tls_stream) => { let io = TokioIo::new(tls_stream); let resolver = resolver.clone(); let service = service_fn(move |req: Request| { let resolver = resolver.clone(); async move { handle_doh_request(req, resolver).await } }); if let Err(e) = hyper::server::conn::http1::Builder::new() .serve_connection(io, service) .await { error!("HTTPS connection error: {}", e); } } Err(e) => { error!("TLS accept error: {}", e); } } }); } Err(e) => { error!("TCP accept error: {}", e); } } } _ = shutdown_rx.changed() => { if *shutdown_rx.borrow() { info!("HTTPS DoH server shutting down"); break; } } } } }); Ok(HttpsServer { shutdown: shutdown_tx, local_addr, }) } /// Stop the HTTPS server. pub fn stop(&self) { let _ = self.shutdown.send(true); } /// Get the bound local address. pub fn local_addr(&self) -> SocketAddr { self.local_addr } } async fn handle_doh_request( req: Request, resolver: Arc, ) -> Result>, hyper::Error> where F: Fn(DnsPacket) -> Fut + Send + Sync, Fut: std::future::Future + Send, { if req.method() == hyper::Method::POST && req.uri().path() == "/dns-query" { let body = req.collect().await?.to_bytes(); match DnsPacket::parse(&body) { Ok(request) => { let response = resolver(request).await; let encoded = response.encode(); Ok(Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/dns-message") .body(Full::new(bytes::Bytes::from(encoded))) .unwrap()) } Err(e) => { error!("Failed to parse DoH request: {}", e); Ok(Response::builder() .status(StatusCode::BAD_REQUEST) .body(Full::new(bytes::Bytes::from(format!("Invalid DNS message: {}", e)))) .unwrap()) } } } else { Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(Full::new(bytes::Bytes::new())) .unwrap()) } } /// Create a rustls ServerConfig from PEM-encoded certificate and key. pub fn create_tls_config(cert_pem: &str, key_pem: &str) -> Result, Box> { let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes()) .collect::, _>>()?; let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())? .ok_or("no private key found in PEM data")?; let config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; Ok(Arc::new(config)) }