use bson::{Bson, Document, doc}; use crate::error::QueryError; use crate::field_path::{get_nested_value, set_nested_value, remove_nested_value}; use crate::matcher::QueryMatcher; /// Update engine — applies update operators to documents. pub struct UpdateEngine; impl UpdateEngine { /// Apply an update specification to a document. /// Returns the updated document. pub fn apply_update( doc: &Document, update: &Document, _array_filters: Option<&[Document]>, ) -> Result { // Check if this is a replacement (no $ operators) if !update.keys().any(|k| k.starts_with('$')) { return Self::apply_replacement(doc, update); } let mut result = doc.clone(); for (op, value) in update { let fields = match value { Bson::Document(d) => d, _ => continue, }; match op.as_str() { "$set" => Self::apply_set(&mut result, fields)?, "$unset" => Self::apply_unset(&mut result, fields)?, "$inc" => Self::apply_inc(&mut result, fields)?, "$mul" => Self::apply_mul(&mut result, fields)?, "$min" => Self::apply_min(&mut result, fields)?, "$max" => Self::apply_max(&mut result, fields)?, "$rename" => Self::apply_rename(&mut result, fields)?, "$currentDate" => Self::apply_current_date(&mut result, fields)?, "$setOnInsert" => {} // handled separately during upsert "$push" => Self::apply_push(&mut result, fields)?, "$pop" => Self::apply_pop(&mut result, fields)?, "$pull" => Self::apply_pull(&mut result, fields)?, "$pullAll" => Self::apply_pull_all(&mut result, fields)?, "$addToSet" => Self::apply_add_to_set(&mut result, fields)?, "$bit" => Self::apply_bit(&mut result, fields)?, other => { return Err(QueryError::InvalidUpdate(format!( "Unknown update operator: {}", other ))); } } } Ok(result) } /// Apply $setOnInsert fields (used during upsert only). pub fn apply_set_on_insert(doc: &mut Document, fields: &Document) { for (key, value) in fields { if key.contains('.') { set_nested_value(doc, key, value.clone()); } else { doc.insert(key.clone(), value.clone()); } } } /// Deep clone a BSON document. pub fn deep_clone(doc: &Document) -> Document { doc.clone() } fn apply_replacement(doc: &Document, replacement: &Document) -> Result { let mut result = replacement.clone(); // Preserve _id if let Some(id) = doc.get("_id") { result.insert("_id", id.clone()); } Ok(result) } fn apply_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, value) in fields { if key.contains('.') { set_nested_value(doc, key, value.clone()); } else { doc.insert(key.clone(), value.clone()); } } Ok(()) } fn apply_unset(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, _) in fields { if key.contains('.') { remove_nested_value(doc, key); } else { doc.remove(key); } } Ok(()) } fn apply_inc(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, inc_value) in fields { let current = if key.contains('.') { get_nested_value(doc, key) } else { doc.get(key).cloned() }; let new_value = match (¤t, inc_value) { (Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a + b), (Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a + b), (Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 + b), (Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a + *b as i64), (Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a + b), (Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b), (Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a + *b as f64), (Some(Bson::Int64(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b), (Some(Bson::Double(a)), Bson::Int64(b)) => Bson::Double(a + *b as f64), (None, v) => v.clone(), // treat missing as 0 _ => { return Err(QueryError::TypeMismatch(format!( "Cannot apply $inc to non-numeric field: {}", key ))); } }; if key.contains('.') { set_nested_value(doc, key, new_value); } else { doc.insert(key.clone(), new_value); } } Ok(()) } fn apply_mul(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, mul_value) in fields { let current = if key.contains('.') { get_nested_value(doc, key) } else { doc.get(key).cloned() }; let new_value = match (¤t, mul_value) { (Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a * b), (Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a * b), (Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 * b), (Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a * *b as i64), (Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a * b), (Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 * b), (Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a * *b as f64), (None, _) => Bson::Int32(0), // missing field * anything = 0 _ => { return Err(QueryError::TypeMismatch(format!( "Cannot apply $mul to non-numeric field: {}", key ))); } }; if key.contains('.') { set_nested_value(doc, key, new_value); } else { doc.insert(key.clone(), new_value); } } Ok(()) } fn apply_min(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, min_value) in fields { let current = if key.contains('.') { get_nested_value(doc, key) } else { doc.get(key).cloned() }; let should_update = match ¤t { None => true, Some(cur) => { if let Some(ord) = QueryMatcher::bson_compare_pub(min_value, cur) { ord == std::cmp::Ordering::Less } else { false } } }; if should_update { if key.contains('.') { set_nested_value(doc, key, min_value.clone()); } else { doc.insert(key.clone(), min_value.clone()); } } } Ok(()) } fn apply_max(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, max_value) in fields { let current = if key.contains('.') { get_nested_value(doc, key) } else { doc.get(key).cloned() }; let should_update = match ¤t { None => true, Some(cur) => { if let Some(ord) = QueryMatcher::bson_compare_pub(max_value, cur) { ord == std::cmp::Ordering::Greater } else { false } } }; if should_update { if key.contains('.') { set_nested_value(doc, key, max_value.clone()); } else { doc.insert(key.clone(), max_value.clone()); } } } Ok(()) } fn apply_rename(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (old_name, new_name_bson) in fields { let new_name = match new_name_bson { Bson::String(s) => s.clone(), _ => continue, }; if let Some(value) = doc.remove(old_name) { doc.insert(new_name, value); } } Ok(()) } fn apply_current_date(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { let now = bson::DateTime::now(); for (key, spec) in fields { let value = match spec { Bson::Boolean(true) => Bson::DateTime(now), Bson::Document(d) => { match d.get_str("$type").unwrap_or("date") { "date" => Bson::DateTime(now), "timestamp" => Bson::Timestamp(bson::Timestamp { time: (now.timestamp_millis() / 1000) as u32, increment: 0, }), _ => Bson::DateTime(now), } } _ => continue, }; if key.contains('.') { set_nested_value(doc, key, value); } else { doc.insert(key.clone(), value); } } Ok(()) } fn apply_push(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, value) in fields { let arr = Self::get_or_create_array(doc, key); match value { Bson::Document(d) if d.contains_key("$each") => { let each = match d.get("$each") { Some(Bson::Array(a)) => a.clone(), _ => return Err(QueryError::InvalidUpdate("$each must be an array".into())), }; let position = d.get("$position").and_then(|v| match v { Bson::Int32(n) => Some(*n as usize), Bson::Int64(n) => Some(*n as usize), _ => None, }); if let Some(pos) = position { let pos = pos.min(arr.len()); for (i, item) in each.into_iter().enumerate() { arr.insert(pos + i, item); } } else { arr.extend(each); } // Apply $sort if present if let Some(sort_spec) = d.get("$sort") { Self::sort_array(arr, sort_spec); } // Apply $slice if present if let Some(slice) = d.get("$slice") { Self::slice_array(arr, slice); } } _ => { arr.push(value.clone()); } } } Ok(()) } fn apply_pop(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, direction) in fields { if let Some(Bson::Array(arr)) = doc.get_mut(key) { if arr.is_empty() { continue; } match direction { Bson::Int32(-1) | Bson::Int64(-1) => { arr.remove(0); } Bson::Int32(1) | Bson::Int64(1) => { arr.pop(); } Bson::Double(f) if *f == 1.0 => { arr.pop(); } Bson::Double(f) if *f == -1.0 => { arr.remove(0); } _ => { arr.pop(); } } } } Ok(()) } fn apply_pull(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, condition) in fields { if let Some(Bson::Array(arr)) = doc.get_mut(key) { match condition { Bson::Document(cond_doc) if QueryMatcher::has_operators_pub(cond_doc) => { arr.retain(|elem| { if let Bson::Document(elem_doc) = elem { !QueryMatcher::matches(elem_doc, cond_doc) } else { // For primitive matching with operators let wrapper = doc! { "v": elem.clone() }; let cond_wrapper = doc! { "v": condition.clone() }; !QueryMatcher::matches(&wrapper, &cond_wrapper) } }); } _ => { arr.retain(|elem| elem != condition); } } } } Ok(()) } fn apply_pull_all(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, values) in fields { if let (Some(Bson::Array(arr)), Bson::Array(to_remove)) = (doc.get_mut(key), values) { arr.retain(|elem| !to_remove.contains(elem)); } } Ok(()) } fn apply_add_to_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, value) in fields { let arr = Self::get_or_create_array(doc, key); match value { Bson::Document(d) if d.contains_key("$each") => { if let Some(Bson::Array(each)) = d.get("$each") { for item in each { if !arr.contains(item) { arr.push(item.clone()); } } } } _ => { if !arr.contains(value) { arr.push(value.clone()); } } } } Ok(()) } fn apply_bit(doc: &mut Document, fields: &Document) -> Result<(), QueryError> { for (key, ops) in fields { let ops_doc = match ops { Bson::Document(d) => d, _ => continue, }; let current = doc.get(key).cloned().unwrap_or(Bson::Int32(0)); let mut val = match ¤t { Bson::Int32(n) => *n as i64, Bson::Int64(n) => *n, _ => continue, }; for (bit_op, operand) in ops_doc { let operand_val = match operand { Bson::Int32(n) => *n as i64, Bson::Int64(n) => *n, _ => continue, }; match bit_op.as_str() { "and" => val &= operand_val, "or" => val |= operand_val, "xor" => val ^= operand_val, _ => {} } } let new_value = match ¤t { Bson::Int32(_) => Bson::Int32(val as i32), _ => Bson::Int64(val), }; doc.insert(key.clone(), new_value); } Ok(()) } // --- Helpers --- fn get_or_create_array<'a>(doc: &'a mut Document, key: &str) -> &'a mut Vec { // Ensure an array exists at this key let needs_init = match doc.get(key) { Some(Bson::Array(_)) => false, _ => true, }; if needs_init { doc.insert(key.to_string(), Bson::Array(Vec::new())); } match doc.get_mut(key).unwrap() { Bson::Array(arr) => arr, _ => unreachable!(), } } fn sort_array(arr: &mut Vec, sort_spec: &Bson) { match sort_spec { Bson::Int32(dir) => { let ascending = *dir > 0; arr.sort_by(|a, b| { let ord = partial_cmp_bson(a, b); if ascending { ord } else { ord.reverse() } }); } Bson::Document(spec) => { arr.sort_by(|a, b| { for (field, dir) in spec { let ascending = match dir { Bson::Int32(n) => *n > 0, _ => true, }; let a_val = if let Bson::Document(d) = a { d.get(field) } else { None }; let b_val = if let Bson::Document(d) = b { d.get(field) } else { None }; let ord = match (a_val, b_val) { (Some(av), Some(bv)) => partial_cmp_bson(av, bv), (Some(_), None) => std::cmp::Ordering::Greater, (None, Some(_)) => std::cmp::Ordering::Less, (None, None) => std::cmp::Ordering::Equal, }; let ord = if ascending { ord } else { ord.reverse() }; if ord != std::cmp::Ordering::Equal { return ord; } } std::cmp::Ordering::Equal }); } _ => {} } } fn slice_array(arr: &mut Vec, slice: &Bson) { let n = match slice { Bson::Int32(n) => *n as i64, Bson::Int64(n) => *n, _ => return, }; if n >= 0 { arr.truncate(n as usize); } else { let keep = (-n) as usize; if keep < arr.len() { let start = arr.len() - keep; *arr = arr[start..].to_vec(); } } } } fn partial_cmp_bson(a: &Bson, b: &Bson) -> std::cmp::Ordering { use std::cmp::Ordering; match (a, b) { (Bson::Int32(x), Bson::Int32(y)) => x.cmp(y), (Bson::Int64(x), Bson::Int64(y)) => x.cmp(y), (Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal), (Bson::String(x), Bson::String(y)) => x.cmp(y), (Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y), _ => Ordering::Equal, } } #[cfg(test)] mod tests { use super::*; #[test] fn test_set() { let doc = doc! { "_id": 1, "name": "Alice" }; let update = doc! { "$set": { "name": "Bob", "age": 30 } }; let result = UpdateEngine::apply_update(&doc, &update, None).unwrap(); assert_eq!(result.get_str("name").unwrap(), "Bob"); assert_eq!(result.get_i32("age").unwrap(), 30); } #[test] fn test_inc() { let doc = doc! { "_id": 1, "count": 5 }; let update = doc! { "$inc": { "count": 3 } }; let result = UpdateEngine::apply_update(&doc, &update, None).unwrap(); assert_eq!(result.get_i32("count").unwrap(), 8); } #[test] fn test_unset() { let doc = doc! { "_id": 1, "name": "Alice", "age": 30 }; let update = doc! { "$unset": { "age": "" } }; let result = UpdateEngine::apply_update(&doc, &update, None).unwrap(); assert!(result.get("age").is_none()); } #[test] fn test_replacement() { let doc = doc! { "_id": 1, "name": "Alice", "age": 30 }; let update = doc! { "name": "Bob" }; let result = UpdateEngine::apply_update(&doc, &update, None).unwrap(); assert_eq!(result.get_i32("_id").unwrap(), 1); // preserved assert_eq!(result.get_str("name").unwrap(), "Bob"); assert!(result.get("age").is_none()); // removed } #[test] fn test_push() { let doc = doc! { "_id": 1, "tags": ["a"] }; let update = doc! { "$push": { "tags": "b" } }; let result = UpdateEngine::apply_update(&doc, &update, None).unwrap(); let tags = result.get_array("tags").unwrap(); assert_eq!(tags.len(), 2); } #[test] fn test_add_to_set() { let doc = doc! { "_id": 1, "tags": ["a", "b"] }; let update = doc! { "$addToSet": { "tags": "a" } }; let result = UpdateEngine::apply_update(&doc, &update, None).unwrap(); let tags = result.get_array("tags").unwrap(); assert_eq!(tags.len(), 2); // no duplicate } }