use std::collections::HashMap; use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::RwLock; use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _}; use hmac::{Hmac, Mac}; use pbkdf2::pbkdf2_hmac; use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use subtle::ConstantTimeEq; use rustdb_config::{AuthOptions, AuthUserOptions}; type HmacSha256 = Hmac; const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; #[derive(Debug, thiserror::Error)] pub enum AuthError { #[error("authentication is disabled")] Disabled, #[error("unsupported authentication mechanism: {0}")] UnsupportedMechanism(String), #[error("invalid SCRAM payload: {0}")] InvalidPayload(String), #[error("authentication failed")] AuthenticationFailed, #[error("unknown SASL conversation")] UnknownConversation, #[error("user already exists: {0}")] UserAlreadyExists(String), #[error("user not found: {0}")] UserNotFound(String), #[error("auth metadata persistence failed: {0}")] Persistence(String), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AuthAction { Read, Write, DbAdmin, UserAdmin, ClusterMonitor, } #[derive(Debug, Clone)] pub struct AuthenticatedUser { pub username: String, pub database: String, pub roles: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ScramCredential { salt: Vec, iterations: u32, stored_key: Vec, server_key: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] struct AuthUser { username: String, database: String, roles: Vec, scram_sha256: ScramCredential, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] struct PersistedAuthState { users: Vec, } #[derive(Debug, Clone)] pub struct ScramConversation { user: AuthenticatedUser, client_first_bare: String, server_first: String, nonce: String, stored_key: Vec, server_key: Vec, } #[derive(Debug, Clone)] pub struct ScramStartResult { pub payload: Vec, pub conversation: ScramConversation, } #[derive(Debug, Clone)] pub struct ScramContinueResult { pub payload: Vec, pub user: AuthenticatedUser, } #[derive(Debug)] pub struct AuthEngine { enabled: bool, users: RwLock>, users_path: Option, scram_iterations: u32, } impl AuthEngine { pub fn from_options(options: &AuthOptions) -> Result { let users_path = options.users_path.as_ref().map(PathBuf::from); let mut users = if let Some(ref path) = users_path { load_users(path)? } else { HashMap::new() }; let mut changed = false; for user_options in &options.users { let key = user_key(&user_options.database, &user_options.username); if !users.contains_key(&key) { let user = AuthUser::from_options(user_options, options.scram_iterations); users.insert(key, user); changed = true; } } if changed { if let Some(ref path) = users_path { persist_users(path, &users)?; } } Ok(Self { enabled: options.enabled, users: RwLock::new(users), users_path, scram_iterations: options.scram_iterations, }) } pub fn disabled() -> Self { Self { enabled: false, users: RwLock::new(HashMap::new()), users_path: None, scram_iterations: 15000, } } pub fn enabled(&self) -> bool { self.enabled } pub fn user_count(&self) -> usize { self.users .read() .unwrap_or_else(|poisoned| poisoned.into_inner()) .len() } pub fn supported_mechanisms(&self, namespace_user: &str) -> Vec { let Some((database, username)) = namespace_user.split_once('.') else { return Vec::new(); }; let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); if users.contains_key(&user_key(database, username)) { vec![SCRAM_SHA_256.to_string()] } else { Vec::new() } } pub fn is_authorized( &self, authenticated_users: &[AuthenticatedUser], target_db: &str, action: AuthAction, ) -> bool { authenticated_users .iter() .any(|user| user.roles.iter().any(|role| role_allows(role, user, target_db, action))) } pub fn create_user( &self, database: &str, username: &str, password: &str, roles: Vec, ) -> Result<(), AuthError> { let key = user_key(database, username); let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); if users.contains_key(&key) { return Err(AuthError::UserAlreadyExists(format!("{database}.{username}"))); } let options = AuthUserOptions { username: username.to_string(), password: password.to_string(), database: database.to_string(), roles, }; users.insert(key, AuthUser::from_options(&options, self.scram_iterations)); self.persist_locked(&users) } pub fn drop_user(&self, database: &str, username: &str) -> Result<(), AuthError> { let key = user_key(database, username); let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); if users.remove(&key).is_none() { return Err(AuthError::UserNotFound(format!("{database}.{username}"))); } self.persist_locked(&users) } pub fn update_user( &self, database: &str, username: &str, password: Option<&str>, roles: Option>, ) -> Result<(), AuthError> { let key = user_key(database, username); let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); let user = users .get_mut(&key) .ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?; if let Some(new_roles) = roles { user.roles = new_roles; } if let Some(new_password) = password { let options = AuthUserOptions { username: username.to_string(), password: new_password.to_string(), database: database.to_string(), roles: user.roles.clone(), }; user.scram_sha256 = AuthUser::from_options(&options, self.scram_iterations).scram_sha256; } self.persist_locked(&users) } pub fn grant_roles( &self, database: &str, username: &str, roles: Vec, ) -> Result<(), AuthError> { let key = user_key(database, username); let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); let user = users .get_mut(&key) .ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?; for role in roles { if !user.roles.contains(&role) { user.roles.push(role); } } self.persist_locked(&users) } pub fn revoke_roles( &self, database: &str, username: &str, roles: Vec, ) -> Result<(), AuthError> { let key = user_key(database, username); let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); let user = users .get_mut(&key) .ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?; user.roles.retain(|role| !roles.contains(role)); self.persist_locked(&users) } pub fn users_info(&self, database: &str, username: Option<&str>) -> Vec { let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); users .values() .filter(|user| user.database == database) .filter(|user| username.map(|name| user.username == name).unwrap_or(true)) .map(AuthUser::to_authenticated_user) .collect() } pub fn list_users(&self) -> Vec { let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); let mut result: Vec = users .values() .map(AuthUser::to_authenticated_user) .collect(); result.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username))); result } pub fn drop_users_for_database(&self, database: &str) -> Result { let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); let before = users.len(); users.retain(|_, user| user.database != database); let dropped = before.saturating_sub(users.len()); if dropped > 0 { self.persist_locked(&users)?; } Ok(dropped) } pub fn start_scram_sha256( &self, database: &str, payload: &[u8], ) -> Result { if !self.enabled { return Err(AuthError::Disabled); } let message = std::str::from_utf8(payload) .map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?; let client_first_bare = message .strip_prefix("n,,") .ok_or_else(|| AuthError::InvalidPayload("expected SCRAM gs2 header 'n,,'".to_string()))?; let attrs = parse_scram_attrs(client_first_bare); let raw_username = attrs .get("n") .ok_or_else(|| AuthError::InvalidPayload("missing username".to_string()))?; let username = decode_scram_name(raw_username); let client_nonce = attrs .get("r") .ok_or_else(|| AuthError::InvalidPayload("missing client nonce".to_string()))?; let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); let user = users .get(&user_key(database, &username)) .ok_or(AuthError::AuthenticationFailed)?; let nonce = format!("{}{}", client_nonce, secure_base64(18)); let server_first = format!( "r={},s={},i={}", nonce, BASE64_STANDARD.encode(&user.scram_sha256.salt), user.scram_sha256.iterations, ); Ok(ScramStartResult { payload: server_first.as_bytes().to_vec(), conversation: ScramConversation { user: user.to_authenticated_user(), client_first_bare: client_first_bare.to_string(), server_first: server_first.clone(), nonce, stored_key: user.scram_sha256.stored_key.clone(), server_key: user.scram_sha256.server_key.clone(), }, }) } pub fn continue_scram_sha256( &self, conversation: ScramConversation, payload: &[u8], ) -> Result { let message = std::str::from_utf8(payload) .map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?; let proof_marker = ",p="; let proof_pos = message .rfind(proof_marker) .ok_or_else(|| AuthError::InvalidPayload("missing client proof".to_string()))?; let client_final_without_proof = &message[..proof_pos]; let proof_b64 = &message[proof_pos + proof_marker.len()..]; let attrs = parse_scram_attrs(client_final_without_proof); let nonce = attrs .get("r") .ok_or_else(|| AuthError::InvalidPayload("missing nonce".to_string()))?; if nonce != &conversation.nonce { return Err(AuthError::AuthenticationFailed); } let client_proof = BASE64_STANDARD .decode(proof_b64.as_bytes()) .map_err(|_| AuthError::InvalidPayload("invalid client proof encoding".to_string()))?; if client_proof.len() != 32 || conversation.stored_key.len() != 32 { return Err(AuthError::AuthenticationFailed); } let auth_message = format!( "{},{},{}", conversation.client_first_bare, conversation.server_first, client_final_without_proof, ); let client_signature = hmac_sha256(&conversation.stored_key, auth_message.as_bytes()); let client_key: Vec = client_proof .iter() .zip(client_signature.iter()) .map(|(proof_byte, signature_byte)| proof_byte ^ signature_byte) .collect(); let computed_stored_key = Sha256::digest(&client_key).to_vec(); if computed_stored_key.ct_eq(&conversation.stored_key).unwrap_u8() != 1 { return Err(AuthError::AuthenticationFailed); } let server_signature = hmac_sha256(&conversation.server_key, auth_message.as_bytes()); let server_final = format!("v={}", BASE64_STANDARD.encode(server_signature)); Ok(ScramContinueResult { payload: server_final.as_bytes().to_vec(), user: conversation.user, }) } fn persist_locked(&self, users: &HashMap) -> Result<(), AuthError> { if let Some(ref path) = self.users_path { persist_users(path, users)?; } Ok(()) } } impl Default for AuthEngine { fn default() -> Self { Self::disabled() } } impl AuthUser { fn from_options(options: &AuthUserOptions, iterations: u32) -> Self { let salt = secure_random(24); let salted_password = salted_password(options.password.as_bytes(), &salt, iterations); let client_key = hmac_sha256(&salted_password, b"Client Key"); let stored_key = Sha256::digest(&client_key).to_vec(); let server_key = hmac_sha256(&salted_password, b"Server Key"); Self { username: options.username.clone(), database: options.database.clone(), roles: options.roles.clone(), scram_sha256: ScramCredential { salt, iterations, stored_key, server_key, }, } } fn to_authenticated_user(&self) -> AuthenticatedUser { AuthenticatedUser { username: self.username.clone(), database: self.database.clone(), roles: self.roles.clone(), } } } fn role_allows(role: &str, user: &AuthenticatedUser, target_db: &str, action: AuthAction) -> bool { let (role_db, role_name) = role.split_once('.').unwrap_or(("", role)); if role_name == "root" { return true; } let any_database = role_name.ends_with("AnyDatabase"); let scoped_db = if role_db.is_empty() { &user.database } else { role_db }; if !any_database && scoped_db != target_db { return false; } match role_name { "read" | "readAnyDatabase" => action == AuthAction::Read, "readWrite" | "readWriteAnyDatabase" => { matches!(action, AuthAction::Read | AuthAction::Write) } "dbAdmin" | "dbAdminAnyDatabase" => action == AuthAction::DbAdmin, "userAdmin" | "userAdminAnyDatabase" => action == AuthAction::UserAdmin, "clusterMonitor" => action == AuthAction::ClusterMonitor, _ => false, } } fn load_users(path: &Path) -> Result, AuthError> { if !path.exists() { return Ok(HashMap::new()); } let data = std::fs::read_to_string(path).map_err(|e| AuthError::Persistence(e.to_string()))?; let persisted: PersistedAuthState = serde_json::from_str(&data) .map_err(|e| AuthError::Persistence(format!("failed to parse users file: {e}")))?; Ok(persisted .users .into_iter() .map(|user| (user_key(&user.database, &user.username), user)) .collect()) } fn persist_users(path: &Path, users: &HashMap) -> Result<(), AuthError> { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).map_err(|e| AuthError::Persistence(e.to_string()))?; } let mut user_list: Vec = users.values().cloned().collect(); user_list.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username))); let payload = serde_json::to_vec_pretty(&PersistedAuthState { users: user_list }) .map_err(|e| AuthError::Persistence(e.to_string()))?; let tmp_path = path.with_extension("tmp"); { let mut file = std::fs::File::create(&tmp_path) .map_err(|e| AuthError::Persistence(e.to_string()))?; file.write_all(&payload) .map_err(|e| AuthError::Persistence(e.to_string()))?; file.sync_all() .map_err(|e| AuthError::Persistence(e.to_string()))?; } std::fs::rename(&tmp_path, path).map_err(|e| AuthError::Persistence(e.to_string()))?; if let Some(parent) = path.parent() { if let Ok(dir) = std::fs::File::open(parent) { let _ = dir.sync_all(); } } Ok(()) } fn user_key(database: &str, username: &str) -> String { format!("{}\0{}", database, username) } fn salted_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec { let mut output = [0u8; 32]; pbkdf2_hmac::(password, salt, iterations, &mut output); output.to_vec() } fn hmac_sha256(key: &[u8], message: &[u8]) -> Vec { let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts keys of any size"); mac.update(message); mac.finalize().into_bytes().to_vec() } fn secure_random(len: usize) -> Vec { let mut bytes = vec![0u8; len]; OsRng.fill_bytes(&mut bytes); bytes } fn secure_base64(len: usize) -> String { BASE64_STANDARD.encode(secure_random(len)) } fn parse_scram_attrs(input: &str) -> HashMap { let mut result = HashMap::new(); for part in input.split(',') { if let Some((key, value)) = part.split_once('=') { result.insert(key.to_string(), value.to_string()); } } result } fn decode_scram_name(input: &str) -> String { input.replace("=2C", ",").replace("=3D", "=") } #[cfg(test)] mod tests { use super::*; #[test] fn mechanism_lookup_returns_scram_sha256() { let options = AuthOptions { enabled: true, users: vec![AuthUserOptions { username: "root".to_string(), password: "secret".to_string(), database: "admin".to_string(), roles: vec!["root".to_string()], }], users_path: None, scram_iterations: 4096, }; let engine = AuthEngine::from_options(&options).unwrap(); assert_eq!(engine.supported_mechanisms("admin.root"), vec![SCRAM_SHA_256.to_string()]); } #[test] fn read_write_role_allows_read_and_write_only_on_own_db() { let user = AuthenticatedUser { username: "app".to_string(), database: "appdb".to_string(), roles: vec!["readWrite".to_string()], }; assert!(role_allows("readWrite", &user, "appdb", AuthAction::Read)); assert!(role_allows("readWrite", &user, "appdb", AuthAction::Write)); assert!(!role_allows("readWrite", &user, "other", AuthAction::Read)); assert!(!role_allows("readWrite", &user, "appdb", AuthAction::DbAdmin)); } }