pub mod management; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use anyhow::Result; use dashmap::DashMap; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; use tokio_util::codec::Framed; use tokio_util::sync::CancellationToken; use rustdb_config::{RustDbOptions, StorageType}; use rustdb_wire::{WireCodec, OP_QUERY}; use rustdb_wire::{encode_op_msg_response, encode_op_reply_response}; use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog}; use rustdb_index::{IndexEngine, IndexOptions}; use rustdb_txn::{TransactionEngine, SessionEngine}; use rustdb_commands::{CommandRouter, CommandContext}; /// The main RustDb server. pub struct RustDb { options: RustDbOptions, ctx: Arc, router: Arc, cancel_token: CancellationToken, listener_handle: Option>, } impl RustDb { /// Create a new RustDb server with the given options. pub async fn new(options: RustDbOptions) -> Result { // Create storage adapter let storage: Arc = match options.storage { StorageType::Memory => { let adapter = if let Some(ref pp) = options.persist_path { tracing::info!("MemoryStorageAdapter with periodic persistence to {}", pp); MemoryStorageAdapter::with_persist_path(PathBuf::from(pp)) } else { tracing::warn!( "SmartDB is using in-memory storage — data will NOT survive a restart. \ Set storage to 'file' for durable persistence." ); MemoryStorageAdapter::new() }; Arc::new(adapter) } StorageType::File => { let path = options .storage_path .clone() .unwrap_or_else(|| "./data".to_string()); let adapter = FileStorageAdapter::new(&path); Arc::new(adapter) } }; // Initialize storage storage.initialize().await?; // Restore any previously persisted state (no-op for file storage and // memory storage without a persist_path). storage.restore().await?; // Spawn periodic persistence task for memory storage with persist_path. if options.storage == StorageType::Memory && options.persist_path.is_some() { let persist_storage = storage.clone(); let interval_ms = options.persist_interval_ms; tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_millis(interval_ms)); interval.tick().await; // skip the immediate first tick loop { interval.tick().await; if let Err(e) = persist_storage.persist().await { tracing::error!("Periodic persist failed: {}", e); } } }); } let indexes: Arc> = Arc::new(DashMap::new()); // Restore persisted indexes from storage. if let Ok(databases) = storage.list_databases().await { for db_name in &databases { if let Ok(collections) = storage.list_collections(db_name).await { for coll_name in &collections { if let Ok(specs) = storage.get_indexes(db_name, coll_name).await { let has_custom = specs.iter().any(|s| { s.get_str("name").unwrap_or("_id_") != "_id_" }); if !has_custom { continue; } let ns_key = format!("{}.{}", db_name, coll_name); let mut engine = IndexEngine::new(); for spec in &specs { let name = spec.get_str("name").unwrap_or("").to_string(); if name == "_id_" { continue; // already created by IndexEngine::new() } let key = match spec.get("key") { Some(bson::Bson::Document(k)) => k.clone(), _ => continue, }; let unique = matches!(spec.get("unique"), Some(bson::Bson::Boolean(true))); let sparse = matches!(spec.get("sparse"), Some(bson::Bson::Boolean(true))); let expire_after_seconds = match spec.get("expireAfterSeconds") { Some(bson::Bson::Int32(n)) => Some(*n as u64), Some(bson::Bson::Int64(n)) => Some(*n as u64), _ => None, }; let options = IndexOptions { name: Some(name.clone()), unique, sparse, expire_after_seconds, }; if let Err(e) = engine.create_index(key, options) { tracing::warn!( namespace = %ns_key, index = %name, error = %e, "failed to restore index" ); } } // Rebuild index data from existing documents. if let Ok(docs) = storage.find_all(db_name, coll_name).await { if !docs.is_empty() { engine.rebuild_from_documents(&docs); } } tracing::info!( namespace = %ns_key, indexes = engine.list_indexes().len(), "restored indexes" ); indexes.insert(ns_key, engine); } } } } } let ctx = Arc::new(CommandContext { storage, indexes, transactions: Arc::new(TransactionEngine::new()), sessions: Arc::new(SessionEngine::new(30 * 60 * 1000, 60 * 1000)), cursors: Arc::new(DashMap::new()), start_time: std::time::Instant::now(), oplog: Arc::new(OpLog::new()), }); let router = Arc::new(CommandRouter::new(ctx.clone())); Ok(Self { options, ctx, router, cancel_token: CancellationToken::new(), listener_handle: None, }) } /// Start listening for connections. pub async fn start(&mut self) -> Result<()> { let cancel = self.cancel_token.clone(); let router = self.router.clone(); if let Some(ref socket_path) = self.options.socket_path { #[cfg(unix)] { // Remove stale socket file let _ = tokio::fs::remove_file(socket_path).await; let listener = UnixListener::bind(socket_path)?; let socket_path_clone = socket_path.clone(); tracing::info!("RustDb listening on unix:{}", socket_path_clone); let handle = tokio::spawn(async move { loop { tokio::select! { _ = cancel.cancelled() => break, result = listener.accept() => { match result { Ok((stream, _addr)) => { let router = router.clone(); tokio::spawn(async move { handle_connection(stream, router).await; }); } Err(e) => { tracing::error!("Accept error: {}", e); } } } } } }); self.listener_handle = Some(handle); } #[cfg(not(unix))] { anyhow::bail!("Unix sockets are not supported on this platform"); } } else { let addr = format!("{}:{}", self.options.host, self.options.port); let listener = TcpListener::bind(&addr).await?; tracing::info!("RustDb listening on {}", addr); let handle = tokio::spawn(async move { loop { tokio::select! { _ = cancel.cancelled() => break, result = listener.accept() => { match result { Ok((stream, _addr)) => { let _ = stream.set_nodelay(true); let router = router.clone(); tokio::spawn(async move { handle_connection(stream, router).await; }); } Err(e) => { tracing::error!("Accept error: {}", e); } } } } } }); self.listener_handle = Some(handle); } Ok(()) } /// Stop the server. pub async fn stop(&mut self) -> Result<()> { self.cancel_token.cancel(); if let Some(handle) = self.listener_handle.take() { handle.abort(); let _ = handle.await; } // Close storage (persists if configured) self.ctx.storage.close().await?; // Clean up Unix socket file if let Some(ref socket_path) = self.options.socket_path { let _ = tokio::fs::remove_file(socket_path).await; } Ok(()) } /// Get the connection URI. pub fn connection_uri(&self) -> String { self.options.connection_uri() } /// Get a reference to the shared command context (for management IPC access to oplog, storage, etc.). pub fn ctx(&self) -> &Arc { &self.ctx } } /// Handle a single client connection using the wire protocol codec. async fn handle_connection(stream: S, router: Arc) where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { use futures_util::{SinkExt, StreamExt}; let mut framed = Framed::new(stream, WireCodec); while let Some(result) = framed.next().await { match result { Ok(parsed_cmd) => { let request_id = parsed_cmd.request_id; let op_code = parsed_cmd.op_code; let response_doc = router.route(&parsed_cmd).await; let response_id = next_request_id(); let response_bytes = if op_code == OP_QUERY { encode_op_reply_response(request_id, &[response_doc], response_id, 0) } else { encode_op_msg_response(request_id, &response_doc, response_id) }; if let Err(e) = framed.send(response_bytes).await { tracing::debug!("Failed to send response: {}", e); break; } } Err(e) => { tracing::debug!("Wire protocol error: {}", e); break; } } } } fn next_request_id() -> i32 { use std::sync::atomic::{AtomicI32, Ordering}; static COUNTER: AtomicI32 = AtomicI32::new(1); COUNTER.fetch_add(1, Ordering::Relaxed) }