Files
smartdb/rust/crates/rustdb-commands/src/handlers/aggregate_handler.rs

311 lines
9.5 KiB
Rust

use bson::{doc, Bson, Document};
use rustdb_query::AggregationEngine;
use rustdb_query::error::QueryError;
use tracing::debug;
use crate::context::{CommandContext, CursorState};
use crate::error::{CommandError, CommandResult};
/// A CollectionResolver that reads from the storage adapter.
struct StorageResolver<'a> {
storage: &'a dyn rustdb_storage::StorageAdapter,
/// We use a tokio runtime handle to call async methods synchronously,
/// since the CollectionResolver trait is synchronous.
handle: tokio::runtime::Handle,
}
impl<'a> rustdb_query::aggregation::CollectionResolver for StorageResolver<'a> {
fn resolve(&self, db: &str, coll: &str) -> Result<Vec<Document>, QueryError> {
self.handle
.block_on(async { self.storage.find_all(db, coll).await })
.map_err(|e| QueryError::AggregationError(format!("Failed to resolve {}.{}: {}", db, coll, e)))
}
}
/// Handle the `aggregate` command.
pub async fn handle(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
// The aggregate field can be a string (collection name) or integer 1 (db-level).
let (coll, is_db_level) = match cmd.get("aggregate") {
Some(Bson::String(s)) => (s.as_str().to_string(), false),
Some(Bson::Int32(1)) => (String::new(), true),
Some(Bson::Int64(1)) => (String::new(), true),
_ => {
return Err(CommandError::InvalidArgument(
"missing or invalid 'aggregate' field".into(),
));
}
};
let pipeline_bson = cmd
.get_array("pipeline")
.map_err(|_| CommandError::InvalidArgument("missing 'pipeline' array".into()))?;
// Convert pipeline to Vec<Document>.
let mut pipeline: Vec<Document> = Vec::with_capacity(pipeline_bson.len());
for stage in pipeline_bson {
match stage {
Bson::Document(d) => pipeline.push(d.clone()),
_ => {
return Err(CommandError::InvalidArgument(
"pipeline stage must be a document".into(),
));
}
}
}
// Check for $out and $merge as the last stage (handle after pipeline execution).
let out_stage = if let Some(last) = pipeline.last() {
if last.contains_key("$out") || last.contains_key("$merge") {
Some(pipeline.pop().unwrap())
} else {
None
}
} else {
None
};
let batch_size = cmd
.get_document("cursor")
.ok()
.and_then(|c| {
c.get_i32("batchSize")
.ok()
.map(|v| v as usize)
.or_else(|| c.get_i64("batchSize").ok().map(|v| v as usize))
})
.unwrap_or(101);
debug!(
db = db,
collection = %coll,
stages = pipeline.len(),
"aggregate command"
);
// Load source documents.
let source_docs = if is_db_level {
// Database-level aggregate: start with empty set (useful for $currentOp, etc.)
Vec::new()
} else {
ctx.storage.find_all(db, &coll).await?
};
// Create a resolver for $lookup and similar stages.
let handle = tokio::runtime::Handle::current();
let resolver = StorageResolver {
storage: ctx.storage.as_ref(),
handle,
};
// Run the aggregation pipeline.
let result_docs = AggregationEngine::aggregate(
source_docs,
&pipeline,
Some(&resolver),
db,
)
.map_err(|e| CommandError::InternalError(e.to_string()))?;
// Handle $out stage: write results to target collection.
if let Some(out) = out_stage {
if let Some(out_spec) = out.get("$out") {
handle_out_stage(db, out_spec, &result_docs, ctx).await?;
} else if let Some(merge_spec) = out.get("$merge") {
handle_merge_stage(db, merge_spec, &result_docs, ctx).await?;
}
}
// Build cursor response.
let ns = if is_db_level {
format!("{}.$cmd.aggregate", db)
} else {
format!("{}.{}", db, coll)
};
if result_docs.len() <= batch_size {
// All results fit in first batch.
let first_batch: Vec<Bson> = result_docs
.into_iter()
.map(Bson::Document)
.collect();
Ok(doc! {
"cursor": {
"firstBatch": first_batch,
"id": 0_i64,
"ns": &ns,
},
"ok": 1.0,
})
} else {
// Need to create a cursor for remaining results.
let first_batch: Vec<Bson> = result_docs[..batch_size]
.iter()
.cloned()
.map(Bson::Document)
.collect();
let remaining: Vec<Document> = result_docs[batch_size..].to_vec();
let cursor_id = generate_cursor_id();
ctx.cursors.insert(
cursor_id,
CursorState {
documents: remaining,
position: 0,
database: db.to_string(),
collection: coll.to_string(),
},
);
Ok(doc! {
"cursor": {
"firstBatch": first_batch,
"id": cursor_id,
"ns": &ns,
},
"ok": 1.0,
})
}
}
/// Handle $out stage: drop and replace target collection with pipeline results.
async fn handle_out_stage(
db: &str,
out_spec: &Bson,
docs: &[Document],
ctx: &CommandContext,
) -> CommandResult<()> {
let (target_db, target_coll) = match out_spec {
Bson::String(coll_name) => (db.to_string(), coll_name.clone()),
Bson::Document(d) => {
let tdb = d.get_str("db").unwrap_or(db).to_string();
let tcoll = d
.get_str("coll")
.map_err(|_| CommandError::InvalidArgument("$out requires 'coll'".into()))?
.to_string();
(tdb, tcoll)
}
_ => {
return Err(CommandError::InvalidArgument(
"$out requires a string or document".into(),
));
}
};
// Drop existing target collection (ignore errors).
let _ = ctx.storage.drop_collection(&target_db, &target_coll).await;
// Create target collection.
let _ = ctx.storage.create_database(&target_db).await;
let _ = ctx.storage.create_collection(&target_db, &target_coll).await;
// Insert all result documents.
for doc in docs {
let _ = ctx
.storage
.insert_one(&target_db, &target_coll, doc.clone())
.await;
}
Ok(())
}
/// Handle $merge stage: merge pipeline results into target collection.
async fn handle_merge_stage(
db: &str,
merge_spec: &Bson,
docs: &[Document],
ctx: &CommandContext,
) -> CommandResult<()> {
let (target_db, target_coll) = match merge_spec {
Bson::String(coll_name) => (db.to_string(), coll_name.clone()),
Bson::Document(d) => {
let into_val = d.get("into");
match into_val {
Some(Bson::String(s)) => (db.to_string(), s.clone()),
Some(Bson::Document(into_doc)) => {
let tdb = into_doc.get_str("db").unwrap_or(db).to_string();
let tcoll = into_doc
.get_str("coll")
.map_err(|_| {
CommandError::InvalidArgument("$merge.into requires 'coll'".into())
})?
.to_string();
(tdb, tcoll)
}
_ => {
return Err(CommandError::InvalidArgument(
"$merge requires 'into' field".into(),
));
}
}
}
_ => {
return Err(CommandError::InvalidArgument(
"$merge requires a string or document".into(),
));
}
};
// Ensure target collection exists.
let _ = ctx.storage.create_database(&target_db).await;
let _ = ctx
.storage
.create_collection(&target_db, &target_coll)
.await;
// Simple merge: upsert by _id.
for doc in docs {
let id_str = match doc.get("_id") {
Some(Bson::ObjectId(oid)) => oid.to_hex(),
Some(Bson::String(s)) => s.clone(),
Some(other) => format!("{}", other),
None => {
// No _id, just insert.
let _ = ctx
.storage
.insert_one(&target_db, &target_coll, doc.clone())
.await;
continue;
}
};
// Try update first, insert if it fails.
match ctx
.storage
.update_by_id(&target_db, &target_coll, &id_str, doc.clone())
.await
{
Ok(()) => {}
Err(_) => {
let _ = ctx
.storage
.insert_one(&target_db, &target_coll, doc.clone())
.await;
}
}
}
Ok(())
}
/// Generate a pseudo-random cursor ID.
fn generate_cursor_id() -> i64 {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let s = RandomState::new();
let mut hasher = s.build_hasher();
hasher.write_u64(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64);
let id = hasher.finish() as i64;
// Ensure positive and non-zero.
if id == 0 { 1 } else { id.abs() }
}