Compare commits

..

23 Commits

Author SHA1 Message Date
jkunz 2042b345aa v2.8.0 2026-04-29 22:14:46 +00:00
jkunz b72e8ed5e7 feat(transactions): add single-node transaction support with session-aware reads, commits, aborts, and transaction metrics 2026-04-29 22:14:46 +00:00
jkunz e79fe339aa v2.7.1 2026-04-29 22:02:40 +00:00
jkunz 96ae76e70c fix(repo): no changes to commit 2026-04-29 22:02:39 +00:00
jkunz ed2c02bcf9 feat(enterprise): add auth TLS and recovery hardening 2026-04-29 22:01:43 +00:00
jkunz 2f3031cfc7 v2.7.0 2026-04-14 10:44:45 +00:00
jkunz 459adc077a feat(update): add aggregation pipeline updates and enforce immutable _id handling 2026-04-14 10:44:45 +00:00
jkunz 19f18ef480 v2.6.2 2026-04-05 12:42:54 +00:00
jkunz 6148b28cba fix(readme): align architecture diagram formatting in the documentation 2026-04-05 12:42:54 +00:00
jkunz 012632111e v2.6.1 2026-04-05 12:41:39 +00:00
jkunz b9a59a8649 fix(readme): correct ASCII diagram spacing in architecture overview 2026-04-05 12:41:39 +00:00
jkunz f8a8c9fdff v2.6.0 2026-04-05 12:38:46 +00:00
jkunz d37b444dd5 feat(readme): document index enforcement, storage reliability, and data integrity validation features 2026-04-05 12:38:46 +00:00
jkunz 02ad9a29a7 v2.5.9 2026-04-05 09:54:56 +00:00
jkunz 24c504518d fix(rustdb-storage): run collection compaction during file storage initialization after crashes 2026-04-05 09:54:56 +00:00
jkunz 92f07ef3d7 v2.5.8 2026-04-05 09:48:10 +00:00
jkunz 22e010c554 fix(rustdb-storage): detect stale hint files using data file size metadata and add restart persistence regression tests 2026-04-05 09:48:10 +00:00
jkunz 8ebc1bb9e1 v2.5.7 2026-04-05 03:54:13 +00:00
jkunz 3fc21dcd99 fix(repo): no changes to commit 2026-04-05 03:54:13 +00:00
jkunz ad5e0e8a72 chore: gitignore generated bundled.ts to fix release cycle 2026-04-05 03:54:05 +00:00
jkunz c384df20ce v2.5.6 2026-04-05 03:53:29 +00:00
jkunz 4e944f3d05 fix(repo): no changes to commit 2026-04-05 03:53:29 +00:00
jkunz e0455daa2e chore: rebuild bundled debug server with current version 2026-04-05 03:53:22 +00:00
48 changed files with 4498 additions and 268 deletions
+3
View File
@@ -13,5 +13,8 @@ rust/target/
package-lock.json
yarn.lock
# generated bundle (rebuilt on every build, embeds version)
ts_debugserver/bundled.ts
# playwright
.playwright-mcp/
+58
View File
@@ -1,5 +1,63 @@
# Changelog
## 2026-04-29 - 2.8.0 - feat(transactions)
add single-node transaction support with session-aware reads, commits, aborts, and transaction metrics
- Buffer insert, update, delete, find, count, distinct, and findAndModify operations inside driver sessions and apply them on commit with write-conflict checks
- Return MongoDB-compatible NoSuchTransaction and WriteConflict errors for transaction lifecycle failures
- Expose authenticated users in connectionStatus and add session, transaction, auth, and oplog data to serverStatus and management metrics
- Document transaction support and extend bridge metrics typings and integration tests accordingly
## 2026-04-29 - 2.7.1 - fix(repo)
no changes to commit
## 2026-04-14 - 2.7.0 - feat(update)
add aggregation pipeline updates and enforce immutable _id handling
- support aggregation pipeline syntax in update and findOneAndUpdate operations, including upserts
- add $unset stage support for aggregation-based document transformations
- return an ImmutableField error when updates attempt to change _id and preserve _id when omitted from replacements
## 2026-04-05 - 2.6.2 - fix(readme)
align architecture diagram formatting in the documentation
- Adjusts spacing and box alignment in the README architecture diagram for clearer presentation.
## 2026-04-05 - 2.6.1 - fix(readme)
correct ASCII diagram spacing in architecture overview
- Adjusts alignment in the README architecture diagram for clearer visual formatting.
## 2026-04-05 - 2.6.0 - feat(readme)
document index enforcement, storage reliability, and data integrity validation features
- Add documentation for engine-level unique index enforcement and duplicate key behavior
- Describe storage engine reliability features including WAL, CRC32 checks, compaction, hint file staleness detection, and stale socket cleanup
- Add usage documentation for the offline data integrity validation CLI
## 2026-04-05 - 2.5.9 - fix(rustdb-storage)
run collection compaction during file storage initialization after crashes
- Triggers compaction for all loaded collections before starting the periodic background compaction task.
- Helps clean up dead weight left from before a crash during startup.
## 2026-04-05 - 2.5.8 - fix(rustdb-storage)
detect stale hint files using data file size metadata and add restart persistence regression tests
- Store the current data.rdb size in hint file headers and validate it on load to rebuild KeyDir when hints are stale or written in the old format.
- Persist updated hint metadata after compaction and shutdown to avoid missing appended tombstones after restart.
- Add validation reporting for stale hint files based on recorded versus actual data file size.
- Add regression tests covering delete persistence across restarts, missing hint recovery, stale socket cleanup, and unique index enforcement persistence.
## 2026-04-05 - 2.5.7 - fix(repo)
no changes to commit
## 2026-04-05 - 2.5.6 - fix(repo)
no changes to commit
## 2026-04-05 - 2.5.5 - fix(repo)
no changes to commit
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartdb",
"version": "2.5.5",
"version": "2.8.0",
"private": false,
"description": "A MongoDB-compatible embedded database server with wire protocol support, backed by a high-performance Rust engine.",
"exports": {
+124 -29
View File
@@ -44,38 +44,38 @@ SmartDB uses a **sidecar binary** pattern — TypeScript handles lifecycle, Rust
```
┌──────────────────────────────────────────────────────────────┐
│ Your Application
│ (TypeScript / Node.js)
│ ┌───────────────── ┌───────────────────────────┐
│ │ SmartdbServer │────▶│ RustDbBridge (IPC) │
│ │ or LocalSmartDb │ │ @push.rocks/smartrust │
│ └───────────────── └───────────┬───────────────┘
└─────────────────────────────────────────────────────────────┘
│ spawn + JSON IPC
│ Your Application │
│ (TypeScript / Node.js) │
│ ┌──────────────────┐ ┌───────────────────────────┐ │
│ │ SmartdbServer │────▶│ RustDbBridge (IPC) │ │
│ │ or LocalSmartDb │ │ @push.rocks/smartrust │ │
│ └──────────────────┘ └───────────┬───────────────┘ │
└─────────────────────────────────────────────────────────────┘
│ spawn + JSON IPC
┌──────────────────────────────────────────────────────────────┐
│ rustdb binary 🦀
│ rustdb binary
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌───────────────┐ │
│ │ Wire Protocol│→ │Command Router│→ │ Handlers │ │
│ │ (OP_MSG) │ │ (40+ cmds) │ │ Find,Insert.. │ │
│ └──────────────┘ └──────────────┘ └───────┬───────┘ │
│ │
│ ┌─────────┐ ┌────────┐ ┌───────────┐ ┌──────┴──────┐ │
│ │ Query │ │ Update │ │Aggregation│ │ Index │ │
│ │ Matcher │ │ Engine │ │ Engine │ │ Engine │ │
│ └─────────┘ └────────┘ └───────────┘ └─────────────┘ │
│ ┌──────────────┐ ┌──────────────┐ ┌───────────────┐
│ │ Wire Protocol│→ │Command Router│→ │ Handlers │
│ │ (OP_MSG) │ │ (40+ cmds) │ │ Find,Insert.. │
│ └──────────────┘ └──────────────┘ └───────┬───────┘
│ │
│ ┌─────────┐ ┌────────┐ ┌───────────┐ ┌──────┴──────┐
│ │ Query │ │ Update │ │Aggregation│ │ Index │
│ │ Matcher │ │ Engine │ │ Engine │ │ Engine │
│ └─────────┘ └────────┘ └───────────┘ └─────────────┘
│ │
│ ┌──────────────────┐ ┌──────────────────┐ ┌──────────┐ │
│ │ MemoryStorage │ │ FileStorage │ │ OpLog │ │
│ └──────────────────┘ └──────────────────┘ └──────────┘ │
│ ┌──────────────────┐ ┌──────────────────┐ ┌──────────┐
│ │ MemoryStorage │ │ FileStorage │ │ OpLog │
│ └──────────────────┘ └──────────────────┘ └──────────┘
└──────────────────────────────────────────────────────────────┘
│ TCP / Unix Socket (wire protocol)
┌─────────────┴────────────────────────────────────────────────┐
│ MongoClient (mongodb npm driver)
│ Connects directly to Rust binary
│ MongoClient (mongodb npm driver) │
│ Connects directly to Rust binary │
└──────────────────────────────────────────────────────────────┘
```
@@ -159,7 +159,7 @@ The debug dashboard gives you:
## 📝 Operation Log & Point-in-Time Revert
Every write operation (insert, update, delete) is automatically recorded in an in-memory **operation log (OpLog)** with full before/after document snapshots. This enables:
Every write operation (insert, update, delete) is automatically recorded in an in-memory **operation log (OpLog)** with full before/after document snapshots. The OpLog lives in RAM and resets on restart — it covers the current session only. This enables:
- **Change tracking** — see exactly what changed, when, and in which collection
- **Field-level diffs** — compare previous and new document states
@@ -248,6 +248,64 @@ const server = new SmartdbServer({
persistPath: './data/snapshot.json',
persistIntervalMs: 30000, // Save every 30s
});
// TLS transport for TCP mode
const tlsServer = new SmartdbServer({
port: 27017,
tls: {
enabled: true,
certPath: './certs/server.pem',
keyPath: './certs/server.key',
// caPath: './certs/client-ca.pem',
// requireClientCert: true, // Enables mTLS client certificate checks
},
});
// SCRAM-SHA-256 authentication
const secureServer = new SmartdbServer({
port: 27017,
auth: {
enabled: true,
usersPath: './data/smartdb-users.json', // Optional: persists derived SCRAM credentials
users: [
{
username: 'root',
password: 'change-me',
database: 'admin',
roles: ['root'],
},
],
},
});
```
When `auth.enabled` is true, protected commands require successful SCRAM-SHA-256 authentication through the official MongoDB driver:
```typescript
const client = new MongoClient('mongodb://root:change-me@127.0.0.1:27017/admin?authSource=admin', {
directConnection: true,
});
await client.connect();
```
TLS is available for TCP listeners. `getConnectionUri()` includes `?tls=true` when TLS is enabled; pass the trusted CA to the MongoDB driver with `tlsCAFile`, `ca`, or `secureContext`.
Authentication verifies SCRAM credentials, denies unauthenticated commands, and enforces command-level built-in roles for supported operations. `connectionStatus` reports the authenticated users and roles for the current socket.
Supported built-in role names are `root`, `read`, `readWrite`, `dbAdmin`, `userAdmin`, `clusterMonitor`, plus `readAnyDatabase`, `readWriteAnyDatabase`, `dbAdminAnyDatabase`, and `userAdminAnyDatabase`. When `usersPath` is set, SmartDB persists SCRAM credential material atomically and does not store plaintext passwords.
Single-node transactions are supported through official MongoDB driver sessions. Writes with `startTransaction` and `autocommit: false` are buffered per logical session, reads inside the transaction see the buffered overlay, `commitTransaction` applies the write set with conflict checks, and `abortTransaction` discards it.
Basic user management commands are available for authenticated users with `root` or `userAdmin` privileges:
```typescript
await client.db('admin').command({
createUser: 'reader',
pwd: 'readpass',
roles: [{ role: 'read', db: 'myapp' }],
});
await client.db('admin').command({ usersInfo: 'reader' });
```
#### Methods & Properties
@@ -261,7 +319,7 @@ const server = new SmartdbServer({
| `port` | `number` | Configured port (TCP mode) |
| `host` | `string` | Configured host (TCP mode) |
| `socketPath` | `string \| undefined` | Socket path (socket mode) |
| `getMetrics()` | `Promise<ISmartDbMetrics>` | Server metrics (db/collection counts, uptime) |
| `getMetrics()` | `Promise<ISmartDbMetrics>` | Server metrics (db/collection counts, sessions, transactions, auth, uptime) |
| `getOpLog(params?)` | `Promise<IOpLogResult>` | Query oplog entries with optional filters |
| `getOpLogStats()` | `Promise<IOpLogStats>` | Aggregate oplog statistics |
| `revertToSeq(seq, dryRun?)` | `Promise<IRevertResult>` | Revert to a specific oplog sequence |
@@ -429,6 +487,8 @@ await collection.dropIndex('email_1');
await collection.dropIndexes(); // drop all except _id
```
> 🛡️ **Unique indexes are enforced at the engine level.** Duplicate values are rejected with a `DuplicateKey` error (code 11000) *before* the document is written to disk — on `insertOne`, `updateOne`, `findAndModify`, and upserts. Index definitions are persisted to `indexes.json` and automatically restored on restart.
### Database & Admin
```typescript
@@ -473,7 +533,7 @@ const names = await collection.distinct('name');
| **Aggregation** | `aggregate`, `count`, `distinct` |
| **Indexes** | `createIndexes`, `dropIndexes`, `listIndexes` |
| **Sessions** | `startSession`, `endSessions` |
| **Transactions** | `commitTransaction`, `abortTransaction` |
| **Transactions** | `startTransaction`, `commitTransaction`, `abortTransaction` through driver sessions |
| **Admin** | `ping`, `listDatabases`, `listCollections`, `drop`, `dropDatabase`, `create`, `serverStatus`, `buildInfo`, `dbStats`, `collStats`, `connectionStatus`, `currentOp`, `renameCollection` |
Compatible with wire protocol versions 021 (driver versions 3.6 through 7.0).
@@ -482,7 +542,7 @@ Compatible with wire protocol versions 021 (driver versions 3.6 through 7.0).
## Rust Crate Architecture 🦀
The Rust engine is organized as a Cargo workspace with 8 focused crates:
The Rust engine is organized as a Cargo workspace with 9 focused crates:
| Crate | Purpose |
|---|---|
@@ -493,10 +553,45 @@ The Rust engine is organized as a Cargo workspace with 8 focused crates:
| `rustdb-storage` | Storage backends (memory, file), OpLog with point-in-time replay |
| `rustdb-index` | B-tree/hash indexes, query planner (IXSCAN/COLLSCAN) |
| `rustdb-txn` | Transaction + session management with snapshot isolation |
| `rustdb-auth` | SCRAM-SHA-256 credential handling, user metadata persistence, RBAC checks |
| `rustdb-commands` | 40+ command handlers wiring everything together |
Cross-compiled for `linux_amd64` and `linux_arm64` via [@git.zone/tsrust](https://www.npmjs.com/package/@git.zone/tsrust).
### Storage Engine Reliability 🔒
The Bitcask-style file storage engine includes several reliability features:
- **Write-ahead log (WAL)** — every write is logged before being applied, with crash recovery on restart
- **CRC32 checksums** — every record is integrity-checked on read
- **Automatic compaction** — dead records are reclaimed when they exceed 50% of file size, runs on startup and after every write
- **Hint file staleness detection** — the hint file records the data file size at write time; if data.rdb changed since (e.g. crash after a delete), the engine falls back to a full scan to ensure tombstones are not lost
- **Torn-tail repair** — startup scans `data.rdb` to the last valid record, truncates invalid trailing bytes, and preserves all verified records after interrupted writes
- **Stale socket cleanup** — orphaned `/tmp/smartdb-*.sock` files from crashed instances are automatically cleaned up on startup
### Data Integrity CLI 🔍
The Rust binary includes an offline integrity checker:
```bash
# Check all collections in a data directory
./dist_rust/rustdb_linux_amd64 --validate-data /path/to/data
# Output:
# === SmartDB Data Integrity Report ===
#
# Database: mydb
# Collection: users
# Header: OK
# Records: 1,234 (1,200 live, 34 tombstones)
# Data size: 2.1 MB
# Duplicates: 0
# CRC errors: 0
# Hint file: OK
```
Checks file headers, record CRC32 checksums, duplicate `_id` entries, and hint file consistency. Exit code 1 if any errors are found.
---
## Testing Example
@@ -541,7 +636,7 @@ export default tap.start();
## License and Legal Information
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license) file.
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.
+307 -13
View File
@@ -60,7 +60,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -71,7 +71,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -124,6 +124,15 @@ dependencies = [
"wyz",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "bson"
version = "2.15.0"
@@ -139,7 +148,7 @@ dependencies = [
"indexmap",
"js-sys",
"once_cell",
"rand",
"rand 0.9.2",
"serde",
"serde_bytes",
"serde_json",
@@ -221,6 +230,15 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
[[package]]
name = "cpufeatures"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.5.0"
@@ -236,6 +254,16 @@ version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crypto-common"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "dashmap"
version = "6.1.0"
@@ -259,6 +287,17 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
name = "equivalent"
version = "1.0.2"
@@ -272,7 +311,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -342,6 +381,16 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.17"
@@ -415,6 +464,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "id-arena"
version = "2.3.0"
@@ -536,7 +594,7 @@ checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc"
dependencies = [
"libc",
"wasi",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -545,7 +603,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -589,6 +647,16 @@ dependencies = [
"windows-link",
]
[[package]]
name = "pbkdf2"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2"
dependencies = [
"digest",
"hmac",
]
[[package]]
name = "pin-project-lite"
version = "0.2.17"
@@ -656,14 +724,35 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a"
dependencies = [
"libc",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha",
"rand_core",
"rand_chacha 0.9.0",
"rand_core 0.9.5",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core 0.6.4",
]
[[package]]
@@ -673,7 +762,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.9.5",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom 0.2.17",
]
[[package]]
@@ -723,6 +821,20 @@ version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
[[package]]
name = "ring"
version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [
"cc",
"cfg-if",
"getrandom 0.2.17",
"libc",
"untrusted",
"windows-sys 0.52.0",
]
[[package]]
name = "rustdb"
version = "0.1.0"
@@ -735,6 +847,7 @@ dependencies = [
"dashmap",
"futures-util",
"mimalloc",
"rustdb-auth",
"rustdb-commands",
"rustdb-config",
"rustdb-index",
@@ -742,14 +855,33 @@ dependencies = [
"rustdb-storage",
"rustdb-txn",
"rustdb-wire",
"rustls-pemfile",
"serde",
"serde_json",
"tokio",
"tokio-rustls",
"tokio-util",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "rustdb-auth"
version = "0.1.0"
dependencies = [
"base64",
"bson",
"hmac",
"pbkdf2",
"rand 0.8.6",
"rustdb-config",
"serde",
"serde_json",
"sha2",
"subtle",
"thiserror",
]
[[package]]
name = "rustdb-commands"
version = "0.1.0"
@@ -757,6 +889,7 @@ dependencies = [
"async-trait",
"bson",
"dashmap",
"rustdb-auth",
"rustdb-config",
"rustdb-index",
"rustdb-query",
@@ -858,7 +991,50 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls"
version = "0.23.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b"
dependencies = [
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pemfile"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9"
dependencies = [
"zeroize",
]
[[package]]
name = "rustls-webpki"
version = "0.103.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
@@ -933,6 +1109,17 @@ dependencies = [
"zmij",
]
[[package]]
name = "sha2"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
@@ -977,7 +1164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -986,6 +1173,12 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "2.0.117"
@@ -1013,7 +1206,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -1090,7 +1283,7 @@ dependencies = [
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys",
"windows-sys 0.61.2",
]
[[package]]
@@ -1104,6 +1297,16 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.18"
@@ -1178,6 +1381,12 @@ dependencies = [
"tracing-log",
]
[[package]]
name = "typenum"
version = "1.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de"
[[package]]
name = "unicode-ident"
version = "1.0.24"
@@ -1190,6 +1399,12 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "utf8parse"
version = "0.2.2"
@@ -1329,6 +1544,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-sys"
version = "0.61.2"
@@ -1338,6 +1562,70 @@ dependencies = [
"windows-link",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "wit-bindgen"
version = "0.51.0"
@@ -1455,6 +1743,12 @@ dependencies = [
"syn",
]
[[package]]
name = "zeroize"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0"
[[package]]
name = "zmij"
version = "1.0.21"
+14
View File
@@ -8,6 +8,7 @@ members = [
"crates/rustdb-storage",
"crates/rustdb-index",
"crates/rustdb-txn",
"crates/rustdb-auth",
"crates/rustdb-commands",
]
@@ -51,6 +52,10 @@ dashmap = "6"
# Cancellation / utility
tokio-util = { version = "0.7", features = ["codec"] }
# TLS transport
tokio-rustls = { version = "0.26", default-features = false, features = ["ring", "tls12"] }
rustls-pemfile = "2"
# mimalloc allocator
mimalloc = "0.1"
@@ -60,6 +65,14 @@ crc32fast = "1"
# Regex for $regex operator
regex = "1"
# Auth crypto
base64 = "0.22"
hmac = "0.12"
pbkdf2 = { version = "0.12", features = ["hmac"] }
rand = "0.8"
sha2 = "0.10"
subtle = "2"
# UUID for sessions
uuid = { version = "1", features = ["v4", "serde"] }
@@ -76,4 +89,5 @@ rustdb-query = { path = "crates/rustdb-query" }
rustdb-storage = { path = "crates/rustdb-storage" }
rustdb-index = { path = "crates/rustdb-index" }
rustdb-txn = { path = "crates/rustdb-txn" }
rustdb-auth = { path = "crates/rustdb-auth" }
rustdb-commands = { path = "crates/rustdb-commands" }
+20
View File
@@ -0,0 +1,20 @@
[package]
name = "rustdb-auth"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Authentication primitives for RustDb"
[dependencies]
base64 = { workspace = true }
bson = { workspace = true }
hmac = { workspace = true }
pbkdf2 = { workspace = true }
rand = { workspace = true }
rustdb-config = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sha2 = { workspace = true }
subtle = { workspace = true }
thiserror = { workspace = true }
+572
View File
@@ -0,0 +1,572 @@
use std::collections::HashMap;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::RwLock;
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2_hmac;
use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use rustdb_config::{AuthOptions, AuthUserOptions};
type HmacSha256 = Hmac<Sha256>;
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("authentication is disabled")]
Disabled,
#[error("unsupported authentication mechanism: {0}")]
UnsupportedMechanism(String),
#[error("invalid SCRAM payload: {0}")]
InvalidPayload(String),
#[error("authentication failed")]
AuthenticationFailed,
#[error("unknown SASL conversation")]
UnknownConversation,
#[error("user already exists: {0}")]
UserAlreadyExists(String),
#[error("user not found: {0}")]
UserNotFound(String),
#[error("auth metadata persistence failed: {0}")]
Persistence(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthAction {
Read,
Write,
DbAdmin,
UserAdmin,
ClusterMonitor,
}
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
pub username: String,
pub database: String,
pub roles: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ScramCredential {
salt: Vec<u8>,
iterations: u32,
stored_key: Vec<u8>,
server_key: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AuthUser {
username: String,
database: String,
roles: Vec<String>,
scram_sha256: ScramCredential,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct PersistedAuthState {
users: Vec<AuthUser>,
}
#[derive(Debug, Clone)]
pub struct ScramConversation {
user: AuthenticatedUser,
client_first_bare: String,
server_first: String,
nonce: String,
stored_key: Vec<u8>,
server_key: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct ScramStartResult {
pub payload: Vec<u8>,
pub conversation: ScramConversation,
}
#[derive(Debug, Clone)]
pub struct ScramContinueResult {
pub payload: Vec<u8>,
pub user: AuthenticatedUser,
}
#[derive(Debug)]
pub struct AuthEngine {
enabled: bool,
users: RwLock<HashMap<String, AuthUser>>,
users_path: Option<PathBuf>,
scram_iterations: u32,
}
impl AuthEngine {
pub fn from_options(options: &AuthOptions) -> Result<Self, AuthError> {
let users_path = options.users_path.as_ref().map(PathBuf::from);
let mut users = if let Some(ref path) = users_path {
load_users(path)?
} else {
HashMap::new()
};
let mut changed = false;
for user_options in &options.users {
let key = user_key(&user_options.database, &user_options.username);
if !users.contains_key(&key) {
let user = AuthUser::from_options(user_options, options.scram_iterations);
users.insert(key, user);
changed = true;
}
}
if changed {
if let Some(ref path) = users_path {
persist_users(path, &users)?;
}
}
Ok(Self {
enabled: options.enabled,
users: RwLock::new(users),
users_path,
scram_iterations: options.scram_iterations,
})
}
pub fn disabled() -> Self {
Self {
enabled: false,
users: RwLock::new(HashMap::new()),
users_path: None,
scram_iterations: 15000,
}
}
pub fn enabled(&self) -> bool {
self.enabled
}
pub fn user_count(&self) -> usize {
self.users
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.len()
}
pub fn supported_mechanisms(&self, namespace_user: &str) -> Vec<String> {
let Some((database, username)) = namespace_user.split_once('.') else {
return Vec::new();
};
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
if users.contains_key(&user_key(database, username)) {
vec![SCRAM_SHA_256.to_string()]
} else {
Vec::new()
}
}
pub fn is_authorized(
&self,
authenticated_users: &[AuthenticatedUser],
target_db: &str,
action: AuthAction,
) -> bool {
authenticated_users
.iter()
.any(|user| user.roles.iter().any(|role| role_allows(role, user, target_db, action)))
}
pub fn create_user(
&self,
database: &str,
username: &str,
password: &str,
roles: Vec<String>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
if users.contains_key(&key) {
return Err(AuthError::UserAlreadyExists(format!("{database}.{username}")));
}
let options = AuthUserOptions {
username: username.to_string(),
password: password.to_string(),
database: database.to_string(),
roles,
};
users.insert(key, AuthUser::from_options(&options, self.scram_iterations));
self.persist_locked(&users)
}
pub fn drop_user(&self, database: &str, username: &str) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
if users.remove(&key).is_none() {
return Err(AuthError::UserNotFound(format!("{database}.{username}")));
}
self.persist_locked(&users)
}
pub fn update_user(
&self,
database: &str,
username: &str,
password: Option<&str>,
roles: Option<Vec<String>>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get_mut(&key)
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
if let Some(new_roles) = roles {
user.roles = new_roles;
}
if let Some(new_password) = password {
let options = AuthUserOptions {
username: username.to_string(),
password: new_password.to_string(),
database: database.to_string(),
roles: user.roles.clone(),
};
user.scram_sha256 = AuthUser::from_options(&options, self.scram_iterations).scram_sha256;
}
self.persist_locked(&users)
}
pub fn grant_roles(
&self,
database: &str,
username: &str,
roles: Vec<String>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get_mut(&key)
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
for role in roles {
if !user.roles.contains(&role) {
user.roles.push(role);
}
}
self.persist_locked(&users)
}
pub fn revoke_roles(
&self,
database: &str,
username: &str,
roles: Vec<String>,
) -> Result<(), AuthError> {
let key = user_key(database, username);
let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get_mut(&key)
.ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?;
user.roles.retain(|role| !roles.contains(role));
self.persist_locked(&users)
}
pub fn users_info(&self, database: &str, username: Option<&str>) -> Vec<AuthenticatedUser> {
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
users
.values()
.filter(|user| user.database == database)
.filter(|user| username.map(|name| user.username == name).unwrap_or(true))
.map(AuthUser::to_authenticated_user)
.collect()
}
pub fn start_scram_sha256(
&self,
database: &str,
payload: &[u8],
) -> Result<ScramStartResult, AuthError> {
if !self.enabled {
return Err(AuthError::Disabled);
}
let message = std::str::from_utf8(payload)
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
let client_first_bare = message
.strip_prefix("n,,")
.ok_or_else(|| AuthError::InvalidPayload("expected SCRAM gs2 header 'n,,'".to_string()))?;
let attrs = parse_scram_attrs(client_first_bare);
let raw_username = attrs
.get("n")
.ok_or_else(|| AuthError::InvalidPayload("missing username".to_string()))?;
let username = decode_scram_name(raw_username);
let client_nonce = attrs
.get("r")
.ok_or_else(|| AuthError::InvalidPayload("missing client nonce".to_string()))?;
let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner());
let user = users
.get(&user_key(database, &username))
.ok_or(AuthError::AuthenticationFailed)?;
let nonce = format!("{}{}", client_nonce, secure_base64(18));
let server_first = format!(
"r={},s={},i={}",
nonce,
BASE64_STANDARD.encode(&user.scram_sha256.salt),
user.scram_sha256.iterations,
);
Ok(ScramStartResult {
payload: server_first.as_bytes().to_vec(),
conversation: ScramConversation {
user: user.to_authenticated_user(),
client_first_bare: client_first_bare.to_string(),
server_first: server_first.clone(),
nonce,
stored_key: user.scram_sha256.stored_key.clone(),
server_key: user.scram_sha256.server_key.clone(),
},
})
}
pub fn continue_scram_sha256(
&self,
conversation: ScramConversation,
payload: &[u8],
) -> Result<ScramContinueResult, AuthError> {
let message = std::str::from_utf8(payload)
.map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?;
let proof_marker = ",p=";
let proof_pos = message
.rfind(proof_marker)
.ok_or_else(|| AuthError::InvalidPayload("missing client proof".to_string()))?;
let client_final_without_proof = &message[..proof_pos];
let proof_b64 = &message[proof_pos + proof_marker.len()..];
let attrs = parse_scram_attrs(client_final_without_proof);
let nonce = attrs
.get("r")
.ok_or_else(|| AuthError::InvalidPayload("missing nonce".to_string()))?;
if nonce != &conversation.nonce {
return Err(AuthError::AuthenticationFailed);
}
let client_proof = BASE64_STANDARD
.decode(proof_b64.as_bytes())
.map_err(|_| AuthError::InvalidPayload("invalid client proof encoding".to_string()))?;
if client_proof.len() != 32 || conversation.stored_key.len() != 32 {
return Err(AuthError::AuthenticationFailed);
}
let auth_message = format!(
"{},{},{}",
conversation.client_first_bare,
conversation.server_first,
client_final_without_proof,
);
let client_signature = hmac_sha256(&conversation.stored_key, auth_message.as_bytes());
let client_key: Vec<u8> = client_proof
.iter()
.zip(client_signature.iter())
.map(|(proof_byte, signature_byte)| proof_byte ^ signature_byte)
.collect();
let computed_stored_key = Sha256::digest(&client_key).to_vec();
if computed_stored_key.ct_eq(&conversation.stored_key).unwrap_u8() != 1 {
return Err(AuthError::AuthenticationFailed);
}
let server_signature = hmac_sha256(&conversation.server_key, auth_message.as_bytes());
let server_final = format!("v={}", BASE64_STANDARD.encode(server_signature));
Ok(ScramContinueResult {
payload: server_final.as_bytes().to_vec(),
user: conversation.user,
})
}
fn persist_locked(&self, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
if let Some(ref path) = self.users_path {
persist_users(path, users)?;
}
Ok(())
}
}
impl Default for AuthEngine {
fn default() -> Self {
Self::disabled()
}
}
impl AuthUser {
fn from_options(options: &AuthUserOptions, iterations: u32) -> Self {
let salt = secure_random(24);
let salted_password = salted_password(options.password.as_bytes(), &salt, iterations);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = Sha256::digest(&client_key).to_vec();
let server_key = hmac_sha256(&salted_password, b"Server Key");
Self {
username: options.username.clone(),
database: options.database.clone(),
roles: options.roles.clone(),
scram_sha256: ScramCredential {
salt,
iterations,
stored_key,
server_key,
},
}
}
fn to_authenticated_user(&self) -> AuthenticatedUser {
AuthenticatedUser {
username: self.username.clone(),
database: self.database.clone(),
roles: self.roles.clone(),
}
}
}
fn role_allows(role: &str, user: &AuthenticatedUser, target_db: &str, action: AuthAction) -> bool {
let (role_db, role_name) = role.split_once('.').unwrap_or(("", role));
if role_name == "root" {
return true;
}
let any_database = role_name.ends_with("AnyDatabase");
let scoped_db = if role_db.is_empty() { &user.database } else { role_db };
if !any_database && scoped_db != target_db {
return false;
}
match role_name {
"read" | "readAnyDatabase" => action == AuthAction::Read,
"readWrite" | "readWriteAnyDatabase" => {
matches!(action, AuthAction::Read | AuthAction::Write)
}
"dbAdmin" | "dbAdminAnyDatabase" => action == AuthAction::DbAdmin,
"userAdmin" | "userAdminAnyDatabase" => action == AuthAction::UserAdmin,
"clusterMonitor" => action == AuthAction::ClusterMonitor,
_ => false,
}
}
fn load_users(path: &Path) -> Result<HashMap<String, AuthUser>, AuthError> {
if !path.exists() {
return Ok(HashMap::new());
}
let data = std::fs::read_to_string(path).map_err(|e| AuthError::Persistence(e.to_string()))?;
let persisted: PersistedAuthState = serde_json::from_str(&data)
.map_err(|e| AuthError::Persistence(format!("failed to parse users file: {e}")))?;
Ok(persisted
.users
.into_iter()
.map(|user| (user_key(&user.database, &user.username), user))
.collect())
}
fn persist_users(path: &Path, users: &HashMap<String, AuthUser>) -> Result<(), AuthError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| AuthError::Persistence(e.to_string()))?;
}
let mut user_list: Vec<AuthUser> = users.values().cloned().collect();
user_list.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username)));
let payload = serde_json::to_vec_pretty(&PersistedAuthState { users: user_list })
.map_err(|e| AuthError::Persistence(e.to_string()))?;
let tmp_path = path.with_extension("tmp");
{
let mut file = std::fs::File::create(&tmp_path)
.map_err(|e| AuthError::Persistence(e.to_string()))?;
file.write_all(&payload)
.map_err(|e| AuthError::Persistence(e.to_string()))?;
file.sync_all()
.map_err(|e| AuthError::Persistence(e.to_string()))?;
}
std::fs::rename(&tmp_path, path).map_err(|e| AuthError::Persistence(e.to_string()))?;
if let Some(parent) = path.parent() {
if let Ok(dir) = std::fs::File::open(parent) {
let _ = dir.sync_all();
}
}
Ok(())
}
fn user_key(database: &str, username: &str) -> String {
format!("{}\0{}", database, username)
}
fn salted_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
let mut output = [0u8; 32];
pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut output);
output.to_vec()
}
fn hmac_sha256(key: &[u8], message: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts keys of any size");
mac.update(message);
mac.finalize().into_bytes().to_vec()
}
fn secure_random(len: usize) -> Vec<u8> {
let mut bytes = vec![0u8; len];
OsRng.fill_bytes(&mut bytes);
bytes
}
fn secure_base64(len: usize) -> String {
BASE64_STANDARD.encode(secure_random(len))
}
fn parse_scram_attrs(input: &str) -> HashMap<String, String> {
let mut result = HashMap::new();
for part in input.split(',') {
if let Some((key, value)) = part.split_once('=') {
result.insert(key.to_string(), value.to_string());
}
}
result
}
fn decode_scram_name(input: &str) -> String {
input.replace("=2C", ",").replace("=3D", "=")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mechanism_lookup_returns_scram_sha256() {
let options = AuthOptions {
enabled: true,
users: vec![AuthUserOptions {
username: "root".to_string(),
password: "secret".to_string(),
database: "admin".to_string(),
roles: vec!["root".to_string()],
}],
users_path: None,
scram_iterations: 4096,
};
let engine = AuthEngine::from_options(&options).unwrap();
assert_eq!(engine.supported_mechanisms("admin.root"), vec![SCRAM_SHA_256.to_string()]);
}
#[test]
fn read_write_role_allows_read_and_write_only_on_own_db() {
let user = AuthenticatedUser {
username: "app".to_string(),
database: "appdb".to_string(),
roles: vec!["readWrite".to_string()],
};
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Read));
assert!(role_allows("readWrite", &user, "appdb", AuthAction::Write));
assert!(!role_allows("readWrite", &user, "other", AuthAction::Read));
assert!(!role_allows("readWrite", &user, "appdb", AuthAction::DbAdmin));
}
}
+1
View File
@@ -22,3 +22,4 @@ rustdb-query = { workspace = true }
rustdb-storage = { workspace = true }
rustdb-index = { workspace = true }
rustdb-txn = { workspace = true }
rustdb-auth = { workspace = true }
@@ -2,6 +2,7 @@ use std::sync::Arc;
use bson::{Bson, Document};
use dashmap::DashMap;
use rustdb_auth::{AuthEngine, AuthenticatedUser, ScramConversation};
use rustdb_index::{IndexEngine, IndexOptions};
use rustdb_storage::{OpLog, StorageAdapter};
use rustdb_txn::{SessionEngine, TransactionEngine};
@@ -22,6 +23,8 @@ pub struct CommandContext {
pub start_time: std::time::Instant,
/// Operation log for point-in-time replay.
pub oplog: Arc<OpLog>,
/// Authentication engine and user store.
pub auth: Arc<AuthEngine>,
}
impl CommandContext {
@@ -85,6 +88,43 @@ impl CommandContext {
}
}
/// Per-client connection state. Authentication is socket-scoped in MongoDB.
pub struct ConnectionState {
pub authenticated_users: Vec<AuthenticatedUser>,
pub sasl_conversations: std::collections::HashMap<i32, ScramConversation>,
next_conversation_id: i32,
}
impl ConnectionState {
pub fn new() -> Self {
Self {
authenticated_users: Vec::new(),
sasl_conversations: std::collections::HashMap::new(),
next_conversation_id: 1,
}
}
pub fn is_authenticated(&self) -> bool {
!self.authenticated_users.is_empty()
}
pub fn next_conversation_id(&mut self) -> i32 {
let id = self.next_conversation_id;
self.next_conversation_id += 1;
id
}
pub fn authenticate(&mut self, user: AuthenticatedUser) {
self.authenticated_users.push(user);
}
}
impl Default for ConnectionState {
fn default() -> Self {
Self::new()
}
}
/// State of an open cursor from a find or aggregate command.
pub struct CursorState {
/// Documents remaining to be returned.
+33 -1
View File
@@ -18,6 +18,12 @@ pub enum CommandError {
#[error("transaction error: {0}")]
TransactionError(String),
#[error("no such transaction: {0}")]
NoSuchTransaction(String),
#[error("write conflict: {0}")]
WriteConflict(String),
#[error("namespace not found: {0}")]
NamespaceNotFound(String),
@@ -27,6 +33,18 @@ pub enum CommandError {
#[error("duplicate key: {0}")]
DuplicateKey(String),
#[error("immutable field: {0}")]
ImmutableField(String),
#[error("unauthorized: {0}")]
Unauthorized(String),
#[error("authentication failed")]
AuthenticationFailed,
#[error("illegal operation: {0}")]
IllegalOperation(String),
#[error("internal error: {0}")]
InternalError(String),
}
@@ -40,9 +58,15 @@ impl CommandError {
CommandError::StorageError(_) => (1, "InternalError"),
CommandError::IndexError(_) => (27, "IndexNotFound"),
CommandError::TransactionError(_) => (112, "WriteConflict"),
CommandError::NoSuchTransaction(_) => (251, "NoSuchTransaction"),
CommandError::WriteConflict(_) => (112, "WriteConflict"),
CommandError::NamespaceNotFound(_) => (26, "NamespaceNotFound"),
CommandError::NamespaceExists(_) => (48, "NamespaceExists"),
CommandError::DuplicateKey(_) => (11000, "DuplicateKey"),
CommandError::ImmutableField(_) => (66, "ImmutableField"),
CommandError::Unauthorized(_) => (13, "Unauthorized"),
CommandError::AuthenticationFailed => (18, "AuthenticationFailed"),
CommandError::IllegalOperation(_) => (20, "IllegalOperation"),
CommandError::InternalError(_) => (1, "InternalError"),
};
@@ -63,7 +87,15 @@ impl From<rustdb_storage::StorageError> for CommandError {
impl From<rustdb_txn::TransactionError> for CommandError {
fn from(e: rustdb_txn::TransactionError) -> Self {
CommandError::TransactionError(e.to_string())
match e {
rustdb_txn::TransactionError::NotFound(message) => {
CommandError::NoSuchTransaction(message)
}
rustdb_txn::TransactionError::WriteConflict(message) => {
CommandError::WriteConflict(message)
}
other => CommandError::TransactionError(other.to_string()),
}
}
}
@@ -2,8 +2,9 @@ use bson::{doc, Bson, Document};
use rustdb_index::IndexEngine;
use tracing::debug;
use crate::context::{CommandContext, CursorState};
use crate::context::{CommandContext, ConnectionState, CursorState};
use crate::error::{CommandError, CommandResult};
use crate::transactions;
/// Handle various admin / diagnostic / session / auth commands.
pub async fn handle(
@@ -11,6 +12,7 @@ pub async fn handle(
db: &str,
ctx: &CommandContext,
command_name: &str,
connection: &ConnectionState,
) -> CommandResult<Document> {
match command_name {
"ping" => Ok(doc! { "ok": 1.0 }),
@@ -24,13 +26,7 @@ pub async fn handle(
"ok": 1.0,
}),
"serverStatus" => Ok(doc! {
"host": "localhost",
"version": "7.0.0",
"process": "rustdb",
"uptime": ctx.start_time.elapsed().as_secs() as i64,
"ok": 1.0,
}),
"serverStatus" => handle_server_status(ctx),
"hostInfo" => Ok(doc! {
"system": {
@@ -90,13 +86,19 @@ pub async fn handle(
"codeName": "CommandNotFound",
}),
"connectionStatus" => Ok(doc! {
"authInfo": {
"authenticatedUsers": [],
"authenticatedUserRoles": [],
},
"ok": 1.0,
}),
"connectionStatus" => Ok(handle_connection_status(connection)),
"createUser" => handle_create_user(cmd, db, ctx).await,
"updateUser" => handle_update_user(cmd, db, ctx).await,
"dropUser" => handle_drop_user(cmd, db, ctx).await,
"usersInfo" => handle_users_info(cmd, db, ctx).await,
"grantRolesToUser" => handle_grant_roles_to_user(cmd, db, ctx).await,
"revokeRolesFromUser" => handle_revoke_roles_from_user(cmd, db, ctx).await,
"listDatabases" => handle_list_databases(cmd, ctx).await,
@@ -144,15 +146,9 @@ pub async fn handle(
Ok(doc! { "ok": 1.0 })
}
"commitTransaction" => {
// Stub: acknowledge.
Ok(doc! { "ok": 1.0 })
}
"commitTransaction" => transactions::commit_transaction_command(cmd, ctx).await,
"abortTransaction" => {
// Stub: acknowledge.
Ok(doc! { "ok": 1.0 })
}
"abortTransaction" => transactions::abort_transaction_command(cmd, ctx),
// Auth stubs - accept silently.
"saslStart" => Ok(doc! {
@@ -189,6 +185,232 @@ pub async fn handle(
}
}
fn handle_server_status(ctx: &CommandContext) -> CommandResult<Document> {
let oplog_stats = ctx.oplog.stats();
Ok(doc! {
"host": "localhost",
"version": "7.0.0",
"process": "rustdb",
"uptime": ctx.start_time.elapsed().as_secs() as i64,
"connections": {
"current": 0_i32,
"available": i32::MAX,
},
"logicalSessionRecordCache": {
"activeSessionsCount": ctx.sessions.len() as i64,
},
"transactions": {
"currentActive": ctx.transactions.len() as i64,
},
"oplog": {
"currentSeq": oplog_stats.current_seq as i64,
"totalEntries": oplog_stats.total_entries as i64,
"oldestSeq": oplog_stats.oldest_seq as i64,
"entriesByOp": {
"insert": oplog_stats.inserts as i64,
"update": oplog_stats.updates as i64,
"delete": oplog_stats.deletes as i64,
},
},
"security": {
"authentication": ctx.auth.enabled(),
"users": ctx.auth.user_count() as i64,
},
"ok": 1.0,
})
}
fn handle_connection_status(connection: &ConnectionState) -> Document {
let authenticated_users: Vec<Bson> = connection
.authenticated_users
.iter()
.map(|user| {
Bson::Document(doc! {
"user": user.username.clone(),
"db": user.database.clone(),
})
})
.collect();
let authenticated_roles: Vec<Bson> = connection
.authenticated_users
.iter()
.flat_map(|user| {
user.roles
.iter()
.map(|role| Bson::Document(role_to_document(&user.database, role)))
})
.collect();
doc! {
"authInfo": {
"authenticatedUsers": authenticated_users,
"authenticatedUserRoles": authenticated_roles,
},
"ok": 1.0,
}
}
async fn handle_create_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("createUser")
.map_err(|_| CommandError::InvalidArgument("missing 'createUser' field".into()))?;
let password = cmd
.get_str("pwd")
.map_err(|_| CommandError::InvalidArgument("missing 'pwd' field".into()))?;
let roles = parse_roles(cmd, db, "roles")?;
ctx.auth
.create_user(db, username, password, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_update_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("updateUser")
.map_err(|_| CommandError::InvalidArgument("missing 'updateUser' field".into()))?;
let password = cmd.get_str("pwd").ok();
let roles = if cmd.contains_key("roles") {
Some(parse_roles(cmd, db, "roles")?)
} else {
None
};
ctx.auth
.update_user(db, username, password, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_drop_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("dropUser")
.map_err(|_| CommandError::InvalidArgument("missing 'dropUser' field".into()))?;
ctx.auth
.drop_user(db, username)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_users_info(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = match cmd.get("usersInfo") {
Some(Bson::String(name)) => Some(name.as_str()),
Some(Bson::Document(user_doc)) => user_doc.get_str("user").ok(),
_ => None,
};
let users = ctx.auth.users_info(db, username);
let user_docs: Vec<Bson> = users
.into_iter()
.map(|user| {
let roles: Vec<Bson> = user
.roles
.iter()
.map(|role| Bson::Document(role_to_document(&user.database, role)))
.collect();
Bson::Document(doc! {
"user": user.username,
"db": user.database,
"roles": roles,
"mechanisms": ["SCRAM-SHA-256"],
})
})
.collect();
Ok(doc! { "users": user_docs, "ok": 1.0 })
}
async fn handle_grant_roles_to_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("grantRolesToUser")
.map_err(|_| CommandError::InvalidArgument("missing 'grantRolesToUser' field".into()))?;
let roles = parse_roles(cmd, db, "roles")?;
ctx.auth
.grant_roles(db, username, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
async fn handle_revoke_roles_from_user(
cmd: &Document,
db: &str,
ctx: &CommandContext,
) -> CommandResult<Document> {
let username = cmd
.get_str("revokeRolesFromUser")
.map_err(|_| CommandError::InvalidArgument("missing 'revokeRolesFromUser' field".into()))?;
let roles = parse_roles(cmd, db, "roles")?;
ctx.auth
.revoke_roles(db, username, roles)
.map_err(auth_error_to_command_error)?;
Ok(doc! { "ok": 1.0 })
}
fn parse_roles(cmd: &Document, db: &str, key: &str) -> CommandResult<Vec<String>> {
let role_values = cmd
.get_array(key)
.map_err(|_| CommandError::InvalidArgument(format!("missing '{key}' array")))?;
let mut roles = Vec::with_capacity(role_values.len());
for role_value in role_values {
match role_value {
Bson::String(role) => roles.push(role.clone()),
Bson::Document(role_doc) => {
let role = role_doc
.get_str("role")
.map_err(|_| CommandError::InvalidArgument("role document missing 'role'".into()))?;
let role_db = role_doc.get_str("db").unwrap_or(db);
if role_db == db {
roles.push(role.to_string());
} else {
roles.push(format!("{role_db}.{role}"));
}
}
_ => return Err(CommandError::InvalidArgument("roles must be strings or documents".into())),
}
}
Ok(roles)
}
fn role_to_document(default_db: &str, role: &str) -> Document {
if let Some((role_db, role_name)) = role.split_once('.') {
doc! { "role": role_name, "db": role_db }
} else {
doc! { "role": role, "db": default_db }
}
}
fn auth_error_to_command_error(error: rustdb_auth::AuthError) -> CommandError {
match error {
rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message),
rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message),
rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message),
rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed,
rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()),
rustdb_auth::AuthError::UnknownConversation => {
CommandError::InvalidArgument("unknown SASL conversation".into())
}
}
}
/// Handle `listDatabases` command.
async fn handle_list_databases(
cmd: &Document,
@@ -0,0 +1,87 @@
use bson::{doc, Binary, Bson, Document};
use crate::context::{CommandContext, ConnectionState};
use crate::error::{CommandError, CommandResult};
pub async fn handle_sasl_start(
cmd: &Document,
db: &str,
ctx: &CommandContext,
connection: &mut ConnectionState,
) -> CommandResult<Document> {
let mechanism = cmd
.get_str("mechanism")
.map_err(|_| CommandError::InvalidArgument("missing SASL mechanism".into()))?;
if mechanism != "SCRAM-SHA-256" {
return Err(CommandError::InvalidArgument(format!(
"unsupported SASL mechanism: {mechanism}"
)));
}
let payload = payload_bytes(cmd)?;
let result = ctx
.auth
.start_scram_sha256(db, &payload)
.map_err(map_auth_error)?;
let conversation_id = connection.next_conversation_id();
connection
.sasl_conversations
.insert(conversation_id, result.conversation);
Ok(doc! {
"conversationId": conversation_id,
"done": false,
"payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload },
"ok": 1.0,
})
}
pub async fn handle_sasl_continue(
cmd: &Document,
ctx: &CommandContext,
connection: &mut ConnectionState,
) -> CommandResult<Document> {
let conversation_id = cmd
.get_i32("conversationId")
.map_err(|_| CommandError::InvalidArgument("missing SASL conversationId".into()))?;
let payload = payload_bytes(cmd)?;
let conversation = connection
.sasl_conversations
.remove(&conversation_id)
.ok_or_else(|| CommandError::InvalidArgument("unknown SASL conversation".into()))?;
let result = ctx
.auth
.continue_scram_sha256(conversation, &payload)
.map_err(map_auth_error)?;
connection.authenticate(result.user);
Ok(doc! {
"conversationId": conversation_id,
"done": true,
"payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload },
"ok": 1.0,
})
}
fn payload_bytes(cmd: &Document) -> CommandResult<Vec<u8>> {
match cmd.get("payload") {
Some(Bson::Binary(binary)) => Ok(binary.bytes.clone()),
Some(Bson::String(value)) => Ok(value.as_bytes().to_vec()),
_ => Err(CommandError::InvalidArgument("missing SASL payload".into())),
}
}
fn map_auth_error(error: rustdb_auth::AuthError) -> CommandError {
match error {
rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message),
rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()),
rustdb_auth::AuthError::UnknownConversation => {
CommandError::InvalidArgument("unknown SASL conversation".into())
}
rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed,
rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message),
rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message),
rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message),
}
}
@@ -7,6 +7,7 @@ use tracing::debug;
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
use crate::transactions;
/// Handle the `delete` command.
pub async fn handle(
@@ -36,6 +37,7 @@ pub async fn handle(
);
let ns_key = format!("{}.{}", db, coll);
let txn_id = transactions::active_transaction_id(ctx, cmd);
let mut total_deleted: i32 = 0;
let mut write_errors: Vec<Document> = Vec::new();
@@ -69,7 +71,7 @@ pub async fn handle(
_ => 0, // default: delete all matches
};
match delete_matching(db, coll, &ns_key, &filter, limit, ctx).await {
match delete_matching(db, coll, &ns_key, &filter, limit, ctx, txn_id.as_deref()).await {
Ok(count) => {
total_deleted += count;
}
@@ -114,7 +116,24 @@ async fn delete_matching(
filter: &Document,
limit: i32,
ctx: &CommandContext,
txn_id: Option<&str>,
) -> Result<i32, CommandError> {
if let Some(txn_id) = txn_id {
let docs = transactions::load_transaction_docs(ctx, txn_id, db, coll).await?;
let matched = QueryMatcher::filter(&docs, filter);
let to_delete: &[Document] = if limit == 1 && !matched.is_empty() {
&matched[..1]
} else {
&matched
};
for doc in to_delete {
transactions::record_delete(ctx, txn_id, db, coll, doc.clone()).await?;
}
return Ok(to_delete.len() as i32);
}
// Check if the collection exists; if not, nothing to delete.
match ctx.storage.collection_exists(db, coll).await {
Ok(false) => return Ok(0),
@@ -7,6 +7,7 @@ use rustdb_query::{QueryMatcher, sort_documents, apply_projection, distinct_valu
use crate::context::{CommandContext, CursorState};
use crate::error::{CommandError, CommandResult};
use crate::transactions;
/// Atomic counter for generating unique cursor IDs.
static CURSOR_ID_COUNTER: AtomicI64 = AtomicI64::new(1);
@@ -80,9 +81,14 @@ pub async fn handle(
let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize;
let batch_size = get_i32(cmd, "batchSize").unwrap_or(101).max(0) as usize;
let single_batch = get_bool(cmd, "singleBatch").unwrap_or(false);
let txn_id = transactions::active_transaction_id(ctx, cmd);
// If the collection does not exist, return an empty cursor.
let exists = ctx.storage.collection_exists(db, coll).await?;
let exists = if txn_id.is_some() {
true
} else {
ctx.storage.collection_exists(db, coll).await?
};
if !exists {
return Ok(doc! {
"cursor": {
@@ -96,7 +102,9 @@ pub async fn handle(
// Try index-accelerated lookup.
let index_key = format!("{}.{}", db, coll);
let docs = if let Some(idx_ref) = ctx.indexes.get(&index_key) {
let docs = if let Some(ref txn_id) = txn_id {
transactions::load_transaction_docs(ctx, txn_id, db, coll).await?
} else if let Some(idx_ref) = ctx.indexes.get(&index_key) {
if let Some(candidate_ids) = idx_ref.find_candidate_ids(&filter) {
debug!(
ns = %ns,
@@ -298,9 +306,14 @@ pub async fn handle_count(
ctx: &CommandContext,
) -> CommandResult<Document> {
let coll = get_str(cmd, "count").unwrap_or("unknown");
let txn_id = transactions::active_transaction_id(ctx, cmd);
// Check collection existence.
let exists = ctx.storage.collection_exists(db, coll).await?;
let exists = if txn_id.is_some() {
true
} else {
ctx.storage.collection_exists(db, coll).await?
};
if !exists {
return Ok(doc! { "n": 0_i64, "ok": 1.0 });
}
@@ -309,6 +322,23 @@ pub async fn handle_count(
let skip = get_i64(cmd, "skip").unwrap_or(0).max(0) as usize;
let limit = get_i64(cmd, "limit").unwrap_or(0).max(0) as usize;
if let Some(ref txn_id) = txn_id {
let docs = transactions::load_transaction_docs(ctx, txn_id, db, coll).await?;
let filtered = if query.is_empty() {
docs
} else {
QueryMatcher::filter(&docs, &query)
};
let mut n = filtered.len().saturating_sub(skip);
if limit > 0 {
n = n.min(limit);
}
return Ok(doc! {
"n": n as i64,
"ok": 1.0,
});
}
let count: u64 = if query.is_empty() && skip == 0 && limit == 0 {
// Fast path: use storage-level count.
ctx.storage.count(db, coll).await?
@@ -352,15 +382,24 @@ pub async fn handle_distinct(
let key = get_str(cmd, "key").ok_or_else(|| {
CommandError::InvalidArgument("distinct requires a 'key' field".into())
})?;
let txn_id = transactions::active_transaction_id(ctx, cmd);
// Check collection existence.
let exists = ctx.storage.collection_exists(db, coll).await?;
let exists = if txn_id.is_some() {
true
} else {
ctx.storage.collection_exists(db, coll).await?
};
if !exists {
return Ok(doc! { "values": [], "ok": 1.0 });
}
let query = get_document(cmd, "query").cloned();
let docs = ctx.storage.find_all(db, coll).await?;
let docs = if let Some(txn_id) = txn_id {
transactions::load_transaction_docs(ctx, &txn_id, db, coll).await?
} else {
ctx.storage.find_all(db, coll).await?
};
let values = distinct_values(&docs, key, query.as_ref());
Ok(doc! {
@@ -1,4 +1,4 @@
use bson::{doc, Document};
use bson::{doc, Bson, Document};
use crate::context::CommandContext;
use crate::error::CommandResult;
@@ -7,12 +7,13 @@ use crate::error::CommandResult;
///
/// Returns server capabilities matching wire protocol expectations.
pub async fn handle(
_cmd: &Document,
cmd: &Document,
_db: &str,
_ctx: &CommandContext,
ctx: &CommandContext,
) -> CommandResult<Document> {
Ok(doc! {
let mut response = doc! {
"ismaster": true,
"helloOk": true,
"isWritablePrimary": true,
"maxBsonObjectSize": 16_777_216_i32,
"maxMessageSizeBytes": 48_000_000_i32,
@@ -24,5 +25,19 @@ pub async fn handle(
"maxWireVersion": 21_i32,
"readOnly": false,
"ok": 1.0,
})
};
if ctx.auth.enabled() {
if let Ok(namespace_user) = cmd.get_str("saslSupportedMechs") {
let mechanisms: Vec<Bson> = ctx
.auth
.supported_mechanisms(namespace_user)
.into_iter()
.map(Bson::String)
.collect();
response.insert("saslSupportedMechs", Bson::Array(mechanisms));
}
}
Ok(response)
}
@@ -6,6 +6,7 @@ use tracing::debug;
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
use crate::transactions;
/// Handle the `insert` command.
pub async fn handle(
@@ -48,8 +49,13 @@ pub async fn handle(
"insert command"
);
// Auto-create database and collection if they don't exist.
ensure_collection_exists(db, coll, ctx).await?;
let txn_id = transactions::active_transaction_id(ctx, cmd);
// Auto-create database and collection if they don't exist. Transactional
// writes defer collection creation until commit so abort remains clean.
if txn_id.is_none() {
ensure_collection_exists(db, coll, ctx).await?;
}
let ns_key = format!("{}.{}", db, coll);
let mut inserted_count: i32 = 0;
@@ -84,6 +90,24 @@ pub async fn handle(
}
}
if let Some(ref txn_id) = txn_id {
match transactions::record_insert(ctx, txn_id, db, coll, doc.clone()).await {
Ok(_) => inserted_count += 1,
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 11000_i32,
"codeName": "DuplicateKey",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
continue;
}
// Attempt storage insert.
match ctx.storage.insert_one(db, coll, doc.clone()).await {
Ok(id_str) => {
@@ -1,5 +1,6 @@
pub mod admin_handler;
pub mod aggregate_handler;
pub mod auth_handler;
pub mod delete_handler;
pub mod find_handler;
pub mod hello_handler;
@@ -7,6 +7,7 @@ use tracing::debug;
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
use crate::transactions;
/// Handle `update` and `findAndModify` commands.
pub async fn handle(
@@ -21,6 +22,11 @@ pub async fn handle(
}
}
enum TUpdateSpec {
Document(Document),
Pipeline(Vec<Document>),
}
/// Handle the `update` command.
async fn handle_update(
cmd: &Document,
@@ -42,8 +48,12 @@ async fn handle_update(
debug!(db = db, collection = coll, count = updates.len(), "update command");
// Auto-create database and collection if needed.
ensure_collection_exists(db, coll, ctx).await?;
let txn_id = transactions::active_transaction_id(ctx, cmd);
// Transactional writes defer namespace creation until commit.
if txn_id.is_none() {
ensure_collection_exists(db, coll, ctx).await?;
}
let ns_key = format!("{}.{}", db, coll);
@@ -78,21 +88,22 @@ async fn handle_update(
};
let update = match update_spec.get("u") {
Some(Bson::Document(d)) => d.clone(),
Some(Bson::Array(_pipeline)) => {
// Aggregation pipeline updates are not yet supported; treat as error.
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": "aggregation pipeline updates not yet supported",
});
if ordered {
break;
Some(update_value) => match parse_update_spec(update_value) {
Ok(parsed) => parsed,
Err(err) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
"codeName": "TypeMismatch",
"errmsg": err,
});
if ordered {
break;
}
continue;
}
continue;
}
_ => {
},
None => {
write_errors.push(doc! {
"index": idx as i32,
"code": 14_i32,
@@ -130,28 +141,19 @@ async fn handle_update(
});
// Load all documents and filter.
let all_docs = load_filtered_docs(db, coll, &filter, &ns_key, ctx).await?;
let all_docs = load_filtered_docs(db, coll, &filter, &ns_key, ctx, txn_id.as_deref()).await?;
if all_docs.is_empty() && upsert {
// Upsert: create a new document.
let new_doc = build_upsert_doc(&filter);
// Apply update operators or replacement.
match UpdateEngine::apply_update(&new_doc, &update, array_filters.as_deref()) {
match apply_update_spec(&new_doc, &update, array_filters.as_deref()) {
Ok(mut updated) => {
// Apply $setOnInsert if present.
if let Some(Bson::Document(soi)) = update.get("$setOnInsert") {
UpdateEngine::apply_set_on_insert(&mut updated, soi);
}
apply_set_on_insert_if_present(&update, &mut updated);
// Ensure _id exists.
let new_id = if !updated.contains_key("_id") {
let oid = ObjectId::new();
updated.insert("_id", oid);
Bson::ObjectId(oid)
} else {
updated.get("_id").unwrap().clone()
};
let new_id = ensure_document_id(&mut updated);
// Pre-check unique index constraints before upsert insert.
if let Some(engine) = ctx.indexes.get(&ns_key) {
@@ -169,6 +171,30 @@ async fn handle_update(
}
}
if let Some(ref txn_id) = txn_id {
match transactions::record_insert(ctx, txn_id, db, coll, updated.clone()).await {
Ok(_) => {
total_n += 1;
upserted_list.push(doc! {
"index": idx as i32,
"_id": new_id,
});
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 1_i32,
"codeName": "InternalError",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
continue;
}
// Insert the new document.
match ctx.storage.insert_one(db, coll, updated.clone()).await {
Ok(id_str) => {
@@ -229,12 +255,21 @@ async fn handle_update(
};
for matched_doc in &docs_to_update {
match UpdateEngine::apply_update(
matched_doc,
&update,
array_filters.as_deref(),
) {
Ok(updated_doc) => {
match apply_update_spec(matched_doc, &update, array_filters.as_deref()) {
Ok(mut updated_doc) => {
if let Err(e) = ensure_immutable_id(matched_doc, &mut updated_doc) {
write_errors.push(doc! {
"index": idx as i32,
"code": 66_i32,
"codeName": "ImmutableField",
"errmsg": e.to_string(),
});
if ordered {
break;
}
continue;
}
// Pre-check unique index constraints before storage write.
if let Some(engine) = ctx.indexes.get(&ns_key) {
if let Err(e) = engine.check_unique_constraints_for_update(matched_doc, &updated_doc) {
@@ -252,6 +287,38 @@ async fn handle_update(
}
let id_str = extract_id_string(matched_doc);
if let Some(ref txn_id) = txn_id {
match transactions::record_update(
ctx,
txn_id,
db,
coll,
matched_doc.clone(),
updated_doc.clone(),
)
.await
{
Ok(_) => {
total_n += 1;
if matched_doc != &updated_doc {
total_n_modified += 1;
}
}
Err(e) => {
write_errors.push(doc! {
"index": idx as i32,
"code": 1_i32,
"codeName": "InternalError",
"errmsg": e.to_string(),
});
if ordered {
break;
}
}
}
continue;
}
match ctx
.storage
.update_by_id(db, coll, &id_str, updated_doc.clone())
@@ -361,8 +428,11 @@ async fn handle_find_and_modify(
};
let update_doc = match cmd.get("update") {
Some(Bson::Document(d)) => Some(d.clone()),
_ => None,
Some(update_value) => Some(
parse_update_spec(update_value)
.map_err(CommandError::InvalidArgument)?
),
None => None,
};
let remove = match cmd.get("remove") {
@@ -398,8 +468,12 @@ async fn handle_find_and_modify(
.collect()
});
// Auto-create database and collection.
ensure_collection_exists(db, coll, ctx).await?;
let txn_id = transactions::active_transaction_id(ctx, cmd);
// Transactional writes defer namespace creation until commit.
if txn_id.is_none() {
ensure_collection_exists(db, coll, ctx).await?;
}
let ns_key = format!("{}.{}", db, coll);
@@ -407,7 +481,7 @@ async fn handle_find_and_modify(
drop(ctx.get_or_init_index_engine(db, coll).await);
// Load and filter documents.
let mut matched = load_filtered_docs(db, coll, &query, &ns_key, ctx).await?;
let mut matched = load_filtered_docs(db, coll, &query, &ns_key, ctx, txn_id.as_deref()).await?;
// Sort if specified.
if let Some(ref sort_spec) = sort {
@@ -421,6 +495,21 @@ async fn handle_find_and_modify(
// Remove operation.
if let Some(ref doc) = target {
let id_str = extract_id_string(doc);
if let Some(ref txn_id) = txn_id {
transactions::record_delete(ctx, txn_id, db, coll, doc.clone()).await?;
let value = apply_fields_projection(doc, &fields);
return Ok(doc! {
"value": value,
"lastErrorObject": {
"n": 1_i32,
"updatedExisting": false,
},
"ok": 1.0,
});
}
ctx.storage.delete_by_id(db, coll, &id_str).await?;
// Record in oplog.
@@ -477,12 +566,14 @@ async fn handle_find_and_modify(
if let Some(original_doc) = target {
// Update the matched document.
let updated_doc = UpdateEngine::apply_update(
let mut updated_doc = apply_update_spec(
&original_doc,
&update,
array_filters.as_deref(),
)
.map_err(|e| CommandError::InternalError(e.to_string()))?;
.map_err(CommandError::InternalError)?;
ensure_immutable_id(&original_doc, &mut updated_doc)?;
// Pre-check unique index constraints before storage write.
if let Some(engine) = ctx.indexes.get(&ns_key) {
@@ -492,6 +583,35 @@ async fn handle_find_and_modify(
}
let id_str = extract_id_string(&original_doc);
if let Some(ref txn_id) = txn_id {
transactions::record_update(
ctx,
txn_id,
db,
coll,
original_doc.clone(),
updated_doc.clone(),
)
.await?;
let return_doc = if return_new {
&updated_doc
} else {
&original_doc
};
let value = apply_fields_projection(return_doc, &fields);
return Ok(doc! {
"value": value,
"lastErrorObject": {
"n": 1_i32,
"updatedExisting": true,
},
"ok": 1.0,
});
}
ctx.storage
.update_by_id(db, coll, &id_str, updated_doc.clone())
.await?;
@@ -533,26 +653,17 @@ async fn handle_find_and_modify(
// Upsert: create a new document.
let new_doc = build_upsert_doc(&query);
let mut updated_doc = UpdateEngine::apply_update(
let mut updated_doc = apply_update_spec(
&new_doc,
&update,
array_filters.as_deref(),
)
.map_err(|e| CommandError::InternalError(e.to_string()))?;
.map_err(CommandError::InternalError)?;
// Apply $setOnInsert if present.
if let Some(Bson::Document(soi)) = update.get("$setOnInsert") {
UpdateEngine::apply_set_on_insert(&mut updated_doc, soi);
}
apply_set_on_insert_if_present(&update, &mut updated_doc);
// Ensure _id.
let upserted_id = if !updated_doc.contains_key("_id") {
let oid = ObjectId::new();
updated_doc.insert("_id", oid);
Bson::ObjectId(oid)
} else {
updated_doc.get("_id").unwrap().clone()
};
let upserted_id = ensure_document_id(&mut updated_doc);
// Pre-check unique index constraints before upsert insert.
if let Some(engine) = ctx.indexes.get(&ns_key) {
@@ -561,6 +672,26 @@ async fn handle_find_and_modify(
}
}
if let Some(ref txn_id) = txn_id {
transactions::record_insert(ctx, txn_id, db, coll, updated_doc.clone()).await?;
let value = if return_new {
apply_fields_projection(&updated_doc, &fields)
} else {
Bson::Null
};
return Ok(doc! {
"value": value,
"lastErrorObject": {
"n": 1_i32,
"updatedExisting": false,
"upserted": upserted_id,
},
"ok": 1.0,
});
}
let inserted_id_str = ctx.storage
.insert_one(db, coll, updated_doc.clone())
.await?;
@@ -620,7 +751,17 @@ async fn load_filtered_docs(
filter: &Document,
ns_key: &str,
ctx: &CommandContext,
txn_id: Option<&str>,
) -> CommandResult<Vec<Document>> {
if let Some(txn_id) = txn_id {
let docs = transactions::load_transaction_docs(ctx, txn_id, db, coll).await?;
return if filter.is_empty() {
Ok(docs)
} else {
Ok(QueryMatcher::filter(&docs, filter))
};
}
// Try to use index to narrow candidates.
let candidate_ids: Option<HashSet<String>> = ctx
.indexes
@@ -667,6 +808,88 @@ fn build_upsert_doc(filter: &Document) -> Document {
doc
}
fn parse_update_spec(update_value: &Bson) -> Result<TUpdateSpec, String> {
match update_value {
Bson::Document(d) => Ok(TUpdateSpec::Document(d.clone())),
Bson::Array(stages) => {
if stages.is_empty() {
return Err("aggregation pipeline update cannot be empty".into());
}
let mut pipeline = Vec::with_capacity(stages.len());
for stage in stages {
match stage {
Bson::Document(d) => pipeline.push(d.clone()),
_ => {
return Err(
"aggregation pipeline update stages must be documents".into(),
);
}
}
}
Ok(TUpdateSpec::Pipeline(pipeline))
}
_ => Err("missing or invalid 'u' field in update spec".into()),
}
}
fn apply_update_spec(
doc: &Document,
update: &TUpdateSpec,
array_filters: Option<&[Document]>,
) -> Result<Document, String> {
match update {
TUpdateSpec::Document(update_doc) => UpdateEngine::apply_update(doc, update_doc, array_filters)
.map_err(|e| e.to_string()),
TUpdateSpec::Pipeline(pipeline) => {
if array_filters.is_some_and(|filters| !filters.is_empty()) {
return Err(
"arrayFilters are not supported with aggregation pipeline updates"
.into(),
);
}
UpdateEngine::apply_pipeline_update(doc, pipeline).map_err(|e| e.to_string())
}
}
}
fn apply_set_on_insert_if_present(update: &TUpdateSpec, doc: &mut Document) {
if let TUpdateSpec::Document(update_doc) = update {
if let Some(Bson::Document(soi)) = update_doc.get("$setOnInsert") {
UpdateEngine::apply_set_on_insert(doc, soi);
}
}
}
fn ensure_document_id(doc: &mut Document) -> Bson {
if let Some(id) = doc.get("_id") {
id.clone()
} else {
let oid = ObjectId::new();
doc.insert("_id", oid);
Bson::ObjectId(oid)
}
}
fn ensure_immutable_id(original_doc: &Document, updated_doc: &mut Document) -> CommandResult<()> {
if let Some(original_id) = original_doc.get("_id") {
match updated_doc.get("_id") {
Some(updated_id) if updated_id == original_id => Ok(()),
Some(_) => Err(CommandError::ImmutableField(
"cannot modify immutable field '_id'".into(),
)),
None => {
updated_doc.insert("_id", original_id.clone());
Ok(())
}
}
} else {
Ok(())
}
}
/// Extract _id as a string for storage operations.
fn extract_id_string(doc: &Document) -> String {
match doc.get("_id") {
+2 -1
View File
@@ -1,8 +1,9 @@
mod context;
pub mod error;
pub mod handlers;
pub mod transactions;
mod router;
pub use context::{CommandContext, CursorState};
pub use context::{CommandContext, ConnectionState, CursorState};
pub use error::{CommandError, CommandResult};
pub use router::CommandRouter;
+107 -6
View File
@@ -1,13 +1,14 @@
use std::sync::Arc;
use bson::Document;
use bson::{Bson, Document};
use tracing::{debug, warn};
use rustdb_wire::ParsedCommand;
use rustdb_auth::AuthAction;
use crate::context::CommandContext;
use crate::context::{CommandContext, ConnectionState};
use crate::error::CommandError;
use crate::handlers;
use crate::{handlers, transactions};
/// Routes parsed wire protocol commands to the appropriate handler.
pub struct CommandRouter {
@@ -21,12 +22,47 @@ impl CommandRouter {
}
/// Route a parsed command to the appropriate handler, returning a BSON response document.
pub async fn route(&self, cmd: &ParsedCommand) -> Document {
pub async fn route(&self, cmd: &ParsedCommand, connection: &mut ConnectionState) -> Document {
let db = &cmd.database;
let command_name = cmd.command_name.as_str();
debug!(command = %command_name, database = %db, "routing command");
if self.ctx.auth.enabled()
&& !connection.is_authenticated()
&& !allows_unauthenticated(command_name)
{
return CommandError::Unauthorized(format!(
"command '{}' requires authentication",
command_name,
))
.to_error_doc();
}
if self.ctx.auth.enabled() && connection.is_authenticated() {
if let Some(action) = required_action(command_name, &cmd.command) {
if !self
.ctx
.auth
.is_authorized(&connection.authenticated_users, db, action)
{
return CommandError::Unauthorized(format!(
"command '{}' is not authorized for database '{}'",
command_name, db,
))
.to_error_doc();
}
}
}
if let Err(e) = transactions::prepare_transaction_for_command(
&self.ctx,
&cmd.command,
command_name,
) {
return e.to_error_doc();
}
// Extract session id if present, and touch the session.
if let Some(lsid) = cmd.command.get("lsid") {
if let Some(session_id) = rustdb_txn::SessionEngine::extract_session_id(lsid) {
@@ -40,6 +76,14 @@ impl CommandRouter {
handlers::hello_handler::handle(&cmd.command, db, &self.ctx).await
}
// -- authentication --
"saslStart" => {
handlers::auth_handler::handle_sasl_start(&cmd.command, db, &self.ctx, connection).await
}
"saslContinue" => {
handlers::auth_handler::handle_sasl_continue(&cmd.command, &self.ctx, connection).await
}
// -- query commands --
"find" => {
handlers::find_handler::handle(&cmd.command, db, &self.ctx).await
@@ -88,10 +132,12 @@ impl CommandRouter {
| "dbStats" | "collStats" | "validate" | "explain"
| "startSession" | "endSessions" | "killSessions"
| "commitTransaction" | "abortTransaction"
| "saslStart" | "saslContinue" | "authenticate" | "logout"
| "authenticate" | "logout"
| "createUser" | "updateUser" | "dropUser" | "usersInfo"
| "grantRolesToUser" | "revokeRolesFromUser"
| "currentOp" | "killOp" | "top" | "profile"
| "compact" | "reIndex" | "fsync" | "connPoolSync" => {
handlers::admin_handler::handle(&cmd.command, db, &self.ctx, command_name).await
handlers::admin_handler::handle(&cmd.command, db, &self.ctx, command_name, connection).await
}
// -- unknown command --
@@ -107,3 +153,58 @@ impl CommandRouter {
}
}
}
fn allows_unauthenticated(command_name: &str) -> bool {
matches!(
command_name,
"hello" | "ismaster" | "isMaster" | "saslStart" | "saslContinue" | "getnonce"
)
}
fn required_action(command_name: &str, command: &Document) -> Option<AuthAction> {
match command_name {
"hello" | "ismaster" | "isMaster" | "saslStart" | "saslContinue" | "getnonce" => None,
"ping" | "buildInfo" | "buildinfo" | "hostInfo" | "whatsmyuri" | "getLog"
| "getCmdLineOpts" | "getParameter" | "getFreeMonitoringStatus" | "setFreeMonitoring"
| "getShardMap" | "shardingState" | "atlasVersion" | "connectionStatus"
| "startSession" | "endSessions" | "killSessions" | "authenticate" | "logout" => None,
"find" | "getMore" | "killCursors" | "count" | "distinct" | "listIndexes"
| "listCollections" | "collStats" | "dbStats" | "validate" | "explain" => {
Some(AuthAction::Read)
}
"aggregate" => Some(if aggregate_writes(command) {
AuthAction::Write
} else {
AuthAction::Read
}),
"insert" | "update" | "findAndModify" | "delete" | "commitTransaction"
| "abortTransaction" => Some(AuthAction::Write),
"createIndexes" | "dropIndexes" | "create" | "drop" | "dropDatabase"
| "renameCollection" | "compact" | "reIndex" | "fsync" | "profile" => {
Some(AuthAction::DbAdmin)
}
"createUser" | "updateUser" | "dropUser" | "usersInfo" | "grantRolesToUser"
| "revokeRolesFromUser" => Some(AuthAction::UserAdmin),
"serverStatus" | "listDatabases" | "currentOp" | "killOp" | "top" => {
Some(AuthAction::ClusterMonitor)
}
_ => None,
}
}
fn aggregate_writes(command: &Document) -> bool {
let Ok(pipeline) = command.get_array("pipeline") else {
return false;
};
pipeline.last().and_then(|stage| match stage {
Bson::Document(doc) => Some(doc.contains_key("$out") || doc.contains_key("$merge")),
_ => None,
}).unwrap_or(false)
}
@@ -0,0 +1,367 @@
use bson::{doc, Bson, Document};
use rustdb_storage::OpType;
use rustdb_txn::{TransactionState, WriteEntry, WriteOp};
use crate::context::CommandContext;
use crate::error::{CommandError, CommandResult};
pub fn command_starts_transaction(cmd: &Document) -> bool {
matches!(cmd.get("startTransaction"), Some(Bson::Boolean(true)))
}
pub fn command_uses_transaction(cmd: &Document) -> bool {
command_starts_transaction(cmd) || matches!(cmd.get("autocommit"), Some(Bson::Boolean(false)))
}
pub fn active_transaction_id(ctx: &CommandContext, cmd: &Document) -> Option<String> {
if !command_uses_transaction(cmd) {
return None;
}
let session_id = cmd
.get("lsid")
.and_then(rustdb_txn::SessionEngine::extract_session_id)?;
ctx.sessions.get_transaction_id(&session_id)
}
pub fn prepare_transaction_for_command(
ctx: &CommandContext,
cmd: &Document,
command_name: &str,
) -> CommandResult<()> {
if matches!(command_name, "commitTransaction" | "abortTransaction") {
return Ok(());
}
let starts_transaction = command_starts_transaction(cmd);
let uses_transaction = command_uses_transaction(cmd);
if !uses_transaction {
return Ok(());
}
let session_id = session_id_from_command(cmd)?;
require_txn_number(cmd)?;
ctx.sessions.get_or_create_session(&session_id);
if starts_transaction {
let txn_id = ctx.transactions.start_transaction(&session_id)?;
ctx.sessions.start_transaction(&session_id, &txn_id)?;
return Ok(());
}
if ctx.sessions.get_transaction_id(&session_id).is_none() {
return Err(CommandError::NoSuchTransaction(format!(
"session {session_id} has no active transaction"
)));
}
Ok(())
}
pub async fn load_transaction_docs(
ctx: &CommandContext,
txn_id: &str,
db: &str,
coll: &str,
) -> CommandResult<Vec<Document>> {
let ns = namespace(db, coll);
if !ctx.transactions.has_snapshot(txn_id, &ns) {
let docs = match ctx.storage.collection_exists(db, coll).await {
Ok(true) => ctx.storage.find_all(db, coll).await?,
Ok(false) => Vec::new(),
Err(_) => Vec::new(),
};
ctx.transactions.set_snapshot(txn_id, &ns, docs);
}
ctx.transactions
.get_snapshot(txn_id, &ns)
.ok_or_else(|| CommandError::NoSuchTransaction(txn_id.to_string()))
}
pub async fn record_insert(
ctx: &CommandContext,
txn_id: &str,
db: &str,
coll: &str,
doc: Document,
) -> CommandResult<String> {
let id = document_id_string(&doc)?;
let docs = load_transaction_docs(ctx, txn_id, db, coll).await?;
if docs.iter().any(|existing| document_id_string(existing).ok().as_deref() == Some(id.as_str())) {
return Err(CommandError::DuplicateKey(format!(
"duplicate _id '{}' in transaction",
id
)));
}
ctx.transactions.record_write(
txn_id,
&namespace(db, coll),
&id,
WriteOp::Insert,
Some(doc),
None,
);
Ok(id)
}
pub async fn record_update(
ctx: &CommandContext,
txn_id: &str,
db: &str,
coll: &str,
original: Document,
updated: Document,
) -> CommandResult<String> {
let id = document_id_string(&original)?;
ctx.transactions.record_write(
txn_id,
&namespace(db, coll),
&id,
WriteOp::Update,
Some(updated),
Some(original),
);
Ok(id)
}
pub async fn record_delete(
ctx: &CommandContext,
txn_id: &str,
db: &str,
coll: &str,
original: Document,
) -> CommandResult<String> {
let id = document_id_string(&original)?;
ctx.transactions.record_write(
txn_id,
&namespace(db, coll),
&id,
WriteOp::Delete,
None,
Some(original),
);
Ok(id)
}
pub async fn commit_transaction_command(
cmd: &Document,
ctx: &CommandContext,
) -> CommandResult<Document> {
let session_id = session_id_from_command(cmd)?;
let txn_id = ctx
.sessions
.get_transaction_id(&session_id)
.ok_or_else(|| CommandError::NoSuchTransaction(format!(
"session {session_id} has no active transaction"
)))?;
let state = ctx.transactions.take_transaction(&txn_id)?;
preflight_transaction(&state, ctx).await?;
apply_transaction(state, ctx).await?;
ctx.sessions.end_transaction(&session_id);
Ok(doc! { "ok": 1.0 })
}
pub fn abort_transaction_command(cmd: &Document, ctx: &CommandContext) -> CommandResult<Document> {
let session_id = session_id_from_command(cmd)?;
let txn_id = ctx
.sessions
.get_transaction_id(&session_id)
.ok_or_else(|| CommandError::NoSuchTransaction(format!(
"session {session_id} has no active transaction"
)))?;
ctx.transactions.abort_transaction(&txn_id)?;
ctx.sessions.end_transaction(&session_id);
Ok(doc! { "ok": 1.0 })
}
pub fn document_id_string(doc: &Document) -> CommandResult<String> {
match doc.get("_id") {
Some(Bson::ObjectId(oid)) => Ok(oid.to_hex()),
Some(Bson::String(s)) => Ok(s.clone()),
Some(other) => Ok(format!("{}", other)),
None => Err(CommandError::InvalidArgument("document missing _id field".into())),
}
}
fn session_id_from_command(cmd: &Document) -> CommandResult<String> {
cmd.get("lsid")
.and_then(rustdb_txn::SessionEngine::extract_session_id)
.ok_or_else(|| CommandError::InvalidArgument("transaction command requires lsid".into()))
}
fn require_txn_number(cmd: &Document) -> CommandResult<()> {
match cmd.get("txnNumber") {
Some(Bson::Int64(_)) | Some(Bson::Int32(_)) => Ok(()),
_ => Err(CommandError::InvalidArgument(
"transaction command requires txnNumber".into(),
)),
}
}
fn namespace(db: &str, coll: &str) -> String {
format!("{db}.{coll}")
}
async fn preflight_transaction(state: &TransactionState, ctx: &CommandContext) -> CommandResult<()> {
for (ns, writes) in &state.write_set {
let (db, coll) = split_namespace(ns)?;
drop(ctx.get_or_init_index_engine(db, coll).await);
for (doc_id, entry) in writes {
let current = current_doc(ctx, db, coll, doc_id).await?;
match entry.op {
WriteOp::Insert => {
if current.is_some() {
return Err(CommandError::DuplicateKey(format!(
"duplicate _id '{}' on transaction commit",
doc_id
)));
}
if let Some(ref doc) = entry.doc {
if let Some(engine) = ctx.indexes.get(ns) {
engine.check_unique_constraints(doc)?;
}
}
}
WriteOp::Update => {
assert_unchanged(doc_id, current.as_ref(), entry.original_doc.as_ref())?;
if let (Some(current_doc), Some(updated_doc)) = (current.as_ref(), entry.doc.as_ref()) {
if let Some(engine) = ctx.indexes.get(ns) {
engine.check_unique_constraints_for_update(current_doc, updated_doc)?;
}
}
}
WriteOp::Delete => {
assert_unchanged(doc_id, current.as_ref(), entry.original_doc.as_ref())?;
}
}
}
}
Ok(())
}
async fn apply_transaction(state: TransactionState, ctx: &CommandContext) -> CommandResult<()> {
let mut namespaces: Vec<_> = state.write_set.into_iter().collect();
namespaces.sort_by(|a, b| a.0.cmp(&b.0));
for (ns, writes) in namespaces {
let (db, coll) = split_namespace(&ns)?;
ensure_collection_exists(db, coll, ctx).await?;
drop(ctx.get_or_init_index_engine(db, coll).await);
let mut writes: Vec<(String, WriteEntry)> = writes.into_iter().collect();
writes.sort_by(|a, b| a.0.cmp(&b.0));
for (doc_id, entry) in writes {
match entry.op {
WriteOp::Insert => {
let Some(doc) = entry.doc else { continue; };
let inserted_id = ctx.storage.insert_one(db, coll, doc.clone()).await?;
ctx.oplog.append(OpType::Insert, db, coll, &inserted_id, Some(doc.clone()), None);
if let Some(mut engine) = ctx.indexes.get_mut(&ns) {
engine.on_insert(&doc)?;
}
}
WriteOp::Update => {
let Some(doc) = entry.doc else { continue; };
ctx.storage.update_by_id(db, coll, &doc_id, doc.clone()).await?;
ctx.oplog.append(
OpType::Update,
db,
coll,
&doc_id,
Some(doc.clone()),
entry.original_doc.clone(),
);
if let (Some(mut engine), Some(ref original)) =
(ctx.indexes.get_mut(&ns), entry.original_doc.as_ref())
{
engine.on_update(original, &doc)?;
}
}
WriteOp::Delete => {
ctx.storage.delete_by_id(db, coll, &doc_id).await?;
ctx.oplog.append(
OpType::Delete,
db,
coll,
&doc_id,
None,
entry.original_doc.clone(),
);
if let (Some(mut engine), Some(ref original)) =
(ctx.indexes.get_mut(&ns), entry.original_doc.as_ref())
{
engine.on_delete(original);
}
}
}
}
}
Ok(())
}
async fn current_doc(
ctx: &CommandContext,
db: &str,
coll: &str,
doc_id: &str,
) -> CommandResult<Option<Document>> {
match ctx.storage.collection_exists(db, coll).await {
Ok(true) => Ok(ctx.storage.find_by_id(db, coll, doc_id).await?),
Ok(false) => Ok(None),
Err(_) => Ok(None),
}
}
fn assert_unchanged(
doc_id: &str,
current: Option<&Document>,
original: Option<&Document>,
) -> CommandResult<()> {
if current == original {
return Ok(());
}
Err(CommandError::WriteConflict(format!(
"document '{}' changed during transaction",
doc_id
)))
}
async fn ensure_collection_exists(
db: &str,
coll: &str,
ctx: &CommandContext,
) -> CommandResult<()> {
if let Err(e) = ctx.storage.create_database(db).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
match ctx.storage.collection_exists(db, coll).await {
Ok(true) => Ok(()),
Ok(false) | Err(_) => {
if let Err(e) = ctx.storage.create_collection(db, coll).await {
let msg = e.to_string();
if !msg.contains("AlreadyExists") && !msg.contains("already exists") {
return Err(CommandError::StorageError(msg));
}
}
Ok(())
}
}
}
fn split_namespace(ns: &str) -> CommandResult<(&str, &str)> {
ns.split_once('.')
.ok_or_else(|| CommandError::InvalidArgument(format!("invalid namespace '{ns}'")))
}
+162 -1
View File
@@ -46,6 +46,99 @@ pub struct RustDbOptions {
/// Interval in ms for periodic persistence (default: 60000)
#[serde(default = "default_persist_interval")]
pub persist_interval_ms: u64,
/// Authentication configuration.
#[serde(default)]
pub auth: AuthOptions,
/// TLS transport configuration for TCP listeners.
#[serde(default)]
pub tls: TlsOptions,
}
/// Authentication configuration for the embedded server.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthOptions {
/// Whether clients must authenticate before issuing protected commands.
#[serde(default)]
pub enabled: bool,
/// Bootstrap users loaded at startup. Passwords are converted into SCRAM credentials in memory.
#[serde(default)]
pub users: Vec<AuthUserOptions>,
/// Optional path for persisted SCRAM user metadata. Stores derived credentials, never plaintext passwords.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub users_path: Option<String>,
/// SCRAM iteration count used for bootstrap credentials.
#[serde(default = "default_scram_iterations")]
pub scram_iterations: u32,
}
impl Default for AuthOptions {
fn default() -> Self {
Self {
enabled: false,
users: Vec::new(),
users_path: None,
scram_iterations: default_scram_iterations(),
}
}
}
/// TLS transport configuration for the embedded server.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TlsOptions {
/// Whether TCP client connections must use TLS.
#[serde(default)]
pub enabled: bool,
/// PEM-encoded server certificate chain.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub cert_path: Option<String>,
/// PEM-encoded server private key.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub key_path: Option<String>,
/// PEM-encoded client CA roots for mTLS verification.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub ca_path: Option<String>,
/// Require clients to present a certificate signed by caPath.
#[serde(default)]
pub require_client_cert: bool,
}
impl Default for TlsOptions {
fn default() -> Self {
Self {
enabled: false,
cert_path: None,
key_path: None,
ca_path: None,
require_client_cert: false,
}
}
}
/// A bootstrap user for SCRAM authentication.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthUserOptions {
pub username: String,
pub password: String,
#[serde(default = "default_auth_database")]
pub database: String,
#[serde(default)]
pub roles: Vec<String>,
}
fn default_port() -> u16 {
@@ -60,6 +153,14 @@ fn default_persist_interval() -> u64 {
60000
}
fn default_scram_iterations() -> u32 {
15000
}
fn default_auth_database() -> String {
"admin".to_string()
}
impl Default for RustDbOptions {
fn default() -> Self {
Self {
@@ -70,6 +171,8 @@ impl Default for RustDbOptions {
storage_path: None,
persist_path: None,
persist_interval_ms: default_persist_interval(),
auth: AuthOptions::default(),
tls: TlsOptions::default(),
}
}
}
@@ -92,6 +195,59 @@ impl RustDbOptions {
"storagePath is required when storage is 'file'".to_string(),
));
}
if self.auth.enabled {
if self.auth.users.is_empty() && self.auth.users_path.is_none() {
return Err(ConfigError::ValidationError(
"auth.users or auth.usersPath must be set when auth.enabled is true".to_string(),
));
}
if self.auth.scram_iterations < 4096 {
return Err(ConfigError::ValidationError(
"auth.scramIterations must be at least 4096".to_string(),
));
}
for user in &self.auth.users {
if user.username.is_empty() {
return Err(ConfigError::ValidationError(
"auth.users[].username must not be empty".to_string(),
));
}
if user.password.is_empty() {
return Err(ConfigError::ValidationError(
format!("auth user '{}' must have a non-empty password", user.username),
));
}
if user.database.is_empty() {
return Err(ConfigError::ValidationError(
format!("auth user '{}' must have a non-empty database", user.username),
));
}
}
}
if self.tls.enabled {
if self.socket_path.is_some() {
return Err(ConfigError::ValidationError(
"tls.enabled is only supported for TCP listeners".to_string(),
));
}
if self.tls.cert_path.as_deref().unwrap_or_default().is_empty() {
return Err(ConfigError::ValidationError(
"tls.certPath is required when tls.enabled is true".to_string(),
));
}
if self.tls.key_path.as_deref().unwrap_or_default().is_empty() {
return Err(ConfigError::ValidationError(
"tls.keyPath is required when tls.enabled is true".to_string(),
));
}
if self.tls.require_client_cert
&& self.tls.ca_path.as_deref().unwrap_or_default().is_empty()
{
return Err(ConfigError::ValidationError(
"tls.caPath is required when tls.requireClientCert is true".to_string(),
));
}
}
Ok(())
}
@@ -101,7 +257,12 @@ impl RustDbOptions {
let encoded = urlencoding(socket_path);
format!("mongodb://{}", encoded)
} else {
format!("mongodb://{}:{}", self.host, self.port)
let base = format!("mongodb://{}:{}", self.host, self.port);
if self.tls.enabled {
format!("{}/?tls=true", base)
} else {
base
}
}
}
}
+134 -29
View File
@@ -2,10 +2,10 @@ use bson::{Bson, Document};
use std::collections::HashMap;
use crate::error::QueryError;
use crate::field_path::{get_nested_value, remove_nested_value};
use crate::matcher::QueryMatcher;
use crate::sort::sort_documents;
use crate::projection::apply_projection;
use crate::field_path::get_nested_value;
use crate::sort::sort_documents;
/// Aggregation pipeline engine.
pub struct AggregationEngine;
@@ -42,6 +42,7 @@ impl AggregationEngine {
"$count" => Self::stage_count(current, stage_spec)?,
"$addFields" | "$set" => Self::stage_add_fields(current, stage_spec)?,
"$replaceRoot" | "$replaceWith" => Self::stage_replace_root(current, stage_spec)?,
"$unset" => Self::stage_unset(current, stage_spec)?,
"$lookup" => Self::stage_lookup(current, stage_spec, resolver, db)?,
"$facet" => Self::stage_facet(current, stage_spec, resolver, db)?,
"$unionWith" => Self::stage_union_with(current, stage_spec, resolver, db)?,
@@ -60,7 +61,11 @@ impl AggregationEngine {
fn stage_match(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let filter = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$match requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$match requires a document".into(),
))
}
};
Ok(QueryMatcher::filter(&docs, filter))
}
@@ -68,15 +73,26 @@ impl AggregationEngine {
fn stage_project(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let projection = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$project requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$project requires a document".into(),
))
}
};
Ok(docs.into_iter().map(|doc| apply_projection(&doc, projection)).collect())
Ok(docs
.into_iter()
.map(|doc| apply_projection(&doc, projection))
.collect())
}
fn stage_sort(mut docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let sort_spec = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$sort requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$sort requires a document".into(),
))
}
};
sort_documents(&mut docs, sort_spec);
Ok(docs)
@@ -97,7 +113,11 @@ impl AggregationEngine {
fn stage_group(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let group_spec = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$group requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$group requires a document".into(),
))
}
};
let id_expr = group_spec.get("_id").cloned().unwrap_or(Bson::Null);
@@ -158,13 +178,18 @@ impl AggregationEngine {
let (path, preserve_null) = match spec {
Bson::String(s) => (s.trim_start_matches('$').to_string(), false),
Bson::Document(d) => {
let path = d.get_str("path")
let path = d
.get_str("path")
.map(|s| s.trim_start_matches('$').to_string())
.map_err(|_| QueryError::AggregationError("$unwind requires 'path'".into()))?;
let preserve = d.get_bool("preserveNullAndEmptyArrays").unwrap_or(false);
(path, preserve)
}
_ => return Err(QueryError::AggregationError("$unwind requires a string or document".into())),
_ => {
return Err(QueryError::AggregationError(
"$unwind requires a string or document".into(),
))
}
};
let mut result = Vec::new();
@@ -206,7 +231,11 @@ impl AggregationEngine {
fn stage_count(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let field = match spec {
Bson::String(s) => s.clone(),
_ => return Err(QueryError::AggregationError("$count requires a string".into())),
_ => {
return Err(QueryError::AggregationError(
"$count requires a string".into(),
))
}
};
Ok(vec![bson::doc! { field: docs.len() as i64 }])
}
@@ -214,7 +243,11 @@ impl AggregationEngine {
fn stage_add_fields(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let fields = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$addFields requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$addFields requires a document".into(),
))
}
};
Ok(docs
@@ -231,9 +264,16 @@ impl AggregationEngine {
fn stage_replace_root(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let new_root_expr = match spec {
Bson::Document(d) => d.get("newRoot").cloned().unwrap_or(Bson::Document(d.clone())),
Bson::Document(d) => d
.get("newRoot")
.cloned()
.unwrap_or(Bson::Document(d.clone())),
Bson::String(s) => Bson::String(s.clone()),
_ => return Err(QueryError::AggregationError("$replaceRoot requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$replaceRoot requires a document".into(),
))
}
};
let mut result = Vec::new();
@@ -246,6 +286,40 @@ impl AggregationEngine {
Ok(result)
}
fn stage_unset(docs: Vec<Document>, spec: &Bson) -> Result<Vec<Document>, QueryError> {
let fields: Vec<String> = match spec {
Bson::String(s) => vec![s.clone()],
Bson::Array(arr) => arr
.iter()
.map(|value| match value {
Bson::String(field) => Ok(field.clone()),
_ => Err(QueryError::AggregationError(
"$unset array entries must be strings".into(),
)),
})
.collect::<Result<Vec<_>, _>>()?,
_ => {
return Err(QueryError::AggregationError(
"$unset requires a string or array of strings".into(),
));
}
};
Ok(docs
.into_iter()
.map(|mut doc| {
for field in &fields {
if field.contains('.') {
remove_nested_value(&mut doc, field);
} else {
doc.remove(field);
}
}
doc
})
.collect())
}
fn stage_lookup(
docs: Vec<Document>,
spec: &Bson,
@@ -254,20 +328,29 @@ impl AggregationEngine {
) -> Result<Vec<Document>, QueryError> {
let lookup = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$lookup requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$lookup requires a document".into(),
))
}
};
let from = lookup.get_str("from")
let from = lookup
.get_str("from")
.map_err(|_| QueryError::AggregationError("$lookup requires 'from'".into()))?;
let local_field = lookup.get_str("localField")
let local_field = lookup
.get_str("localField")
.map_err(|_| QueryError::AggregationError("$lookup requires 'localField'".into()))?;
let foreign_field = lookup.get_str("foreignField")
let foreign_field = lookup
.get_str("foreignField")
.map_err(|_| QueryError::AggregationError("$lookup requires 'foreignField'".into()))?;
let as_field = lookup.get_str("as")
let as_field = lookup
.get_str("as")
.map_err(|_| QueryError::AggregationError("$lookup requires 'as'".into()))?;
let resolver = resolver
.ok_or_else(|| QueryError::AggregationError("$lookup requires a collection resolver".into()))?;
let resolver = resolver.ok_or_else(|| {
QueryError::AggregationError("$lookup requires a collection resolver".into())
})?;
let foreign_docs = resolver.resolve(db, from)?;
Ok(docs
@@ -299,7 +382,11 @@ impl AggregationEngine {
) -> Result<Vec<Document>, QueryError> {
let facets = match spec {
Bson::Document(d) => d,
_ => return Err(QueryError::AggregationError("$facet requires a document".into())),
_ => {
return Err(QueryError::AggregationError(
"$facet requires a document".into(),
))
}
};
let mut result = Document::new();
@@ -337,22 +424,32 @@ impl AggregationEngine {
let (coll, pipeline) = match spec {
Bson::String(s) => (s.as_str(), None),
Bson::Document(d) => {
let coll = d.get_str("coll")
.map_err(|_| QueryError::AggregationError("$unionWith requires 'coll'".into()))?;
let coll = d.get_str("coll").map_err(|_| {
QueryError::AggregationError("$unionWith requires 'coll'".into())
})?;
let pipeline = d.get_array("pipeline").ok().map(|arr| {
arr.iter()
.filter_map(|s| {
if let Bson::Document(d) = s { Some(d.clone()) } else { None }
if let Bson::Document(d) = s {
Some(d.clone())
} else {
None
}
})
.collect::<Vec<Document>>()
});
(coll, pipeline)
}
_ => return Err(QueryError::AggregationError("$unionWith requires a string or document".into())),
_ => {
return Err(QueryError::AggregationError(
"$unionWith requires a string or document".into(),
))
}
};
let resolver = resolver
.ok_or_else(|| QueryError::AggregationError("$unionWith requires a collection resolver".into()))?;
let resolver = resolver.ok_or_else(|| {
QueryError::AggregationError("$unionWith requires a collection resolver".into())
})?;
let mut other_docs = resolver.resolve(db, coll)?;
if let Some(p) = pipeline {
@@ -476,7 +573,11 @@ fn accumulate_min(docs: &[Document], expr: &Bson) -> Bson {
None => val,
Some(current) => {
if let (Some(cv), Some(vv)) = (bson_to_f64(&current), bson_to_f64(&val)) {
if vv < cv { val } else { current }
if vv < cv {
val
} else {
current
}
} else {
current
}
@@ -499,7 +600,11 @@ fn accumulate_max(docs: &[Document], expr: &Bson) -> Bson {
None => val,
Some(current) => {
if let (Some(cv), Some(vv)) = (bson_to_f64(&current), bson_to_f64(&val)) {
if vv > cv { val } else { current }
if vv > cv {
val
} else {
current
}
} else {
current
}
+107 -21
View File
@@ -1,7 +1,8 @@
use bson::{Bson, Document, doc};
use bson::{doc, Bson, Document};
use crate::aggregation::AggregationEngine;
use crate::error::QueryError;
use crate::field_path::{get_nested_value, set_nested_value, remove_nested_value};
use crate::field_path::{get_nested_value, remove_nested_value, set_nested_value};
use crate::matcher::QueryMatcher;
/// Update engine — applies update operators to documents.
@@ -56,6 +57,46 @@ impl UpdateEngine {
Ok(result)
}
/// Apply an aggregation pipeline update specification to a document.
pub fn apply_pipeline_update(
doc: &Document,
pipeline: &[Document],
) -> Result<Document, QueryError> {
if pipeline.is_empty() {
return Err(QueryError::InvalidUpdate(
"aggregation pipeline update cannot be empty".into(),
));
}
for stage in pipeline {
let (stage_name, _) = stage.iter().next().ok_or_else(|| {
QueryError::InvalidUpdate(
"aggregation pipeline update stages must not be empty".into(),
)
})?;
if !matches!(
stage_name.as_str(),
"$addFields" | "$set" | "$project" | "$unset" | "$replaceRoot" | "$replaceWith"
) {
return Err(QueryError::InvalidUpdate(format!(
"Unsupported aggregation pipeline update stage: {}",
stage_name
)));
}
}
let mut results = AggregationEngine::aggregate(vec![doc.clone()], pipeline, None, "")
.map_err(|e| QueryError::InvalidUpdate(e.to_string()))?;
match results.len() {
1 => Ok(results.remove(0)),
_ => Err(QueryError::InvalidUpdate(
"aggregation pipeline update must produce exactly one document".into(),
)),
}
}
/// Apply $setOnInsert fields (used during upsert only).
pub fn apply_set_on_insert(doc: &mut Document, fields: &Document) {
for (key, value) in fields {
@@ -252,16 +293,14 @@ impl UpdateEngine {
for (key, spec) in fields {
let value = match spec {
Bson::Boolean(true) => Bson::DateTime(now),
Bson::Document(d) => {
match d.get_str("$type").unwrap_or("date") {
"date" => Bson::DateTime(now),
"timestamp" => Bson::Timestamp(bson::Timestamp {
time: (now.timestamp_millis() / 1000) as u32,
increment: 0,
}),
_ => Bson::DateTime(now),
}
}
Bson::Document(d) => match d.get_str("$type").unwrap_or("date") {
"date" => Bson::DateTime(now),
"timestamp" => Bson::Timestamp(bson::Timestamp {
time: (now.timestamp_millis() / 1000) as u32,
increment: 0,
}),
_ => Bson::DateTime(now),
},
_ => continue,
};
@@ -282,7 +321,9 @@ impl UpdateEngine {
Bson::Document(d) if d.contains_key("$each") => {
let each = match d.get("$each") {
Some(Bson::Array(a)) => a.clone(),
_ => return Err(QueryError::InvalidUpdate("$each must be an array".into())),
_ => {
return Err(QueryError::InvalidUpdate("$each must be an array".into()))
}
};
let position = d.get("$position").and_then(|v| match v {
@@ -325,11 +366,21 @@ impl UpdateEngine {
continue;
}
match direction {
Bson::Int32(-1) | Bson::Int64(-1) => { arr.remove(0); }
Bson::Int32(1) | Bson::Int64(1) => { arr.pop(); }
Bson::Double(f) if *f == 1.0 => { arr.pop(); }
Bson::Double(f) if *f == -1.0 => { arr.remove(0); }
_ => { arr.pop(); }
Bson::Int32(-1) | Bson::Int64(-1) => {
arr.remove(0);
}
Bson::Int32(1) | Bson::Int64(1) => {
arr.pop();
}
Bson::Double(f) if *f == 1.0 => {
arr.pop();
}
Bson::Double(f) if *f == -1.0 => {
arr.remove(0);
}
_ => {
arr.pop();
}
}
}
}
@@ -455,7 +506,11 @@ impl UpdateEngine {
let ascending = *dir > 0;
arr.sort_by(|a, b| {
let ord = partial_cmp_bson(a, b);
if ascending { ord } else { ord.reverse() }
if ascending {
ord
} else {
ord.reverse()
}
});
}
Bson::Document(spec) => {
@@ -465,8 +520,16 @@ impl UpdateEngine {
Bson::Int32(n) => *n > 0,
_ => true,
};
let a_val = if let Bson::Document(d) = a { d.get(field) } else { None };
let b_val = if let Bson::Document(d) = b { d.get(field) } else { None };
let a_val = if let Bson::Document(d) = a {
d.get(field)
} else {
None
};
let b_val = if let Bson::Document(d) = b {
d.get(field)
} else {
None
};
let ord = match (a_val, b_val) {
(Some(av), Some(bv)) => partial_cmp_bson(av, bv),
(Some(_), None) => std::cmp::Ordering::Greater,
@@ -572,4 +635,27 @@ mod tests {
let tags = result.get_array("tags").unwrap();
assert_eq!(tags.len(), 2); // no duplicate
}
#[test]
fn test_pipeline_update() {
let doc = doc! { "_id": 1, "name": "Alice", "age": 30, "legacy": true };
let pipeline = vec![
doc! { "$set": { "displayName": "$name", "status": "updated" } },
doc! { "$unset": ["legacy"] },
];
let result = UpdateEngine::apply_pipeline_update(&doc, &pipeline).unwrap();
assert_eq!(result.get_str("displayName").unwrap(), "Alice");
assert_eq!(result.get_str("status").unwrap(), "updated");
assert!(result.get("legacy").is_none());
}
#[test]
fn test_pipeline_update_rejects_unsupported_stage() {
let doc = doc! { "_id": 1, "name": "Alice" };
let pipeline = vec![doc! { "$match": { "name": "Alice" } }];
let result = UpdateEngine::apply_pipeline_update(&doc, &pipeline);
assert!(matches!(result, Err(QueryError::InvalidUpdate(_))));
}
}
+79 -24
View File
@@ -178,7 +178,8 @@ impl CollectionState {
tracing::warn!("compaction failed for {:?}: {e}", self.coll_dir);
} else {
// Persist hint file after successful compaction to prevent stale hints
if let Err(e) = self.keydir.persist_to_hint_file(&self.hint_path()) {
let current_size = self.data_file_size.load(Ordering::Relaxed);
if let Err(e) = self.keydir.persist_to_hint_file(&self.hint_path(), current_size) {
tracing::warn!("failed to persist hint after compaction for {:?}: {e}", self.coll_dir);
}
}
@@ -186,6 +187,27 @@ impl CollectionState {
}
}
fn truncate_invalid_tail(
data_path: &PathBuf,
stats: &crate::keydir::BuildStats,
) -> StorageResult<()> {
if stats.invalid_tail_bytes == 0 {
return Ok(());
}
tracing::warn!(
path = %data_path.display(),
valid_data_end = stats.valid_data_end,
invalid_tail_bytes = stats.invalid_tail_bytes,
"truncating invalid data file tail"
);
let file = std::fs::OpenOptions::new().write(true).open(data_path)?;
file.set_len(stats.valid_data_end)?;
file.sync_all()?;
Ok(())
}
// ---------------------------------------------------------------------------
// Collection cache key: "db\0coll"
// ---------------------------------------------------------------------------
@@ -257,36 +279,61 @@ impl FileStorageAdapter {
// Try loading from hint file first, fall back to data file scan
let (keydir, dead_bytes, loaded_from_hint) = if hint_path.exists() && data_path.exists() {
match KeyDir::load_from_hint_file(&hint_path) {
Ok(Some(kd)) => {
// Validate hint against actual data file
let hint_valid = kd.validate_against_data_file(&data_path, 16)
.unwrap_or(false);
if hint_valid {
debug!("loaded KeyDir from hint file: {:?}", hint_path);
let file_size = std::fs::metadata(&data_path)
.map(|m| m.len())
.unwrap_or(FILE_HEADER_SIZE as u64);
let live_bytes: u64 = {
let mut total = 0u64;
kd.for_each(|_, e| total += e.record_len as u64);
total
};
let dead = file_size.saturating_sub(FILE_HEADER_SIZE as u64).saturating_sub(live_bytes);
(kd, dead, true)
} else {
tracing::warn!("hint file {:?} is stale, rebuilding from data file", hint_path);
let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?;
Ok(Some((kd, stored_size))) => {
let actual_size = std::fs::metadata(&data_path)
.map(|m| m.len())
.unwrap_or(0);
// Check if data.rdb changed since the hint was written.
// If stored_size is 0, this is an old-format hint without size tracking.
let size_matches = stored_size > 0 && stored_size == actual_size;
if !size_matches {
// data.rdb size differs from hint snapshot — records were appended
// (inserts, tombstones) after the hint was written. Full scan required
// to pick up tombstones that would otherwise be invisible.
if stored_size == 0 {
debug!("hint file {:?} has no size tracking, rebuilding from data file", hint_path);
} else {
tracing::warn!(
"hint file {:?} is stale: data size changed ({} -> {}), rebuilding",
hint_path, stored_size, actual_size
);
}
let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?;
truncate_invalid_tail(&data_path, &stats)?;
(kd, dead, false)
} else {
// Size matches — validate entry integrity with spot-checks
let hint_valid = kd.validate_against_data_file(&data_path, 16)
.unwrap_or(false);
if hint_valid {
debug!("loaded KeyDir from hint file: {:?}", hint_path);
let live_bytes: u64 = {
let mut total = 0u64;
kd.for_each(|_, e| total += e.record_len as u64);
total
};
let dead = actual_size.saturating_sub(FILE_HEADER_SIZE as u64).saturating_sub(live_bytes);
(kd, dead, true)
} else {
tracing::warn!("hint file {:?} failed validation, rebuilding from data file", hint_path);
let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?;
truncate_invalid_tail(&data_path, &stats)?;
(kd, dead, false)
}
}
}
_ => {
debug!("hint file invalid, rebuilding KeyDir from data file");
let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?;
let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?;
truncate_invalid_tail(&data_path, &stats)?;
(kd, dead, false)
}
}
} else if data_path.exists() {
let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?;
let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?;
truncate_invalid_tail(&data_path, &stats)?;
(kd, dead, false)
} else {
(KeyDir::new(), 0, false)
@@ -482,6 +529,13 @@ impl StorageAdapter for FileStorageAdapter {
"FileStorageAdapter initialization complete"
);
// Run compaction on all collections that need it (dead weight from before crash)
for entry in self.collections.iter() {
let state = entry.value();
let _guard = state.write_lock.lock().unwrap();
state.try_compact();
}
// Start periodic compaction task (runs every 24 hours)
{
let collections = self.collections.clone();
@@ -510,10 +564,11 @@ impl StorageAdapter for FileStorageAdapter {
handle.abort();
}
// Persist all KeyDir hint files
// Persist all KeyDir hint files with current data file sizes
for entry in self.collections.iter() {
let state = entry.value();
let _ = state.keydir.persist_to_hint_file(&state.hint_path());
let current_size = state.data_file_size.load(Ordering::Relaxed);
let _ = state.keydir.persist_to_hint_file(&state.hint_path(), current_size);
}
debug!("FileStorageAdapter closed");
Ok(())
+62 -14
View File
@@ -14,7 +14,7 @@ use dashmap::DashMap;
use crate::error::{StorageError, StorageResult};
use crate::record::{
DataRecord, FileHeader, FileType, RecordScanner, FILE_HEADER_SIZE, FORMAT_VERSION,
DataRecord, FileHeader, FileType, FILE_HEADER_SIZE, FORMAT_VERSION,
};
// ---------------------------------------------------------------------------
@@ -49,6 +49,10 @@ pub struct BuildStats {
pub tombstones: u64,
/// Number of records superseded by a later write for the same key.
pub superseded_records: u64,
/// Byte offset immediately after the last valid record.
pub valid_data_end: u64,
/// Number of invalid tail bytes after the last valid record.
pub invalid_tail_bytes: u64,
}
// ---------------------------------------------------------------------------
@@ -137,6 +141,7 @@ impl KeyDir {
/// stale records (superseded by later writes or tombstoned).
pub fn build_from_data_file(path: &Path) -> StorageResult<(Self, u64, BuildStats)> {
let file = std::fs::File::open(path)?;
let file_len = file.metadata()?.len();
let mut reader = BufReader::new(file);
// Read and validate file header
@@ -152,13 +157,49 @@ impl KeyDir {
let keydir = KeyDir::new();
let mut dead_bytes: u64 = 0;
let mut stats = BuildStats::default();
let mut stats = BuildStats {
valid_data_end: FILE_HEADER_SIZE as u64,
..BuildStats::default()
};
let scanner = RecordScanner::new(reader, FILE_HEADER_SIZE as u64);
for result in scanner {
let (offset, record) = result?;
loop {
let record_offset = stats.valid_data_end;
let (record, disk_size) = match DataRecord::decode_from(&mut reader) {
Ok(Some((record, disk_size))) => (record, disk_size),
Ok(None) => {
if file_len > record_offset {
stats.invalid_tail_bytes = file_len - record_offset;
}
break;
}
Err(StorageError::IoError(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
stats.invalid_tail_bytes = file_len.saturating_sub(record_offset);
break;
}
Err(StorageError::ChecksumMismatch { expected, actual }) => {
tracing::warn!(
path = %path.display(),
offset = record_offset,
"stopping data file scan at checksum mismatch: expected 0x{expected:08X}, got 0x{actual:08X}"
);
stats.invalid_tail_bytes = file_len.saturating_sub(record_offset);
break;
}
Err(StorageError::CorruptRecord(message)) => {
tracing::warn!(
path = %path.display(),
offset = record_offset,
"stopping data file scan at corrupt record: {message}"
);
stats.invalid_tail_bytes = file_len.saturating_sub(record_offset);
break;
}
Err(e) => return Err(e),
};
stats.valid_data_end += disk_size as u64;
let is_tombstone = record.is_tombstone();
let disk_size = record.disk_size() as u32;
let disk_size = disk_size as u32;
let value_len = record.value.len() as u32;
let timestamp = record.timestamp;
let key = String::from_utf8(record.key)
@@ -175,7 +216,7 @@ impl KeyDir {
dead_bytes += disk_size as u64;
} else {
let entry = KeyDirEntry {
offset,
offset: record_offset,
record_len: disk_size,
value_len,
timestamp,
@@ -198,14 +239,17 @@ impl KeyDir {
/// Persist the KeyDir to a hint file for fast restart.
///
/// `data_file_size` is the current size of data.rdb — stored in the hint header
/// so that on next load we can detect if data.rdb changed (stale hint).
///
/// Hint file format (after the 64-byte file header):
/// For each entry: [key_len:u32 LE][key bytes][offset:u64 LE][record_len:u32 LE][value_len:u32 LE][timestamp:u64 LE]
pub fn persist_to_hint_file(&self, path: &Path) -> StorageResult<()> {
pub fn persist_to_hint_file(&self, path: &Path, data_file_size: u64) -> StorageResult<()> {
let file = std::fs::File::create(path)?;
let mut writer = BufWriter::new(file);
// Write file header
let hdr = FileHeader::new(FileType::Hint);
// Write file header with data_file_size for staleness detection
let hdr = FileHeader::new_hint(data_file_size);
writer.write_all(&hdr.encode())?;
// Write entries
@@ -225,7 +269,9 @@ impl KeyDir {
}
/// Load a KeyDir from a hint file. Returns None if the file doesn't exist.
pub fn load_from_hint_file(path: &Path) -> StorageResult<Option<Self>> {
/// Returns `(keydir, stored_data_file_size)` where `stored_data_file_size` is the
/// data.rdb size recorded when the hint was written (0 = old format, unknown).
pub fn load_from_hint_file(path: &Path) -> StorageResult<Option<(Self, u64)>> {
if !path.exists() {
return Ok(None);
}
@@ -254,6 +300,7 @@ impl KeyDir {
)));
}
let stored_data_file_size = hdr.data_file_size;
let keydir = KeyDir::new();
loop {
@@ -292,7 +339,7 @@ impl KeyDir {
);
}
Ok(Some(keydir))
Ok(Some((keydir, stored_data_file_size)))
}
// -----------------------------------------------------------------------
@@ -517,9 +564,10 @@ mod tests {
},
);
kd.persist_to_hint_file(&hint_path).unwrap();
let loaded = KeyDir::load_from_hint_file(&hint_path).unwrap().unwrap();
kd.persist_to_hint_file(&hint_path, 12345).unwrap();
let (loaded, stored_size) = KeyDir::load_from_hint_file(&hint_path).unwrap().unwrap();
assert_eq!(stored_size, 12345);
assert_eq!(loaded.len(), 2);
let e1 = loaded.get("doc1").unwrap();
assert_eq!(e1.offset, 64);
+21 -1
View File
@@ -79,6 +79,9 @@ pub struct FileHeader {
pub file_type: FileType,
pub flags: u32,
pub created_ms: u64,
/// For hint files: the data.rdb file size at the time the hint was written.
/// Used to detect stale hints after ungraceful shutdown. 0 = unknown (old format).
pub data_file_size: u64,
}
impl FileHeader {
@@ -89,6 +92,18 @@ impl FileHeader {
file_type,
flags: 0,
created_ms: now_ms(),
data_file_size: 0,
}
}
/// Create a new hint header that records the data file size.
pub fn new_hint(data_file_size: u64) -> Self {
Self {
version: FORMAT_VERSION,
file_type: FileType::Hint,
flags: 0,
created_ms: now_ms(),
data_file_size,
}
}
@@ -100,7 +115,8 @@ impl FileHeader {
buf[10] = self.file_type as u8;
buf[11..15].copy_from_slice(&self.flags.to_le_bytes());
buf[15..23].copy_from_slice(&self.created_ms.to_le_bytes());
// bytes 23..64 are reserved (zeros)
buf[23..31].copy_from_slice(&self.data_file_size.to_le_bytes());
// bytes 31..64 are reserved (zeros)
buf
}
@@ -127,11 +143,15 @@ impl FileHeader {
let created_ms = u64::from_le_bytes([
buf[15], buf[16], buf[17], buf[18], buf[19], buf[20], buf[21], buf[22],
]);
let data_file_size = u64::from_le_bytes([
buf[23], buf[24], buf[25], buf[26], buf[27], buf[28], buf[29], buf[30],
]);
Ok(Self {
version,
file_type,
flags,
created_ms,
data_file_size,
})
}
}
+7 -1
View File
@@ -295,7 +295,13 @@ fn validate_collection(db: &str, coll: &str, coll_dir: &Path) -> CollectionRepor
// Validate hint file if present
if hint_path.exists() {
match KeyDir::load_from_hint_file(&hint_path) {
Ok(Some(hint_kd)) => {
Ok(Some((hint_kd, stored_size))) => {
if stored_size > 0 && stored_size != report.data_file_size {
report.errors.push(format!(
"hint file is stale: recorded data size {} but actual is {}",
stored_size, report.data_file_size
));
}
// Check for orphaned entries: keys in hint but not live in data
hint_kd.for_each(|key, _entry| {
if !live_ids.contains(key) {
+10
View File
@@ -170,6 +170,16 @@ impl SessionEngine {
}
count
}
/// Number of currently tracked logical sessions.
pub fn len(&self) -> usize {
self.sessions.len()
}
/// Whether there are no tracked logical sessions.
pub fn is_empty(&self) -> bool {
self.sessions.is_empty()
}
}
impl Default for SessionEngine {
+104 -11
View File
@@ -18,7 +18,7 @@ pub enum TransactionStatus {
}
/// Describes a write operation within a transaction.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WriteOp {
Insert,
Update,
@@ -137,6 +137,25 @@ impl TransactionEngine {
Ok(())
}
/// Remove an active transaction and return its buffered state for an
/// external committer that needs to update secondary indexes and oplogs.
pub fn take_transaction(&self, txn_id: &str) -> TransactionResult<TransactionState> {
let state = self
.transactions
.remove(txn_id)
.map(|(_, s)| s)
.ok_or_else(|| TransactionError::NotFound(txn_id.to_string()))?;
if state.status != TransactionStatus::Active {
return Err(TransactionError::InvalidState(format!(
"transaction {} is {:?}, cannot commit",
txn_id, state.status
)));
}
Ok(state)
}
/// Abort a transaction, discarding all buffered writes.
pub fn abort_transaction(&self, txn_id: &str) -> TransactionResult<()> {
let mut state = self
@@ -191,19 +210,32 @@ impl TransactionEngine {
original: Option<Document>,
) {
if let Some(mut state) = self.transactions.get_mut(txn_id) {
let entry = WriteEntry {
op,
doc,
original_doc: original,
};
state
.write_set
.entry(ns.to_string())
.or_default()
.insert(doc_id.to_string(), entry);
let writes = state.write_set.entry(ns.to_string()).or_default();
if let Some(existing) = writes.remove(doc_id) {
if let Some(merged) = merge_write_entry(existing, op, doc, original) {
writes.insert(doc_id.to_string(), merged);
}
} else {
writes.insert(
doc_id.to_string(),
WriteEntry {
op,
doc,
original_doc: original,
},
);
}
}
}
/// Return true if the transaction already has a base snapshot for a namespace.
pub fn has_snapshot(&self, txn_id: &str, ns: &str) -> bool {
self.transactions
.get(txn_id)
.map(|state| state.snapshots.contains_key(ns))
.unwrap_or(false)
}
/// Get a snapshot of documents for a namespace within a transaction,
/// applying the write overlay (inserts, updates, deletes) on top.
pub fn get_snapshot(&self, txn_id: &str, ns: &str) -> Option<Vec<Document>> {
@@ -270,6 +302,67 @@ impl TransactionEngine {
state.snapshots.insert(ns.to_string(), docs);
}
}
/// Number of currently active transactions.
pub fn len(&self) -> usize {
self.transactions.len()
}
/// Whether there are no active transactions.
pub fn is_empty(&self) -> bool {
self.transactions.is_empty()
}
}
fn merge_write_entry(
existing: WriteEntry,
next_op: WriteOp,
next_doc: Option<Document>,
next_original: Option<Document>,
) -> Option<WriteEntry> {
match (existing.op, next_op) {
(WriteOp::Insert, WriteOp::Update) => Some(WriteEntry {
op: WriteOp::Insert,
doc: next_doc,
original_doc: None,
}),
(WriteOp::Insert, WriteOp::Delete) => None,
(WriteOp::Insert, WriteOp::Insert) => Some(WriteEntry {
op: WriteOp::Insert,
doc: next_doc,
original_doc: None,
}),
(WriteOp::Update, WriteOp::Update) => Some(WriteEntry {
op: WriteOp::Update,
doc: next_doc,
original_doc: existing.original_doc,
}),
(WriteOp::Update, WriteOp::Delete) => Some(WriteEntry {
op: WriteOp::Delete,
doc: None,
original_doc: existing.original_doc,
}),
(WriteOp::Update, WriteOp::Insert) => Some(WriteEntry {
op: WriteOp::Update,
doc: next_doc,
original_doc: existing.original_doc,
}),
(WriteOp::Delete, WriteOp::Insert) => Some(WriteEntry {
op: if existing.original_doc.is_some() {
WriteOp::Update
} else {
WriteOp::Insert
},
doc: next_doc,
original_doc: existing.original_doc,
}),
(WriteOp::Delete, WriteOp::Update) => Some(WriteEntry {
op: WriteOp::Update,
doc: next_doc,
original_doc: existing.original_doc.or(next_original),
}),
(WriteOp::Delete, WriteOp::Delete) => Some(existing),
}
}
impl Default for TransactionEngine {
+3
View File
@@ -21,9 +21,12 @@ rustdb-query = { workspace = true }
rustdb-storage = { workspace = true }
rustdb-index = { workspace = true }
rustdb-txn = { workspace = true }
rustdb-auth = { workspace = true }
rustdb-commands = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tokio-rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
clap = { workspace = true }
+109 -9
View File
@@ -1,10 +1,12 @@
pub mod management;
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use anyhow::{Context, Result};
use dashmap::DashMap;
use tokio::net::TcpListener;
#[cfg(unix)]
@@ -12,13 +14,17 @@ use tokio::net::UnixListener;
use tokio_util::codec::Framed;
use tokio_util::sync::CancellationToken;
use rustdb_config::{RustDbOptions, StorageType};
use rustdb_config::{RustDbOptions, StorageType, TlsOptions};
use rustdb_wire::{WireCodec, OP_QUERY};
use rustdb_wire::{encode_op_msg_response, encode_op_reply_response};
use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog};
use rustdb_index::{IndexEngine, IndexOptions};
use rustdb_txn::{TransactionEngine, SessionEngine};
use rustdb_commands::{CommandRouter, CommandContext};
use rustdb_auth::AuthEngine;
use rustdb_commands::{CommandRouter, CommandContext, ConnectionState};
use tokio_rustls::rustls::{RootCertStore, ServerConfig};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::TlsAcceptor;
/// The main RustDb server.
pub struct RustDb {
@@ -150,6 +156,8 @@ impl RustDb {
}
}
let auth = Arc::new(AuthEngine::from_options(&options.auth)?);
let ctx = Arc::new(CommandContext {
storage,
indexes,
@@ -158,6 +166,7 @@ impl RustDb {
cursors: Arc::new(DashMap::new()),
start_time: std::time::Instant::now(),
oplog: Arc::new(OpLog::new()),
auth,
});
let router = Arc::new(CommandRouter::new(ctx.clone()));
@@ -215,7 +224,12 @@ impl RustDb {
} else {
let addr = format!("{}:{}", self.options.host, self.options.port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!("RustDb listening on {}", addr);
let tls_acceptor = if self.options.tls.enabled {
Some(build_tls_acceptor(&self.options.tls)?)
} else {
None
};
tracing::info!(tls = self.options.tls.enabled, "RustDb listening on {}", addr);
let handle = tokio::spawn(async move {
loop {
@@ -226,9 +240,21 @@ impl RustDb {
Ok((stream, _addr)) => {
let _ = stream.set_nodelay(true);
let router = router.clone();
tokio::spawn(async move {
handle_connection(stream, router).await;
});
match tls_acceptor.clone() {
Some(acceptor) => {
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => handle_connection(tls_stream, router).await,
Err(e) => tracing::debug!("TLS handshake failed: {}", e),
}
});
}
None => {
tokio::spawn(async move {
handle_connection(stream, router).await;
});
}
}
}
Err(e) => {
tracing::error!("Accept error: {}", e);
@@ -275,14 +301,88 @@ impl RustDb {
}
}
fn build_tls_acceptor(options: &TlsOptions) -> Result<TlsAcceptor> {
let cert_path = options
.cert_path
.as_deref()
.context("tls.certPath is required when tls.enabled is true")?;
let key_path = options
.key_path
.as_deref()
.context("tls.keyPath is required when tls.enabled is true")?;
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let config = if options.require_client_cert {
let ca_path = options
.ca_path
.as_deref()
.context("tls.caPath is required when tls.requireClientCert is true")?;
let roots = load_root_store(ca_path)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.context("failed to build TLS client certificate verifier")?;
ServerConfig::builder()
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.context("failed to build TLS server configuration")?
} else {
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.context("failed to build TLS server configuration")?
};
Ok(TlsAcceptor::from(Arc::new(config)))
}
fn load_certs(path: &str) -> Result<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>> {
let file = File::open(path).with_context(|| format!("failed to open TLS certificate file '{}'", path))?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.with_context(|| format!("failed to parse TLS certificate file '{}'", path))?;
if certs.is_empty() {
anyhow::bail!("TLS certificate file '{}' did not contain any certificates", path);
}
Ok(certs)
}
fn load_private_key(path: &str) -> Result<tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>> {
let file = File::open(path).with_context(|| format!("failed to open TLS private key file '{}'", path))?;
let mut reader = BufReader::new(file);
rustls_pemfile::private_key(&mut reader)
.with_context(|| format!("failed to parse TLS private key file '{}'", path))?
.with_context(|| format!("TLS private key file '{}' did not contain a private key", path))
}
fn load_root_store(path: &str) -> Result<RootCertStore> {
let mut roots = RootCertStore::empty();
for cert in load_certs(path)? {
roots
.add(cert)
.with_context(|| format!("failed to add TLS client CA certificate from '{}'", path))?;
}
if roots.is_empty() {
anyhow::bail!("TLS client CA file '{}' did not contain usable certificates", path);
}
Ok(roots)
}
/// Handle a single client connection using the wire protocol codec.
async fn handle_connection<S>(stream: S, router: Arc<CommandRouter>)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use futures_util::{SinkExt, StreamExt};
let mut framed = Framed::new(stream, WireCodec);
let mut connection = ConnectionState::new();
while let Some(result) = framed.next().await {
match result {
@@ -290,7 +390,7 @@ where
let request_id = parsed_cmd.request_id;
let op_code = parsed_cmd.op_code;
let response_doc = router.route(&parsed_cmd).await;
let response_doc = router.route(&parsed_cmd, &mut connection).await;
let response_id = next_request_id();
+7
View File
@@ -167,6 +167,9 @@ async fn handle_start(
Ok(o) => o,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
};
if let Err(e) = options.validate() {
return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e));
}
let connection_uri = options.connection_uri();
@@ -252,6 +255,10 @@ async fn handle_get_metrics(
"collections": total_collections,
"oplogEntries": oplog_stats.total_entries,
"oplogCurrentSeq": oplog_stats.current_seq,
"sessions": ctx.sessions.len(),
"activeTransactions": ctx.transactions.len(),
"authEnabled": ctx.auth.enabled(),
"authUsers": ctx.auth.user_count(),
"uptimeSeconds": uptime_secs,
}),
)
+178
View File
@@ -0,0 +1,178 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient } from 'mongodb';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
let server: smartdb.SmartdbServer;
let authedClient: MongoClient;
let openClient: MongoClient;
let readerClient: MongoClient;
let tmpDir: string;
let usersPath: string;
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-auth-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
tap.test('auth: should start server with SCRAM-SHA-256 auth enabled', async () => {
tmpDir = makeTmpDir();
usersPath = path.join(tmpDir, 'users.json');
server = new smartdb.SmartdbServer({
port: 27118,
auth: {
enabled: true,
usersPath,
scramIterations: 4096,
users: [
{
username: 'root',
password: 'secret',
database: 'admin',
roles: ['root'],
},
],
},
});
await server.start();
expect(server.running).toBeTrue();
});
tap.test('auth: should reject protected commands before authentication', async () => {
openClient = new MongoClient('mongodb://127.0.0.1:27118', {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await openClient.connect();
let threw = false;
try {
await openClient.db('admin').command({ ping: 1 });
} catch (err: any) {
threw = true;
expect(err.code).toEqual(13);
}
expect(threw).toBeTrue();
});
tap.test('auth: should reject invalid credentials', async () => {
const badClient = new MongoClient('mongodb://root:wrong@127.0.0.1:27118/admin?authSource=admin', {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
let threw = false;
try {
await badClient.connect();
await badClient.db('admin').command({ ping: 1 });
} catch {
threw = true;
} finally {
await badClient.close().catch(() => undefined);
}
expect(threw).toBeTrue();
});
tap.test('auth: should authenticate valid credentials', async () => {
authedClient = new MongoClient('mongodb://root:secret@127.0.0.1:27118/admin?authSource=admin', {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await authedClient.connect();
const result = await authedClient.db('admin').command({ ping: 1 });
expect(result.ok).toEqual(1);
const status = await authedClient.db('admin').command({ connectionStatus: 1 });
expect(status.ok).toEqual(1);
expect(status.authInfo.authenticatedUsers[0]).toEqual({ user: 'root', db: 'admin' });
expect(status.authInfo.authenticatedUserRoles[0]).toEqual({ role: 'root', db: 'admin' });
});
tap.test('auth: should allow CRUD after authentication', async () => {
const coll = authedClient.db('securedb').collection('notes');
const inserted = await coll.insertOne({ title: 'enterprise auth' });
expect(inserted.acknowledged).toBeTrue();
const doc = await coll.findOne({ _id: inserted.insertedId });
expect(doc).toBeTruthy();
expect(doc!.title).toEqual('enterprise auth');
});
tap.test('auth: root should create a read-only user', async () => {
const result = await authedClient.db('admin').command({
createUser: 'reader',
pwd: 'readpass',
roles: [{ role: 'read', db: 'securedb' }],
});
expect(result.ok).toEqual(1);
const usersInfo = await authedClient.db('admin').command({ usersInfo: 'reader' });
expect(usersInfo.ok).toEqual(1);
expect(usersInfo.users.length).toEqual(1);
expect(usersInfo.users[0].user).toEqual('reader');
});
tap.test('auth: read-only user should read but not write', async () => {
readerClient = new MongoClient('mongodb://reader:readpass@127.0.0.1:27118/admin?authSource=admin', {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await readerClient.connect();
const doc = await readerClient.db('securedb').collection('notes').findOne({ title: 'enterprise auth' });
expect(doc).toBeTruthy();
let threw = false;
try {
await readerClient.db('securedb').collection('notes').insertOne({ title: 'denied write' });
} catch (err: any) {
threw = true;
expect(err.code).toEqual(13);
}
expect(threw).toBeTrue();
});
tap.test('auth: persisted users should survive server restart', async () => {
await readerClient.close();
await authedClient.close();
await server.stop();
// Simulates a crash after writing the temporary auth metadata file but before rename.
fs.writeFileSync(path.join(tmpDir, 'users.tmp'), '{ invalid json');
server = new smartdb.SmartdbServer({
port: 27118,
auth: {
enabled: true,
usersPath,
users: [],
scramIterations: 4096,
},
});
await server.start();
readerClient = new MongoClient('mongodb://reader:readpass@127.0.0.1:27118/admin?authSource=admin', {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await readerClient.connect();
const result = await readerClient.db('admin').command({ ping: 1 });
expect(result.ok).toEqual(1);
});
tap.test('auth: cleanup', async () => {
await openClient.close();
await readerClient.close();
await server.stop();
expect(server.running).toBeFalse();
cleanTmpDir(tmpDir);
});
export default tap.start();
+91
View File
@@ -0,0 +1,91 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient, Db } from 'mongodb';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
let tmpDir: string;
let localDb: smartdb.LocalSmartDb;
let client: MongoClient;
let db: Db;
let dataPath: string;
let corruptedSize: number;
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-crash-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
async function startAndConnect(): Promise<void> {
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('crashtest');
}
tap.test('crash-recovery: create baseline data', async () => {
tmpDir = makeTmpDir();
await startAndConnect();
await db.collection('docs').insertMany([
{ key: 'a', value: 1 },
{ key: 'b', value: 2 },
{ key: 'c', value: 3 },
]);
await client.close();
await localDb.stop();
dataPath = path.join(tmpDir, 'crashtest', 'docs', 'data.rdb');
expect(fs.existsSync(dataPath)).toBeTrue();
});
tap.test('crash-recovery: append a torn final record', async () => {
const data = fs.readFileSync(dataPath);
const partialRecord = data.subarray(64, 94);
expect(partialRecord.length).toEqual(30);
fs.appendFileSync(dataPath, partialRecord);
corruptedSize = fs.statSync(dataPath).size;
expect(corruptedSize).toEqual(data.length + partialRecord.length);
});
tap.test('crash-recovery: restart truncates invalid tail and preserves valid records', async () => {
await startAndConnect();
const repairedSize = fs.statSync(dataPath).size;
expect(repairedSize < corruptedSize).toBeTrue();
const docs = await db.collection('docs').find({}).sort({ key: 1 }).toArray();
expect(docs.map(doc => doc.key)).toEqual(['a', 'b', 'c']);
});
tap.test('crash-recovery: future writes remain durable after tail repair', async () => {
await db.collection('docs').insertOne({ key: 'd', value: 4 });
expect(await db.collection('docs').countDocuments()).toEqual(4);
await client.close();
await localDb.stop();
await startAndConnect();
const docs = await db.collection('docs').find({}).sort({ key: 1 }).toArray();
expect(docs.map(doc => doc.key)).toEqual(['a', 'b', 'c', 'd']);
});
tap.test('crash-recovery: cleanup', async () => {
await client.close();
await localDb.stop();
cleanTmpDir(tmpDir);
});
export default tap.start();
+191
View File
@@ -0,0 +1,191 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient, Db } from 'mongodb';
import * as fs from 'fs';
import * as path from 'path';
import * as os from 'os';
// ---------------------------------------------------------------------------
// Test: Deletes persist across restart (tombstone + hint staleness detection)
// Covers: append_tombstone to data.rdb, hint file data_file_size tracking,
// stale hint detection on restart
// ---------------------------------------------------------------------------
let tmpDir: string;
let localDb: smartdb.LocalSmartDb;
let client: MongoClient;
let db: Db;
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-delete-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
// ============================================================================
// Setup
// ============================================================================
tap.test('setup: start local db and insert documents', async () => {
tmpDir = makeTmpDir();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('deletetest');
const coll = db.collection('items');
await coll.insertMany([
{ name: 'keep-1', value: 100 },
{ name: 'keep-2', value: 200 },
{ name: 'delete-me', value: 999 },
{ name: 'keep-3', value: 300 },
]);
const count = await coll.countDocuments();
expect(count).toEqual(4);
});
// ============================================================================
// Delete and verify
// ============================================================================
tap.test('delete-persistence: delete a document', async () => {
const coll = db.collection('items');
const result = await coll.deleteOne({ name: 'delete-me' });
expect(result.deletedCount).toEqual(1);
const remaining = await coll.countDocuments();
expect(remaining).toEqual(3);
const deleted = await coll.findOne({ name: 'delete-me' });
expect(deleted).toBeNull();
});
// ============================================================================
// Graceful restart: delete survives
// ============================================================================
tap.test('delete-persistence: graceful stop and restart', async () => {
await client.close();
await localDb.stop(); // graceful — writes hint file
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('deletetest');
});
tap.test('delete-persistence: deleted doc stays deleted after graceful restart', async () => {
const coll = db.collection('items');
const count = await coll.countDocuments();
expect(count).toEqual(3);
const deleted = await coll.findOne({ name: 'delete-me' });
expect(deleted).toBeNull();
// The remaining docs are intact
const keep1 = await coll.findOne({ name: 'keep-1' });
expect(keep1).toBeTruthy();
expect(keep1!.value).toEqual(100);
});
// ============================================================================
// Simulate ungraceful restart: delete after hint write, then restart
// The hint file data_file_size check should detect the stale hint
// ============================================================================
tap.test('delete-persistence: insert and delete more docs, then restart', async () => {
const coll = db.collection('items');
// Insert a new doc
await coll.insertOne({ name: 'temporary', value: 777 });
expect(await coll.countDocuments()).toEqual(4);
// Delete it
await coll.deleteOne({ name: 'temporary' });
expect(await coll.countDocuments()).toEqual(3);
const gone = await coll.findOne({ name: 'temporary' });
expect(gone).toBeNull();
});
tap.test('delete-persistence: stop and restart again', async () => {
await client.close();
await localDb.stop();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('deletetest');
});
tap.test('delete-persistence: all deletes survived second restart', async () => {
const coll = db.collection('items');
const count = await coll.countDocuments();
expect(count).toEqual(3);
// Both deletes are permanent
expect(await coll.findOne({ name: 'delete-me' })).toBeNull();
expect(await coll.findOne({ name: 'temporary' })).toBeNull();
// Survivors intact
const names = (await coll.find({}).toArray()).map(d => d.name).sort();
expect(names).toEqual(['keep-1', 'keep-2', 'keep-3']);
});
// ============================================================================
// Delete all docs and verify empty after restart
// ============================================================================
tap.test('delete-persistence: delete all remaining docs', async () => {
const coll = db.collection('items');
await coll.deleteMany({});
expect(await coll.countDocuments()).toEqual(0);
});
tap.test('delete-persistence: restart with empty collection', async () => {
await client.close();
await localDb.stop();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('deletetest');
});
tap.test('delete-persistence: collection is empty after restart', async () => {
const coll = db.collection('items');
const count = await coll.countDocuments();
expect(count).toEqual(0);
});
// ============================================================================
// Cleanup
// ============================================================================
tap.test('delete-persistence: cleanup', async () => {
await client.close();
await localDb.stop();
cleanTmpDir(tmpDir);
});
export default tap.start();
+126
View File
@@ -0,0 +1,126 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient, Db } from 'mongodb';
import * as fs from 'fs';
import * as path from 'path';
import * as os from 'os';
// ---------------------------------------------------------------------------
// Test: Missing data.rdb header recovery + startup logging
// Covers: ensure_data_header, BuildStats, info-level startup logging
// ---------------------------------------------------------------------------
let tmpDir: string;
let localDb: smartdb.LocalSmartDb;
let client: MongoClient;
let db: Db;
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-header-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
// ============================================================================
// Setup: create data, then corrupt it
// ============================================================================
tap.test('setup: start, insert data, stop', async () => {
tmpDir = makeTmpDir();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('headertest');
const coll = db.collection('docs');
await coll.insertMany([
{ key: 'a', val: 1 },
{ key: 'b', val: 2 },
{ key: 'c', val: 3 },
]);
await client.close();
await localDb.stop();
});
// ============================================================================
// Delete hint file and restart: should rebuild from data.rdb scan
// ============================================================================
tap.test('header-recovery: delete hint file and restart', async () => {
// Find and delete hint files
const dbDir = path.join(tmpDir, 'headertest', 'docs');
const hintPath = path.join(dbDir, 'keydir.hint');
if (fs.existsSync(hintPath)) {
fs.unlinkSync(hintPath);
}
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('headertest');
});
tap.test('header-recovery: data intact after hint deletion', async () => {
const coll = db.collection('docs');
const count = await coll.countDocuments();
expect(count).toEqual(3);
const a = await coll.findOne({ key: 'a' });
expect(a!.val).toEqual(1);
});
// ============================================================================
// Write new data after restart, stop, restart again
// ============================================================================
tap.test('header-recovery: write after hint-less restart', async () => {
const coll = db.collection('docs');
await coll.insertOne({ key: 'd', val: 4 });
expect(await coll.countDocuments()).toEqual(4);
});
tap.test('header-recovery: restart and verify all data', async () => {
await client.close();
await localDb.stop();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('headertest');
const coll = db.collection('docs');
const count = await coll.countDocuments();
expect(count).toEqual(4);
const keys = (await coll.find({}).toArray()).map(d => d.key).sort();
expect(keys).toEqual(['a', 'b', 'c', 'd']);
});
// ============================================================================
// Cleanup
// ============================================================================
tap.test('header-recovery: cleanup', async () => {
await client.close();
await localDb.stop();
cleanTmpDir(tmpDir);
});
export default tap.start();
+83 -1
View File
@@ -1,6 +1,6 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient, Db, Collection } from 'mongodb';
import { MongoClient, Db, Collection, ObjectId } from 'mongodb';
let server: smartdb.SmartdbServer;
let client: MongoClient;
@@ -252,6 +252,71 @@ tap.test('smartdb: update - upsert creates new document', async () => {
expect(inserted!.email).toEqual('new@example.com');
});
tap.test('smartdb: update - aggregation pipeline updateOne', async () => {
const collection = db.collection('users');
await collection.insertOne({ name: 'PipelineUser', source: 'alpha', legacy: true, visits: 2 });
const result = await collection.updateOne(
{ name: 'PipelineUser' },
[
{ $set: { sourceCopy: '$source', pipelineStatus: 'updated' } },
{ $unset: ['legacy'] },
]
);
expect(result.matchedCount).toEqual(1);
expect(result.modifiedCount).toEqual(1);
const updated = await collection.findOne({ name: 'PipelineUser' });
expect(updated).toBeTruthy();
expect(updated!.sourceCopy).toEqual('alpha');
expect(updated!.pipelineStatus).toEqual('updated');
expect(updated!.legacy).toBeUndefined();
});
tap.test('smartdb: update - aggregation pipeline upsert', async () => {
const collection = db.collection('users');
const result = await collection.updateOne(
{ name: 'PipelineUpsert' },
[
{ $set: { email: 'pipeline@example.com', status: 'new', mirroredName: '$name' } },
],
{ upsert: true }
);
expect(result.upsertedCount).toEqual(1);
const inserted = await collection.findOne({ name: 'PipelineUpsert' });
expect(inserted).toBeTruthy();
expect(inserted!.email).toEqual('pipeline@example.com');
expect(inserted!.status).toEqual('new');
expect(inserted!.mirroredName).toEqual('PipelineUpsert');
});
tap.test('smartdb: update - cannot modify immutable _id through pipeline', async () => {
const collection = db.collection('users');
const inserted = await collection.insertOne({ name: 'ImmutableIdUser' });
let threw = false;
try {
await collection.updateOne(
{ _id: inserted.insertedId },
[
{ $set: { _id: new ObjectId() } },
]
);
} catch (err: any) {
threw = true;
expect(err.code).toEqual(66);
}
expect(threw).toBeTrue();
const persisted = await collection.findOne({ _id: inserted.insertedId });
expect(persisted).toBeTruthy();
expect(persisted!.name).toEqual('ImmutableIdUser');
});
// ============================================================================
// Cursor Tests
// ============================================================================
@@ -306,6 +371,23 @@ tap.test('smartdb: findOneAndUpdate - returns updated document', async () => {
expect(result!.status).toEqual('active');
});
tap.test('smartdb: findOneAndUpdate - supports aggregation pipeline updates', async () => {
const collection = db.collection('users');
await collection.insertOne({ name: 'PipelineFindAndModify', sourceName: 'Finder' });
const result = await collection.findOneAndUpdate(
{ name: 'PipelineFindAndModify' },
[
{ $set: { displayName: '$sourceName', mode: 'pipeline' } },
],
{ returnDocument: 'after' }
);
expect(result).toBeTruthy();
expect(result!.displayName).toEqual('Finder');
expect(result!.mode).toEqual('pipeline');
});
tap.test('smartdb: findOneAndDelete - returns deleted document', async () => {
const collection = db.collection('users');
+82
View File
@@ -0,0 +1,82 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import * as fs from 'fs';
import * as net from 'net';
import * as path from 'path';
import * as os from 'os';
// ---------------------------------------------------------------------------
// Test: Stale socket cleanup on startup
// Covers: LocalSmartDb.cleanStaleSockets(), isSocketAlive()
// ---------------------------------------------------------------------------
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-socket-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
// ============================================================================
// Stale socket cleanup: active sockets are preserved
// ============================================================================
tap.test('stale-sockets: does not remove active sockets', async () => {
const tmpDir = makeTmpDir();
const activeSocketPath = path.join(os.tmpdir(), `smartdb-active-${Date.now()}.sock`);
// Create an active socket (server still listening)
const activeServer = net.createServer();
await new Promise<void>((resolve) => activeServer.listen(activeSocketPath, resolve));
expect(fs.existsSync(activeSocketPath)).toBeTrue();
// Start LocalSmartDb — should NOT remove the active socket
const localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
await localDb.start();
expect(fs.existsSync(activeSocketPath)).toBeTrue();
// Cleanup
await localDb.stop();
await new Promise<void>((resolve) => activeServer.close(() => resolve()));
try { fs.unlinkSync(activeSocketPath); } catch {}
cleanTmpDir(tmpDir);
});
// ============================================================================
// Stale socket cleanup: startup works with no stale sockets
// ============================================================================
tap.test('stale-sockets: startup works cleanly with no stale sockets', async () => {
const tmpDir = makeTmpDir();
const localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
expect(localDb.running).toBeTrue();
expect(info.socketPath).toBeTruthy();
await localDb.stop();
cleanTmpDir(tmpDir);
});
// ============================================================================
// Stale socket cleanup: the socket file for the current instance is cleaned on stop
// ============================================================================
tap.test('stale-sockets: own socket file is removed on stop', async () => {
const tmpDir = makeTmpDir();
const localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
expect(fs.existsSync(info.socketPath)).toBeTrue();
await localDb.stop();
// Socket file should be gone after graceful stop
expect(fs.existsSync(info.socketPath)).toBeFalse();
cleanTmpDir(tmpDir);
});
export default tap.start();
+171
View File
@@ -0,0 +1,171 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient } from 'mongodb';
import * as fs from 'fs';
import * as net from 'net';
import * as os from 'os';
import * as path from 'path';
// Static test-only CA and server certificate. The private key is intentionally
// non-secret test fixture material and must not be reused outside tests.
const CA_PEM = `-----BEGIN CERTIFICATE-----
MIIDFTCCAf2gAwIBAgIUXQlk6FLuWELDKLw9KXi0UIYmU50wDQYJKoZIhvcNAQEL
BQAwGjEYMBYGA1UEAwwPU21hcnREQiBUZXN0IENBMB4XDTI2MDQyOTIxMjYxNFoX
DTM2MDQyNjIxMjYxNFowGjEYMBYGA1UEAwwPU21hcnREQiBUZXN0IENBMIIBIjAN
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApnRgZvodreKEKkSodwgDe2JKsA3N
GC4c7dmqmOBRQst0OYRoW0kjHnzCVHoGlMTAnjJWXRayPeJCroSA0WhEZIjgHAjW
FuWIr+MUYdCG7czdbDEqZYGsrBDUwv+ydgsDNhLKtbfVfcJckdmFp+TT+Po3sf8o
u5AfOlcjhM22reBLhZJ2FfM2IbqygRbBxNvU3tH5E1kgu2CpYieXQsmqBwkOPM0S
fgkCjlqFeeqV7Jjdq1P6srIItzg6n8/5KGBTxc7VB11WxVAZMIxnOtwpOCpSjbiy
jymBLKvyZxklWGpG9HT6RzUTdp0WpwnO7FlbYqD97jrbwA7PfhbJVUkTeQIDAQAB
o1MwUTAdBgNVHQ4EFgQUaqFWiFvibBYpJjluNW4XlocmqOQwHwYDVR0jBBgwFoAU
aqFWiFvibBYpJjluNW4XlocmqOQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B
AQsFAAOCAQEAdbmRCxeHwfq6Mw0BRXWYM81xrzDMDBwLkIyaVkBJXCEX4Ybj8QHv
tplNqgQae1Hr1qYyNzkivDI/hPnvv/wDsAnT8Wz0/udPpcASTXC03xhRtFXwBSGq
2GtLa53cZHJLoGu1S2ntM6Xo3gropXSx/+LIfefsQvqRO/5WxRrEE10OiFr19rA7
md0nD6zXdwrMRghu6ACuxX6Ext6QJbTL4r1UGbHg2a9UbdBjcb8sfFPLyEjiLpBK
DYvRjddKOwbOpFPoLwmed59Pa6bcqT9NnkRHL+aXUm3M3HfVhNKae7JJShUmCzdx
rbKNJQAUp/mMHnBOSxYS7aqgwBKCiKtP4A==
-----END CERTIFICATE-----
`;
const SERVER_CERT_PEM = `-----BEGIN CERTIFICATE-----
MIIDPTCCAiWgAwIBAgIUMfuX4VHvVJ8Vo6o1U2+f7MHU7dowDQYJKoZIhvcNAQEL
BQAwGjEYMBYGA1UEAwwPU21hcnREQiBUZXN0IENBMB4XDTI2MDQyOTIxMjYxNFoX
DTM2MDQyNjIxMjYxNFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG
9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5eFz1q4juQsEE7cPN5eFrLvRJW/zOMGBmiet
VTQSqVZ/3j3NBWsgxK2xQnNbEXGMlTEE11ih0cCQacc/JnbuvwOt3QX8X6oy4pmb
LMGQJEk2FgdpP6OtGqqYbt/fT7QBY39nt6z/RzxYZI7t5g/nkHnlzmzD+ila6k9b
TzBSfSmtHHKW/c6az/Dh/xe50zDgrzlBA7e5zoleKqRJFRZlDnDoLyx0EOUbbTbQ
vipMynP5bq8l6Fc0N9DAWmXvV4o2x0ZQjfEx5LTvbxNkVWtv8w9w4t4vAZqXwrXd
5OZETMWdy7ezxL0E9Snwc6sSfatlVenD/8P5hWJ/C0vCiw21RwIDAQABo4GAMH4w
GgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMAsGA1UdDwQEAwIFoDATBgNVHSUE
DDAKBggrBgEFBQcDATAdBgNVHQ4EFgQUK2nSXereMZek6gxLweY1AVt9OaswHwYD
VR0jBBgwFoAUaqFWiFvibBYpJjluNW4XlocmqOQwDQYJKoZIhvcNAQELBQADggEB
AAkC6suxamn+OEmJLMqgaGCvEtFbob5pMijYC32vJNPev+bUHMOB4Oo0FyO59sX3
zfLLwk7jagbWJi37T714aSjyJwUHd4XA7McSabP4+1hOOL0NqfiE4yRnxPhlvf3E
9otoStAAJ86067DwIs5id7jYm+qrxn6bL+P1h+P1tYxnPOoD0v1cHVbtUNV2tH2E
eBhdtTbF+NHrj+oXFGI3jiI7qcwpJ9DFUo/w0sC0POY0T5aWl4ptSXVgEc7nkE91
bbPOPyoMjjZ4WhKAW5UzfOafB0bO7+4E0GHcAkBJmS4V8g5qt56nftr+d58R/odY
0hQjpoIwzl9RCEW0h8xkqMQ=
-----END CERTIFICATE-----
`;
const SERVER_KEY_PEM = `-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDl4XPWriO5CwQT
tw83l4Wsu9Elb/M4wYGaJ61VNBKpVn/ePc0FayDErbFCc1sRcYyVMQTXWKHRwJBp
xz8mdu6/A63dBfxfqjLimZsswZAkSTYWB2k/o60aqphu399PtAFjf2e3rP9HPFhk
ju3mD+eQeeXObMP6KVrqT1tPMFJ9Ka0ccpb9zprP8OH/F7nTMOCvOUEDt7nOiV4q
pEkVFmUOcOgvLHQQ5RttNtC+KkzKc/luryXoVzQ30MBaZe9XijbHRlCN8THktO9v
E2RVa2/zD3Di3i8BmpfCtd3k5kRMxZ3Lt7PEvQT1KfBzqxJ9q2VV6cP/w/mFYn8L
S8KLDbVHAgMBAAECggEAAInWJR8US1cow8kOepFayUxJUZ6hAbWGUa+dGtF757Sh
qQoZBFW7ZmqHu0Gc6X4MF79dJQn6mwyp6e2DCtqFdaITEqz0ad7yrpAwilrLtSIM
w+FxkCoYejMDF2Nj2QJxbGO8gPQhRu/vvxCMoxjPcImwjZq4nMnjAiB8dMOGte9V
av/RoWUOFXqeiJHqAXiE372I4BupwYhGrSUQyuVj3SugDRbzvPepTQNRxaBJQPgy
4ZtZ8FjJdPFvlyxv6fmLFULHwPNcS6PLWPuwpj7oEQzG4/Q9ojYj4EPdpoOW7qoH
h1Y6ag1vk5A/m9DjvMhIDzmUJmq8mlldxqbCBpH0+QKBgQD3Eh7F0ZXdLQe/aG5t
ul9hTv68NZa5M0JzJinB6WjXl2s0bUgIvAE9ZmfUYHs8AMvTu4YwJqsrpMuzFOT9
Ct5wBSyFbPzVOt9MYE1Gipxx8RfEMSq7Sp0MjarX3h0Va8ry83NWzrN1CvyP8BQq
CuXo/IislCDgPg0uXhLD/7GsWQKBgQDuMEptldCKtpW6CdLdYih6xh0j1mdGU4Kb
7mTzo3OU3nDnGXGhqvJt/xpksPl7GPRHYQ1dqRzvLKHDtTJqhkedZBnE6A94LkVl
uNJnR8v4PkR9nKKg0uK2ug9VcfSiXUpl2yyYiDc123WjHdwH2U6BV3smb/7KwEvv
FWaP7PO6nwKBgAE2w5PxPa1ChWE5YCGF4uYVf0bpdH4gdFkgfOAJB4zXn504VDxG
wDLPB/+RIcnfryCxMS2XYwvp2V5d4eokXYdrXxagvHVHvsUfTAHmuHIO3zEFlNIq
wa7IG2jIHJh4WRzseUqZ5WPT0/3ZDiBOwWZtpzZB3A99/o6Vw73WycaxAoGAHTeR
OaYB4bIJ5bskwYEz4/N/SZEYM/k0cTop6fTnzaAHi2GEncchW7rKGwXWZHIoLMVL
5WxEH1aDNUV5vLVh/X1058FrfFt4qcSlEoQtEfNZZWscS8vygWWLUfjbgDsfUCU1
cDRtSU71PCACiHfweE8pzQo539b8uYQPg6IWN5MCgYA6z/kvGiBB9xFBUAJPsj+w
XW/UGbn7svZaCob+N5RA9Rs/0idv/bO2nAauZyHG/nn6HXII6U5pmRyVqWKhI22q
K3J0LCP42Zb6/eYzQPbP1jWHCMaL2QJQGsl4NMZixlnNJV0aG/5CButqzSC/cMbG
DX0n+YqqWmCgHWU2csnlAA==
-----END PRIVATE KEY-----
`;
let server: smartdb.SmartdbServer;
let client: MongoClient;
let tmpDir: string;
let caPath: string;
let certPath: string;
let keyPath: string;
let port: number;
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-tls-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
async function getFreePort(): Promise<number> {
return await new Promise((resolve, reject) => {
const probe = net.createServer();
probe.once('error', reject);
probe.listen(0, '127.0.0.1', () => {
const address = probe.address();
if (!address || typeof address === 'string') {
probe.close(() => reject(new Error('Failed to allocate TCP port')));
return;
}
probe.close(() => resolve(address.port));
});
});
}
tap.test('tls: should start server with TLS enabled', async () => {
tmpDir = makeTmpDir();
port = await getFreePort();
caPath = path.join(tmpDir, 'ca.pem');
certPath = path.join(tmpDir, 'server.pem');
keyPath = path.join(tmpDir, 'server.key');
fs.writeFileSync(caPath, CA_PEM);
fs.writeFileSync(certPath, SERVER_CERT_PEM);
fs.writeFileSync(keyPath, SERVER_KEY_PEM, { mode: 0o600 });
server = new smartdb.SmartdbServer({
port,
tls: {
enabled: true,
certPath,
keyPath,
},
});
await server.start();
expect(server.running).toBeTrue();
expect(server.getConnectionUri()).toEqual(`mongodb://127.0.0.1:${port}/?tls=true`);
});
tap.test('tls: should connect with official MongoClient and CA validation', async () => {
client = new MongoClient(server.getConnectionUri(), {
directConnection: true,
serverSelectionTimeoutMS: 5000,
tlsCAFile: caPath,
});
await client.connect();
const ping = await client.db('admin').command({ ping: 1 });
expect(ping.ok).toEqual(1);
});
tap.test('tls: should support CRUD over encrypted transport', async () => {
const collection = client.db('tlsdb').collection('notes');
const inserted = await collection.insertOne({ title: 'encrypted transport' });
expect(inserted.acknowledged).toBeTrue();
const doc = await collection.findOne({ _id: inserted.insertedId });
expect(doc).toBeTruthy();
expect(doc!.title).toEqual('encrypted transport');
});
tap.test('tls: cleanup', async () => {
await client.close();
await server.stop();
expect(server.running).toBeFalse();
cleanTmpDir(tmpDir);
});
export default tap.start();
+160
View File
@@ -0,0 +1,160 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient } from 'mongodb';
import * as net from 'net';
let server: smartdb.SmartdbServer;
let client: MongoClient;
let port: number;
async function getFreePort(): Promise<number> {
return await new Promise((resolve, reject) => {
const probe = net.createServer();
probe.once('error', reject);
probe.listen(0, '127.0.0.1', () => {
const address = probe.address();
if (!address || typeof address === 'string') {
probe.close(() => reject(new Error('Failed to allocate TCP port')));
return;
}
probe.close(() => resolve(address.port));
});
});
}
tap.test('transactions: should start server and connect', async () => {
port = await getFreePort();
server = new smartdb.SmartdbServer({ port });
await server.start();
client = new MongoClient(`mongodb://127.0.0.1:${port}`, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
expect(server.running).toBeTrue();
});
tap.test('transactions: should still support explicit sessions', async () => {
const result = await client.db('admin').command({ startSession: 1 });
expect(result.ok).toEqual(1);
expect(result.id).toBeTruthy();
const end = await client.db('admin').command({ endSessions: [result.id] });
expect(end.ok).toEqual(1);
});
tap.test('transactions: should reject transaction-scoped writes without txnNumber before mutation', async () => {
const db = client.db('txntest');
const coll = db.collection('docs');
await coll.insertOne({ key: 'outside', value: 1 });
let threw = false;
try {
await db.command({
insert: 'docs',
documents: [{ key: 'inside-raw', value: 2 }],
startTransaction: true,
autocommit: false,
});
} catch (err: any) {
threw = true;
expect(err.code).toEqual(14);
expect(err.codeName).toEqual('TypeMismatch');
}
expect(threw).toBeTrue();
expect(await coll.countDocuments({ key: 'inside-raw' })).toEqual(0);
expect(await coll.countDocuments({ key: 'outside' })).toEqual(1);
});
tap.test('transactions: official driver transaction should commit buffered writes', async () => {
const coll = client.db('txntest').collection('driverdocs');
await coll.insertOne({ key: 'outside-driver', value: 0 });
const session = client.startSession();
try {
session.startTransaction();
await coll.insertOne({ key: 'inside-driver', value: 1 }, { session });
const inTxn = await coll.findOne({ key: 'inside-driver' }, { session });
expect(inTxn).toBeTruthy();
expect(await coll.countDocuments({ key: 'inside-driver' })).toEqual(0);
await session.commitTransaction();
} finally {
await session.endSession();
}
expect(await coll.countDocuments({ key: 'inside-driver' })).toEqual(1);
expect(await coll.countDocuments({ key: 'outside-driver' })).toEqual(1);
});
tap.test('transactions: abort should discard buffered writes', async () => {
const coll = client.db('txntest').collection('abortdocs');
const session = client.startSession();
try {
session.startTransaction();
await coll.insertOne({ key: 'abort-me', value: 1 }, { session });
expect(await coll.findOne({ key: 'abort-me' }, { session })).toBeTruthy();
await session.abortTransaction();
} finally {
await session.endSession();
}
expect(await coll.findOne({ key: 'abort-me' })).toBeNull();
});
tap.test('transactions: update and delete should commit atomically', async () => {
const coll = client.db('txntest').collection('mutations');
await coll.insertMany([
{ key: 'update-me', value: 1 },
{ key: 'delete-me', value: 2 },
]);
const session = client.startSession();
try {
session.startTransaction();
await coll.updateOne({ key: 'update-me' }, { $set: { value: 10 } }, { session });
await coll.deleteOne({ key: 'delete-me' }, { session });
expect((await coll.findOne({ key: 'update-me' }, { session }))!.value).toEqual(10);
expect(await coll.findOne({ key: 'delete-me' }, { session })).toBeNull();
expect((await coll.findOne({ key: 'update-me' }))!.value).toEqual(1);
expect(await coll.findOne({ key: 'delete-me' })).toBeTruthy();
await session.commitTransaction();
} finally {
await session.endSession();
}
expect((await coll.findOne({ key: 'update-me' }))!.value).toEqual(10);
expect(await coll.findOne({ key: 'delete-me' })).toBeNull();
});
tap.test('transactions: commit and abort without active transaction should be explicit errors', async () => {
for (const command of [{ commitTransaction: 1 }, { abortTransaction: 1 }]) {
let threw = false;
try {
await client.db('admin').command(command);
} catch (err: any) {
threw = true;
expect(err.code).toEqual(251);
expect(err.codeName).toEqual('NoSuchTransaction');
}
expect(threw).toBeTrue();
}
});
tap.test('transactions: serverStatus should expose transaction and oplog metrics', async () => {
const status = await client.db('admin').command({ serverStatus: 1 });
expect(status.ok).toEqual(1);
expect(status.transactions.currentActive).toEqual(0);
expect(status.logicalSessionRecordCache.activeSessionsCount).toBeGreaterThanOrEqual(0);
expect(status.oplog.totalEntries).toBeGreaterThan(0);
});
tap.test('transactions: cleanup', async () => {
await client.close();
await server.stop();
expect(server.running).toBeFalse();
});
export default tap.start();
+180
View File
@@ -0,0 +1,180 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartdb from '../ts/index.js';
import { MongoClient, Db } from 'mongodb';
import * as fs from 'fs';
import * as path from 'path';
import * as os from 'os';
// ---------------------------------------------------------------------------
// Test: Unique index enforcement via wire protocol
// Covers: unique index pre-check, createIndexes persistence, index restoration
// ---------------------------------------------------------------------------
let tmpDir: string;
let localDb: smartdb.LocalSmartDb;
let client: MongoClient;
let db: Db;
function makeTmpDir(): string {
return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-unique-test-'));
}
function cleanTmpDir(dir: string): void {
if (fs.existsSync(dir)) {
fs.rmSync(dir, { recursive: true, force: true });
}
}
// ============================================================================
// Setup
// ============================================================================
tap.test('setup: start local db', async () => {
tmpDir = makeTmpDir();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('uniquetest');
});
// ============================================================================
// Unique index enforcement on insert
// ============================================================================
tap.test('unique-index: createIndex with unique: true', async () => {
const coll = db.collection('users');
await coll.insertOne({ email: 'alice@example.com', name: 'Alice' });
const indexName = await coll.createIndex({ email: 1 }, { unique: true });
expect(indexName).toBeTruthy();
});
tap.test('unique-index: reject duplicate on insertOne', async () => {
const coll = db.collection('users');
let threw = false;
try {
await coll.insertOne({ email: 'alice@example.com', name: 'Alice2' });
} catch (err: any) {
threw = true;
expect(err.code).toEqual(11000);
}
expect(threw).toBeTrue();
// Verify only 1 document exists
const count = await coll.countDocuments();
expect(count).toEqual(1);
});
tap.test('unique-index: allow insert with different unique value', async () => {
const coll = db.collection('users');
await coll.insertOne({ email: 'bob@example.com', name: 'Bob' });
const count = await coll.countDocuments();
expect(count).toEqual(2);
});
// ============================================================================
// Unique index enforcement on update
// ============================================================================
tap.test('unique-index: reject duplicate on updateOne that changes unique field', async () => {
const coll = db.collection('users');
let threw = false;
try {
await coll.updateOne(
{ email: 'bob@example.com' },
{ $set: { email: 'alice@example.com' } }
);
} catch (err: any) {
threw = true;
expect(err.code).toEqual(11000);
}
expect(threw).toBeTrue();
// Bob's email should be unchanged
const bob = await coll.findOne({ name: 'Bob' });
expect(bob!.email).toEqual('bob@example.com');
});
tap.test('unique-index: allow update that keeps same unique value', async () => {
const coll = db.collection('users');
await coll.updateOne(
{ email: 'bob@example.com' },
{ $set: { name: 'Robert' } }
);
const bob = await coll.findOne({ email: 'bob@example.com' });
expect(bob!.name).toEqual('Robert');
});
// ============================================================================
// Unique index enforcement on upsert
// ============================================================================
tap.test('unique-index: reject duplicate on upsert insert', async () => {
const coll = db.collection('users');
let threw = false;
try {
await coll.updateOne(
{ email: 'new@example.com' },
{ $set: { email: 'alice@example.com', name: 'Imposter' } },
{ upsert: true }
);
} catch (err: any) {
threw = true;
}
expect(threw).toBeTrue();
});
// ============================================================================
// Unique index survives restart (persistence + restoration)
// ============================================================================
tap.test('unique-index: stop and restart', async () => {
await client.close();
await localDb.stop();
localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir });
const info = await localDb.start();
client = new MongoClient(info.connectionUri, {
directConnection: true,
serverSelectionTimeoutMS: 5000,
});
await client.connect();
db = client.db('uniquetest');
});
tap.test('unique-index: enforcement persists after restart', async () => {
const coll = db.collection('users');
// Data should still be there
const count = await coll.countDocuments();
expect(count).toEqual(2);
// Unique constraint should still be enforced without calling createIndex again
let threw = false;
try {
await coll.insertOne({ email: 'alice@example.com', name: 'Alice3' });
} catch (err: any) {
threw = true;
expect(err.code).toEqual(11000);
}
expect(threw).toBeTrue();
// Count unchanged
const countAfter = await coll.countDocuments();
expect(countAfter).toEqual(2);
});
// ============================================================================
// Cleanup
// ============================================================================
tap.test('unique-index: cleanup', async () => {
await client.close();
await localDb.stop();
cleanTmpDir(tmpDir);
});
export default tap.start();
+1 -1
View File
@@ -3,6 +3,6 @@
*/
export const commitinfo = {
name: '@push.rocks/smartdb',
version: '2.5.5',
version: '2.8.0',
description: 'A MongoDB-compatible embedded database server with wire protocol support, backed by a high-performance Rust engine.'
}
+6 -1
View File
@@ -2,7 +2,12 @@
// Export server (the main entry point for using SmartDB)
export { SmartdbServer } from './server/SmartdbServer.js';
export type { ISmartdbServerOptions } from './server/SmartdbServer.js';
export type {
ISmartdbAuthOptions,
ISmartdbAuthUser,
ISmartdbServerOptions,
ISmartdbTlsOptions,
} from './server/SmartdbServer.js';
// Export bridge for advanced usage
export { RustDbBridge } from './rust-db-bridge.js';
+22
View File
@@ -76,6 +76,10 @@ export interface ISmartDbMetrics {
collections: number;
oplogEntries: number;
oplogCurrentSeq: number;
sessions: number;
activeTransactions: number;
authEnabled: boolean;
authUsers: number;
uptimeSeconds: number;
}
@@ -117,6 +121,24 @@ interface ISmartDbRustConfig {
storagePath?: string;
persistPath?: string;
persistIntervalMs?: number;
auth?: {
enabled?: boolean;
users?: Array<{
username: string;
password: string;
database?: string;
roles?: string[];
}>;
usersPath?: string;
scramIterations?: number;
};
tls?: {
enabled?: boolean;
certPath?: string;
keyPath?: string;
caPath?: string;
requireClientCert?: boolean;
};
}
/**
+32 -1
View File
@@ -28,6 +28,32 @@ export interface ISmartdbServerOptions {
persistPath?: string;
/** Persistence interval in ms (default: 60000) */
persistIntervalMs?: number;
/** Authentication configuration. Disabled by default. */
auth?: ISmartdbAuthOptions;
/** TLS transport configuration for TCP listeners. Disabled by default. */
tls?: ISmartdbTlsOptions;
}
export interface ISmartdbAuthOptions {
enabled?: boolean;
users?: ISmartdbAuthUser[];
usersPath?: string;
scramIterations?: number;
}
export interface ISmartdbAuthUser {
username: string;
password: string;
database?: string;
roles?: string[];
}
export interface ISmartdbTlsOptions {
enabled?: boolean;
certPath?: string;
keyPath?: string;
caPath?: string;
requireClientCert?: boolean;
}
/**
@@ -64,6 +90,8 @@ export class SmartdbServer {
storagePath: options.storagePath ?? './data',
persistPath: options.persistPath,
persistIntervalMs: options.persistIntervalMs ?? 60000,
auth: options.auth,
tls: options.tls,
};
this.bridge = new RustDbBridge();
}
@@ -106,6 +134,8 @@ export class SmartdbServer {
storagePath: this.options.storagePath,
persistPath: this.options.persistPath,
persistIntervalMs: this.options.persistIntervalMs,
auth: this.options.auth,
tls: this.options.tls,
});
this.resolvedConnectionUri = result.connectionUri;
@@ -142,7 +172,8 @@ export class SmartdbServer {
const encodedPath = encodeURIComponent(this.options.socketPath);
return `mongodb://${encodedPath}`;
}
return `mongodb://${this.options.host ?? '127.0.0.1'}:${this.options.port ?? 27017}`;
const baseUri = `mongodb://${this.options.host ?? '127.0.0.1'}:${this.options.port ?? 27017}`;
return this.options.tls?.enabled ? `${baseUri}/?tls=true` : baseUri;
}
/**
File diff suppressed because one or more lines are too long