594 lines
19 KiB
Rust
594 lines
19 KiB
Rust
use std::collections::HashMap;
|
|
use std::io::Write;
|
|
use std::path::{Path, PathBuf};
|
|
use std::sync::RwLock;
|
|
|
|
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
|
|
use hmac::{Hmac, Mac};
|
|
use pbkdf2::pbkdf2_hmac;
|
|
use rand::{rngs::OsRng, RngCore};
|
|
use serde::{Deserialize, Serialize};
|
|
use sha2::{Digest, Sha256};
|
|
use subtle::ConstantTimeEq;
|
|
|
|
use rustdb_config::{AuthOptions, AuthUserOptions};
|
|
|
|
type HmacSha256 = Hmac<Sha256>;
|
|
|
|
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum AuthError {
|
|
#[error("authentication is disabled")]
|
|
Disabled,
|
|
#[error("unsupported authentication mechanism: {0}")]
|
|
UnsupportedMechanism(String),
|
|
#[error("invalid SCRAM payload: {0}")]
|
|
InvalidPayload(String),
|
|
#[error("authentication failed")]
|
|
AuthenticationFailed,
|
|
#[error("unknown SASL conversation")]
|
|
UnknownConversation,
|
|
#[error("user already exists: {0}")]
|
|
UserAlreadyExists(String),
|
|
#[error("user not found: {0}")]
|
|
UserNotFound(String),
|
|
#[error("auth metadata persistence failed: {0}")]
|
|
Persistence(String),
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum AuthAction {
|
|
Read,
|
|
Write,
|
|
DbAdmin,
|
|
UserAdmin,
|
|
ClusterMonitor,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct AuthenticatedUser {
|
|
pub username: String,
|
|
pub database: String,
|
|
pub roles: Vec<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct ScramCredential {
|
|
salt: Vec<u8>,
|
|
iterations: u32,
|
|
stored_key: Vec<u8>,
|
|
server_key: Vec<u8>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct AuthUser {
|
|
username: String,
|
|
database: String,
|
|
roles: Vec<String>,
|
|
scram_sha256: ScramCredential,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
|
struct PersistedAuthState {
|
|
users: Vec<AuthUser>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ScramConversation {
|
|
user: AuthenticatedUser,
|
|
client_first_bare: String,
|
|
server_first: String,
|
|
nonce: String,
|
|
stored_key: Vec<u8>,
|
|
server_key: Vec<u8>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ScramStartResult {
|
|
pub payload: Vec<u8>,
|
|
pub conversation: ScramConversation,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ScramContinueResult {
|
|
pub payload: Vec<u8>,
|
|
pub user: AuthenticatedUser,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AuthEngine {
|
|
enabled: bool,
|
|
users: RwLock<HashMap<String, AuthUser>>,
|
|
users_path: Option<PathBuf>,
|
|
scram_iterations: u32,
|
|
}
|
|
|
|
impl AuthEngine {
|
|
pub fn from_options(options: &AuthOptions) -> Result<Self, AuthError> {
|
|
let users_path = options.users_path.as_ref().map(PathBuf::from);
|
|
let mut users = if let Some(ref path) = users_path {
|
|
load_users(path)?
|
|
} else {
|
|
HashMap::new()
|
|
};
|
|
|
|
let mut changed = false;
|
|
for user_options in &options.users {
|
|
let key = user_key(&user_options.database, &user_options.username);
|
|
if !users.contains_key(&key) {
|
|
let user = AuthUser::from_options(user_options, options.scram_iterations);
|
|
users.insert(key, user);
|
|
changed = true;
|
|
}
|
|
}
|
|
|
|
if changed {
|
|
if let Some(ref path) = users_path {
|
|
persist_users(path, &users)?;
|
|
}
|
|
}
|
|
|
|
Ok(Self {
|
|
enabled: options.enabled,
|
|
users: RwLock::new(users),
|
|
users_path,
|
|
scram_iterations: options.scram_iterations,
|
|
})
|
|
}
|
|
|
|
pub fn disabled() -> Self {
|
|
Self {
|
|
enabled: false,
|
|
users: RwLock::new(HashMap::new()),
|
|
users_path: None,
|
|
scram_iterations: 15000,
|
|
}
|
|
}
|
|
|
|
pub fn enabled(&self) -> bool {
|
|
self.enabled
|
|
}
|
|
|
|
pub fn user_count(&self) -> usize {
|
|
self.users
|
|
.read()
|
|
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
|
.len()
|
|
}
|
|
|
|
pub fn supported_mechanisms(&self, namespace_user: &str) -> Vec<String> {
|
|
let Some((database, username)) = namespace_user.split_once('.') else {
|
|
return Vec::new();
|
|
};
|
|
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
if users.contains_key(&user_key(database, username)) {
|
|
vec![SCRAM_SHA_256.to_string()]
|
|
} else {
|
|
Vec::new()
|
|
}
|
|
}
|
|
|
|
pub fn is_authorized(
|
|
&self,
|
|
authenticated_users: &[AuthenticatedUser],
|
|
target_db: &str,
|
|
action: AuthAction,
|
|
) -> bool {
|
|
authenticated_users
|
|
.iter()
|
|
.any(|user| user.roles.iter().any(|role| role_allows(role, user, target_db, action)))
|
|
}
|
|
|
|
pub fn create_user(
|
|
&self,
|
|
database: &str,
|
|
username: &str,
|
|
password: &str,
|
|
roles: Vec<String>,
|
|
) -> Result<(), AuthError> {
|
|
let key = user_key(database, username);
|
|
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
if users.contains_key(&key) {
|
|
return Err(AuthError::UserAlreadyExists(format!("{database}.{username}")));
|
|
}
|
|
let options = AuthUserOptions {
|
|
username: username.to_string(),
|
|
password: password.to_string(),
|
|
database: database.to_string(),
|
|
roles,
|
|
};
|
|
users.insert(key, AuthUser::from_options(&options, self.scram_iterations));
|
|
self.persist_locked(&users)
|
|
}
|
|
|
|
pub fn drop_user(&self, database: &str, username: &str) -> Result<(), AuthError> {
|
|
let key = user_key(database, username);
|
|
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
if users.remove(&key).is_none() {
|
|
return Err(AuthError::UserNotFound(format!("{database}.{username}")));
|
|
}
|
|
self.persist_locked(&users)
|
|
}
|
|
|
|
pub fn update_user(
|
|
&self,
|
|
database: &str,
|
|
username: &str,
|
|
password: Option<&str>,
|
|
roles: Option<Vec<String>>,
|
|
) -> Result<(), AuthError> {
|
|
let key = user_key(database, username);
|
|
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
let user = users
|
|
.get_mut(&key)
|
|
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
|
|
if let Some(new_roles) = roles {
|
|
user.roles = new_roles;
|
|
}
|
|
if let Some(new_password) = password {
|
|
let options = AuthUserOptions {
|
|
username: username.to_string(),
|
|
password: new_password.to_string(),
|
|
database: database.to_string(),
|
|
roles: user.roles.clone(),
|
|
};
|
|
user.scram_sha256 = AuthUser::from_options(&options, self.scram_iterations).scram_sha256;
|
|
}
|
|
self.persist_locked(&users)
|
|
}
|
|
|
|
pub fn grant_roles(
|
|
&self,
|
|
database: &str,
|
|
username: &str,
|
|
roles: Vec<String>,
|
|
) -> Result<(), AuthError> {
|
|
let key = user_key(database, username);
|
|
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
let user = users
|
|
.get_mut(&key)
|
|
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
|
|
for role in roles {
|
|
if !user.roles.contains(&role) {
|
|
user.roles.push(role);
|
|
}
|
|
}
|
|
self.persist_locked(&users)
|
|
}
|
|
|
|
pub fn revoke_roles(
|
|
&self,
|
|
database: &str,
|
|
username: &str,
|
|
roles: Vec<String>,
|
|
) -> Result<(), AuthError> {
|
|
let key = user_key(database, username);
|
|
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
let user = users
|
|
.get_mut(&key)
|
|
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
|
|
user.roles.retain(|role| !roles.contains(role));
|
|
self.persist_locked(&users)
|
|
}
|
|
|
|
pub fn users_info(&self, database: &str, username: Option<&str>) -> Vec<AuthenticatedUser> {
|
|
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
users
|
|
.values()
|
|
.filter(|user| user.database == database)
|
|
.filter(|user| username.map(|name| user.username == name).unwrap_or(true))
|
|
.map(AuthUser::to_authenticated_user)
|
|
.collect()
|
|
}
|
|
|
|
pub fn list_users(&self) -> Vec<AuthenticatedUser> {
|
|
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
let mut result: Vec<AuthenticatedUser> = users
|
|
.values()
|
|
.map(AuthUser::to_authenticated_user)
|
|
.collect();
|
|
result.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username)));
|
|
result
|
|
}
|
|
|
|
pub fn drop_users_for_database(&self, database: &str) -> Result<usize, AuthError> {
|
|
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
let before = users.len();
|
|
users.retain(|_, user| user.database != database);
|
|
let dropped = before.saturating_sub(users.len());
|
|
if dropped > 0 {
|
|
self.persist_locked(&users)?;
|
|
}
|
|
Ok(dropped)
|
|
}
|
|
|
|
pub fn start_scram_sha256(
|
|
&self,
|
|
database: &str,
|
|
payload: &[u8],
|
|
) -> Result<ScramStartResult, AuthError> {
|
|
if !self.enabled {
|
|
return Err(AuthError::Disabled);
|
|
}
|
|
|
|
let message = std::str::from_utf8(payload)
|
|
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
|
|
let client_first_bare = message
|
|
.strip_prefix("n,,")
|
|
.ok_or_else(|| AuthError::InvalidPayload("expected SCRAM gs2 header 'n,,'".to_string()))?;
|
|
let attrs = parse_scram_attrs(client_first_bare);
|
|
let raw_username = attrs
|
|
.get("n")
|
|
.ok_or_else(|| AuthError::InvalidPayload("missing username".to_string()))?;
|
|
let username = decode_scram_name(raw_username);
|
|
let client_nonce = attrs
|
|
.get("r")
|
|
.ok_or_else(|| AuthError::InvalidPayload("missing client nonce".to_string()))?;
|
|
|
|
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
|
|
let user = users
|
|
.get(&user_key(database, &username))
|
|
.ok_or(AuthError::AuthenticationFailed)?;
|
|
|
|
let nonce = format!("{}{}", client_nonce, secure_base64(18));
|
|
let server_first = format!(
|
|
"r={},s={},i={}",
|
|
nonce,
|
|
BASE64_STANDARD.encode(&user.scram_sha256.salt),
|
|
user.scram_sha256.iterations,
|
|
);
|
|
|
|
Ok(ScramStartResult {
|
|
payload: server_first.as_bytes().to_vec(),
|
|
conversation: ScramConversation {
|
|
user: user.to_authenticated_user(),
|
|
client_first_bare: client_first_bare.to_string(),
|
|
server_first: server_first.clone(),
|
|
nonce,
|
|
stored_key: user.scram_sha256.stored_key.clone(),
|
|
server_key: user.scram_sha256.server_key.clone(),
|
|
},
|
|
})
|
|
}
|
|
|
|
pub fn continue_scram_sha256(
|
|
&self,
|
|
conversation: ScramConversation,
|
|
payload: &[u8],
|
|
) -> Result<ScramContinueResult, AuthError> {
|
|
let message = std::str::from_utf8(payload)
|
|
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
|
|
let proof_marker = ",p=";
|
|
let proof_pos = message
|
|
.rfind(proof_marker)
|
|
.ok_or_else(|| AuthError::InvalidPayload("missing client proof".to_string()))?;
|
|
let client_final_without_proof = &message[..proof_pos];
|
|
let proof_b64 = &message[proof_pos + proof_marker.len()..];
|
|
let attrs = parse_scram_attrs(client_final_without_proof);
|
|
let nonce = attrs
|
|
.get("r")
|
|
.ok_or_else(|| AuthError::InvalidPayload("missing nonce".to_string()))?;
|
|
if nonce != &conversation.nonce {
|
|
return Err(AuthError::AuthenticationFailed);
|
|
}
|
|
|
|
let client_proof = BASE64_STANDARD
|
|
.decode(proof_b64.as_bytes())
|
|
.map_err(|_| AuthError::InvalidPayload("invalid client proof encoding".to_string()))?;
|
|
if client_proof.len() != 32 || conversation.stored_key.len() != 32 {
|
|
return Err(AuthError::AuthenticationFailed);
|
|
}
|
|
|
|
let auth_message = format!(
|
|
"{},{},{}",
|
|
conversation.client_first_bare,
|
|
conversation.server_first,
|
|
client_final_without_proof,
|
|
);
|
|
let client_signature = hmac_sha256(&conversation.stored_key, auth_message.as_bytes());
|
|
let client_key: Vec<u8> = client_proof
|
|
.iter()
|
|
.zip(client_signature.iter())
|
|
.map(|(proof_byte, signature_byte)| proof_byte ^ signature_byte)
|
|
.collect();
|
|
let computed_stored_key = Sha256::digest(&client_key).to_vec();
|
|
|
|
if computed_stored_key.ct_eq(&conversation.stored_key).unwrap_u8() != 1 {
|
|
return Err(AuthError::AuthenticationFailed);
|
|
}
|
|
|
|
let server_signature = hmac_sha256(&conversation.server_key, auth_message.as_bytes());
|
|
let server_final = format!("v={}", BASE64_STANDARD.encode(server_signature));
|
|
|
|
Ok(ScramContinueResult {
|
|
payload: server_final.as_bytes().to_vec(),
|
|
user: conversation.user,
|
|
})
|
|
}
|
|
|
|
fn persist_locked(&self, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
|
|
if let Some(ref path) = self.users_path {
|
|
persist_users(path, users)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl Default for AuthEngine {
|
|
fn default() -> Self {
|
|
Self::disabled()
|
|
}
|
|
}
|
|
|
|
impl AuthUser {
|
|
fn from_options(options: &AuthUserOptions, iterations: u32) -> Self {
|
|
let salt = secure_random(24);
|
|
let salted_password = salted_password(options.password.as_bytes(), &salt, iterations);
|
|
let client_key = hmac_sha256(&salted_password, b"Client Key");
|
|
let stored_key = Sha256::digest(&client_key).to_vec();
|
|
let server_key = hmac_sha256(&salted_password, b"Server Key");
|
|
|
|
Self {
|
|
username: options.username.clone(),
|
|
database: options.database.clone(),
|
|
roles: options.roles.clone(),
|
|
scram_sha256: ScramCredential {
|
|
salt,
|
|
iterations,
|
|
stored_key,
|
|
server_key,
|
|
},
|
|
}
|
|
}
|
|
|
|
fn to_authenticated_user(&self) -> AuthenticatedUser {
|
|
AuthenticatedUser {
|
|
username: self.username.clone(),
|
|
database: self.database.clone(),
|
|
roles: self.roles.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn role_allows(role: &str, user: &AuthenticatedUser, target_db: &str, action: AuthAction) -> bool {
|
|
let (role_db, role_name) = role.split_once('.').unwrap_or(("", role));
|
|
if role_name == "root" {
|
|
return true;
|
|
}
|
|
|
|
let any_database = role_name.ends_with("AnyDatabase");
|
|
let scoped_db = if role_db.is_empty() { &user.database } else { role_db };
|
|
if !any_database && scoped_db != target_db {
|
|
return false;
|
|
}
|
|
|
|
match role_name {
|
|
"read" | "readAnyDatabase" => action == AuthAction::Read,
|
|
"readWrite" | "readWriteAnyDatabase" => {
|
|
matches!(action, AuthAction::Read | AuthAction::Write)
|
|
}
|
|
"dbAdmin" | "dbAdminAnyDatabase" => action == AuthAction::DbAdmin,
|
|
"userAdmin" | "userAdminAnyDatabase" => action == AuthAction::UserAdmin,
|
|
"clusterMonitor" => action == AuthAction::ClusterMonitor,
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
fn load_users(path: &Path) -> Result<HashMap<String, AuthUser>, AuthError> {
|
|
if !path.exists() {
|
|
return Ok(HashMap::new());
|
|
}
|
|
let data = std::fs::read_to_string(path).map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
let persisted: PersistedAuthState = serde_json::from_str(&data)
|
|
.map_err(|e| AuthError::Persistence(format!("failed to parse users file: {e}")))?;
|
|
Ok(persisted
|
|
.users
|
|
.into_iter()
|
|
.map(|user| (user_key(&user.database, &user.username), user))
|
|
.collect())
|
|
}
|
|
|
|
fn persist_users(path: &Path, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
|
|
if let Some(parent) = path.parent() {
|
|
std::fs::create_dir_all(parent).map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
}
|
|
|
|
let mut user_list: Vec<AuthUser> = users.values().cloned().collect();
|
|
user_list.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username)));
|
|
let payload = serde_json::to_vec_pretty(&PersistedAuthState { users: user_list })
|
|
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
|
|
let tmp_path = path.with_extension("tmp");
|
|
{
|
|
let mut file = std::fs::File::create(&tmp_path)
|
|
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
file.write_all(&payload)
|
|
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
file.sync_all()
|
|
.map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
}
|
|
std::fs::rename(&tmp_path, path).map_err(|e| AuthError::Persistence(e.to_string()))?;
|
|
if let Some(parent) = path.parent() {
|
|
if let Ok(dir) = std::fs::File::open(parent) {
|
|
let _ = dir.sync_all();
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn user_key(database: &str, username: &str) -> String {
|
|
format!("{}\0{}", database, username)
|
|
}
|
|
|
|
fn salted_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
|
|
let mut output = [0u8; 32];
|
|
pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut output);
|
|
output.to_vec()
|
|
}
|
|
|
|
fn hmac_sha256(key: &[u8], message: &[u8]) -> Vec<u8> {
|
|
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts keys of any size");
|
|
mac.update(message);
|
|
mac.finalize().into_bytes().to_vec()
|
|
}
|
|
|
|
fn secure_random(len: usize) -> Vec<u8> {
|
|
let mut bytes = vec![0u8; len];
|
|
OsRng.fill_bytes(&mut bytes);
|
|
bytes
|
|
}
|
|
|
|
fn secure_base64(len: usize) -> String {
|
|
BASE64_STANDARD.encode(secure_random(len))
|
|
}
|
|
|
|
fn parse_scram_attrs(input: &str) -> HashMap<String, String> {
|
|
let mut result = HashMap::new();
|
|
for part in input.split(',') {
|
|
if let Some((key, value)) = part.split_once('=') {
|
|
result.insert(key.to_string(), value.to_string());
|
|
}
|
|
}
|
|
result
|
|
}
|
|
|
|
fn decode_scram_name(input: &str) -> String {
|
|
input.replace("=2C", ",").replace("=3D", "=")
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn mechanism_lookup_returns_scram_sha256() {
|
|
let options = AuthOptions {
|
|
enabled: true,
|
|
users: vec![AuthUserOptions {
|
|
username: "root".to_string(),
|
|
password: "secret".to_string(),
|
|
database: "admin".to_string(),
|
|
roles: vec!["root".to_string()],
|
|
}],
|
|
users_path: None,
|
|
scram_iterations: 4096,
|
|
};
|
|
let engine = AuthEngine::from_options(&options).unwrap();
|
|
assert_eq!(engine.supported_mechanisms("admin.root"), vec![SCRAM_SHA_256.to_string()]);
|
|
}
|
|
|
|
#[test]
|
|
fn read_write_role_allows_read_and_write_only_on_own_db() {
|
|
let user = AuthenticatedUser {
|
|
username: "app".to_string(),
|
|
database: "appdb".to_string(),
|
|
roles: vec!["readWrite".to_string()],
|
|
};
|
|
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Read));
|
|
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Write));
|
|
assert!(!role_allows("readWrite", &user, "other", AuthAction::Read));
|
|
assert!(!role_allows("readWrite", &user, "appdb", AuthAction::DbAdmin));
|
|
}
|
|
}
|