BREAKING CHANGE(core): replace the TypeScript database engine with a Rust-backed embedded server and bridge

This commit is contained in:
2026-03-26 19:48:27 +00:00
parent 8ec2046908
commit e23a951dbe
106 changed files with 11567 additions and 10678 deletions

View File

@@ -0,0 +1,15 @@
[package]
name = "rustdb-query"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "MongoDB-compatible query matching, update operators, aggregation, sort, and projection engine"
[dependencies]
bson = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
regex = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }

View File

@@ -0,0 +1,614 @@
use bson::{Bson, Document};
use std::collections::HashMap;
use crate::error::QueryError;
use crate::matcher::QueryMatcher;
use crate::sort::sort_documents;
use crate::projection::apply_projection;
use crate::field_path::get_nested_value;
/// Aggregation pipeline engine.
pub struct AggregationEngine;
/// Trait for resolving cross-collection data (for $lookup, $graphLookup, etc.).
pub trait CollectionResolver {
fn resolve(&self, db: &str, coll: &str) -> Result<Vec<Document>, QueryError>;
}
impl AggregationEngine {
/// Execute an aggregation pipeline on a set of documents.
pub fn aggregate(
docs: Vec<Document>,
pipeline: &[Document],
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let mut current = docs;
for stage in pipeline {
let (stage_name, stage_spec) = stage
.iter()
.next()
.ok_or_else(|| QueryError::AggregationError("Empty pipeline stage".into()))?;
current = match stage_name.as_str() {
"$match" => Self::stage_match(current, stage_spec)?,
"$project" => Self::stage_project(current, stage_spec)?,
"$sort" => Self::stage_sort(current, stage_spec)?,
"$limit" => Self::stage_limit(current, stage_spec)?,
"$skip" => Self::stage_skip(current, stage_spec)?,
"$group" => Self::stage_group(current, stage_spec)?,
"$unwind" => Self::stage_unwind(current, stage_spec)?,
"$count" => Self::stage_count(current, stage_spec)?,
"$addFields" | "$set" => Self::stage_add_fields(current, stage_spec)?,
"$replaceRoot" | "$replaceWith" => Self::stage_replace_root(current, stage_spec)?,
"$lookup" => Self::stage_lookup(current, stage_spec, resolver, db)?,
"$facet" => Self::stage_facet(current, stage_spec, resolver, db)?,
"$unionWith" => Self::stage_union_with(current, stage_spec, resolver, db)?,
other => {
return Err(QueryError::AggregationError(format!(
"Unsupported aggregation stage: {}",
other
)));
}
};
}
Ok(current)
}
fn stage_match(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let filter = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$match requires a document".into())),
};
Ok(QueryMatcher::filter(&docs, filter))
}
fn stage_project(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let projection = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$project requires a document".into())),
};
Ok(docs.into_iter().map(|doc| apply_projection(&doc, projection)).collect())
}
fn stage_sort(mut docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let sort_spec = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$sort requires a document".into())),
};
sort_documents(&mut docs, sort_spec);
Ok(docs)
}
fn stage_limit(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let n = bson_to_usize(spec)
.ok_or_else(|| QueryError::AggregationError("$limit requires a number".into()))?;
Ok(docs.into_iter().take(n).collect())
}
fn stage_skip(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let n = bson_to_usize(spec)
.ok_or_else(|| QueryError::AggregationError("$skip requires a number".into()))?;
Ok(docs.into_iter().skip(n).collect())
}
fn stage_group(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let group_spec = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$group requires a document".into())),
};
let id_expr = group_spec.get("_id").cloned().unwrap_or(Bson::Null);
// Group documents by _id
let mut groups: HashMap<String, (Bson, Vec<Document>)> = HashMap::new();
for doc in &docs {
let group_key = resolve_expression(&id_expr, doc);
let key_str = format!("{:?}", group_key);
groups
.entry(key_str)
.or_insert_with(|| (group_key.clone(), Vec::new()))
.1
.push(doc.clone());
}
let mut result = Vec::new();
for (_key_str, (group_id, group_docs)) in groups {
let mut output = bson::doc! { "_id": group_id };
for (field, accumulator) in group_spec {
if field == "_id" {
continue;
}
let acc_doc = match accumulator {
Bson::Document(d) => d,
_ => continue,
};
let (acc_op, acc_expr) = acc_doc.iter().next().unwrap();
let value = match acc_op.as_str() {
"$sum" => accumulate_sum(&group_docs, acc_expr),
"$avg" => accumulate_avg(&group_docs, acc_expr),
"$min" => accumulate_min(&group_docs, acc_expr),
"$max" => accumulate_max(&group_docs, acc_expr),
"$first" => accumulate_first(&group_docs, acc_expr),
"$last" => accumulate_last(&group_docs, acc_expr),
"$push" => accumulate_push(&group_docs, acc_expr),
"$addToSet" => accumulate_add_to_set(&group_docs, acc_expr),
"$count" => Bson::Int64(group_docs.len() as i64),
_ => Bson::Null,
};
output.insert(field.clone(), value);
}
result.push(output);
}
Ok(result)
}
fn stage_unwind(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let (path, preserve_null) = match spec {
Bson::String(s) => (s.trim_start_matches('$').to_string(), false),
Bson::Document(d) => {
let path = d.get_str("path")
.map(|s| s.trim_start_matches('$').to_string())
.map_err(|_| QueryError::AggregationError("$unwind requires 'path'".into()))?;
let preserve = d.get_bool("preserveNullAndEmptyArrays").unwrap_or(false);
(path, preserve)
}
_ => return Err(QueryError::AggregationError("$unwind requires a string or document".into())),
};
let mut result = Vec::new();
for doc in docs {
let value = doc.get(&path).cloned();
match value {
Some(Bson::Array(arr)) => {
if arr.is_empty() && preserve_null {
let mut new_doc = doc.clone();
new_doc.remove(&path);
result.push(new_doc);
} else {
for elem in arr {
let mut new_doc = doc.clone();
new_doc.insert(path.clone(), elem);
result.push(new_doc);
}
}
}
Some(Bson::Null) | None => {
if preserve_null {
result.push(doc);
}
}
Some(val) => {
// Non-array: keep as-is
let mut new_doc = doc;
new_doc.insert(path.clone(), val);
result.push(new_doc);
}
}
}
Ok(result)
}
fn stage_count(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let field = match spec {
Bson::String(s) => s.clone(),
_ => return Err(QueryError::AggregationError("$count requires a string".into())),
};
Ok(vec![bson::doc! { field: docs.len() as i64 }])
}
fn stage_add_fields(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let fields = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$addFields requires a document".into())),
};
Ok(docs
.into_iter()
.map(|mut doc| {
for (key, expr) in fields {
let value = resolve_expression(expr, &doc);
doc.insert(key.clone(), value);
}
doc
})
.collect())
}
fn stage_replace_root(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let new_root_expr = match spec {
Bson::Document(d) => d.get("newRoot").cloned().unwrap_or(Bson::Document(d.clone())),
Bson::String(s) => Bson::String(s.clone()),
_ => return Err(QueryError::AggregationError("$replaceRoot requires a document".into())),
};
let mut result = Vec::new();
for doc in docs {
let new_root = resolve_expression(&new_root_expr, &doc);
if let Bson::Document(d) = new_root {
result.push(d);
}
}
Ok(result)
}
fn stage_lookup(
docs: Vec<Document>,
spec: &Bson,
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let lookup = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$lookup requires a document".into())),
};
let from = lookup.get_str("from")
.map_err(|_| QueryError::AggregationError("$lookup requires 'from'".into()))?;
let local_field = lookup.get_str("localField")
.map_err(|_| QueryError::AggregationError("$lookup requires 'localField'".into()))?;
let foreign_field = lookup.get_str("foreignField")
.map_err(|_| QueryError::AggregationError("$lookup requires 'foreignField'".into()))?;
let as_field = lookup.get_str("as")
.map_err(|_| QueryError::AggregationError("$lookup requires 'as'".into()))?;
let resolver = resolver
.ok_or_else(|| QueryError::AggregationError("$lookup requires a collection resolver".into()))?;
let foreign_docs = resolver.resolve(db, from)?;
Ok(docs
.into_iter()
.map(|mut doc| {
let local_val = get_nested_value(&doc, local_field);
let matches: Vec<Bson> = foreign_docs
.iter()
.filter(|fd| {
let foreign_val = get_nested_value(fd, foreign_field);
match (&local_val, &foreign_val) {
(Some(a), Some(b)) => bson_loose_eq(a, b),
_ => false,
}
})
.map(|fd| Bson::Document(fd.clone()))
.collect();
doc.insert(as_field.to_string(), Bson::Array(matches));
doc
})
.collect())
}
fn stage_facet(
docs: Vec<Document>,
spec: &Bson,
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let facets = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$facet requires a document".into())),
};
let mut result = Document::new();
for (facet_name, pipeline_bson) in facets {
let pipeline = match pipeline_bson {
Bson::Array(arr) => {
let mut stages = Vec::new();
for stage in arr {
if let Bson::Document(d) = stage {
stages.push(d.clone());
}
}
stages
}
_ => continue,
};
let facet_result = Self::aggregate(docs.clone(), &pipeline, resolver, db)?;
result.insert(
facet_name.clone(),
Bson::Array(facet_result.into_iter().map(Bson::Document).collect()),
);
}
Ok(vec![result])
}
fn stage_union_with(
mut docs: Vec<Document>,
spec: &Bson,
resolver: Option<&dyn CollectionResolver>,
db: &str,
) -> Result<Vec<Document>, QueryError> {
let (coll, pipeline) = match spec {
Bson::String(s) => (s.as_str(), None),
Bson::Document(d) => {
let coll = d.get_str("coll")
.map_err(|_| QueryError::AggregationError("$unionWith requires 'coll'".into()))?;
let pipeline = d.get_array("pipeline").ok().map(|arr| {
arr.iter()
.filter_map(|s| {
if let Bson::Document(d) = s { Some(d.clone()) } else { None }
})
.collect::<Vec<Document>>()
});
(coll, pipeline)
}
_ => return Err(QueryError::AggregationError("$unionWith requires a string or document".into())),
};
let resolver = resolver
.ok_or_else(|| QueryError::AggregationError("$unionWith requires a collection resolver".into()))?;
let mut other_docs = resolver.resolve(db, coll)?;
if let Some(p) = pipeline {
other_docs = Self::aggregate(other_docs, &p, Some(resolver), db)?;
}
docs.extend(other_docs);
Ok(docs)
}
}
// --- Helper functions ---
fn resolve_expression(expr: &Bson, doc: &Document) -> Bson {
match expr {
Bson::String(s) if s.starts_with('$') => {
let field = &s[1..];
get_nested_value(doc, field).unwrap_or(Bson::Null)
}
_ => expr.clone(),
}
}
fn bson_to_usize(v: &Bson) -> Option<usize> {
match v {
Bson::Int32(n) => Some(*n as usize),
Bson::Int64(n) => Some(*n as usize),
Bson::Double(n) => Some(*n as usize),
_ => None,
}
}
fn bson_to_f64(v: &Bson) -> Option<f64> {
match v {
Bson::Int32(n) => Some(*n as f64),
Bson::Int64(n) => Some(*n as f64),
Bson::Double(n) => Some(*n),
_ => None,
}
}
fn bson_loose_eq(a: &Bson, b: &Bson) -> bool {
match (a, b) {
(Bson::Int32(x), Bson::Int64(y)) => (*x as i64) == *y,
(Bson::Int64(x), Bson::Int32(y)) => *x == (*y as i64),
(Bson::Int32(x), Bson::Double(y)) => (*x as f64) == *y,
(Bson::Double(x), Bson::Int32(y)) => *x == (*y as f64),
_ => a == b,
}
}
// --- Accumulators ---
fn accumulate_sum(docs: &[Document], expr: &Bson) -> Bson {
match expr {
Bson::Int32(n) => Bson::Int64(*n as i64 * docs.len() as i64),
Bson::Int64(n) => Bson::Int64(*n * docs.len() as i64),
Bson::String(s) if s.starts_with('$') => {
let field = &s[1..];
let mut sum = 0.0f64;
let mut is_int = true;
let mut int_sum = 0i64;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
if let Some(n) = bson_to_f64(&val) {
sum += n;
if is_int {
match &val {
Bson::Int32(i) => int_sum += *i as i64,
Bson::Int64(i) => int_sum += i,
_ => is_int = false,
}
}
}
}
}
if is_int {
Bson::Int64(int_sum)
} else {
Bson::Double(sum)
}
}
_ => Bson::Int32(0),
}
}
fn accumulate_avg(docs: &[Document], expr: &Bson) -> Bson {
if docs.is_empty() {
return Bson::Null;
}
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
let mut sum = 0.0f64;
let mut count = 0usize;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
if let Some(n) = bson_to_f64(&val) {
sum += n;
count += 1;
}
}
}
if count == 0 {
Bson::Null
} else {
Bson::Double(sum / count as f64)
}
}
fn accumulate_min(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
let mut min: Option<Bson> = None;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
min = Some(match min {
None => val,
Some(current) => {
if let (Some(cv), Some(vv)) = (bson_to_f64(&current), bson_to_f64(&val)) {
if vv < cv { val } else { current }
} else {
current
}
}
});
}
}
min.unwrap_or(Bson::Null)
}
fn accumulate_max(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
let mut max: Option<Bson> = None;
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
max = Some(match max {
None => val,
Some(current) => {
if let (Some(cv), Some(vv)) = (bson_to_f64(&current), bson_to_f64(&val)) {
if vv > cv { val } else { current }
} else {
current
}
}
});
}
}
max.unwrap_or(Bson::Null)
}
fn accumulate_first(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
docs.first()
.and_then(|doc| get_nested_value(doc, field))
.unwrap_or(Bson::Null)
}
fn accumulate_last(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Null,
};
docs.last()
.and_then(|doc| get_nested_value(doc, field))
.unwrap_or(Bson::Null)
}
fn accumulate_push(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Array(vec![]),
};
let values: Vec<Bson> = docs
.iter()
.filter_map(|doc| get_nested_value(doc, field))
.collect();
Bson::Array(values)
}
fn accumulate_add_to_set(docs: &[Document], expr: &Bson) -> Bson {
let field = match expr {
Bson::String(s) if s.starts_with('$') => &s[1..],
_ => return Bson::Array(vec![]),
};
let mut seen = std::collections::HashSet::new();
let mut values = Vec::new();
for doc in docs {
if let Some(val) = get_nested_value(doc, field) {
let key = format!("{:?}", val);
if seen.insert(key) {
values.push(val);
}
}
}
Bson::Array(values)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_match_stage() {
let docs = vec![
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
bson::doc! { "x": 3 },
];
let pipeline = vec![bson::doc! { "$match": { "x": { "$gt": 1 } } }];
let result = AggregationEngine::aggregate(docs, &pipeline, None, "test").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_group_stage() {
let docs = vec![
bson::doc! { "category": "a", "value": 10 },
bson::doc! { "category": "b", "value": 20 },
bson::doc! { "category": "a", "value": 30 },
];
let pipeline = vec![bson::doc! {
"$group": {
"_id": "$category",
"total": { "$sum": "$value" }
}
}];
let result = AggregationEngine::aggregate(docs, &pipeline, None, "test").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_sort_limit_skip() {
let docs = vec![
bson::doc! { "x": 3 },
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
bson::doc! { "x": 4 },
];
let pipeline = vec![
bson::doc! { "$sort": { "x": 1 } },
bson::doc! { "$skip": 1_i64 },
bson::doc! { "$limit": 2_i64 },
];
let result = AggregationEngine::aggregate(docs, &pipeline, None, "test").unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].get_i32("x").unwrap(), 2);
assert_eq!(result[1].get_i32("x").unwrap(), 3);
}
}

View File

@@ -0,0 +1,80 @@
use bson::{Bson, Document};
use std::collections::HashSet;
use crate::field_path::get_nested_value;
use crate::matcher::QueryMatcher;
/// Get distinct values for a field across documents, with optional filter.
/// Handles array flattening (each array element counted separately).
pub fn distinct_values(
docs: &[Document],
field: &str,
filter: Option<&Document>,
) -> Vec<Bson> {
let filtered: Vec<&Document> = if let Some(f) = filter {
docs.iter().filter(|d| QueryMatcher::matches(d, f)).collect()
} else {
docs.iter().collect()
};
let mut seen = HashSet::new();
let mut result = Vec::new();
for doc in &filtered {
let value = if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
};
if let Some(val) = value {
collect_distinct_values(&val, &mut seen, &mut result);
}
}
result
}
fn collect_distinct_values(value: &Bson, seen: &mut HashSet<String>, result: &mut Vec<Bson>) {
match value {
Bson::Array(arr) => {
// Flatten: each array element is a separate value
for elem in arr {
collect_distinct_values(elem, seen, result);
}
}
_ => {
let key = format!("{:?}", value);
if seen.insert(key) {
result.push(value.clone());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distinct_simple() {
let docs = vec![
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
bson::doc! { "x": 1 },
bson::doc! { "x": 3 },
];
let result = distinct_values(&docs, "x", None);
assert_eq!(result.len(), 3);
}
#[test]
fn test_distinct_array_flattening() {
let docs = vec![
bson::doc! { "tags": ["a", "b"] },
bson::doc! { "tags": ["b", "c"] },
];
let result = distinct_values(&docs, "tags", None);
assert_eq!(result.len(), 3); // a, b, c
}
}

View File

@@ -0,0 +1,18 @@
/// Errors from query operations.
#[derive(Debug, thiserror::Error)]
pub enum QueryError {
#[error("Invalid query operator: {0}")]
InvalidOperator(String),
#[error("Type mismatch: {0}")]
TypeMismatch(String),
#[error("Invalid update: {0}")]
InvalidUpdate(String),
#[error("Aggregation error: {0}")]
AggregationError(String),
#[error("Invalid regex: {0}")]
InvalidRegex(String),
}

View File

@@ -0,0 +1,115 @@
use bson::{Bson, Document};
/// Get a nested value from a document using dot-notation path (e.g., "a.b.c").
/// Handles both nested documents and array traversal.
pub fn get_nested_value(doc: &Document, path: &str) -> Option<Bson> {
let parts: Vec<&str> = path.split('.').collect();
get_nested_recursive(&Bson::Document(doc.clone()), &parts)
}
fn get_nested_recursive(value: &Bson, parts: &[&str]) -> Option<Bson> {
if parts.is_empty() {
return Some(value.clone());
}
let key = parts[0];
let rest = &parts[1..];
match value {
Bson::Document(doc) => {
let child = doc.get(key)?;
get_nested_recursive(child, rest)
}
Bson::Array(arr) => {
// Try numeric index first
if let Ok(idx) = key.parse::<usize>() {
if let Some(elem) = arr.get(idx) {
return get_nested_recursive(elem, rest);
}
}
// Otherwise, collect from all elements
let results: Vec<Bson> = arr
.iter()
.filter_map(|elem| get_nested_recursive(elem, parts))
.collect();
if results.is_empty() {
None
} else if results.len() == 1 {
Some(results.into_iter().next().unwrap())
} else {
Some(Bson::Array(results))
}
}
_ => None,
}
}
/// Set a nested value in a document using dot-notation path.
pub fn set_nested_value(doc: &mut Document, path: &str, value: Bson) {
let parts: Vec<&str> = path.split('.').collect();
set_nested_recursive(doc, &parts, value);
}
fn set_nested_recursive(doc: &mut Document, parts: &[&str], value: Bson) {
if parts.len() == 1 {
doc.insert(parts[0].to_string(), value);
return;
}
let key = parts[0];
let rest = &parts[1..];
// Get or create nested document
if !doc.contains_key(key) {
doc.insert(key.to_string(), Bson::Document(Document::new()));
}
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
set_nested_recursive(nested, rest, value);
}
}
/// Remove a nested value from a document using dot-notation path.
pub fn remove_nested_value(doc: &mut Document, path: &str) -> Option<Bson> {
let parts: Vec<&str> = path.split('.').collect();
remove_nested_recursive(doc, &parts)
}
fn remove_nested_recursive(doc: &mut Document, parts: &[&str]) -> Option<Bson> {
if parts.len() == 1 {
return doc.remove(parts[0]);
}
let key = parts[0];
let rest = &parts[1..];
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
remove_nested_recursive(nested, rest)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_nested_simple() {
let doc = bson::doc! { "a": { "b": { "c": 42 } } };
assert_eq!(get_nested_value(&doc, "a.b.c"), Some(Bson::Int32(42)));
}
#[test]
fn test_get_nested_missing() {
let doc = bson::doc! { "a": { "b": 1 } };
assert_eq!(get_nested_value(&doc, "a.c"), None);
}
#[test]
fn test_set_nested() {
let mut doc = bson::doc! {};
set_nested_value(&mut doc, "a.b.c", Bson::Int32(42));
assert_eq!(get_nested_value(&doc, "a.b.c"), Some(Bson::Int32(42)));
}
}

View File

@@ -0,0 +1,16 @@
mod matcher;
mod update;
mod sort;
mod projection;
mod distinct;
pub mod aggregation;
mod field_path;
pub mod error;
pub use matcher::QueryMatcher;
pub use update::UpdateEngine;
pub use sort::sort_documents;
pub use projection::apply_projection;
pub use distinct::distinct_values;
pub use aggregation::AggregationEngine;
pub use field_path::{get_nested_value, set_nested_value};

View File

@@ -0,0 +1,574 @@
use bson::{Bson, Document};
use regex::Regex;
use crate::field_path::get_nested_value;
/// Query matching engine.
/// Evaluates filter documents against BSON documents.
pub struct QueryMatcher;
impl QueryMatcher {
/// Test whether a single document matches a filter.
pub fn matches(doc: &Document, filter: &Document) -> bool {
Self::matches_filter(doc, filter)
}
/// Filter a slice of documents, returning those that match.
pub fn filter(docs: &[Document], filter: &Document) -> Vec<Document> {
if filter.is_empty() {
return docs.to_vec();
}
docs.iter()
.filter(|doc| Self::matches_filter(doc, filter))
.cloned()
.collect()
}
/// Find the first document matching a filter.
pub fn find_one(docs: &[Document], filter: &Document) -> Option<Document> {
docs.iter()
.find(|doc| Self::matches_filter(doc, filter))
.cloned()
}
fn matches_filter(doc: &Document, filter: &Document) -> bool {
for (key, value) in filter {
if !Self::matches_condition(doc, key, value) {
return false;
}
}
true
}
fn matches_condition(doc: &Document, key: &str, condition: &Bson) -> bool {
match key {
"$and" => Self::match_logical_and(doc, condition),
"$or" => Self::match_logical_or(doc, condition),
"$nor" => Self::match_logical_nor(doc, condition),
"$not" => Self::match_logical_not(doc, condition),
"$expr" => {
// Basic $expr support - just return true for now
true
}
_ => {
// Field condition
match condition {
Bson::Document(cond_doc) if Self::has_operators(cond_doc) => {
Self::match_field_operators(doc, key, cond_doc)
}
// Implicit equality
_ => Self::match_equality(doc, key, condition),
}
}
}
}
fn has_operators(doc: &Document) -> bool {
doc.keys().any(|k| k.starts_with('$'))
}
/// Public accessor for has_operators (used by update engine).
pub fn has_operators_pub(doc: &Document) -> bool {
Self::has_operators(doc)
}
/// Public accessor for bson_compare (used by update engine).
pub fn bson_compare_pub(a: &Bson, b: &Bson) -> Option<std::cmp::Ordering> {
Self::bson_compare(a, b)
}
fn match_equality(doc: &Document, field: &str, expected: &Bson) -> bool {
let actual = Self::resolve_field(doc, field);
match actual {
Some(val) => Self::bson_equals(&val, expected),
None => matches!(expected, Bson::Null),
}
}
fn match_field_operators(doc: &Document, field: &str, operators: &Document) -> bool {
let actual = Self::resolve_field(doc, field);
for (op, op_value) in operators {
let result = match op.as_str() {
"$eq" => Self::op_eq(&actual, op_value),
"$ne" => Self::op_ne(&actual, op_value),
"$gt" => Self::op_cmp(&actual, op_value, CmpOp::Gt),
"$gte" => Self::op_cmp(&actual, op_value, CmpOp::Gte),
"$lt" => Self::op_cmp(&actual, op_value, CmpOp::Lt),
"$lte" => Self::op_cmp(&actual, op_value, CmpOp::Lte),
"$in" => Self::op_in(&actual, op_value),
"$nin" => Self::op_nin(&actual, op_value),
"$exists" => Self::op_exists(&actual, op_value),
"$type" => Self::op_type(&actual, op_value),
"$regex" => Self::op_regex(&actual, op_value, operators.get("$options")),
"$not" => Self::op_not(doc, field, op_value),
"$elemMatch" => Self::op_elem_match(&actual, op_value),
"$size" => Self::op_size(&actual, op_value),
"$all" => Self::op_all(&actual, op_value),
"$mod" => Self::op_mod(&actual, op_value),
"$options" => continue, // handled by $regex
_ => true, // unknown operator, skip
};
if !result {
return false;
}
}
true
}
fn resolve_field(doc: &Document, field: &str) -> Option<Bson> {
if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
}
}
fn bson_equals(a: &Bson, b: &Bson) -> bool {
match (a, b) {
(Bson::Int32(x), Bson::Int64(y)) => (*x as i64) == *y,
(Bson::Int64(x), Bson::Int32(y)) => *x == (*y as i64),
(Bson::Int32(x), Bson::Double(y)) => (*x as f64) == *y,
(Bson::Double(x), Bson::Int32(y)) => *x == (*y as f64),
(Bson::Int64(x), Bson::Double(y)) => (*x as f64) == *y,
(Bson::Double(x), Bson::Int64(y)) => *x == (*y as f64),
// For arrays, check if any element matches (implicit $elemMatch)
(Bson::Array(arr), _) if !matches!(b, Bson::Array(_)) => {
arr.iter().any(|elem| Self::bson_equals(elem, b))
}
_ => a == b,
}
}
fn bson_compare(a: &Bson, b: &Bson) -> Option<std::cmp::Ordering> {
use std::cmp::Ordering;
match (a, b) {
// Numeric comparisons (cross-type)
(Bson::Int32(x), Bson::Int32(y)) => Some(x.cmp(y)),
(Bson::Int64(x), Bson::Int64(y)) => Some(x.cmp(y)),
(Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y),
(Bson::Int32(x), Bson::Int64(y)) => Some((*x as i64).cmp(y)),
(Bson::Int64(x), Bson::Int32(y)) => Some(x.cmp(&(*y as i64))),
(Bson::Int32(x), Bson::Double(y)) => (*x as f64).partial_cmp(y),
(Bson::Double(x), Bson::Int32(y)) => x.partial_cmp(&(*y as f64)),
(Bson::Int64(x), Bson::Double(y)) => (*x as f64).partial_cmp(y),
(Bson::Double(x), Bson::Int64(y)) => x.partial_cmp(&(*y as f64)),
// String comparisons
(Bson::String(x), Bson::String(y)) => Some(x.cmp(y)),
// DateTime comparisons
(Bson::DateTime(x), Bson::DateTime(y)) => Some(x.cmp(y)),
// Boolean comparisons
(Bson::Boolean(x), Bson::Boolean(y)) => Some(x.cmp(y)),
// ObjectId comparisons
(Bson::ObjectId(x), Bson::ObjectId(y)) => Some(x.cmp(y)),
// Null comparisons
(Bson::Null, Bson::Null) => Some(Ordering::Equal),
_ => None,
}
}
// --- Operator implementations ---
fn op_eq(actual: &Option<Bson>, expected: &Bson) -> bool {
match actual {
Some(val) => Self::bson_equals(val, expected),
None => matches!(expected, Bson::Null),
}
}
fn op_ne(actual: &Option<Bson>, expected: &Bson) -> bool {
!Self::op_eq(actual, expected)
}
fn op_cmp(actual: &Option<Bson>, expected: &Bson, op: CmpOp) -> bool {
let val = match actual {
Some(v) => v,
None => return false,
};
// For arrays, check if any element satisfies the comparison
if let Bson::Array(arr) = val {
return arr.iter().any(|elem| {
if let Some(ord) = Self::bson_compare(elem, expected) {
op.check(ord)
} else {
false
}
});
}
if let Some(ord) = Self::bson_compare(val, expected) {
op.check(ord)
} else {
false
}
}
fn op_in(actual: &Option<Bson>, values: &Bson) -> bool {
let arr = match values {
Bson::Array(a) => a,
_ => return false,
};
match actual {
Some(val) => {
// For array values, check if any element is in the list
if let Bson::Array(actual_arr) = val {
actual_arr.iter().any(|elem| {
arr.iter().any(|v| Self::bson_equals(elem, v))
}) || arr.iter().any(|v| Self::bson_equals(val, v))
} else {
arr.iter().any(|v| Self::bson_equals(val, v))
}
}
None => arr.iter().any(|v| matches!(v, Bson::Null)),
}
}
fn op_nin(actual: &Option<Bson>, values: &Bson) -> bool {
!Self::op_in(actual, values)
}
fn op_exists(actual: &Option<Bson>, expected: &Bson) -> bool {
let should_exist = match expected {
Bson::Boolean(b) => *b,
Bson::Int32(n) => *n != 0,
Bson::Int64(n) => *n != 0,
_ => true,
};
actual.is_some() == should_exist
}
fn op_type(actual: &Option<Bson>, expected: &Bson) -> bool {
let val = match actual {
Some(v) => v,
None => return false,
};
let type_num = match expected {
Bson::Int32(n) => *n,
Bson::String(s) => match s.as_str() {
"double" => 1,
"string" => 2,
"object" => 3,
"array" => 4,
"binData" => 5,
"objectId" => 7,
"bool" => 8,
"date" => 9,
"null" => 10,
"regex" => 11,
"int" => 16,
"long" => 18,
"decimal" => 19,
"number" => -1, // special: any numeric type
_ => return false,
},
_ => return false,
};
if type_num == -1 {
return matches!(val, Bson::Int32(_) | Bson::Int64(_) | Bson::Double(_));
}
let actual_type = match val {
Bson::Double(_) => 1,
Bson::String(_) => 2,
Bson::Document(_) => 3,
Bson::Array(_) => 4,
Bson::Binary(_) => 5,
Bson::ObjectId(_) => 7,
Bson::Boolean(_) => 8,
Bson::DateTime(_) => 9,
Bson::Null => 10,
Bson::RegularExpression(_) => 11,
Bson::Int32(_) => 16,
Bson::Int64(_) => 18,
Bson::Decimal128(_) => 19,
_ => 0,
};
actual_type == type_num
}
fn op_regex(actual: &Option<Bson>, pattern: &Bson, options: Option<&Bson>) -> bool {
let val = match actual {
Some(Bson::String(s)) => s.as_str(),
_ => return false,
};
let pattern_str = match pattern {
Bson::String(s) => s.as_str(),
Bson::RegularExpression(re) => re.pattern.as_str(),
_ => return false,
};
let opts = match options {
Some(Bson::String(s)) => s.as_str(),
_ => match pattern {
Bson::RegularExpression(re) => re.options.as_str(),
_ => "",
},
};
let mut regex_pattern = String::new();
if opts.contains('i') {
regex_pattern.push_str("(?i)");
}
if opts.contains('m') {
regex_pattern.push_str("(?m)");
}
if opts.contains('s') {
regex_pattern.push_str("(?s)");
}
regex_pattern.push_str(pattern_str);
match Regex::new(&regex_pattern) {
Ok(re) => re.is_match(val),
Err(_) => false,
}
}
fn op_not(doc: &Document, field: &str, condition: &Bson) -> bool {
match condition {
Bson::Document(cond_doc) => !Self::match_field_operators(doc, field, cond_doc),
_ => true,
}
}
fn op_elem_match(actual: &Option<Bson>, condition: &Bson) -> bool {
let arr = match actual {
Some(Bson::Array(a)) => a,
_ => return false,
};
let cond_doc = match condition {
Bson::Document(d) => d,
_ => return false,
};
arr.iter().any(|elem| {
if let Bson::Document(elem_doc) = elem {
Self::matches_filter(elem_doc, cond_doc)
} else {
false
}
})
}
fn op_size(actual: &Option<Bson>, expected: &Bson) -> bool {
let arr = match actual {
Some(Bson::Array(a)) => a,
_ => return false,
};
let expected_size = match expected {
Bson::Int32(n) => *n as usize,
Bson::Int64(n) => *n as usize,
_ => return false,
};
arr.len() == expected_size
}
fn op_all(actual: &Option<Bson>, expected: &Bson) -> bool {
let arr = match actual {
Some(Bson::Array(a)) => a,
_ => return false,
};
let expected_arr = match expected {
Bson::Array(a) => a,
_ => return false,
};
expected_arr.iter().all(|expected_val| {
arr.iter().any(|elem| Self::bson_equals(elem, expected_val))
})
}
fn op_mod(actual: &Option<Bson>, expected: &Bson) -> bool {
let val = match actual {
Some(v) => match v {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
Bson::Double(n) => *n as i64,
_ => return false,
},
None => return false,
};
let arr = match expected {
Bson::Array(a) if a.len() == 2 => a,
_ => return false,
};
let divisor = match &arr[0] {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return false,
};
let remainder = match &arr[1] {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return false,
};
if divisor == 0 {
return false;
}
val % divisor == remainder
}
// --- Logical operators ---
fn match_logical_and(doc: &Document, conditions: &Bson) -> bool {
match conditions {
Bson::Array(arr) => arr.iter().all(|cond| {
if let Bson::Document(cond_doc) = cond {
Self::matches_filter(doc, cond_doc)
} else {
false
}
}),
_ => false,
}
}
fn match_logical_or(doc: &Document, conditions: &Bson) -> bool {
match conditions {
Bson::Array(arr) => arr.iter().any(|cond| {
if let Bson::Document(cond_doc) = cond {
Self::matches_filter(doc, cond_doc)
} else {
false
}
}),
_ => false,
}
}
fn match_logical_nor(doc: &Document, conditions: &Bson) -> bool {
!Self::match_logical_or(doc, conditions)
}
fn match_logical_not(doc: &Document, condition: &Bson) -> bool {
match condition {
Bson::Document(cond_doc) => !Self::matches_filter(doc, cond_doc),
_ => true,
}
}
}
#[derive(Debug, Clone, Copy)]
enum CmpOp {
Gt,
Gte,
Lt,
Lte,
}
impl CmpOp {
fn check(self, ord: std::cmp::Ordering) -> bool {
use std::cmp::Ordering;
match self {
CmpOp::Gt => ord == Ordering::Greater,
CmpOp::Gte => ord == Ordering::Greater || ord == Ordering::Equal,
CmpOp::Lt => ord == Ordering::Less,
CmpOp::Lte => ord == Ordering::Less || ord == Ordering::Equal,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_equality() {
let doc = bson::doc! { "name": "Alice", "age": 30 };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "name": "Alice" }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "name": "Bob" }));
}
#[test]
fn test_comparison_operators() {
let doc = bson::doc! { "age": 30 };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$gt": 25 } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$gte": 30 } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$lt": 35 } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$lte": 30 } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "age": { "$gt": 30 } }));
}
#[test]
fn test_in_operator() {
let doc = bson::doc! { "status": "active" };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "status": { "$in": ["active", "pending"] } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "status": { "$in": ["closed"] } }));
}
#[test]
fn test_exists_operator() {
let doc = bson::doc! { "name": "Alice" };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "name": { "$exists": true } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "age": { "$exists": true } }));
assert!(QueryMatcher::matches(&doc, &bson::doc! { "age": { "$exists": false } }));
}
#[test]
fn test_logical_or() {
let doc = bson::doc! { "age": 30 };
assert!(QueryMatcher::matches(&doc, &bson::doc! {
"$or": [{ "age": 30 }, { "age": 40 }]
}));
assert!(!QueryMatcher::matches(&doc, &bson::doc! {
"$or": [{ "age": 20 }, { "age": 40 }]
}));
}
#[test]
fn test_logical_and() {
let doc = bson::doc! { "age": 30, "name": "Alice" };
assert!(QueryMatcher::matches(&doc, &bson::doc! {
"$and": [{ "age": 30 }, { "name": "Alice" }]
}));
assert!(!QueryMatcher::matches(&doc, &bson::doc! {
"$and": [{ "age": 30 }, { "name": "Bob" }]
}));
}
#[test]
fn test_dot_notation() {
let doc = bson::doc! { "address": { "city": "NYC" } };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "address.city": "NYC" }));
}
#[test]
fn test_ne_operator() {
let doc = bson::doc! { "status": "active" };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "status": { "$ne": "closed" } }));
assert!(!QueryMatcher::matches(&doc, &bson::doc! { "status": { "$ne": "active" } }));
}
#[test]
fn test_cross_type_numeric_equality() {
let doc = bson::doc! { "count": 5_i32 };
assert!(QueryMatcher::matches(&doc, &bson::doc! { "count": 5_i64 }));
}
#[test]
fn test_empty_filter_matches_all() {
let doc = bson::doc! { "x": 1 };
assert!(QueryMatcher::matches(&doc, &bson::doc! {}));
}
}

View File

@@ -0,0 +1,168 @@
use bson::{Bson, Document};
use crate::field_path::get_nested_value;
/// Apply a projection to a document.
/// Inclusion mode: only specified fields + _id.
/// Exclusion mode: all fields except specified ones.
/// _id can be explicitly excluded in either mode.
pub fn apply_projection(doc: &Document, projection: &Document) -> Document {
if projection.is_empty() {
return doc.clone();
}
// Determine mode: inclusion or exclusion
let mut has_inclusion = false;
let mut id_explicitly_set = false;
for (key, value) in projection {
if key == "_id" {
id_explicitly_set = true;
continue;
}
match value {
Bson::Int32(0) | Bson::Int64(0) | Bson::Boolean(false) => {}
_ => has_inclusion = true,
}
}
if has_inclusion {
apply_inclusion(doc, projection, id_explicitly_set)
} else {
apply_exclusion(doc, projection)
}
}
fn apply_inclusion(doc: &Document, projection: &Document, id_explicitly_set: bool) -> Document {
let mut result = Document::new();
// Include _id by default unless explicitly excluded
let include_id = if id_explicitly_set {
is_truthy(projection.get("_id"))
} else {
true
};
if include_id {
if let Some(id) = doc.get("_id") {
result.insert("_id", id.clone());
}
}
for (key, value) in projection {
if key == "_id" {
continue;
}
if !is_truthy(Some(value)) {
continue;
}
if key.contains('.') {
if let Some(val) = get_nested_value(doc, key) {
// Rebuild nested structure
set_nested_in_result(&mut result, key, val);
}
} else if let Some(val) = doc.get(key) {
result.insert(key.clone(), val.clone());
}
}
result
}
fn apply_exclusion(doc: &Document, projection: &Document) -> Document {
let mut result = doc.clone();
for (key, value) in projection {
if !is_truthy(Some(value)) {
if key.contains('.') {
// Remove nested field
remove_nested_from_result(&mut result, key);
} else {
result.remove(key);
}
}
}
result
}
fn is_truthy(value: Option<&Bson>) -> bool {
match value {
None => false,
Some(Bson::Int32(0)) | Some(Bson::Int64(0)) | Some(Bson::Boolean(false)) => false,
_ => true,
}
}
fn set_nested_in_result(doc: &mut Document, path: &str, value: Bson) {
let parts: Vec<&str> = path.split('.').collect();
set_nested_recursive(doc, &parts, value);
}
fn set_nested_recursive(doc: &mut Document, parts: &[&str], value: Bson) {
if parts.len() == 1 {
doc.insert(parts[0].to_string(), value);
return;
}
let key = parts[0];
if !doc.contains_key(key) {
doc.insert(key.to_string(), Bson::Document(Document::new()));
}
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
set_nested_recursive(nested, &parts[1..], value);
}
}
fn remove_nested_from_result(doc: &mut Document, path: &str) {
let parts: Vec<&str> = path.split('.').collect();
remove_nested_recursive(doc, &parts);
}
fn remove_nested_recursive(doc: &mut Document, parts: &[&str]) {
if parts.len() == 1 {
doc.remove(parts[0]);
return;
}
let key = parts[0];
if let Some(Bson::Document(ref mut nested)) = doc.get_mut(key) {
remove_nested_recursive(nested, &parts[1..]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inclusion_projection() {
let doc = bson::doc! { "_id": 1, "name": "Alice", "age": 30, "email": "a@b.c" };
let proj = bson::doc! { "name": 1, "age": 1 };
let result = apply_projection(&doc, &proj);
assert!(result.contains_key("_id"));
assert!(result.contains_key("name"));
assert!(result.contains_key("age"));
assert!(!result.contains_key("email"));
}
#[test]
fn test_exclusion_projection() {
let doc = bson::doc! { "_id": 1, "name": "Alice", "age": 30 };
let proj = bson::doc! { "age": 0 };
let result = apply_projection(&doc, &proj);
assert!(result.contains_key("_id"));
assert!(result.contains_key("name"));
assert!(!result.contains_key("age"));
}
#[test]
fn test_exclude_id() {
let doc = bson::doc! { "_id": 1, "name": "Alice" };
let proj = bson::doc! { "name": 1, "_id": 0 };
let result = apply_projection(&doc, &proj);
assert!(!result.contains_key("_id"));
assert!(result.contains_key("name"));
}
}

View File

@@ -0,0 +1,137 @@
use bson::{Bson, Document};
use crate::field_path::get_nested_value;
/// Sort documents according to a sort specification.
/// Sort spec: `{ field1: 1, field2: -1 }` where 1 = ascending, -1 = descending.
pub fn sort_documents(docs: &mut [Document], sort_spec: &Document) {
if sort_spec.is_empty() {
return;
}
docs.sort_by(|a, b| {
for (field, direction) in sort_spec {
let ascending = match direction {
Bson::Int32(n) => *n > 0,
Bson::Int64(n) => *n > 0,
Bson::String(s) => !s.eq_ignore_ascii_case("desc") && !s.eq_ignore_ascii_case("descending"),
_ => true,
};
let a_val = get_value(a, field);
let b_val = get_value(b, field);
let ord = compare_bson_values(&a_val, &b_val);
let ord = if ascending { ord } else { ord.reverse() };
if ord != std::cmp::Ordering::Equal {
return ord;
}
}
std::cmp::Ordering::Equal
});
}
fn get_value(doc: &Document, field: &str) -> Option<Bson> {
if field.contains('.') {
get_nested_value(doc, field)
} else {
doc.get(field).cloned()
}
}
/// Compare two BSON values for sorting purposes.
/// BSON type sort order: null < numbers < strings < objects < arrays < binData < ObjectId < bool < date
fn compare_bson_values(a: &Option<Bson>, b: &Option<Bson>) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(None, None) => Ordering::Equal,
(None, Some(Bson::Null)) => Ordering::Equal,
(Some(Bson::Null), None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(Bson::Null), Some(Bson::Null)) => Ordering::Equal,
(Some(Bson::Null), Some(_)) => Ordering::Less,
(Some(_), Some(Bson::Null)) => Ordering::Greater,
(Some(av), Some(bv)) => compare_typed(av, bv),
}
}
fn compare_typed(a: &Bson, b: &Bson) -> std::cmp::Ordering {
use std::cmp::Ordering;
// Cross-type numeric comparison
let a_num = to_f64(a);
let b_num = to_f64(b);
if let (Some(an), Some(bn)) = (a_num, b_num) {
return an.partial_cmp(&bn).unwrap_or(Ordering::Equal);
}
match (a, b) {
(Bson::String(x), Bson::String(y)) => x.cmp(y),
(Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y),
(Bson::DateTime(x), Bson::DateTime(y)) => x.cmp(y),
(Bson::ObjectId(x), Bson::ObjectId(y)) => x.cmp(y),
_ => {
let ta = type_order(a);
let tb = type_order(b);
ta.cmp(&tb)
}
}
}
fn to_f64(v: &Bson) -> Option<f64> {
match v {
Bson::Int32(n) => Some(*n as f64),
Bson::Int64(n) => Some(*n as f64),
Bson::Double(n) => Some(*n),
_ => None,
}
}
fn type_order(v: &Bson) -> u8 {
match v {
Bson::Null => 0,
Bson::Int32(_) | Bson::Int64(_) | Bson::Double(_) | Bson::Decimal128(_) => 1,
Bson::String(_) => 2,
Bson::Document(_) => 3,
Bson::Array(_) => 4,
Bson::Binary(_) => 5,
Bson::ObjectId(_) => 7,
Bson::Boolean(_) => 8,
Bson::DateTime(_) => 9,
_ => 10,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_ascending() {
let mut docs = vec![
bson::doc! { "x": 3 },
bson::doc! { "x": 1 },
bson::doc! { "x": 2 },
];
sort_documents(&mut docs, &bson::doc! { "x": 1 });
assert_eq!(docs[0].get_i32("x").unwrap(), 1);
assert_eq!(docs[1].get_i32("x").unwrap(), 2);
assert_eq!(docs[2].get_i32("x").unwrap(), 3);
}
#[test]
fn test_sort_descending() {
let mut docs = vec![
bson::doc! { "x": 1 },
bson::doc! { "x": 3 },
bson::doc! { "x": 2 },
];
sort_documents(&mut docs, &bson::doc! { "x": -1 });
assert_eq!(docs[0].get_i32("x").unwrap(), 3);
assert_eq!(docs[1].get_i32("x").unwrap(), 2);
assert_eq!(docs[2].get_i32("x").unwrap(), 1);
}
}

View File

@@ -0,0 +1,575 @@
use bson::{Bson, Document, doc};
use crate::error::QueryError;
use crate::field_path::{get_nested_value, set_nested_value, remove_nested_value};
use crate::matcher::QueryMatcher;
/// Update engine — applies update operators to documents.
pub struct UpdateEngine;
impl UpdateEngine {
/// Apply an update specification to a document.
/// Returns the updated document.
pub fn apply_update(
doc: &Document,
update: &Document,
_array_filters: Option<&[Document]>,
) -> Result<Document, QueryError> {
// Check if this is a replacement (no $ operators)
if !update.keys().any(|k| k.starts_with('$')) {
return Self::apply_replacement(doc, update);
}
let mut result = doc.clone();
for (op, value) in update {
let fields = match value {
Bson::Document(d) => d,
_ => continue,
};
match op.as_str() {
"$set" => Self::apply_set(&mut result, fields)?,
"$unset" => Self::apply_unset(&mut result, fields)?,
"$inc" => Self::apply_inc(&mut result, fields)?,
"$mul" => Self::apply_mul(&mut result, fields)?,
"$min" => Self::apply_min(&mut result, fields)?,
"$max" => Self::apply_max(&mut result, fields)?,
"$rename" => Self::apply_rename(&mut result, fields)?,
"$currentDate" => Self::apply_current_date(&mut result, fields)?,
"$setOnInsert" => {} // handled separately during upsert
"$push" => Self::apply_push(&mut result, fields)?,
"$pop" => Self::apply_pop(&mut result, fields)?,
"$pull" => Self::apply_pull(&mut result, fields)?,
"$pullAll" => Self::apply_pull_all(&mut result, fields)?,
"$addToSet" => Self::apply_add_to_set(&mut result, fields)?,
"$bit" => Self::apply_bit(&mut result, fields)?,
other => {
return Err(QueryError::InvalidUpdate(format!(
"Unknown update operator: {}",
other
)));
}
}
}
Ok(result)
}
/// Apply $setOnInsert fields (used during upsert only).
pub fn apply_set_on_insert(doc: &mut Document, fields: &Document) {
for (key, value) in fields {
if key.contains('.') {
set_nested_value(doc, key, value.clone());
} else {
doc.insert(key.clone(), value.clone());
}
}
}
/// Deep clone a BSON document.
pub fn deep_clone(doc: &Document) -> Document {
doc.clone()
}
fn apply_replacement(doc: &Document, replacement: &Document) -> Result<Document, QueryError> {
let mut result = replacement.clone();
// Preserve _id
if let Some(id) = doc.get("_id") {
result.insert("_id", id.clone());
}
Ok(result)
}
fn apply_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
if key.contains('.') {
set_nested_value(doc, key, value.clone());
} else {
doc.insert(key.clone(), value.clone());
}
}
Ok(())
}
fn apply_unset(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, _) in fields {
if key.contains('.') {
remove_nested_value(doc, key);
} else {
doc.remove(key);
}
}
Ok(())
}
fn apply_inc(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, inc_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let new_value = match (&current, inc_value) {
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a + b),
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a + b),
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 + b),
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a + *b as i64),
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a + b),
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a + *b as f64),
(Some(Bson::Int64(a)), Bson::Double(b)) => Bson::Double(*a as f64 + b),
(Some(Bson::Double(a)), Bson::Int64(b)) => Bson::Double(a + *b as f64),
(None, v) => v.clone(), // treat missing as 0
_ => {
return Err(QueryError::TypeMismatch(format!(
"Cannot apply $inc to non-numeric field: {}",
key
)));
}
};
if key.contains('.') {
set_nested_value(doc, key, new_value);
} else {
doc.insert(key.clone(), new_value);
}
}
Ok(())
}
fn apply_mul(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, mul_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let new_value = match (&current, mul_value) {
(Some(Bson::Int32(a)), Bson::Int32(b)) => Bson::Int32(a * b),
(Some(Bson::Int64(a)), Bson::Int64(b)) => Bson::Int64(a * b),
(Some(Bson::Int32(a)), Bson::Int64(b)) => Bson::Int64(*a as i64 * b),
(Some(Bson::Int64(a)), Bson::Int32(b)) => Bson::Int64(a * *b as i64),
(Some(Bson::Double(a)), Bson::Double(b)) => Bson::Double(a * b),
(Some(Bson::Int32(a)), Bson::Double(b)) => Bson::Double(*a as f64 * b),
(Some(Bson::Double(a)), Bson::Int32(b)) => Bson::Double(a * *b as f64),
(None, _) => Bson::Int32(0), // missing field * anything = 0
_ => {
return Err(QueryError::TypeMismatch(format!(
"Cannot apply $mul to non-numeric field: {}",
key
)));
}
};
if key.contains('.') {
set_nested_value(doc, key, new_value);
} else {
doc.insert(key.clone(), new_value);
}
}
Ok(())
}
fn apply_min(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, min_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let should_update = match &current {
None => true,
Some(cur) => {
if let Some(ord) = QueryMatcher::bson_compare_pub(min_value, cur) {
ord == std::cmp::Ordering::Less
} else {
false
}
}
};
if should_update {
if key.contains('.') {
set_nested_value(doc, key, min_value.clone());
} else {
doc.insert(key.clone(), min_value.clone());
}
}
}
Ok(())
}
fn apply_max(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, max_value) in fields {
let current = if key.contains('.') {
get_nested_value(doc, key)
} else {
doc.get(key).cloned()
};
let should_update = match &current {
None => true,
Some(cur) => {
if let Some(ord) = QueryMatcher::bson_compare_pub(max_value, cur) {
ord == std::cmp::Ordering::Greater
} else {
false
}
}
};
if should_update {
if key.contains('.') {
set_nested_value(doc, key, max_value.clone());
} else {
doc.insert(key.clone(), max_value.clone());
}
}
}
Ok(())
}
fn apply_rename(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (old_name, new_name_bson) in fields {
let new_name = match new_name_bson {
Bson::String(s) => s.clone(),
_ => continue,
};
if let Some(value) = doc.remove(old_name) {
doc.insert(new_name, value);
}
}
Ok(())
}
fn apply_current_date(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
let now = bson::DateTime::now();
for (key, spec) in fields {
let value = match spec {
Bson::Boolean(true) => Bson::DateTime(now),
Bson::Document(d) => {
match d.get_str("$type").unwrap_or("date") {
"date" => Bson::DateTime(now),
"timestamp" => Bson::Timestamp(bson::Timestamp {
time: (now.timestamp_millis() / 1000) as u32,
increment: 0,
}),
_ => Bson::DateTime(now),
}
}
_ => continue,
};
if key.contains('.') {
set_nested_value(doc, key, value);
} else {
doc.insert(key.clone(), value);
}
}
Ok(())
}
fn apply_push(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
let arr = Self::get_or_create_array(doc, key);
match value {
Bson::Document(d) if d.contains_key("$each") => {
let each = match d.get("$each") {
Some(Bson::Array(a)) => a.clone(),
_ => return Err(QueryError::InvalidUpdate("$each must be an array".into())),
};
let position = d.get("$position").and_then(|v| match v {
Bson::Int32(n) => Some(*n as usize),
Bson::Int64(n) => Some(*n as usize),
_ => None,
});
if let Some(pos) = position {
let pos = pos.min(arr.len());
for (i, item) in each.into_iter().enumerate() {
arr.insert(pos + i, item);
}
} else {
arr.extend(each);
}
// Apply $sort if present
if let Some(sort_spec) = d.get("$sort") {
Self::sort_array(arr, sort_spec);
}
// Apply $slice if present
if let Some(slice) = d.get("$slice") {
Self::slice_array(arr, slice);
}
}
_ => {
arr.push(value.clone());
}
}
}
Ok(())
}
fn apply_pop(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, direction) in fields {
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
if arr.is_empty() {
continue;
}
match direction {
Bson::Int32(-1) | Bson::Int64(-1) => { arr.remove(0); }
Bson::Int32(1) | Bson::Int64(1) => { arr.pop(); }
Bson::Double(f) if *f == 1.0 => { arr.pop(); }
Bson::Double(f) if *f == -1.0 => { arr.remove(0); }
_ => { arr.pop(); }
}
}
}
Ok(())
}
fn apply_pull(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, condition) in fields {
if let Some(Bson::Array(arr)) = doc.get_mut(key) {
match condition {
Bson::Document(cond_doc) if QueryMatcher::has_operators_pub(cond_doc) => {
arr.retain(|elem| {
if let Bson::Document(elem_doc) = elem {
!QueryMatcher::matches(elem_doc, cond_doc)
} else {
// For primitive matching with operators
let wrapper = doc! { "v": elem.clone() };
let cond_wrapper = doc! { "v": condition.clone() };
!QueryMatcher::matches(&wrapper, &cond_wrapper)
}
});
}
_ => {
arr.retain(|elem| elem != condition);
}
}
}
}
Ok(())
}
fn apply_pull_all(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, values) in fields {
if let (Some(Bson::Array(arr)), Bson::Array(to_remove)) = (doc.get_mut(key), values) {
arr.retain(|elem| !to_remove.contains(elem));
}
}
Ok(())
}
fn apply_add_to_set(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, value) in fields {
let arr = Self::get_or_create_array(doc, key);
match value {
Bson::Document(d) if d.contains_key("$each") => {
if let Some(Bson::Array(each)) = d.get("$each") {
for item in each {
if !arr.contains(item) {
arr.push(item.clone());
}
}
}
}
_ => {
if !arr.contains(value) {
arr.push(value.clone());
}
}
}
}
Ok(())
}
fn apply_bit(doc: &mut Document, fields: &Document) -> Result<(), QueryError> {
for (key, ops) in fields {
let ops_doc = match ops {
Bson::Document(d) => d,
_ => continue,
};
let current = doc.get(key).cloned().unwrap_or(Bson::Int32(0));
let mut val = match &current {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => continue,
};
for (bit_op, operand) in ops_doc {
let operand_val = match operand {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => continue,
};
match bit_op.as_str() {
"and" => val &= operand_val,
"or" => val |= operand_val,
"xor" => val ^= operand_val,
_ => {}
}
}
let new_value = match &current {
Bson::Int32(_) => Bson::Int32(val as i32),
_ => Bson::Int64(val),
};
doc.insert(key.clone(), new_value);
}
Ok(())
}
// --- Helpers ---
fn get_or_create_array<'a>(doc: &'a mut Document, key: &str) -> &'a mut Vec<Bson> {
// Ensure an array exists at this key
let needs_init = match doc.get(key) {
Some(Bson::Array(_)) => false,
_ => true,
};
if needs_init {
doc.insert(key.to_string(), Bson::Array(Vec::new()));
}
match doc.get_mut(key).unwrap() {
Bson::Array(arr) => arr,
_ => unreachable!(),
}
}
fn sort_array(arr: &mut Vec<Bson>, sort_spec: &Bson) {
match sort_spec {
Bson::Int32(dir) => {
let ascending = *dir > 0;
arr.sort_by(|a, b| {
let ord = partial_cmp_bson(a, b);
if ascending { ord } else { ord.reverse() }
});
}
Bson::Document(spec) => {
arr.sort_by(|a, b| {
for (field, dir) in spec {
let ascending = match dir {
Bson::Int32(n) => *n > 0,
_ => true,
};
let a_val = if let Bson::Document(d) = a { d.get(field) } else { None };
let b_val = if let Bson::Document(d) = b { d.get(field) } else { None };
let ord = match (a_val, b_val) {
(Some(av), Some(bv)) => partial_cmp_bson(av, bv),
(Some(_), None) => std::cmp::Ordering::Greater,
(None, Some(_)) => std::cmp::Ordering::Less,
(None, None) => std::cmp::Ordering::Equal,
};
let ord = if ascending { ord } else { ord.reverse() };
if ord != std::cmp::Ordering::Equal {
return ord;
}
}
std::cmp::Ordering::Equal
});
}
_ => {}
}
}
fn slice_array(arr: &mut Vec<Bson>, slice: &Bson) {
let n = match slice {
Bson::Int32(n) => *n as i64,
Bson::Int64(n) => *n,
_ => return,
};
if n >= 0 {
arr.truncate(n as usize);
} else {
let keep = (-n) as usize;
if keep < arr.len() {
let start = arr.len() - keep;
*arr = arr[start..].to_vec();
}
}
}
}
fn partial_cmp_bson(a: &Bson, b: &Bson) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(Bson::Int32(x), Bson::Int32(y)) => x.cmp(y),
(Bson::Int64(x), Bson::Int64(y)) => x.cmp(y),
(Bson::Double(x), Bson::Double(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
(Bson::String(x), Bson::String(y)) => x.cmp(y),
(Bson::Boolean(x), Bson::Boolean(y)) => x.cmp(y),
_ => Ordering::Equal,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set() {
let doc = doc! { "_id": 1, "name": "Alice" };
let update = doc! { "$set": { "name": "Bob", "age": 30 } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_str("name").unwrap(), "Bob");
assert_eq!(result.get_i32("age").unwrap(), 30);
}
#[test]
fn test_inc() {
let doc = doc! { "_id": 1, "count": 5 };
let update = doc! { "$inc": { "count": 3 } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_i32("count").unwrap(), 8);
}
#[test]
fn test_unset() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
let update = doc! { "$unset": { "age": "" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert!(result.get("age").is_none());
}
#[test]
fn test_replacement() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30 };
let update = doc! { "name": "Bob" };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
assert_eq!(result.get_i32("_id").unwrap(), 1); // preserved
assert_eq!(result.get_str("name").unwrap(), "Bob");
assert!(result.get("age").is_none()); // removed
}
#[test]
fn test_push() {
let doc = doc! { "_id": 1, "tags": ["a"] };
let update = doc! { "$push": { "tags": "b" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2);
}
#[test]
fn test_add_to_set() {
let doc = doc! { "_id": 1, "tags": ["a", "b"] };
let update = doc! { "$addToSet": { "tags": "a" } };
let result = UpdateEngine::apply_update(&doc, &update, None).unwrap();
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2); // no duplicate
}
}