use bson::{doc, Bson, Document}; use rustdb_storage::OpType; use rustdb_txn::{TransactionState, WriteEntry, WriteOp}; use crate::context::CommandContext; use crate::error::{CommandError, CommandResult}; pub fn command_starts_transaction(cmd: &Document) -> bool { matches!(cmd.get("startTransaction"), Some(Bson::Boolean(true))) } pub fn command_uses_transaction(cmd: &Document) -> bool { command_starts_transaction(cmd) || matches!(cmd.get("autocommit"), Some(Bson::Boolean(false))) } pub fn active_transaction_id(ctx: &CommandContext, cmd: &Document) -> Option { if !command_uses_transaction(cmd) { return None; } let session_id = cmd .get("lsid") .and_then(rustdb_txn::SessionEngine::extract_session_id)?; ctx.sessions.get_transaction_id(&session_id) } pub fn prepare_transaction_for_command( ctx: &CommandContext, cmd: &Document, command_name: &str, ) -> CommandResult<()> { if matches!(command_name, "commitTransaction" | "abortTransaction") { return Ok(()); } let starts_transaction = command_starts_transaction(cmd); let uses_transaction = command_uses_transaction(cmd); if !uses_transaction { return Ok(()); } let session_id = session_id_from_command(cmd)?; require_txn_number(cmd)?; ctx.sessions.get_or_create_session(&session_id); if starts_transaction { let txn_id = ctx.transactions.start_transaction(&session_id)?; ctx.sessions.start_transaction(&session_id, &txn_id)?; return Ok(()); } if ctx.sessions.get_transaction_id(&session_id).is_none() { return Err(CommandError::NoSuchTransaction(format!( "session {session_id} has no active transaction" ))); } Ok(()) } pub async fn load_transaction_docs( ctx: &CommandContext, txn_id: &str, db: &str, coll: &str, ) -> CommandResult> { let ns = namespace(db, coll); if !ctx.transactions.has_snapshot(txn_id, &ns) { let docs = match ctx.storage.collection_exists(db, coll).await { Ok(true) => ctx.storage.find_all(db, coll).await?, Ok(false) => Vec::new(), Err(_) => Vec::new(), }; ctx.transactions.set_snapshot(txn_id, &ns, docs); } ctx.transactions .get_snapshot(txn_id, &ns) .ok_or_else(|| CommandError::NoSuchTransaction(txn_id.to_string())) } pub async fn record_insert( ctx: &CommandContext, txn_id: &str, db: &str, coll: &str, doc: Document, ) -> CommandResult { let id = document_id_string(&doc)?; let docs = load_transaction_docs(ctx, txn_id, db, coll).await?; if docs.iter().any(|existing| document_id_string(existing).ok().as_deref() == Some(id.as_str())) { return Err(CommandError::DuplicateKey(format!( "duplicate _id '{}' in transaction", id ))); } ctx.transactions.record_write( txn_id, &namespace(db, coll), &id, WriteOp::Insert, Some(doc), None, ); Ok(id) } pub async fn record_update( ctx: &CommandContext, txn_id: &str, db: &str, coll: &str, original: Document, updated: Document, ) -> CommandResult { let id = document_id_string(&original)?; ctx.transactions.record_write( txn_id, &namespace(db, coll), &id, WriteOp::Update, Some(updated), Some(original), ); Ok(id) } pub async fn record_delete( ctx: &CommandContext, txn_id: &str, db: &str, coll: &str, original: Document, ) -> CommandResult { let id = document_id_string(&original)?; ctx.transactions.record_write( txn_id, &namespace(db, coll), &id, WriteOp::Delete, None, Some(original), ); Ok(id) } pub async fn commit_transaction_command( cmd: &Document, ctx: &CommandContext, ) -> CommandResult { let session_id = session_id_from_command(cmd)?; let txn_id = ctx .sessions .get_transaction_id(&session_id) .ok_or_else(|| CommandError::NoSuchTransaction(format!( "session {session_id} has no active transaction" )))?; let state = ctx.transactions.take_transaction(&txn_id)?; preflight_transaction(&state, ctx).await?; apply_transaction(state, ctx).await?; ctx.sessions.end_transaction(&session_id); Ok(doc! { "ok": 1.0 }) } pub fn abort_transaction_command(cmd: &Document, ctx: &CommandContext) -> CommandResult { let session_id = session_id_from_command(cmd)?; let txn_id = ctx .sessions .get_transaction_id(&session_id) .ok_or_else(|| CommandError::NoSuchTransaction(format!( "session {session_id} has no active transaction" )))?; ctx.transactions.abort_transaction(&txn_id)?; ctx.sessions.end_transaction(&session_id); Ok(doc! { "ok": 1.0 }) } pub fn document_id_string(doc: &Document) -> CommandResult { match doc.get("_id") { Some(Bson::ObjectId(oid)) => Ok(oid.to_hex()), Some(Bson::String(s)) => Ok(s.clone()), Some(other) => Ok(format!("{}", other)), None => Err(CommandError::InvalidArgument("document missing _id field".into())), } } fn session_id_from_command(cmd: &Document) -> CommandResult { cmd.get("lsid") .and_then(rustdb_txn::SessionEngine::extract_session_id) .ok_or_else(|| CommandError::InvalidArgument("transaction command requires lsid".into())) } fn require_txn_number(cmd: &Document) -> CommandResult<()> { match cmd.get("txnNumber") { Some(Bson::Int64(_)) | Some(Bson::Int32(_)) => Ok(()), _ => Err(CommandError::InvalidArgument( "transaction command requires txnNumber".into(), )), } } fn namespace(db: &str, coll: &str) -> String { format!("{db}.{coll}") } async fn preflight_transaction(state: &TransactionState, ctx: &CommandContext) -> CommandResult<()> { for (ns, writes) in &state.write_set { let (db, coll) = split_namespace(ns)?; drop(ctx.get_or_init_index_engine(db, coll).await); for (doc_id, entry) in writes { let current = current_doc(ctx, db, coll, doc_id).await?; match entry.op { WriteOp::Insert => { if current.is_some() { return Err(CommandError::DuplicateKey(format!( "duplicate _id '{}' on transaction commit", doc_id ))); } if let Some(ref doc) = entry.doc { if let Some(engine) = ctx.indexes.get(ns) { engine.check_unique_constraints(doc)?; } } } WriteOp::Update => { assert_unchanged(doc_id, current.as_ref(), entry.original_doc.as_ref())?; if let (Some(current_doc), Some(updated_doc)) = (current.as_ref(), entry.doc.as_ref()) { if let Some(engine) = ctx.indexes.get(ns) { engine.check_unique_constraints_for_update(current_doc, updated_doc)?; } } } WriteOp::Delete => { assert_unchanged(doc_id, current.as_ref(), entry.original_doc.as_ref())?; } } } } Ok(()) } async fn apply_transaction(state: TransactionState, ctx: &CommandContext) -> CommandResult<()> { let mut namespaces: Vec<_> = state.write_set.into_iter().collect(); namespaces.sort_by(|a, b| a.0.cmp(&b.0)); for (ns, writes) in namespaces { let (db, coll) = split_namespace(&ns)?; ensure_collection_exists(db, coll, ctx).await?; drop(ctx.get_or_init_index_engine(db, coll).await); let mut writes: Vec<(String, WriteEntry)> = writes.into_iter().collect(); writes.sort_by(|a, b| a.0.cmp(&b.0)); for (doc_id, entry) in writes { match entry.op { WriteOp::Insert => { let Some(doc) = entry.doc else { continue; }; let inserted_id = ctx.storage.insert_one(db, coll, doc.clone()).await?; ctx.oplog.append(OpType::Insert, db, coll, &inserted_id, Some(doc.clone()), None); if let Some(mut engine) = ctx.indexes.get_mut(&ns) { engine.on_insert(&doc)?; } } WriteOp::Update => { let Some(doc) = entry.doc else { continue; }; ctx.storage.update_by_id(db, coll, &doc_id, doc.clone()).await?; ctx.oplog.append( OpType::Update, db, coll, &doc_id, Some(doc.clone()), entry.original_doc.clone(), ); if let (Some(mut engine), Some(ref original)) = (ctx.indexes.get_mut(&ns), entry.original_doc.as_ref()) { engine.on_update(original, &doc)?; } } WriteOp::Delete => { ctx.storage.delete_by_id(db, coll, &doc_id).await?; ctx.oplog.append( OpType::Delete, db, coll, &doc_id, None, entry.original_doc.clone(), ); if let (Some(mut engine), Some(ref original)) = (ctx.indexes.get_mut(&ns), entry.original_doc.as_ref()) { engine.on_delete(original); } } } } } Ok(()) } async fn current_doc( ctx: &CommandContext, db: &str, coll: &str, doc_id: &str, ) -> CommandResult> { match ctx.storage.collection_exists(db, coll).await { Ok(true) => Ok(ctx.storage.find_by_id(db, coll, doc_id).await?), Ok(false) => Ok(None), Err(_) => Ok(None), } } fn assert_unchanged( doc_id: &str, current: Option<&Document>, original: Option<&Document>, ) -> CommandResult<()> { if current == original { return Ok(()); } Err(CommandError::WriteConflict(format!( "document '{}' changed during transaction", doc_id ))) } async fn ensure_collection_exists( db: &str, coll: &str, ctx: &CommandContext, ) -> CommandResult<()> { if let Err(e) = ctx.storage.create_database(db).await { let msg = e.to_string(); if !msg.contains("AlreadyExists") && !msg.contains("already exists") { return Err(CommandError::StorageError(msg)); } } match ctx.storage.collection_exists(db, coll).await { Ok(true) => Ok(()), Ok(false) | Err(_) => { if let Err(e) = ctx.storage.create_collection(db, coll).await { let msg = e.to_string(); if !msg.contains("AlreadyExists") && !msg.contains("already exists") { return Err(CommandError::StorageError(msg)); } } Ok(()) } } } fn split_namespace(ns: &str) -> CommandResult<(&str, &str)> { ns.split_once('.') .ok_or_else(|| CommandError::InvalidArgument(format!("invalid namespace '{ns}'"))) }