feat(enterprise): add auth TLS and recovery hardening
This commit is contained in:
@@ -21,9 +21,12 @@ rustdb-query = { workspace = true }
|
||||
rustdb-storage = { workspace = true }
|
||||
rustdb-index = { workspace = true }
|
||||
rustdb-txn = { workspace = true }
|
||||
rustdb-auth = { workspace = true }
|
||||
rustdb-commands = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-rustls = { workspace = true }
|
||||
rustls-pemfile = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
pub mod management;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use dashmap::DashMap;
|
||||
use tokio::net::TcpListener;
|
||||
#[cfg(unix)]
|
||||
@@ -12,13 +14,17 @@ use tokio::net::UnixListener;
|
||||
use tokio_util::codec::Framed;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use rustdb_config::{RustDbOptions, StorageType};
|
||||
use rustdb_config::{RustDbOptions, StorageType, TlsOptions};
|
||||
use rustdb_wire::{WireCodec, OP_QUERY};
|
||||
use rustdb_wire::{encode_op_msg_response, encode_op_reply_response};
|
||||
use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog};
|
||||
use rustdb_index::{IndexEngine, IndexOptions};
|
||||
use rustdb_txn::{TransactionEngine, SessionEngine};
|
||||
use rustdb_commands::{CommandRouter, CommandContext};
|
||||
use rustdb_auth::AuthEngine;
|
||||
use rustdb_commands::{CommandRouter, CommandContext, ConnectionState};
|
||||
use tokio_rustls::rustls::{RootCertStore, ServerConfig};
|
||||
use tokio_rustls::rustls::server::WebPkiClientVerifier;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
/// The main RustDb server.
|
||||
pub struct RustDb {
|
||||
@@ -150,6 +156,8 @@ impl RustDb {
|
||||
}
|
||||
}
|
||||
|
||||
let auth = Arc::new(AuthEngine::from_options(&options.auth)?);
|
||||
|
||||
let ctx = Arc::new(CommandContext {
|
||||
storage,
|
||||
indexes,
|
||||
@@ -158,6 +166,7 @@ impl RustDb {
|
||||
cursors: Arc::new(DashMap::new()),
|
||||
start_time: std::time::Instant::now(),
|
||||
oplog: Arc::new(OpLog::new()),
|
||||
auth,
|
||||
});
|
||||
|
||||
let router = Arc::new(CommandRouter::new(ctx.clone()));
|
||||
@@ -215,7 +224,12 @@ impl RustDb {
|
||||
} else {
|
||||
let addr = format!("{}:{}", self.options.host, self.options.port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
tracing::info!("RustDb listening on {}", addr);
|
||||
let tls_acceptor = if self.options.tls.enabled {
|
||||
Some(build_tls_acceptor(&self.options.tls)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tracing::info!(tls = self.options.tls.enabled, "RustDb listening on {}", addr);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
@@ -226,9 +240,21 @@ impl RustDb {
|
||||
Ok((stream, _addr)) => {
|
||||
let _ = stream.set_nodelay(true);
|
||||
let router = router.clone();
|
||||
tokio::spawn(async move {
|
||||
handle_connection(stream, router).await;
|
||||
});
|
||||
match tls_acceptor.clone() {
|
||||
Some(acceptor) => {
|
||||
tokio::spawn(async move {
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => handle_connection(tls_stream, router).await,
|
||||
Err(e) => tracing::debug!("TLS handshake failed: {}", e),
|
||||
}
|
||||
});
|
||||
}
|
||||
None => {
|
||||
tokio::spawn(async move {
|
||||
handle_connection(stream, router).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Accept error: {}", e);
|
||||
@@ -275,14 +301,88 @@ impl RustDb {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tls_acceptor(options: &TlsOptions) -> Result<TlsAcceptor> {
|
||||
let cert_path = options
|
||||
.cert_path
|
||||
.as_deref()
|
||||
.context("tls.certPath is required when tls.enabled is true")?;
|
||||
let key_path = options
|
||||
.key_path
|
||||
.as_deref()
|
||||
.context("tls.keyPath is required when tls.enabled is true")?;
|
||||
|
||||
let certs = load_certs(cert_path)?;
|
||||
let key = load_private_key(key_path)?;
|
||||
|
||||
let config = if options.require_client_cert {
|
||||
let ca_path = options
|
||||
.ca_path
|
||||
.as_deref()
|
||||
.context("tls.caPath is required when tls.requireClientCert is true")?;
|
||||
let roots = load_root_store(ca_path)?;
|
||||
let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
|
||||
.build()
|
||||
.context("failed to build TLS client certificate verifier")?;
|
||||
ServerConfig::builder()
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_single_cert(certs, key)
|
||||
.context("failed to build TLS server configuration")?
|
||||
} else {
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.context("failed to build TLS server configuration")?
|
||||
};
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
fn load_certs(path: &str) -> Result<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>> {
|
||||
let file = File::open(path).with_context(|| format!("failed to open TLS certificate file '{}'", path))?;
|
||||
let mut reader = BufReader::new(file);
|
||||
let certs = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.with_context(|| format!("failed to parse TLS certificate file '{}'", path))?;
|
||||
|
||||
if certs.is_empty() {
|
||||
anyhow::bail!("TLS certificate file '{}' did not contain any certificates", path);
|
||||
}
|
||||
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
fn load_private_key(path: &str) -> Result<tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>> {
|
||||
let file = File::open(path).with_context(|| format!("failed to open TLS private key file '{}'", path))?;
|
||||
let mut reader = BufReader::new(file);
|
||||
rustls_pemfile::private_key(&mut reader)
|
||||
.with_context(|| format!("failed to parse TLS private key file '{}'", path))?
|
||||
.with_context(|| format!("TLS private key file '{}' did not contain a private key", path))
|
||||
}
|
||||
|
||||
fn load_root_store(path: &str) -> Result<RootCertStore> {
|
||||
let mut roots = RootCertStore::empty();
|
||||
for cert in load_certs(path)? {
|
||||
roots
|
||||
.add(cert)
|
||||
.with_context(|| format!("failed to add TLS client CA certificate from '{}'", path))?;
|
||||
}
|
||||
|
||||
if roots.is_empty() {
|
||||
anyhow::bail!("TLS client CA file '{}' did not contain usable certificates", path);
|
||||
}
|
||||
|
||||
Ok(roots)
|
||||
}
|
||||
|
||||
/// Handle a single client connection using the wire protocol codec.
|
||||
async fn handle_connection<S>(stream: S, router: Arc<CommandRouter>)
|
||||
where
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
|
||||
let mut framed = Framed::new(stream, WireCodec);
|
||||
let mut connection = ConnectionState::new();
|
||||
|
||||
while let Some(result) = framed.next().await {
|
||||
match result {
|
||||
@@ -290,7 +390,7 @@ where
|
||||
let request_id = parsed_cmd.request_id;
|
||||
let op_code = parsed_cmd.op_code;
|
||||
|
||||
let response_doc = router.route(&parsed_cmd).await;
|
||||
let response_doc = router.route(&parsed_cmd, &mut connection).await;
|
||||
|
||||
let response_id = next_request_id();
|
||||
|
||||
|
||||
@@ -167,6 +167,9 @@ async fn handle_start(
|
||||
Ok(o) => o,
|
||||
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
|
||||
};
|
||||
if let Err(e) = options.validate() {
|
||||
return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e));
|
||||
}
|
||||
|
||||
let connection_uri = options.connection_uri();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user