use std::collections::HashSet; use std::path::PathBuf; use std::time::{SystemTime, UNIX_EPOCH}; use async_trait::async_trait; use bson::{doc, oid::ObjectId, Document}; use dashmap::DashMap; use tracing::{debug, warn}; use crate::adapter::StorageAdapter; use crate::error::{StorageError, StorageResult}; /// Per-document timestamp tracking for conflict detection. type TimestampMap = DashMap; /// db -> coll -> id_hex -> Document type DataStore = DashMap>>; /// db -> coll -> Vec type IndexStore = DashMap>>; /// db -> coll -> id_hex -> last_modified_ms type ModificationStore = DashMap>; fn now_ms() -> i64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as i64 } /// In-memory storage adapter backed by `DashMap`. /// /// Optionally persists to a JSON file at a configured path. pub struct MemoryStorageAdapter { data: DataStore, indexes: IndexStore, modifications: ModificationStore, persist_path: Option, } impl MemoryStorageAdapter { /// Create a new purely in-memory adapter. pub fn new() -> Self { Self { data: DashMap::new(), indexes: DashMap::new(), modifications: DashMap::new(), persist_path: None, } } /// Create a new adapter that will persist state to the given JSON file. pub fn with_persist_path(path: PathBuf) -> Self { Self { data: DashMap::new(), indexes: DashMap::new(), modifications: DashMap::new(), persist_path: Some(path), } } /// Get or create the database entry in the data store. fn ensure_db(&self, db: &str) { self.data.entry(db.to_string()).or_insert_with(DashMap::new); self.indexes .entry(db.to_string()) .or_insert_with(DashMap::new); self.modifications .entry(db.to_string()) .or_insert_with(DashMap::new); } fn extract_id(doc: &Document) -> StorageResult { match doc.get("_id") { Some(bson::Bson::ObjectId(oid)) => Ok(oid.to_hex()), _ => Err(StorageError::NotFound("document missing _id".into())), } } fn record_modification(&self, db: &str, coll: &str, id: &str) { if let Some(db_mods) = self.modifications.get(db) { if let Some(coll_mods) = db_mods.get(coll) { coll_mods.insert(id.to_string(), now_ms()); } } } } #[async_trait] impl StorageAdapter for MemoryStorageAdapter { async fn initialize(&self) -> StorageResult<()> { debug!("MemoryStorageAdapter initialized"); Ok(()) } async fn close(&self) -> StorageResult<()> { // Persist if configured. self.persist().await?; debug!("MemoryStorageAdapter closed"); Ok(()) } // ---- database ---- async fn list_databases(&self) -> StorageResult> { Ok(self.data.iter().map(|e| e.key().clone()).collect()) } async fn create_database(&self, db: &str) -> StorageResult<()> { if self.data.contains_key(db) { return Err(StorageError::AlreadyExists(format!("database '{db}'"))); } self.ensure_db(db); Ok(()) } async fn drop_database(&self, db: &str) -> StorageResult<()> { self.data.remove(db); self.indexes.remove(db); self.modifications.remove(db); Ok(()) } async fn database_exists(&self, db: &str) -> StorageResult { Ok(self.data.contains_key(db)) } // ---- collection ---- async fn list_collections(&self, db: &str) -> StorageResult> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; Ok(db_ref.iter().map(|e| e.key().clone()).collect()) } async fn create_collection(&self, db: &str, coll: &str) -> StorageResult<()> { self.ensure_db(db); let db_ref = self.data.get(db).unwrap(); if db_ref.contains_key(coll) { return Err(StorageError::AlreadyExists(format!( "collection '{db}.{coll}'" ))); } db_ref.insert(coll.to_string(), DashMap::new()); drop(db_ref); // Create modification tracker for this collection. if let Some(db_mods) = self.modifications.get(db) { db_mods.insert(coll.to_string(), DashMap::new()); } // Auto-create _id index spec. let idx_spec = doc! { "name": "_id_", "key": { "_id": 1 } }; if let Some(db_idx) = self.indexes.get(db) { db_idx.insert(coll.to_string(), vec![idx_spec]); } Ok(()) } async fn drop_collection(&self, db: &str, coll: &str) -> StorageResult<()> { if let Some(db_ref) = self.data.get(db) { db_ref.remove(coll); } if let Some(db_idx) = self.indexes.get(db) { db_idx.remove(coll); } if let Some(db_mods) = self.modifications.get(db) { db_mods.remove(coll); } Ok(()) } async fn collection_exists(&self, db: &str, coll: &str) -> StorageResult { Ok(self .data .get(db) .map(|db_ref| db_ref.contains_key(coll)) .unwrap_or(false)) } async fn rename_collection( &self, db: &str, old_name: &str, new_name: &str, ) -> StorageResult<()> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; if db_ref.contains_key(new_name) { return Err(StorageError::AlreadyExists(format!( "collection '{db}.{new_name}'" ))); } let (_, coll_data) = db_ref .remove(old_name) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{old_name}'")))?; db_ref.insert(new_name.to_string(), coll_data); drop(db_ref); // Rename in indexes. if let Some(db_idx) = self.indexes.get(db) { if let Some((_, idx_data)) = db_idx.remove(old_name) { db_idx.insert(new_name.to_string(), idx_data); } } // Rename in modifications. if let Some(db_mods) = self.modifications.get(db) { if let Some((_, mod_data)) = db_mods.remove(old_name) { db_mods.insert(new_name.to_string(), mod_data); } } Ok(()) } // ---- document writes ---- async fn insert_one( &self, db: &str, coll: &str, mut doc: Document, ) -> StorageResult { // Ensure _id exists. if !doc.contains_key("_id") { doc.insert("_id", ObjectId::new()); } let id = Self::extract_id(&doc)?; let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; if coll_ref.contains_key(&id) { return Err(StorageError::AlreadyExists(format!("document '{id}'"))); } coll_ref.insert(id.clone(), doc); drop(coll_ref); drop(db_ref); self.record_modification(db, coll, &id); Ok(id) } async fn insert_many( &self, db: &str, coll: &str, docs: Vec, ) -> StorageResult> { let mut ids = Vec::with_capacity(docs.len()); for doc in docs { let id = self.insert_one(db, coll, doc).await?; ids.push(id); } Ok(ids) } async fn update_by_id( &self, db: &str, coll: &str, id: &str, doc: Document, ) -> StorageResult<()> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; if !coll_ref.contains_key(id) { return Err(StorageError::NotFound(format!("document '{id}'"))); } coll_ref.insert(id.to_string(), doc); drop(coll_ref); drop(db_ref); self.record_modification(db, coll, id); Ok(()) } async fn delete_by_id( &self, db: &str, coll: &str, id: &str, ) -> StorageResult<()> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; coll_ref .remove(id) .ok_or_else(|| StorageError::NotFound(format!("document '{id}'")))?; drop(coll_ref); drop(db_ref); self.record_modification(db, coll, id); Ok(()) } async fn delete_by_ids( &self, db: &str, coll: &str, ids: &[String], ) -> StorageResult<()> { for id in ids { self.delete_by_id(db, coll, id).await?; } Ok(()) } // ---- document reads ---- async fn find_all( &self, db: &str, coll: &str, ) -> StorageResult> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; Ok(coll_ref.iter().map(|e| e.value().clone()).collect()) } async fn find_by_ids( &self, db: &str, coll: &str, ids: HashSet, ) -> StorageResult> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; let mut results = Vec::with_capacity(ids.len()); for id in &ids { if let Some(doc) = coll_ref.get(id) { results.push(doc.value().clone()); } } Ok(results) } async fn find_by_id( &self, db: &str, coll: &str, id: &str, ) -> StorageResult> { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; Ok(coll_ref.get(id).map(|e| e.value().clone())) } async fn count( &self, db: &str, coll: &str, ) -> StorageResult { let db_ref = self .data .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let coll_ref = db_ref .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; Ok(coll_ref.len() as u64) } // ---- indexes ---- async fn save_index( &self, db: &str, coll: &str, name: &str, spec: Document, ) -> StorageResult<()> { let db_idx = self .indexes .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let mut specs = db_idx .get_mut(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; // Remove existing index with same name, then add. specs.retain(|s| s.get_str("name").unwrap_or("") != name); let mut full_spec = spec; full_spec.insert("name", name); specs.push(full_spec); Ok(()) } async fn get_indexes( &self, db: &str, coll: &str, ) -> StorageResult> { let db_idx = self .indexes .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let specs = db_idx .get(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; Ok(specs.clone()) } async fn drop_index( &self, db: &str, coll: &str, name: &str, ) -> StorageResult<()> { let db_idx = self .indexes .get(db) .ok_or_else(|| StorageError::NotFound(format!("database '{db}'")))?; let mut specs = db_idx .get_mut(coll) .ok_or_else(|| StorageError::NotFound(format!("collection '{db}.{coll}'")))?; let before = specs.len(); specs.retain(|s| s.get_str("name").unwrap_or("") != name); if specs.len() == before { return Err(StorageError::NotFound(format!("index '{name}'"))); } Ok(()) } // ---- snapshot / conflict detection ---- async fn create_snapshot( &self, _db: &str, _coll: &str, ) -> StorageResult { Ok(now_ms()) } async fn has_conflicts( &self, db: &str, coll: &str, ids: &HashSet, snapshot_time: i64, ) -> StorageResult { if let Some(db_mods) = self.modifications.get(db) { if let Some(coll_mods) = db_mods.get(coll) { for id in ids { if let Some(ts) = coll_mods.get(id) { if *ts.value() > snapshot_time { return Ok(true); } } } } } Ok(false) } // ---- persistence ---- async fn persist(&self) -> StorageResult<()> { let path = match &self.persist_path { Some(p) => p, None => return Ok(()), }; // Serialize the entire data store to JSON. let mut db_map = serde_json::Map::new(); for db_entry in self.data.iter() { let db_name = db_entry.key().clone(); let mut coll_map = serde_json::Map::new(); for coll_entry in db_entry.value().iter() { let coll_name = coll_entry.key().clone(); let mut docs_map = serde_json::Map::new(); for doc_entry in coll_entry.value().iter() { let id = doc_entry.key().clone(); // Convert bson::Document -> serde_json::Value via bson's // built-in extended-JSON serialization. let json_val: serde_json::Value = bson::to_bson(doc_entry.value()) .map_err(|e| StorageError::SerializationError(e.to_string())) .and_then(|b| { serde_json::to_value(&b) .map_err(|e| StorageError::SerializationError(e.to_string())) })?; docs_map.insert(id, json_val); } coll_map.insert(coll_name, serde_json::Value::Object(docs_map)); } db_map.insert(db_name, serde_json::Value::Object(coll_map)); } let json = serde_json::to_string_pretty(&serde_json::Value::Object(db_map))?; if let Some(parent) = path.parent() { tokio::fs::create_dir_all(parent).await?; } tokio::fs::write(path, json).await?; debug!("MemoryStorageAdapter persisted to {:?}", path); Ok(()) } async fn restore(&self) -> StorageResult<()> { let path = match &self.persist_path { Some(p) => p, None => return Ok(()), }; if !path.exists() { warn!("persist file not found at {:?}, skipping restore", path); return Ok(()); } let json = tokio::fs::read_to_string(path).await?; let root: serde_json::Value = serde_json::from_str(&json)?; let root_obj = root .as_object() .ok_or_else(|| StorageError::SerializationError("expected object".into()))?; self.data.clear(); self.indexes.clear(); self.modifications.clear(); for (db_name, colls_val) in root_obj { self.ensure_db(db_name); let db_ref = self.data.get(db_name).unwrap(); let colls = colls_val .as_object() .ok_or_else(|| StorageError::SerializationError("expected object".into()))?; for (coll_name, docs_val) in colls { let coll_map: DashMap = DashMap::new(); let docs = docs_val .as_object() .ok_or_else(|| StorageError::SerializationError("expected object".into()))?; for (id, doc_val) in docs { let bson_val: bson::Bson = serde_json::from_value(doc_val.clone()) .map_err(|e| StorageError::SerializationError(e.to_string()))?; let doc = bson_val .as_document() .ok_or_else(|| { StorageError::SerializationError("expected document".into()) })? .clone(); coll_map.insert(id.clone(), doc); } db_ref.insert(coll_name.clone(), coll_map); // Restore modification tracker and default _id index. if let Some(db_mods) = self.modifications.get(db_name) { db_mods.insert(coll_name.clone(), DashMap::new()); } if let Some(db_idx) = self.indexes.get(db_name) { let idx_spec = doc! { "name": "_id_", "key": { "_id": 1 } }; db_idx.insert(coll_name.clone(), vec![idx_spec]); } } } debug!("MemoryStorageAdapter restored from {:?}", path); Ok(()) } } impl Default for MemoryStorageAdapter { fn default() -> Self { Self::new() } }