pub mod management; use std::sync::Arc; use anyhow::Result; use dashmap::DashMap; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; use tokio_util::codec::Framed; use tokio_util::sync::CancellationToken; use rustdb_config::{RustDbOptions, StorageType}; use rustdb_wire::{WireCodec, OP_QUERY}; use rustdb_wire::{encode_op_msg_response, encode_op_reply_response}; use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog}; // IndexEngine is used indirectly via CommandContext use rustdb_txn::{TransactionEngine, SessionEngine}; use rustdb_commands::{CommandRouter, CommandContext}; /// The main RustDb server. pub struct RustDb { options: RustDbOptions, ctx: Arc, router: Arc, cancel_token: CancellationToken, listener_handle: Option>, } impl RustDb { /// Create a new RustDb server with the given options. pub async fn new(options: RustDbOptions) -> Result { // Create storage adapter let storage: Arc = match options.storage { StorageType::Memory => { let adapter = MemoryStorageAdapter::new(); Arc::new(adapter) } StorageType::File => { let path = options .storage_path .clone() .unwrap_or_else(|| "./data".to_string()); let adapter = FileStorageAdapter::new(&path); Arc::new(adapter) } }; // Initialize storage storage.initialize().await?; let ctx = Arc::new(CommandContext { storage, indexes: Arc::new(DashMap::new()), transactions: Arc::new(TransactionEngine::new()), sessions: Arc::new(SessionEngine::new(30 * 60 * 1000, 60 * 1000)), cursors: Arc::new(DashMap::new()), start_time: std::time::Instant::now(), oplog: Arc::new(OpLog::new()), }); let router = Arc::new(CommandRouter::new(ctx.clone())); Ok(Self { options, ctx, router, cancel_token: CancellationToken::new(), listener_handle: None, }) } /// Start listening for connections. pub async fn start(&mut self) -> Result<()> { let cancel = self.cancel_token.clone(); let router = self.router.clone(); if let Some(ref socket_path) = self.options.socket_path { #[cfg(unix)] { // Remove stale socket file let _ = tokio::fs::remove_file(socket_path).await; let listener = UnixListener::bind(socket_path)?; let socket_path_clone = socket_path.clone(); tracing::info!("RustDb listening on unix:{}", socket_path_clone); let handle = tokio::spawn(async move { loop { tokio::select! { _ = cancel.cancelled() => break, result = listener.accept() => { match result { Ok((stream, _addr)) => { let router = router.clone(); tokio::spawn(async move { handle_connection(stream, router).await; }); } Err(e) => { tracing::error!("Accept error: {}", e); } } } } } }); self.listener_handle = Some(handle); } #[cfg(not(unix))] { anyhow::bail!("Unix sockets are not supported on this platform"); } } else { let addr = format!("{}:{}", self.options.host, self.options.port); let listener = TcpListener::bind(&addr).await?; tracing::info!("RustDb listening on {}", addr); let handle = tokio::spawn(async move { loop { tokio::select! { _ = cancel.cancelled() => break, result = listener.accept() => { match result { Ok((stream, _addr)) => { let _ = stream.set_nodelay(true); let router = router.clone(); tokio::spawn(async move { handle_connection(stream, router).await; }); } Err(e) => { tracing::error!("Accept error: {}", e); } } } } } }); self.listener_handle = Some(handle); } Ok(()) } /// Stop the server. pub async fn stop(&mut self) -> Result<()> { self.cancel_token.cancel(); if let Some(handle) = self.listener_handle.take() { handle.abort(); let _ = handle.await; } // Close storage (persists if configured) self.ctx.storage.close().await?; // Clean up Unix socket file if let Some(ref socket_path) = self.options.socket_path { let _ = tokio::fs::remove_file(socket_path).await; } Ok(()) } /// Get the connection URI. pub fn connection_uri(&self) -> String { self.options.connection_uri() } /// Get a reference to the shared command context (for management IPC access to oplog, storage, etc.). pub fn ctx(&self) -> &Arc { &self.ctx } } /// Handle a single client connection using the wire protocol codec. async fn handle_connection(stream: S, router: Arc) where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { use futures_util::{SinkExt, StreamExt}; let mut framed = Framed::new(stream, WireCodec); while let Some(result) = framed.next().await { match result { Ok(parsed_cmd) => { let request_id = parsed_cmd.request_id; let op_code = parsed_cmd.op_code; let response_doc = router.route(&parsed_cmd).await; let response_id = next_request_id(); let response_bytes = if op_code == OP_QUERY { encode_op_reply_response(request_id, &[response_doc], response_id, 0) } else { encode_op_msg_response(request_id, &response_doc, response_id) }; if let Err(e) = framed.send(response_bytes).await { tracing::debug!("Failed to send response: {}", e); break; } } Err(e) => { tracing::debug!("Wire protocol error: {}", e); break; } } } } fn next_request_id() -> i32 { use std::sync::atomic::{AtomicI32, Ordering}; static COUNTER: AtomicI32 = AtomicI32::new(1); COUNTER.fetch_add(1, Ordering::Relaxed) }