feat(enterprise): add auth TLS and recovery hardening
This commit is contained in:
@@ -0,0 +1,565 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::RwLock;
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
|
||||
use hmac::{Hmac, Mac};
|
||||
use pbkdf2::pbkdf2_hmac;
|
||||
use rand::{rngs::OsRng, RngCore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use rustdb_config::{AuthOptions, AuthUserOptions};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error("authentication is disabled")]
|
||||
Disabled,
|
||||
#[error("unsupported authentication mechanism: {0}")]
|
||||
UnsupportedMechanism(String),
|
||||
#[error("invalid SCRAM payload: {0}")]
|
||||
InvalidPayload(String),
|
||||
#[error("authentication failed")]
|
||||
AuthenticationFailed,
|
||||
#[error("unknown SASL conversation")]
|
||||
UnknownConversation,
|
||||
#[error("user already exists: {0}")]
|
||||
UserAlreadyExists(String),
|
||||
#[error("user not found: {0}")]
|
||||
UserNotFound(String),
|
||||
#[error("auth metadata persistence failed: {0}")]
|
||||
Persistence(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AuthAction {
|
||||
Read,
|
||||
Write,
|
||||
DbAdmin,
|
||||
UserAdmin,
|
||||
ClusterMonitor,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthenticatedUser {
|
||||
pub username: String,
|
||||
pub database: String,
|
||||
pub roles: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ScramCredential {
|
||||
salt: Vec<u8>,
|
||||
iterations: u32,
|
||||
stored_key: Vec<u8>,
|
||||
server_key: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AuthUser {
|
||||
username: String,
|
||||
database: String,
|
||||
roles: Vec<String>,
|
||||
scram_sha256: ScramCredential,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
struct PersistedAuthState {
|
||||
users: Vec<AuthUser>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScramConversation {
|
||||
user: AuthenticatedUser,
|
||||
client_first_bare: String,
|
||||
server_first: String,
|
||||
nonce: String,
|
||||
stored_key: Vec<u8>,
|
||||
server_key: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScramStartResult {
|
||||
pub payload: Vec<u8>,
|
||||
pub conversation: ScramConversation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScramContinueResult {
|
||||
pub payload: Vec<u8>,
|
||||
pub user: AuthenticatedUser,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthEngine {
|
||||
enabled: bool,
|
||||
users: RwLock<HashMap<String, AuthUser>>,
|
||||
users_path: Option<PathBuf>,
|
||||
scram_iterations: u32,
|
||||
}
|
||||
|
||||
impl AuthEngine {
|
||||
pub fn from_options(options: &AuthOptions) -> Result<Self, AuthError> {
|
||||
let users_path = options.users_path.as_ref().map(PathBuf::from);
|
||||
let mut users = if let Some(ref path) = users_path {
|
||||
load_users(path)?
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
let mut changed = false;
|
||||
for user_options in &options.users {
|
||||
let key = user_key(&user_options.database, &user_options.username);
|
||||
if !users.contains_key(&key) {
|
||||
let user = AuthUser::from_options(user_options, options.scram_iterations);
|
||||
users.insert(key, user);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if changed {
|
||||
if let Some(ref path) = users_path {
|
||||
persist_users(path, &users)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
enabled: options.enabled,
|
||||
users: RwLock::new(users),
|
||||
users_path,
|
||||
scram_iterations: options.scram_iterations,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn disabled() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
users: RwLock::new(HashMap::new()),
|
||||
users_path: None,
|
||||
scram_iterations: 15000,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
pub fn supported_mechanisms(&self, namespace_user: &str) -> Vec<String> {
|
||||
let Some((database, username)) = namespace_user.split_once('.') else {
|
||||
return Vec::new();
|
||||
};
|
||||
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
if users.contains_key(&user_key(database, username)) {
|
||||
vec![SCRAM_SHA_256.to_string()]
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_authorized(
|
||||
&self,
|
||||
authenticated_users: &[AuthenticatedUser],
|
||||
target_db: &str,
|
||||
action: AuthAction,
|
||||
) -> bool {
|
||||
authenticated_users
|
||||
.iter()
|
||||
.any(|user| user.roles.iter().any(|role| role_allows(role, user, target_db, action)))
|
||||
}
|
||||
|
||||
pub fn create_user(
|
||||
&self,
|
||||
database: &str,
|
||||
username: &str,
|
||||
password: &str,
|
||||
roles: Vec<String>,
|
||||
) -> Result<(), AuthError> {
|
||||
let key = user_key(database, username);
|
||||
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
if users.contains_key(&key) {
|
||||
return Err(AuthError::UserAlreadyExists(format!("{database}.{username}")));
|
||||
}
|
||||
let options = AuthUserOptions {
|
||||
username: username.to_string(),
|
||||
password: password.to_string(),
|
||||
database: database.to_string(),
|
||||
roles,
|
||||
};
|
||||
users.insert(key, AuthUser::from_options(&options, self.scram_iterations));
|
||||
self.persist_locked(&users)
|
||||
}
|
||||
|
||||
pub fn drop_user(&self, database: &str, username: &str) -> Result<(), AuthError> {
|
||||
let key = user_key(database, username);
|
||||
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
if users.remove(&key).is_none() {
|
||||
return Err(AuthError::UserNotFound(format!("{database}.{username}")));
|
||||
}
|
||||
self.persist_locked(&users)
|
||||
}
|
||||
|
||||
pub fn update_user(
|
||||
&self,
|
||||
database: &str,
|
||||
username: &str,
|
||||
password: Option<&str>,
|
||||
roles: Option<Vec<String>>,
|
||||
) -> Result<(), AuthError> {
|
||||
let key = user_key(database, username);
|
||||
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let user = users
|
||||
.get_mut(&key)
|
||||
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
|
||||
if let Some(new_roles) = roles {
|
||||
user.roles = new_roles;
|
||||
}
|
||||
if let Some(new_password) = password {
|
||||
let options = AuthUserOptions {
|
||||
username: username.to_string(),
|
||||
password: new_password.to_string(),
|
||||
database: database.to_string(),
|
||||
roles: user.roles.clone(),
|
||||
};
|
||||
user.scram_sha256 = AuthUser::from_options(&options, self.scram_iterations).scram_sha256;
|
||||
}
|
||||
self.persist_locked(&users)
|
||||
}
|
||||
|
||||
pub fn grant_roles(
|
||||
&self,
|
||||
database: &str,
|
||||
username: &str,
|
||||
roles: Vec<String>,
|
||||
) -> Result<(), AuthError> {
|
||||
let key = user_key(database, username);
|
||||
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let user = users
|
||||
.get_mut(&key)
|
||||
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
|
||||
for role in roles {
|
||||
if !user.roles.contains(&role) {
|
||||
user.roles.push(role);
|
||||
}
|
||||
}
|
||||
self.persist_locked(&users)
|
||||
}
|
||||
|
||||
pub fn revoke_roles(
|
||||
&self,
|
||||
database: &str,
|
||||
username: &str,
|
||||
roles: Vec<String>,
|
||||
) -> Result<(), AuthError> {
|
||||
let key = user_key(database, username);
|
||||
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let user = users
|
||||
.get_mut(&key)
|
||||
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
|
||||
user.roles.retain(|role| !roles.contains(role));
|
||||
self.persist_locked(&users)
|
||||
}
|
||||
|
||||
pub fn users_info(&self, database: &str, username: Option<&str>) -> Vec<AuthenticatedUser> {
|
||||
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
users
|
||||
.values()
|
||||
.filter(|user| user.database == database)
|
||||
.filter(|user| username.map(|name| user.username == name).unwrap_or(true))
|
||||
.map(AuthUser::to_authenticated_user)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn start_scram_sha256(
|
||||
&self,
|
||||
database: &str,
|
||||
payload: &[u8],
|
||||
) -> Result<ScramStartResult, AuthError> {
|
||||
if !self.enabled {
|
||||
return Err(AuthError::Disabled);
|
||||
}
|
||||
|
||||
let message = std::str::from_utf8(payload)
|
||||
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
|
||||
let client_first_bare = message
|
||||
.strip_prefix("n,,")
|
||||
.ok_or_else(|| AuthError::InvalidPayload("expected SCRAM gs2 header 'n,,'".to_string()))?;
|
||||
let attrs = parse_scram_attrs(client_first_bare);
|
||||
let raw_username = attrs
|
||||
.get("n")
|
||||
.ok_or_else(|| AuthError::InvalidPayload("missing username".to_string()))?;
|
||||
let username = decode_scram_name(raw_username);
|
||||
let client_nonce = attrs
|
||||
.get("r")
|
||||
.ok_or_else(|| AuthError::InvalidPayload("missing client nonce".to_string()))?;
|
||||
|
||||
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let user = users
|
||||
.get(&user_key(database, &username))
|
||||
.ok_or(AuthError::AuthenticationFailed)?;
|
||||
|
||||
let nonce = format!("{}{}", client_nonce, secure_base64(18));
|
||||
let server_first = format!(
|
||||
"r={},s={},i={}",
|
||||
nonce,
|
||||
BASE64_STANDARD.encode(&user.scram_sha256.salt),
|
||||
user.scram_sha256.iterations,
|
||||
);
|
||||
|
||||
Ok(ScramStartResult {
|
||||
payload: server_first.as_bytes().to_vec(),
|
||||
conversation: ScramConversation {
|
||||
user: user.to_authenticated_user(),
|
||||
client_first_bare: client_first_bare.to_string(),
|
||||
server_first: server_first.clone(),
|
||||
nonce,
|
||||
stored_key: user.scram_sha256.stored_key.clone(),
|
||||
server_key: user.scram_sha256.server_key.clone(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn continue_scram_sha256(
|
||||
&self,
|
||||
conversation: ScramConversation,
|
||||
payload: &[u8],
|
||||
) -> Result<ScramContinueResult, AuthError> {
|
||||
let message = std::str::from_utf8(payload)
|
||||
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
|
||||
let proof_marker = ",p=";
|
||||
let proof_pos = message
|
||||
.rfind(proof_marker)
|
||||
.ok_or_else(|| AuthError::InvalidPayload("missing client proof".to_string()))?;
|
||||
let client_final_without_proof = &message[..proof_pos];
|
||||
let proof_b64 = &message[proof_pos + proof_marker.len()..];
|
||||
let attrs = parse_scram_attrs(client_final_without_proof);
|
||||
let nonce = attrs
|
||||
.get("r")
|
||||
.ok_or_else(|| AuthError::InvalidPayload("missing nonce".to_string()))?;
|
||||
if nonce != &conversation.nonce {
|
||||
return Err(AuthError::AuthenticationFailed);
|
||||
}
|
||||
|
||||
let client_proof = BASE64_STANDARD
|
||||
.decode(proof_b64.as_bytes())
|
||||
.map_err(|_| AuthError::InvalidPayload("invalid client proof encoding".to_string()))?;
|
||||
if client_proof.len() != 32 || conversation.stored_key.len() != 32 {
|
||||
return Err(AuthError::AuthenticationFailed);
|
||||
}
|
||||
|
||||
let auth_message = format!(
|
||||
"{},{},{}",
|
||||
conversation.client_first_bare,
|
||||
conversation.server_first,
|
||||
client_final_without_proof,
|
||||
);
|
||||
let client_signature = hmac_sha256(&conversation.stored_key, auth_message.as_bytes());
|
||||
let client_key: Vec<u8> = client_proof
|
||||
.iter()
|
||||
.zip(client_signature.iter())
|
||||
.map(|(proof_byte, signature_byte)| proof_byte ^ signature_byte)
|
||||
.collect();
|
||||
let computed_stored_key = Sha256::digest(&client_key).to_vec();
|
||||
|
||||
if computed_stored_key.ct_eq(&conversation.stored_key).unwrap_u8() != 1 {
|
||||
return Err(AuthError::AuthenticationFailed);
|
||||
}
|
||||
|
||||
let server_signature = hmac_sha256(&conversation.server_key, auth_message.as_bytes());
|
||||
let server_final = format!("v={}", BASE64_STANDARD.encode(server_signature));
|
||||
|
||||
Ok(ScramContinueResult {
|
||||
payload: server_final.as_bytes().to_vec(),
|
||||
user: conversation.user,
|
||||
})
|
||||
}
|
||||
|
||||
fn persist_locked(&self, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
|
||||
if let Some(ref path) = self.users_path {
|
||||
persist_users(path, users)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AuthEngine {
|
||||
fn default() -> Self {
|
||||
Self::disabled()
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthUser {
|
||||
fn from_options(options: &AuthUserOptions, iterations: u32) -> Self {
|
||||
let salt = secure_random(24);
|
||||
let salted_password = salted_password(options.password.as_bytes(), &salt, iterations);
|
||||
let client_key = hmac_sha256(&salted_password, b"Client Key");
|
||||
let stored_key = Sha256::digest(&client_key).to_vec();
|
||||
let server_key = hmac_sha256(&salted_password, b"Server Key");
|
||||
|
||||
Self {
|
||||
username: options.username.clone(),
|
||||
database: options.database.clone(),
|
||||
roles: options.roles.clone(),
|
||||
scram_sha256: ScramCredential {
|
||||
salt,
|
||||
iterations,
|
||||
stored_key,
|
||||
server_key,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn to_authenticated_user(&self) -> AuthenticatedUser {
|
||||
AuthenticatedUser {
|
||||
username: self.username.clone(),
|
||||
database: self.database.clone(),
|
||||
roles: self.roles.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn role_allows(role: &str, user: &AuthenticatedUser, target_db: &str, action: AuthAction) -> bool {
|
||||
let (role_db, role_name) = role.split_once('.').unwrap_or(("", role));
|
||||
if role_name == "root" {
|
||||
return true;
|
||||
}
|
||||
|
||||
let any_database = role_name.ends_with("AnyDatabase");
|
||||
let scoped_db = if role_db.is_empty() { &user.database } else { role_db };
|
||||
if !any_database && scoped_db != target_db {
|
||||
return false;
|
||||
}
|
||||
|
||||
match role_name {
|
||||
"read" | "readAnyDatabase" => action == AuthAction::Read,
|
||||
"readWrite" | "readWriteAnyDatabase" => {
|
||||
matches!(action, AuthAction::Read | AuthAction::Write)
|
||||
}
|
||||
"dbAdmin" | "dbAdminAnyDatabase" => action == AuthAction::DbAdmin,
|
||||
"userAdmin" | "userAdminAnyDatabase" => action == AuthAction::UserAdmin,
|
||||
"clusterMonitor" => action == AuthAction::ClusterMonitor,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_users(path: &Path) -> Result<HashMap<String, AuthUser>, AuthError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let data = std::fs::read_to_string(path).map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
let persisted: PersistedAuthState = serde_json::from_str(&data)
|
||||
.map_err(|e| AuthError::Persistence(format!("failed to parse users file: {e}")))?;
|
||||
Ok(persisted
|
||||
.users
|
||||
.into_iter()
|
||||
.map(|user| (user_key(&user.database, &user.username), user))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn persist_users(path: &Path, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
}
|
||||
|
||||
let mut user_list: Vec<AuthUser> = users.values().cloned().collect();
|
||||
user_list.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username)));
|
||||
let payload = serde_json::to_vec_pretty(&PersistedAuthState { users: user_list })
|
||||
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
|
||||
let tmp_path = path.with_extension("tmp");
|
||||
{
|
||||
let mut file = std::fs::File::create(&tmp_path)
|
||||
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
file.write_all(&payload)
|
||||
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
file.sync_all()
|
||||
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
}
|
||||
std::fs::rename(&tmp_path, path).map_err(|e| AuthError::Persistence(e.to_string()))?;
|
||||
if let Some(parent) = path.parent() {
|
||||
if let Ok(dir) = std::fs::File::open(parent) {
|
||||
let _ = dir.sync_all();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn user_key(database: &str, username: &str) -> String {
|
||||
format!("{}\0{}", database, username)
|
||||
}
|
||||
|
||||
fn salted_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
|
||||
let mut output = [0u8; 32];
|
||||
pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut output);
|
||||
output.to_vec()
|
||||
}
|
||||
|
||||
fn hmac_sha256(key: &[u8], message: &[u8]) -> Vec<u8> {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts keys of any size");
|
||||
mac.update(message);
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
fn secure_random(len: usize) -> Vec<u8> {
|
||||
let mut bytes = vec![0u8; len];
|
||||
OsRng.fill_bytes(&mut bytes);
|
||||
bytes
|
||||
}
|
||||
|
||||
fn secure_base64(len: usize) -> String {
|
||||
BASE64_STANDARD.encode(secure_random(len))
|
||||
}
|
||||
|
||||
fn parse_scram_attrs(input: &str) -> HashMap<String, String> {
|
||||
let mut result = HashMap::new();
|
||||
for part in input.split(',') {
|
||||
if let Some((key, value)) = part.split_once('=') {
|
||||
result.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn decode_scram_name(input: &str) -> String {
|
||||
input.replace("=2C", ",").replace("=3D", "=")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mechanism_lookup_returns_scram_sha256() {
|
||||
let options = AuthOptions {
|
||||
enabled: true,
|
||||
users: vec![AuthUserOptions {
|
||||
username: "root".to_string(),
|
||||
password: "secret".to_string(),
|
||||
database: "admin".to_string(),
|
||||
roles: vec!["root".to_string()],
|
||||
}],
|
||||
users_path: None,
|
||||
scram_iterations: 4096,
|
||||
};
|
||||
let engine = AuthEngine::from_options(&options).unwrap();
|
||||
assert_eq!(engine.supported_mechanisms("admin.root"), vec![SCRAM_SHA_256.to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_write_role_allows_read_and_write_only_on_own_db() {
|
||||
let user = AuthenticatedUser {
|
||||
username: "app".to_string(),
|
||||
database: "appdb".to_string(),
|
||||
roles: vec!["readWrite".to_string()],
|
||||
};
|
||||
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Read));
|
||||
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Write));
|
||||
assert!(!role_allows("readWrite", &user, "other", AuthAction::Read));
|
||||
assert!(!role_allows("readWrite", &user, "appdb", AuthAction::DbAdmin));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user