576 lines
20 KiB
Rust
576 lines
20 KiB
Rust
use bson::{Bson, Document, doc};
|
|
|
|
use crate::error::QueryError;
|
|
use crate::field_path::{get_nested_value, set_nested_value, remove_nested_value};
|
|
use crate::matcher::QueryMatcher;
|
|
|
|
/// Update engine — applies update operators to documents.
|
|
pub struct UpdateEngine;
|
|
|
|
impl UpdateEngine {
|
|
/// Apply an update specification to a document.
|
|
/// Returns the updated document.
|
|
pub fn apply_update(
|
|
doc: &Document,
|
|
update: &Document,
|
|
_array_filters: Option<&[Document]>,
|
|
) -> Result<Document, QueryError> {
|
|
// Check if this is a replacement (no $ operators)
|
|
if !update.keys().any(|k| k.starts_with('$')) {
|
|
return Self::apply_replacement(doc, update);
|
|
}
|
|
|
|
let mut result = doc.clone();
|
|
|
|
for (op, value) in update {
|
|
let fields = match value {
|
|
Bson::Document(d) => d,
|
|
_ => continue,
|
|
};
|
|
|
|
match op.as_str() {
|
|
"$set" => Self::apply_set(&mut result, fields)?,
|
|
"$unset" => Self::apply_unset(&mut result, fields)?,
|
|
"$inc" => Self::apply_inc(&mut result, fields)?,
|
|
"$mul" => Self::apply_mul(&mut result, fields)?,
|
|
"$min" => Self::apply_min(&mut result, fields)?,
|
|
"$max" => Self::apply_max(&mut result, fields)?,
|
|
"$rename" => Self::apply_rename(&mut result, fields)?,
|
|
"$currentDate" => Self::apply_current_date(&mut result, fields)?,
|
|
"$setOnInsert" => {} // handled separately during upsert
|
|
"$push" => Self::apply_push(&mut result, fields)?,
|
|
"$pop" => Self::apply_pop(&mut result, fields)?,
|
|
"$pull" => Self::apply_pull(&mut result, fields)?,
|
|
"$pullAll" => Self::apply_pull_all(&mut result, fields)?,
|
|
"$addToSet" => Self::apply_add_to_set(&mut result, fields)?,
|
|
"$bit" => Self::apply_bit(&mut result, fields)?,
|
|
other => {
|
|
return Err(QueryError::InvalidUpdate(format!(
|
|
"Unknown update operator: {}",
|
|
other
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Apply $setOnInsert fields (used during upsert only).
|
|
pub fn apply_set_on_insert(doc: &mut Document, fields: &Document) {
|
|
for (key, value) in fields {
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, value.clone());
|
|
} else {
|
|
doc.insert(key.clone(), value.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Deep clone a BSON document.
|
|
pub fn deep_clone(doc: &Document) -> Document {
|
|
doc.clone()
|
|
}
|
|
|
|
fn apply_replacement(doc: &Document, replacement: &Document) -> Result<Document, QueryError> {
|
|
let mut result = replacement.clone();
|
|
// Preserve _id
|
|
if let Some(id) = doc.get("_id") {
|
|
result.insert("_id", id.clone());
|
|
}
|
|
Ok(result)
|
|
}
|
|
|
|
fn apply_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, value) in fields {
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, value.clone());
|
|
} else {
|
|
doc.insert(key.clone(), value.clone());
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_unset(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, _) in fields {
|
|
if key.contains('.') {
|
|
remove_nested_value(doc, key);
|
|
} else {
|
|
doc.remove(key);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_inc(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, inc_value) in fields {
|
|
let current = if key.contains('.') {
|
|
get_nested_value(doc, key)
|
|
} else {
|
|
doc.get(key).cloned()
|
|
};
|
|
|
|
let new_value = match (¤t, inc_value) {
|
|
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a + b),
|
|
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a + b),
|
|
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 + b),
|
|
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a + *b as i64),
|
|
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a + b),
|
|
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
|
|
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a + *b as f64),
|
|
(Some(Bson::Int64(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
|
|
(Some(Bson::Double(a)), Bson::Int64(b)) => Bson::Double(a + *b as f64),
|
|
(None, v) => v.clone(), // treat missing as 0
|
|
_ => {
|
|
return Err(QueryError::TypeMismatch(format!(
|
|
"Cannot apply $inc to non-numeric field: {}",
|
|
key
|
|
)));
|
|
}
|
|
};
|
|
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, new_value);
|
|
} else {
|
|
doc.insert(key.clone(), new_value);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_mul(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, mul_value) in fields {
|
|
let current = if key.contains('.') {
|
|
get_nested_value(doc, key)
|
|
} else {
|
|
doc.get(key).cloned()
|
|
};
|
|
|
|
let new_value = match (¤t, mul_value) {
|
|
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a * b),
|
|
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a * b),
|
|
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 * b),
|
|
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a * *b as i64),
|
|
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a * b),
|
|
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 * b),
|
|
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a * *b as f64),
|
|
(None, _) => Bson::Int32(0), // missing field * anything = 0
|
|
_ => {
|
|
return Err(QueryError::TypeMismatch(format!(
|
|
"Cannot apply $mul to non-numeric field: {}",
|
|
key
|
|
)));
|
|
}
|
|
};
|
|
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, new_value);
|
|
} else {
|
|
doc.insert(key.clone(), new_value);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_min(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, min_value) in fields {
|
|
let current = if key.contains('.') {
|
|
get_nested_value(doc, key)
|
|
} else {
|
|
doc.get(key).cloned()
|
|
};
|
|
|
|
let should_update = match ¤t {
|
|
None => true,
|
|
Some(cur) => {
|
|
if let Some(ord) = QueryMatcher::bson_compare_pub(min_value, cur) {
|
|
ord == std::cmp::Ordering::Less
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
};
|
|
|
|
if should_update {
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, min_value.clone());
|
|
} else {
|
|
doc.insert(key.clone(), min_value.clone());
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_max(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, max_value) in fields {
|
|
let current = if key.contains('.') {
|
|
get_nested_value(doc, key)
|
|
} else {
|
|
doc.get(key).cloned()
|
|
};
|
|
|
|
let should_update = match ¤t {
|
|
None => true,
|
|
Some(cur) => {
|
|
if let Some(ord) = QueryMatcher::bson_compare_pub(max_value, cur) {
|
|
ord == std::cmp::Ordering::Greater
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
};
|
|
|
|
if should_update {
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, max_value.clone());
|
|
} else {
|
|
doc.insert(key.clone(), max_value.clone());
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_rename(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (old_name, new_name_bson) in fields {
|
|
let new_name = match new_name_bson {
|
|
Bson::String(s) => s.clone(),
|
|
_ => continue,
|
|
};
|
|
|
|
if let Some(value) = doc.remove(old_name) {
|
|
doc.insert(new_name, value);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_current_date(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
let now = bson::DateTime::now();
|
|
for (key, spec) in fields {
|
|
let value = match spec {
|
|
Bson::Boolean(true) => Bson::DateTime(now),
|
|
Bson::Document(d) => {
|
|
match d.get_str("$type").unwrap_or("date") {
|
|
"date" => Bson::DateTime(now),
|
|
"timestamp" => Bson::Timestamp(bson::Timestamp {
|
|
time: (now.timestamp_millis() / 1000) as u32,
|
|
increment: 0,
|
|
}),
|
|
_ => Bson::DateTime(now),
|
|
}
|
|
}
|
|
_ => continue,
|
|
};
|
|
|
|
if key.contains('.') {
|
|
set_nested_value(doc, key, value);
|
|
} else {
|
|
doc.insert(key.clone(), value);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_push(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, value) in fields {
|
|
let arr = Self::get_or_create_array(doc, key);
|
|
|
|
match value {
|
|
Bson::Document(d) if d.contains_key("$each") => {
|
|
let each = match d.get("$each") {
|
|
Some(Bson::Array(a)) => a.clone(),
|
|
_ => return Err(QueryError::InvalidUpdate("$each must be an array".into())),
|
|
};
|
|
|
|
let position = d.get("$position").and_then(|v| match v {
|
|
Bson::Int32(n) => Some(*n as usize),
|
|
Bson::Int64(n) => Some(*n as usize),
|
|
_ => None,
|
|
});
|
|
|
|
if let Some(pos) = position {
|
|
let pos = pos.min(arr.len());
|
|
for (i, item) in each.into_iter().enumerate() {
|
|
arr.insert(pos + i, item);
|
|
}
|
|
} else {
|
|
arr.extend(each);
|
|
}
|
|
|
|
// Apply $sort if present
|
|
if let Some(sort_spec) = d.get("$sort") {
|
|
Self::sort_array(arr, sort_spec);
|
|
}
|
|
|
|
// Apply $slice if present
|
|
if let Some(slice) = d.get("$slice") {
|
|
Self::slice_array(arr, slice);
|
|
}
|
|
}
|
|
_ => {
|
|
arr.push(value.clone());
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_pop(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, direction) in fields {
|
|
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
|
|
if arr.is_empty() {
|
|
continue;
|
|
}
|
|
match direction {
|
|
Bson::Int32(-1) | Bson::Int64(-1) => { arr.remove(0); }
|
|
Bson::Int32(1) | Bson::Int64(1) => { arr.pop(); }
|
|
Bson::Double(f) if *f == 1.0 => { arr.pop(); }
|
|
Bson::Double(f) if *f == -1.0 => { arr.remove(0); }
|
|
_ => { arr.pop(); }
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_pull(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, condition) in fields {
|
|
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
|
|
match condition {
|
|
Bson::Document(cond_doc) if QueryMatcher::has_operators_pub(cond_doc) => {
|
|
arr.retain(|elem| {
|
|
if let Bson::Document(elem_doc) = elem {
|
|
!QueryMatcher::matches(elem_doc, cond_doc)
|
|
} else {
|
|
// For primitive matching with operators
|
|
let wrapper = doc! { "v": elem.clone() };
|
|
let cond_wrapper = doc! { "v": condition.clone() };
|
|
!QueryMatcher::matches(&wrapper, &cond_wrapper)
|
|
}
|
|
});
|
|
}
|
|
_ => {
|
|
arr.retain(|elem| elem != condition);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_pull_all(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, values) in fields {
|
|
if let (Some(Bson::Array(arr)), Bson::Array(to_remove)) = (doc.get_mut(key), values) {
|
|
arr.retain(|elem| !to_remove.contains(elem));
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_add_to_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, value) in fields {
|
|
let arr = Self::get_or_create_array(doc, key);
|
|
|
|
match value {
|
|
Bson::Document(d) if d.contains_key("$each") => {
|
|
if let Some(Bson::Array(each)) = d.get("$each") {
|
|
for item in each {
|
|
if !arr.contains(item) {
|
|
arr.push(item.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
_ => {
|
|
if !arr.contains(value) {
|
|
arr.push(value.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_bit(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
|
|
for (key, ops) in fields {
|
|
let ops_doc = match ops {
|
|
Bson::Document(d) => d,
|
|
_ => continue,
|
|
};
|
|
|
|
let current = doc.get(key).cloned().unwrap_or(Bson::Int32(0));
|
|
let mut val = match ¤t {
|
|
Bson::Int32(n) => *n as i64,
|
|
Bson::Int64(n) => *n,
|
|
_ => continue,
|
|
};
|
|
|
|
for (bit_op, operand) in ops_doc {
|
|
let operand_val = match operand {
|
|
Bson::Int32(n) => *n as i64,
|
|
Bson::Int64(n) => *n,
|
|
_ => continue,
|
|
};
|
|
|
|
match bit_op.as_str() {
|
|
"and" => val &= operand_val,
|
|
"or" => val |= operand_val,
|
|
"xor" => val ^= operand_val,
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let new_value = match ¤t {
|
|
Bson::Int32(_) => Bson::Int32(val as i32),
|
|
_ => Bson::Int64(val),
|
|
};
|
|
doc.insert(key.clone(), new_value);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// --- Helpers ---
|
|
|
|
fn get_or_create_array<'a>(doc: &'a mut Document, key: &str) -> &'a mut Vec<Bson> {
|
|
// Ensure an array exists at this key
|
|
let needs_init = match doc.get(key) {
|
|
Some(Bson::Array(_)) => false,
|
|
_ => true,
|
|
};
|
|
if needs_init {
|
|
doc.insert(key.to_string(), Bson::Array(Vec::new()));
|
|
}
|
|
match doc.get_mut(key).unwrap() {
|
|
Bson::Array(arr) => arr,
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
|
|
fn sort_array(arr: &mut Vec<Bson>, sort_spec: &Bson) {
|
|
match sort_spec {
|
|
Bson::Int32(dir) => {
|
|
let ascending = *dir > 0;
|
|
arr.sort_by(|a, b| {
|
|
let ord = partial_cmp_bson(a, b);
|
|
if ascending { ord } else { ord.reverse() }
|
|
});
|
|
}
|
|
Bson::Document(spec) => {
|
|
arr.sort_by(|a, b| {
|
|
for (field, dir) in spec {
|
|
let ascending = match dir {
|
|
Bson::Int32(n) => *n > 0,
|
|
_ => true,
|
|
};
|
|
let a_val = if let Bson::Document(d) = a { d.get(field) } else { None };
|
|
let b_val = if let Bson::Document(d) = b { d.get(field) } else { None };
|
|
let ord = match (a_val, b_val) {
|
|
(Some(av), Some(bv)) => partial_cmp_bson(av, bv),
|
|
(Some(_), None) => std::cmp::Ordering::Greater,
|
|
(None, Some(_)) => std::cmp::Ordering::Less,
|
|
(None, None) => std::cmp::Ordering::Equal,
|
|
};
|
|
let ord = if ascending { ord } else { ord.reverse() };
|
|
if ord != std::cmp::Ordering::Equal {
|
|
return ord;
|
|
}
|
|
}
|
|
std::cmp::Ordering::Equal
|
|
});
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
fn slice_array(arr: &mut Vec<Bson>, slice: &Bson) {
|
|
let n = match slice {
|
|
Bson::Int32(n) => *n as i64,
|
|
Bson::Int64(n) => *n,
|
|
_ => return,
|
|
};
|
|
|
|
if n >= 0 {
|
|
arr.truncate(n as usize);
|
|
} else {
|
|
let keep = (-n) as usize;
|
|
if keep < arr.len() {
|
|
let start = arr.len() - keep;
|
|
*arr = arr[start..].to_vec();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn partial_cmp_bson(a: &Bson, b: &Bson) -> std::cmp::Ordering {
|
|
use std::cmp::Ordering;
|
|
match (a, b) {
|
|
(Bson::Int32(x), Bson::Int32(y)) => x.cmp(y),
|
|
(Bson::Int64(x), Bson::Int64(y)) => x.cmp(y),
|
|
(Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
|
|
(Bson::String(x), Bson::String(y)) => x.cmp(y),
|
|
(Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y),
|
|
_ => Ordering::Equal,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_set() {
|
|
let doc = doc! { "_id": 1, "name": "Alice" };
|
|
let update = doc! { "$set": { "name": "Bob", "age": 30 } };
|
|
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
|
|
assert_eq!(result.get_str("name").unwrap(), "Bob");
|
|
assert_eq!(result.get_i32("age").unwrap(), 30);
|
|
}
|
|
|
|
#[test]
|
|
fn test_inc() {
|
|
let doc = doc! { "_id": 1, "count": 5 };
|
|
let update = doc! { "$inc": { "count": 3 } };
|
|
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
|
|
assert_eq!(result.get_i32("count").unwrap(), 8);
|
|
}
|
|
|
|
#[test]
|
|
fn test_unset() {
|
|
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
|
|
let update = doc! { "$unset": { "age": "" } };
|
|
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
|
|
assert!(result.get("age").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_replacement() {
|
|
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
|
|
let update = doc! { "name": "Bob" };
|
|
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
|
|
assert_eq!(result.get_i32("_id").unwrap(), 1); // preserved
|
|
assert_eq!(result.get_str("name").unwrap(), "Bob");
|
|
assert!(result.get("age").is_none()); // removed
|
|
}
|
|
|
|
#[test]
|
|
fn test_push() {
|
|
let doc = doc! { "_id": 1, "tags": ["a"] };
|
|
let update = doc! { "$push": { "tags": "b" } };
|
|
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
|
|
let tags = result.get_array("tags").unwrap();
|
|
assert_eq!(tags.len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_add_to_set() {
|
|
let doc = doc! { "_id": 1, "tags": ["a", "b"] };
|
|
let update = doc! { "$addToSet": { "tags": "a" } };
|
|
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
|
|
let tags = result.get_array("tags").unwrap();
|
|
assert_eq!(tags.len(), 2); // no duplicate
|
|
}
|
|
}
|