Files
smartdns/rust/crates/rustdns-server/src/https.rs

165 lines
5.9 KiB
Rust

use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use http_body_util::{BodyExt, Full};
use rustdns_protocol::packet::DnsPacket;
use rustls::ServerConfig;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tracing::{error, info};
/// Configuration for the HTTPS DoH server.
pub struct HttpsServerConfig {
pub bind_addr: SocketAddr,
pub tls_config: Arc<ServerConfig>,
}
/// An HTTPS DNS-over-HTTPS server.
pub struct HttpsServer {
shutdown: tokio::sync::watch::Sender<bool>,
local_addr: SocketAddr,
}
impl HttpsServer {
/// Start the HTTPS DoH server.
pub async fn start<F, Fut>(
config: HttpsServerConfig,
resolver: F,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
where
F: Fn(DnsPacket) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = DnsPacket> + Send + 'static,
{
let listener = TcpListener::bind(config.bind_addr).await?;
let local_addr = listener.local_addr()?;
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let tls_acceptor = TlsAcceptor::from(config.tls_config);
let resolver = Arc::new(resolver);
info!("HTTPS DoH server listening on {}", local_addr);
tokio::spawn(async move {
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, _peer_addr)) => {
let acceptor = tls_acceptor.clone();
let resolver = resolver.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
let resolver = resolver.clone();
let service = service_fn(move |req: Request<Incoming>| {
let resolver = resolver.clone();
async move {
handle_doh_request(req, resolver).await
}
});
if let Err(e) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
error!("HTTPS connection error: {}", e);
}
}
Err(e) => {
error!("TLS accept error: {}", e);
}
}
});
}
Err(e) => {
error!("TCP accept error: {}", e);
}
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("HTTPS DoH server shutting down");
break;
}
}
}
}
});
Ok(HttpsServer {
shutdown: shutdown_tx,
local_addr,
})
}
/// Stop the HTTPS server.
pub fn stop(&self) {
let _ = self.shutdown.send(true);
}
/// Get the bound local address.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
async fn handle_doh_request<F, Fut>(
req: Request<Incoming>,
resolver: Arc<F>,
) -> Result<Response<Full<bytes::Bytes>>, hyper::Error>
where
F: Fn(DnsPacket) -> Fut + Send + Sync,
Fut: std::future::Future<Output = DnsPacket> + Send,
{
if req.method() == hyper::Method::POST && req.uri().path() == "/dns-query" {
let body = req.collect().await?.to_bytes();
match DnsPacket::parse(&body) {
Ok(request) => {
let response = resolver(request).await;
let encoded = response.encode();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/dns-message")
.body(Full::new(bytes::Bytes::from(encoded)))
.unwrap())
}
Err(e) => {
error!("Failed to parse DoH request: {}", e);
Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(bytes::Bytes::from(format!("Invalid DNS message: {}", e))))
.unwrap())
}
}
} else {
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(bytes::Bytes::new()))
.unwrap())
}
}
/// Create a rustls ServerConfig from PEM-encoded certificate and key.
pub fn create_tls_config(cert_pem: &str, key_pem: &str) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error + Send + Sync>> {
let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
.ok_or("no private key found in PEM data")?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
Ok(Arc::new(config))
}