371 lines
11 KiB
Rust
371 lines
11 KiB
Rust
use std::sync::atomic::{AtomicI64, Ordering};
|
|
|
|
use bson::{doc, Bson, Document};
|
|
use tracing::debug;
|
|
|
|
use rustdb_query::{QueryMatcher, sort_documents, apply_projection, distinct_values};
|
|
|
|
use crate::context::{CommandContext, CursorState};
|
|
use crate::error::{CommandError, CommandResult};
|
|
|
|
/// Atomic counter for generating unique cursor IDs.
|
|
static CURSOR_ID_COUNTER: AtomicI64 = AtomicI64::new(1);
|
|
|
|
/// Generate a new unique, positive cursor ID.
|
|
fn next_cursor_id() -> i64 {
|
|
CURSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Helpers to defensively extract values from BSON command documents
|
|
// ---------------------------------------------------------------------------
|
|
|
|
fn get_str<'a>(doc: &'a Document, key: &str) -> Option<&'a str> {
|
|
match doc.get(key)? {
|
|
Bson::String(s) => Some(s.as_str()),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn get_i32(doc: &Document, key: &str) -> Option<i32> {
|
|
match doc.get(key)? {
|
|
Bson::Int32(v) => Some(*v),
|
|
Bson::Int64(v) => Some(*v as i32),
|
|
Bson::Double(v) => Some(*v as i32),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn get_i64(doc: &Document, key: &str) -> Option<i64> {
|
|
match doc.get(key)? {
|
|
Bson::Int64(v) => Some(*v),
|
|
Bson::Int32(v) => Some(*v as i64),
|
|
Bson::Double(v) => Some(*v as i64),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn get_bool(doc: &Document, key: &str) -> Option<bool> {
|
|
match doc.get(key)? {
|
|
Bson::Boolean(v) => Some(*v),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn get_document<'a>(doc: &'a Document, key: &str) -> Option<&'a Document> {
|
|
match doc.get(key)? {
|
|
Bson::Document(d) => Some(d),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// find
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Handle the `find` command.
|
|
pub async fn handle(
|
|
cmd: &Document,
|
|
db: &str,
|
|
ctx: &CommandContext,
|
|
) -> CommandResult<Document> {
|
|
let coll = get_str(cmd, "find").unwrap_or("unknown");
|
|
let ns = format!("{}.{}", db, coll);
|
|
|
|
// Extract optional parameters.
|
|
let filter = get_document(cmd, "filter").cloned().unwrap_or_default();
|
|
let sort_spec = get_document(cmd, "sort").cloned();
|
|
let projection = get_document(cmd, "projection").cloned();
|
|
let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize;
|
|
let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize;
|
|
let batch_size = get_i32(cmd, "batchSize").unwrap_or(101).max(0) as usize;
|
|
let single_batch = get_bool(cmd, "singleBatch").unwrap_or(false);
|
|
|
|
// If the collection does not exist, return an empty cursor.
|
|
let exists = ctx.storage.collection_exists(db, coll).await?;
|
|
if !exists {
|
|
return Ok(doc! {
|
|
"cursor": {
|
|
"firstBatch": [],
|
|
"id": 0_i64,
|
|
"ns": &ns,
|
|
},
|
|
"ok": 1.0,
|
|
});
|
|
}
|
|
|
|
// Try index-accelerated lookup.
|
|
let index_key = format!("{}.{}", db, coll);
|
|
let docs = if let Some(idx_ref) = ctx.indexes.get(&index_key) {
|
|
if let Some(candidate_ids) = idx_ref.find_candidate_ids(&filter) {
|
|
debug!(
|
|
ns = %ns,
|
|
candidates = candidate_ids.len(),
|
|
"using index acceleration"
|
|
);
|
|
ctx.storage.find_by_ids(db, coll, candidate_ids).await?
|
|
} else {
|
|
ctx.storage.find_all(db, coll).await?
|
|
}
|
|
} else {
|
|
ctx.storage.find_all(db, coll).await?
|
|
};
|
|
|
|
// Apply filter.
|
|
let mut docs = QueryMatcher::filter(&docs, &filter);
|
|
|
|
// Apply sort.
|
|
if let Some(ref sort) = sort_spec {
|
|
sort_documents(&mut docs, sort);
|
|
}
|
|
|
|
// Apply skip.
|
|
if skip > 0 {
|
|
if skip >= docs.len() {
|
|
docs = Vec::new();
|
|
} else {
|
|
docs = docs.split_off(skip);
|
|
}
|
|
}
|
|
|
|
// Apply limit.
|
|
if limit > 0 && docs.len() > limit {
|
|
docs.truncate(limit);
|
|
}
|
|
|
|
// Apply projection.
|
|
if let Some(ref proj) = projection {
|
|
docs = docs.iter().map(|d| apply_projection(d, proj)).collect();
|
|
}
|
|
|
|
// Determine first batch.
|
|
if docs.len() <= batch_size || single_batch {
|
|
// Everything fits in a single batch.
|
|
let batch: Vec<Bson> = docs.into_iter().map(Bson::Document).collect();
|
|
Ok(doc! {
|
|
"cursor": {
|
|
"firstBatch": batch,
|
|
"id": 0_i64,
|
|
"ns": &ns,
|
|
},
|
|
"ok": 1.0,
|
|
})
|
|
} else {
|
|
// Split into first batch and remainder, store cursor.
|
|
let remaining = docs.split_off(batch_size);
|
|
let first_batch: Vec<Bson> = docs.into_iter().map(Bson::Document).collect();
|
|
|
|
let cursor_id = next_cursor_id();
|
|
ctx.cursors.insert(cursor_id, CursorState {
|
|
documents: remaining,
|
|
position: 0,
|
|
database: db.to_string(),
|
|
collection: coll.to_string(),
|
|
});
|
|
|
|
Ok(doc! {
|
|
"cursor": {
|
|
"firstBatch": first_batch,
|
|
"id": cursor_id,
|
|
"ns": &ns,
|
|
},
|
|
"ok": 1.0,
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// getMore
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Handle the `getMore` command.
|
|
pub async fn handle_get_more(
|
|
cmd: &Document,
|
|
db: &str,
|
|
ctx: &CommandContext,
|
|
) -> CommandResult<Document> {
|
|
// Defensively extract cursor id.
|
|
let cursor_id = get_i64(cmd, "getMore").ok_or_else(|| {
|
|
CommandError::InvalidArgument("getMore requires a cursor id".into())
|
|
})?;
|
|
|
|
let coll = get_str(cmd, "collection").unwrap_or("unknown");
|
|
let ns = format!("{}.{}", db, coll);
|
|
|
|
let batch_size = get_i64(cmd, "batchSize")
|
|
.or_else(|| get_i32(cmd, "batchSize").map(|v| v as i64))
|
|
.unwrap_or(101)
|
|
.max(0) as usize;
|
|
|
|
// Look up the cursor.
|
|
let mut cursor_entry = ctx.cursors.get_mut(&cursor_id).ok_or_else(|| {
|
|
CommandError::InvalidArgument(format!("cursor id {} not found", cursor_id))
|
|
})?;
|
|
|
|
let cursor = cursor_entry.value_mut();
|
|
let start = cursor.position;
|
|
let end = (start + batch_size).min(cursor.documents.len());
|
|
let batch: Vec<Bson> = cursor.documents[start..end]
|
|
.iter()
|
|
.cloned()
|
|
.map(Bson::Document)
|
|
.collect();
|
|
cursor.position = end;
|
|
|
|
let exhausted = cursor.position >= cursor.documents.len();
|
|
|
|
// Must drop the mutable reference before removing.
|
|
drop(cursor_entry);
|
|
|
|
if exhausted {
|
|
ctx.cursors.remove(&cursor_id);
|
|
Ok(doc! {
|
|
"cursor": {
|
|
"nextBatch": batch,
|
|
"id": 0_i64,
|
|
"ns": &ns,
|
|
},
|
|
"ok": 1.0,
|
|
})
|
|
} else {
|
|
Ok(doc! {
|
|
"cursor": {
|
|
"nextBatch": batch,
|
|
"id": cursor_id,
|
|
"ns": &ns,
|
|
},
|
|
"ok": 1.0,
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// killCursors
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Handle the `killCursors` command.
|
|
pub async fn handle_kill_cursors(
|
|
cmd: &Document,
|
|
_db: &str,
|
|
ctx: &CommandContext,
|
|
) -> CommandResult<Document> {
|
|
let cursor_ids = match cmd.get("cursors") {
|
|
Some(Bson::Array(arr)) => arr,
|
|
_ => {
|
|
return Ok(doc! {
|
|
"cursorsKilled": [],
|
|
"cursorsNotFound": [],
|
|
"cursorsAlive": [],
|
|
"cursorsUnknown": [],
|
|
"ok": 1.0,
|
|
});
|
|
}
|
|
};
|
|
|
|
let mut killed: Vec<Bson> = Vec::new();
|
|
let mut not_found: Vec<Bson> = Vec::new();
|
|
|
|
for id_bson in cursor_ids {
|
|
let id = match id_bson {
|
|
Bson::Int64(v) => *v,
|
|
Bson::Int32(v) => *v as i64,
|
|
_ => continue,
|
|
};
|
|
if ctx.cursors.remove(&id).is_some() {
|
|
killed.push(Bson::Int64(id));
|
|
} else {
|
|
not_found.push(Bson::Int64(id));
|
|
}
|
|
}
|
|
|
|
Ok(doc! {
|
|
"cursorsKilled": killed,
|
|
"cursorsNotFound": not_found,
|
|
"cursorsAlive": [],
|
|
"cursorsUnknown": [],
|
|
"ok": 1.0,
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// count
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Handle the `count` command.
|
|
pub async fn handle_count(
|
|
cmd: &Document,
|
|
db: &str,
|
|
ctx: &CommandContext,
|
|
) -> CommandResult<Document> {
|
|
let coll = get_str(cmd, "count").unwrap_or("unknown");
|
|
|
|
// Check collection existence.
|
|
let exists = ctx.storage.collection_exists(db, coll).await?;
|
|
if !exists {
|
|
return Ok(doc! { "n": 0_i64, "ok": 1.0 });
|
|
}
|
|
|
|
let query = get_document(cmd, "query").cloned().unwrap_or_default();
|
|
let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize;
|
|
let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize;
|
|
|
|
let count: u64 = if query.is_empty() && skip == 0 && limit == 0 {
|
|
// Fast path: use storage-level count.
|
|
ctx.storage.count(db, coll).await?
|
|
} else if query.is_empty() {
|
|
// No filter but skip/limit apply.
|
|
let total = ctx.storage.count(db, coll).await? as usize;
|
|
let after_skip = total.saturating_sub(skip);
|
|
let result = if limit > 0 { after_skip.min(limit) } else { after_skip };
|
|
result as u64
|
|
} else {
|
|
// Need to load and filter.
|
|
let docs = ctx.storage.find_all(db, coll).await?;
|
|
let filtered = QueryMatcher::filter(&docs, &query);
|
|
let mut n = filtered.len();
|
|
// Apply skip.
|
|
n = n.saturating_sub(skip);
|
|
// Apply limit.
|
|
if limit > 0 {
|
|
n = n.min(limit);
|
|
}
|
|
n as u64
|
|
};
|
|
|
|
Ok(doc! {
|
|
"n": count as i64,
|
|
"ok": 1.0,
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// distinct
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Handle the `distinct` command.
|
|
pub async fn handle_distinct(
|
|
cmd: &Document,
|
|
db: &str,
|
|
ctx: &CommandContext,
|
|
) -> CommandResult<Document> {
|
|
let coll = get_str(cmd, "distinct").unwrap_or("unknown");
|
|
let key = get_str(cmd, "key").ok_or_else(|| {
|
|
CommandError::InvalidArgument("distinct requires a 'key' field".into())
|
|
})?;
|
|
|
|
// Check collection existence.
|
|
let exists = ctx.storage.collection_exists(db, coll).await?;
|
|
if !exists {
|
|
return Ok(doc! { "values": [], "ok": 1.0 });
|
|
}
|
|
|
|
let query = get_document(cmd, "query").cloned();
|
|
let docs = ctx.storage.find_all(db, coll).await?;
|
|
let values = distinct_values(&docs, key, query.as_ref());
|
|
|
|
Ok(doc! {
|
|
"values": values,
|
|
"ok": 1.0,
|
|
})
|
|
}
|