165 lines
5.9 KiB
Rust
165 lines
5.9 KiB
Rust
|
|
use hyper::body::Incoming;
|
||
|
|
use hyper::{Request, Response, StatusCode};
|
||
|
|
use hyper::service::service_fn;
|
||
|
|
use hyper_util::rt::TokioIo;
|
||
|
|
use http_body_util::{BodyExt, Full};
|
||
|
|
use rustdns_protocol::packet::DnsPacket;
|
||
|
|
use rustls::ServerConfig;
|
||
|
|
use std::net::SocketAddr;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use tokio::net::TcpListener;
|
||
|
|
use tokio_rustls::TlsAcceptor;
|
||
|
|
use tracing::{error, info};
|
||
|
|
|
||
|
|
/// Configuration for the HTTPS DoH server.
|
||
|
|
pub struct HttpsServerConfig {
|
||
|
|
pub bind_addr: SocketAddr,
|
||
|
|
pub tls_config: Arc<ServerConfig>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// An HTTPS DNS-over-HTTPS server.
|
||
|
|
pub struct HttpsServer {
|
||
|
|
shutdown: tokio::sync::watch::Sender<bool>,
|
||
|
|
local_addr: SocketAddr,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl HttpsServer {
|
||
|
|
/// Start the HTTPS DoH server.
|
||
|
|
pub async fn start<F, Fut>(
|
||
|
|
config: HttpsServerConfig,
|
||
|
|
resolver: F,
|
||
|
|
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
|
||
|
|
where
|
||
|
|
F: Fn(DnsPacket) -> Fut + Send + Sync + 'static,
|
||
|
|
Fut: std::future::Future<Output = DnsPacket> + Send + 'static,
|
||
|
|
{
|
||
|
|
let listener = TcpListener::bind(config.bind_addr).await?;
|
||
|
|
let local_addr = listener.local_addr()?;
|
||
|
|
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||
|
|
|
||
|
|
let tls_acceptor = TlsAcceptor::from(config.tls_config);
|
||
|
|
let resolver = Arc::new(resolver);
|
||
|
|
|
||
|
|
info!("HTTPS DoH server listening on {}", local_addr);
|
||
|
|
|
||
|
|
tokio::spawn(async move {
|
||
|
|
let mut shutdown_rx = shutdown_rx;
|
||
|
|
|
||
|
|
loop {
|
||
|
|
tokio::select! {
|
||
|
|
result = listener.accept() => {
|
||
|
|
match result {
|
||
|
|
Ok((stream, _peer_addr)) => {
|
||
|
|
let acceptor = tls_acceptor.clone();
|
||
|
|
let resolver = resolver.clone();
|
||
|
|
|
||
|
|
tokio::spawn(async move {
|
||
|
|
match acceptor.accept(stream).await {
|
||
|
|
Ok(tls_stream) => {
|
||
|
|
let io = TokioIo::new(tls_stream);
|
||
|
|
let resolver = resolver.clone();
|
||
|
|
|
||
|
|
let service = service_fn(move |req: Request<Incoming>| {
|
||
|
|
let resolver = resolver.clone();
|
||
|
|
async move {
|
||
|
|
handle_doh_request(req, resolver).await
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
if let Err(e) = hyper::server::conn::http1::Builder::new()
|
||
|
|
.serve_connection(io, service)
|
||
|
|
.await
|
||
|
|
{
|
||
|
|
error!("HTTPS connection error: {}", e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
error!("TLS accept error: {}", e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
error!("TCP accept error: {}", e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
_ = shutdown_rx.changed() => {
|
||
|
|
if *shutdown_rx.borrow() {
|
||
|
|
info!("HTTPS DoH server shutting down");
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
Ok(HttpsServer {
|
||
|
|
shutdown: shutdown_tx,
|
||
|
|
local_addr,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Stop the HTTPS server.
|
||
|
|
pub fn stop(&self) {
|
||
|
|
let _ = self.shutdown.send(true);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get the bound local address.
|
||
|
|
pub fn local_addr(&self) -> SocketAddr {
|
||
|
|
self.local_addr
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle_doh_request<F, Fut>(
|
||
|
|
req: Request<Incoming>,
|
||
|
|
resolver: Arc<F>,
|
||
|
|
) -> Result<Response<Full<bytes::Bytes>>, hyper::Error>
|
||
|
|
where
|
||
|
|
F: Fn(DnsPacket) -> Fut + Send + Sync,
|
||
|
|
Fut: std::future::Future<Output = DnsPacket> + Send,
|
||
|
|
{
|
||
|
|
if req.method() == hyper::Method::POST && req.uri().path() == "/dns-query" {
|
||
|
|
let body = req.collect().await?.to_bytes();
|
||
|
|
|
||
|
|
match DnsPacket::parse(&body) {
|
||
|
|
Ok(request) => {
|
||
|
|
let response = resolver(request).await;
|
||
|
|
let encoded = response.encode();
|
||
|
|
|
||
|
|
Ok(Response::builder()
|
||
|
|
.status(StatusCode::OK)
|
||
|
|
.header("Content-Type", "application/dns-message")
|
||
|
|
.body(Full::new(bytes::Bytes::from(encoded)))
|
||
|
|
.unwrap())
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
error!("Failed to parse DoH request: {}", e);
|
||
|
|
Ok(Response::builder()
|
||
|
|
.status(StatusCode::BAD_REQUEST)
|
||
|
|
.body(Full::new(bytes::Bytes::from(format!("Invalid DNS message: {}", e))))
|
||
|
|
.unwrap())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
Ok(Response::builder()
|
||
|
|
.status(StatusCode::NOT_FOUND)
|
||
|
|
.body(Full::new(bytes::Bytes::new()))
|
||
|
|
.unwrap())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Create a rustls ServerConfig from PEM-encoded certificate and key.
|
||
|
|
pub fn create_tls_config(cert_pem: &str, key_pem: &str) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error + Send + Sync>> {
|
||
|
|
let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
|
||
|
|
.collect::<Result<Vec<_>, _>>()?;
|
||
|
|
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
|
||
|
|
.ok_or("no private key found in PEM data")?;
|
||
|
|
|
||
|
|
let config = ServerConfig::builder()
|
||
|
|
.with_no_client_auth()
|
||
|
|
.with_single_cert(certs, key)?;
|
||
|
|
|
||
|
|
Ok(Arc::new(config))
|
||
|
|
}
|