feat(enterprise): add auth TLS and recovery hardening

This commit is contained in:
2026-04-29 22:01:43 +00:00
parent 2f3031cfc7
commit ed2c02bcf9
27 changed files with 2369 additions and 55 deletions
+3
View File
@@ -21,9 +21,12 @@ rustdb-query = { workspace = true }
rustdb-storage = { workspace = true }
rustdb-index = { workspace = true }
rustdb-txn = { workspace = true }
rustdb-auth = { workspace = true }
rustdb-commands = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tokio-rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
clap = { workspace = true }
+109 -9
View File
@@ -1,10 +1,12 @@
pub mod management;
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use anyhow::{Context, Result};
use dashmap::DashMap;
use tokio::net::TcpListener;
#[cfg(unix)]
@@ -12,13 +14,17 @@ use tokio::net::UnixListener;
use tokio_util::codec::Framed;
use tokio_util::sync::CancellationToken;
use rustdb_config::{RustDbOptions, StorageType};
use rustdb_config::{RustDbOptions, StorageType, TlsOptions};
use rustdb_wire::{WireCodec, OP_QUERY};
use rustdb_wire::{encode_op_msg_response, encode_op_reply_response};
use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog};
use rustdb_index::{IndexEngine, IndexOptions};
use rustdb_txn::{TransactionEngine, SessionEngine};
use rustdb_commands::{CommandRouter, CommandContext};
use rustdb_auth::AuthEngine;
use rustdb_commands::{CommandRouter, CommandContext, ConnectionState};
use tokio_rustls::rustls::{RootCertStore, ServerConfig};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::TlsAcceptor;
/// The main RustDb server.
pub struct RustDb {
@@ -150,6 +156,8 @@ impl RustDb {
}
}
let auth = Arc::new(AuthEngine::from_options(&options.auth)?);
let ctx = Arc::new(CommandContext {
storage,
indexes,
@@ -158,6 +166,7 @@ impl RustDb {
cursors: Arc::new(DashMap::new()),
start_time: std::time::Instant::now(),
oplog: Arc::new(OpLog::new()),
auth,
});
let router = Arc::new(CommandRouter::new(ctx.clone()));
@@ -215,7 +224,12 @@ impl RustDb {
} else {
let addr = format!("{}:{}", self.options.host, self.options.port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!("RustDb listening on {}", addr);
let tls_acceptor = if self.options.tls.enabled {
Some(build_tls_acceptor(&self.options.tls)?)
} else {
None
};
tracing::info!(tls = self.options.tls.enabled, "RustDb listening on {}", addr);
let handle = tokio::spawn(async move {
loop {
@@ -226,9 +240,21 @@ impl RustDb {
Ok((stream, _addr)) => {
let _ = stream.set_nodelay(true);
let router = router.clone();
tokio::spawn(async move {
handle_connection(stream, router).await;
});
match tls_acceptor.clone() {
Some(acceptor) => {
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => handle_connection(tls_stream, router).await,
Err(e) => tracing::debug!("TLS handshake failed: {}", e),
}
});
}
None => {
tokio::spawn(async move {
handle_connection(stream, router).await;
});
}
}
}
Err(e) => {
tracing::error!("Accept error: {}", e);
@@ -275,14 +301,88 @@ impl RustDb {
}
}
fn build_tls_acceptor(options: &TlsOptions) -> Result<TlsAcceptor> {
let cert_path = options
.cert_path
.as_deref()
.context("tls.certPath is required when tls.enabled is true")?;
let key_path = options
.key_path
.as_deref()
.context("tls.keyPath is required when tls.enabled is true")?;
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let config = if options.require_client_cert {
let ca_path = options
.ca_path
.as_deref()
.context("tls.caPath is required when tls.requireClientCert is true")?;
let roots = load_root_store(ca_path)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.context("failed to build TLS client certificate verifier")?;
ServerConfig::builder()
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.context("failed to build TLS server configuration")?
} else {
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.context("failed to build TLS server configuration")?
};
Ok(TlsAcceptor::from(Arc::new(config)))
}
fn load_certs(path: &str) -> Result<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>> {
let file = File::open(path).with_context(|| format!("failed to open TLS certificate file '{}'", path))?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.with_context(|| format!("failed to parse TLS certificate file '{}'", path))?;
if certs.is_empty() {
anyhow::bail!("TLS certificate file '{}' did not contain any certificates", path);
}
Ok(certs)
}
fn load_private_key(path: &str) -> Result<tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>> {
let file = File::open(path).with_context(|| format!("failed to open TLS private key file '{}'", path))?;
let mut reader = BufReader::new(file);
rustls_pemfile::private_key(&mut reader)
.with_context(|| format!("failed to parse TLS private key file '{}'", path))?
.with_context(|| format!("TLS private key file '{}' did not contain a private key", path))
}
fn load_root_store(path: &str) -> Result<RootCertStore> {
let mut roots = RootCertStore::empty();
for cert in load_certs(path)? {
roots
.add(cert)
.with_context(|| format!("failed to add TLS client CA certificate from '{}'", path))?;
}
if roots.is_empty() {
anyhow::bail!("TLS client CA file '{}' did not contain usable certificates", path);
}
Ok(roots)
}
/// Handle a single client connection using the wire protocol codec.
async fn handle_connection<S>(stream: S, router: Arc<CommandRouter>)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use futures_util::{SinkExt, StreamExt};
let mut framed = Framed::new(stream, WireCodec);
let mut connection = ConnectionState::new();
while let Some(result) = framed.next().await {
match result {
@@ -290,7 +390,7 @@ where
let request_id = parsed_cmd.request_id;
let op_code = parsed_cmd.op_code;
let response_doc = router.route(&parsed_cmd).await;
let response_doc = router.route(&parsed_cmd, &mut connection).await;
let response_id = next_request_id();
+3
View File
@@ -167,6 +167,9 @@ async fn handle_start(
Ok(o) => o,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
};
if let Err(e) = options.validate() {
return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e));
}
let connection_uri = options.connection_uri();