use bson::{doc, Bson, Document}; use rustdb_query::AggregationEngine; use rustdb_query::error::QueryError; use tracing::debug; use crate::context::{CommandContext, CursorState}; use crate::error::{CommandError, CommandResult}; /// A CollectionResolver that reads from the storage adapter. struct StorageResolver<'a> { storage: &'a dyn rustdb_storage::StorageAdapter, /// We use a tokio runtime handle to call async methods synchronously, /// since the CollectionResolver trait is synchronous. handle: tokio::runtime::Handle, } impl<'a> rustdb_query::aggregation::CollectionResolver for StorageResolver<'a> { fn resolve(&self, db: &str, coll: &str) -> Result, QueryError> { self.handle .block_on(async { self.storage.find_all(db, coll).await }) .map_err(|e| QueryError::AggregationError(format!("Failed to resolve {}.{}: {}", db, coll, e))) } } /// Handle the `aggregate` command. pub async fn handle( cmd: &Document, db: &str, ctx: &CommandContext, ) -> CommandResult { // The aggregate field can be a string (collection name) or integer 1 (db-level). let (coll, is_db_level) = match cmd.get("aggregate") { Some(Bson::String(s)) => (s.as_str().to_string(), false), Some(Bson::Int32(1)) => (String::new(), true), Some(Bson::Int64(1)) => (String::new(), true), _ => { return Err(CommandError::InvalidArgument( "missing or invalid 'aggregate' field".into(), )); } }; let pipeline_bson = cmd .get_array("pipeline") .map_err(|_| CommandError::InvalidArgument("missing 'pipeline' array".into()))?; // Convert pipeline to Vec. let mut pipeline: Vec = Vec::with_capacity(pipeline_bson.len()); for stage in pipeline_bson { match stage { Bson::Document(d) => pipeline.push(d.clone()), _ => { return Err(CommandError::InvalidArgument( "pipeline stage must be a document".into(), )); } } } // Check for $out and $merge as the last stage (handle after pipeline execution). let out_stage = if let Some(last) = pipeline.last() { if last.contains_key("$out") || last.contains_key("$merge") { Some(pipeline.pop().unwrap()) } else { None } } else { None }; let batch_size = cmd .get_document("cursor") .ok() .and_then(|c| { c.get_i32("batchSize") .ok() .map(|v| v as usize) .or_else(|| c.get_i64("batchSize").ok().map(|v| v as usize)) }) .unwrap_or(101); debug!( db = db, collection = %coll, stages = pipeline.len(), "aggregate command" ); // Load source documents. let source_docs = if is_db_level { // Database-level aggregate: start with empty set (useful for $currentOp, etc.) Vec::new() } else { ctx.storage.find_all(db, &coll).await? }; // Create a resolver for $lookup and similar stages. let handle = tokio::runtime::Handle::current(); let resolver = StorageResolver { storage: ctx.storage.as_ref(), handle, }; // Run the aggregation pipeline. let result_docs = AggregationEngine::aggregate( source_docs, &pipeline, Some(&resolver), db, ) .map_err(|e| CommandError::InternalError(e.to_string()))?; // Handle $out stage: write results to target collection. if let Some(out) = out_stage { if let Some(out_spec) = out.get("$out") { handle_out_stage(db, out_spec, &result_docs, ctx).await?; } else if let Some(merge_spec) = out.get("$merge") { handle_merge_stage(db, merge_spec, &result_docs, ctx).await?; } } // Build cursor response. let ns = if is_db_level { format!("{}.$cmd.aggregate", db) } else { format!("{}.{}", db, coll) }; if result_docs.len() <= batch_size { // All results fit in first batch. let first_batch: Vec = result_docs .into_iter() .map(Bson::Document) .collect(); Ok(doc! { "cursor": { "firstBatch": first_batch, "id": 0_i64, "ns": &ns, }, "ok": 1.0, }) } else { // Need to create a cursor for remaining results. let first_batch: Vec = result_docs[..batch_size] .iter() .cloned() .map(Bson::Document) .collect(); let remaining: Vec = result_docs[batch_size..].to_vec(); let cursor_id = generate_cursor_id(); ctx.cursors.insert( cursor_id, CursorState { documents: remaining, position: 0, database: db.to_string(), collection: coll.to_string(), }, ); Ok(doc! { "cursor": { "firstBatch": first_batch, "id": cursor_id, "ns": &ns, }, "ok": 1.0, }) } } /// Handle $out stage: drop and replace target collection with pipeline results. async fn handle_out_stage( db: &str, out_spec: &Bson, docs: &[Document], ctx: &CommandContext, ) -> CommandResult<()> { let (target_db, target_coll) = match out_spec { Bson::String(coll_name) => (db.to_string(), coll_name.clone()), Bson::Document(d) => { let tdb = d.get_str("db").unwrap_or(db).to_string(); let tcoll = d .get_str("coll") .map_err(|_| CommandError::InvalidArgument("$out requires 'coll'".into()))? .to_string(); (tdb, tcoll) } _ => { return Err(CommandError::InvalidArgument( "$out requires a string or document".into(), )); } }; // Drop existing target collection (ignore errors). let _ = ctx.storage.drop_collection(&target_db, &target_coll).await; // Create target collection. let _ = ctx.storage.create_database(&target_db).await; let _ = ctx.storage.create_collection(&target_db, &target_coll).await; // Insert all result documents. for doc in docs { let _ = ctx .storage .insert_one(&target_db, &target_coll, doc.clone()) .await; } Ok(()) } /// Handle $merge stage: merge pipeline results into target collection. async fn handle_merge_stage( db: &str, merge_spec: &Bson, docs: &[Document], ctx: &CommandContext, ) -> CommandResult<()> { let (target_db, target_coll) = match merge_spec { Bson::String(coll_name) => (db.to_string(), coll_name.clone()), Bson::Document(d) => { let into_val = d.get("into"); match into_val { Some(Bson::String(s)) => (db.to_string(), s.clone()), Some(Bson::Document(into_doc)) => { let tdb = into_doc.get_str("db").unwrap_or(db).to_string(); let tcoll = into_doc .get_str("coll") .map_err(|_| { CommandError::InvalidArgument("$merge.into requires 'coll'".into()) })? .to_string(); (tdb, tcoll) } _ => { return Err(CommandError::InvalidArgument( "$merge requires 'into' field".into(), )); } } } _ => { return Err(CommandError::InvalidArgument( "$merge requires a string or document".into(), )); } }; // Ensure target collection exists. let _ = ctx.storage.create_database(&target_db).await; let _ = ctx .storage .create_collection(&target_db, &target_coll) .await; // Simple merge: upsert by _id. for doc in docs { let id_str = match doc.get("_id") { Some(Bson::ObjectId(oid)) => oid.to_hex(), Some(Bson::String(s)) => s.clone(), Some(other) => format!("{}", other), None => { // No _id, just insert. let _ = ctx .storage .insert_one(&target_db, &target_coll, doc.clone()) .await; continue; } }; // Try update first, insert if it fails. match ctx .storage .update_by_id(&target_db, &target_coll, &id_str, doc.clone()) .await { Ok(()) => {} Err(_) => { let _ = ctx .storage .insert_one(&target_db, &target_coll, doc.clone()) .await; } } } Ok(()) } /// Generate a pseudo-random cursor ID. fn generate_cursor_id() -> i64 { use std::collections::hash_map::RandomState; use std::hash::{BuildHasher, Hasher}; let s = RandomState::new(); let mut hasher = s.build_hasher(); hasher.write_u64(std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_nanos() as u64); let id = hasher.finish() as i64; // Ensure positive and non-zero. if id == 0 { 1 } else { id.abs() } }