use bson::{doc, Binary, Bson, Document}; use crate::context::{CommandContext, ConnectionState}; use crate::error::{CommandError, CommandResult}; pub async fn handle_sasl_start( cmd: &Document, db: &str, ctx: &CommandContext, connection: &mut ConnectionState, ) -> CommandResult { let mechanism = cmd .get_str("mechanism") .map_err(|_| CommandError::InvalidArgument("missing SASL mechanism".into()))?; if mechanism != "SCRAM-SHA-256" { return Err(CommandError::InvalidArgument(format!( "unsupported SASL mechanism: {mechanism}" ))); } let payload = payload_bytes(cmd)?; let result = ctx .auth .start_scram_sha256(db, &payload) .map_err(map_auth_error)?; let conversation_id = connection.next_conversation_id(); connection .sasl_conversations .insert(conversation_id, result.conversation); Ok(doc! { "conversationId": conversation_id, "done": false, "payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload }, "ok": 1.0, }) } pub async fn handle_sasl_continue( cmd: &Document, ctx: &CommandContext, connection: &mut ConnectionState, ) -> CommandResult { let conversation_id = cmd .get_i32("conversationId") .map_err(|_| CommandError::InvalidArgument("missing SASL conversationId".into()))?; let payload = payload_bytes(cmd)?; let conversation = connection .sasl_conversations .remove(&conversation_id) .ok_or_else(|| CommandError::InvalidArgument("unknown SASL conversation".into()))?; let result = ctx .auth .continue_scram_sha256(conversation, &payload) .map_err(map_auth_error)?; connection.authenticate(result.user); Ok(doc! { "conversationId": conversation_id, "done": true, "payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload }, "ok": 1.0, }) } fn payload_bytes(cmd: &Document) -> CommandResult> { match cmd.get("payload") { Some(Bson::Binary(binary)) => Ok(binary.bytes.clone()), Some(Bson::String(value)) => Ok(value.as_bytes().to_vec()), _ => Err(CommandError::InvalidArgument("missing SASL payload".into())), } } fn map_auth_error(error: rustdb_auth::AuthError) -> CommandError { match error { rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message), rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message), rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()), rustdb_auth::AuthError::UnknownConversation => { CommandError::InvalidArgument("unknown SASL conversation".into()) } rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed, rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message), rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message), rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message), } }