368 lines
12 KiB
Rust
368 lines
12 KiB
Rust
|
|
use bson::{doc, Bson, Document};
|
||
|
|
use rustdb_storage::OpType;
|
||
|
|
use rustdb_txn::{TransactionState, WriteEntry, WriteOp};
|
||
|
|
|
||
|
|
use crate::context::CommandContext;
|
||
|
|
use crate::error::{CommandError, CommandResult};
|
||
|
|
|
||
|
|
pub fn command_starts_transaction(cmd: &Document) -> bool {
|
||
|
|
matches!(cmd.get("startTransaction"), Some(Bson::Boolean(true)))
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn command_uses_transaction(cmd: &Document) -> bool {
|
||
|
|
command_starts_transaction(cmd) || matches!(cmd.get("autocommit"), Some(Bson::Boolean(false)))
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn active_transaction_id(ctx: &CommandContext, cmd: &Document) -> Option<String> {
|
||
|
|
if !command_uses_transaction(cmd) {
|
||
|
|
return None;
|
||
|
|
}
|
||
|
|
|
||
|
|
let session_id = cmd
|
||
|
|
.get("lsid")
|
||
|
|
.and_then(rustdb_txn::SessionEngine::extract_session_id)?;
|
||
|
|
ctx.sessions.get_transaction_id(&session_id)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn prepare_transaction_for_command(
|
||
|
|
ctx: &CommandContext,
|
||
|
|
cmd: &Document,
|
||
|
|
command_name: &str,
|
||
|
|
) -> CommandResult<()> {
|
||
|
|
if matches!(command_name, "commitTransaction" | "abortTransaction") {
|
||
|
|
return Ok(());
|
||
|
|
}
|
||
|
|
|
||
|
|
let starts_transaction = command_starts_transaction(cmd);
|
||
|
|
let uses_transaction = command_uses_transaction(cmd);
|
||
|
|
if !uses_transaction {
|
||
|
|
return Ok(());
|
||
|
|
}
|
||
|
|
|
||
|
|
let session_id = session_id_from_command(cmd)?;
|
||
|
|
require_txn_number(cmd)?;
|
||
|
|
ctx.sessions.get_or_create_session(&session_id);
|
||
|
|
|
||
|
|
if starts_transaction {
|
||
|
|
let txn_id = ctx.transactions.start_transaction(&session_id)?;
|
||
|
|
ctx.sessions.start_transaction(&session_id, &txn_id)?;
|
||
|
|
return Ok(());
|
||
|
|
}
|
||
|
|
|
||
|
|
if ctx.sessions.get_transaction_id(&session_id).is_none() {
|
||
|
|
return Err(CommandError::NoSuchTransaction(format!(
|
||
|
|
"session {session_id} has no active transaction"
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn load_transaction_docs(
|
||
|
|
ctx: &CommandContext,
|
||
|
|
txn_id: &str,
|
||
|
|
db: &str,
|
||
|
|
coll: &str,
|
||
|
|
) -> CommandResult<Vec<Document>> {
|
||
|
|
let ns = namespace(db, coll);
|
||
|
|
if !ctx.transactions.has_snapshot(txn_id, &ns) {
|
||
|
|
let docs = match ctx.storage.collection_exists(db, coll).await {
|
||
|
|
Ok(true) => ctx.storage.find_all(db, coll).await?,
|
||
|
|
Ok(false) => Vec::new(),
|
||
|
|
Err(_) => Vec::new(),
|
||
|
|
};
|
||
|
|
ctx.transactions.set_snapshot(txn_id, &ns, docs);
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx.transactions
|
||
|
|
.get_snapshot(txn_id, &ns)
|
||
|
|
.ok_or_else(|| CommandError::NoSuchTransaction(txn_id.to_string()))
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn record_insert(
|
||
|
|
ctx: &CommandContext,
|
||
|
|
txn_id: &str,
|
||
|
|
db: &str,
|
||
|
|
coll: &str,
|
||
|
|
doc: Document,
|
||
|
|
) -> CommandResult<String> {
|
||
|
|
let id = document_id_string(&doc)?;
|
||
|
|
let docs = load_transaction_docs(ctx, txn_id, db, coll).await?;
|
||
|
|
if docs.iter().any(|existing| document_id_string(existing).ok().as_deref() == Some(id.as_str())) {
|
||
|
|
return Err(CommandError::DuplicateKey(format!(
|
||
|
|
"duplicate _id '{}' in transaction",
|
||
|
|
id
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx.transactions.record_write(
|
||
|
|
txn_id,
|
||
|
|
&namespace(db, coll),
|
||
|
|
&id,
|
||
|
|
WriteOp::Insert,
|
||
|
|
Some(doc),
|
||
|
|
None,
|
||
|
|
);
|
||
|
|
Ok(id)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn record_update(
|
||
|
|
ctx: &CommandContext,
|
||
|
|
txn_id: &str,
|
||
|
|
db: &str,
|
||
|
|
coll: &str,
|
||
|
|
original: Document,
|
||
|
|
updated: Document,
|
||
|
|
) -> CommandResult<String> {
|
||
|
|
let id = document_id_string(&original)?;
|
||
|
|
ctx.transactions.record_write(
|
||
|
|
txn_id,
|
||
|
|
&namespace(db, coll),
|
||
|
|
&id,
|
||
|
|
WriteOp::Update,
|
||
|
|
Some(updated),
|
||
|
|
Some(original),
|
||
|
|
);
|
||
|
|
Ok(id)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn record_delete(
|
||
|
|
ctx: &CommandContext,
|
||
|
|
txn_id: &str,
|
||
|
|
db: &str,
|
||
|
|
coll: &str,
|
||
|
|
original: Document,
|
||
|
|
) -> CommandResult<String> {
|
||
|
|
let id = document_id_string(&original)?;
|
||
|
|
ctx.transactions.record_write(
|
||
|
|
txn_id,
|
||
|
|
&namespace(db, coll),
|
||
|
|
&id,
|
||
|
|
WriteOp::Delete,
|
||
|
|
None,
|
||
|
|
Some(original),
|
||
|
|
);
|
||
|
|
Ok(id)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn commit_transaction_command(
|
||
|
|
cmd: &Document,
|
||
|
|
ctx: &CommandContext,
|
||
|
|
) -> CommandResult<Document> {
|
||
|
|
let session_id = session_id_from_command(cmd)?;
|
||
|
|
let txn_id = ctx
|
||
|
|
.sessions
|
||
|
|
.get_transaction_id(&session_id)
|
||
|
|
.ok_or_else(|| CommandError::NoSuchTransaction(format!(
|
||
|
|
"session {session_id} has no active transaction"
|
||
|
|
)))?;
|
||
|
|
let state = ctx.transactions.take_transaction(&txn_id)?;
|
||
|
|
|
||
|
|
preflight_transaction(&state, ctx).await?;
|
||
|
|
apply_transaction(state, ctx).await?;
|
||
|
|
ctx.sessions.end_transaction(&session_id);
|
||
|
|
|
||
|
|
Ok(doc! { "ok": 1.0 })
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn abort_transaction_command(cmd: &Document, ctx: &CommandContext) -> CommandResult<Document> {
|
||
|
|
let session_id = session_id_from_command(cmd)?;
|
||
|
|
let txn_id = ctx
|
||
|
|
.sessions
|
||
|
|
.get_transaction_id(&session_id)
|
||
|
|
.ok_or_else(|| CommandError::NoSuchTransaction(format!(
|
||
|
|
"session {session_id} has no active transaction"
|
||
|
|
)))?;
|
||
|
|
ctx.transactions.abort_transaction(&txn_id)?;
|
||
|
|
ctx.sessions.end_transaction(&session_id);
|
||
|
|
Ok(doc! { "ok": 1.0 })
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn document_id_string(doc: &Document) -> CommandResult<String> {
|
||
|
|
match doc.get("_id") {
|
||
|
|
Some(Bson::ObjectId(oid)) => Ok(oid.to_hex()),
|
||
|
|
Some(Bson::String(s)) => Ok(s.clone()),
|
||
|
|
Some(other) => Ok(format!("{}", other)),
|
||
|
|
None => Err(CommandError::InvalidArgument("document missing _id field".into())),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn session_id_from_command(cmd: &Document) -> CommandResult<String> {
|
||
|
|
cmd.get("lsid")
|
||
|
|
.and_then(rustdb_txn::SessionEngine::extract_session_id)
|
||
|
|
.ok_or_else(|| CommandError::InvalidArgument("transaction command requires lsid".into()))
|
||
|
|
}
|
||
|
|
|
||
|
|
fn require_txn_number(cmd: &Document) -> CommandResult<()> {
|
||
|
|
match cmd.get("txnNumber") {
|
||
|
|
Some(Bson::Int64(_)) | Some(Bson::Int32(_)) => Ok(()),
|
||
|
|
_ => Err(CommandError::InvalidArgument(
|
||
|
|
"transaction command requires txnNumber".into(),
|
||
|
|
)),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn namespace(db: &str, coll: &str) -> String {
|
||
|
|
format!("{db}.{coll}")
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn preflight_transaction(state: &TransactionState, ctx: &CommandContext) -> CommandResult<()> {
|
||
|
|
for (ns, writes) in &state.write_set {
|
||
|
|
let (db, coll) = split_namespace(ns)?;
|
||
|
|
drop(ctx.get_or_init_index_engine(db, coll).await);
|
||
|
|
|
||
|
|
for (doc_id, entry) in writes {
|
||
|
|
let current = current_doc(ctx, db, coll, doc_id).await?;
|
||
|
|
match entry.op {
|
||
|
|
WriteOp::Insert => {
|
||
|
|
if current.is_some() {
|
||
|
|
return Err(CommandError::DuplicateKey(format!(
|
||
|
|
"duplicate _id '{}' on transaction commit",
|
||
|
|
doc_id
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
if let Some(ref doc) = entry.doc {
|
||
|
|
if let Some(engine) = ctx.indexes.get(ns) {
|
||
|
|
engine.check_unique_constraints(doc)?;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
WriteOp::Update => {
|
||
|
|
assert_unchanged(doc_id, current.as_ref(), entry.original_doc.as_ref())?;
|
||
|
|
if let (Some(current_doc), Some(updated_doc)) = (current.as_ref(), entry.doc.as_ref()) {
|
||
|
|
if let Some(engine) = ctx.indexes.get(ns) {
|
||
|
|
engine.check_unique_constraints_for_update(current_doc, updated_doc)?;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
WriteOp::Delete => {
|
||
|
|
assert_unchanged(doc_id, current.as_ref(), entry.original_doc.as_ref())?;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn apply_transaction(state: TransactionState, ctx: &CommandContext) -> CommandResult<()> {
|
||
|
|
let mut namespaces: Vec<_> = state.write_set.into_iter().collect();
|
||
|
|
namespaces.sort_by(|a, b| a.0.cmp(&b.0));
|
||
|
|
|
||
|
|
for (ns, writes) in namespaces {
|
||
|
|
let (db, coll) = split_namespace(&ns)?;
|
||
|
|
ensure_collection_exists(db, coll, ctx).await?;
|
||
|
|
drop(ctx.get_or_init_index_engine(db, coll).await);
|
||
|
|
|
||
|
|
let mut writes: Vec<(String, WriteEntry)> = writes.into_iter().collect();
|
||
|
|
writes.sort_by(|a, b| a.0.cmp(&b.0));
|
||
|
|
|
||
|
|
for (doc_id, entry) in writes {
|
||
|
|
match entry.op {
|
||
|
|
WriteOp::Insert => {
|
||
|
|
let Some(doc) = entry.doc else { continue; };
|
||
|
|
let inserted_id = ctx.storage.insert_one(db, coll, doc.clone()).await?;
|
||
|
|
ctx.oplog.append(OpType::Insert, db, coll, &inserted_id, Some(doc.clone()), None);
|
||
|
|
if let Some(mut engine) = ctx.indexes.get_mut(&ns) {
|
||
|
|
engine.on_insert(&doc)?;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
WriteOp::Update => {
|
||
|
|
let Some(doc) = entry.doc else { continue; };
|
||
|
|
ctx.storage.update_by_id(db, coll, &doc_id, doc.clone()).await?;
|
||
|
|
ctx.oplog.append(
|
||
|
|
OpType::Update,
|
||
|
|
db,
|
||
|
|
coll,
|
||
|
|
&doc_id,
|
||
|
|
Some(doc.clone()),
|
||
|
|
entry.original_doc.clone(),
|
||
|
|
);
|
||
|
|
if let (Some(mut engine), Some(ref original)) =
|
||
|
|
(ctx.indexes.get_mut(&ns), entry.original_doc.as_ref())
|
||
|
|
{
|
||
|
|
engine.on_update(original, &doc)?;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
WriteOp::Delete => {
|
||
|
|
ctx.storage.delete_by_id(db, coll, &doc_id).await?;
|
||
|
|
ctx.oplog.append(
|
||
|
|
OpType::Delete,
|
||
|
|
db,
|
||
|
|
coll,
|
||
|
|
&doc_id,
|
||
|
|
None,
|
||
|
|
entry.original_doc.clone(),
|
||
|
|
);
|
||
|
|
if let (Some(mut engine), Some(ref original)) =
|
||
|
|
(ctx.indexes.get_mut(&ns), entry.original_doc.as_ref())
|
||
|
|
{
|
||
|
|
engine.on_delete(original);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn current_doc(
|
||
|
|
ctx: &CommandContext,
|
||
|
|
db: &str,
|
||
|
|
coll: &str,
|
||
|
|
doc_id: &str,
|
||
|
|
) -> CommandResult<Option<Document>> {
|
||
|
|
match ctx.storage.collection_exists(db, coll).await {
|
||
|
|
Ok(true) => Ok(ctx.storage.find_by_id(db, coll, doc_id).await?),
|
||
|
|
Ok(false) => Ok(None),
|
||
|
|
Err(_) => Ok(None),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn assert_unchanged(
|
||
|
|
doc_id: &str,
|
||
|
|
current: Option<&Document>,
|
||
|
|
original: Option<&Document>,
|
||
|
|
) -> CommandResult<()> {
|
||
|
|
if current == original {
|
||
|
|
return Ok(());
|
||
|
|
}
|
||
|
|
|
||
|
|
Err(CommandError::WriteConflict(format!(
|
||
|
|
"document '{}' changed during transaction",
|
||
|
|
doc_id
|
||
|
|
)))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn ensure_collection_exists(
|
||
|
|
db: &str,
|
||
|
|
coll: &str,
|
||
|
|
ctx: &CommandContext,
|
||
|
|
) -> CommandResult<()> {
|
||
|
|
if let Err(e) = ctx.storage.create_database(db).await {
|
||
|
|
let msg = e.to_string();
|
||
|
|
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
|
||
|
|
return Err(CommandError::StorageError(msg));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
match ctx.storage.collection_exists(db, coll).await {
|
||
|
|
Ok(true) => Ok(()),
|
||
|
|
Ok(false) | Err(_) => {
|
||
|
|
if let Err(e) = ctx.storage.create_collection(db, coll).await {
|
||
|
|
let msg = e.to_string();
|
||
|
|
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
|
||
|
|
return Err(CommandError::StorageError(msg));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn split_namespace(ns: &str) -> CommandResult<(&str, &str)> {
|
||
|
|
ns.split_once('.')
|
||
|
|
.ok_or_else(|| CommandError::InvalidArgument(format!("invalid namespace '{ns}'")))
|
||
|
|
}
|