feat(enterprise): add auth TLS and recovery hardening

This commit is contained in:
2026-04-29 22:01:43 +00:00
parent 2f3031cfc7
commit ed2c02bcf9
27 changed files with 2369 additions and 55 deletions
+1
View File
@@ -22,3 +22,4 @@ rustdb-query = { workspace = true }
rustdb-storage = { workspace = true }
rustdb-index = { workspace = true }
rustdb-txn = { workspace = true }
rustdb-auth = { workspace = true }
@@ -2,6 +2,7 @@ use std::sync::Arc;
use bson::{Bson, Document};
use dashmap::DashMap;
use rustdb_auth::{AuthEngine, AuthenticatedUser, ScramConversation};
use rustdb_index::{IndexEngine, IndexOptions};
use rustdb_storage::{OpLog, StorageAdapter};
use rustdb_txn::{SessionEngine, TransactionEngine};
@@ -22,6 +23,8 @@ pub struct CommandContext {
pub start_time: std::time::Instant,
/// Operation log for point-in-time replay.
pub oplog: Arc<OpLog>,
/// Authentication engine and user store.
pub auth: Arc<AuthEngine>,
}
impl CommandContext {
@@ -85,6 +88,43 @@ impl CommandContext {
}
}
/// Per-client connection state. Authentication is socket-scoped in MongoDB.
pub struct ConnectionState {
pub authenticated_users: Vec<AuthenticatedUser>,
pub sasl_conversations: std::collections::HashMap<i32, ScramConversation>,
next_conversation_id: i32,
}
impl ConnectionState {
pub fn new() -> Self {
Self {
authenticated_users: Vec::new(),
sasl_conversations: std::collections::HashMap::new(),
next_conversation_id: 1,
}
}
pub fn is_authenticated(&self) -> bool {
!self.authenticated_users.is_empty()
}
pub fn next_conversation_id(&mut self) -> i32 {
let id = self.next_conversation_id;
self.next_conversation_id += 1;
id
}
pub fn authenticate(&mut self, user: AuthenticatedUser) {
self.authenticated_users.push(user);
}
}
impl Default for ConnectionState {
fn default() -> Self {
Self::new()
}
}
/// State of an open cursor from a find or aggregate command.
pub struct CursorState {
/// Documents remaining to be returned.
+12
View File
@@ -30,6 +30,15 @@ pub enum CommandError {
#[error("immutable field: {0}")]
ImmutableField(String),
#[error("unauthorized: {0}")]
Unauthorized(String),
#[error("authentication failed")]
AuthenticationFailed,
#[error("illegal operation: {0}")]
IllegalOperation(String),
#[error("internal error: {0}")]
InternalError(String),
}
@@ -47,6 +56,9 @@ impl CommandError {
CommandError::NamespaceExists(_) => (48, "NamespaceExists"),
CommandError::DuplicateKey(_) => (11000, "DuplicateKey"),
CommandError::ImmutableField(_) => (66, "ImmutableField"),
CommandError::Unauthorized(_) => (13, "Unauthorized"),
CommandError::AuthenticationFailed => (18, "AuthenticationFailed"),
CommandError::IllegalOperation(_) => (20, "IllegalOperation"),
CommandError::InternalError(_) => (1, "InternalError"),
};
@@ -98,6 +98,18 @@ pub async fn handle(
"ok": 1.0,
}),
"createUser" => handle_create_user(cmd, db, ctx).await,
"updateUser" => handle_update_user(cmd, db, ctx).await,
"dropUser" => handle_drop_user(cmd, db, ctx).await,
"usersInfo" => handle_users_info(cmd, db, ctx).await,
"grantRolesToUser" => handle_grant_roles_to_user(cmd, db, ctx).await,
"revokeRolesFromUser" => handle_revoke_roles_from_user(cmd, db, ctx).await,
"listDatabases" => handle_list_databases(cmd, ctx).await,
"listCollections" => handle_list_collections(cmd, db, ctx).await,
@@ -144,15 +156,9 @@ pub async fn handle(
Ok(doc! { "ok": 1.0 })
}
"commitTransaction" => {
// Stub: acknowledge.
Ok(doc! { "ok": 1.0 })
}
"abortTransaction" => {
// Stub: acknowledge.
Ok(doc! { "ok": 1.0 })
}
"commitTransaction" | "abortTransaction" => Err(CommandError::IllegalOperation(
"Transaction numbers are only allowed on a replica set member or mongos".into(),
)),
// Auth stubs - accept silently.
"saslStart" => Ok(doc! {
@@ -189,6 +195,166 @@ pub async fn handle(
}
}
async fn handle_create_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("createUser")
.map_err(|_| CommandError::InvalidArgument("missing 'createUser' field".into()))?;
let password = cmd
.get_str("pwd")
.map_err(|_| CommandError::InvalidArgument("missing 'pwd' field".into()))?;
let roles = parse_roles(cmd, db, "roles")?;
ctx.auth
.create_user(db, username, password, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_update_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("updateUser")
.map_err(|_| CommandError::InvalidArgument("missing 'updateUser' field".into()))?;
let password = cmd.get_str("pwd").ok();
let roles = if cmd.contains_key("roles") {
Some(parse_roles(cmd, db, "roles")?)
} else {
None
};
ctx.auth
.update_user(db, username, password, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_drop_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("dropUser")
.map_err(|_| CommandError::InvalidArgument("missing 'dropUser' field".into()))?;
ctx.auth
.drop_user(db, username)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_users_info(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = match cmd.get("usersInfo") {
Some(Bson::String(name)) => Some(name.as_str()),
Some(Bson::Document(user_doc)) => user_doc.get_str("user").ok(),
_ => None,
};
let users = ctx.auth.users_info(db, username);
let user_docs: Vec<Bson> = users
.into_iter()
.map(|user| {
let roles: Vec<Bson> = user
.roles
.iter()
.map(|role| Bson::Document(role_to_document(&user.database, role)))
.collect();
Bson::Document(doc! {
"user": user.username,
"db": user.database,
"roles": roles,
"mechanisms": ["SCRAM-SHA-256"],
})
})
.collect();
Ok(doc! { "users": user_docs, "ok": 1.0 })
}
async fn handle_grant_roles_to_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("grantRolesToUser")
.map_err(|_| CommandError::InvalidArgument("missing 'grantRolesToUser' field".into()))?;
let roles = parse_roles(cmd, db, "roles")?;
ctx.auth
.grant_roles(db, username, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_revoke_roles_from_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("revokeRolesFromUser")
.map_err(|_| CommandError::InvalidArgument("missing 'revokeRolesFromUser' field".into()))?;
let roles = parse_roles(cmd, db, "roles")?;
ctx.auth
.revoke_roles(db, username, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
fn parse_roles(cmd: &Document, db: &str, key: &str) -> CommandResult<Vec<String>> {
let role_values = cmd
.get_array(key)
.map_err(|_| CommandError::InvalidArgument(format!("missing '{key}' array")))?;
let mut roles = Vec::with_capacity(role_values.len());
for role_value in role_values {
match role_value {
Bson::String(role) => roles.push(role.clone()),
Bson::Document(role_doc) => {
let role = role_doc
.get_str("role")
.map_err(|_| CommandError::InvalidArgument("role document missing 'role'".into()))?;
let role_db = role_doc.get_str("db").unwrap_or(db);
if role_db == db {
roles.push(role.to_string());
} else {
roles.push(format!("{role_db}.{role}"));
}
}
_ => return Err(CommandError::InvalidArgument("roles must be strings or documents".into())),
}
}
Ok(roles)
}
fn role_to_document(default_db: &str, role: &str) -> Document {
if let Some((role_db, role_name)) = role.split_once('.') {
doc! { "role": role_name, "db": role_db }
} else {
doc! { "role": role, "db": default_db }
}
}
fn auth_error_to_command_error(error: rustdb_auth::AuthError) -> CommandError {
match error {
rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message),
rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message),
rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message),
rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed,
rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()),
rustdb_auth::AuthError::UnknownConversation => {
CommandError::InvalidArgument("unknown SASL conversation".into())
}
}
}
/// Handle `listDatabases` command.
async fn handle_list_databases(
cmd: &Document,
@@ -0,0 +1,87 @@
use bson::{doc, Binary, Bson, Document};
use crate::context::{CommandContext, ConnectionState};
use crate::error::{CommandError, CommandResult};
pub async fn handle_sasl_start(
cmd: &Document,
db: &str,
ctx: &CommandContext,
connection: &mut ConnectionState,
) -> CommandResult<Document> {
let mechanism = cmd
.get_str("mechanism")
.map_err(|_| CommandError::InvalidArgument("missing SASL mechanism".into()))?;
if mechanism != "SCRAM-SHA-256" {
return Err(CommandError::InvalidArgument(format!(
"unsupported SASL mechanism: {mechanism}"
)));
}
let payload = payload_bytes(cmd)?;
let result = ctx
.auth
.start_scram_sha256(db, &payload)
.map_err(map_auth_error)?;
let conversation_id = connection.next_conversation_id();
connection
.sasl_conversations
.insert(conversation_id, result.conversation);
Ok(doc! {
"conversationId": conversation_id,
"done": false,
"payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload },
"ok": 1.0,
})
}
pub async fn handle_sasl_continue(
cmd: &Document,
ctx: &CommandContext,
connection: &mut ConnectionState,
) -> CommandResult<Document> {
let conversation_id = cmd
.get_i32("conversationId")
.map_err(|_| CommandError::InvalidArgument("missing SASL conversationId".into()))?;
let payload = payload_bytes(cmd)?;
let conversation = connection
.sasl_conversations
.remove(&conversation_id)
.ok_or_else(|| CommandError::InvalidArgument("unknown SASL conversation".into()))?;
let result = ctx
.auth
.continue_scram_sha256(conversation, &payload)
.map_err(map_auth_error)?;
connection.authenticate(result.user);
Ok(doc! {
"conversationId": conversation_id,
"done": true,
"payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload },
"ok": 1.0,
})
}
fn payload_bytes(cmd: &Document) -> CommandResult<Vec<u8>> {
match cmd.get("payload") {
Some(Bson::Binary(binary)) => Ok(binary.bytes.clone()),
Some(Bson::String(value)) => Ok(value.as_bytes().to_vec()),
_ => Err(CommandError::InvalidArgument("missing SASL payload".into())),
}
}
fn map_auth_error(error: rustdb_auth::AuthError) -> CommandError {
match error {
rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()),
rustdb_auth::AuthError::UnknownConversation => {
CommandError::InvalidArgument("unknown SASL conversation".into())
}
rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed,
rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message),
rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message),
rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message),
}
}
@@ -1,4 +1,4 @@
use bson::{doc, Document};
use bson::{doc, Bson, Document};
use crate::context::CommandContext;
use crate::error::CommandResult;
@@ -7,12 +7,13 @@ use crate::error::CommandResult;
///
/// Returns server capabilities matching wire protocol expectations.
pub async fn handle(
_cmd: &Document,
cmd: &Document,
_db: &str,
_ctx: &CommandContext,
ctx: &CommandContext,
) -> CommandResult<Document> {
Ok(doc! {
let mut response = doc! {
"ismaster": true,
"helloOk": true,
"isWritablePrimary": true,
"maxBsonObjectSize": 16_777_216_i32,
"maxMessageSizeBytes": 48_000_000_i32,
@@ -24,5 +25,19 @@ pub async fn handle(
"maxWireVersion": 21_i32,
"readOnly": false,
"ok": 1.0,
})
};
if ctx.auth.enabled() {
if let Ok(namespace_user) = cmd.get_str("saslSupportedMechs") {
let mechanisms: Vec<Bson> = ctx
.auth
.supported_mechanisms(namespace_user)
.into_iter()
.map(Bson::String)
.collect();
response.insert("saslSupportedMechs", Bson::Array(mechanisms));
}
}
Ok(response)
}
@@ -1,5 +1,6 @@
pub mod admin_handler;
pub mod aggregate_handler;
pub mod auth_handler;
pub mod delete_handler;
pub mod find_handler;
pub mod hello_handler;
+1 -1
View File
@@ -3,6 +3,6 @@ pub mod error;
pub mod handlers;
mod router;
pub use context::{CommandContext, CursorState};
pub use context::{CommandContext, ConnectionState, CursorState};
pub use error::{CommandError, CommandResult};
pub use router::CommandRouter;
+110 -4
View File
@@ -1,11 +1,12 @@
use std::sync::Arc;
use bson::Document;
use bson::{Bson, Document};
use tracing::{debug, warn};
use rustdb_wire::ParsedCommand;
use rustdb_auth::AuthAction;
use crate::context::CommandContext;
use crate::context::{CommandContext, ConnectionState};
use crate::error::CommandError;
use crate::handlers;
@@ -21,12 +22,46 @@ impl CommandRouter {
}
/// Route a parsed command to the appropriate handler, returning a BSON response document.
pub async fn route(&self, cmd: &ParsedCommand) -> Document {
pub async fn route(&self, cmd: &ParsedCommand, connection: &mut ConnectionState) -> Document {
let db = &cmd.database;
let command_name = cmd.command_name.as_str();
debug!(command = %command_name, database = %db, "routing command");
if self.ctx.auth.enabled()
&& !connection.is_authenticated()
&& !allows_unauthenticated(command_name)
{
return CommandError::Unauthorized(format!(
"command '{}' requires authentication",
command_name,
))
.to_error_doc();
}
if self.ctx.auth.enabled() && connection.is_authenticated() {
if let Some(action) = required_action(command_name, &cmd.command) {
if !self
.ctx
.auth
.is_authorized(&connection.authenticated_users, db, action)
{
return CommandError::Unauthorized(format!(
"command '{}' is not authorized for database '{}'",
command_name, db,
))
.to_error_doc();
}
}
}
if transaction_command_unsupported(command_name, &cmd.command) {
return CommandError::IllegalOperation(
"Transaction numbers are only allowed on a replica set member or mongos".into(),
)
.to_error_doc();
}
// Extract session id if present, and touch the session.
if let Some(lsid) = cmd.command.get("lsid") {
if let Some(session_id) = rustdb_txn::SessionEngine::extract_session_id(lsid) {
@@ -40,6 +75,14 @@ impl CommandRouter {
handlers::hello_handler::handle(&cmd.command, db, &self.ctx).await
}
// -- authentication --
"saslStart" => {
handlers::auth_handler::handle_sasl_start(&cmd.command, db, &self.ctx, connection).await
}
"saslContinue" => {
handlers::auth_handler::handle_sasl_continue(&cmd.command, &self.ctx, connection).await
}
// -- query commands --
"find" => {
handlers::find_handler::handle(&cmd.command, db, &self.ctx).await
@@ -88,7 +131,9 @@ impl CommandRouter {
| "dbStats" | "collStats" | "validate" | "explain"
| "startSession" | "endSessions" | "killSessions"
| "commitTransaction" | "abortTransaction"
| "saslStart" | "saslContinue" | "authenticate" | "logout"
| "authenticate" | "logout"
| "createUser" | "updateUser" | "dropUser" | "usersInfo"
| "grantRolesToUser" | "revokeRolesFromUser"
| "currentOp" | "killOp" | "top" | "profile"
| "compact" | "reIndex" | "fsync" | "connPoolSync" => {
handlers::admin_handler::handle(&cmd.command, db, &self.ctx, command_name).await
@@ -107,3 +152,64 @@ impl CommandRouter {
}
}
}
fn allows_unauthenticated(command_name: &str) -> bool {
matches!(
command_name,
"hello" | "ismaster" | "isMaster" | "saslStart" | "saslContinue" | "getnonce"
)
}
fn required_action(command_name: &str, command: &Document) -> Option<AuthAction> {
match command_name {
"hello" | "ismaster" | "isMaster" | "saslStart" | "saslContinue" | "getnonce" => None,
"ping" | "buildInfo" | "buildinfo" | "hostInfo" | "whatsmyuri" | "getLog"
| "getCmdLineOpts" | "getParameter" | "getFreeMonitoringStatus" | "setFreeMonitoring"
| "getShardMap" | "shardingState" | "atlasVersion" | "connectionStatus"
| "startSession" | "endSessions" | "killSessions" | "authenticate" | "logout" => None,
"find" | "getMore" | "killCursors" | "count" | "distinct" | "listIndexes"
| "listCollections" | "collStats" | "dbStats" | "validate" | "explain" => {
Some(AuthAction::Read)
}
"aggregate" => Some(if aggregate_writes(command) {
AuthAction::Write
} else {
AuthAction::Read
}),
"insert" | "update" | "findAndModify" | "delete" | "commitTransaction"
| "abortTransaction" => Some(AuthAction::Write),
"createIndexes" | "dropIndexes" | "create" | "drop" | "dropDatabase"
| "renameCollection" | "compact" | "reIndex" | "fsync" | "profile" => {
Some(AuthAction::DbAdmin)
}
"createUser" | "updateUser" | "dropUser" | "usersInfo" | "grantRolesToUser"
| "revokeRolesFromUser" => Some(AuthAction::UserAdmin),
"serverStatus" | "listDatabases" | "currentOp" | "killOp" | "top" => {
Some(AuthAction::ClusterMonitor)
}
_ => None,
}
}
fn aggregate_writes(command: &Document) -> bool {
let Ok(pipeline) = command.get_array("pipeline") else {
return false;
};
pipeline.last().and_then(|stage| match stage {
Bson::Document(doc) => Some(doc.contains_key("$out") || doc.contains_key("$merge")),
_ => None,
}).unwrap_or(false)
}
fn transaction_command_unsupported(command_name: &str, command: &Document) -> bool {
matches!(command_name, "commitTransaction" | "abortTransaction")
|| matches!(command.get("startTransaction"), Some(Bson::Boolean(true)))
|| matches!(command.get("autocommit"), Some(Bson::Boolean(false)))
}