321 lines
12 KiB
Rust
321 lines
12 KiB
Rust
pub mod management;
|
|
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Result;
|
|
use dashmap::DashMap;
|
|
use tokio::net::TcpListener;
|
|
#[cfg(unix)]
|
|
use tokio::net::UnixListener;
|
|
use tokio_util::codec::Framed;
|
|
use tokio_util::sync::CancellationToken;
|
|
|
|
use rustdb_config::{RustDbOptions, StorageType};
|
|
use rustdb_wire::{WireCodec, OP_QUERY};
|
|
use rustdb_wire::{encode_op_msg_response, encode_op_reply_response};
|
|
use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog};
|
|
use rustdb_index::{IndexEngine, IndexOptions};
|
|
use rustdb_txn::{TransactionEngine, SessionEngine};
|
|
use rustdb_commands::{CommandRouter, CommandContext};
|
|
|
|
/// The main RustDb server.
|
|
pub struct RustDb {
|
|
options: RustDbOptions,
|
|
ctx: Arc<CommandContext>,
|
|
router: Arc<CommandRouter>,
|
|
cancel_token: CancellationToken,
|
|
listener_handle: Option<tokio::task::JoinHandle<()>>,
|
|
}
|
|
|
|
impl RustDb {
|
|
/// Create a new RustDb server with the given options.
|
|
pub async fn new(options: RustDbOptions) -> Result<Self> {
|
|
// Create storage adapter
|
|
let storage: Arc<dyn StorageAdapter> = match options.storage {
|
|
StorageType::Memory => {
|
|
let adapter = if let Some(ref pp) = options.persist_path {
|
|
tracing::info!("MemoryStorageAdapter with periodic persistence to {}", pp);
|
|
MemoryStorageAdapter::with_persist_path(PathBuf::from(pp))
|
|
} else {
|
|
tracing::warn!(
|
|
"SmartDB is using in-memory storage — data will NOT survive a restart. \
|
|
Set storage to 'file' for durable persistence."
|
|
);
|
|
MemoryStorageAdapter::new()
|
|
};
|
|
Arc::new(adapter)
|
|
}
|
|
StorageType::File => {
|
|
let path = options
|
|
.storage_path
|
|
.clone()
|
|
.unwrap_or_else(|| "./data".to_string());
|
|
let adapter = FileStorageAdapter::new(&path);
|
|
Arc::new(adapter)
|
|
}
|
|
};
|
|
|
|
// Initialize storage
|
|
storage.initialize().await?;
|
|
|
|
// Restore any previously persisted state (no-op for file storage and
|
|
// memory storage without a persist_path).
|
|
storage.restore().await?;
|
|
|
|
// Spawn periodic persistence task for memory storage with persist_path.
|
|
if options.storage == StorageType::Memory && options.persist_path.is_some() {
|
|
let persist_storage = storage.clone();
|
|
let interval_ms = options.persist_interval_ms;
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(Duration::from_millis(interval_ms));
|
|
interval.tick().await; // skip the immediate first tick
|
|
loop {
|
|
interval.tick().await;
|
|
if let Err(e) = persist_storage.persist().await {
|
|
tracing::error!("Periodic persist failed: {}", e);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
let indexes: Arc<DashMap<String, IndexEngine>> = Arc::new(DashMap::new());
|
|
|
|
// Restore persisted indexes from storage.
|
|
if let Ok(databases) = storage.list_databases().await {
|
|
for db_name in &databases {
|
|
if let Ok(collections) = storage.list_collections(db_name).await {
|
|
for coll_name in &collections {
|
|
if let Ok(specs) = storage.get_indexes(db_name, coll_name).await {
|
|
let has_custom = specs.iter().any(|s| {
|
|
s.get_str("name").unwrap_or("_id_") != "_id_"
|
|
});
|
|
if !has_custom {
|
|
continue;
|
|
}
|
|
|
|
let ns_key = format!("{}.{}", db_name, coll_name);
|
|
let mut engine = IndexEngine::new();
|
|
|
|
for spec in &specs {
|
|
let name = spec.get_str("name").unwrap_or("").to_string();
|
|
if name == "_id_" {
|
|
continue; // already created by IndexEngine::new()
|
|
}
|
|
let key = match spec.get("key") {
|
|
Some(bson::Bson::Document(k)) => k.clone(),
|
|
_ => continue,
|
|
};
|
|
let unique = matches!(spec.get("unique"), Some(bson::Bson::Boolean(true)));
|
|
let sparse = matches!(spec.get("sparse"), Some(bson::Bson::Boolean(true)));
|
|
let expire_after_seconds = match spec.get("expireAfterSeconds") {
|
|
Some(bson::Bson::Int32(n)) => Some(*n as u64),
|
|
Some(bson::Bson::Int64(n)) => Some(*n as u64),
|
|
_ => None,
|
|
};
|
|
|
|
let options = IndexOptions {
|
|
name: Some(name.clone()),
|
|
unique,
|
|
sparse,
|
|
expire_after_seconds,
|
|
};
|
|
if let Err(e) = engine.create_index(key, options) {
|
|
tracing::warn!(
|
|
namespace = %ns_key,
|
|
index = %name,
|
|
error = %e,
|
|
"failed to restore index"
|
|
);
|
|
}
|
|
}
|
|
|
|
// Rebuild index data from existing documents.
|
|
if let Ok(docs) = storage.find_all(db_name, coll_name).await {
|
|
if !docs.is_empty() {
|
|
engine.rebuild_from_documents(&docs);
|
|
}
|
|
}
|
|
|
|
tracing::info!(
|
|
namespace = %ns_key,
|
|
indexes = engine.list_indexes().len(),
|
|
"restored indexes"
|
|
);
|
|
indexes.insert(ns_key, engine);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let ctx = Arc::new(CommandContext {
|
|
storage,
|
|
indexes,
|
|
transactions: Arc::new(TransactionEngine::new()),
|
|
sessions: Arc::new(SessionEngine::new(30 * 60 * 1000, 60 * 1000)),
|
|
cursors: Arc::new(DashMap::new()),
|
|
start_time: std::time::Instant::now(),
|
|
oplog: Arc::new(OpLog::new()),
|
|
});
|
|
|
|
let router = Arc::new(CommandRouter::new(ctx.clone()));
|
|
|
|
Ok(Self {
|
|
options,
|
|
ctx,
|
|
router,
|
|
cancel_token: CancellationToken::new(),
|
|
listener_handle: None,
|
|
})
|
|
}
|
|
|
|
/// Start listening for connections.
|
|
pub async fn start(&mut self) -> Result<()> {
|
|
let cancel = self.cancel_token.clone();
|
|
let router = self.router.clone();
|
|
|
|
if let Some(ref socket_path) = self.options.socket_path {
|
|
#[cfg(unix)]
|
|
{
|
|
// Remove stale socket file
|
|
let _ = tokio::fs::remove_file(socket_path).await;
|
|
|
|
let listener = UnixListener::bind(socket_path)?;
|
|
let socket_path_clone = socket_path.clone();
|
|
tracing::info!("RustDb listening on unix:{}", socket_path_clone);
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel.cancelled() => break,
|
|
result = listener.accept() => {
|
|
match result {
|
|
Ok((stream, _addr)) => {
|
|
let router = router.clone();
|
|
tokio::spawn(async move {
|
|
handle_connection(stream, router).await;
|
|
});
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Accept error: {}", e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
self.listener_handle = Some(handle);
|
|
}
|
|
#[cfg(not(unix))]
|
|
{
|
|
anyhow::bail!("Unix sockets are not supported on this platform");
|
|
}
|
|
} else {
|
|
let addr = format!("{}:{}", self.options.host, self.options.port);
|
|
let listener = TcpListener::bind(&addr).await?;
|
|
tracing::info!("RustDb listening on {}", addr);
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel.cancelled() => break,
|
|
result = listener.accept() => {
|
|
match result {
|
|
Ok((stream, _addr)) => {
|
|
let _ = stream.set_nodelay(true);
|
|
let router = router.clone();
|
|
tokio::spawn(async move {
|
|
handle_connection(stream, router).await;
|
|
});
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Accept error: {}", e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
self.listener_handle = Some(handle);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Stop the server.
|
|
pub async fn stop(&mut self) -> Result<()> {
|
|
self.cancel_token.cancel();
|
|
|
|
if let Some(handle) = self.listener_handle.take() {
|
|
handle.abort();
|
|
let _ = handle.await;
|
|
}
|
|
|
|
// Close storage (persists if configured)
|
|
self.ctx.storage.close().await?;
|
|
|
|
// Clean up Unix socket file
|
|
if let Some(ref socket_path) = self.options.socket_path {
|
|
let _ = tokio::fs::remove_file(socket_path).await;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get the connection URI.
|
|
pub fn connection_uri(&self) -> String {
|
|
self.options.connection_uri()
|
|
}
|
|
|
|
/// Get a reference to the shared command context (for management IPC access to oplog, storage, etc.).
|
|
pub fn ctx(&self) -> &Arc<CommandContext> {
|
|
&self.ctx
|
|
}
|
|
}
|
|
|
|
/// Handle a single client connection using the wire protocol codec.
|
|
async fn handle_connection<S>(stream: S, router: Arc<CommandRouter>)
|
|
where
|
|
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
|
|
{
|
|
use futures_util::{SinkExt, StreamExt};
|
|
|
|
let mut framed = Framed::new(stream, WireCodec);
|
|
|
|
while let Some(result) = framed.next().await {
|
|
match result {
|
|
Ok(parsed_cmd) => {
|
|
let request_id = parsed_cmd.request_id;
|
|
let op_code = parsed_cmd.op_code;
|
|
|
|
let response_doc = router.route(&parsed_cmd).await;
|
|
|
|
let response_id = next_request_id();
|
|
|
|
let response_bytes = if op_code == OP_QUERY {
|
|
encode_op_reply_response(request_id, &[response_doc], response_id, 0)
|
|
} else {
|
|
encode_op_msg_response(request_id, &response_doc, response_id)
|
|
};
|
|
|
|
if let Err(e) = framed.send(response_bytes).await {
|
|
tracing::debug!("Failed to send response: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::debug!("Wire protocol error: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn next_request_id() -> i32 {
|
|
use std::sync::atomic::{AtomicI32, Ordering};
|
|
static COUNTER: AtomicI32 = AtomicI32::new(1);
|
|
COUNTER.fetch_add(1, Ordering::Relaxed)
|
|
}
|