88 lines
3.2 KiB
Rust
88 lines
3.2 KiB
Rust
|
|
use bson::{doc, Binary, Bson, Document};
|
||
|
|
|
||
|
|
use crate::context::{CommandContext, ConnectionState};
|
||
|
|
use crate::error::{CommandError, CommandResult};
|
||
|
|
|
||
|
|
pub async fn handle_sasl_start(
|
||
|
|
cmd: &Document,
|
||
|
|
db: &str,
|
||
|
|
ctx: &CommandContext,
|
||
|
|
connection: &mut ConnectionState,
|
||
|
|
) -> CommandResult<Document> {
|
||
|
|
let mechanism = cmd
|
||
|
|
.get_str("mechanism")
|
||
|
|
.map_err(|_| CommandError::InvalidArgument("missing SASL mechanism".into()))?;
|
||
|
|
if mechanism != "SCRAM-SHA-256" {
|
||
|
|
return Err(CommandError::InvalidArgument(format!(
|
||
|
|
"unsupported SASL mechanism: {mechanism}"
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
|
||
|
|
let payload = payload_bytes(cmd)?;
|
||
|
|
let result = ctx
|
||
|
|
.auth
|
||
|
|
.start_scram_sha256(db, &payload)
|
||
|
|
.map_err(map_auth_error)?;
|
||
|
|
let conversation_id = connection.next_conversation_id();
|
||
|
|
connection
|
||
|
|
.sasl_conversations
|
||
|
|
.insert(conversation_id, result.conversation);
|
||
|
|
|
||
|
|
Ok(doc! {
|
||
|
|
"conversationId": conversation_id,
|
||
|
|
"done": false,
|
||
|
|
"payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload },
|
||
|
|
"ok": 1.0,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn handle_sasl_continue(
|
||
|
|
cmd: &Document,
|
||
|
|
ctx: &CommandContext,
|
||
|
|
connection: &mut ConnectionState,
|
||
|
|
) -> CommandResult<Document> {
|
||
|
|
let conversation_id = cmd
|
||
|
|
.get_i32("conversationId")
|
||
|
|
.map_err(|_| CommandError::InvalidArgument("missing SASL conversationId".into()))?;
|
||
|
|
let payload = payload_bytes(cmd)?;
|
||
|
|
let conversation = connection
|
||
|
|
.sasl_conversations
|
||
|
|
.remove(&conversation_id)
|
||
|
|
.ok_or_else(|| CommandError::InvalidArgument("unknown SASL conversation".into()))?;
|
||
|
|
let result = ctx
|
||
|
|
.auth
|
||
|
|
.continue_scram_sha256(conversation, &payload)
|
||
|
|
.map_err(map_auth_error)?;
|
||
|
|
connection.authenticate(result.user);
|
||
|
|
|
||
|
|
Ok(doc! {
|
||
|
|
"conversationId": conversation_id,
|
||
|
|
"done": true,
|
||
|
|
"payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload },
|
||
|
|
"ok": 1.0,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
fn payload_bytes(cmd: &Document) -> CommandResult<Vec<u8>> {
|
||
|
|
match cmd.get("payload") {
|
||
|
|
Some(Bson::Binary(binary)) => Ok(binary.bytes.clone()),
|
||
|
|
Some(Bson::String(value)) => Ok(value.as_bytes().to_vec()),
|
||
|
|
_ => Err(CommandError::InvalidArgument("missing SASL payload".into())),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn map_auth_error(error: rustdb_auth::AuthError) -> CommandError {
|
||
|
|
match error {
|
||
|
|
rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message),
|
||
|
|
rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message),
|
||
|
|
rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()),
|
||
|
|
rustdb_auth::AuthError::UnknownConversation => {
|
||
|
|
CommandError::InvalidArgument("unknown SASL conversation".into())
|
||
|
|
}
|
||
|
|
rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed,
|
||
|
|
rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message),
|
||
|
|
rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message),
|
||
|
|
rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message),
|
||
|
|
}
|
||
|
|
}
|