use std::sync::atomic::{AtomicI64, Ordering}; use bson::{doc, Bson, Document}; use tracing::debug; use rustdb_query::{QueryMatcher, sort_documents, apply_projection, distinct_values}; use crate::context::{CommandContext, CursorState}; use crate::error::{CommandError, CommandResult}; /// Atomic counter for generating unique cursor IDs. static CURSOR_ID_COUNTER: AtomicI64 = AtomicI64::new(1); /// Generate a new unique, positive cursor ID. fn next_cursor_id() -> i64 { CURSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed) } // --------------------------------------------------------------------------- // Helpers to defensively extract values from BSON command documents // --------------------------------------------------------------------------- fn get_str<'a>(doc: &'a Document, key: &str) -> Option<&'a str> { match doc.get(key)? { Bson::String(s) => Some(s.as_str()), _ => None, } } fn get_i32(doc: &Document, key: &str) -> Option { match doc.get(key)? { Bson::Int32(v) => Some(*v), Bson::Int64(v) => Some(*v as i32), Bson::Double(v) => Some(*v as i32), _ => None, } } fn get_i64(doc: &Document, key: &str) -> Option { match doc.get(key)? { Bson::Int64(v) => Some(*v), Bson::Int32(v) => Some(*v as i64), Bson::Double(v) => Some(*v as i64), _ => None, } } fn get_bool(doc: &Document, key: &str) -> Option { match doc.get(key)? { Bson::Boolean(v) => Some(*v), _ => None, } } fn get_document<'a>(doc: &'a Document, key: &str) -> Option<&'a Document> { match doc.get(key)? { Bson::Document(d) => Some(d), _ => None, } } // --------------------------------------------------------------------------- // find // --------------------------------------------------------------------------- /// Handle the `find` command. pub async fn handle( cmd: &Document, db: &str, ctx: &CommandContext, ) -> CommandResult { let coll = get_str(cmd, "find").unwrap_or("unknown"); let ns = format!("{}.{}", db, coll); // Extract optional parameters. let filter = get_document(cmd, "filter").cloned().unwrap_or_default(); let sort_spec = get_document(cmd, "sort").cloned(); let projection = get_document(cmd, "projection").cloned(); let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize; let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize; let batch_size = get_i32(cmd, "batchSize").unwrap_or(101).max(0) as usize; let single_batch = get_bool(cmd, "singleBatch").unwrap_or(false); // If the collection does not exist, return an empty cursor. let exists = ctx.storage.collection_exists(db, coll).await?; if !exists { return Ok(doc! { "cursor": { "firstBatch": [], "id": 0_i64, "ns": &ns, }, "ok": 1.0, }); } // Try index-accelerated lookup. let index_key = format!("{}.{}", db, coll); let docs = if let Some(idx_ref) = ctx.indexes.get(&index_key) { if let Some(candidate_ids) = idx_ref.find_candidate_ids(&filter) { debug!( ns = %ns, candidates = candidate_ids.len(), "using index acceleration" ); ctx.storage.find_by_ids(db, coll, candidate_ids).await? } else { ctx.storage.find_all(db, coll).await? } } else { ctx.storage.find_all(db, coll).await? }; // Apply filter. let mut docs = QueryMatcher::filter(&docs, &filter); // Apply sort. if let Some(ref sort) = sort_spec { sort_documents(&mut docs, sort); } // Apply skip. if skip > 0 { if skip >= docs.len() { docs = Vec::new(); } else { docs = docs.split_off(skip); } } // Apply limit. if limit > 0 && docs.len() > limit { docs.truncate(limit); } // Apply projection. if let Some(ref proj) = projection { docs = docs.iter().map(|d| apply_projection(d, proj)).collect(); } // Determine first batch. if docs.len() <= batch_size || single_batch { // Everything fits in a single batch. let batch: Vec = docs.into_iter().map(Bson::Document).collect(); Ok(doc! { "cursor": { "firstBatch": batch, "id": 0_i64, "ns": &ns, }, "ok": 1.0, }) } else { // Split into first batch and remainder, store cursor. let remaining = docs.split_off(batch_size); let first_batch: Vec = docs.into_iter().map(Bson::Document).collect(); let cursor_id = next_cursor_id(); ctx.cursors.insert(cursor_id, CursorState { documents: remaining, position: 0, database: db.to_string(), collection: coll.to_string(), }); Ok(doc! { "cursor": { "firstBatch": first_batch, "id": cursor_id, "ns": &ns, }, "ok": 1.0, }) } } // --------------------------------------------------------------------------- // getMore // --------------------------------------------------------------------------- /// Handle the `getMore` command. pub async fn handle_get_more( cmd: &Document, db: &str, ctx: &CommandContext, ) -> CommandResult { // Defensively extract cursor id. let cursor_id = get_i64(cmd, "getMore").ok_or_else(|| { CommandError::InvalidArgument("getMore requires a cursor id".into()) })?; let coll = get_str(cmd, "collection").unwrap_or("unknown"); let ns = format!("{}.{}", db, coll); let batch_size = get_i64(cmd, "batchSize") .or_else(|| get_i32(cmd, "batchSize").map(|v| v as i64)) .unwrap_or(101) .max(0) as usize; // Look up the cursor. let mut cursor_entry = ctx.cursors.get_mut(&cursor_id).ok_or_else(|| { CommandError::InvalidArgument(format!("cursor id {} not found", cursor_id)) })?; let cursor = cursor_entry.value_mut(); let start = cursor.position; let end = (start + batch_size).min(cursor.documents.len()); let batch: Vec = cursor.documents[start..end] .iter() .cloned() .map(Bson::Document) .collect(); cursor.position = end; let exhausted = cursor.position >= cursor.documents.len(); // Must drop the mutable reference before removing. drop(cursor_entry); if exhausted { ctx.cursors.remove(&cursor_id); Ok(doc! { "cursor": { "nextBatch": batch, "id": 0_i64, "ns": &ns, }, "ok": 1.0, }) } else { Ok(doc! { "cursor": { "nextBatch": batch, "id": cursor_id, "ns": &ns, }, "ok": 1.0, }) } } // --------------------------------------------------------------------------- // killCursors // --------------------------------------------------------------------------- /// Handle the `killCursors` command. pub async fn handle_kill_cursors( cmd: &Document, _db: &str, ctx: &CommandContext, ) -> CommandResult { let cursor_ids = match cmd.get("cursors") { Some(Bson::Array(arr)) => arr, _ => { return Ok(doc! { "cursorsKilled": [], "cursorsNotFound": [], "cursorsAlive": [], "cursorsUnknown": [], "ok": 1.0, }); } }; let mut killed: Vec = Vec::new(); let mut not_found: Vec = Vec::new(); for id_bson in cursor_ids { let id = match id_bson { Bson::Int64(v) => *v, Bson::Int32(v) => *v as i64, _ => continue, }; if ctx.cursors.remove(&id).is_some() { killed.push(Bson::Int64(id)); } else { not_found.push(Bson::Int64(id)); } } Ok(doc! { "cursorsKilled": killed, "cursorsNotFound": not_found, "cursorsAlive": [], "cursorsUnknown": [], "ok": 1.0, }) } // --------------------------------------------------------------------------- // count // --------------------------------------------------------------------------- /// Handle the `count` command. pub async fn handle_count( cmd: &Document, db: &str, ctx: &CommandContext, ) -> CommandResult { let coll = get_str(cmd, "count").unwrap_or("unknown"); // Check collection existence. let exists = ctx.storage.collection_exists(db, coll).await?; if !exists { return Ok(doc! { "n": 0_i64, "ok": 1.0 }); } let query = get_document(cmd, "query").cloned().unwrap_or_default(); let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize; let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize; let count: u64 = if query.is_empty() && skip == 0 && limit == 0 { // Fast path: use storage-level count. ctx.storage.count(db, coll).await? } else if query.is_empty() { // No filter but skip/limit apply. let total = ctx.storage.count(db, coll).await? as usize; let after_skip = total.saturating_sub(skip); let result = if limit > 0 { after_skip.min(limit) } else { after_skip }; result as u64 } else { // Need to load and filter. let docs = ctx.storage.find_all(db, coll).await?; let filtered = QueryMatcher::filter(&docs, &query); let mut n = filtered.len(); // Apply skip. n = n.saturating_sub(skip); // Apply limit. if limit > 0 { n = n.min(limit); } n as u64 }; Ok(doc! { "n": count as i64, "ok": 1.0, }) } // --------------------------------------------------------------------------- // distinct // --------------------------------------------------------------------------- /// Handle the `distinct` command. pub async fn handle_distinct( cmd: &Document, db: &str, ctx: &CommandContext, ) -> CommandResult { let coll = get_str(cmd, "distinct").unwrap_or("unknown"); let key = get_str(cmd, "key").ok_or_else(|| { CommandError::InvalidArgument("distinct requires a 'key' field".into()) })?; // Check collection existence. let exists = ctx.storage.collection_exists(db, coll).await?; if !exists { return Ok(doc! { "values": [], "ok": 1.0 }); } let query = get_document(cmd, "query").cloned(); let docs = ctx.storage.find_all(db, coll).await?; let values = distinct_values(&docs, key, query.as_ref()); Ok(doc! { "values": values, "ok": 1.0, }) }