BREAKING CHANGE(core): replace the TypeScript database engine with a Rust-backed embedded server and bridge

This commit is contained in:
2026-03-26 19:48:27 +00:00
parent 8ec2046908
commit e23a951dbe
106 changed files with 11567 additions and 10678 deletions

2
rust/.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"

1423
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

76
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,76 @@
[workspace]
resolver = "2"
members = [
"crates/rustdb",
"crates/rustdb-config",
"crates/rustdb-wire",
"crates/rustdb-query",
"crates/rustdb-storage",
"crates/rustdb-index",
"crates/rustdb-txn",
"crates/rustdb-commands",
]
[workspace.package]
version = "0.1.0"
edition = "2021"
license = "MIT"
authors = ["Lossless GmbH <hello@lossless.com>"]
[workspace.dependencies]
# Async runtime
tokio = { version = "1", features = ["full"] }
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# BSON serialization (bson crate)
bson = "2"
# Binary buffer manipulation
bytes = "1"
# CLI
clap = { version = "4", features = ["derive"] }
# Structured logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Error handling
thiserror = "2"
anyhow = "1"
# Lock-free atomics
arc-swap = "1"
# Concurrent maps
dashmap = "6"
# Cancellation / utility
tokio-util = { version = "0.7", features = ["codec"] }
# mimalloc allocator
mimalloc = "0.1"
# CRC32 checksums
crc32fast = "1"
# Regex for $regex operator
regex = "1"
# UUID for sessions
uuid = { version = "1", features = ["v4", "serde"] }
# Async traits
async-trait = "0.1"
# Internal crates
rustdb-config = { path = "crates/rustdb-config" }
rustdb-wire = { path = "crates/rustdb-wire" }
rustdb-query = { path = "crates/rustdb-query" }
rustdb-storage = { path = "crates/rustdb-storage" }
rustdb-index = { path = "crates/rustdb-index" }
rustdb-txn = { path = "crates/rustdb-txn" }
rustdb-commands = { path = "crates/rustdb-commands" }

View File

@@ -0,0 +1,24 @@
[package]
name = "rustdb-commands"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible command routing and handlers for RustDb"
[dependencies]
bson = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
dashmap = { workspace = true }
tokio = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
async-trait = { workspace = true }
rustdb-config = { workspace = true }
rustdb-wire = { workspace = true }
rustdb-query = { workspace = true }
rustdb-storage = { workspace = true }
rustdb-index = { workspace = true }
rustdb-txn = { workspace = true }

View File

@@ -0,0 +1,35 @@
use std::sync::Arc;
use bson::Document;
use dashmap::DashMap;
use rustdb_index::IndexEngine;
use rustdb_storage::StorageAdapter;
use rustdb_txn::{SessionEngine, TransactionEngine};
/// Shared command execution context, passed to all handlers.
pub struct CommandContext {
/// The storage backend.
pub storage: Arc<dyn StorageAdapter>,
/// Index engines per namespace: "db.collection" -> IndexEngine.
pub indexes: Arc<DashMap<String, IndexEngine>>,
/// Transaction engine for multi-document transactions.
pub transactions: Arc<TransactionEngine>,
/// Session engine for logical sessions.
pub sessions: Arc<SessionEngine>,
/// Active cursors for getMore / killCursors.
pub cursors: Arc<DashMap<i64, CursorState>>,
/// Server start time (for uptime reporting).
pub start_time: std::time::Instant,
}
/// State of an open cursor from a find or aggregate command.
pub struct CursorState {
/// Documents remaining to be returned.
pub documents: Vec<Document>,
/// Current read position within `documents`.
pub position: usize,
/// Database the cursor belongs to.
pub database: String,
/// Collection the cursor belongs to.
pub collection: String,
}

View File

@@ -0,0 +1,76 @@
use thiserror::Error;
/// Errors that can occur during command processing.
#[derive(Debug, Error)]
pub enum CommandError {
#[error("command not implemented: {0}")]
NotImplemented(String),
#[error("invalid argument: {0}")]
InvalidArgument(String),
#[error("storage error: {0}")]
StorageError(String),
#[error("index error: {0}")]
IndexError(String),
#[error("transaction error: {0}")]
TransactionError(String),
#[error("namespace not found: {0}")]
NamespaceNotFound(String),
#[error("namespace already exists: {0}")]
NamespaceExists(String),
#[error("duplicate key: {0}")]
DuplicateKey(String),
#[error("internal error: {0}")]
InternalError(String),
}
impl CommandError {
/// Convert a CommandError to a BSON error response document.
pub fn to_error_doc(&self) -> bson::Document {
let (code, code_name) = match self {
CommandError::NotImplemented(_) => (59, "CommandNotFound"),
CommandError::InvalidArgument(_) => (14, "TypeMismatch"),
CommandError::StorageError(_) => (1, "InternalError"),
CommandError::IndexError(_) => (27, "IndexNotFound"),
CommandError::TransactionError(_) => (112, "WriteConflict"),
CommandError::NamespaceNotFound(_) => (26, "NamespaceNotFound"),
CommandError::NamespaceExists(_) => (48, "NamespaceExists"),
CommandError::DuplicateKey(_) => (11000, "DuplicateKey"),
CommandError::InternalError(_) => (1, "InternalError"),
};
bson::doc! {
"ok": 0,
"errmsg": self.to_string(),
"code": code,
"codeName": code_name,
}
}
}
impl From<rustdb_storage::StorageError> for CommandError {
fn from(e: rustdb_storage::StorageError) -> Self {
CommandError::StorageError(e.to_string())
}
}
impl From<rustdb_txn::TransactionError> for CommandError {
fn from(e: rustdb_txn::TransactionError) -> Self {
CommandError::TransactionError(e.to_string())
}
}
impl From<rustdb_index::IndexError> for CommandError {
fn from(e: rustdb_index::IndexError) -> Self {
CommandError::IndexError(e.to_string())
}
}
pub type CommandResult<T> = Result<T, CommandError>;

View File

@@ -0,0 +1,653 @@
use bson::{doc, Bson, Document};
use rustdb_index::IndexEngine;
use tracing::debug;
use crate::context::{CommandContext, CursorState};
use crate::error::{CommandError, CommandResult};
/// Handle various admin / diagnostic / session / auth commands.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
command_name: &str,
) -> CommandResult<Document> {
match command_name {
"ping" => Ok(doc! { "ok": 1.0 }),
"buildInfo" | "buildinfo" => Ok(doc! {
"version": "7.0.0",
"gitVersion": "unknown",
"modules": [],
"sysInfo": "rustdb",
"versionArray": [7_i32, 0_i32, 0_i32, 0_i32],
"ok": 1.0,
}),
"serverStatus" => Ok(doc! {
"host": "localhost",
"version": "7.0.0",
"process": "rustdb",
"uptime": ctx.start_time.elapsed().as_secs() as i64,
"ok": 1.0,
}),
"hostInfo" => Ok(doc! {
"system": {
"hostname": "localhost",
},
"ok": 1.0,
}),
"whatsmyuri" => Ok(doc! {
"you": "127.0.0.1:0",
"ok": 1.0,
}),
"getLog" => {
let _log_type = cmd.get_str("getLog").unwrap_or("global");
Ok(doc! {
"totalLinesWritten": 0_i32,
"log": [],
"ok": 1.0,
})
}
"replSetGetStatus" => {
// Not a replica set.
Ok(doc! {
"ok": 0.0,
"errmsg": "not running with --replSet",
"code": 76_i32,
"codeName": "NoReplicationEnabled",
})
}
"getCmdLineOpts" => Ok(doc! {
"argv": ["rustdb"],
"parsed": {},
"ok": 1.0,
}),
"getParameter" => Ok(doc! {
"ok": 1.0,
}),
"getFreeMonitoringStatus" | "setFreeMonitoring" => Ok(doc! {
"state": "disabled",
"ok": 1.0,
}),
"getShardMap" | "shardingState" => Ok(doc! {
"enabled": false,
"ok": 1.0,
}),
"atlasVersion" => Ok(doc! {
"ok": 0.0,
"errmsg": "not supported",
"code": 59_i32,
"codeName": "CommandNotFound",
}),
"connectionStatus" => Ok(doc! {
"authInfo": {
"authenticatedUsers": [],
"authenticatedUserRoles": [],
},
"ok": 1.0,
}),
"listDatabases" => handle_list_databases(cmd, ctx).await,
"listCollections" => handle_list_collections(cmd, db, ctx).await,
"create" => handle_create(cmd, db, ctx).await,
"drop" => handle_drop(cmd, db, ctx).await,
"dropDatabase" => handle_drop_database(db, ctx).await,
"renameCollection" => handle_rename_collection(cmd, ctx).await,
"collStats" | "validate" => handle_coll_stats(cmd, db, ctx, command_name).await,
"dbStats" => handle_db_stats(db, ctx).await,
"explain" => Ok(doc! {
"queryPlanner": {},
"ok": 1.0,
}),
"startSession" => {
let session_id = uuid::Uuid::new_v4().to_string();
ctx.sessions.get_or_create_session(&session_id);
Ok(doc! {
"id": { "id": &session_id },
"timeoutMinutes": 30_i32,
"ok": 1.0,
})
}
"endSessions" | "killSessions" => {
// Attempt to end listed sessions.
if let Ok(sessions) = cmd
.get_array("endSessions")
.or_else(|_| cmd.get_array("killSessions"))
{
for s in sessions {
if let Some(sid) = rustdb_txn::SessionEngine::extract_session_id(s) {
ctx.sessions.end_session(&sid);
}
}
}
Ok(doc! { "ok": 1.0 })
}
"commitTransaction" => {
// Stub: acknowledge.
Ok(doc! { "ok": 1.0 })
}
"abortTransaction" => {
// Stub: acknowledge.
Ok(doc! { "ok": 1.0 })
}
// Auth stubs - accept silently.
"saslStart" => Ok(doc! {
"conversationId": 1_i32,
"done": true,
"payload": bson::Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: vec![] },
"ok": 1.0,
}),
"saslContinue" => Ok(doc! {
"conversationId": 1_i32,
"done": true,
"payload": bson::Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: vec![] },
"ok": 1.0,
}),
"authenticate" | "logout" => Ok(doc! { "ok": 1.0 }),
"currentOp" => Ok(doc! {
"inprog": [],
"ok": 1.0,
}),
"killOp" | "top" | "profile" | "compact" | "reIndex"
| "fsync" | "connPoolSync" => Ok(doc! { "ok": 1.0 }),
other => {
// Catch-all for any admin command we missed.
Ok(doc! {
"ok": 1.0,
"note": format!("stub response for command: {}", other),
})
}
}
}
/// Handle `listDatabases` command.
async fn handle_list_databases(
cmd: &Document,
ctx: &CommandContext,
) -> CommandResult<Document> {
let dbs = ctx.storage.list_databases().await?;
let name_only = match cmd.get("nameOnly") {
Some(Bson::Boolean(true)) => true,
_ => false,
};
let filter = match cmd.get("filter") {
Some(Bson::Document(d)) => Some(d.clone()),
_ => None,
};
let mut db_docs: Vec<Bson> = Vec::new();
let mut total_size: i64 = 0;
for db_name in &dbs {
let mut db_info = doc! { "name": db_name.as_str() };
if !name_only {
// Estimate size by counting documents across collections.
let mut db_size: i64 = 0;
if let Ok(collections) = ctx.storage.list_collections(db_name).await {
for coll in &collections {
if let Ok(count) = ctx.storage.count(db_name, coll).await {
// Rough estimate: 200 bytes per document.
db_size += count as i64 * 200;
}
}
}
db_info.insert("sizeOnDisk", db_size);
db_info.insert("empty", db_size == 0);
total_size += db_size;
}
// Apply filter if specified.
if let Some(ref f) = filter {
if !rustdb_query::QueryMatcher::matches(&db_info, f) {
continue;
}
}
db_docs.push(Bson::Document(db_info));
}
let mut response = doc! {
"databases": db_docs,
"ok": 1.0,
};
if !name_only {
response.insert("totalSize", total_size);
}
Ok(response)
}
/// Handle `listCollections` command.
async fn handle_list_collections(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let collections = ctx.storage.list_collections(db).await?;
let filter = match cmd.get("filter") {
Some(Bson::Document(d)) => Some(d.clone()),
_ => None,
};
let name_only = match cmd.get("nameOnly") {
Some(Bson::Boolean(true)) => true,
_ => false,
};
let batch_size = cmd
.get_document("cursor")
.ok()
.and_then(|c| {
c.get_i32("batchSize")
.ok()
.map(|v| v as usize)
.or_else(|| c.get_i64("batchSize").ok().map(|v| v as usize))
})
.unwrap_or(usize::MAX);
let ns = format!("{}.$cmd.listCollections", db);
let mut coll_docs: Vec<Document> = Vec::new();
for coll_name in &collections {
let info_doc = if name_only {
doc! {
"name": coll_name.as_str(),
"type": "collection",
}
} else {
doc! {
"name": coll_name.as_str(),
"type": "collection",
"options": {},
"info": {
"readOnly": false,
},
"idIndex": {
"v": 2_i32,
"key": { "_id": 1_i32 },
"name": "_id_",
},
}
};
// Apply filter if specified.
if let Some(ref f) = filter {
if !rustdb_query::QueryMatcher::matches(&info_doc, f) {
continue;
}
}
coll_docs.push(info_doc);
}
if coll_docs.len() <= batch_size {
let first_batch: Vec<Bson> = coll_docs.into_iter().map(Bson::Document).collect();
Ok(doc! {
"cursor": {
"id": 0_i64,
"ns": &ns,
"firstBatch": first_batch,
},
"ok": 1.0,
})
} else {
let first_batch: Vec<Bson> = coll_docs[..batch_size]
.iter()
.cloned()
.map(Bson::Document)
.collect();
let remaining: Vec<Document> = coll_docs[batch_size..].to_vec();
let cursor_id = generate_cursor_id();
ctx.cursors.insert(
cursor_id,
CursorState {
documents: remaining,
position: 0,
database: db.to_string(),
collection: String::new(),
},
);
Ok(doc! {
"cursor": {
"id": cursor_id,
"ns": &ns,
"firstBatch": first_batch,
},
"ok": 1.0,
})
}
}
/// Handle `create` command.
async fn handle_create(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("create")
.map_err(|_| CommandError::InvalidArgument("missing 'create' field".into()))?;
debug!(db = db, collection = coll, "create command");
// Create database (ignore AlreadyExists).
if let Err(e) = ctx.storage.create_database(db).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
// Create collection.
if let Err(e) = ctx.storage.create_collection(db, coll).await {
let msg = e.to_string();
if msg.contains("AlreadyExists") || msg.contains("already exists") {
return Err(CommandError::NamespaceExists(format!("{}.{}", db, coll)));
}
return Err(CommandError::StorageError(msg));
}
// Initialize index engine for the new collection.
let ns_key = format!("{}.{}", db, coll);
ctx.indexes
.entry(ns_key)
.or_insert_with(IndexEngine::new);
Ok(doc! { "ok": 1.0 })
}
/// Handle `drop` command.
async fn handle_drop(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("drop")
.map_err(|_| CommandError::InvalidArgument("missing 'drop' field".into()))?;
let ns_key = format!("{}.{}", db, coll);
debug!(db = db, collection = coll, "drop command");
// Check if collection exists.
match ctx.storage.collection_exists(db, coll).await {
Ok(false) => {
return Err(CommandError::NamespaceNotFound(format!(
"ns not found: {}",
ns_key
)));
}
Err(_) => {}
_ => {}
}
// Drop from storage.
ctx.storage.drop_collection(db, coll).await?;
// Remove from indexes.
ctx.indexes.remove(&ns_key);
// Count of indexes that were on this collection (at least _id_).
Ok(doc! {
"ns": &ns_key,
"nIndexesWas": 1_i32,
"ok": 1.0,
})
}
/// Handle `dropDatabase` command.
async fn handle_drop_database(
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
debug!(db = db, "dropDatabase command");
// Remove all index entries for this database.
let prefix = format!("{}.", db);
let keys_to_remove: Vec<String> = ctx
.indexes
.iter()
.filter(|entry| entry.key().starts_with(&prefix))
.map(|entry| entry.key().clone())
.collect();
for key in keys_to_remove {
ctx.indexes.remove(&key);
}
// Drop from storage.
ctx.storage.drop_database(db).await?;
Ok(doc! {
"dropped": db,
"ok": 1.0,
})
}
/// Handle `renameCollection` command.
async fn handle_rename_collection(
cmd: &Document,
ctx: &CommandContext,
) -> CommandResult<Document> {
let source_ns = cmd
.get_str("renameCollection")
.map_err(|_| CommandError::InvalidArgument("missing 'renameCollection' field".into()))?;
let target_ns = cmd
.get_str("to")
.map_err(|_| CommandError::InvalidArgument("missing 'to' field".into()))?;
let drop_target = match cmd.get("dropTarget") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
// Parse "db.collection" format.
let (source_db, source_coll) = parse_namespace(source_ns)?;
let (target_db, target_coll) = parse_namespace(target_ns)?;
debug!(
source = source_ns,
target = target_ns,
drop_target = drop_target,
"renameCollection command"
);
// If cross-database rename, that's more complex. For now, support same-db rename.
if source_db != target_db {
return Err(CommandError::InvalidArgument(
"cross-database renameCollection not yet supported".into(),
));
}
// If dropTarget, drop the target collection first.
if drop_target {
let _ = ctx.storage.drop_collection(target_db, target_coll).await;
let target_ns_key = format!("{}.{}", target_db, target_coll);
ctx.indexes.remove(&target_ns_key);
} else {
// Check if target already exists.
if let Ok(true) = ctx.storage.collection_exists(target_db, target_coll).await {
return Err(CommandError::NamespaceExists(target_ns.to_string()));
}
}
// Rename in storage.
ctx.storage
.rename_collection(source_db, source_coll, target_coll)
.await?;
// Update index engine: move from old namespace to new.
let source_ns_key = format!("{}.{}", source_db, source_coll);
let target_ns_key = format!("{}.{}", target_db, target_coll);
if let Some((_, engine)) = ctx.indexes.remove(&source_ns_key) {
ctx.indexes.insert(target_ns_key, engine);
}
Ok(doc! { "ok": 1.0 })
}
/// Handle `collStats` command.
async fn handle_coll_stats(
cmd: &Document,
db: &str,
ctx: &CommandContext,
command_name: &str,
) -> CommandResult<Document> {
let coll = cmd
.get_str(command_name)
.unwrap_or("unknown");
let ns = format!("{}.{}", db, coll);
let count = ctx
.storage
.count(db, coll)
.await
.unwrap_or(0);
let n_indexes = match ctx.indexes.get(&ns) {
Some(engine) => engine.list_indexes().len() as i32,
None => 1_i32,
};
// Rough size estimate.
let data_size = count as i64 * 200;
Ok(doc! {
"ns": &ns,
"count": count as i64,
"size": data_size,
"avgObjSize": if count > 0 { 200_i64 } else { 0_i64 },
"storageSize": data_size,
"nindexes": n_indexes,
"totalIndexSize": 0_i64,
"ok": 1.0,
})
}
/// Handle `dbStats` command.
async fn handle_db_stats(
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let collections = ctx
.storage
.list_collections(db)
.await
.unwrap_or_default();
let num_collections = collections.len() as i32;
let mut total_objects: i64 = 0;
let mut total_indexes: i32 = 0;
for coll in &collections {
if let Ok(count) = ctx.storage.count(db, coll).await {
total_objects += count as i64;
}
let ns_key = format!("{}.{}", db, coll);
if let Some(engine) = ctx.indexes.get(&ns_key) {
total_indexes += engine.list_indexes().len() as i32;
} else {
total_indexes += 1; // At least _id_.
}
}
let data_size = total_objects * 200;
Ok(doc! {
"db": db,
"collections": num_collections,
"objects": total_objects,
"avgObjSize": if total_objects > 0 { 200_i64 } else { 0_i64 },
"dataSize": data_size,
"storageSize": data_size,
"indexes": total_indexes,
"indexSize": 0_i64,
"ok": 1.0,
})
}
/// Parse a namespace string "db.collection" into (db, collection).
fn parse_namespace(ns: &str) -> CommandResult<(&str, &str)> {
let dot_pos = ns.find('.').ok_or_else(|| {
CommandError::InvalidArgument(format!(
"invalid namespace '{}': expected 'db.collection' format",
ns
))
})?;
let db = &ns[..dot_pos];
let coll = &ns[dot_pos + 1..];
if db.is_empty() || coll.is_empty() {
return Err(CommandError::InvalidArgument(format!(
"invalid namespace '{}': db and collection must not be empty",
ns
)));
}
Ok((db, coll))
}
/// Generate a pseudo-random cursor ID.
fn generate_cursor_id() -> i64 {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let s = RandomState::new();
let mut hasher = s.build_hasher();
hasher.write_u64(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64,
);
let id = hasher.finish() as i64;
if id == 0 {
1
} else {
id.abs()
}
}

View File

@@ -0,0 +1,310 @@
use bson::{doc, Bson, Document};
use rustdb_query::AggregationEngine;
use rustdb_query::error::QueryError;
use tracing::debug;
use crate::context::{CommandContext, CursorState};
use crate::error::{CommandError, CommandResult};
/// A CollectionResolver that reads from the storage adapter.
struct StorageResolver<'a> {
storage: &'a dyn rustdb_storage::StorageAdapter,
/// We use a tokio runtime handle to call async methods synchronously,
/// since the CollectionResolver trait is synchronous.
handle: tokio::runtime::Handle,
}
impl<'a> rustdb_query::aggregation::CollectionResolver for StorageResolver<'a> {
fn resolve(&self, db: &str, coll: &str) -> Result<Vec<Document>, QueryError> {
self.handle
.block_on(async { self.storage.find_all(db, coll).await })
.map_err(|e| QueryError::AggregationError(format!("Failed to resolve {}.{}: {}", db, coll, e)))
}
}
/// Handle the `aggregate` command.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
// The aggregate field can be a string (collection name) or integer 1 (db-level).
let (coll, is_db_level) = match cmd.get("aggregate") {
Some(Bson::String(s)) => (s.as_str().to_string(), false),
Some(Bson::Int32(1)) => (String::new(), true),
Some(Bson::Int64(1)) => (String::new(), true),
_ => {
return Err(CommandError::InvalidArgument(
"missing or invalid 'aggregate' field".into(),
));
}
};
let pipeline_bson = cmd
.get_array("pipeline")
.map_err(|_| CommandError::InvalidArgument("missing 'pipeline' array".into()))?;
// Convert pipeline to Vec<Document>.
let mut pipeline: Vec<Document> = Vec::with_capacity(pipeline_bson.len());
for stage in pipeline_bson {
match stage {
Bson::Document(d) => pipeline.push(d.clone()),
_ => {
return Err(CommandError::InvalidArgument(
"pipeline stage must be a document".into(),
));
}
}
}
// Check for $out and $merge as the last stage (handle after pipeline execution).
let out_stage = if let Some(last) = pipeline.last() {
if last.contains_key("$out") || last.contains_key("$merge") {
Some(pipeline.pop().unwrap())
} else {
None
}
} else {
None
};
let batch_size = cmd
.get_document("cursor")
.ok()
.and_then(|c| {
c.get_i32("batchSize")
.ok()
.map(|v| v as usize)
.or_else(|| c.get_i64("batchSize").ok().map(|v| v as usize))
})
.unwrap_or(101);
debug!(
db = db,
collection = %coll,
stages = pipeline.len(),
"aggregate command"
);
// Load source documents.
let source_docs = if is_db_level {
// Database-level aggregate: start with empty set (useful for $currentOp, etc.)
Vec::new()
} else {
ctx.storage.find_all(db, &coll).await?
};
// Create a resolver for $lookup and similar stages.
let handle = tokio::runtime::Handle::current();
let resolver = StorageResolver {
storage: ctx.storage.as_ref(),
handle,
};
// Run the aggregation pipeline.
let result_docs = AggregationEngine::aggregate(
source_docs,
&pipeline,
Some(&resolver),
db,
)
.map_err(|e| CommandError::InternalError(e.to_string()))?;
// Handle $out stage: write results to target collection.
if let Some(out) = out_stage {
if let Some(out_spec) = out.get("$out") {
handle_out_stage(db, out_spec, &result_docs, ctx).await?;
} else if let Some(merge_spec) = out.get("$merge") {
handle_merge_stage(db, merge_spec, &result_docs, ctx).await?;
}
}
// Build cursor response.
let ns = if is_db_level {
format!("{}.$cmd.aggregate", db)
} else {
format!("{}.{}", db, coll)
};
if result_docs.len() <= batch_size {
// All results fit in first batch.
let first_batch: Vec<Bson> = result_docs
.into_iter()
.map(Bson::Document)
.collect();
Ok(doc! {
"cursor": {
"firstBatch": first_batch,
"id": 0_i64,
"ns": &ns,
},
"ok": 1.0,
})
} else {
// Need to create a cursor for remaining results.
let first_batch: Vec<Bson> = result_docs[..batch_size]
.iter()
.cloned()
.map(Bson::Document)
.collect();
let remaining: Vec<Document> = result_docs[batch_size..].to_vec();
let cursor_id = generate_cursor_id();
ctx.cursors.insert(
cursor_id,
CursorState {
documents: remaining,
position: 0,
database: db.to_string(),
collection: coll.to_string(),
},
);
Ok(doc! {
"cursor": {
"firstBatch": first_batch,
"id": cursor_id,
"ns": &ns,
},
"ok": 1.0,
})
}
}
/// Handle $out stage: drop and replace target collection with pipeline results.
async fn handle_out_stage(
db: &str,
out_spec: &Bson,
docs: &[Document],
ctx: &CommandContext,
) -> CommandResult<()> {
let (target_db, target_coll) = match out_spec {
Bson::String(coll_name) => (db.to_string(), coll_name.clone()),
Bson::Document(d) => {
let tdb = d.get_str("db").unwrap_or(db).to_string();
let tcoll = d
.get_str("coll")
.map_err(|_| CommandError::InvalidArgument("$out requires 'coll'".into()))?
.to_string();
(tdb, tcoll)
}
_ => {
return Err(CommandError::InvalidArgument(
"$out requires a string or document".into(),
));
}
};
// Drop existing target collection (ignore errors).
let _ = ctx.storage.drop_collection(&target_db, &target_coll).await;
// Create target collection.
let _ = ctx.storage.create_database(&target_db).await;
let _ = ctx.storage.create_collection(&target_db, &target_coll).await;
// Insert all result documents.
for doc in docs {
let _ = ctx
.storage
.insert_one(&target_db, &target_coll, doc.clone())
.await;
}
Ok(())
}
/// Handle $merge stage: merge pipeline results into target collection.
async fn handle_merge_stage(
db: &str,
merge_spec: &Bson,
docs: &[Document],
ctx: &CommandContext,
) -> CommandResult<()> {
let (target_db, target_coll) = match merge_spec {
Bson::String(coll_name) => (db.to_string(), coll_name.clone()),
Bson::Document(d) => {
let into_val = d.get("into");
match into_val {
Some(Bson::String(s)) => (db.to_string(), s.clone()),
Some(Bson::Document(into_doc)) => {
let tdb = into_doc.get_str("db").unwrap_or(db).to_string();
let tcoll = into_doc
.get_str("coll")
.map_err(|_| {
CommandError::InvalidArgument("$merge.into requires 'coll'".into())
})?
.to_string();
(tdb, tcoll)
}
_ => {
return Err(CommandError::InvalidArgument(
"$merge requires 'into' field".into(),
));
}
}
}
_ => {
return Err(CommandError::InvalidArgument(
"$merge requires a string or document".into(),
));
}
};
// Ensure target collection exists.
let _ = ctx.storage.create_database(&target_db).await;
let _ = ctx
.storage
.create_collection(&target_db, &target_coll)
.await;
// Simple merge: upsert by _id.
for doc in docs {
let id_str = match doc.get("_id") {
Some(Bson::ObjectId(oid)) => oid.to_hex(),
Some(Bson::String(s)) => s.clone(),
Some(other) => format!("{}", other),
None => {
// No _id, just insert.
let _ = ctx
.storage
.insert_one(&target_db, &target_coll, doc.clone())
.await;
continue;
}
};
// Try update first, insert if it fails.
match ctx
.storage
.update_by_id(&target_db, &target_coll, &id_str, doc.clone())
.await
{
Ok(()) => {}
Err(_) => {
let _ = ctx
.storage
.insert_one(&target_db, &target_coll, doc.clone())
.await;
}
}
}
Ok(())
}
/// Generate a pseudo-random cursor ID.
fn generate_cursor_id() -> i64 {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let s = RandomState::new();
let mut hasher = s.build_hasher();
hasher.write_u64(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64);
let id = hasher.finish() as i64;
// Ensure positive and non-zero.
if id == 0 { 1 } else { id.abs() }
}

View File

@@ -0,0 +1,196 @@
use std::collections::HashSet;
use bson::{doc, Bson, Document};
use rustdb_query::QueryMatcher;
use tracing::debug;
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
/// Handle the `delete` command.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("delete")
.map_err(|_| CommandError::InvalidArgument("missing 'delete' field".into()))?;
let deletes = cmd
.get_array("deletes")
.map_err(|_| CommandError::InvalidArgument("missing 'deletes' array".into()))?;
// Ordered flag (default true).
let ordered = match cmd.get("ordered") {
Some(Bson::Boolean(b)) => *b,
_ => true,
};
debug!(
db = db,
collection = coll,
count = deletes.len(),
"delete command"
);
let ns_key = format!("{}.{}", db, coll);
let mut total_deleted: i32 = 0;
let mut write_errors: Vec<Document> = Vec::new();
for (idx, del_spec) in deletes.iter().enumerate() {
let del_doc = match del_spec {
Bson::Document(d) => d,
_ => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": "delete spec must be a document",
});
if ordered {
break;
}
continue;
}
};
// Extract filter (q) and limit.
let filter = match del_doc.get_document("q") {
Ok(f) => f.clone(),
Err(_) => Document::new(), // empty filter matches everything
};
let limit = match del_doc.get("limit") {
Some(Bson::Int32(n)) => *n,
Some(Bson::Int64(n)) => *n as i32,
Some(Bson::Double(n)) => *n as i32,
_ => 0, // default: delete all matches
};
match delete_matching(db, coll, &ns_key, &filter, limit, ctx).await {
Ok(count) => {
total_deleted += count;
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 1_i32,
"codeName": "InternalError",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
}
// Build response.
let mut response = doc! {
"n": total_deleted,
"ok": 1.0,
};
if !write_errors.is_empty() {
response.insert(
"writeErrors",
write_errors
.into_iter()
.map(Bson::Document)
.collect::<Vec<_>>(),
);
}
Ok(response)
}
/// Find and delete documents matching a filter, returning the number deleted.
async fn delete_matching(
db: &str,
coll: &str,
ns_key: &str,
filter: &Document,
limit: i32,
ctx: &CommandContext,
) -> Result<i32, CommandError> {
// Check if the collection exists; if not, nothing to delete.
match ctx.storage.collection_exists(db, coll).await {
Ok(false) => return Ok(0),
Err(_) => return Ok(0),
Ok(true) => {}
}
// Try to use index to narrow candidates.
let candidate_ids: Option<HashSet<String>> = {
if let Some(engine) = ctx.indexes.get(ns_key) {
engine.find_candidate_ids(filter)
} else {
None
}
};
// Load candidate documents.
let docs = if let Some(ids) = candidate_ids {
if ids.is_empty() {
return Ok(0);
}
ctx.storage
.find_by_ids(db, coll, ids)
.await
.map_err(|e| CommandError::StorageError(e.to_string()))?
} else {
ctx.storage
.find_all(db, coll)
.await
.map_err(|e| CommandError::StorageError(e.to_string()))?
};
// Apply filter to get matched documents.
let matched = QueryMatcher::filter(&docs, filter);
// Apply limit: 0 means delete all, 1 means delete only the first match.
let to_delete: &[Document] = if limit == 1 && !matched.is_empty() {
&matched[..1]
} else {
&matched
};
if to_delete.is_empty() {
return Ok(0);
}
let mut deleted_count: i32 = 0;
for doc in to_delete {
// Extract the _id as a hex string for storage deletion.
let id_str = extract_id_string(doc)?;
ctx.storage
.delete_by_id(db, coll, &id_str)
.await
.map_err(|e| CommandError::StorageError(e.to_string()))?;
// Update index engine.
if let Some(mut engine) = ctx.indexes.get_mut(ns_key) {
engine.on_delete(doc);
}
deleted_count += 1;
}
Ok(deleted_count)
}
/// Extract the `_id` field from a document as a hex string suitable for the
/// storage adapter.
fn extract_id_string(doc: &Document) -> Result<String, CommandError> {
match doc.get("_id") {
Some(Bson::ObjectId(oid)) => Ok(oid.to_hex()),
Some(Bson::String(s)) => Ok(s.clone()),
Some(other) => Ok(format!("{}", other)),
None => Err(CommandError::InvalidArgument(
"document missing _id field".into(),
)),
}
}

View File

@@ -0,0 +1,370 @@
use std::sync::atomic::{AtomicI64, Ordering};
use bson::{doc, Bson, Document};
use tracing::debug;
use rustdb_query::{QueryMatcher, sort_documents, apply_projection, distinct_values};
use crate::context::{CommandContext, CursorState};
use crate::error::{CommandError, CommandResult};
/// Atomic counter for generating unique cursor IDs.
static CURSOR_ID_COUNTER: AtomicI64 = AtomicI64::new(1);
/// Generate a new unique, positive cursor ID.
fn next_cursor_id() -> i64 {
CURSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
}
// ---------------------------------------------------------------------------
// Helpers to defensively extract values from BSON command documents
// ---------------------------------------------------------------------------
fn get_str<'a>(doc: &'a Document, key: &str) -> Option<&'a str> {
match doc.get(key)? {
Bson::String(s) => Some(s.as_str()),
_ => None,
}
}
fn get_i32(doc: &Document, key: &str) -> Option<i32> {
match doc.get(key)? {
Bson::Int32(v) => Some(*v),
Bson::Int64(v) => Some(*v as i32),
Bson::Double(v) => Some(*v as i32),
_ => None,
}
}
fn get_i64(doc: &Document, key: &str) -> Option<i64> {
match doc.get(key)? {
Bson::Int64(v) => Some(*v),
Bson::Int32(v) => Some(*v as i64),
Bson::Double(v) => Some(*v as i64),
_ => None,
}
}
fn get_bool(doc: &Document, key: &str) -> Option<bool> {
match doc.get(key)? {
Bson::Boolean(v) => Some(*v),
_ => None,
}
}
fn get_document<'a>(doc: &'a Document, key: &str) -> Option<&'a Document> {
match doc.get(key)? {
Bson::Document(d) => Some(d),
_ => None,
}
}
// ---------------------------------------------------------------------------
// find
// ---------------------------------------------------------------------------
/// Handle the `find` command.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = get_str(cmd, "find").unwrap_or("unknown");
let ns = format!("{}.{}", db, coll);
// Extract optional parameters.
let filter = get_document(cmd, "filter").cloned().unwrap_or_default();
let sort_spec = get_document(cmd, "sort").cloned();
let projection = get_document(cmd, "projection").cloned();
let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize;
let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize;
let batch_size = get_i32(cmd, "batchSize").unwrap_or(101).max(0) as usize;
let single_batch = get_bool(cmd, "singleBatch").unwrap_or(false);
// If the collection does not exist, return an empty cursor.
let exists = ctx.storage.collection_exists(db, coll).await?;
if !exists {
return Ok(doc! {
"cursor": {
"firstBatch": [],
"id": 0_i64,
"ns": &ns,
},
"ok": 1.0,
});
}
// Try index-accelerated lookup.
let index_key = format!("{}.{}", db, coll);
let docs = if let Some(idx_ref) = ctx.indexes.get(&index_key) {
if let Some(candidate_ids) = idx_ref.find_candidate_ids(&filter) {
debug!(
ns = %ns,
candidates = candidate_ids.len(),
"using index acceleration"
);
ctx.storage.find_by_ids(db, coll, candidate_ids).await?
} else {
ctx.storage.find_all(db, coll).await?
}
} else {
ctx.storage.find_all(db, coll).await?
};
// Apply filter.
let mut docs = QueryMatcher::filter(&docs, &filter);
// Apply sort.
if let Some(ref sort) = sort_spec {
sort_documents(&mut docs, sort);
}
// Apply skip.
if skip > 0 {
if skip >= docs.len() {
docs = Vec::new();
} else {
docs = docs.split_off(skip);
}
}
// Apply limit.
if limit > 0 && docs.len() > limit {
docs.truncate(limit);
}
// Apply projection.
if let Some(ref proj) = projection {
docs = docs.iter().map(|d| apply_projection(d, proj)).collect();
}
// Determine first batch.
if docs.len() <= batch_size || single_batch {
// Everything fits in a single batch.
let batch: Vec<Bson> = docs.into_iter().map(Bson::Document).collect();
Ok(doc! {
"cursor": {
"firstBatch": batch,
"id": 0_i64,
"ns": &ns,
},
"ok": 1.0,
})
} else {
// Split into first batch and remainder, store cursor.
let remaining = docs.split_off(batch_size);
let first_batch: Vec<Bson> = docs.into_iter().map(Bson::Document).collect();
let cursor_id = next_cursor_id();
ctx.cursors.insert(cursor_id, CursorState {
documents: remaining,
position: 0,
database: db.to_string(),
collection: coll.to_string(),
});
Ok(doc! {
"cursor": {
"firstBatch": first_batch,
"id": cursor_id,
"ns": &ns,
},
"ok": 1.0,
})
}
}
// ---------------------------------------------------------------------------
// getMore
// ---------------------------------------------------------------------------
/// Handle the `getMore` command.
pub async fn handle_get_more(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
// Defensively extract cursor id.
let cursor_id = get_i64(cmd, "getMore").ok_or_else(|| {
CommandError::InvalidArgument("getMore requires a cursor id".into())
})?;
let coll = get_str(cmd, "collection").unwrap_or("unknown");
let ns = format!("{}.{}", db, coll);
let batch_size = get_i64(cmd, "batchSize")
.or_else(|| get_i32(cmd, "batchSize").map(|v| v as i64))
.unwrap_or(101)
.max(0) as usize;
// Look up the cursor.
let mut cursor_entry = ctx.cursors.get_mut(&cursor_id).ok_or_else(|| {
CommandError::InvalidArgument(format!("cursor id {} not found", cursor_id))
})?;
let cursor = cursor_entry.value_mut();
let start = cursor.position;
let end = (start + batch_size).min(cursor.documents.len());
let batch: Vec<Bson> = cursor.documents[start..end]
.iter()
.cloned()
.map(Bson::Document)
.collect();
cursor.position = end;
let exhausted = cursor.position >= cursor.documents.len();
// Must drop the mutable reference before removing.
drop(cursor_entry);
if exhausted {
ctx.cursors.remove(&cursor_id);
Ok(doc! {
"cursor": {
"nextBatch": batch,
"id": 0_i64,
"ns": &ns,
},
"ok": 1.0,
})
} else {
Ok(doc! {
"cursor": {
"nextBatch": batch,
"id": cursor_id,
"ns": &ns,
},
"ok": 1.0,
})
}
}
// ---------------------------------------------------------------------------
// killCursors
// ---------------------------------------------------------------------------
/// Handle the `killCursors` command.
pub async fn handle_kill_cursors(
cmd: &Document,
_db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let cursor_ids = match cmd.get("cursors") {
Some(Bson::Array(arr)) => arr,
_ => {
return Ok(doc! {
"cursorsKilled": [],
"cursorsNotFound": [],
"cursorsAlive": [],
"cursorsUnknown": [],
"ok": 1.0,
});
}
};
let mut killed: Vec<Bson> = Vec::new();
let mut not_found: Vec<Bson> = Vec::new();
for id_bson in cursor_ids {
let id = match id_bson {
Bson::Int64(v) => *v,
Bson::Int32(v) => *v as i64,
_ => continue,
};
if ctx.cursors.remove(&id).is_some() {
killed.push(Bson::Int64(id));
} else {
not_found.push(Bson::Int64(id));
}
}
Ok(doc! {
"cursorsKilled": killed,
"cursorsNotFound": not_found,
"cursorsAlive": [],
"cursorsUnknown": [],
"ok": 1.0,
})
}
// ---------------------------------------------------------------------------
// count
// ---------------------------------------------------------------------------
/// Handle the `count` command.
pub async fn handle_count(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = get_str(cmd, "count").unwrap_or("unknown");
// Check collection existence.
let exists = ctx.storage.collection_exists(db, coll).await?;
if !exists {
return Ok(doc! { "n": 0_i64, "ok": 1.0 });
}
let query = get_document(cmd, "query").cloned().unwrap_or_default();
let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize;
let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize;
let count: u64 = if query.is_empty() && skip == 0 && limit == 0 {
// Fast path: use storage-level count.
ctx.storage.count(db, coll).await?
} else if query.is_empty() {
// No filter but skip/limit apply.
let total = ctx.storage.count(db, coll).await? as usize;
let after_skip = total.saturating_sub(skip);
let result = if limit > 0 { after_skip.min(limit) } else { after_skip };
result as u64
} else {
// Need to load and filter.
let docs = ctx.storage.find_all(db, coll).await?;
let filtered = QueryMatcher::filter(&docs, &query);
let mut n = filtered.len();
// Apply skip.
n = n.saturating_sub(skip);
// Apply limit.
if limit > 0 {
n = n.min(limit);
}
n as u64
};
Ok(doc! {
"n": count as i64,
"ok": 1.0,
})
}
// ---------------------------------------------------------------------------
// distinct
// ---------------------------------------------------------------------------
/// Handle the `distinct` command.
pub async fn handle_distinct(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = get_str(cmd, "distinct").unwrap_or("unknown");
let key = get_str(cmd, "key").ok_or_else(|| {
CommandError::InvalidArgument("distinct requires a 'key' field".into())
})?;
// Check collection existence.
let exists = ctx.storage.collection_exists(db, coll).await?;
if !exists {
return Ok(doc! { "values": [], "ok": 1.0 });
}
let query = get_document(cmd, "query").cloned();
let docs = ctx.storage.find_all(db, coll).await?;
let values = distinct_values(&docs, key, query.as_ref());
Ok(doc! {
"values": values,
"ok": 1.0,
})
}

View File

@@ -0,0 +1,28 @@
use bson::{doc, Document};
use crate::context::CommandContext;
use crate::error::CommandResult;
/// Handle `hello`, `ismaster`, and `isMaster` commands.
///
/// Returns server capabilities matching wire protocol expectations.
pub async fn handle(
_cmd: &Document,
_db: &str,
_ctx: &CommandContext,
) -> CommandResult<Document> {
Ok(doc! {
"ismaster": true,
"isWritablePrimary": true,
"maxBsonObjectSize": 16_777_216_i32,
"maxMessageSizeBytes": 48_000_000_i32,
"maxWriteBatchSize": 100_000_i32,
"localTime": bson::DateTime::now(),
"logicalSessionTimeoutMinutes": 30_i32,
"connectionId": 1_i32,
"minWireVersion": 0_i32,
"maxWireVersion": 21_i32,
"readOnly": false,
"ok": 1.0,
})
}

View File

@@ -0,0 +1,342 @@
use bson::{doc, Bson, Document};
use rustdb_index::{IndexEngine, IndexOptions};
use tracing::debug;
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
/// Handle `createIndexes`, `dropIndexes`, and `listIndexes` commands.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
command_name: &str,
) -> CommandResult<Document> {
match command_name {
"createIndexes" => handle_create_indexes(cmd, db, ctx).await,
"dropIndexes" => handle_drop_indexes(cmd, db, ctx).await,
"listIndexes" => handle_list_indexes(cmd, db, ctx).await,
_ => Ok(doc! { "ok": 1.0 }),
}
}
/// Handle the `createIndexes` command.
async fn handle_create_indexes(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("createIndexes")
.map_err(|_| CommandError::InvalidArgument("missing 'createIndexes' field".into()))?;
let indexes = cmd
.get_array("indexes")
.map_err(|_| CommandError::InvalidArgument("missing 'indexes' array".into()))?;
let ns_key = format!("{}.{}", db, coll);
debug!(
db = db,
collection = coll,
count = indexes.len(),
"createIndexes command"
);
// Auto-create collection if needed.
let created_automatically = ensure_collection_exists(db, coll, ctx).await?;
// Get the number of indexes before creating new ones.
let num_before = {
let engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
engine.list_indexes().len() as i32
};
let mut created_count = 0_i32;
for index_bson in indexes {
let index_spec = match index_bson {
Bson::Document(d) => d,
_ => {
return Err(CommandError::InvalidArgument(
"index spec must be a document".into(),
));
}
};
let key = match index_spec.get("key") {
Some(Bson::Document(k)) => k.clone(),
_ => {
return Err(CommandError::InvalidArgument(
"index spec must have a 'key' document".into(),
));
}
};
let name = index_spec.get_str("name").ok().map(|s| s.to_string());
let unique = match index_spec.get("unique") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let sparse = match index_spec.get("sparse") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let expire_after_seconds = match index_spec.get("expireAfterSeconds") {
Some(Bson::Int32(n)) => Some(*n as u64),
Some(Bson::Int64(n)) => Some(*n as u64),
_ => None,
};
let options = IndexOptions {
name,
unique,
sparse,
expire_after_seconds,
};
// Create the index.
let mut engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
match engine.create_index(key, options) {
Ok(index_name) => {
debug!(index_name = %index_name, "Created index");
created_count += 1;
}
Err(e) => {
return Err(CommandError::IndexError(e.to_string()));
}
}
}
// If we created indexes on an existing collection, rebuild from documents.
if created_count > 0 && !created_automatically {
// Load all documents and rebuild indexes.
if let Ok(all_docs) = ctx.storage.find_all(db, coll).await {
if !all_docs.is_empty() {
let mut engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
engine.rebuild_from_documents(&all_docs);
}
}
}
let num_after = {
let engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
engine.list_indexes().len() as i32
};
Ok(doc! {
"createdCollectionAutomatically": created_automatically,
"numIndexesBefore": num_before,
"numIndexesAfter": num_after,
"ok": 1.0,
})
}
/// Handle the `dropIndexes` command.
async fn handle_drop_indexes(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("dropIndexes")
.map_err(|_| CommandError::InvalidArgument("missing 'dropIndexes' field".into()))?;
let ns_key = format!("{}.{}", db, coll);
// Get current index count.
let n_indexes_was = {
match ctx.indexes.get(&ns_key) {
Some(engine) => engine.list_indexes().len() as i32,
None => 1_i32, // At minimum the _id_ index.
}
};
let index_spec = cmd.get("index");
debug!(
db = db,
collection = coll,
index_spec = ?index_spec,
"dropIndexes command"
);
match index_spec {
Some(Bson::String(name)) if name == "*" => {
// Drop all indexes except _id_.
if let Some(mut engine) = ctx.indexes.get_mut(&ns_key) {
engine.drop_all_indexes();
}
}
Some(Bson::String(name)) => {
// Drop by name.
if let Some(mut engine) = ctx.indexes.get_mut(&ns_key) {
engine.drop_index(name).map_err(|e| {
CommandError::IndexError(e.to_string())
})?;
} else {
return Err(CommandError::IndexError(format!(
"index not found: {}",
name
)));
}
}
Some(Bson::Document(key_spec)) => {
// Drop by key spec: find the index with matching key.
if let Some(mut engine) = ctx.indexes.get_mut(&ns_key) {
let index_name = engine
.list_indexes()
.iter()
.find(|info| info.key == *key_spec)
.map(|info| info.name.clone());
if let Some(name) = index_name {
engine.drop_index(&name).map_err(|e| {
CommandError::IndexError(e.to_string())
})?;
} else {
return Err(CommandError::IndexError(
"index not found with specified key".into(),
));
}
} else {
return Err(CommandError::IndexError(
"no indexes found for collection".into(),
));
}
}
_ => {
return Err(CommandError::InvalidArgument(
"dropIndexes requires 'index' field (string, document, or \"*\")".into(),
));
}
}
Ok(doc! {
"nIndexesWas": n_indexes_was,
"ok": 1.0,
})
}
/// Handle the `listIndexes` command.
async fn handle_list_indexes(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("listIndexes")
.map_err(|_| CommandError::InvalidArgument("missing 'listIndexes' field".into()))?;
let ns_key = format!("{}.{}", db, coll);
let ns = format!("{}.{}", db, coll);
// Check if collection exists.
match ctx.storage.collection_exists(db, coll).await {
Ok(false) => {
return Err(CommandError::NamespaceNotFound(format!(
"ns not found: {}",
ns
)));
}
Err(_) => {
// If we can't check, try to proceed anyway.
}
_ => {}
}
let indexes = match ctx.indexes.get(&ns_key) {
Some(engine) => engine.list_indexes(),
None => {
// Return at least the default _id_ index.
let engine = IndexEngine::new();
engine.list_indexes()
}
};
let first_batch: Vec<Bson> = indexes
.into_iter()
.map(|info| {
let mut doc = doc! {
"v": info.v,
"key": info.key,
"name": info.name,
};
if info.unique {
doc.insert("unique", true);
}
if info.sparse {
doc.insert("sparse", true);
}
if let Some(ttl) = info.expire_after_seconds {
doc.insert("expireAfterSeconds", ttl as i64);
}
Bson::Document(doc)
})
.collect();
Ok(doc! {
"cursor": {
"id": 0_i64,
"ns": &ns,
"firstBatch": first_batch,
},
"ok": 1.0,
})
}
/// Ensure the target database and collection exist. Returns true if the collection
/// was newly created (i.e., `createdCollectionAutomatically`).
async fn ensure_collection_exists(
db: &str,
coll: &str,
ctx: &CommandContext,
) -> CommandResult<bool> {
// Create database (ignore AlreadyExists).
if let Err(e) = ctx.storage.create_database(db).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
// Check if collection exists.
match ctx.storage.collection_exists(db, coll).await {
Ok(true) => Ok(false),
Ok(false) => {
if let Err(e) = ctx.storage.create_collection(db, coll).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
Ok(true)
}
Err(_) => {
// Try creating anyway.
if let Err(e) = ctx.storage.create_collection(db, coll).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
Ok(true)
}
}
}

View File

@@ -0,0 +1,185 @@
use std::collections::HashMap;
use bson::{doc, oid::ObjectId, Bson, Document};
use rustdb_index::IndexEngine;
use tracing::{debug, warn};
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
/// Handle the `insert` command.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
document_sequences: Option<&HashMap<String, Vec<Document>>>,
) -> CommandResult<Document> {
let coll = cmd
.get_str("insert")
.map_err(|_| CommandError::InvalidArgument("missing 'insert' field".into()))?;
// Determine whether writes are ordered (default: true).
let ordered = match cmd.get("ordered") {
Some(Bson::Boolean(b)) => *b,
_ => true,
};
// Collect documents from either the command body or OP_MSG document sequences.
let docs: Vec<Document> = if let Some(seqs) = document_sequences {
if let Some(seq_docs) = seqs.get("documents") {
seq_docs.clone()
} else {
extract_docs_from_array(cmd)?
}
} else {
extract_docs_from_array(cmd)?
};
if docs.is_empty() {
return Err(CommandError::InvalidArgument(
"no documents to insert".into(),
));
}
debug!(
db = db,
collection = coll,
count = docs.len(),
"insert command"
);
// Auto-create database and collection if they don't exist.
ensure_collection_exists(db, coll, ctx).await?;
let ns_key = format!("{}.{}", db, coll);
let mut inserted_count: i32 = 0;
let mut write_errors: Vec<Document> = Vec::new();
for (idx, mut doc) in docs.into_iter().enumerate() {
// Auto-generate _id if not present.
if !doc.contains_key("_id") {
doc.insert("_id", ObjectId::new());
}
// Attempt storage insert.
match ctx.storage.insert_one(db, coll, doc.clone()).await {
Ok(_id_str) => {
// Update index engine.
let mut engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
if let Err(e) = engine.on_insert(&doc) {
warn!(
namespace = %ns_key,
error = %e,
"index update failed after successful insert"
);
}
inserted_count += 1;
}
Err(e) => {
let err_msg = e.to_string();
let (code, code_name) = if err_msg.contains("AlreadyExists")
|| err_msg.contains("duplicate")
{
(11000_i32, "DuplicateKey")
} else {
(1_i32, "InternalError")
};
write_errors.push(doc! {
"index": idx as i32,
"code": code,
"codeName": code_name,
"errmsg": &err_msg,
});
if ordered {
// Stop on first error when ordered.
break;
}
}
}
}
// Build response document.
let mut response = doc! {
"n": inserted_count,
"ok": 1.0,
};
if !write_errors.is_empty() {
response.insert(
"writeErrors",
write_errors
.into_iter()
.map(Bson::Document)
.collect::<Vec<_>>(),
);
}
Ok(response)
}
/// Extract documents from the `documents` array field in the command BSON.
fn extract_docs_from_array(cmd: &Document) -> CommandResult<Vec<Document>> {
match cmd.get_array("documents") {
Ok(arr) => {
let mut docs = Vec::with_capacity(arr.len());
for item in arr {
match item {
Bson::Document(d) => docs.push(d.clone()),
_ => {
return Err(CommandError::InvalidArgument(
"documents array contains non-document element".into(),
));
}
}
}
Ok(docs)
}
Err(_) => Ok(Vec::new()),
}
}
/// Ensure the target database and collection exist, creating them if needed.
async fn ensure_collection_exists(
db: &str,
coll: &str,
ctx: &CommandContext,
) -> CommandResult<()> {
// Create database (no-op if it already exists in most backends).
if let Err(e) = ctx.storage.create_database(db).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
// Create collection if it doesn't exist.
match ctx.storage.collection_exists(db, coll).await {
Ok(true) => {}
Ok(false) => {
if let Err(e) = ctx.storage.create_collection(db, coll).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
}
Err(e) => {
// Database might not exist yet; try creating collection anyway.
if let Err(e2) = ctx.storage.create_collection(db, coll).await {
let msg = e2.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(format!(
"collection_exists failed: {e}; create_collection failed: {msg}"
)));
}
}
}
}
Ok(())
}

View File

@@ -0,0 +1,8 @@
pub mod admin_handler;
pub mod aggregate_handler;
pub mod delete_handler;
pub mod find_handler;
pub mod hello_handler;
pub mod index_handler;
pub mod insert_handler;
pub mod update_handler;

View File

@@ -0,0 +1,617 @@
use std::collections::HashSet;
use bson::{doc, oid::ObjectId, Bson, Document};
use rustdb_index::IndexEngine;
use rustdb_query::{QueryMatcher, UpdateEngine, sort_documents, apply_projection};
use tracing::debug;
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
/// Handle `update` and `findAndModify` commands.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
command_name: &str,
) -> CommandResult<Document> {
match command_name {
"findAndModify" | "findandmodify" => handle_find_and_modify(cmd, db, ctx).await,
_ => handle_update(cmd, db, ctx).await,
}
}
/// Handle the `update` command.
async fn handle_update(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("update")
.map_err(|_| CommandError::InvalidArgument("missing 'update' field".into()))?;
let updates = cmd
.get_array("updates")
.map_err(|_| CommandError::InvalidArgument("missing 'updates' array".into()))?;
let ordered = match cmd.get("ordered") {
Some(Bson::Boolean(b)) => *b,
_ => true,
};
debug!(db = db, collection = coll, count = updates.len(), "update command");
// Auto-create database and collection if needed.
ensure_collection_exists(db, coll, ctx).await?;
let ns_key = format!("{}.{}", db, coll);
let mut total_n: i32 = 0;
let mut total_n_modified: i32 = 0;
let mut upserted_list: Vec<Document> = Vec::new();
let mut write_errors: Vec<Document> = Vec::new();
for (idx, update_bson) in updates.iter().enumerate() {
let update_spec = match update_bson {
Bson::Document(d) => d,
_ => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": "update spec must be a document",
});
if ordered {
break;
}
continue;
}
};
let filter = match update_spec.get("q") {
Some(Bson::Document(d)) => d.clone(),
_ => Document::new(),
};
let update = match update_spec.get("u") {
Some(Bson::Document(d)) => d.clone(),
Some(Bson::Array(_pipeline)) => {
// Aggregation pipeline updates are not yet supported; treat as error.
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": "aggregation pipeline updates not yet supported",
});
if ordered {
break;
}
continue;
}
_ => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": "missing or invalid 'u' field in update spec",
});
if ordered {
break;
}
continue;
}
};
let multi = match update_spec.get("multi") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let upsert = match update_spec.get("upsert") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let array_filters: Option<Vec<Document>> =
update_spec.get_array("arrayFilters").ok().map(|arr| {
arr.iter()
.filter_map(|v| {
if let Bson::Document(d) = v {
Some(d.clone())
} else {
None
}
})
.collect()
});
// Load all documents and filter.
let all_docs = load_filtered_docs(db, coll, &filter, &ns_key, ctx).await?;
if all_docs.is_empty() && upsert {
// Upsert: create a new document.
let new_doc = build_upsert_doc(&filter);
// Apply update operators or replacement.
match UpdateEngine::apply_update(&new_doc, &update, array_filters.as_deref()) {
Ok(mut updated) => {
// Apply $setOnInsert if present.
if let Some(Bson::Document(soi)) = update.get("$setOnInsert") {
UpdateEngine::apply_set_on_insert(&mut updated, soi);
}
// Ensure _id exists.
let new_id = if !updated.contains_key("_id") {
let oid = ObjectId::new();
updated.insert("_id", oid);
Bson::ObjectId(oid)
} else {
updated.get("_id").unwrap().clone()
};
// Insert the new document.
match ctx.storage.insert_one(db, coll, updated.clone()).await {
Ok(_) => {
// Update index.
let mut engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
let _ = engine.on_insert(&updated);
total_n += 1;
upserted_list.push(doc! {
"index": idx as i32,
"_id": new_id,
});
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 1_i32,
"codeName": "InternalError",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
} else {
// Update matched documents.
let docs_to_update = if multi {
all_docs
} else {
all_docs.into_iter().take(1).collect()
};
for matched_doc in &docs_to_update {
match UpdateEngine::apply_update(
matched_doc,
&update,
array_filters.as_deref(),
) {
Ok(updated_doc) => {
let id_str = extract_id_string(matched_doc);
match ctx
.storage
.update_by_id(db, coll, &id_str, updated_doc.clone())
.await
{
Ok(()) => {
// Update index.
if let Some(mut engine) = ctx.indexes.get_mut(&ns_key) {
let _ = engine.on_update(matched_doc, &updated_doc);
}
total_n += 1;
// Check if the document actually changed.
if matched_doc != &updated_doc {
total_n_modified += 1;
}
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 1_i32,
"codeName": "InternalError",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
}
}
}
// Build response.
let mut response = doc! {
"n": total_n,
"nModified": total_n_modified,
"ok": 1.0,
};
if !upserted_list.is_empty() {
response.insert(
"upserted",
upserted_list
.into_iter()
.map(Bson::Document)
.collect::<Vec<_>>(),
);
}
if !write_errors.is_empty() {
response.insert(
"writeErrors",
write_errors
.into_iter()
.map(Bson::Document)
.collect::<Vec<_>>(),
);
}
Ok(response)
}
/// Handle the `findAndModify` command.
async fn handle_find_and_modify(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = cmd
.get_str("findAndModify")
.or_else(|_| cmd.get_str("findandmodify"))
.map_err(|_| CommandError::InvalidArgument("missing 'findAndModify' field".into()))?;
let query = match cmd.get("query") {
Some(Bson::Document(d)) => d.clone(),
_ => Document::new(),
};
let sort = match cmd.get("sort") {
Some(Bson::Document(d)) => Some(d.clone()),
_ => None,
};
let update_doc = match cmd.get("update") {
Some(Bson::Document(d)) => Some(d.clone()),
_ => None,
};
let remove = match cmd.get("remove") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let return_new = match cmd.get("new") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let upsert = match cmd.get("upsert") {
Some(Bson::Boolean(b)) => *b,
_ => false,
};
let fields = match cmd.get("fields") {
Some(Bson::Document(d)) => Some(d.clone()),
_ => None,
};
let array_filters: Option<Vec<Document>> =
cmd.get_array("arrayFilters").ok().map(|arr| {
arr.iter()
.filter_map(|v| {
if let Bson::Document(d) = v {
Some(d.clone())
} else {
None
}
})
.collect()
});
// Auto-create database and collection.
ensure_collection_exists(db, coll, ctx).await?;
let ns_key = format!("{}.{}", db, coll);
// Load and filter documents.
let mut matched = load_filtered_docs(db, coll, &query, &ns_key, ctx).await?;
// Sort if specified.
if let Some(ref sort_spec) = sort {
sort_documents(&mut matched, sort_spec);
}
// Take the first match.
let target = matched.into_iter().next();
if remove {
// Remove operation.
if let Some(ref doc) = target {
let id_str = extract_id_string(doc);
ctx.storage.delete_by_id(db, coll, &id_str).await?;
// Update index.
if let Some(mut engine) = ctx.indexes.get_mut(&ns_key) {
engine.on_delete(doc);
}
let value = apply_fields_projection(doc, &fields);
return Ok(doc! {
"value": value,
"lastErrorObject": {
"n": 1_i32,
"updatedExisting": false,
},
"ok": 1.0,
});
} else {
return Ok(doc! {
"value": Bson::Null,
"lastErrorObject": {
"n": 0_i32,
"updatedExisting": false,
},
"ok": 1.0,
});
}
}
// Update operation.
let update = match update_doc {
Some(u) => u,
None => {
return Ok(doc! {
"value": Bson::Null,
"lastErrorObject": {
"n": 0_i32,
"updatedExisting": false,
},
"ok": 1.0,
});
}
};
if let Some(original_doc) = target {
// Update the matched document.
let updated_doc = UpdateEngine::apply_update(
&original_doc,
&update,
array_filters.as_deref(),
)
.map_err(|e| CommandError::InternalError(e.to_string()))?;
let id_str = extract_id_string(&original_doc);
ctx.storage
.update_by_id(db, coll, &id_str, updated_doc.clone())
.await?;
// Update index.
if let Some(mut engine) = ctx.indexes.get_mut(&ns_key) {
let _ = engine.on_update(&original_doc, &updated_doc);
}
let return_doc = if return_new {
&updated_doc
} else {
&original_doc
};
let value = apply_fields_projection(return_doc, &fields);
Ok(doc! {
"value": value,
"lastErrorObject": {
"n": 1_i32,
"updatedExisting": true,
},
"ok": 1.0,
})
} else if upsert {
// Upsert: create a new document.
let new_doc = build_upsert_doc(&query);
let mut updated_doc = UpdateEngine::apply_update(
&new_doc,
&update,
array_filters.as_deref(),
)
.map_err(|e| CommandError::InternalError(e.to_string()))?;
// Apply $setOnInsert if present.
if let Some(Bson::Document(soi)) = update.get("$setOnInsert") {
UpdateEngine::apply_set_on_insert(&mut updated_doc, soi);
}
// Ensure _id.
let upserted_id = if !updated_doc.contains_key("_id") {
let oid = ObjectId::new();
updated_doc.insert("_id", oid);
Bson::ObjectId(oid)
} else {
updated_doc.get("_id").unwrap().clone()
};
ctx.storage
.insert_one(db, coll, updated_doc.clone())
.await?;
// Update index.
{
let mut engine = ctx
.indexes
.entry(ns_key.clone())
.or_insert_with(IndexEngine::new);
let _ = engine.on_insert(&updated_doc);
}
let value = if return_new {
apply_fields_projection(&updated_doc, &fields)
} else {
Bson::Null
};
Ok(doc! {
"value": value,
"lastErrorObject": {
"n": 1_i32,
"updatedExisting": false,
"upserted": upserted_id,
},
"ok": 1.0,
})
} else {
Ok(doc! {
"value": Bson::Null,
"lastErrorObject": {
"n": 0_i32,
"updatedExisting": false,
},
"ok": 1.0,
})
}
}
// ---- Helpers ----
/// Load documents from storage, optionally using index for candidate narrowing, then filter.
async fn load_filtered_docs(
db: &str,
coll: &str,
filter: &Document,
ns_key: &str,
ctx: &CommandContext,
) -> CommandResult<Vec<Document>> {
// Try to use index to narrow candidates.
let candidate_ids: Option<HashSet<String>> = ctx
.indexes
.get(ns_key)
.and_then(|engine| engine.find_candidate_ids(filter));
let docs = if let Some(ids) = candidate_ids {
if ids.is_empty() {
return Ok(Vec::new());
}
ctx.storage.find_by_ids(db, coll, ids).await?
} else {
ctx.storage.find_all(db, coll).await?
};
// Apply filter.
if filter.is_empty() {
Ok(docs)
} else {
Ok(QueryMatcher::filter(&docs, filter))
}
}
/// Build a base document for an upsert from the filter's equality conditions.
fn build_upsert_doc(filter: &Document) -> Document {
let mut doc = Document::new();
for (key, value) in filter {
if key.starts_with('$') {
// Skip top-level operators like $and, $or.
continue;
}
match value {
Bson::Document(d) if d.keys().any(|k| k.starts_with('$')) => {
// If the value has operators (e.g., $gt), extract $eq if present.
if let Some(eq_val) = d.get("$eq") {
doc.insert(key.clone(), eq_val.clone());
}
}
_ => {
doc.insert(key.clone(), value.clone());
}
}
}
doc
}
/// Extract _id as a string for storage operations.
fn extract_id_string(doc: &Document) -> String {
match doc.get("_id") {
Some(Bson::ObjectId(oid)) => oid.to_hex(),
Some(Bson::String(s)) => s.clone(),
Some(other) => format!("{}", other),
None => String::new(),
}
}
/// Apply fields projection if specified, returning Bson.
fn apply_fields_projection(doc: &Document, fields: &Option<Document>) -> Bson {
match fields {
Some(proj) if !proj.is_empty() => Bson::Document(apply_projection(doc, proj)),
_ => Bson::Document(doc.clone()),
}
}
/// Ensure the target database and collection exist, creating them if needed.
async fn ensure_collection_exists(
db: &str,
coll: &str,
ctx: &CommandContext,
) -> CommandResult<()> {
if let Err(e) = ctx.storage.create_database(db).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
match ctx.storage.collection_exists(db, coll).await {
Ok(true) => {}
Ok(false) => {
if let Err(e) = ctx.storage.create_collection(db, coll).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
}
Err(e) => {
if let Err(e2) = ctx.storage.create_collection(db, coll).await {
let msg = e2.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(format!(
"collection_exists failed: {e}; create_collection failed: {msg}"
)));
}
}
}
}
Ok(())
}

View File

@@ -0,0 +1,8 @@
mod context;
pub mod error;
pub mod handlers;
mod router;
pub use context::{CommandContext, CursorState};
pub use error::{CommandError, CommandResult};
pub use router::CommandRouter;

View File

@@ -0,0 +1,109 @@
use std::sync::Arc;
use bson::Document;
use tracing::{debug, warn};
use rustdb_wire::ParsedCommand;
use crate::context::CommandContext;
use crate::error::CommandError;
use crate::handlers;
/// Routes parsed wire protocol commands to the appropriate handler.
pub struct CommandRouter {
ctx: Arc<CommandContext>,
}
impl CommandRouter {
/// Create a new command router with the given context.
pub fn new(ctx: Arc<CommandContext>) -> Self {
Self { ctx }
}
/// Route a parsed command to the appropriate handler, returning a BSON response document.
pub async fn route(&self, cmd: &ParsedCommand) -> Document {
let db = &cmd.database;
let command_name = cmd.command_name.as_str();
debug!(command = %command_name, database = %db, "routing command");
// Extract session id if present, and touch the session.
if let Some(lsid) = cmd.command.get("lsid") {
if let Some(session_id) = rustdb_txn::SessionEngine::extract_session_id(lsid) {
self.ctx.sessions.get_or_create_session(&session_id);
}
}
let result = match command_name {
// -- handshake / monitoring --
"hello" | "ismaster" | "isMaster" => {
handlers::hello_handler::handle(&cmd.command, db, &self.ctx).await
}
// -- query commands --
"find" => {
handlers::find_handler::handle(&cmd.command, db, &self.ctx).await
}
"getMore" => {
handlers::find_handler::handle_get_more(&cmd.command, db, &self.ctx).await
}
"killCursors" => {
handlers::find_handler::handle_kill_cursors(&cmd.command, db, &self.ctx).await
}
"count" => {
handlers::find_handler::handle_count(&cmd.command, db, &self.ctx).await
}
"distinct" => {
handlers::find_handler::handle_distinct(&cmd.command, db, &self.ctx).await
}
// -- write commands --
"insert" => {
handlers::insert_handler::handle(&cmd.command, db, &self.ctx, cmd.document_sequences.as_ref()).await
}
"update" | "findAndModify" => {
handlers::update_handler::handle(&cmd.command, db, &self.ctx, command_name).await
}
"delete" => {
handlers::delete_handler::handle(&cmd.command, db, &self.ctx).await
}
// -- aggregation --
"aggregate" => {
handlers::aggregate_handler::handle(&cmd.command, db, &self.ctx).await
}
// -- index management --
"createIndexes" | "dropIndexes" | "listIndexes" => {
handlers::index_handler::handle(&cmd.command, db, &self.ctx, command_name).await
}
// -- admin commands --
"ping" | "buildInfo" | "buildinfo" | "serverStatus" | "hostInfo"
| "whatsmyuri" | "getLog" | "replSetGetStatus" | "getCmdLineOpts"
| "getParameter" | "getFreeMonitoringStatus" | "setFreeMonitoring"
| "getShardMap" | "shardingState" | "atlasVersion"
| "connectionStatus" | "listDatabases" | "listCollections"
| "create" | "drop" | "dropDatabase" | "renameCollection"
| "dbStats" | "collStats" | "validate" | "explain"
| "startSession" | "endSessions" | "killSessions"
| "commitTransaction" | "abortTransaction"
| "saslStart" | "saslContinue" | "authenticate" | "logout"
| "currentOp" | "killOp" | "top" | "profile"
| "compact" | "reIndex" | "fsync" | "connPoolSync" => {
handlers::admin_handler::handle(&cmd.command, db, &self.ctx, command_name).await
}
// -- unknown command --
other => {
warn!(command = %other, "unknown command");
Err(CommandError::NotImplemented(other.to_string()))
}
};
match result {
Ok(doc) => doc,
Err(e) => e.to_error_doc(),
}
}
}

View File

@@ -0,0 +1,12 @@
[package]
name = "rustdb-config"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Configuration types for RustDb, compatible with SmartDB JSON schema"
[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,181 @@
use serde::{Deserialize, Serialize};
/// Storage backend type.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum StorageType {
Memory,
File,
}
impl Default for StorageType {
fn default() -> Self {
StorageType::Memory
}
}
/// Top-level configuration for RustDb server.
/// Field names use camelCase to match the TypeScript SmartdbServer options.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RustDbOptions {
/// TCP port to listen on (default: 27017)
#[serde(default = "default_port")]
pub port: u16,
/// Host/IP to bind to (default: "127.0.0.1")
#[serde(default = "default_host")]
pub host: String,
/// Unix socket path (overrides TCP if set)
#[serde(skip_serializing_if = "Option::is_none")]
pub socket_path: Option<String>,
/// Storage backend type
#[serde(default)]
pub storage: StorageType,
/// Base path for file storage (required when storage = "file")
#[serde(skip_serializing_if = "Option::is_none")]
pub storage_path: Option<String>,
/// Path for periodic persistence of in-memory data
#[serde(skip_serializing_if = "Option::is_none")]
pub persist_path: Option<String>,
/// Interval in ms for periodic persistence (default: 60000)
#[serde(default = "default_persist_interval")]
pub persist_interval_ms: u64,
}
fn default_port() -> u16 {
27017
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_persist_interval() -> u64 {
60000
}
impl Default for RustDbOptions {
fn default() -> Self {
Self {
port: default_port(),
host: default_host(),
socket_path: None,
storage: StorageType::default(),
storage_path: None,
persist_path: None,
persist_interval_ms: default_persist_interval(),
}
}
}
impl RustDbOptions {
/// Load options from a JSON config file.
pub fn from_file(path: &str) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path)
.map_err(|e| ConfigError::IoError(e.to_string()))?;
let options: Self = serde_json::from_str(&content)
.map_err(|e| ConfigError::ParseError(e.to_string()))?;
options.validate()?;
Ok(options)
}
/// Validate the configuration.
pub fn validate(&self) -> Result<(), ConfigError> {
if self.storage == StorageType::File && self.storage_path.is_none() {
return Err(ConfigError::ValidationError(
"storagePath is required when storage is 'file'".to_string(),
));
}
Ok(())
}
/// Get the connection URI for this server configuration.
pub fn connection_uri(&self) -> String {
if let Some(ref socket_path) = self.socket_path {
let encoded = urlencoding(socket_path);
format!("mongodb://{}", encoded)
} else {
format!("mongodb://{}:{}", self.host, self.port)
}
}
}
/// Simple URL encoding for socket paths (encode / as %2F, etc.)
fn urlencoding(s: &str) -> String {
s.chars()
.map(|c| match c {
'/' => "%2F".to_string(),
':' => "%3A".to_string(),
' ' => "%20".to_string(),
_ => c.to_string(),
})
.collect()
}
/// Configuration errors.
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
IoError(String),
#[error("Parse error: {0}")]
ParseError(String),
#[error("Validation error: {0}")]
ValidationError(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_options() {
let opts = RustDbOptions::default();
assert_eq!(opts.port, 27017);
assert_eq!(opts.host, "127.0.0.1");
assert!(opts.socket_path.is_none());
assert_eq!(opts.storage, StorageType::Memory);
}
#[test]
fn test_deserialize_from_json() {
let json = r#"{"port": 27018, "storage": "file", "storagePath": "./data"}"#;
let opts: RustDbOptions = serde_json::from_str(json).unwrap();
assert_eq!(opts.port, 27018);
assert_eq!(opts.storage, StorageType::File);
assert_eq!(opts.storage_path, Some("./data".to_string()));
}
#[test]
fn test_connection_uri_tcp() {
let opts = RustDbOptions::default();
assert_eq!(opts.connection_uri(), "mongodb://127.0.0.1:27017");
}
#[test]
fn test_connection_uri_socket() {
let opts = RustDbOptions {
socket_path: Some("/tmp/smartdb-test.sock".to_string()),
..Default::default()
};
assert_eq!(
opts.connection_uri(),
"mongodb://%2Ftmp%2Fsmartdb-test.sock"
);
}
#[test]
fn test_validation_file_storage_requires_path() {
let opts = RustDbOptions {
storage: StorageType::File,
storage_path: None,
..Default::default()
};
assert!(opts.validate().is_err());
}
}

View File

@@ -0,0 +1,15 @@
[package]
name = "rustdb-index"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible B-tree and hash index engine with query planner for RustDb"
[dependencies]
bson = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
rustdb-query = { workspace = true }

View File

@@ -0,0 +1,691 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use bson::{Bson, Document};
use tracing::{debug, trace};
use rustdb_query::get_nested_value;
use crate::error::IndexError;
/// Options for creating an index.
#[derive(Debug, Clone, Default)]
pub struct IndexOptions {
/// Custom name for the index. Auto-generated if None.
pub name: Option<String>,
/// Whether the index enforces unique values.
pub unique: bool,
/// Whether the index skips documents missing the indexed field.
pub sparse: bool,
/// TTL in seconds (for date fields). None means no expiry.
pub expire_after_seconds: Option<u64>,
}
/// Metadata about an existing index.
#[derive(Debug, Clone)]
pub struct IndexInfo {
/// Index version (always 2).
pub v: i32,
/// The key specification document (e.g. {"name": 1}).
pub key: Document,
/// The index name.
pub name: String,
/// Whether the index enforces uniqueness.
pub unique: bool,
/// Whether the index is sparse.
pub sparse: bool,
/// TTL expiry in seconds, if set.
pub expire_after_seconds: Option<u64>,
}
/// Internal data for a single index.
struct IndexData {
/// The key specification (field -> direction).
key: Document,
/// The index name.
name: String,
/// Whether uniqueness is enforced.
unique: bool,
/// Whether the index is sparse.
sparse: bool,
/// TTL in seconds.
expire_after_seconds: Option<u64>,
/// B-tree for range queries: serialized key bytes -> set of document _id hex strings.
btree: BTreeMap<Vec<u8>, BTreeSet<String>>,
/// Hash map for equality lookups: serialized key bytes -> set of document _id hex strings.
hash: HashMap<Vec<u8>, HashSet<String>>,
}
impl IndexData {
fn new(key: Document, name: String, unique: bool, sparse: bool, expire_after_seconds: Option<u64>) -> Self {
Self {
key,
name,
unique,
sparse,
expire_after_seconds,
btree: BTreeMap::new(),
hash: HashMap::new(),
}
}
fn to_info(&self) -> IndexInfo {
IndexInfo {
v: 2,
key: self.key.clone(),
name: self.name.clone(),
unique: self.unique,
sparse: self.sparse,
expire_after_seconds: self.expire_after_seconds,
}
}
}
/// Manages indexes for a single collection.
pub struct IndexEngine {
/// All indexes keyed by name.
indexes: HashMap<String, IndexData>,
}
impl IndexEngine {
/// Create a new IndexEngine with the default `_id_` index.
pub fn new() -> Self {
let mut indexes = HashMap::new();
let id_key = bson::doc! { "_id": 1 };
let id_index = IndexData::new(id_key, "_id_".to_string(), true, false, None);
indexes.insert("_id_".to_string(), id_index);
Self { indexes }
}
/// Create a new index. Returns the index name.
pub fn create_index(&mut self, key: Document, options: IndexOptions) -> Result<String, IndexError> {
if key.is_empty() {
return Err(IndexError::InvalidIndex("Index key must have at least one field".to_string()));
}
let name = options.name.unwrap_or_else(|| Self::generate_index_name(&key));
if self.indexes.contains_key(&name) {
debug!(index_name = %name, "Index already exists, returning existing");
return Ok(name);
}
debug!(index_name = %name, unique = options.unique, sparse = options.sparse, "Creating index");
let index_data = IndexData::new(
key,
name.clone(),
options.unique,
options.sparse,
options.expire_after_seconds,
);
self.indexes.insert(name.clone(), index_data);
Ok(name)
}
/// Drop an index by name. Returns true if the index existed.
/// Cannot drop the `_id_` index.
pub fn drop_index(&mut self, name: &str) -> Result<bool, IndexError> {
if name == "_id_" {
return Err(IndexError::ProtectedIndex("_id_".to_string()));
}
let existed = self.indexes.remove(name).is_some();
if existed {
debug!(index_name = %name, "Dropped index");
}
Ok(existed)
}
/// Drop all indexes except `_id_`.
pub fn drop_all_indexes(&mut self) {
self.indexes.retain(|name, _| name == "_id_");
debug!("Dropped all non-_id indexes");
}
/// List all indexes.
pub fn list_indexes(&self) -> Vec<IndexInfo> {
self.indexes.values().map(|idx| idx.to_info()).collect()
}
/// Check whether an index with the given name exists.
pub fn index_exists(&self, name: &str) -> bool {
self.indexes.contains_key(name)
}
/// Notify the engine that a document has been inserted.
/// Checks unique constraints and updates all index structures.
pub fn on_insert(&mut self, doc: &Document) -> Result<(), IndexError> {
let doc_id = Self::extract_id(doc);
// First pass: check unique constraints
for idx in self.indexes.values() {
if idx.unique {
let key_bytes = Self::extract_key_bytes(doc, &idx.key, idx.sparse);
if let Some(ref kb) = key_bytes {
if let Some(existing_ids) = idx.hash.get(kb) {
if !existing_ids.is_empty() {
return Err(IndexError::DuplicateKey {
index: idx.name.clone(),
key: format!("{:?}", kb),
});
}
}
}
}
}
// Second pass: insert into all indexes
for idx in self.indexes.values_mut() {
let key_bytes = Self::extract_key_bytes(doc, &idx.key, idx.sparse);
if let Some(kb) = key_bytes {
idx.btree.entry(kb.clone()).or_default().insert(doc_id.clone());
idx.hash.entry(kb).or_default().insert(doc_id.clone());
}
}
trace!(doc_id = %doc_id, "Indexed document on insert");
Ok(())
}
/// Notify the engine that a document has been updated.
pub fn on_update(&mut self, old_doc: &Document, new_doc: &Document) -> Result<(), IndexError> {
let doc_id = Self::extract_id(old_doc);
// Check unique constraints for the new document (excluding the document itself)
for idx in self.indexes.values() {
if idx.unique {
let new_key_bytes = Self::extract_key_bytes(new_doc, &idx.key, idx.sparse);
if let Some(ref kb) = new_key_bytes {
if let Some(existing_ids) = idx.hash.get(kb) {
// If there are existing entries that aren't this document, it's a conflict
let other_ids: HashSet<_> = existing_ids.iter()
.filter(|id| **id != doc_id)
.collect();
if !other_ids.is_empty() {
return Err(IndexError::DuplicateKey {
index: idx.name.clone(),
key: format!("{:?}", kb),
});
}
}
}
}
}
// Remove old entries and insert new ones
for idx in self.indexes.values_mut() {
let old_key_bytes = Self::extract_key_bytes(old_doc, &idx.key, idx.sparse);
if let Some(ref kb) = old_key_bytes {
if let Some(set) = idx.btree.get_mut(kb) {
set.remove(&doc_id);
if set.is_empty() {
idx.btree.remove(kb);
}
}
if let Some(set) = idx.hash.get_mut(kb) {
set.remove(&doc_id);
if set.is_empty() {
idx.hash.remove(kb);
}
}
}
let new_key_bytes = Self::extract_key_bytes(new_doc, &idx.key, idx.sparse);
if let Some(kb) = new_key_bytes {
idx.btree.entry(kb.clone()).or_default().insert(doc_id.clone());
idx.hash.entry(kb).or_default().insert(doc_id.clone());
}
}
trace!(doc_id = %doc_id, "Re-indexed document on update");
Ok(())
}
/// Notify the engine that a document has been deleted.
pub fn on_delete(&mut self, doc: &Document) {
let doc_id = Self::extract_id(doc);
for idx in self.indexes.values_mut() {
let key_bytes = Self::extract_key_bytes(doc, &idx.key, idx.sparse);
if let Some(ref kb) = key_bytes {
if let Some(set) = idx.btree.get_mut(kb) {
set.remove(&doc_id);
if set.is_empty() {
idx.btree.remove(kb);
}
}
if let Some(set) = idx.hash.get_mut(kb) {
set.remove(&doc_id);
if set.is_empty() {
idx.hash.remove(kb);
}
}
}
}
trace!(doc_id = %doc_id, "Removed document from indexes");
}
/// Attempt to find candidate document IDs using indexes for the given filter.
/// Returns `None` if no suitable index is found (meaning a COLLSCAN is needed).
/// Returns `Some(set)` with candidate IDs that should be checked against the full filter.
pub fn find_candidate_ids(&self, filter: &Document) -> Option<HashSet<String>> {
if filter.is_empty() {
return None;
}
// Try each index to see which can serve this query
let mut best_candidates: Option<HashSet<String>> = None;
let mut best_score: f64 = 0.0;
for idx in self.indexes.values() {
if let Some((candidates, score)) = self.try_index_lookup(idx, filter) {
if score > best_score {
best_score = score;
best_candidates = Some(candidates);
}
}
}
best_candidates
}
/// Rebuild all indexes from a full set of documents.
pub fn rebuild_from_documents(&mut self, docs: &[Document]) {
// Clear all index data
for idx in self.indexes.values_mut() {
idx.btree.clear();
idx.hash.clear();
}
// Re-index all documents
for doc in docs {
let doc_id = Self::extract_id(doc);
for idx in self.indexes.values_mut() {
let key_bytes = Self::extract_key_bytes(doc, &idx.key, idx.sparse);
if let Some(kb) = key_bytes {
idx.btree.entry(kb.clone()).or_default().insert(doc_id.clone());
idx.hash.entry(kb).or_default().insert(doc_id.clone());
}
}
}
debug!(num_docs = docs.len(), num_indexes = self.indexes.len(), "Rebuilt all indexes");
}
// ---- Internal helpers ----
/// Try to use an index for the given filter. Returns candidate IDs and a score.
fn try_index_lookup(&self, idx: &IndexData, filter: &Document) -> Option<(HashSet<String>, f64)> {
let index_fields: Vec<String> = idx.key.keys().map(|k| k.to_string()).collect();
// Check if the filter uses fields covered by this index
let mut matched_any = false;
let mut result_set: Option<HashSet<String>> = None;
let mut total_score: f64 = 0.0;
for field in &index_fields {
if let Some(condition) = filter.get(field) {
matched_any = true;
let (candidates, score) = self.lookup_field(idx, field, condition);
total_score += score;
// Add unique bonus
if idx.unique {
total_score += 0.5;
}
result_set = Some(match result_set {
Some(existing) => existing.intersection(&candidates).cloned().collect(),
None => candidates,
});
}
}
if !matched_any {
return None;
}
result_set.map(|rs| (rs, total_score))
}
/// Look up candidates for a single field condition in an index.
fn lookup_field(&self, idx: &IndexData, field: &str, condition: &Bson) -> (HashSet<String>, f64) {
match condition {
// Equality match
Bson::Document(cond_doc) if Self::has_operators(cond_doc) => {
self.lookup_operator(idx, field, cond_doc)
}
// Direct equality
_ => {
let key_bytes = Self::bson_to_key_bytes(condition);
let candidates = idx.hash
.get(&key_bytes)
.cloned()
.unwrap_or_default();
(candidates, 2.0) // equality score
}
}
}
/// Handle operator-based lookups ($eq, $in, $gt, $lt, etc.).
fn lookup_operator(&self, idx: &IndexData, field: &str, operators: &Document) -> (HashSet<String>, f64) {
let mut candidates = HashSet::new();
let mut score: f64 = 0.0;
let mut has_range = false;
for (op, value) in operators {
match op.as_str() {
"$eq" => {
let key_bytes = Self::bson_to_key_bytes(value);
if let Some(ids) = idx.hash.get(&key_bytes) {
candidates = if candidates.is_empty() {
ids.clone()
} else {
candidates.intersection(ids).cloned().collect()
};
}
score += 2.0;
}
"$in" => {
if let Bson::Array(arr) = value {
let mut in_candidates = HashSet::new();
for v in arr {
let key_bytes = Self::bson_to_key_bytes(v);
if let Some(ids) = idx.hash.get(&key_bytes) {
in_candidates.extend(ids.iter().cloned());
}
}
candidates = if candidates.is_empty() {
in_candidates
} else {
candidates.intersection(&in_candidates).cloned().collect()
};
score += 1.5;
}
}
"$gt" | "$gte" | "$lt" | "$lte" => {
let range_candidates = self.range_scan(idx, field, op.as_str(), value);
candidates = if candidates.is_empty() && !has_range {
range_candidates
} else {
candidates.intersection(&range_candidates).cloned().collect()
};
has_range = true;
score += 1.0;
}
_ => {
// Operators like $ne, $nin, $exists, $regex are not efficiently indexable
// Return all indexed IDs for this index
}
}
}
// If we only had non-indexable operators, return empty with 0 score
if score == 0.0 {
return (HashSet::new(), 0.0);
}
(candidates, score)
}
/// Perform a range scan on the B-tree index.
fn range_scan(&self, idx: &IndexData, _field: &str, op: &str, bound: &Bson) -> HashSet<String> {
let bound_bytes = Self::bson_to_key_bytes(bound);
let mut result = HashSet::new();
match op {
"$gt" => {
use std::ops::Bound;
for (_key, ids) in idx.btree.range((Bound::Excluded(bound_bytes), Bound::Unbounded)) {
result.extend(ids.iter().cloned());
}
}
"$gte" => {
for (_key, ids) in idx.btree.range(bound_bytes..) {
result.extend(ids.iter().cloned());
}
}
"$lt" => {
for (_key, ids) in idx.btree.range(..bound_bytes) {
result.extend(ids.iter().cloned());
}
}
"$lte" => {
for (_key, ids) in idx.btree.range(..=bound_bytes) {
result.extend(ids.iter().cloned());
}
}
_ => {}
}
result
}
/// Generate an index name from the key spec (e.g. {"name": 1, "age": -1} -> "name_1_age_-1").
fn generate_index_name(key: &Document) -> String {
key.iter()
.map(|(field, dir)| {
let dir_val = match dir {
Bson::Int32(n) => n.to_string(),
Bson::Int64(n) => n.to_string(),
Bson::String(s) => s.clone(),
_ => "1".to_string(),
};
format!("{}_{}", field, dir_val)
})
.collect::<Vec<_>>()
.join("_")
}
/// Extract the `_id` field from a document as a hex string.
fn extract_id(doc: &Document) -> String {
match doc.get("_id") {
Some(Bson::ObjectId(oid)) => oid.to_hex(),
Some(Bson::String(s)) => s.clone(),
Some(other) => format!("{}", other),
None => String::new(),
}
}
/// Extract the index key bytes from a document for a given key specification.
/// Returns `None` if the document should be skipped (sparse index with missing fields).
fn extract_key_bytes(doc: &Document, key_spec: &Document, sparse: bool) -> Option<Vec<u8>> {
let fields: Vec<(&str, &Bson)> = key_spec.iter().map(|(k, v)| (k.as_str(), v)).collect();
if fields.len() == 1 {
// Single-field index
let field = fields[0].0;
let value = Self::resolve_field_value(doc, field);
if sparse && value.is_none() {
return None;
}
let val = value.unwrap_or(Bson::Null);
Some(Self::bson_to_key_bytes(&val))
} else {
// Compound index: concatenate field values
let mut all_null = true;
let mut compound_bytes = Vec::new();
for (field, _dir) in &fields {
let value = Self::resolve_field_value(doc, field);
if value.is_some() {
all_null = false;
}
let val = value.unwrap_or(Bson::Null);
let field_bytes = Self::bson_to_key_bytes(&val);
// Length-prefix each field for unambiguous concatenation
compound_bytes.extend_from_slice(&(field_bytes.len() as u32).to_be_bytes());
compound_bytes.extend_from_slice(&field_bytes);
}
if sparse && all_null {
return None;
}
Some(compound_bytes)
}
}
/// Resolve a field value from a document, supporting dot notation.
fn resolve_field_value(doc: &Document, field: &str) -> Option<Bson> {
if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
}
}
/// Serialize a BSON value to bytes for use as an index key.
fn bson_to_key_bytes(value: &Bson) -> Vec<u8> {
// Use BSON raw serialization for consistent byte representation.
// We wrap in a document since raw BSON requires a top-level document.
let wrapper = bson::doc! { "k": value.clone() };
let raw = bson::to_vec(&wrapper).unwrap_or_default();
raw
}
fn has_operators(doc: &Document) -> bool {
doc.keys().any(|k| k.starts_with('$'))
}
}
impl Default for IndexEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use bson::oid::ObjectId;
fn make_doc(name: &str, age: i32) -> Document {
bson::doc! {
"_id": ObjectId::new(),
"name": name,
"age": age,
}
}
#[test]
fn test_default_id_index() {
let engine = IndexEngine::new();
assert!(engine.index_exists("_id_"));
assert_eq!(engine.list_indexes().len(), 1);
}
#[test]
fn test_create_and_drop_index() {
let mut engine = IndexEngine::new();
let name = engine.create_index(
bson::doc! { "name": 1 },
IndexOptions::default(),
).unwrap();
assert_eq!(name, "name_1");
assert!(engine.index_exists("name_1"));
assert!(engine.drop_index("name_1").unwrap());
assert!(!engine.index_exists("name_1"));
}
#[test]
fn test_cannot_drop_id_index() {
let mut engine = IndexEngine::new();
let result = engine.drop_index("_id_");
assert!(result.is_err());
}
#[test]
fn test_unique_constraint() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "email": 1 },
IndexOptions { unique: true, ..Default::default() },
).unwrap();
let doc1 = bson::doc! { "_id": ObjectId::new(), "email": "a@b.com" };
let doc2 = bson::doc! { "_id": ObjectId::new(), "email": "a@b.com" };
engine.on_insert(&doc1).unwrap();
let result = engine.on_insert(&doc2);
assert!(result.is_err());
}
#[test]
fn test_find_candidates_equality() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "name": 1 },
IndexOptions::default(),
).unwrap();
let doc1 = make_doc("Alice", 30);
let doc2 = make_doc("Bob", 25);
let doc3 = make_doc("Alice", 35);
engine.on_insert(&doc1).unwrap();
engine.on_insert(&doc2).unwrap();
engine.on_insert(&doc3).unwrap();
let filter = bson::doc! { "name": "Alice" };
let candidates = engine.find_candidate_ids(&filter);
assert!(candidates.is_some());
assert_eq!(candidates.unwrap().len(), 2);
}
#[test]
fn test_on_delete() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "name": 1 },
IndexOptions::default(),
).unwrap();
let doc = make_doc("Alice", 30);
engine.on_insert(&doc).unwrap();
let filter = bson::doc! { "name": "Alice" };
assert!(engine.find_candidate_ids(&filter).is_some());
engine.on_delete(&doc);
let candidates = engine.find_candidate_ids(&filter);
assert!(candidates.is_some());
assert!(candidates.unwrap().is_empty());
}
#[test]
fn test_rebuild_from_documents() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "name": 1 },
IndexOptions::default(),
).unwrap();
let docs = vec![
make_doc("Alice", 30),
make_doc("Bob", 25),
];
engine.rebuild_from_documents(&docs);
let filter = bson::doc! { "name": "Alice" };
let candidates = engine.find_candidate_ids(&filter);
assert!(candidates.is_some());
assert_eq!(candidates.unwrap().len(), 1);
}
#[test]
fn test_drop_all_indexes() {
let mut engine = IndexEngine::new();
engine.create_index(bson::doc! { "a": 1 }, IndexOptions::default()).unwrap();
engine.create_index(bson::doc! { "b": 1 }, IndexOptions::default()).unwrap();
assert_eq!(engine.list_indexes().len(), 3);
engine.drop_all_indexes();
assert_eq!(engine.list_indexes().len(), 1);
assert!(engine.index_exists("_id_"));
}
}

View File

@@ -0,0 +1,15 @@
/// Errors from index operations.
#[derive(Debug, thiserror::Error)]
pub enum IndexError {
#[error("Duplicate key error: index '{index}' has duplicate value for key {key}")]
DuplicateKey { index: String, key: String },
#[error("Index not found: {0}")]
IndexNotFound(String),
#[error("Invalid index specification: {0}")]
InvalidIndex(String),
#[error("Cannot drop protected index: {0}")]
ProtectedIndex(String),
}

View File

@@ -0,0 +1,7 @@
mod engine;
mod planner;
pub mod error;
pub use engine::{IndexEngine, IndexInfo, IndexOptions};
pub use planner::{QueryPlan, QueryPlanner};
pub use error::IndexError;

View File

@@ -0,0 +1,239 @@
use std::collections::HashSet;
use bson::{Bson, Document};
use tracing::debug;
use crate::engine::IndexEngine;
/// The execution plan for a query.
#[derive(Debug, Clone)]
pub enum QueryPlan {
/// Full collection scan - no suitable index found.
CollScan,
/// Index scan with exact/equality matches.
IxScan {
/// Name of the index used.
index_name: String,
/// Candidate document IDs from the index.
candidate_ids: HashSet<String>,
},
/// Index scan with range-based matches.
IxScanRange {
/// Name of the index used.
index_name: String,
/// Candidate document IDs from the range scan.
candidate_ids: HashSet<String>,
},
}
/// Plans query execution by selecting the best available index.
pub struct QueryPlanner;
impl QueryPlanner {
/// Analyze a filter and the available indexes to produce a query plan.
pub fn plan(filter: &Document, engine: &IndexEngine) -> QueryPlan {
if filter.is_empty() {
debug!("Empty filter -> CollScan");
return QueryPlan::CollScan;
}
let indexes = engine.list_indexes();
let mut best_plan: Option<QueryPlan> = None;
let mut best_score: f64 = 0.0;
for idx_info in &indexes {
let index_fields: Vec<String> = idx_info.key.keys().map(|k| k.to_string()).collect();
let mut matched = false;
let mut score: f64 = 0.0;
let mut is_range = false;
for field in &index_fields {
if let Some(condition) = filter.get(field) {
matched = true;
let field_score = Self::score_condition(condition);
score += field_score;
if Self::is_range_condition(condition) {
is_range = true;
}
}
}
if !matched {
continue;
}
// Unique index bonus
if idx_info.unique {
score += 0.5;
}
if score > best_score {
best_score = score;
// Try to get candidates from the engine
// We build a sub-filter with only the fields this index covers
let mut sub_filter = Document::new();
for field in &index_fields {
if let Some(val) = filter.get(field) {
sub_filter.insert(field.clone(), val.clone());
}
}
if let Some(candidates) = engine.find_candidate_ids(&sub_filter) {
if is_range {
best_plan = Some(QueryPlan::IxScanRange {
index_name: idx_info.name.clone(),
candidate_ids: candidates,
});
} else {
best_plan = Some(QueryPlan::IxScan {
index_name: idx_info.name.clone(),
candidate_ids: candidates,
});
}
}
}
}
match best_plan {
Some(plan) => {
debug!(score = best_score, "Selected index plan");
plan
}
None => {
debug!("No suitable index found -> CollScan");
QueryPlan::CollScan
}
}
}
/// Score a filter condition for index selectivity.
/// Higher scores indicate more selective (better) index usage.
fn score_condition(condition: &Bson) -> f64 {
match condition {
Bson::Document(doc) if Self::has_operators(doc) => {
let mut score: f64 = 0.0;
for (op, _) in doc {
score += match op.as_str() {
"$eq" => 2.0,
"$in" => 1.5,
"$gt" | "$gte" | "$lt" | "$lte" => 1.0,
_ => 0.0,
};
}
score
}
// Direct equality
_ => 2.0,
}
}
/// Check if a condition involves range operators.
fn is_range_condition(condition: &Bson) -> bool {
match condition {
Bson::Document(doc) => {
doc.keys().any(|k| matches!(k.as_str(), "$gt" | "$gte" | "$lt" | "$lte"))
}
_ => false,
}
}
fn has_operators(doc: &Document) -> bool {
doc.keys().any(|k| k.starts_with('$'))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::IndexOptions;
use bson::oid::ObjectId;
#[test]
fn test_empty_filter_collscan() {
let engine = IndexEngine::new();
let plan = QueryPlanner::plan(&bson::doc! {}, &engine);
assert!(matches!(plan, QueryPlan::CollScan));
}
#[test]
fn test_id_equality_ixscan() {
let mut engine = IndexEngine::new();
let oid = ObjectId::new();
let doc = bson::doc! { "_id": oid.clone(), "name": "Alice" };
engine.on_insert(&doc).unwrap();
let filter = bson::doc! { "_id": oid };
let plan = QueryPlanner::plan(&filter, &engine);
assert!(matches!(plan, QueryPlan::IxScan { .. }));
}
#[test]
fn test_indexed_field_ixscan() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "status": 1 },
IndexOptions::default(),
).unwrap();
let doc = bson::doc! { "_id": ObjectId::new(), "status": "active" };
engine.on_insert(&doc).unwrap();
let filter = bson::doc! { "status": "active" };
let plan = QueryPlanner::plan(&filter, &engine);
assert!(matches!(plan, QueryPlan::IxScan { .. }));
}
#[test]
fn test_unindexed_field_collscan() {
let engine = IndexEngine::new();
let filter = bson::doc! { "unindexed_field": "value" };
let plan = QueryPlanner::plan(&filter, &engine);
assert!(matches!(plan, QueryPlan::CollScan));
}
#[test]
fn test_range_query_ixscan_range() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "age": 1 },
IndexOptions::default(),
).unwrap();
let doc = bson::doc! { "_id": ObjectId::new(), "age": 30 };
engine.on_insert(&doc).unwrap();
let filter = bson::doc! { "age": { "$gte": 25, "$lt": 35 } };
let plan = QueryPlanner::plan(&filter, &engine);
assert!(matches!(plan, QueryPlan::IxScanRange { .. }));
}
#[test]
fn test_unique_index_preferred() {
let mut engine = IndexEngine::new();
engine.create_index(
bson::doc! { "email": 1 },
IndexOptions { unique: true, ..Default::default() },
).unwrap();
engine.create_index(
bson::doc! { "email": 1, "name": 1 },
IndexOptions { name: Some("email_name".to_string()), ..Default::default() },
).unwrap();
let doc = bson::doc! { "_id": ObjectId::new(), "email": "a@b.com", "name": "Alice" };
engine.on_insert(&doc).unwrap();
let filter = bson::doc! { "email": "a@b.com" };
let plan = QueryPlanner::plan(&filter, &engine);
// The unique index on email should be preferred (higher score)
match plan {
QueryPlan::IxScan { index_name, .. } => {
assert_eq!(index_name, "email_1");
}
_ => panic!("Expected IxScan"),
}
}
}

View File

View File

@@ -0,0 +1,15 @@
[package]
name = "rustdb-query"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible query matching, update operators, aggregation, sort, and projection engine"
[dependencies]
bson = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
regex = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }

View File

@@ -0,0 +1,614 @@
use bson::{Bson, Document};
use std::collections::HashMap;
use crate::error::QueryError;
use crate::matcher::QueryMatcher;
use crate::sort::sort_documents;
use crate::projection::apply_projection;
use crate::field_path::get_nested_value;
/// Aggregation pipeline engine.
pub struct AggregationEngine;
/// Trait for resolving cross-collection data (for $lookup, $graphLookup, etc.).
pub trait CollectionResolver {
fn resolve(&self, db: &str, coll: &str) -> Result<Vec<Document>, QueryError>;
}
impl AggregationEngine {
/// Execute an aggregation pipeline on a set of documents.
pub fn aggregate(
docs: Vec<Document>,
pipeline: &[Document],
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let mut current = docs;
for stage in pipeline {
let (stage_name, stage_spec) = stage
.iter()
.next()
.ok_or_else(|| QueryError::AggregationError("Empty pipeline stage".into()))?;
current = match stage_name.as_str() {
"$match" => Self::stage_match(current, stage_spec)?,
"$project" => Self::stage_project(current, stage_spec)?,
"$sort" => Self::stage_sort(current, stage_spec)?,
"$limit" => Self::stage_limit(current, stage_spec)?,
"$skip" => Self::stage_skip(current, stage_spec)?,
"$group" => Self::stage_group(current, stage_spec)?,
"$unwind" => Self::stage_unwind(current, stage_spec)?,
"$count" => Self::stage_count(current, stage_spec)?,
"$addFields" | "$set" => Self::stage_add_fields(current, stage_spec)?,
"$replaceRoot" | "$replaceWith" => Self::stage_replace_root(current, stage_spec)?,
"$lookup" => Self::stage_lookup(current, stage_spec, resolver, db)?,
"$facet" => Self::stage_facet(current, stage_spec, resolver, db)?,
"$unionWith" => Self::stage_union_with(current, stage_spec, resolver, db)?,
other => {
return Err(QueryError::AggregationError(format!(
"Unsupported aggregation stage: {}",
other
)));
}
};
}
Ok(current)
}
fn stage_match(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let filter = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$match requires a document".into())),
};
Ok(QueryMatcher::filter(&docs, filter))
}
fn stage_project(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let projection = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$project requires a document".into())),
};
Ok(docs.into_iter().map(|doc| apply_projection(&doc, projection)).collect())
}
fn stage_sort(mut docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let sort_spec = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$sort requires a document".into())),
};
sort_documents(&mut docs, sort_spec);
Ok(docs)
}
fn stage_limit(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let n = bson_to_usize(spec)
.ok_or_else(|| QueryError::AggregationError("$limit requires a number".into()))?;
Ok(docs.into_iter().take(n).collect())
}
fn stage_skip(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let n = bson_to_usize(spec)
.ok_or_else(|| QueryError::AggregationError("$skip requires a number".into()))?;
Ok(docs.into_iter().skip(n).collect())
}
fn stage_group(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let group_spec = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$group requires a document".into())),
};
let id_expr = group_spec.get("_id").cloned().unwrap_or(Bson::Null);
// Group documents by _id
let mut groups: HashMap<String, (Bson, Vec<Document>)> = HashMap::new();
for doc in &docs {
let group_key = resolve_expression(&id_expr, doc);
let key_str = format!("{:?}", group_key);
groups
.entry(key_str)
.or_insert_with(|| (group_key.clone(), Vec::new()))
.1
.push(doc.clone());
}
let mut result = Vec::new();
for (_key_str, (group_id, group_docs)) in groups {
let mut output = bson::doc! { "_id": group_id };
for (field, accumulator) in group_spec {
if field == "_id" {
continue;
}
let acc_doc = match accumulator {
Bson::Document(d) => d,
_ => continue,
};
let (acc_op, acc_expr) = acc_doc.iter().next().unwrap();
let value = match acc_op.as_str() {
"$sum" => accumulate_sum(&group_docs, acc_expr),
"$avg" => accumulate_avg(&group_docs, acc_expr),
"$min" => accumulate_min(&group_docs, acc_expr),
"$max" => accumulate_max(&group_docs, acc_expr),
"$first" => accumulate_first(&group_docs, acc_expr),
"$last" => accumulate_last(&group_docs, acc_expr),
"$push" => accumulate_push(&group_docs, acc_expr),
"$addToSet" => accumulate_add_to_set(&group_docs, acc_expr),
"$count" => Bson::Int64(group_docs.len() as i64),
_ => Bson::Null,
};
output.insert(field.clone(), value);
}
result.push(output);
}
Ok(result)
}
fn stage_unwind(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let (path, preserve_null) = match spec {
Bson::String(s) => (s.trim_start_matches('$').to_string(), false),
Bson::Document(d) => {
let path = d.get_str("path")
.map(|s| s.trim_start_matches('$').to_string())
.map_err(|_| QueryError::AggregationError("$unwind requires 'path'".into()))?;
let preserve = d.get_bool("preserveNullAndEmptyArrays").unwrap_or(false);
(path, preserve)
}
_ => return Err(QueryError::AggregationError("$unwind requires a string or document".into())),
};
let mut result = Vec::new();
for doc in docs {
let value = doc.get(&path).cloned();
match value {
Some(Bson::Array(arr)) => {
if arr.is_empty() && preserve_null {
let mut new_doc = doc.clone();
new_doc.remove(&path);
result.push(new_doc);
} else {
for elem in arr {
let mut new_doc = doc.clone();
new_doc.insert(path.clone(), elem);
result.push(new_doc);
}
}
}
Some(Bson::Null) | None => {
if preserve_null {
result.push(doc);
}
}
Some(val) => {
// Non-array: keep as-is
let mut new_doc = doc;
new_doc.insert(path.clone(), val);
result.push(new_doc);
}
}
}
Ok(result)
}
fn stage_count(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let field = match spec {
Bson::String(s) => s.clone(),
_ => return Err(QueryError::AggregationError("$count requires a string".into())),
};
Ok(vec![bson::doc! { field: docs.len() as i64 }])
}
fn stage_add_fields(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let fields = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$addFields requires a document".into())),
};
Ok(docs
.into_iter()
.map(|mut doc| {
for (key, expr) in fields {
let value = resolve_expression(expr, &doc);
doc.insert(key.clone(), value);
}
doc
})
.collect())
}
fn stage_replace_root(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let new_root_expr = match spec {
Bson::Document(d) => d.get("newRoot").cloned().unwrap_or(Bson::Document(d.clone())),
Bson::String(s) => Bson::String(s.clone()),
_ => return Err(QueryError::AggregationError("$replaceRoot requires a document".into())),
};
let mut result = Vec::new();
for doc in docs {
let new_root = resolve_expression(&new_root_expr, &doc);
if let Bson::Document(d) = new_root {
result.push(d);
}
}
Ok(result)
}
fn stage_lookup(
docs: Vec<Document>,
spec: &Bson,
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let lookup = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$lookup requires a document".into())),
};
let from = lookup.get_str("from")
.map_err(|_| QueryError::AggregationError("$lookup requires 'from'".into()))?;
let local_field = lookup.get_str("localField")
.map_err(|_| QueryError::AggregationError("$lookup requires 'localField'".into()))?;
let foreign_field = lookup.get_str("foreignField")
.map_err(|_| QueryError::AggregationError("$lookup requires 'foreignField'".into()))?;
let as_field = lookup.get_str("as")
.map_err(|_| QueryError::AggregationError("$lookup requires 'as'".into()))?;
let resolver = resolver
.ok_or_else(|| QueryError::AggregationError("$lookup requires a collection resolver".into()))?;
let foreign_docs = resolver.resolve(db, from)?;
Ok(docs
.into_iter()
.map(|mut doc| {
let local_val = get_nested_value(&doc, local_field);
let matches: Vec<Bson> = foreign_docs
.iter()
.filter(|fd| {
let foreign_val = get_nested_value(fd, foreign_field);
match (&local_val, &foreign_val) {
(Some(a), Some(b)) => bson_loose_eq(a, b),
_ => false,
}
})
.map(|fd| Bson::Document(fd.clone()))
.collect();
doc.insert(as_field.to_string(), Bson::Array(matches));
doc
})
.collect())
}
fn stage_facet(
docs: Vec<Document>,
spec: &Bson,
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let facets = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$facet requires a document".into())),
};
let mut result = Document::new();
for (facet_name, pipeline_bson) in facets {
let pipeline = match pipeline_bson {
Bson::Array(arr) => {
let mut stages = Vec::new();
for stage in arr {
if let Bson::Document(d) = stage {
stages.push(d.clone());
}
}
stages
}
_ => continue,
};
let facet_result = Self::aggregate(docs.clone(), &pipeline, resolver, db)?;
result.insert(
facet_name.clone(),
Bson::Array(facet_result.into_iter().map(Bson::Document).collect()),
);
}
Ok(vec![result])
}
fn stage_union_with(
mut docs: Vec<Document>,
spec: &Bson,
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let (coll, pipeline) = match spec {
Bson::String(s) => (s.as_str(), None),
Bson::Document(d) => {
let coll = d.get_str("coll")
.map_err(|_| QueryError::AggregationError("$unionWith requires 'coll'".into()))?;
let pipeline = d.get_array("pipeline").ok().map(|arr| {
arr.iter()
.filter_map(|s| {
if let Bson::Document(d) = s { Some(d.clone()) } else { None }
})
.collect::<Vec<Document>>()
});
(coll, pipeline)
}
_ => return Err(QueryError::AggregationError("$unionWith requires a string or document".into())),
};
let resolver = resolver
.ok_or_else(|| QueryError::AggregationError("$unionWith requires a collection resolver".into()))?;
let mut other_docs = resolver.resolve(db, coll)?;
if let Some(p) = pipeline {
other_docs = Self::aggregate(other_docs, &p, Some(resolver), db)?;
}
docs.extend(other_docs);
Ok(docs)
}
}
// --- Helper functions ---
fn resolve_expression(expr: &Bson, doc: &Document) -> Bson {
match expr {
Bson::String(s) if s.starts_with('$') => {
let field = &s[1..];
get_nested_value(doc, field).unwrap_or(Bson::Null)
}
_ => expr.clone(),
}
}
fn bson_to_usize(v: &Bson) -> Option<usize> {
match v {
Bson::Int32(n) => Some(*n as usize),
Bson::Int64(n) => Some(*n as usize),
Bson::Double(n) => Some(*n as usize),
_ => None,
}
}
fn bson_to_f64(v: &Bson) -> Option<f64> {
match v {
Bson::Int32(n) => Some(*n as f64),
Bson::Int64(n) => Some(*n as f64),
Bson::Double(n) => Some(*n),
_ => None,
}
}
fn bson_loose_eq(a: &Bson, b: &Bson) -> bool {
match (a, b) {
(Bson::Int32(x), Bson::Int64(y)) => (*x as i64) == *y,
(Bson::Int64(x), Bson::Int32(y)) => *x == (*y as i64),
(Bson::Int32(x), Bson::Double(y)) => (*x as f64) == *y,
(Bson::Double(x), Bson::Int32(y)) => *x == (*y as f64),
_ => a == b,
}
}
// --- Accumulators ---
fn accumulate_sum(docs: &[Document], expr: &Bson) -> Bson {
match expr {
Bson::Int32(n) => Bson::Int64(*n as i64 * docs.len() as i64),
Bson::Int64(n) => Bson::Int64(*n * docs.len() as i64),
Bson::String(s) if s.starts_with('$') => {
let field = &s[1..];
let mut sum = 0.0f64;
let mut is_int = true;
let mut int_sum = 0i64;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
if let Some(n) = bson_to_f64(&val) {
sum += n;
if is_int {
match &val {
Bson::Int32(i) => int_sum += *i as i64,
Bson::Int64(i) => int_sum += i,
_ => is_int = false,
}
}
}
}
}
if is_int {
Bson::Int64(int_sum)
} else {
Bson::Double(sum)
}
}
_ => Bson::Int32(0),
}
}
fn accumulate_avg(docs: &[Document], expr: &Bson) -> Bson {
if docs.is_empty() {
return Bson::Null;
}
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
let mut sum = 0.0f64;
let mut count = 0usize;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
if let Some(n) = bson_to_f64(&val) {
sum += n;
count += 1;
}
}
}
if count == 0 {
Bson::Null
} else {
Bson::Double(sum / count as f64)
}
}
fn accumulate_min(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
let mut min: Option<Bson> = None;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
min = Some(match min {
None => val,
Some(current) => {
if let (Some(cv), Some(vv)) = (bson_to_f64(&current), bson_to_f64(&val)) {
if vv < cv { val } else { current }
} else {
current
}
}
});
}
}
min.unwrap_or(Bson::Null)
}
fn accumulate_max(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
let mut max: Option<Bson> = None;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
max = Some(match max {
None => val,
Some(current) => {
if let (Some(cv), Some(vv)) = (bson_to_f64(&current), bson_to_f64(&val)) {
if vv > cv { val } else { current }
} else {
current
}
}
});
}
}
max.unwrap_or(Bson::Null)
}
fn accumulate_first(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
docs.first()
.and_then(|doc| get_nested_value(doc, field))
.unwrap_or(Bson::Null)
}
fn accumulate_last(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
docs.last()
.and_then(|doc| get_nested_value(doc, field))
.unwrap_or(Bson::Null)
}
fn accumulate_push(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Array(vec![]),
};
let values: Vec<Bson> = docs
.iter()
.filter_map(|doc| get_nested_value(doc, field))
.collect();
Bson::Array(values)
}
fn accumulate_add_to_set(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Array(vec![]),
};
let mut seen = std::collections::HashSet::new();
let mut values = Vec::new();
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
let key = format!("{:?}", val);
if seen.insert(key) {
values.push(val);
}
}
}
Bson::Array(values)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_match_stage() {
let docs = vec![
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
bson::doc! { "x": 3 },
];
let pipeline = vec![bson::doc! { "$match": { "x": { "$gt": 1 } } }];
let result = AggregationEngine::aggregate(docs, &pipeline, None, "test").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_group_stage() {
let docs = vec![
bson::doc! { "category": "a", "value": 10 },
bson::doc! { "category": "b", "value": 20 },
bson::doc! { "category": "a", "value": 30 },
];
let pipeline = vec![bson::doc! {
"$group": {
"_id": "$category",
"total": { "$sum": "$value" }
}
}];
let result = AggregationEngine::aggregate(docs, &pipeline, None, "test").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_sort_limit_skip() {
let docs = vec![
bson::doc! { "x": 3 },
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
bson::doc! { "x": 4 },
];
let pipeline = vec![
bson::doc! { "$sort": { "x": 1 } },
bson::doc! { "$skip": 1_i64 },
bson::doc! { "$limit": 2_i64 },
];
let result = AggregationEngine::aggregate(docs, &pipeline, None, "test").unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].get_i32("x").unwrap(), 2);
assert_eq!(result[1].get_i32("x").unwrap(), 3);
}
}

View File

@@ -0,0 +1,80 @@
use bson::{Bson, Document};
use std::collections::HashSet;
use crate::field_path::get_nested_value;
use crate::matcher::QueryMatcher;
/// Get distinct values for a field across documents, with optional filter.
/// Handles array flattening (each array element counted separately).
pub fn distinct_values(
docs: &[Document],
field: &str,
filter: Option<&Document>,
) -> Vec<Bson> {
let filtered: Vec<&Document> = if let Some(f) = filter {
docs.iter().filter(|d| QueryMatcher::matches(d, f)).collect()
} else {
docs.iter().collect()
};
let mut seen = HashSet::new();
let mut result = Vec::new();
for doc in &filtered {
let value = if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
};
if let Some(val) = value {
collect_distinct_values(&val, &mut seen, &mut result);
}
}
result
}
fn collect_distinct_values(value: &Bson, seen: &mut HashSet<String>, result: &mut Vec<Bson>) {
match value {
Bson::Array(arr) => {
// Flatten: each array element is a separate value
for elem in arr {
collect_distinct_values(elem, seen, result);
}
}
_ => {
let key = format!("{:?}", value);
if seen.insert(key) {
result.push(value.clone());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distinct_simple() {
let docs = vec![
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
bson::doc! { "x": 1 },
bson::doc! { "x": 3 },
];
let result = distinct_values(&docs, "x", None);
assert_eq!(result.len(), 3);
}
#[test]
fn test_distinct_array_flattening() {
let docs = vec![
bson::doc! { "tags": ["a", "b"] },
bson::doc! { "tags": ["b", "c"] },
];
let result = distinct_values(&docs, "tags", None);
assert_eq!(result.len(), 3); // a, b, c
}
}

View File

@@ -0,0 +1,18 @@
/// Errors from query operations.
#[derive(Debug, thiserror::Error)]
pub enum QueryError {
#[error("Invalid query operator: {0}")]
InvalidOperator(String),
#[error("Type mismatch: {0}")]
TypeMismatch(String),
#[error("Invalid update: {0}")]
InvalidUpdate(String),
#[error("Aggregation error: {0}")]
AggregationError(String),
#[error("Invalid regex: {0}")]
InvalidRegex(String),
}

View File

@@ -0,0 +1,115 @@
use bson::{Bson, Document};
/// Get a nested value from a document using dot-notation path (e.g., "a.b.c").
/// Handles both nested documents and array traversal.
pub fn get_nested_value(doc: &Document, path: &str) -> Option<Bson> {
let parts: Vec<&str> = path.split('.').collect();
get_nested_recursive(&Bson::Document(doc.clone()), &parts)
}
fn get_nested_recursive(value: &Bson, parts: &[&str]) -> Option<Bson> {
if parts.is_empty() {
return Some(value.clone());
}
let key = parts[0];
let rest = &parts[1..];
match value {
Bson::Document(doc) => {
let child = doc.get(key)?;
get_nested_recursive(child, rest)
}
Bson::Array(arr) => {
// Try numeric index first
if let Ok(idx) = key.parse::<usize>() {
if let Some(elem) = arr.get(idx) {
return get_nested_recursive(elem, rest);
}
}
// Otherwise, collect from all elements
let results: Vec<Bson> = arr
.iter()
.filter_map(|elem| get_nested_recursive(elem, parts))
.collect();
if results.is_empty() {
None
} else if results.len() == 1 {
Some(results.into_iter().next().unwrap())
} else {
Some(Bson::Array(results))
}
}
_ => None,
}
}
/// Set a nested value in a document using dot-notation path.
pub fn set_nested_value(doc: &mut Document, path: &str, value: Bson) {
let parts: Vec<&str> = path.split('.').collect();
set_nested_recursive(doc, &parts, value);
}
fn set_nested_recursive(doc: &mut Document, parts: &[&str], value: Bson) {
if parts.len() == 1 {
doc.insert(parts[0].to_string(), value);
return;
}
let key = parts[0];
let rest = &parts[1..];
// Get or create nested document
if !doc.contains_key(key) {
doc.insert(key.to_string(), Bson::Document(Document::new()));
}
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
set_nested_recursive(nested, rest, value);
}
}
/// Remove a nested value from a document using dot-notation path.
pub fn remove_nested_value(doc: &mut Document, path: &str) -> Option<Bson> {
let parts: Vec<&str> = path.split('.').collect();
remove_nested_recursive(doc, &parts)
}
fn remove_nested_recursive(doc: &mut Document, parts: &[&str]) -> Option<Bson> {
if parts.len() == 1 {
return doc.remove(parts[0]);
}
let key = parts[0];
let rest = &parts[1..];
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
remove_nested_recursive(nested, rest)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_nested_simple() {
let doc = bson::doc! { "a": { "b": { "c": 42 } } };
assert_eq!(get_nested_value(&doc, "a.b.c"), Some(Bson::Int32(42)));
}
#[test]
fn test_get_nested_missing() {
let doc = bson::doc! { "a": { "b": 1 } };
assert_eq!(get_nested_value(&doc, "a.c"), None);
}
#[test]
fn test_set_nested() {
let mut doc = bson::doc! {};
set_nested_value(&mut doc, "a.b.c", Bson::Int32(42));
assert_eq!(get_nested_value(&doc, "a.b.c"), Some(Bson::Int32(42)));
}
}

View File

@@ -0,0 +1,16 @@
mod matcher;
mod update;
mod sort;
mod projection;
mod distinct;
pub mod aggregation;
mod field_path;
pub mod error;
pub use matcher::QueryMatcher;
pub use update::UpdateEngine;
pub use sort::sort_documents;
pub use projection::apply_projection;
pub use distinct::distinct_values;
pub use aggregation::AggregationEngine;
pub use field_path::{get_nested_value, set_nested_value};

View File

@@ -0,0 +1,574 @@
use bson::{Bson, Document};
use regex::Regex;
use crate::field_path::get_nested_value;
/// Query matching engine.
/// Evaluates filter documents against BSON documents.
pub struct QueryMatcher;
impl QueryMatcher {
/// Test whether a single document matches a filter.
pub fn matches(doc: &Document, filter: &Document) -> bool {
Self::matches_filter(doc, filter)
}
/// Filter a slice of documents, returning those that match.
pub fn filter(docs: &[Document], filter: &Document) -> Vec<Document> {
if filter.is_empty() {
return docs.to_vec();
}
docs.iter()
.filter(|doc| Self::matches_filter(doc, filter))
.cloned()
.collect()
}
/// Find the first document matching a filter.
pub fn find_one(docs: &[Document], filter: &Document) -> Option<Document> {
docs.iter()
.find(|doc| Self::matches_filter(doc, filter))
.cloned()
}
fn matches_filter(doc: &Document, filter: &Document) -> bool {
for (key, value) in filter {
if !Self::matches_condition(doc, key, value) {
return false;
}
}
true
}
fn matches_condition(doc: &Document, key: &str, condition: &Bson) -> bool {
match key {
"$and" => Self::match_logical_and(doc, condition),
"$or" => Self::match_logical_or(doc, condition),
"$nor" => Self::match_logical_nor(doc, condition),
"$not" => Self::match_logical_not(doc, condition),
"$expr" => {
// Basic $expr support - just return true for now
true
}
_ => {
// Field condition
match condition {
Bson::Document(cond_doc) if Self::has_operators(cond_doc) => {
Self::match_field_operators(doc, key, cond_doc)
}
// Implicit equality
_ => Self::match_equality(doc, key, condition),
}
}
}
}
fn has_operators(doc: &Document) -> bool {
doc.keys().any(|k| k.starts_with('$'))
}
/// Public accessor for has_operators (used by update engine).
pub fn has_operators_pub(doc: &Document) -> bool {
Self::has_operators(doc)
}
/// Public accessor for bson_compare (used by update engine).
pub fn bson_compare_pub(a: &Bson, b: &Bson) -> Option<std::cmp::Ordering> {
Self::bson_compare(a, b)
}
fn match_equality(doc: &Document, field: &str, expected: &Bson) -> bool {
let actual = Self::resolve_field(doc, field);
match actual {
Some(val) => Self::bson_equals(&val, expected),
None => matches!(expected, Bson::Null),
}
}
fn match_field_operators(doc: &Document, field: &str, operators: &Document) -> bool {
let actual = Self::resolve_field(doc, field);
for (op, op_value) in operators {
let result = match op.as_str() {
"$eq" => Self::op_eq(&actual, op_value),
"$ne" => Self::op_ne(&actual, op_value),
"$gt" => Self::op_cmp(&actual, op_value, CmpOp::Gt),
"$gte" => Self::op_cmp(&actual, op_value, CmpOp::Gte),
"$lt" => Self::op_cmp(&actual, op_value, CmpOp::Lt),
"$lte" => Self::op_cmp(&actual, op_value, CmpOp::Lte),
"$in" => Self::op_in(&actual, op_value),
"$nin" => Self::op_nin(&actual, op_value),
"$exists" => Self::op_exists(&actual, op_value),
"$type" => Self::op_type(&actual, op_value),
"$regex" => Self::op_regex(&actual, op_value, operators.get("$options")),
"$not" => Self::op_not(doc, field, op_value),
"$elemMatch" => Self::op_elem_match(&actual, op_value),
"$size" => Self::op_size(&actual, op_value),
"$all" => Self::op_all(&actual, op_value),
"$mod" => Self::op_mod(&actual, op_value),
"$options" => continue, // handled by $regex
_ => true, // unknown operator, skip
};
if !result {
return false;
}
}
true
}
fn resolve_field(doc: &Document, field: &str) -> Option<Bson> {
if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
}
}
fn bson_equals(a: &Bson, b: &Bson) -> bool {
match (a, b) {
(Bson::Int32(x), Bson::Int64(y)) => (*x as i64) == *y,
(Bson::Int64(x), Bson::Int32(y)) => *x == (*y as i64),
(Bson::Int32(x), Bson::Double(y)) => (*x as f64) == *y,
(Bson::Double(x), Bson::Int32(y)) => *x == (*y as f64),
(Bson::Int64(x), Bson::Double(y)) => (*x as f64) == *y,
(Bson::Double(x), Bson::Int64(y)) => *x == (*y as f64),
// For arrays, check if any element matches (implicit $elemMatch)
(Bson::Array(arr), _) if !matches!(b, Bson::Array(_)) => {
arr.iter().any(|elem| Self::bson_equals(elem, b))
}
_ => a == b,
}
}
fn bson_compare(a: &Bson, b: &Bson) -> Option<std::cmp::Ordering> {
use std::cmp::Ordering;
match (a, b) {
// Numeric comparisons (cross-type)
(Bson::Int32(x), Bson::Int32(y)) => Some(x.cmp(y)),
(Bson::Int64(x), Bson::Int64(y)) => Some(x.cmp(y)),
(Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y),
(Bson::Int32(x), Bson::Int64(y)) => Some((*x as i64).cmp(y)),
(Bson::Int64(x), Bson::Int32(y)) => Some(x.cmp(&(*y as i64))),
(Bson::Int32(x), Bson::Double(y)) => (*x as f64).partial_cmp(y),
(Bson::Double(x), Bson::Int32(y)) => x.partial_cmp(&(*y as f64)),
(Bson::Int64(x), Bson::Double(y)) => (*x as f64).partial_cmp(y),
(Bson::Double(x), Bson::Int64(y)) => x.partial_cmp(&(*y as f64)),
// String comparisons
(Bson::String(x), Bson::String(y)) => Some(x.cmp(y)),
// DateTime comparisons
(Bson::DateTime(x), Bson::DateTime(y)) => Some(x.cmp(y)),
// Boolean comparisons
(Bson::Boolean(x), Bson::Boolean(y)) => Some(x.cmp(y)),
// ObjectId comparisons
(Bson::ObjectId(x), Bson::ObjectId(y)) => Some(x.cmp(y)),
// Null comparisons
(Bson::Null, Bson::Null) => Some(Ordering::Equal),
_ => None,
}
}
// --- Operator implementations ---
fn op_eq(actual: &Option<Bson>, expected: &Bson) -> bool {
match actual {
Some(val) => Self::bson_equals(val, expected),
None => matches!(expected, Bson::Null),
}
}
fn op_ne(actual: &Option<Bson>, expected: &Bson) -> bool {
!Self::op_eq(actual, expected)
}
fn op_cmp(actual: &Option<Bson>, expected: &Bson, op: CmpOp) -> bool {
let val = match actual {
Some(v) => v,
None => return false,
};
// For arrays, check if any element satisfies the comparison
if let Bson::Array(arr) = val {
return arr.iter().any(|elem| {
if let Some(ord) = Self::bson_compare(elem, expected) {
op.check(ord)
} else {
false
}
});
}
if let Some(ord) = Self::bson_compare(val, expected) {
op.check(ord)
} else {
false
}
}
fn op_in(actual: &Option<Bson>, values: &Bson) -> bool {
let arr = match values {
Bson::Array(a) => a,
_ => return false,
};
match actual {
Some(val) => {
// For array values, check if any element is in the list
if let Bson::Array(actual_arr) = val {
actual_arr.iter().any(|elem| {
arr.iter().any(|v| Self::bson_equals(elem, v))
}) || arr.iter().any(|v| Self::bson_equals(val, v))
} else {
arr.iter().any(|v| Self::bson_equals(val, v))
}
}
None => arr.iter().any(|v| matches!(v, Bson::Null)),
}
}
fn op_nin(actual: &Option<Bson>, values: &Bson) -> bool {
!Self::op_in(actual, values)
}
fn op_exists(actual: &Option<Bson>, expected: &Bson) -> bool {
let should_exist = match expected {
Bson::Boolean(b) => *b,
Bson::Int32(n) => *n != 0,
Bson::Int64(n) => *n != 0,
_ => true,
};
actual.is_some() == should_exist
}
fn op_type(actual: &Option<Bson>, expected: &Bson) -> bool {
let val = match actual {
Some(v) => v,
None => return false,
};
let type_num = match expected {
Bson::Int32(n) => *n,
Bson::String(s) => match s.as_str() {
"double" => 1,
"string" => 2,
"object" => 3,
"array" => 4,
"binData" => 5,
"objectId" => 7,
"bool" => 8,
"date" => 9,
"null" => 10,
"regex" => 11,
"int" => 16,
"long" => 18,
"decimal" => 19,
"number" => -1, // special: any numeric type
_ => return false,
},
_ => return false,
};
if type_num == -1 {
return matches!(val, Bson::Int32(_) | Bson::Int64(_) | Bson::Double(_));
}
let actual_type = match val {
Bson::Double(_) => 1,
Bson::String(_) => 2,
Bson::Document(_) => 3,
Bson::Array(_) => 4,
Bson::Binary(_) => 5,
Bson::ObjectId(_) => 7,
Bson::Boolean(_) => 8,
Bson::DateTime(_) => 9,
Bson::Null => 10,
Bson::RegularExpression(_) => 11,
Bson::Int32(_) => 16,
Bson::Int64(_) => 18,
Bson::Decimal128(_) => 19,
_ => 0,
};
actual_type == type_num
}
fn op_regex(actual: &Option<Bson>, pattern: &Bson, options: Option<&Bson>) -> bool {
let val = match actual {
Some(Bson::String(s)) => s.as_str(),
_ => return false,
};
let pattern_str = match pattern {
Bson::String(s) => s.as_str(),
Bson::RegularExpression(re) => re.pattern.as_str(),
_ => return false,
};
let opts = match options {
Some(Bson::String(s)) => s.as_str(),
_ => match pattern {
Bson::RegularExpression(re) => re.options.as_str(),
_ => "",
},
};
let mut regex_pattern = String::new();
if opts.contains('i') {
regex_pattern.push_str("(?i)");
}
if opts.contains('m') {
regex_pattern.push_str("(?m)");
}
if opts.contains('s') {
regex_pattern.push_str("(?s)");
}
regex_pattern.push_str(pattern_str);
match Regex::new(&regex_pattern) {
Ok(re) => re.is_match(val),
Err(_) => false,
}
}
fn op_not(doc: &Document, field: &str, condition: &Bson) -> bool {
match condition {
Bson::Document(cond_doc) => !Self::match_field_operators(doc, field, cond_doc),
_ => true,
}
}
fn op_elem_match(actual: &Option<Bson>, condition: &Bson) -> bool {
let arr = match actual {
Some(Bson::Array(a)) => a,
_ => return false,
};
let cond_doc = match condition {
Bson::Document(d) => d,
_ => return false,
};
arr.iter().any(|elem| {
if let Bson::Document(elem_doc) = elem {
Self::matches_filter(elem_doc, cond_doc)
} else {
false
}
})
}
fn op_size(actual: &Option<Bson>, expected: &Bson) -> bool {
let arr = match actual {
Some(Bson::Array(a)) => a,
_ => return false,
};
let expected_size = match expected {
Bson::Int32(n) => *n as usize,
Bson::Int64(n) => *n as usize,
_ => return false,
};
arr.len() == expected_size
}
fn op_all(actual: &Option<Bson>, expected: &Bson) -> bool {
let arr = match actual {
Some(Bson::Array(a)) => a,
_ => return false,
};
let expected_arr = match expected {
Bson::Array(a) => a,
_ => return false,
};
expected_arr.iter().all(|expected_val| {
arr.iter().any(|elem| Self::bson_equals(elem, expected_val))
})
}
fn op_mod(actual: &Option<Bson>, expected: &Bson) -> bool {
let val = match actual {
Some(v) => match v {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
Bson::Double(n) => *n as i64,
_ => return false,
},
None => return false,
};
let arr = match expected {
Bson::Array(a) if a.len() == 2 => a,
_ => return false,
};
let divisor = match &arr[0] {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return false,
};
let remainder = match &arr[1] {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return false,
};
if divisor == 0 {
return false;
}
val % divisor == remainder
}
// --- Logical operators ---
fn match_logical_and(doc: &Document, conditions: &Bson) -> bool {
match conditions {
Bson::Array(arr) => arr.iter().all(|cond| {
if let Bson::Document(cond_doc) = cond {
Self::matches_filter(doc, cond_doc)
} else {
false
}
}),
_ => false,
}
}
fn match_logical_or(doc: &Document, conditions: &Bson) -> bool {
match conditions {
Bson::Array(arr) => arr.iter().any(|cond| {
if let Bson::Document(cond_doc) = cond {
Self::matches_filter(doc, cond_doc)
} else {
false
}
}),
_ => false,
}
}
fn match_logical_nor(doc: &Document, conditions: &Bson) -> bool {
!Self::match_logical_or(doc, conditions)
}
fn match_logical_not(doc: &Document, condition: &Bson) -> bool {
match condition {
Bson::Document(cond_doc) => !Self::matches_filter(doc, cond_doc),
_ => true,
}
}
}
#[derive(Debug, Clone, Copy)]
enum CmpOp {
Gt,
Gte,
Lt,
Lte,
}
impl CmpOp {
fn check(self, ord: std::cmp::Ordering) -> bool {
use std::cmp::Ordering;
match self {
CmpOp::Gt => ord == Ordering::Greater,
CmpOp::Gte => ord == Ordering::Greater || ord == Ordering::Equal,
CmpOp::Lt => ord == Ordering::Less,
CmpOp::Lte => ord == Ordering::Less || ord == Ordering::Equal,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_equality() {
let doc = bson::doc! { "name": "Alice", "age": 30 };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "name": "Alice" }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "name": "Bob" }));
}
#[test]
fn test_comparison_operators() {
let doc = bson::doc! { "age": 30 };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$gt": 25 } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$gte": 30 } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$lt": 35 } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$lte": 30 } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "age": { "$gt": 30 } }));
}
#[test]
fn test_in_operator() {
let doc = bson::doc! { "status": "active" };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "status": { "$in": ["active", "pending"] } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "status": { "$in": ["closed"] } }));
}
#[test]
fn test_exists_operator() {
let doc = bson::doc! { "name": "Alice" };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "name": { "$exists": true } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "age": { "$exists": true } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$exists": false } }));
}
#[test]
fn test_logical_or() {
let doc = bson::doc! { "age": 30 };
assert!(QueryMatcher::matches(&doc, &bson::doc! {
"$or": [{ "age": 30 }, { "age": 40 }]
}));
assert!(!QueryMatcher::matches(&doc, &bson::doc! {
"$or": [{ "age": 20 }, { "age": 40 }]
}));
}
#[test]
fn test_logical_and() {
let doc = bson::doc! { "age": 30, "name": "Alice" };
assert!(QueryMatcher::matches(&doc, &bson::doc! {
"$and": [{ "age": 30 }, { "name": "Alice" }]
}));
assert!(!QueryMatcher::matches(&doc, &bson::doc! {
"$and": [{ "age": 30 }, { "name": "Bob" }]
}));
}
#[test]
fn test_dot_notation() {
let doc = bson::doc! { "address": { "city": "NYC" } };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "address.city": "NYC" }));
}
#[test]
fn test_ne_operator() {
let doc = bson::doc! { "status": "active" };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "status": { "$ne": "closed" } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "status": { "$ne": "active" } }));
}
#[test]
fn test_cross_type_numeric_equality() {
let doc = bson::doc! { "count": 5_i32 };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "count": 5_i64 }));
}
#[test]
fn test_empty_filter_matches_all() {
let doc = bson::doc! { "x": 1 };
assert!(QueryMatcher::matches(&doc, &bson::doc! {}));
}
}

View File

@@ -0,0 +1,168 @@
use bson::{Bson, Document};
use crate::field_path::get_nested_value;
/// Apply a projection to a document.
/// Inclusion mode: only specified fields + _id.
/// Exclusion mode: all fields except specified ones.
/// _id can be explicitly excluded in either mode.
pub fn apply_projection(doc: &Document, projection: &Document) -> Document {
if projection.is_empty() {
return doc.clone();
}
// Determine mode: inclusion or exclusion
let mut has_inclusion = false;
let mut id_explicitly_set = false;
for (key, value) in projection {
if key == "_id" {
id_explicitly_set = true;
continue;
}
match value {
Bson::Int32(0) | Bson::Int64(0) | Bson::Boolean(false) => {}
_ => has_inclusion = true,
}
}
if has_inclusion {
apply_inclusion(doc, projection, id_explicitly_set)
} else {
apply_exclusion(doc, projection)
}
}
fn apply_inclusion(doc: &Document, projection: &Document, id_explicitly_set: bool) -> Document {
let mut result = Document::new();
// Include _id by default unless explicitly excluded
let include_id = if id_explicitly_set {
is_truthy(projection.get("_id"))
} else {
true
};
if include_id {
if let Some(id) = doc.get("_id") {
result.insert("_id", id.clone());
}
}
for (key, value) in projection {
if key == "_id" {
continue;
}
if !is_truthy(Some(value)) {
continue;
}
if key.contains('.') {
if let Some(val) = get_nested_value(doc, key) {
// Rebuild nested structure
set_nested_in_result(&mut result, key, val);
}
} else if let Some(val) = doc.get(key) {
result.insert(key.clone(), val.clone());
}
}
result
}
fn apply_exclusion(doc: &Document, projection: &Document) -> Document {
let mut result = doc.clone();
for (key, value) in projection {
if !is_truthy(Some(value)) {
if key.contains('.') {
// Remove nested field
remove_nested_from_result(&mut result, key);
} else {
result.remove(key);
}
}
}
result
}
fn is_truthy(value: Option<&Bson>) -> bool {
match value {
None => false,
Some(Bson::Int32(0)) | Some(Bson::Int64(0)) | Some(Bson::Boolean(false)) => false,
_ => true,
}
}
fn set_nested_in_result(doc: &mut Document, path: &str, value: Bson) {
let parts: Vec<&str> = path.split('.').collect();
set_nested_recursive(doc, &parts, value);
}
fn set_nested_recursive(doc: &mut Document, parts: &[&str], value: Bson) {
if parts.len() == 1 {
doc.insert(parts[0].to_string(), value);
return;
}
let key = parts[0];
if !doc.contains_key(key) {
doc.insert(key.to_string(), Bson::Document(Document::new()));
}
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
set_nested_recursive(nested, &parts[1..], value);
}
}
fn remove_nested_from_result(doc: &mut Document, path: &str) {
let parts: Vec<&str> = path.split('.').collect();
remove_nested_recursive(doc, &parts);
}
fn remove_nested_recursive(doc: &mut Document, parts: &[&str]) {
if parts.len() == 1 {
doc.remove(parts[0]);
return;
}
let key = parts[0];
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
remove_nested_recursive(nested, &parts[1..]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inclusion_projection() {
let doc = bson::doc! { "_id": 1, "name": "Alice", "age": 30, "email": "a@b.c" };
let proj = bson::doc! { "name": 1, "age": 1 };
let result = apply_projection(&doc, &proj);
assert!(result.contains_key("_id"));
assert!(result.contains_key("name"));
assert!(result.contains_key("age"));
assert!(!result.contains_key("email"));
}
#[test]
fn test_exclusion_projection() {
let doc = bson::doc! { "_id": 1, "name": "Alice", "age": 30 };
let proj = bson::doc! { "age": 0 };
let result = apply_projection(&doc, &proj);
assert!(result.contains_key("_id"));
assert!(result.contains_key("name"));
assert!(!result.contains_key("age"));
}
#[test]
fn test_exclude_id() {
let doc = bson::doc! { "_id": 1, "name": "Alice" };
let proj = bson::doc! { "name": 1, "_id": 0 };
let result = apply_projection(&doc, &proj);
assert!(!result.contains_key("_id"));
assert!(result.contains_key("name"));
}
}

View File

@@ -0,0 +1,137 @@
use bson::{Bson, Document};
use crate::field_path::get_nested_value;
/// Sort documents according to a sort specification.
/// Sort spec: `{ field1: 1, field2: -1 }` where 1 = ascending, -1 = descending.
pub fn sort_documents(docs: &mut [Document], sort_spec: &Document) {
if sort_spec.is_empty() {
return;
}
docs.sort_by(|a, b| {
for (field, direction) in sort_spec {
let ascending = match direction {
Bson::Int32(n) => *n > 0,
Bson::Int64(n) => *n > 0,
Bson::String(s) => !s.eq_ignore_ascii_case("desc") && !s.eq_ignore_ascii_case("descending"),
_ => true,
};
let a_val = get_value(a, field);
let b_val = get_value(b, field);
let ord = compare_bson_values(&a_val, &b_val);
let ord = if ascending { ord } else { ord.reverse() };
if ord != std::cmp::Ordering::Equal {
return ord;
}
}
std::cmp::Ordering::Equal
});
}
fn get_value(doc: &Document, field: &str) -> Option<Bson> {
if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
}
}
/// Compare two BSON values for sorting purposes.
/// BSON type sort order: null < numbers < strings < objects < arrays < binData < ObjectId < bool < date
fn compare_bson_values(a: &Option<Bson>, b: &Option<Bson>) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(None, None) => Ordering::Equal,
(None, Some(Bson::Null)) => Ordering::Equal,
(Some(Bson::Null), None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(Bson::Null), Some(Bson::Null)) => Ordering::Equal,
(Some(Bson::Null), Some(_)) => Ordering::Less,
(Some(_), Some(Bson::Null)) => Ordering::Greater,
(Some(av), Some(bv)) => compare_typed(av, bv),
}
}
fn compare_typed(a: &Bson, b: &Bson) -> std::cmp::Ordering {
use std::cmp::Ordering;
// Cross-type numeric comparison
let a_num = to_f64(a);
let b_num = to_f64(b);
if let (Some(an), Some(bn)) = (a_num, b_num) {
return an.partial_cmp(&bn).unwrap_or(Ordering::Equal);
}
match (a, b) {
(Bson::String(x), Bson::String(y)) => x.cmp(y),
(Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y),
(Bson::DateTime(x), Bson::DateTime(y)) => x.cmp(y),
(Bson::ObjectId(x), Bson::ObjectId(y)) => x.cmp(y),
_ => {
let ta = type_order(a);
let tb = type_order(b);
ta.cmp(&tb)
}
}
}
fn to_f64(v: &Bson) -> Option<f64> {
match v {
Bson::Int32(n) => Some(*n as f64),
Bson::Int64(n) => Some(*n as f64),
Bson::Double(n) => Some(*n),
_ => None,
}
}
fn type_order(v: &Bson) -> u8 {
match v {
Bson::Null => 0,
Bson::Int32(_) | Bson::Int64(_) | Bson::Double(_) | Bson::Decimal128(_) => 1,
Bson::String(_) => 2,
Bson::Document(_) => 3,
Bson::Array(_) => 4,
Bson::Binary(_) => 5,
Bson::ObjectId(_) => 7,
Bson::Boolean(_) => 8,
Bson::DateTime(_) => 9,
_ => 10,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_ascending() {
let mut docs = vec![
bson::doc! { "x": 3 },
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
];
sort_documents(&mut docs, &bson::doc! { "x": 1 });
assert_eq!(docs[0].get_i32("x").unwrap(), 1);
assert_eq!(docs[1].get_i32("x").unwrap(), 2);
assert_eq!(docs[2].get_i32("x").unwrap(), 3);
}
#[test]
fn test_sort_descending() {
let mut docs = vec![
bson::doc! { "x": 1 },
bson::doc! { "x": 3 },
bson::doc! { "x": 2 },
];
sort_documents(&mut docs, &bson::doc! { "x": -1 });
assert_eq!(docs[0].get_i32("x").unwrap(), 3);
assert_eq!(docs[1].get_i32("x").unwrap(), 2);
assert_eq!(docs[2].get_i32("x").unwrap(), 1);
}
}

View File

@@ -0,0 +1,575 @@
use bson::{Bson, Document, doc};
use crate::error::QueryError;
use crate::field_path::{get_nested_value, set_nested_value, remove_nested_value};
use crate::matcher::QueryMatcher;
/// Update engine — applies update operators to documents.
pub struct UpdateEngine;
impl UpdateEngine {
/// Apply an update specification to a document.
/// Returns the updated document.
pub fn apply_update(
doc: &Document,
update: &Document,
_array_filters: Option<&[Document]>,
) -> Result<Document, QueryError> {
// Check if this is a replacement (no $ operators)
if !update.keys().any(|k| k.starts_with('$')) {
return Self::apply_replacement(doc, update);
}
let mut result = doc.clone();
for (op, value) in update {
let fields = match value {
Bson::Document(d) => d,
_ => continue,
};
match op.as_str() {
"$set" => Self::apply_set(&mut result, fields)?,
"$unset" => Self::apply_unset(&mut result, fields)?,
"$inc" => Self::apply_inc(&mut result, fields)?,
"$mul" => Self::apply_mul(&mut result, fields)?,
"$min" => Self::apply_min(&mut result, fields)?,
"$max" => Self::apply_max(&mut result, fields)?,
"$rename" => Self::apply_rename(&mut result, fields)?,
"$currentDate" => Self::apply_current_date(&mut result, fields)?,
"$setOnInsert" => {} // handled separately during upsert
"$push" => Self::apply_push(&mut result, fields)?,
"$pop" => Self::apply_pop(&mut result, fields)?,
"$pull" => Self::apply_pull(&mut result, fields)?,
"$pullAll" => Self::apply_pull_all(&mut result, fields)?,
"$addToSet" => Self::apply_add_to_set(&mut result, fields)?,
"$bit" => Self::apply_bit(&mut result, fields)?,
other => {
return Err(QueryError::InvalidUpdate(format!(
"Unknown update operator: {}",
other
)));
}
}
}
Ok(result)
}
/// Apply $setOnInsert fields (used during upsert only).
pub fn apply_set_on_insert(doc: &mut Document, fields: &Document) {
for (key, value) in fields {
if key.contains('.') {
set_nested_value(doc, key, value.clone());
} else {
doc.insert(key.clone(), value.clone());
}
}
}
/// Deep clone a BSON document.
pub fn deep_clone(doc: &Document) -> Document {
doc.clone()
}
fn apply_replacement(doc: &Document, replacement: &Document) -> Result<Document, QueryError> {
let mut result = replacement.clone();
// Preserve _id
if let Some(id) = doc.get("_id") {
result.insert("_id", id.clone());
}
Ok(result)
}
fn apply_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
if key.contains('.') {
set_nested_value(doc, key, value.clone());
} else {
doc.insert(key.clone(), value.clone());
}
}
Ok(())
}
fn apply_unset(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, _) in fields {
if key.contains('.') {
remove_nested_value(doc, key);
} else {
doc.remove(key);
}
}
Ok(())
}
fn apply_inc(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, inc_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let new_value = match (&current, inc_value) {
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a + b),
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a + b),
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 + b),
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a + *b as i64),
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a + b),
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a + *b as f64),
(Some(Bson::Int64(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
(Some(Bson::Double(a)), Bson::Int64(b)) => Bson::Double(a + *b as f64),
(None, v) => v.clone(), // treat missing as 0
_ => {
return Err(QueryError::TypeMismatch(format!(
"Cannot apply $inc to non-numeric field: {}",
key
)));
}
};
if key.contains('.') {
set_nested_value(doc, key, new_value);
} else {
doc.insert(key.clone(), new_value);
}
}
Ok(())
}
fn apply_mul(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, mul_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let new_value = match (&current, mul_value) {
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a * b),
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a * b),
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 * b),
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a * *b as i64),
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a * b),
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 * b),
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a * *b as f64),
(None, _) => Bson::Int32(0), // missing field * anything = 0
_ => {
return Err(QueryError::TypeMismatch(format!(
"Cannot apply $mul to non-numeric field: {}",
key
)));
}
};
if key.contains('.') {
set_nested_value(doc, key, new_value);
} else {
doc.insert(key.clone(), new_value);
}
}
Ok(())
}
fn apply_min(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, min_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let should_update = match &current {
None => true,
Some(cur) => {
if let Some(ord) = QueryMatcher::bson_compare_pub(min_value, cur) {
ord == std::cmp::Ordering::Less
} else {
false
}
}
};
if should_update {
if key.contains('.') {
set_nested_value(doc, key, min_value.clone());
} else {
doc.insert(key.clone(), min_value.clone());
}
}
}
Ok(())
}
fn apply_max(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, max_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let should_update = match &current {
None => true,
Some(cur) => {
if let Some(ord) = QueryMatcher::bson_compare_pub(max_value, cur) {
ord == std::cmp::Ordering::Greater
} else {
false
}
}
};
if should_update {
if key.contains('.') {
set_nested_value(doc, key, max_value.clone());
} else {
doc.insert(key.clone(), max_value.clone());
}
}
}
Ok(())
}
fn apply_rename(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (old_name, new_name_bson) in fields {
let new_name = match new_name_bson {
Bson::String(s) => s.clone(),
_ => continue,
};
if let Some(value) = doc.remove(old_name) {
doc.insert(new_name, value);
}
}
Ok(())
}
fn apply_current_date(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
let now = bson::DateTime::now();
for (key, spec) in fields {
let value = match spec {
Bson::Boolean(true) => Bson::DateTime(now),
Bson::Document(d) => {
match d.get_str("$type").unwrap_or("date") {
"date" => Bson::DateTime(now),
"timestamp" => Bson::Timestamp(bson::Timestamp {
time: (now.timestamp_millis() / 1000) as u32,
increment: 0,
}),
_ => Bson::DateTime(now),
}
}
_ => continue,
};
if key.contains('.') {
set_nested_value(doc, key, value);
} else {
doc.insert(key.clone(), value);
}
}
Ok(())
}
fn apply_push(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
let arr = Self::get_or_create_array(doc, key);
match value {
Bson::Document(d) if d.contains_key("$each") => {
let each = match d.get("$each") {
Some(Bson::Array(a)) => a.clone(),
_ => return Err(QueryError::InvalidUpdate("$each must be an array".into())),
};
let position = d.get("$position").and_then(|v| match v {
Bson::Int32(n) => Some(*n as usize),
Bson::Int64(n) => Some(*n as usize),
_ => None,
});
if let Some(pos) = position {
let pos = pos.min(arr.len());
for (i, item) in each.into_iter().enumerate() {
arr.insert(pos + i, item);
}
} else {
arr.extend(each);
}
// Apply $sort if present
if let Some(sort_spec) = d.get("$sort") {
Self::sort_array(arr, sort_spec);
}
// Apply $slice if present
if let Some(slice) = d.get("$slice") {
Self::slice_array(arr, slice);
}
}
_ => {
arr.push(value.clone());
}
}
}
Ok(())
}
fn apply_pop(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, direction) in fields {
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
if arr.is_empty() {
continue;
}
match direction {
Bson::Int32(-1) | Bson::Int64(-1) => { arr.remove(0); }
Bson::Int32(1) | Bson::Int64(1) => { arr.pop(); }
Bson::Double(f) if *f == 1.0 => { arr.pop(); }
Bson::Double(f) if *f == -1.0 => { arr.remove(0); }
_ => { arr.pop(); }
}
}
}
Ok(())
}
fn apply_pull(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, condition) in fields {
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
match condition {
Bson::Document(cond_doc) if QueryMatcher::has_operators_pub(cond_doc) => {
arr.retain(|elem| {
if let Bson::Document(elem_doc) = elem {
!QueryMatcher::matches(elem_doc, cond_doc)
} else {
// For primitive matching with operators
let wrapper = doc! { "v": elem.clone() };
let cond_wrapper = doc! { "v": condition.clone() };
!QueryMatcher::matches(&wrapper, &cond_wrapper)
}
});
}
_ => {
arr.retain(|elem| elem != condition);
}
}
}
}
Ok(())
}
fn apply_pull_all(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, values) in fields {
if let (Some(Bson::Array(arr)), Bson::Array(to_remove)) = (doc.get_mut(key), values) {
arr.retain(|elem| !to_remove.contains(elem));
}
}
Ok(())
}
fn apply_add_to_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
let arr = Self::get_or_create_array(doc, key);
match value {
Bson::Document(d) if d.contains_key("$each") => {
if let Some(Bson::Array(each)) = d.get("$each") {
for item in each {
if !arr.contains(item) {
arr.push(item.clone());
}
}
}
}
_ => {
if !arr.contains(value) {
arr.push(value.clone());
}
}
}
}
Ok(())
}
fn apply_bit(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, ops) in fields {
let ops_doc = match ops {
Bson::Document(d) => d,
_ => continue,
};
let current = doc.get(key).cloned().unwrap_or(Bson::Int32(0));
let mut val = match &current {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => continue,
};
for (bit_op, operand) in ops_doc {
let operand_val = match operand {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => continue,
};
match bit_op.as_str() {
"and" => val &= operand_val,
"or" => val |= operand_val,
"xor" => val ^= operand_val,
_ => {}
}
}
let new_value = match &current {
Bson::Int32(_) => Bson::Int32(val as i32),
_ => Bson::Int64(val),
};
doc.insert(key.clone(), new_value);
}
Ok(())
}
// --- Helpers ---
fn get_or_create_array<'a>(doc: &'a mut Document, key: &str) -> &'a mut Vec<Bson> {
// Ensure an array exists at this key
let needs_init = match doc.get(key) {
Some(Bson::Array(_)) => false,
_ => true,
};
if needs_init {
doc.insert(key.to_string(), Bson::Array(Vec::new()));
}
match doc.get_mut(key).unwrap() {
Bson::Array(arr) => arr,
_ => unreachable!(),
}
}
fn sort_array(arr: &mut Vec<Bson>, sort_spec: &Bson) {
match sort_spec {
Bson::Int32(dir) => {
let ascending = *dir > 0;
arr.sort_by(|a, b| {
let ord = partial_cmp_bson(a, b);
if ascending { ord } else { ord.reverse() }
});
}
Bson::Document(spec) => {
arr.sort_by(|a, b| {
for (field, dir) in spec {
let ascending = match dir {
Bson::Int32(n) => *n > 0,
_ => true,
};
let a_val = if let Bson::Document(d) = a { d.get(field) } else { None };
let b_val = if let Bson::Document(d) = b { d.get(field) } else { None };
let ord = match (a_val, b_val) {
(Some(av), Some(bv)) => partial_cmp_bson(av, bv),
(Some(_), None) => std::cmp::Ordering::Greater,
(None, Some(_)) => std::cmp::Ordering::Less,
(None, None) => std::cmp::Ordering::Equal,
};
let ord = if ascending { ord } else { ord.reverse() };
if ord != std::cmp::Ordering::Equal {
return ord;
}
}
std::cmp::Ordering::Equal
});
}
_ => {}
}
}
fn slice_array(arr: &mut Vec<Bson>, slice: &Bson) {
let n = match slice {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return,
};
if n >= 0 {
arr.truncate(n as usize);
} else {
let keep = (-n) as usize;
if keep < arr.len() {
let start = arr.len() - keep;
*arr = arr[start..].to_vec();
}
}
}
}
fn partial_cmp_bson(a: &Bson, b: &Bson) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(Bson::Int32(x), Bson::Int32(y)) => x.cmp(y),
(Bson::Int64(x), Bson::Int64(y)) => x.cmp(y),
(Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
(Bson::String(x), Bson::String(y)) => x.cmp(y),
(Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y),
_ => Ordering::Equal,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set() {
let doc = doc! { "_id": 1, "name": "Alice" };
let update = doc! { "$set": { "name": "Bob", "age": 30 } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_str("name").unwrap(), "Bob");
assert_eq!(result.get_i32("age").unwrap(), 30);
}
#[test]
fn test_inc() {
let doc = doc! { "_id": 1, "count": 5 };
let update = doc! { "$inc": { "count": 3 } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_i32("count").unwrap(), 8);
}
#[test]
fn test_unset() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
let update = doc! { "$unset": { "age": "" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert!(result.get("age").is_none());
}
#[test]
fn test_replacement() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
let update = doc! { "name": "Bob" };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_i32("_id").unwrap(), 1); // preserved
assert_eq!(result.get_str("name").unwrap(), "Bob");
assert!(result.get("age").is_none()); // removed
}
#[test]
fn test_push() {
let doc = doc! { "_id": 1, "tags": ["a"] };
let update = doc! { "$push": { "tags": "b" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2);
}
#[test]
fn test_add_to_set() {
let doc = doc! { "_id": 1, "tags": ["a", "b"] };
let update = doc! { "$addToSet": { "tags": "a" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2); // no duplicate
}
}

View File

@@ -0,0 +1,19 @@
[package]
name = "rustdb-storage"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Storage adapters (memory, file) with WAL and OpLog for RustDb"
[dependencies]
bson = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
dashmap = { workspace = true }
tokio = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
crc32fast = { workspace = true }
uuid = { workspace = true }
async-trait = { workspace = true }

View File

@@ -0,0 +1,185 @@
use std::collections::HashSet;
use async_trait::async_trait;
use bson::Document;
use crate::error::StorageResult;
/// Core storage adapter trait that all backends must implement.
#[async_trait]
pub trait StorageAdapter: Send + Sync {
// ---- lifecycle ----
/// Initialize the storage backend (create directories, open files, etc.).
async fn initialize(&self) -> StorageResult<()>;
/// Gracefully shut down the storage backend.
async fn close(&self) -> StorageResult<()>;
// ---- database operations ----
/// List all database names.
async fn list_databases(&self) -> StorageResult<Vec<String>>;
/// Create a new database.
async fn create_database(&self, db: &str) -> StorageResult<()>;
/// Drop a database and all its collections.
async fn drop_database(&self, db: &str) -> StorageResult<()>;
/// Check whether a database exists.
async fn database_exists(&self, db: &str) -> StorageResult<bool>;
// ---- collection operations ----
/// List all collection names in a database.
async fn list_collections(&self, db: &str) -> StorageResult<Vec<String>>;
/// Create a new collection inside a database.
async fn create_collection(&self, db: &str, coll: &str) -> StorageResult<()>;
/// Drop a collection.
async fn drop_collection(&self, db: &str, coll: &str) -> StorageResult<()>;
/// Check whether a collection exists.
async fn collection_exists(&self, db: &str, coll: &str) -> StorageResult<bool>;
/// Rename a collection within the same database.
async fn rename_collection(
&self,
db: &str,
old_name: &str,
new_name: &str,
) -> StorageResult<()>;
// ---- document write operations ----
/// Insert a single document. Returns the `_id` as hex string.
async fn insert_one(
&self,
db: &str,
coll: &str,
doc: Document,
) -> StorageResult<String>;
/// Insert many documents. Returns the `_id` hex strings.
async fn insert_many(
&self,
db: &str,
coll: &str,
docs: Vec<Document>,
) -> StorageResult<Vec<String>>;
/// Replace a document by its `_id` hex string.
async fn update_by_id(
&self,
db: &str,
coll: &str,
id: &str,
doc: Document,
) -> StorageResult<()>;
/// Delete a single document by `_id` hex string.
async fn delete_by_id(
&self,
db: &str,
coll: &str,
id: &str,
) -> StorageResult<()>;
/// Delete multiple documents by `_id` hex strings.
async fn delete_by_ids(
&self,
db: &str,
coll: &str,
ids: &[String],
) -> StorageResult<()>;
// ---- document read operations ----
/// Return all documents in a collection.
async fn find_all(
&self,
db: &str,
coll: &str,
) -> StorageResult<Vec<Document>>;
/// Return documents whose `_id` hex is in the given set.
async fn find_by_ids(
&self,
db: &str,
coll: &str,
ids: HashSet<String>,
) -> StorageResult<Vec<Document>>;
/// Return a single document by `_id` hex.
async fn find_by_id(
&self,
db: &str,
coll: &str,
id: &str,
) -> StorageResult<Option<Document>>;
/// Count documents in a collection.
async fn count(
&self,
db: &str,
coll: &str,
) -> StorageResult<u64>;
// ---- index operations ----
/// Persist an index specification for a collection.
async fn save_index(
&self,
db: &str,
coll: &str,
name: &str,
spec: Document,
) -> StorageResult<()>;
/// Return all saved index specs for a collection.
async fn get_indexes(
&self,
db: &str,
coll: &str,
) -> StorageResult<Vec<Document>>;
/// Drop a named index.
async fn drop_index(
&self,
db: &str,
coll: &str,
name: &str,
) -> StorageResult<()>;
// ---- snapshot / conflict detection ----
/// Create a logical snapshot timestamp for a collection. Returns a timestamp (ms).
async fn create_snapshot(
&self,
db: &str,
coll: &str,
) -> StorageResult<i64>;
/// Check if any of the given document ids have been modified after `snapshot_time`.
async fn has_conflicts(
&self,
db: &str,
coll: &str,
ids: &HashSet<String>,
snapshot_time: i64,
) -> StorageResult<bool>;
// ---- optional persistence (for in-memory backends) ----
/// Persist current state to durable storage. Default: no-op.
async fn persist(&self) -> StorageResult<()> {
Ok(())
}
/// Restore state from durable storage. Default: no-op.
async fn restore(&self) -> StorageResult<()> {
Ok(())
}
}

View File

@@ -0,0 +1,40 @@
use thiserror::Error;
/// Errors that can occur in storage operations.
#[derive(Debug, Error)]
pub enum StorageError {
#[error("not found: {0}")]
NotFound(String),
#[error("already exists: {0}")]
AlreadyExists(String),
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("serialization error: {0}")]
SerializationError(String),
#[error("conflict detected: {0}")]
ConflictError(String),
}
impl From<serde_json::Error> for StorageError {
fn from(e: serde_json::Error) -> Self {
StorageError::SerializationError(e.to_string())
}
}
impl From<bson::de::Error> for StorageError {
fn from(e: bson::de::Error) -> Self {
StorageError::SerializationError(e.to_string())
}
}
impl From<bson::ser::Error> for StorageError {
fn from(e: bson::ser::Error) -> Self {
StorageError::SerializationError(e.to_string())
}
}
pub type StorageResult<T> = Result<T, StorageError>;

View File

@@ -0,0 +1,476 @@
use std::collections::HashSet;
use std::path::PathBuf;
use async_trait::async_trait;
use bson::{doc, oid::ObjectId, Document};
use tracing::debug;
use crate::adapter::StorageAdapter;
use crate::error::{StorageError, StorageResult};
/// File-based storage adapter. Each collection is stored as a JSON file:
/// `{base_path}/{db}/{coll}.json`
/// Index metadata lives alongside:
/// `{base_path}/{db}/{coll}.indexes.json`
pub struct FileStorageAdapter {
base_path: PathBuf,
}
impl FileStorageAdapter {
pub fn new(base_path: impl Into<PathBuf>) -> Self {
Self {
base_path: base_path.into(),
}
}
fn db_dir(&self, db: &str) -> PathBuf {
self.base_path.join(db)
}
fn coll_path(&self, db: &str, coll: &str) -> PathBuf {
self.db_dir(db).join(format!("{coll}.json"))
}
fn index_path(&self, db: &str, coll: &str) -> PathBuf {
self.db_dir(db).join(format!("{coll}.indexes.json"))
}
/// Read all documents from a collection file. Returns empty vec if file doesn't exist.
async fn read_docs(&self, db: &str, coll: &str) -> StorageResult<Vec<Document>> {
let path = self.coll_path(db, coll);
if !path.exists() {
return Err(StorageError::NotFound(format!(
"collection '{db}.{coll}'"
)));
}
let data = tokio::fs::read_to_string(&path).await?;
let json_docs: Vec<serde_json::Value> = serde_json::from_str(&data)?;
let mut docs = Vec::with_capacity(json_docs.len());
for jv in json_docs {
let bson_val: bson::Bson = serde_json::from_value(jv)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
let doc = bson_val
.as_document()
.ok_or_else(|| StorageError::SerializationError("expected document".into()))?
.clone();
docs.push(doc);
}
Ok(docs)
}
/// Write all documents to a collection file.
async fn write_docs(&self, db: &str, coll: &str, docs: &[Document]) -> StorageResult<()> {
let path = self.coll_path(db, coll);
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let json_vals: Vec<serde_json::Value> = docs
.iter()
.map(|d| {
let b = bson::to_bson(d)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
serde_json::to_value(&b)
.map_err(|e| StorageError::SerializationError(e.to_string()))
})
.collect::<StorageResult<Vec<_>>>()?;
let json = serde_json::to_string_pretty(&json_vals)?;
tokio::fs::write(&path, json).await?;
Ok(())
}
/// Read index specs from the indexes file.
async fn read_indexes(&self, db: &str, coll: &str) -> StorageResult<Vec<Document>> {
let path = self.index_path(db, coll);
if !path.exists() {
return Ok(vec![]);
}
let data = tokio::fs::read_to_string(&path).await?;
let json_vals: Vec<serde_json::Value> = serde_json::from_str(&data)?;
let mut docs = Vec::new();
for jv in json_vals {
let bson_val: bson::Bson = serde_json::from_value(jv)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
let doc = bson_val
.as_document()
.ok_or_else(|| StorageError::SerializationError("expected document".into()))?
.clone();
docs.push(doc);
}
Ok(docs)
}
/// Write index specs to the indexes file.
async fn write_indexes(&self, db: &str, coll: &str, specs: &[Document]) -> StorageResult<()> {
let path = self.index_path(db, coll);
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let json_vals: Vec<serde_json::Value> = specs
.iter()
.map(|d| {
let b = bson::to_bson(d)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
serde_json::to_value(&b)
.map_err(|e| StorageError::SerializationError(e.to_string()))
})
.collect::<StorageResult<Vec<_>>>()?;
let json = serde_json::to_string_pretty(&json_vals)?;
tokio::fs::write(&path, json).await?;
Ok(())
}
fn extract_id_hex(doc: &Document) -> StorageResult<String> {
match doc.get("_id") {
Some(bson::Bson::ObjectId(oid)) => Ok(oid.to_hex()),
_ => Err(StorageError::NotFound("document missing _id".into())),
}
}
}
#[async_trait]
impl StorageAdapter for FileStorageAdapter {
async fn initialize(&self) -> StorageResult<()> {
tokio::fs::create_dir_all(&self.base_path).await?;
debug!("FileStorageAdapter initialized at {:?}", self.base_path);
Ok(())
}
async fn close(&self) -> StorageResult<()> {
debug!("FileStorageAdapter closed");
Ok(())
}
// ---- database ----
async fn list_databases(&self) -> StorageResult<Vec<String>> {
let mut dbs = Vec::new();
let mut entries = tokio::fs::read_dir(&self.base_path).await?;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_dir() {
if let Some(name) = entry.file_name().to_str() {
dbs.push(name.to_string());
}
}
}
Ok(dbs)
}
async fn create_database(&self, db: &str) -> StorageResult<()> {
let dir = self.db_dir(db);
if dir.exists() {
return Err(StorageError::AlreadyExists(format!("database '{db}'")));
}
tokio::fs::create_dir_all(&dir).await?;
Ok(())
}
async fn drop_database(&self, db: &str) -> StorageResult<()> {
let dir = self.db_dir(db);
if dir.exists() {
tokio::fs::remove_dir_all(&dir).await?;
}
Ok(())
}
async fn database_exists(&self, db: &str) -> StorageResult<bool> {
Ok(self.db_dir(db).exists())
}
// ---- collection ----
async fn list_collections(&self, db: &str) -> StorageResult<Vec<String>> {
let dir = self.db_dir(db);
if !dir.exists() {
return Err(StorageError::NotFound(format!("database '{db}'")));
}
let mut colls = Vec::new();
let mut entries = tokio::fs::read_dir(&dir).await?;
while let Some(entry) = entries.next_entry().await? {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".json") && !name.ends_with(".indexes.json") {
colls.push(name.trim_end_matches(".json").to_string());
}
}
}
Ok(colls)
}
async fn create_collection(&self, db: &str, coll: &str) -> StorageResult<()> {
let path = self.coll_path(db, coll);
if path.exists() {
return Err(StorageError::AlreadyExists(format!(
"collection '{db}.{coll}'"
)));
}
// Ensure db dir exists.
tokio::fs::create_dir_all(self.db_dir(db)).await?;
// Write empty array.
self.write_docs(db, coll, &[]).await?;
// Write default _id index.
let idx_spec = doc! { "name": "_id_", "key": { "_id": 1 } };
self.write_indexes(db, coll, &[idx_spec]).await?;
Ok(())
}
async fn drop_collection(&self, db: &str, coll: &str) -> StorageResult<()> {
let path = self.coll_path(db, coll);
if path.exists() {
tokio::fs::remove_file(&path).await?;
}
let idx_path = self.index_path(db, coll);
if idx_path.exists() {
tokio::fs::remove_file(&idx_path).await?;
}
Ok(())
}
async fn collection_exists(&self, db: &str, coll: &str) -> StorageResult<bool> {
Ok(self.coll_path(db, coll).exists())
}
async fn rename_collection(
&self,
db: &str,
old_name: &str,
new_name: &str,
) -> StorageResult<()> {
let old_path = self.coll_path(db, old_name);
let new_path = self.coll_path(db, new_name);
if !old_path.exists() {
return Err(StorageError::NotFound(format!(
"collection '{db}.{old_name}'"
)));
}
if new_path.exists() {
return Err(StorageError::AlreadyExists(format!(
"collection '{db}.{new_name}'"
)));
}
tokio::fs::rename(&old_path, &new_path).await?;
// Rename index file too.
let old_idx = self.index_path(db, old_name);
let new_idx = self.index_path(db, new_name);
if old_idx.exists() {
tokio::fs::rename(&old_idx, &new_idx).await?;
}
Ok(())
}
// ---- document writes ----
async fn insert_one(
&self,
db: &str,
coll: &str,
mut doc: Document,
) -> StorageResult<String> {
if !doc.contains_key("_id") {
doc.insert("_id", ObjectId::new());
}
let id = Self::extract_id_hex(&doc)?;
let mut docs = self.read_docs(db, coll).await?;
// Check for duplicate.
for existing in &docs {
if Self::extract_id_hex(existing)? == id {
return Err(StorageError::AlreadyExists(format!("document '{id}'")));
}
}
docs.push(doc);
self.write_docs(db, coll, &docs).await?;
Ok(id)
}
async fn insert_many(
&self,
db: &str,
coll: &str,
mut new_docs: Vec<Document>,
) -> StorageResult<Vec<String>> {
let mut docs = self.read_docs(db, coll).await?;
let mut ids = Vec::with_capacity(new_docs.len());
for doc in &mut new_docs {
if !doc.contains_key("_id") {
doc.insert("_id", ObjectId::new());
}
ids.push(Self::extract_id_hex(doc)?);
}
docs.extend(new_docs);
self.write_docs(db, coll, &docs).await?;
Ok(ids)
}
async fn update_by_id(
&self,
db: &str,
coll: &str,
id: &str,
doc: Document,
) -> StorageResult<()> {
let mut docs = self.read_docs(db, coll).await?;
let mut found = false;
for existing in &mut docs {
if Self::extract_id_hex(existing)? == id {
*existing = doc.clone();
found = true;
break;
}
}
if !found {
return Err(StorageError::NotFound(format!("document '{id}'")));
}
self.write_docs(db, coll, &docs).await?;
Ok(())
}
async fn delete_by_id(
&self,
db: &str,
coll: &str,
id: &str,
) -> StorageResult<()> {
let mut docs = self.read_docs(db, coll).await?;
let len_before = docs.len();
docs.retain(|d| Self::extract_id_hex(d).map(|i| i != id).unwrap_or(true));
if docs.len() == len_before {
return Err(StorageError::NotFound(format!("document '{id}'")));
}
self.write_docs(db, coll, &docs).await?;
Ok(())
}
async fn delete_by_ids(
&self,
db: &str,
coll: &str,
ids: &[String],
) -> StorageResult<()> {
let id_set: HashSet<&str> = ids.iter().map(|s| s.as_str()).collect();
let mut docs = self.read_docs(db, coll).await?;
docs.retain(|d| {
Self::extract_id_hex(d)
.map(|i| !id_set.contains(i.as_str()))
.unwrap_or(true)
});
self.write_docs(db, coll, &docs).await?;
Ok(())
}
// ---- document reads ----
async fn find_all(
&self,
db: &str,
coll: &str,
) -> StorageResult<Vec<Document>> {
self.read_docs(db, coll).await
}
async fn find_by_ids(
&self,
db: &str,
coll: &str,
ids: HashSet<String>,
) -> StorageResult<Vec<Document>> {
let docs = self.read_docs(db, coll).await?;
Ok(docs
.into_iter()
.filter(|d| {
Self::extract_id_hex(d)
.map(|i| ids.contains(&i))
.unwrap_or(false)
})
.collect())
}
async fn find_by_id(
&self,
db: &str,
coll: &str,
id: &str,
) -> StorageResult<Option<Document>> {
let docs = self.read_docs(db, coll).await?;
Ok(docs
.into_iter()
.find(|d| Self::extract_id_hex(d).map(|i| i == id).unwrap_or(false)))
}
async fn count(
&self,
db: &str,
coll: &str,
) -> StorageResult<u64> {
let docs = self.read_docs(db, coll).await?;
Ok(docs.len() as u64)
}
// ---- indexes ----
async fn save_index(
&self,
db: &str,
coll: &str,
name: &str,
spec: Document,
) -> StorageResult<()> {
let mut indexes = self.read_indexes(db, coll).await?;
indexes.retain(|s| s.get_str("name").unwrap_or("") != name);
let mut full_spec = spec;
full_spec.insert("name", name);
indexes.push(full_spec);
self.write_indexes(db, coll, &indexes).await
}
async fn get_indexes(
&self,
db: &str,
coll: &str,
) -> StorageResult<Vec<Document>> {
self.read_indexes(db, coll).await
}
async fn drop_index(
&self,
db: &str,
coll: &str,
name: &str,
) -> StorageResult<()> {
let mut indexes = self.read_indexes(db, coll).await?;
let before = indexes.len();
indexes.retain(|s| s.get_str("name").unwrap_or("") != name);
if indexes.len() == before {
return Err(StorageError::NotFound(format!("index '{name}'")));
}
self.write_indexes(db, coll, &indexes).await
}
// ---- snapshot / conflict detection ----
// File adapter doesn't track per-document timestamps, so conflict detection
// is a no-op (always returns false).
async fn create_snapshot(
&self,
_db: &str,
_coll: &str,
) -> StorageResult<i64> {
use std::time::{SystemTime, UNIX_EPOCH};
Ok(SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64)
}
async fn has_conflicts(
&self,
_db: &str,
_coll: &str,
_ids: &HashSet<String>,
_snapshot_time: i64,
) -> StorageResult<bool> {
// File adapter does not track modification timestamps per document.
Ok(false)
}
}

View File

@@ -0,0 +1,22 @@
//! `rustdb-storage` -- Storage adapters for RustDb.
//!
//! Provides the [`StorageAdapter`] trait and two concrete implementations:
//! - [`MemoryStorageAdapter`] -- fast in-memory store backed by `DashMap`
//! - [`FileStorageAdapter`] -- JSON-file-per-collection persistent store
//!
//! Also includes an [`OpLog`] for operation logging and a [`WriteAheadLog`]
//! for crash recovery.
pub mod adapter;
pub mod error;
pub mod file;
pub mod memory;
pub mod oplog;
pub mod wal;
pub use adapter::StorageAdapter;
pub use error::{StorageError, StorageResult};
pub use file::FileStorageAdapter;
pub use memory::MemoryStorageAdapter;
pub use oplog::{OpLog, OpLogEntry, OpType};
pub use wal::{WalOp, WalRecord, WriteAheadLog};

View File

@@ -0,0 +1,613 @@
use std::collections::HashSet;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use bson::{doc, oid::ObjectId, Document};
use dashmap::DashMap;
use tracing::{debug, warn};
use crate::adapter::StorageAdapter;
use crate::error::{StorageError, StorageResult};
/// Per-document timestamp tracking for conflict detection.
type TimestampMap = DashMap<String, i64>;
/// db -> coll -> id_hex -> Document
type DataStore = DashMap<String, DashMap<String, DashMap<String, Document>>>;
/// db -> coll -> Vec<index spec Document>
type IndexStore = DashMap<String, DashMap<String, Vec<Document>>>;
/// db -> coll -> id_hex -> last_modified_ms
type ModificationStore = DashMap<String, DashMap<String, TimestampMap>>;
fn now_ms() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
/// In-memory storage adapter backed by `DashMap`.
///
/// Optionally persists to a JSON file at a configured path.
pub struct MemoryStorageAdapter {
data: DataStore,
indexes: IndexStore,
modifications: ModificationStore,
persist_path: Option<PathBuf>,
}
impl MemoryStorageAdapter {
/// Create a new purely in-memory adapter.
pub fn new() -> Self {
Self {
data: DashMap::new(),
indexes: DashMap::new(),
modifications: DashMap::new(),
persist_path: None,
}
}
/// Create a new adapter that will persist state to the given JSON file.
pub fn with_persist_path(path: PathBuf) -> Self {
Self {
data: DashMap::new(),
indexes: DashMap::new(),
modifications: DashMap::new(),
persist_path: Some(path),
}
}
/// Get or create the database entry in the data store.
fn ensure_db(&self, db: &str) {
self.data.entry(db.to_string()).or_insert_with(DashMap::new);
self.indexes
.entry(db.to_string())
.or_insert_with(DashMap::new);
self.modifications
.entry(db.to_string())
.or_insert_with(DashMap::new);
}
fn extract_id(doc: &Document) -> StorageResult<String> {
match doc.get("_id") {
Some(bson::Bson::ObjectId(oid)) => Ok(oid.to_hex()),
_ => Err(StorageError::NotFound("document missing _id".into())),
}
}
fn record_modification(&self, db: &str, coll: &str, id: &str) {
if let Some(db_mods) = self.modifications.get(db) {
if let Some(coll_mods) = db_mods.get(coll) {
coll_mods.insert(id.to_string(), now_ms());
}
}
}
}
#[async_trait]
impl StorageAdapter for MemoryStorageAdapter {
async fn initialize(&self) -> StorageResult<()> {
debug!("MemoryStorageAdapter initialized");
Ok(())
}
async fn close(&self) -> StorageResult<()> {
// Persist if configured.
self.persist().await?;
debug!("MemoryStorageAdapter closed");
Ok(())
}
// ---- database ----
async fn list_databases(&self) -> StorageResult<Vec<String>> {
Ok(self.data.iter().map(|e| e.key().clone()).collect())
}
async fn create_database(&self, db: &str) -> StorageResult<()> {
if self.data.contains_key(db) {
return Err(StorageError::AlreadyExists(format!("database '{db}'")));
}
self.ensure_db(db);
Ok(())
}
async fn drop_database(&self, db: &str) -> StorageResult<()> {
self.data.remove(db);
self.indexes.remove(db);
self.modifications.remove(db);
Ok(())
}
async fn database_exists(&self, db: &str) -> StorageResult<bool> {
Ok(self.data.contains_key(db))
}
// ---- collection ----
async fn list_collections(&self, db: &str) -> StorageResult<Vec<String>> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
Ok(db_ref.iter().map(|e| e.key().clone()).collect())
}
async fn create_collection(&self, db: &str, coll: &str) -> StorageResult<()> {
self.ensure_db(db);
let db_ref = self.data.get(db).unwrap();
if db_ref.contains_key(coll) {
return Err(StorageError::AlreadyExists(format!(
"collection '{db}.{coll}'"
)));
}
db_ref.insert(coll.to_string(), DashMap::new());
drop(db_ref);
// Create modification tracker for this collection.
if let Some(db_mods) = self.modifications.get(db) {
db_mods.insert(coll.to_string(), DashMap::new());
}
// Auto-create _id index spec.
let idx_spec = doc! { "name": "_id_", "key": { "_id": 1 } };
if let Some(db_idx) = self.indexes.get(db) {
db_idx.insert(coll.to_string(), vec![idx_spec]);
}
Ok(())
}
async fn drop_collection(&self, db: &str, coll: &str) -> StorageResult<()> {
if let Some(db_ref) = self.data.get(db) {
db_ref.remove(coll);
}
if let Some(db_idx) = self.indexes.get(db) {
db_idx.remove(coll);
}
if let Some(db_mods) = self.modifications.get(db) {
db_mods.remove(coll);
}
Ok(())
}
async fn collection_exists(&self, db: &str, coll: &str) -> StorageResult<bool> {
Ok(self
.data
.get(db)
.map(|db_ref| db_ref.contains_key(coll))
.unwrap_or(false))
}
async fn rename_collection(
&self,
db: &str,
old_name: &str,
new_name: &str,
) -> StorageResult<()> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
if db_ref.contains_key(new_name) {
return Err(StorageError::AlreadyExists(format!(
"collection '{db}.{new_name}'"
)));
}
let (_, coll_data) = db_ref
.remove(old_name)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{old_name}'")))?;
db_ref.insert(new_name.to_string(), coll_data);
drop(db_ref);
// Rename in indexes.
if let Some(db_idx) = self.indexes.get(db) {
if let Some((_, idx_data)) = db_idx.remove(old_name) {
db_idx.insert(new_name.to_string(), idx_data);
}
}
// Rename in modifications.
if let Some(db_mods) = self.modifications.get(db) {
if let Some((_, mod_data)) = db_mods.remove(old_name) {
db_mods.insert(new_name.to_string(), mod_data);
}
}
Ok(())
}
// ---- document writes ----
async fn insert_one(
&self,
db: &str,
coll: &str,
mut doc: Document,
) -> StorageResult<String> {
// Ensure _id exists.
if !doc.contains_key("_id") {
doc.insert("_id", ObjectId::new());
}
let id = Self::extract_id(&doc)?;
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
if coll_ref.contains_key(&id) {
return Err(StorageError::AlreadyExists(format!("document '{id}'")));
}
coll_ref.insert(id.clone(), doc);
drop(coll_ref);
drop(db_ref);
self.record_modification(db, coll, &id);
Ok(id)
}
async fn insert_many(
&self,
db: &str,
coll: &str,
docs: Vec<Document>,
) -> StorageResult<Vec<String>> {
let mut ids = Vec::with_capacity(docs.len());
for doc in docs {
let id = self.insert_one(db, coll, doc).await?;
ids.push(id);
}
Ok(ids)
}
async fn update_by_id(
&self,
db: &str,
coll: &str,
id: &str,
doc: Document,
) -> StorageResult<()> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
if !coll_ref.contains_key(id) {
return Err(StorageError::NotFound(format!("document '{id}'")));
}
coll_ref.insert(id.to_string(), doc);
drop(coll_ref);
drop(db_ref);
self.record_modification(db, coll, id);
Ok(())
}
async fn delete_by_id(
&self,
db: &str,
coll: &str,
id: &str,
) -> StorageResult<()> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
coll_ref
.remove(id)
.ok_or_else(|| StorageError::NotFound(format!("document '{id}'")))?;
drop(coll_ref);
drop(db_ref);
self.record_modification(db, coll, id);
Ok(())
}
async fn delete_by_ids(
&self,
db: &str,
coll: &str,
ids: &[String],
) -> StorageResult<()> {
for id in ids {
self.delete_by_id(db, coll, id).await?;
}
Ok(())
}
// ---- document reads ----
async fn find_all(
&self,
db: &str,
coll: &str,
) -> StorageResult<Vec<Document>> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
Ok(coll_ref.iter().map(|e| e.value().clone()).collect())
}
async fn find_by_ids(
&self,
db: &str,
coll: &str,
ids: HashSet<String>,
) -> StorageResult<Vec<Document>> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
let mut results = Vec::with_capacity(ids.len());
for id in &ids {
if let Some(doc) = coll_ref.get(id) {
results.push(doc.value().clone());
}
}
Ok(results)
}
async fn find_by_id(
&self,
db: &str,
coll: &str,
id: &str,
) -> StorageResult<Option<Document>> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
Ok(coll_ref.get(id).map(|e| e.value().clone()))
}
async fn count(
&self,
db: &str,
coll: &str,
) -> StorageResult<u64> {
let db_ref = self
.data
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let coll_ref = db_ref
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
Ok(coll_ref.len() as u64)
}
// ---- indexes ----
async fn save_index(
&self,
db: &str,
coll: &str,
name: &str,
spec: Document,
) -> StorageResult<()> {
let db_idx = self
.indexes
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let mut specs = db_idx
.get_mut(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
// Remove existing index with same name, then add.
specs.retain(|s| s.get_str("name").unwrap_or("") != name);
let mut full_spec = spec;
full_spec.insert("name", name);
specs.push(full_spec);
Ok(())
}
async fn get_indexes(
&self,
db: &str,
coll: &str,
) -> StorageResult<Vec<Document>> {
let db_idx = self
.indexes
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let specs = db_idx
.get(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
Ok(specs.clone())
}
async fn drop_index(
&self,
db: &str,
coll: &str,
name: &str,
) -> StorageResult<()> {
let db_idx = self
.indexes
.get(db)
.ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?;
let mut specs = db_idx
.get_mut(coll)
.ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?;
let before = specs.len();
specs.retain(|s| s.get_str("name").unwrap_or("") != name);
if specs.len() == before {
return Err(StorageError::NotFound(format!("index '{name}'")));
}
Ok(())
}
// ---- snapshot / conflict detection ----
async fn create_snapshot(
&self,
_db: &str,
_coll: &str,
) -> StorageResult<i64> {
Ok(now_ms())
}
async fn has_conflicts(
&self,
db: &str,
coll: &str,
ids: &HashSet<String>,
snapshot_time: i64,
) -> StorageResult<bool> {
if let Some(db_mods) = self.modifications.get(db) {
if let Some(coll_mods) = db_mods.get(coll) {
for id in ids {
if let Some(ts) = coll_mods.get(id) {
if *ts.value() > snapshot_time {
return Ok(true);
}
}
}
}
}
Ok(false)
}
// ---- persistence ----
async fn persist(&self) -> StorageResult<()> {
let path = match &self.persist_path {
Some(p) => p,
None => return Ok(()),
};
// Serialize the entire data store to JSON.
let mut db_map = serde_json::Map::new();
for db_entry in self.data.iter() {
let db_name = db_entry.key().clone();
let mut coll_map = serde_json::Map::new();
for coll_entry in db_entry.value().iter() {
let coll_name = coll_entry.key().clone();
let mut docs_map = serde_json::Map::new();
for doc_entry in coll_entry.value().iter() {
let id = doc_entry.key().clone();
// Convert bson::Document -> serde_json::Value via bson's
// built-in extended-JSON serialization.
let json_val: serde_json::Value =
bson::to_bson(doc_entry.value())
.map_err(|e| StorageError::SerializationError(e.to_string()))
.and_then(|b| {
serde_json::to_value(&b)
.map_err(|e| StorageError::SerializationError(e.to_string()))
})?;
docs_map.insert(id, json_val);
}
coll_map.insert(coll_name, serde_json::Value::Object(docs_map));
}
db_map.insert(db_name, serde_json::Value::Object(coll_map));
}
let json = serde_json::to_string_pretty(&serde_json::Value::Object(db_map))?;
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(path, json).await?;
debug!("MemoryStorageAdapter persisted to {:?}", path);
Ok(())
}
async fn restore(&self) -> StorageResult<()> {
let path = match &self.persist_path {
Some(p) => p,
None => return Ok(()),
};
if !path.exists() {
warn!("persist file not found at {:?}, skipping restore", path);
return Ok(());
}
let json = tokio::fs::read_to_string(path).await?;
let root: serde_json::Value = serde_json::from_str(&json)?;
let root_obj = root
.as_object()
.ok_or_else(|| StorageError::SerializationError("expected object".into()))?;
self.data.clear();
self.indexes.clear();
self.modifications.clear();
for (db_name, colls_val) in root_obj {
self.ensure_db(db_name);
let db_ref = self.data.get(db_name).unwrap();
let colls = colls_val
.as_object()
.ok_or_else(|| StorageError::SerializationError("expected object".into()))?;
for (coll_name, docs_val) in colls {
let coll_map: DashMap<String, Document> = DashMap::new();
let docs = docs_val
.as_object()
.ok_or_else(|| StorageError::SerializationError("expected object".into()))?;
for (id, doc_val) in docs {
let bson_val: bson::Bson = serde_json::from_value(doc_val.clone())
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
let doc = bson_val
.as_document()
.ok_or_else(|| {
StorageError::SerializationError("expected document".into())
})?
.clone();
coll_map.insert(id.clone(), doc);
}
db_ref.insert(coll_name.clone(), coll_map);
// Restore modification tracker and default _id index.
if let Some(db_mods) = self.modifications.get(db_name) {
db_mods.insert(coll_name.clone(), DashMap::new());
}
if let Some(db_idx) = self.indexes.get(db_name) {
let idx_spec = doc! { "name": "_id_", "key": { "_id": 1 } };
db_idx.insert(coll_name.clone(), vec![idx_spec]);
}
}
}
debug!("MemoryStorageAdapter restored from {:?}", path);
Ok(())
}
}
impl Default for MemoryStorageAdapter {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,120 @@
//! Operation log (OpLog) for tracking mutations.
//!
//! The OpLog records every write operation so that changes can be replayed,
//! replicated, or used for change-stream style notifications.
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use bson::Document;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
/// The type of operation recorded in the oplog.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum OpType {
Insert,
Update,
Delete,
}
/// A single oplog entry.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpLogEntry {
/// Monotonically increasing sequence number.
pub seq: u64,
/// Timestamp in milliseconds since UNIX epoch.
pub timestamp_ms: i64,
/// Operation type.
pub op: OpType,
/// Database name.
pub db: String,
/// Collection name.
pub collection: String,
/// Document id (hex string).
pub document_id: String,
/// The document snapshot (for insert/update; None for delete).
pub document: Option<Document>,
}
/// In-memory operation log.
pub struct OpLog {
/// All entries keyed by sequence number.
entries: DashMap<u64, OpLogEntry>,
/// Next sequence number.
next_seq: AtomicU64,
}
impl OpLog {
pub fn new() -> Self {
Self {
entries: DashMap::new(),
next_seq: AtomicU64::new(1),
}
}
/// Append an operation to the log and return its sequence number.
pub fn append(
&self,
op: OpType,
db: &str,
collection: &str,
document_id: &str,
document: Option<Document>,
) -> u64 {
let seq = self.next_seq.fetch_add(1, Ordering::SeqCst);
let entry = OpLogEntry {
seq,
timestamp_ms: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64,
op,
db: db.to_string(),
collection: collection.to_string(),
document_id: document_id.to_string(),
document,
};
self.entries.insert(seq, entry);
seq
}
/// Get all entries with sequence number >= `since`.
pub fn entries_since(&self, since: u64) -> Vec<OpLogEntry> {
let mut result: Vec<_> = self
.entries
.iter()
.filter(|e| *e.key() >= since)
.map(|e| e.value().clone())
.collect();
result.sort_by_key(|e| e.seq);
result
}
/// Get the current (latest) sequence number. Returns 0 if empty.
pub fn current_seq(&self) -> u64 {
self.next_seq.load(Ordering::SeqCst).saturating_sub(1)
}
/// Clear all entries.
pub fn clear(&self) {
self.entries.clear();
self.next_seq.store(1, Ordering::SeqCst);
}
/// Number of entries in the log.
pub fn len(&self) -> usize {
self.entries.len()
}
/// Whether the log is empty.
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for OpLog {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,186 @@
//! Write-Ahead Log (WAL) for crash recovery.
//!
//! Before any mutation is applied to storage, it is first written to the WAL.
//! On recovery, uncommitted WAL entries can be replayed or discarded.
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use bson::Document;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tracing::{debug, warn};
use crate::error::StorageResult;
/// WAL operation kind.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum WalOp {
Insert,
Update,
Delete,
}
/// A single WAL record.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalRecord {
/// Sequence number.
pub seq: u64,
/// Operation kind.
pub op: WalOp,
/// Database name.
pub db: String,
/// Collection name.
pub collection: String,
/// Document id (hex string).
pub document_id: String,
/// Document data (for insert/update).
pub document: Option<Document>,
/// Whether this record has been committed (applied to storage).
pub committed: bool,
/// CRC32 checksum of the serialized payload for integrity verification.
pub checksum: u32,
}
/// Write-ahead log that persists records to a file.
pub struct WriteAheadLog {
path: PathBuf,
next_seq: AtomicU64,
}
impl WriteAheadLog {
/// Create a new WAL at the given file path.
pub fn new(path: PathBuf) -> Self {
Self {
path,
next_seq: AtomicU64::new(1),
}
}
/// Initialize the WAL (create file if needed, load sequence counter).
pub async fn initialize(&self) -> StorageResult<()> {
if let Some(parent) = self.path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
if self.path.exists() {
// Load existing records to find the max sequence number.
let records = self.read_all().await?;
if let Some(max_seq) = records.iter().map(|r| r.seq).max() {
self.next_seq.store(max_seq + 1, Ordering::SeqCst);
}
}
debug!("WAL initialized at {:?}", self.path);
Ok(())
}
/// Append a record to the WAL. Returns the sequence number.
pub async fn append(
&self,
op: WalOp,
db: &str,
collection: &str,
document_id: &str,
document: Option<Document>,
) -> StorageResult<u64> {
let seq = self.next_seq.fetch_add(1, Ordering::SeqCst);
// Compute checksum over the payload.
let payload = serde_json::json!({
"op": op,
"db": db,
"collection": collection,
"document_id": document_id,
});
let payload_bytes = serde_json::to_vec(&payload)?;
let checksum = crc32fast::hash(&payload_bytes);
let record = WalRecord {
seq,
op,
db: db.to_string(),
collection: collection.to_string(),
document_id: document_id.to_string(),
document,
committed: false,
checksum,
};
let line = serde_json::to_string(&record)?;
let mut file = tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&self.path)
.await?;
file.write_all(line.as_bytes()).await?;
file.write_all(b"\n").await?;
file.flush().await?;
Ok(seq)
}
/// Mark a WAL record as committed by rewriting the file.
pub async fn mark_committed(&self, seq: u64) -> StorageResult<()> {
let mut records = self.read_all().await?;
for record in &mut records {
if record.seq == seq {
record.committed = true;
}
}
self.write_all(&records).await
}
/// Read all WAL records.
pub async fn read_all(&self) -> StorageResult<Vec<WalRecord>> {
if !self.path.exists() {
return Ok(vec![]);
}
let data = tokio::fs::read_to_string(&self.path).await?;
let mut records = Vec::new();
for line in data.lines() {
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<WalRecord>(line) {
Ok(record) => records.push(record),
Err(e) => {
warn!("skipping corrupt WAL record: {e}");
}
}
}
Ok(records)
}
/// Get all uncommitted records (for replay during recovery).
pub async fn uncommitted(&self) -> StorageResult<Vec<WalRecord>> {
let records = self.read_all().await?;
Ok(records.into_iter().filter(|r| !r.committed).collect())
}
/// Truncate the WAL, removing all committed records.
pub async fn truncate_committed(&self) -> StorageResult<()> {
let records = self.read_all().await?;
let uncommitted: Vec<_> = records.into_iter().filter(|r| !r.committed).collect();
self.write_all(&uncommitted).await
}
/// Clear the entire WAL.
pub async fn clear(&self) -> StorageResult<()> {
if self.path.exists() {
tokio::fs::write(&self.path, "").await?;
}
self.next_seq.store(1, Ordering::SeqCst);
Ok(())
}
/// Write all records to the WAL file (overwrites).
async fn write_all(&self, records: &[WalRecord]) -> StorageResult<()> {
let mut content = String::new();
for record in records {
let line = serde_json::to_string(record)?;
content.push_str(&line);
content.push('\n');
}
tokio::fs::write(&self.path, content).await?;
Ok(())
}
}

View File

@@ -0,0 +1,17 @@
[package]
name = "rustdb-txn"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible transaction and session management with snapshot isolation for RustDb"
[dependencies]
bson = { workspace = true }
dashmap = { workspace = true }
tokio = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
rustdb-storage = { workspace = true }
async-trait = { workspace = true }

View File

@@ -0,0 +1,35 @@
use thiserror::Error;
/// Errors that can occur during transaction or session operations.
#[derive(Debug, Error)]
pub enum TransactionError {
#[error("not found: {0}")]
NotFound(String),
#[error("transaction already active for session: {0}")]
AlreadyActive(String),
#[error("write conflict detected (code 112): {0}")]
WriteConflict(String),
#[error("session expired: {0}")]
SessionExpired(String),
#[error("invalid transaction state: {0}")]
InvalidState(String),
}
impl TransactionError {
/// Returns the error code.
pub fn code(&self) -> i32 {
match self {
TransactionError::NotFound(_) => 251,
TransactionError::AlreadyActive(_) => 256,
TransactionError::WriteConflict(_) => 112,
TransactionError::SessionExpired(_) => 6100,
TransactionError::InvalidState(_) => 263,
}
}
}
pub type TransactionResult<T> = Result<T, TransactionError>;

View File

@@ -0,0 +1,9 @@
pub mod error;
mod session;
mod transaction;
pub use error::{TransactionError, TransactionResult};
pub use session::{Session, SessionEngine};
pub use transaction::{
TransactionEngine, TransactionState, TransactionStatus, WriteEntry, WriteOp,
};

View File

@@ -0,0 +1,205 @@
use std::time::{Duration, Instant};
use bson::Bson;
use dashmap::DashMap;
use tracing::{debug, warn};
use crate::error::{TransactionError, TransactionResult};
/// Represents a logical session.
#[derive(Debug, Clone)]
pub struct Session {
pub id: String,
pub created_at: Instant,
pub last_activity_at: Instant,
pub txn_id: Option<String>,
pub in_transaction: bool,
}
/// Engine that manages logical sessions with timeout and cleanup.
pub struct SessionEngine {
sessions: DashMap<String, Session>,
timeout: Duration,
_cleanup_interval: Duration,
}
impl SessionEngine {
/// Create a new session engine.
///
/// * `timeout_ms` - Session timeout in milliseconds (default: 30 minutes = 1_800_000).
/// * `cleanup_interval_ms` - How often to run the cleanup task in milliseconds (default: 60_000).
pub fn new(timeout_ms: u64, cleanup_interval_ms: u64) -> Self {
Self {
sessions: DashMap::new(),
timeout: Duration::from_millis(timeout_ms),
_cleanup_interval: Duration::from_millis(cleanup_interval_ms),
}
}
/// Get an existing session or create a new one. Returns the session id.
pub fn get_or_create_session(&self, id: &str) -> String {
if let Some(mut session) = self.sessions.get_mut(id) {
session.last_activity_at = Instant::now();
return session.id.clone();
}
let now = Instant::now();
let session = Session {
id: id.to_string(),
created_at: now,
last_activity_at: now,
txn_id: None,
in_transaction: false,
};
self.sessions.insert(id.to_string(), session);
debug!(session_id = %id, "created new session");
id.to_string()
}
/// Update the last activity timestamp for a session.
pub fn touch_session(&self, id: &str) {
if let Some(mut session) = self.sessions.get_mut(id) {
session.last_activity_at = Instant::now();
}
}
/// End a session. If a transaction is active, it will be marked for abort.
pub fn end_session(&self, id: &str) {
if let Some((_, session)) = self.sessions.remove(id) {
if session.in_transaction {
warn!(
session_id = %id,
txn_id = ?session.txn_id,
"ending session with active transaction, transaction should be aborted"
);
}
debug!(session_id = %id, "session ended");
}
}
/// Associate a transaction with a session.
pub fn start_transaction(&self, session_id: &str, txn_id: &str) -> TransactionResult<()> {
let mut session = self
.sessions
.get_mut(session_id)
.ok_or_else(|| TransactionError::NotFound(format!("session {}", session_id)))?;
if session.in_transaction {
return Err(TransactionError::AlreadyActive(session_id.to_string()));
}
session.txn_id = Some(txn_id.to_string());
session.in_transaction = true;
session.last_activity_at = Instant::now();
Ok(())
}
/// Disassociate the transaction from a session (after commit or abort).
pub fn end_transaction(&self, session_id: &str) {
if let Some(mut session) = self.sessions.get_mut(session_id) {
session.txn_id = None;
session.in_transaction = false;
session.last_activity_at = Instant::now();
}
}
/// Check whether a session is currently in a transaction.
pub fn is_in_transaction(&self, session_id: &str) -> bool {
self.sessions
.get(session_id)
.map(|s| s.in_transaction)
.unwrap_or(false)
}
/// Get the active transaction id for a session, if any.
pub fn get_transaction_id(&self, session_id: &str) -> Option<String> {
self.sessions
.get(session_id)
.and_then(|s| s.txn_id.clone())
}
/// Extract a session id from a BSON `lsid` value.
///
/// Handles the following formats:
/// - `{ "id": UUID }` (standard driver format)
/// - `{ "id": "string" }` (string shorthand)
/// - `{ "id": Binary(base64) }` (binary UUID)
pub fn extract_session_id(lsid: &Bson) -> Option<String> {
match lsid {
Bson::Document(doc) => {
if let Some(id_val) = doc.get("id") {
match id_val {
Bson::Binary(bin) => {
// UUID stored as Binary subtype 4.
let bytes = &bin.bytes;
if bytes.len() == 16 {
let uuid = uuid::Uuid::from_slice(bytes).ok()?;
Some(uuid.to_string())
} else {
// Fall back to base64 representation.
Some(base64_encode(bytes))
}
}
Bson::String(s) => Some(s.clone()),
_ => Some(format!("{}", id_val)),
}
} else {
None
}
}
Bson::String(s) => Some(s.clone()),
_ => None,
}
}
/// Clean up expired sessions. Returns the number of sessions removed.
pub fn cleanup_expired(&self) -> usize {
let now = Instant::now();
let timeout = self.timeout;
let expired: Vec<String> = self
.sessions
.iter()
.filter(|entry| now.duration_since(entry.last_activity_at) > timeout)
.map(|entry| entry.id.clone())
.collect();
let count = expired.len();
for id in &expired {
debug!(session_id = %id, "cleaning up expired session");
self.sessions.remove(id);
}
count
}
}
impl Default for SessionEngine {
fn default() -> Self {
// 30 minutes timeout, 60 seconds cleanup interval.
Self::new(1_800_000, 60_000)
}
}
/// Simple base64 encoding for binary data (no external dependency needed).
fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
for chunk in data.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let triple = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(triple & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}

View File

@@ -0,0 +1,279 @@
use std::collections::{HashMap, HashSet};
use bson::Document;
use dashmap::DashMap;
use tracing::{debug, warn};
use uuid::Uuid;
use rustdb_storage::StorageAdapter;
use crate::error::{TransactionError, TransactionResult};
/// The status of a transaction.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransactionStatus {
Active,
Committed,
Aborted,
}
/// Describes a write operation within a transaction.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WriteOp {
Insert,
Update,
Delete,
}
/// A single write entry recorded within a transaction.
#[derive(Debug, Clone)]
pub struct WriteEntry {
pub op: WriteOp,
pub doc: Option<Document>,
pub original_doc: Option<Document>,
}
/// Full state of an in-flight transaction.
#[derive(Debug)]
pub struct TransactionState {
pub id: String,
pub session_id: String,
pub status: TransactionStatus,
/// Tracks which documents were read: namespace -> set of doc ids.
pub read_set: HashMap<String, HashSet<String>>,
/// Tracks writes: namespace -> (doc_id -> WriteEntry).
pub write_set: HashMap<String, HashMap<String, WriteEntry>>,
/// Snapshot of collections at transaction start: namespace -> documents.
pub snapshots: HashMap<String, Vec<Document>>,
}
/// Engine that manages transaction lifecycle and conflict detection.
pub struct TransactionEngine {
transactions: DashMap<String, TransactionState>,
}
impl TransactionEngine {
/// Create a new transaction engine.
pub fn new() -> Self {
Self {
transactions: DashMap::new(),
}
}
/// Start a new transaction for the given session.
/// Returns a unique transaction id (UUID v4).
pub fn start_transaction(&self, session_id: &str) -> TransactionResult<String> {
let txn_id = Uuid::new_v4().to_string();
debug!(txn_id = %txn_id, session_id = %session_id, "starting transaction");
let state = TransactionState {
id: txn_id.clone(),
session_id: session_id.to_string(),
status: TransactionStatus::Active,
read_set: HashMap::new(),
write_set: HashMap::new(),
snapshots: HashMap::new(),
};
self.transactions.insert(txn_id.clone(), state);
Ok(txn_id)
}
/// Commit a transaction: check for conflicts, then apply buffered writes
/// to the underlying storage adapter.
pub async fn commit_transaction(
&self,
txn_id: &str,
storage: &dyn StorageAdapter,
) -> TransactionResult<()> {
// Remove the transaction so we own it exclusively.
let mut state = self
.transactions
.remove(txn_id)
.map(|(_, s)| s)
.ok_or_else(|| TransactionError::NotFound(txn_id.to_string()))?;
if state.status != TransactionStatus::Active {
return Err(TransactionError::InvalidState(format!(
"transaction {} is {:?}, cannot commit",
txn_id, state.status
)));
}
// Conflict detection: check if any documents in the read set have
// been modified since the snapshot was taken.
// (Simplified: we skip real snapshot timestamps for now.)
// Apply buffered writes to storage.
for (ns, writes) in &state.write_set {
let parts: Vec<&str> = ns.splitn(2, '.').collect();
if parts.len() != 2 {
warn!(namespace = %ns, "invalid namespace format, skipping");
continue;
}
let (db, coll) = (parts[0], parts[1]);
for (doc_id, entry) in writes {
match entry.op {
WriteOp::Insert => {
if let Some(ref doc) = entry.doc {
let _ = storage.insert_one(db, coll, doc.clone()).await;
}
}
WriteOp::Update => {
if let Some(ref doc) = entry.doc {
let _ = storage.update_by_id(db, coll, doc_id, doc.clone()).await;
}
}
WriteOp::Delete => {
let _ = storage.delete_by_id(db, coll, doc_id).await;
}
}
}
}
state.status = TransactionStatus::Committed;
debug!(txn_id = %txn_id, "transaction committed");
Ok(())
}
/// Abort a transaction, discarding all buffered writes.
pub fn abort_transaction(&self, txn_id: &str) -> TransactionResult<()> {
let mut state = self
.transactions
.get_mut(txn_id)
.ok_or_else(|| TransactionError::NotFound(txn_id.to_string()))?;
if state.status != TransactionStatus::Active {
return Err(TransactionError::InvalidState(format!(
"transaction {} is {:?}, cannot abort",
txn_id, state.status
)));
}
state.status = TransactionStatus::Aborted;
debug!(txn_id = %txn_id, "transaction aborted");
// Drop the mutable ref before removing.
drop(state);
self.transactions.remove(txn_id);
Ok(())
}
/// Check whether a transaction is currently active.
pub fn is_active(&self, txn_id: &str) -> bool {
self.transactions
.get(txn_id)
.map(|s| s.status == TransactionStatus::Active)
.unwrap_or(false)
}
/// Record a document read within a transaction (for conflict detection).
pub fn record_read(&self, txn_id: &str, ns: &str, doc_id: &str) {
if let Some(mut state) = self.transactions.get_mut(txn_id) {
state
.read_set
.entry(ns.to_string())
.or_default()
.insert(doc_id.to_string());
}
}
/// Record a document write within a transaction (buffered until commit).
pub fn record_write(
&self,
txn_id: &str,
ns: &str,
doc_id: &str,
op: WriteOp,
doc: Option<Document>,
original: Option<Document>,
) {
if let Some(mut state) = self.transactions.get_mut(txn_id) {
let entry = WriteEntry {
op,
doc,
original_doc: original,
};
state
.write_set
.entry(ns.to_string())
.or_default()
.insert(doc_id.to_string(), entry);
}
}
/// Get a snapshot of documents for a namespace within a transaction,
/// applying the write overlay (inserts, updates, deletes) on top.
pub fn get_snapshot(&self, txn_id: &str, ns: &str) -> Option<Vec<Document>> {
let state = self.transactions.get(txn_id)?;
// Start with the base snapshot.
let mut docs: Vec<Document> = state
.snapshots
.get(ns)
.cloned()
.unwrap_or_default();
// Apply write overlay.
if let Some(writes) = state.write_set.get(ns) {
// Collect ids to delete.
let delete_ids: HashSet<&String> = writes
.iter()
.filter(|(_, e)| e.op == WriteOp::Delete)
.map(|(id, _)| id)
.collect();
// Remove deleted docs.
docs.retain(|d| {
if let Some(id) = d.get_object_id("_id").ok().map(|oid| oid.to_hex()) {
!delete_ids.contains(&id)
} else {
true
}
});
// Apply updates.
for (doc_id, entry) in writes {
if entry.op == WriteOp::Update {
if let Some(ref new_doc) = entry.doc {
// Replace existing doc with updated version.
let hex_id = doc_id.clone();
if let Some(pos) = docs.iter().position(|d| {
d.get_object_id("_id")
.ok()
.map(|oid| oid.to_hex()) == Some(hex_id.clone())
}) {
docs[pos] = new_doc.clone();
}
}
}
}
// Apply inserts.
for (_doc_id, entry) in writes {
if entry.op == WriteOp::Insert {
if let Some(ref doc) = entry.doc {
docs.push(doc.clone());
}
}
}
}
Some(docs)
}
/// Store a base snapshot for a namespace within a transaction.
pub fn set_snapshot(&self, txn_id: &str, ns: &str, docs: Vec<Document>) {
if let Some(mut state) = self.transactions.get_mut(txn_id) {
state.snapshots.insert(ns.to_string(), docs);
}
}
}
impl Default for TransactionEngine {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "rustdb-wire"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible wire protocol parser and encoder for RustDb"
[dependencies]
bson = { workspace = true }
bytes = { workspace = true }
tokio-util = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
crc32fast = { workspace = true }
[dev-dependencies]
tokio = { workspace = true }

View File

@@ -0,0 +1,49 @@
use bytes::{Buf, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::error::WireError;
use crate::parser::{parse_message, ParsedCommand};
/// Tokio codec for framing wire protocol messages on a TCP stream.
///
/// The wire protocol is naturally length-prefixed:
/// the first 4 bytes of each message contain the total message length.
pub struct WireCodec;
impl Decoder for WireCodec {
type Item = ParsedCommand;
type Error = WireError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
// Peek at message length
let msg_len = i32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
if src.len() < msg_len {
// Reserve space for the rest of the message
src.reserve(msg_len - src.len());
return Ok(None);
}
match parse_message(src)? {
Some((cmd, bytes_consumed)) => {
src.advance(bytes_consumed);
Ok(Some(cmd))
}
None => Ok(None),
}
}
}
/// Encoder for raw byte responses (already serialized by the command handlers).
impl Encoder<Vec<u8>> for WireCodec {
type Error = WireError;
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(&item);
Ok(())
}
}

View File

@@ -0,0 +1,142 @@
use bson::Document;
use crate::opcodes::*;
/// Encode an OP_MSG response.
pub fn encode_op_msg_response(
response_to: i32,
response: &Document,
request_id: i32,
) -> Vec<u8> {
let body_bson = bson::to_vec(response).expect("failed to serialize BSON response");
// Header (16) + flagBits (4) + section type (1) + body BSON
let message_length = 16 + 4 + 1 + body_bson.len();
let mut buf = Vec::with_capacity(message_length);
// Header
buf.extend_from_slice(&(message_length as i32).to_le_bytes());
buf.extend_from_slice(&request_id.to_le_bytes());
buf.extend_from_slice(&response_to.to_le_bytes());
buf.extend_from_slice(&OP_MSG.to_le_bytes());
// Flag bits (0 = no flags)
buf.extend_from_slice(&0u32.to_le_bytes());
// Section type 0 (body)
buf.push(SECTION_BODY);
// Body BSON
buf.extend_from_slice(&body_bson);
buf
}
/// Encode an OP_REPLY response (legacy, for OP_QUERY responses).
pub fn encode_op_reply_response(
response_to: i32,
documents: &[Document],
request_id: i32,
cursor_id: i64,
) -> Vec<u8> {
let doc_buffers: Vec<Vec<u8>> = documents
.iter()
.map(|doc| bson::to_vec(doc).expect("failed to serialize BSON document"))
.collect();
let total_docs_size: usize = doc_buffers.iter().map(|b| b.len()).sum();
// Header (16) + responseFlags (4) + cursorID (8) + startingFrom (4) + numberReturned (4) + docs
let message_length = 16 + 4 + 8 + 4 + 4 + total_docs_size;
let mut buf = Vec::with_capacity(message_length);
// Header
buf.extend_from_slice(&(message_length as i32).to_le_bytes());
buf.extend_from_slice(&request_id.to_le_bytes());
buf.extend_from_slice(&response_to.to_le_bytes());
buf.extend_from_slice(&OP_REPLY.to_le_bytes());
// OP_REPLY fields
buf.extend_from_slice(&0i32.to_le_bytes()); // responseFlags
buf.extend_from_slice(&cursor_id.to_le_bytes()); // cursorID
buf.extend_from_slice(&0i32.to_le_bytes()); // startingFrom
buf.extend_from_slice(&(documents.len() as i32).to_le_bytes()); // numberReturned
// Documents
for doc_buf in &doc_buffers {
buf.extend_from_slice(doc_buf);
}
buf
}
/// Encode an error response as OP_MSG.
pub fn encode_error_response(
response_to: i32,
error_code: i32,
error_message: &str,
request_id: i32,
) -> Vec<u8> {
let response = bson::doc! {
"ok": 0,
"errmsg": error_message,
"code": error_code,
"codeName": error_code_name(error_code),
};
encode_op_msg_response(response_to, &response, request_id)
}
/// Map error codes to their code names.
pub fn error_code_name(code: i32) -> &'static str {
match code {
0 => "OK",
1 => "InternalError",
2 => "BadValue",
13 => "Unauthorized",
26 => "NamespaceNotFound",
27 => "IndexNotFound",
48 => "NamespaceExists",
59 => "CommandNotFound",
66 => "ImmutableField",
73 => "InvalidNamespace",
85 => "IndexOptionsConflict",
112 => "WriteConflict",
121 => "DocumentValidationFailure",
211 => "KeyNotFound",
251 => "NoSuchTransaction",
11000 => "DuplicateKey",
11001 => "DuplicateKeyValue",
_ => "UnknownError",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_op_msg_roundtrip() {
let doc = bson::doc! { "ok": 1 };
let encoded = encode_op_msg_response(1, &doc, 2);
// Verify header
let msg_len = i32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
assert_eq!(msg_len as usize, encoded.len());
let op_code = i32::from_le_bytes([encoded[12], encoded[13], encoded[14], encoded[15]]);
assert_eq!(op_code, OP_MSG);
}
#[test]
fn test_encode_op_reply() {
let docs = vec![bson::doc! { "ok": 1 }];
let encoded = encode_op_reply_response(1, &docs, 2, 0);
let msg_len = i32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
assert_eq!(msg_len as usize, encoded.len());
let op_code = i32::from_le_bytes([encoded[12], encoded[13], encoded[14], encoded[15]]);
assert_eq!(op_code, OP_REPLY);
}
}

View File

@@ -0,0 +1,27 @@
/// Errors from wire protocol parsing/encoding.
#[derive(Debug, thiserror::Error)]
pub enum WireError {
#[error("Incomplete message: need {needed} bytes, have {have}")]
Incomplete { needed: usize, have: usize },
#[error("Unsupported opCode: {0}")]
UnsupportedOpCode(i32),
#[error("Missing command body section in OP_MSG")]
MissingBody,
#[error("Unknown section type: {0}")]
UnknownSectionType(u8),
#[error("BSON deserialization error: {0}")]
BsonError(#[from] bson::de::Error),
#[error("BSON serialization error: {0}")]
BsonSerError(#[from] bson::ser::Error),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Checksum mismatch: expected {expected}, got {actual}")]
ChecksumMismatch { expected: u32, actual: u32 },
}

View File

@@ -0,0 +1,11 @@
mod codec;
mod error;
mod opcodes;
mod parser;
mod encoder;
pub use codec::WireCodec;
pub use error::WireError;
pub use opcodes::*;
pub use parser::*;
pub use encoder::*;

View File

@@ -0,0 +1,19 @@
/// Wire protocol op codes
pub const OP_REPLY: i32 = 1;
pub const OP_UPDATE: i32 = 2001;
pub const OP_INSERT: i32 = 2002;
pub const OP_QUERY: i32 = 2004;
pub const OP_GET_MORE: i32 = 2005;
pub const OP_DELETE: i32 = 2006;
pub const OP_KILL_CURSORS: i32 = 2007;
pub const OP_COMPRESSED: i32 = 2012;
pub const OP_MSG: i32 = 2013;
/// OP_MSG section types
pub const SECTION_BODY: u8 = 0;
pub const SECTION_DOCUMENT_SEQUENCE: u8 = 1;
/// OP_MSG flag bits
pub const MSG_FLAG_CHECKSUM_PRESENT: u32 = 1 << 0;
pub const MSG_FLAG_MORE_TO_COME: u32 = 1 << 1;
pub const MSG_FLAG_EXHAUST_ALLOWED: u32 = 1 << 16;

View File

@@ -0,0 +1,236 @@
use bson::Document;
use std::collections::HashMap;
use crate::error::WireError;
use crate::opcodes::*;
/// Parsed wire protocol message header (16 bytes).
#[derive(Debug, Clone)]
pub struct MessageHeader {
pub message_length: i32,
pub request_id: i32,
pub response_to: i32,
pub op_code: i32,
}
/// A parsed OP_MSG section.
#[derive(Debug, Clone)]
pub enum OpMsgSection {
/// Section type 0: single BSON document body.
Body(Document),
/// Section type 1: named document sequence for bulk operations.
DocumentSequence {
identifier: String,
documents: Vec<Document>,
},
}
/// A fully parsed command extracted from any message type.
#[derive(Debug, Clone)]
pub struct ParsedCommand {
pub command_name: String,
pub command: Document,
pub database: String,
pub request_id: i32,
pub op_code: i32,
/// Document sequences from OP_MSG section type 1 (e.g., "documents" for insert).
pub document_sequences: Option<HashMap<String, Vec<Document>>>,
}
/// Parse a message header from a byte slice (must be >= 16 bytes).
pub fn parse_header(buf: &[u8]) -> MessageHeader {
MessageHeader {
message_length: i32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
request_id: i32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
response_to: i32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]),
op_code: i32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
}
}
/// Parse a complete message from a buffer.
/// Returns the parsed command and bytes consumed, or None if not enough data.
pub fn parse_message(buf: &[u8]) -> Result<Option<(ParsedCommand, usize)>, WireError> {
if buf.len() < 16 {
return Ok(None);
}
let header = parse_header(buf);
let msg_len = header.message_length as usize;
if buf.len() < msg_len {
return Ok(None);
}
let message_buf = &buf[..msg_len];
match header.op_code {
OP_MSG => parse_op_msg(message_buf, &header).map(|cmd| Some((cmd, msg_len))),
OP_QUERY => parse_op_query(message_buf, &header).map(|cmd| Some((cmd, msg_len))),
other => Err(WireError::UnsupportedOpCode(other)),
}
}
/// Parse an OP_MSG message.
fn parse_op_msg(buf: &[u8], header: &MessageHeader) -> Result<ParsedCommand, WireError> {
let mut offset = 16; // skip header
let flag_bits = u32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
offset += 4;
let mut body: Option<Document> = None;
let mut document_sequences: HashMap<String, Vec<Document>> = HashMap::new();
// Parse sections until end (or checksum)
let message_end = if flag_bits & MSG_FLAG_CHECKSUM_PRESENT != 0 {
header.message_length as usize - 4
} else {
header.message_length as usize
};
while offset < message_end {
let section_type = buf[offset];
offset += 1;
match section_type {
SECTION_BODY => {
let doc_size = i32::from_le_bytes([
buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
]) as usize;
let doc = bson::from_slice(&buf[offset..offset + doc_size])?;
body = Some(doc);
offset += doc_size;
}
SECTION_DOCUMENT_SEQUENCE => {
let section_size = i32::from_le_bytes([
buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
]) as usize;
let section_end = offset + section_size;
offset += 4;
// Read identifier (C string, null-terminated)
let id_start = offset;
while offset < section_end && buf[offset] != 0 {
offset += 1;
}
let identifier = std::str::from_utf8(&buf[id_start..offset])
.unwrap_or("")
.to_string();
offset += 1; // skip null terminator
// Read documents
let mut documents = Vec::new();
while offset < section_end {
let doc_size = i32::from_le_bytes([
buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
]) as usize;
let doc = bson::from_slice(&buf[offset..offset + doc_size])?;
documents.push(doc);
offset += doc_size;
}
document_sequences.insert(identifier, documents);
}
other => return Err(WireError::UnknownSectionType(other)),
}
}
let command = body.ok_or(WireError::MissingBody)?;
let command_name = command
.keys()
.next()
.map(|s| s.to_string())
.unwrap_or_default();
let database = command
.get_str("$db")
.unwrap_or("admin")
.to_string();
Ok(ParsedCommand {
command_name,
command,
database,
request_id: header.request_id,
op_code: header.op_code,
document_sequences: if document_sequences.is_empty() {
None
} else {
Some(document_sequences)
},
})
}
/// Parse an OP_QUERY message (legacy, used for initial driver handshake).
fn parse_op_query(buf: &[u8], header: &MessageHeader) -> Result<ParsedCommand, WireError> {
let mut offset = 16; // skip header
let _flags = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
offset += 4;
// Read full collection name (C string)
let name_start = offset;
while offset < buf.len() && buf[offset] != 0 {
offset += 1;
}
let full_collection_name = std::str::from_utf8(&buf[name_start..offset])
.unwrap_or("")
.to_string();
offset += 1; // skip null terminator
let _number_to_skip = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
offset += 4;
let _number_to_return = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
offset += 4;
// Read query document
let doc_size = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]) as usize;
let query: Document = bson::from_slice(&buf[offset..offset + doc_size])?;
// Extract database from collection name (format: "dbname.$cmd")
let parts: Vec<&str> = full_collection_name.splitn(2, '.').collect();
let database = parts.first().unwrap_or(&"admin").to_string();
let mut command_name = query
.keys()
.next()
.map(|s| s.to_string())
.unwrap_or_else(|| "find".to_string());
// Map legacy isMaster/ismaster to hello
if parts.get(1) == Some(&"$cmd") {
if command_name == "isMaster" || command_name == "ismaster" {
command_name = "hello".to_string();
}
} else {
command_name = "find".to_string();
}
Ok(ParsedCommand {
command_name,
command: query,
database,
request_id: header.request_id,
op_code: header.op_code,
document_sequences: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_header() {
let mut buf = [0u8; 16];
buf[0..4].copy_from_slice(&100i32.to_le_bytes()); // messageLength
buf[4..8].copy_from_slice(&42i32.to_le_bytes()); // requestID
buf[8..12].copy_from_slice(&0i32.to_le_bytes()); // responseTo
buf[12..16].copy_from_slice(&OP_MSG.to_le_bytes()); // opCode
let header = parse_header(&buf);
assert_eq!(header.message_length, 100);
assert_eq!(header.request_id, 42);
assert_eq!(header.response_to, 0);
assert_eq!(header.op_code, OP_MSG);
}
}

View File

@@ -0,0 +1,38 @@
[package]
name = "rustdb"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible embedded database server with wire protocol support"
[[bin]]
name = "rustdb"
path = "src/main.rs"
[lib]
name = "rustdb"
path = "src/lib.rs"
[dependencies]
rustdb-config = { workspace = true }
rustdb-wire = { workspace = true }
rustdb-query = { workspace = true }
rustdb-storage = { workspace = true }
rustdb-index = { workspace = true }
rustdb-txn = { workspace = true }
rustdb-commands = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
clap = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
bson = { workspace = true }
bytes = { workspace = true }
dashmap = { workspace = true }
mimalloc = { workspace = true }
futures-util = { version = "0.3", features = ["sink"] }

View File

@@ -0,0 +1,213 @@
pub mod management;
use std::sync::Arc;
use anyhow::Result;
use dashmap::DashMap;
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio_util::codec::Framed;
use tokio_util::sync::CancellationToken;
use rustdb_config::{RustDbOptions, StorageType};
use rustdb_wire::{WireCodec, OP_QUERY};
use rustdb_wire::{encode_op_msg_response, encode_op_reply_response};
use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter};
// IndexEngine is used indirectly via CommandContext
use rustdb_txn::{TransactionEngine, SessionEngine};
use rustdb_commands::{CommandRouter, CommandContext};
/// The main RustDb server.
pub struct RustDb {
options: RustDbOptions,
ctx: Arc<CommandContext>,
router: Arc<CommandRouter>,
cancel_token: CancellationToken,
listener_handle: Option<tokio::task::JoinHandle<()>>,
}
impl RustDb {
/// Create a new RustDb server with the given options.
pub async fn new(options: RustDbOptions) -> Result<Self> {
// Create storage adapter
let storage: Arc<dyn StorageAdapter> = match options.storage {
StorageType::Memory => {
let adapter = MemoryStorageAdapter::new();
Arc::new(adapter)
}
StorageType::File => {
let path = options
.storage_path
.clone()
.unwrap_or_else(|| "./data".to_string());
let adapter = FileStorageAdapter::new(&path);
Arc::new(adapter)
}
};
// Initialize storage
storage.initialize().await?;
let ctx = Arc::new(CommandContext {
storage,
indexes: Arc::new(DashMap::new()),
transactions: Arc::new(TransactionEngine::new()),
sessions: Arc::new(SessionEngine::new(30 * 60 * 1000, 60 * 1000)),
cursors: Arc::new(DashMap::new()),
start_time: std::time::Instant::now(),
});
let router = Arc::new(CommandRouter::new(ctx.clone()));
Ok(Self {
options,
ctx,
router,
cancel_token: CancellationToken::new(),
listener_handle: None,
})
}
/// Start listening for connections.
pub async fn start(&mut self) -> Result<()> {
let cancel = self.cancel_token.clone();
let router = self.router.clone();
if let Some(ref socket_path) = self.options.socket_path {
#[cfg(unix)]
{
// Remove stale socket file
let _ = tokio::fs::remove_file(socket_path).await;
let listener = UnixListener::bind(socket_path)?;
let socket_path_clone = socket_path.clone();
tracing::info!("RustDb listening on unix:{}", socket_path_clone);
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => break,
result = listener.accept() => {
match result {
Ok((stream, _addr)) => {
let router = router.clone();
tokio::spawn(async move {
handle_connection(stream, router).await;
});
}
Err(e) => {
tracing::error!("Accept error: {}", e);
}
}
}
}
}
});
self.listener_handle = Some(handle);
}
#[cfg(not(unix))]
{
anyhow::bail!("Unix sockets are not supported on this platform");
}
} else {
let addr = format!("{}:{}", self.options.host, self.options.port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!("RustDb listening on {}", addr);
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => break,
result = listener.accept() => {
match result {
Ok((stream, _addr)) => {
let _ = stream.set_nodelay(true);
let router = router.clone();
tokio::spawn(async move {
handle_connection(stream, router).await;
});
}
Err(e) => {
tracing::error!("Accept error: {}", e);
}
}
}
}
}
});
self.listener_handle = Some(handle);
}
Ok(())
}
/// Stop the server.
pub async fn stop(&mut self) -> Result<()> {
self.cancel_token.cancel();
if let Some(handle) = self.listener_handle.take() {
handle.abort();
let _ = handle.await;
}
// Close storage (persists if configured)
self.ctx.storage.close().await?;
// Clean up Unix socket file
if let Some(ref socket_path) = self.options.socket_path {
let _ = tokio::fs::remove_file(socket_path).await;
}
Ok(())
}
/// Get the connection URI.
pub fn connection_uri(&self) -> String {
self.options.connection_uri()
}
}
/// Handle a single client connection using the wire protocol codec.
async fn handle_connection<S>(stream: S, router: Arc<CommandRouter>)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
use futures_util::{SinkExt, StreamExt};
let mut framed = Framed::new(stream, WireCodec);
while let Some(result) = framed.next().await {
match result {
Ok(parsed_cmd) => {
let request_id = parsed_cmd.request_id;
let op_code = parsed_cmd.op_code;
let response_doc = router.route(&parsed_cmd).await;
let response_id = next_request_id();
let response_bytes = if op_code == OP_QUERY {
encode_op_reply_response(request_id, &[response_doc], response_id, 0)
} else {
encode_op_msg_response(request_id, &response_doc, response_id)
};
if let Err(e) = framed.send(response_bytes).await {
tracing::debug!("Failed to send response: {}", e);
break;
}
}
Err(e) => {
tracing::debug!("Wire protocol error: {}", e);
break;
}
}
}
}
fn next_request_id() -> i32 {
use std::sync::atomic::{AtomicI32, Ordering};
static COUNTER: AtomicI32 = AtomicI32::new(1);
COUNTER.fetch_add(1, Ordering::Relaxed)
}

View File

@@ -0,0 +1,85 @@
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustdb::RustDb;
use rustdb::management;
use rustdb_config::RustDbOptions;
/// RustDb - MongoDB-compatible embedded database server
#[derive(Parser, Debug)]
#[command(name = "rustdb", version, about)]
struct Cli {
/// Path to JSON configuration file
#[arg(short, long, default_value = "config.json")]
config: String,
/// Log level (trace, debug, info, warn, error)
#[arg(short, long, default_value = "info")]
log_level: String,
/// Validate configuration without starting
#[arg(long)]
validate: bool,
/// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper)
#[arg(long)]
management: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize tracing - write to stderr so stdout is reserved for management IPC
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
)
.init();
// Management mode: JSON IPC over stdin/stdout
if cli.management {
tracing::info!("RustDb starting in management mode...");
return management::management_loop().await;
}
tracing::info!("RustDb starting...");
// Load configuration
let options = RustDbOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
// Validate-only mode
if cli.validate {
match options.validate() {
Ok(()) => {
tracing::info!("Configuration is valid");
return Ok(());
}
Err(e) => {
tracing::error!("Validation error: {}", e);
anyhow::bail!("Configuration validation failed: {}", e);
}
}
}
// Create and start server
let mut db = RustDb::new(options).await?;
db.start().await?;
// Wait for shutdown signal
tracing::info!("RustDb is running. Press Ctrl+C to stop.");
tokio::signal::ctrl_c().await?;
tracing::info!("Shutdown signal received");
db.stop().await?;
tracing::info!("RustDb shutdown complete");
Ok(())
}

View File

@@ -0,0 +1,240 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use crate::RustDb;
use rustdb_config::RustDbOptions;
/// A management request from the TypeScript wrapper.
#[derive(Debug, Deserialize)]
pub struct ManagementRequest {
pub id: String,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// A management response back to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementResponse {
pub id: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// An unsolicited event from the server to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementEvent {
pub event: String,
pub data: serde_json::Value,
}
impl ManagementResponse {
fn ok(id: String, result: serde_json::Value) -> Self {
Self {
id,
success: true,
result: Some(result),
error: None,
}
}
fn err(id: String, message: String) -> Self {
Self {
id,
success: false,
result: None,
error: Some(message),
}
}
}
fn send_line(line: &str) {
use std::io::Write;
let stdout = std::io::stdout();
let mut handle = stdout.lock();
let _ = handle.write_all(line.as_bytes());
let _ = handle.write_all(b"\n");
let _ = handle.flush();
}
fn send_response(response: &ManagementResponse) {
match serde_json::to_string(response) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management response: {}", e),
}
}
fn send_event(event: &str, data: serde_json::Value) {
let evt = ManagementEvent {
event: event.to_string(),
data,
};
match serde_json::to_string(&evt) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management event: {}", e),
}
}
/// Run the management loop, reading JSON commands from stdin and writing responses to stdout.
pub async fn management_loop() -> Result<()> {
let stdin = BufReader::new(tokio::io::stdin());
let mut lines = stdin.lines();
let mut db: Option<RustDb> = None;
send_event("ready", serde_json::json!({}));
loop {
let line = match lines.next_line().await {
Ok(Some(line)) => line,
Ok(None) => {
// stdin closed - parent process exited
info!("Management stdin closed, shutting down");
if let Some(ref mut d) = db {
let _ = d.stop().await;
}
break;
}
Err(e) => {
error!("Error reading management stdin: {}", e);
break;
}
};
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let request: ManagementRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
error!("Failed to parse management request: {}", e);
send_response(&ManagementResponse::err(
"unknown".to_string(),
format!("Failed to parse request: {}", e),
));
continue;
}
};
let response = handle_request(&request, &mut db).await;
send_response(&response);
}
Ok(())
}
async fn handle_request(
request: &ManagementRequest,
db: &mut Option<RustDb>,
) -> ManagementResponse {
let id = request.id.clone();
match request.method.as_str() {
"start" => handle_start(&id, &request.params, db).await,
"stop" => handle_stop(&id, db).await,
"getStatus" => handle_get_status(&id, db),
"getMetrics" => handle_get_metrics(&id, db),
_ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)),
}
}
async fn handle_start(
id: &str,
params: &serde_json::Value,
db: &mut Option<RustDb>,
) -> ManagementResponse {
if db.is_some() {
return ManagementResponse::err(id.to_string(), "Server is already running".to_string());
}
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
};
let options: RustDbOptions = match serde_json::from_value(config.clone()) {
Ok(o) => o,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
};
let connection_uri = options.connection_uri();
match RustDb::new(options).await {
Ok(mut d) => {
match d.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*db = Some(d);
ManagementResponse::ok(
id.to_string(),
serde_json::json!({ "connectionUri": connection_uri }),
)
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create server: {}", e)),
}
}
async fn handle_stop(
id: &str,
db: &mut Option<RustDb>,
) -> ManagementResponse {
match db.as_mut() {
Some(d) => {
match d.stop().await {
Ok(()) => {
*db = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
fn handle_get_status(
id: &str,
db: &Option<RustDb>,
) -> ManagementResponse {
match db.as_ref() {
Some(_d) => ManagementResponse::ok(
id.to_string(),
serde_json::json!({
"running": true,
}),
),
None => ManagementResponse::ok(
id.to_string(),
serde_json::json!({ "running": false }),
),
}
}
fn handle_get_metrics(
id: &str,
db: &Option<RustDb>,
) -> ManagementResponse {
match db.as_ref() {
Some(_d) => ManagementResponse::ok(
id.to_string(),
serde_json::json!({
"connections": 0,
"databases": 0,
}),
),
None => ManagementResponse::err(id.to_string(), "Server is not running".to_string()),
}
}