Files
smartdb/rust/crates/rustdb-query/src/update.rs
T

576 lines
20 KiB
Rust

use bson::{Bson, Document, doc};
use crate::error::QueryError;
use crate::field_path::{get_nested_value, set_nested_value, remove_nested_value};
use crate::matcher::QueryMatcher;
/// Update engine — applies update operators to documents.
pub struct UpdateEngine;
impl UpdateEngine {
/// Apply an update specification to a document.
/// Returns the updated document.
pub fn apply_update(
doc: &Document,
update: &Document,
_array_filters: Option<&[Document]>,
) -> Result<Document, QueryError> {
// Check if this is a replacement (no $ operators)
if !update.keys().any(|k| k.starts_with('$')) {
return Self::apply_replacement(doc, update);
}
let mut result = doc.clone();
for (op, value) in update {
let fields = match value {
Bson::Document(d) => d,
_ => continue,
};
match op.as_str() {
"$set" => Self::apply_set(&mut result, fields)?,
"$unset" => Self::apply_unset(&mut result, fields)?,
"$inc" => Self::apply_inc(&mut result, fields)?,
"$mul" => Self::apply_mul(&mut result, fields)?,
"$min" => Self::apply_min(&mut result, fields)?,
"$max" => Self::apply_max(&mut result, fields)?,
"$rename" => Self::apply_rename(&mut result, fields)?,
"$currentDate" => Self::apply_current_date(&mut result, fields)?,
"$setOnInsert" => {} // handled separately during upsert
"$push" => Self::apply_push(&mut result, fields)?,
"$pop" => Self::apply_pop(&mut result, fields)?,
"$pull" => Self::apply_pull(&mut result, fields)?,
"$pullAll" => Self::apply_pull_all(&mut result, fields)?,
"$addToSet" => Self::apply_add_to_set(&mut result, fields)?,
"$bit" => Self::apply_bit(&mut result, fields)?,
other => {
return Err(QueryError::InvalidUpdate(format!(
"Unknown update operator: {}",
other
)));
}
}
}
Ok(result)
}
/// Apply $setOnInsert fields (used during upsert only).
pub fn apply_set_on_insert(doc: &mut Document, fields: &Document) {
for (key, value) in fields {
if key.contains('.') {
set_nested_value(doc, key, value.clone());
} else {
doc.insert(key.clone(), value.clone());
}
}
}
/// Deep clone a BSON document.
pub fn deep_clone(doc: &Document) -> Document {
doc.clone()
}
fn apply_replacement(doc: &Document, replacement: &Document) -> Result<Document, QueryError> {
let mut result = replacement.clone();
// Preserve _id
if let Some(id) = doc.get("_id") {
result.insert("_id", id.clone());
}
Ok(result)
}
fn apply_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
if key.contains('.') {
set_nested_value(doc, key, value.clone());
} else {
doc.insert(key.clone(), value.clone());
}
}
Ok(())
}
fn apply_unset(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, _) in fields {
if key.contains('.') {
remove_nested_value(doc, key);
} else {
doc.remove(key);
}
}
Ok(())
}
fn apply_inc(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, inc_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let new_value = match (&current, inc_value) {
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a + b),
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a + b),
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 + b),
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a + *b as i64),
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a + b),
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a + *b as f64),
(Some(Bson::Int64(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
(Some(Bson::Double(a)), Bson::Int64(b)) => Bson::Double(a + *b as f64),
(None, v) => v.clone(), // treat missing as 0
_ => {
return Err(QueryError::TypeMismatch(format!(
"Cannot apply $inc to non-numeric field: {}",
key
)));
}
};
if key.contains('.') {
set_nested_value(doc, key, new_value);
} else {
doc.insert(key.clone(), new_value);
}
}
Ok(())
}
fn apply_mul(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, mul_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let new_value = match (&current, mul_value) {
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a * b),
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a * b),
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 * b),
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a * *b as i64),
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a * b),
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 * b),
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a * *b as f64),
(None, _) => Bson::Int32(0), // missing field * anything = 0
_ => {
return Err(QueryError::TypeMismatch(format!(
"Cannot apply $mul to non-numeric field: {}",
key
)));
}
};
if key.contains('.') {
set_nested_value(doc, key, new_value);
} else {
doc.insert(key.clone(), new_value);
}
}
Ok(())
}
fn apply_min(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, min_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let should_update = match &current {
None => true,
Some(cur) => {
if let Some(ord) = QueryMatcher::bson_compare_pub(min_value, cur) {
ord == std::cmp::Ordering::Less
} else {
false
}
}
};
if should_update {
if key.contains('.') {
set_nested_value(doc, key, min_value.clone());
} else {
doc.insert(key.clone(), min_value.clone());
}
}
}
Ok(())
}
fn apply_max(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, max_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let should_update = match &current {
None => true,
Some(cur) => {
if let Some(ord) = QueryMatcher::bson_compare_pub(max_value, cur) {
ord == std::cmp::Ordering::Greater
} else {
false
}
}
};
if should_update {
if key.contains('.') {
set_nested_value(doc, key, max_value.clone());
} else {
doc.insert(key.clone(), max_value.clone());
}
}
}
Ok(())
}
fn apply_rename(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (old_name, new_name_bson) in fields {
let new_name = match new_name_bson {
Bson::String(s) => s.clone(),
_ => continue,
};
if let Some(value) = doc.remove(old_name) {
doc.insert(new_name, value);
}
}
Ok(())
}
fn apply_current_date(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
let now = bson::DateTime::now();
for (key, spec) in fields {
let value = match spec {
Bson::Boolean(true) => Bson::DateTime(now),
Bson::Document(d) => {
match d.get_str("$type").unwrap_or("date") {
"date" => Bson::DateTime(now),
"timestamp" => Bson::Timestamp(bson::Timestamp {
time: (now.timestamp_millis() / 1000) as u32,
increment: 0,
}),
_ => Bson::DateTime(now),
}
}
_ => continue,
};
if key.contains('.') {
set_nested_value(doc, key, value);
} else {
doc.insert(key.clone(), value);
}
}
Ok(())
}
fn apply_push(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
let arr = Self::get_or_create_array(doc, key);
match value {
Bson::Document(d) if d.contains_key("$each") => {
let each = match d.get("$each") {
Some(Bson::Array(a)) => a.clone(),
_ => return Err(QueryError::InvalidUpdate("$each must be an array".into())),
};
let position = d.get("$position").and_then(|v| match v {
Bson::Int32(n) => Some(*n as usize),
Bson::Int64(n) => Some(*n as usize),
_ => None,
});
if let Some(pos) = position {
let pos = pos.min(arr.len());
for (i, item) in each.into_iter().enumerate() {
arr.insert(pos + i, item);
}
} else {
arr.extend(each);
}
// Apply $sort if present
if let Some(sort_spec) = d.get("$sort") {
Self::sort_array(arr, sort_spec);
}
// Apply $slice if present
if let Some(slice) = d.get("$slice") {
Self::slice_array(arr, slice);
}
}
_ => {
arr.push(value.clone());
}
}
}
Ok(())
}
fn apply_pop(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, direction) in fields {
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
if arr.is_empty() {
continue;
}
match direction {
Bson::Int32(-1) | Bson::Int64(-1) => { arr.remove(0); }
Bson::Int32(1) | Bson::Int64(1) => { arr.pop(); }
Bson::Double(f) if *f == 1.0 => { arr.pop(); }
Bson::Double(f) if *f == -1.0 => { arr.remove(0); }
_ => { arr.pop(); }
}
}
}
Ok(())
}
fn apply_pull(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, condition) in fields {
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
match condition {
Bson::Document(cond_doc) if QueryMatcher::has_operators_pub(cond_doc) => {
arr.retain(|elem| {
if let Bson::Document(elem_doc) = elem {
!QueryMatcher::matches(elem_doc, cond_doc)
} else {
// For primitive matching with operators
let wrapper = doc! { "v": elem.clone() };
let cond_wrapper = doc! { "v": condition.clone() };
!QueryMatcher::matches(&wrapper, &cond_wrapper)
}
});
}
_ => {
arr.retain(|elem| elem != condition);
}
}
}
}
Ok(())
}
fn apply_pull_all(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, values) in fields {
if let (Some(Bson::Array(arr)), Bson::Array(to_remove)) = (doc.get_mut(key), values) {
arr.retain(|elem| !to_remove.contains(elem));
}
}
Ok(())
}
fn apply_add_to_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
let arr = Self::get_or_create_array(doc, key);
match value {
Bson::Document(d) if d.contains_key("$each") => {
if let Some(Bson::Array(each)) = d.get("$each") {
for item in each {
if !arr.contains(item) {
arr.push(item.clone());
}
}
}
}
_ => {
if !arr.contains(value) {
arr.push(value.clone());
}
}
}
}
Ok(())
}
fn apply_bit(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, ops) in fields {
let ops_doc = match ops {
Bson::Document(d) => d,
_ => continue,
};
let current = doc.get(key).cloned().unwrap_or(Bson::Int32(0));
let mut val = match &current {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => continue,
};
for (bit_op, operand) in ops_doc {
let operand_val = match operand {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => continue,
};
match bit_op.as_str() {
"and" => val &= operand_val,
"or" => val |= operand_val,
"xor" => val ^= operand_val,
_ => {}
}
}
let new_value = match &current {
Bson::Int32(_) => Bson::Int32(val as i32),
_ => Bson::Int64(val),
};
doc.insert(key.clone(), new_value);
}
Ok(())
}
// --- Helpers ---
fn get_or_create_array<'a>(doc: &'a mut Document, key: &str) -> &'a mut Vec<Bson> {
// Ensure an array exists at this key
let needs_init = match doc.get(key) {
Some(Bson::Array(_)) => false,
_ => true,
};
if needs_init {
doc.insert(key.to_string(), Bson::Array(Vec::new()));
}
match doc.get_mut(key).unwrap() {
Bson::Array(arr) => arr,
_ => unreachable!(),
}
}
fn sort_array(arr: &mut Vec<Bson>, sort_spec: &Bson) {
match sort_spec {
Bson::Int32(dir) => {
let ascending = *dir > 0;
arr.sort_by(|a, b| {
let ord = partial_cmp_bson(a, b);
if ascending { ord } else { ord.reverse() }
});
}
Bson::Document(spec) => {
arr.sort_by(|a, b| {
for (field, dir) in spec {
let ascending = match dir {
Bson::Int32(n) => *n > 0,
_ => true,
};
let a_val = if let Bson::Document(d) = a { d.get(field) } else { None };
let b_val = if let Bson::Document(d) = b { d.get(field) } else { None };
let ord = match (a_val, b_val) {
(Some(av), Some(bv)) => partial_cmp_bson(av, bv),
(Some(_), None) => std::cmp::Ordering::Greater,
(None, Some(_)) => std::cmp::Ordering::Less,
(None, None) => std::cmp::Ordering::Equal,
};
let ord = if ascending { ord } else { ord.reverse() };
if ord != std::cmp::Ordering::Equal {
return ord;
}
}
std::cmp::Ordering::Equal
});
}
_ => {}
}
}
fn slice_array(arr: &mut Vec<Bson>, slice: &Bson) {
let n = match slice {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return,
};
if n >= 0 {
arr.truncate(n as usize);
} else {
let keep = (-n) as usize;
if keep < arr.len() {
let start = arr.len() - keep;
*arr = arr[start..].to_vec();
}
}
}
}
fn partial_cmp_bson(a: &Bson, b: &Bson) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(Bson::Int32(x), Bson::Int32(y)) => x.cmp(y),
(Bson::Int64(x), Bson::Int64(y)) => x.cmp(y),
(Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
(Bson::String(x), Bson::String(y)) => x.cmp(y),
(Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y),
_ => Ordering::Equal,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set() {
let doc = doc! { "_id": 1, "name": "Alice" };
let update = doc! { "$set": { "name": "Bob", "age": 30 } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_str("name").unwrap(), "Bob");
assert_eq!(result.get_i32("age").unwrap(), 30);
}
#[test]
fn test_inc() {
let doc = doc! { "_id": 1, "count": 5 };
let update = doc! { "$inc": { "count": 3 } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_i32("count").unwrap(), 8);
}
#[test]
fn test_unset() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
let update = doc! { "$unset": { "age": "" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert!(result.get("age").is_none());
}
#[test]
fn test_replacement() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
let update = doc! { "name": "Bob" };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_i32("_id").unwrap(), 1); // preserved
assert_eq!(result.get_str("name").unwrap(), "Bob");
assert!(result.get("age").is_none()); // removed
}
#[test]
fn test_push() {
let doc = doc! { "_id": 1, "tags": ["a"] };
let update = doc! { "$push": { "tags": "b" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2);
}
#[test]
fn test_add_to_set() {
let doc = doc! { "_id": 1, "tags": ["a", "b"] };
let update = doc! { "$addToSet": { "tags": "a" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2); // no duplicate
}
}