feat(enterprise): add auth TLS and recovery hardening

This commit is contained in:
2026-04-29 22:01:43 +00:00
parent 2f3031cfc7
commit ed2c02bcf9
27 changed files with 2369 additions and 55 deletions
+20
View File
@@ -0,0 +1,20 @@
[package]
name = "rustdb-auth"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Authentication primitives for RustDb"
[dependencies]
base64 = { workspace = true }
bson = { workspace = true }
hmac = { workspace = true }
pbkdf2 = { workspace = true }
rand = { workspace = true }
rustdb-config = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sha2 = { workspace = true }
subtle = { workspace = true }
thiserror = { workspace = true }
+565
View File
@@ -0,0 +1,565 @@
use std::collections::HashMap;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::RwLock;
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2_hmac;
use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use rustdb_config::{AuthOptions, AuthUserOptions};
type HmacSha256 = Hmac<Sha256>;
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("authentication is disabled")]
Disabled,
#[error("unsupported authentication mechanism: {0}")]
UnsupportedMechanism(String),
#[error("invalid SCRAM payload: {0}")]
InvalidPayload(String),
#[error("authentication failed")]
AuthenticationFailed,
#[error("unknown SASL conversation")]
UnknownConversation,
#[error("user already exists: {0}")]
UserAlreadyExists(String),
#[error("user not found: {0}")]
UserNotFound(String),
#[error("auth metadata persistence failed: {0}")]
Persistence(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthAction {
Read,
Write,
DbAdmin,
UserAdmin,
ClusterMonitor,
}
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
pub username: String,
pub database: String,
pub roles: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ScramCredential {
salt: Vec<u8>,
iterations: u32,
stored_key: Vec<u8>,
server_key: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AuthUser {
username: String,
database: String,
roles: Vec<String>,
scram_sha256: ScramCredential,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct PersistedAuthState {
users: Vec<AuthUser>,
}
#[derive(Debug, Clone)]
pub struct ScramConversation {
user: AuthenticatedUser,
client_first_bare: String,
server_first: String,
nonce: String,
stored_key: Vec<u8>,
server_key: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct ScramStartResult {
pub payload: Vec<u8>,
pub conversation: ScramConversation,
}
#[derive(Debug, Clone)]
pub struct ScramContinueResult {
pub payload: Vec<u8>,
pub user: AuthenticatedUser,
}
#[derive(Debug)]
pub struct AuthEngine {
enabled: bool,
users: RwLock<HashMap<String, AuthUser>>,
users_path: Option<PathBuf>,
scram_iterations: u32,
}
impl AuthEngine {
pub fn from_options(options: &AuthOptions) -> Result<Self, AuthError> {
let users_path = options.users_path.as_ref().map(PathBuf::from);
let mut users = if let Some(ref path) = users_path {
load_users(path)?
} else {
HashMap::new()
};
let mut changed = false;
for user_options in &options.users {
let key = user_key(&user_options.database, &user_options.username);
if !users.contains_key(&key) {
let user = AuthUser::from_options(user_options, options.scram_iterations);
users.insert(key, user);
changed = true;
}
}
if changed {
if let Some(ref path) = users_path {
persist_users(path, &users)?;
}
}
Ok(Self {
enabled: options.enabled,
users: RwLock::new(users),
users_path,
scram_iterations: options.scram_iterations,
})
}
pub fn disabled() -> Self {
Self {
enabled: false,
users: RwLock::new(HashMap::new()),
users_path: None,
scram_iterations: 15000,
}
}
pub fn enabled(&self) -> bool {
self.enabled
}
pub fn supported_mechanisms(&self, namespace_user: &str) -> Vec<String> {
let Some((database, username)) = namespace_user.split_once('.') else {
return Vec::new();
};
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
if users.contains_key(&user_key(database, username)) {
vec![SCRAM_SHA_256.to_string()]
} else {
Vec::new()
}
}
pub fn is_authorized(
&self,
authenticated_users: &[AuthenticatedUser],
target_db: &str,
action: AuthAction,
) -> bool {
authenticated_users
.iter()
.any(|user| user.roles.iter().any(|role| role_allows(role, user, target_db, action)))
}
pub fn create_user(
&self,
database: &str,
username: &str,
password: &str,
roles: Vec<String>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
if users.contains_key(&key) {
return Err(AuthError::UserAlreadyExists(format!("{database}.{username}")));
}
let options = AuthUserOptions {
username: username.to_string(),
password: password.to_string(),
database: database.to_string(),
roles,
};
users.insert(key, AuthUser::from_options(&options, self.scram_iterations));
self.persist_locked(&users)
}
pub fn drop_user(&self, database: &str, username: &str) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
if users.remove(&key).is_none() {
return Err(AuthError::UserNotFound(format!("{database}.{username}")));
}
self.persist_locked(&users)
}
pub fn update_user(
&self,
database: &str,
username: &str,
password: Option<&str>,
roles: Option<Vec<String>>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get_mut(&key)
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
if let Some(new_roles) = roles {
user.roles = new_roles;
}
if let Some(new_password) = password {
let options = AuthUserOptions {
username: username.to_string(),
password: new_password.to_string(),
database: database.to_string(),
roles: user.roles.clone(),
};
user.scram_sha256 = AuthUser::from_options(&options, self.scram_iterations).scram_sha256;
}
self.persist_locked(&users)
}
pub fn grant_roles(
&self,
database: &str,
username: &str,
roles: Vec<String>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get_mut(&key)
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
for role in roles {
if !user.roles.contains(&role) {
user.roles.push(role);
}
}
self.persist_locked(&users)
}
pub fn revoke_roles(
&self,
database: &str,
username: &str,
roles: Vec<String>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get_mut(&key)
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
user.roles.retain(|role| !roles.contains(role));
self.persist_locked(&users)
}
pub fn users_info(&self, database: &str, username: Option<&str>) -> Vec<AuthenticatedUser> {
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
users
.values()
.filter(|user| user.database == database)
.filter(|user| username.map(|name| user.username == name).unwrap_or(true))
.map(AuthUser::to_authenticated_user)
.collect()
}
pub fn start_scram_sha256(
&self,
database: &str,
payload: &[u8],
) -> Result<ScramStartResult, AuthError> {
if !self.enabled {
return Err(AuthError::Disabled);
}
let message = std::str::from_utf8(payload)
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
let client_first_bare = message
.strip_prefix("n,,")
.ok_or_else(|| AuthError::InvalidPayload("expected SCRAM gs2 header 'n,,'".to_string()))?;
let attrs = parse_scram_attrs(client_first_bare);
let raw_username = attrs
.get("n")
.ok_or_else(|| AuthError::InvalidPayload("missing username".to_string()))?;
let username = decode_scram_name(raw_username);
let client_nonce = attrs
.get("r")
.ok_or_else(|| AuthError::InvalidPayload("missing client nonce".to_string()))?;
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get(&user_key(database, &username))
.ok_or(AuthError::AuthenticationFailed)?;
let nonce = format!("{}{}", client_nonce, secure_base64(18));
let server_first = format!(
"r={},s={},i={}",
nonce,
BASE64_STANDARD.encode(&user.scram_sha256.salt),
user.scram_sha256.iterations,
);
Ok(ScramStartResult {
payload: server_first.as_bytes().to_vec(),
conversation: ScramConversation {
user: user.to_authenticated_user(),
client_first_bare: client_first_bare.to_string(),
server_first: server_first.clone(),
nonce,
stored_key: user.scram_sha256.stored_key.clone(),
server_key: user.scram_sha256.server_key.clone(),
},
})
}
pub fn continue_scram_sha256(
&self,
conversation: ScramConversation,
payload: &[u8],
) -> Result<ScramContinueResult, AuthError> {
let message = std::str::from_utf8(payload)
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
let proof_marker = ",p=";
let proof_pos = message
.rfind(proof_marker)
.ok_or_else(|| AuthError::InvalidPayload("missing client proof".to_string()))?;
let client_final_without_proof = &message[..proof_pos];
let proof_b64 = &message[proof_pos + proof_marker.len()..];
let attrs = parse_scram_attrs(client_final_without_proof);
let nonce = attrs
.get("r")
.ok_or_else(|| AuthError::InvalidPayload("missing nonce".to_string()))?;
if nonce != &conversation.nonce {
return Err(AuthError::AuthenticationFailed);
}
let client_proof = BASE64_STANDARD
.decode(proof_b64.as_bytes())
.map_err(|_| AuthError::InvalidPayload("invalid client proof encoding".to_string()))?;
if client_proof.len() != 32 || conversation.stored_key.len() != 32 {
return Err(AuthError::AuthenticationFailed);
}
let auth_message = format!(
"{},{},{}",
conversation.client_first_bare,
conversation.server_first,
client_final_without_proof,
);
let client_signature = hmac_sha256(&conversation.stored_key, auth_message.as_bytes());
let client_key: Vec<u8> = client_proof
.iter()
.zip(client_signature.iter())
.map(|(proof_byte, signature_byte)| proof_byte ^ signature_byte)
.collect();
let computed_stored_key = Sha256::digest(&client_key).to_vec();
if computed_stored_key.ct_eq(&conversation.stored_key).unwrap_u8() != 1 {
return Err(AuthError::AuthenticationFailed);
}
let server_signature = hmac_sha256(&conversation.server_key, auth_message.as_bytes());
let server_final = format!("v={}", BASE64_STANDARD.encode(server_signature));
Ok(ScramContinueResult {
payload: server_final.as_bytes().to_vec(),
user: conversation.user,
})
}
fn persist_locked(&self, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
if let Some(ref path) = self.users_path {
persist_users(path, users)?;
}
Ok(())
}
}
impl Default for AuthEngine {
fn default() -> Self {
Self::disabled()
}
}
impl AuthUser {
fn from_options(options: &AuthUserOptions, iterations: u32) -> Self {
let salt = secure_random(24);
let salted_password = salted_password(options.password.as_bytes(), &salt, iterations);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = Sha256::digest(&client_key).to_vec();
let server_key = hmac_sha256(&salted_password, b"Server Key");
Self {
username: options.username.clone(),
database: options.database.clone(),
roles: options.roles.clone(),
scram_sha256: ScramCredential {
salt,
iterations,
stored_key,
server_key,
},
}
}
fn to_authenticated_user(&self) -> AuthenticatedUser {
AuthenticatedUser {
username: self.username.clone(),
database: self.database.clone(),
roles: self.roles.clone(),
}
}
}
fn role_allows(role: &str, user: &AuthenticatedUser, target_db: &str, action: AuthAction) -> bool {
let (role_db, role_name) = role.split_once('.').unwrap_or(("", role));
if role_name == "root" {
return true;
}
let any_database = role_name.ends_with("AnyDatabase");
let scoped_db = if role_db.is_empty() { &user.database } else { role_db };
if !any_database && scoped_db != target_db {
return false;
}
match role_name {
"read" | "readAnyDatabase" => action == AuthAction::Read,
"readWrite" | "readWriteAnyDatabase" => {
matches!(action, AuthAction::Read | AuthAction::Write)
}
"dbAdmin" | "dbAdminAnyDatabase" => action == AuthAction::DbAdmin,
"userAdmin" | "userAdminAnyDatabase" => action == AuthAction::UserAdmin,
"clusterMonitor" => action == AuthAction::ClusterMonitor,
_ => false,
}
}
fn load_users(path: &Path) -> Result<HashMap<String, AuthUser>, AuthError> {
if !path.exists() {
return Ok(HashMap::new());
}
let data = std::fs::read_to_string(path).map_err(|e| AuthError::Persistence(e.to_string()))?;
let persisted: PersistedAuthState = serde_json::from_str(&data)
.map_err(|e| AuthError::Persistence(format!("failed to parse users file: {e}")))?;
Ok(persisted
.users
.into_iter()
.map(|user| (user_key(&user.database, &user.username), user))
.collect())
}
fn persist_users(path: &Path, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| AuthError::Persistence(e.to_string()))?;
}
let mut user_list: Vec<AuthUser> = users.values().cloned().collect();
user_list.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username)));
let payload = serde_json::to_vec_pretty(&PersistedAuthState { users: user_list })
.map_err(|e| AuthError::Persistence(e.to_string()))?;
let tmp_path = path.with_extension("tmp");
{
let mut file = std::fs::File::create(&tmp_path)
.map_err(|e| AuthError::Persistence(e.to_string()))?;
file.write_all(&payload)
.map_err(|e| AuthError::Persistence(e.to_string()))?;
file.sync_all()
.map_err(|e| AuthError::Persistence(e.to_string()))?;
}
std::fs::rename(&tmp_path, path).map_err(|e| AuthError::Persistence(e.to_string()))?;
if let Some(parent) = path.parent() {
if let Ok(dir) = std::fs::File::open(parent) {
let _ = dir.sync_all();
}
}
Ok(())
}
fn user_key(database: &str, username: &str) -> String {
format!("{}\0{}", database, username)
}
fn salted_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
let mut output = [0u8; 32];
pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut output);
output.to_vec()
}
fn hmac_sha256(key: &[u8], message: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts keys of any size");
mac.update(message);
mac.finalize().into_bytes().to_vec()
}
fn secure_random(len: usize) -> Vec<u8> {
let mut bytes = vec![0u8; len];
OsRng.fill_bytes(&mut bytes);
bytes
}
fn secure_base64(len: usize) -> String {
BASE64_STANDARD.encode(secure_random(len))
}
fn parse_scram_attrs(input: &str) -> HashMap<String, String> {
let mut result = HashMap::new();
for part in input.split(',') {
if let Some((key, value)) = part.split_once('=') {
result.insert(key.to_string(), value.to_string());
}
}
result
}
fn decode_scram_name(input: &str) -> String {
input.replace("=2C", ",").replace("=3D", "=")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mechanism_lookup_returns_scram_sha256() {
let options = AuthOptions {
enabled: true,
users: vec![AuthUserOptions {
username: "root".to_string(),
password: "secret".to_string(),
database: "admin".to_string(),
roles: vec!["root".to_string()],
}],
users_path: None,
scram_iterations: 4096,
};
let engine = AuthEngine::from_options(&options).unwrap();
assert_eq!(engine.supported_mechanisms("admin.root"), vec![SCRAM_SHA_256.to_string()]);
}
#[test]
fn read_write_role_allows_read_and_write_only_on_own_db() {
let user = AuthenticatedUser {
username: "app".to_string(),
database: "appdb".to_string(),
roles: vec!["readWrite".to_string()],
};
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Read));
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Write));
assert!(!role_allows("readWrite", &user, "other", AuthAction::Read));
assert!(!role_allows("readWrite", &user, "appdb", AuthAction::DbAdmin));
}
}