From ed2c02bcf9cc1e49de4d168c3f63b32138f25715 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Wed, 29 Apr 2026 22:01:43 +0000 Subject: [PATCH] feat(enterprise): add auth TLS and recovery hardening --- readme.md | 56 ++ rust/Cargo.lock | 320 +++++++++- rust/Cargo.toml | 14 + rust/crates/rustdb-auth/Cargo.toml | 20 + rust/crates/rustdb-auth/src/lib.rs | 565 ++++++++++++++++++ rust/crates/rustdb-commands/Cargo.toml | 1 + rust/crates/rustdb-commands/src/context.rs | 40 ++ rust/crates/rustdb-commands/src/error.rs | 12 + .../src/handlers/admin_handler.rs | 184 +++++- .../src/handlers/auth_handler.rs | 87 +++ .../src/handlers/hello_handler.rs | 25 +- .../rustdb-commands/src/handlers/mod.rs | 1 + rust/crates/rustdb-commands/src/lib.rs | 2 +- rust/crates/rustdb-commands/src/router.rs | 114 +++- rust/crates/rustdb-config/src/lib.rs | 163 ++++- rust/crates/rustdb-storage/src/file.rs | 33 +- rust/crates/rustdb-storage/src/keydir.rs | 55 +- rust/crates/rustdb/Cargo.toml | 3 + rust/crates/rustdb/src/lib.rs | 118 +++- rust/crates/rustdb/src/management.rs | 3 + test/test.auth.ts | 173 ++++++ test/test.crash-recovery.ts | 91 +++ test/test.tls.ts | 171 ++++++ test/test.transactions.ts | 115 ++++ ts/ts_smartdb/index.ts | 7 +- ts/ts_smartdb/rust-db-bridge.ts | 18 + ts/ts_smartdb/server/SmartdbServer.ts | 33 +- 27 files changed, 2369 insertions(+), 55 deletions(-) create mode 100644 rust/crates/rustdb-auth/Cargo.toml create mode 100644 rust/crates/rustdb-auth/src/lib.rs create mode 100644 rust/crates/rustdb-commands/src/handlers/auth_handler.rs create mode 100644 test/test.auth.ts create mode 100644 test/test.crash-recovery.ts create mode 100644 test/test.tls.ts create mode 100644 test/test.transactions.ts diff --git a/readme.md b/readme.md index cfc6de3..ffd7812 100644 --- a/readme.md +++ b/readme.md @@ -248,6 +248,62 @@ const server = new SmartdbServer({ persistPath: './data/snapshot.json', persistIntervalMs: 30000, // Save every 30s }); + +// TLS transport for TCP mode +const tlsServer = new SmartdbServer({ + port: 27017, + tls: { + enabled: true, + certPath: './certs/server.pem', + keyPath: './certs/server.key', + // caPath: './certs/client-ca.pem', + // requireClientCert: true, // Enables mTLS client certificate checks + }, +}); + +// SCRAM-SHA-256 authentication +const secureServer = new SmartdbServer({ + port: 27017, + auth: { + enabled: true, + usersPath: './data/smartdb-users.json', // Optional: persists derived SCRAM credentials + users: [ + { + username: 'root', + password: 'change-me', + database: 'admin', + roles: ['root'], + }, + ], + }, +}); +``` + +When `auth.enabled` is true, protected commands require successful SCRAM-SHA-256 authentication through the official MongoDB driver: + +```typescript +const client = new MongoClient('mongodb://root:change-me@127.0.0.1:27017/admin?authSource=admin', { + directConnection: true, +}); +await client.connect(); +``` + +TLS is available for TCP listeners. `getConnectionUri()` includes `?tls=true` when TLS is enabled; pass the trusted CA to the MongoDB driver with `tlsCAFile`, `ca`, or `secureContext`. + +Authentication verifies SCRAM credentials, denies unauthenticated commands, and enforces command-level built-in roles for supported operations. + +Supported built-in role names are `root`, `read`, `readWrite`, `dbAdmin`, `userAdmin`, `clusterMonitor`, plus `readAnyDatabase`, `readWriteAnyDatabase`, `dbAdminAnyDatabase`, and `userAdminAnyDatabase`. When `usersPath` is set, SmartDB persists SCRAM credential material atomically and does not store plaintext passwords. + +Basic user management commands are available for authenticated users with `root` or `userAdmin` privileges: + +```typescript +await client.db('admin').command({ + createUser: 'reader', + pwd: 'readpass', + roles: [{ role: 'read', db: 'myapp' }], +}); + +await client.db('admin').command({ usersInfo: 'reader' }); ``` #### Methods & Properties diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 15cc149..50b3795 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -60,7 +60,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -71,7 +71,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -124,6 +124,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bson" version = "2.15.0" @@ -139,7 +148,7 @@ dependencies = [ "indexmap", "js-sys", "once_cell", - "rand", + "rand 0.9.2", "serde", "serde_bytes", "serde_json", @@ -221,6 +230,15 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -236,6 +254,16 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -259,6 +287,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -272,7 +311,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -342,6 +381,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -415,6 +464,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "id-arena" version = "2.3.0" @@ -536,7 +594,7 @@ checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -545,7 +603,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -589,6 +647,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", +] + [[package]] name = "pin-project-lite" version = "0.2.17" @@ -656,14 +724,35 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "rand" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] @@ -673,7 +762,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", ] [[package]] @@ -723,6 +821,20 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustdb" version = "0.1.0" @@ -735,6 +847,7 @@ dependencies = [ "dashmap", "futures-util", "mimalloc", + "rustdb-auth", "rustdb-commands", "rustdb-config", "rustdb-index", @@ -742,14 +855,33 @@ dependencies = [ "rustdb-storage", "rustdb-txn", "rustdb-wire", + "rustls-pemfile", "serde", "serde_json", "tokio", + "tokio-rustls", "tokio-util", "tracing", "tracing-subscriber", ] +[[package]] +name = "rustdb-auth" +version = "0.1.0" +dependencies = [ + "base64", + "bson", + "hmac", + "pbkdf2", + "rand 0.8.6", + "rustdb-config", + "serde", + "serde_json", + "sha2", + "subtle", + "thiserror", +] + [[package]] name = "rustdb-commands" version = "0.1.0" @@ -757,6 +889,7 @@ dependencies = [ "async-trait", "bson", "dashmap", + "rustdb-auth", "rustdb-config", "rustdb-index", "rustdb-query", @@ -858,7 +991,50 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", ] [[package]] @@ -933,6 +1109,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -977,7 +1164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -986,6 +1173,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.117" @@ -1013,7 +1206,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -1090,7 +1283,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -1104,6 +1297,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1178,6 +1381,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + [[package]] name = "unicode-ident" version = "1.0.24" @@ -1190,6 +1399,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "utf8parse" version = "0.2.2" @@ -1329,6 +1544,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -1338,6 +1562,70 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "wit-bindgen" version = "0.51.0" @@ -1455,6 +1743,12 @@ dependencies = [ "syn", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zmij" version = "1.0.21" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index c10f73a..c01e30f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/rustdb-storage", "crates/rustdb-index", "crates/rustdb-txn", + "crates/rustdb-auth", "crates/rustdb-commands", ] @@ -51,6 +52,10 @@ dashmap = "6" # Cancellation / utility tokio-util = { version = "0.7", features = ["codec"] } +# TLS transport +tokio-rustls = { version = "0.26", default-features = false, features = ["ring", "tls12"] } +rustls-pemfile = "2" + # mimalloc allocator mimalloc = "0.1" @@ -60,6 +65,14 @@ crc32fast = "1" # Regex for $regex operator regex = "1" +# Auth crypto +base64 = "0.22" +hmac = "0.12" +pbkdf2 = { version = "0.12", features = ["hmac"] } +rand = "0.8" +sha2 = "0.10" +subtle = "2" + # UUID for sessions uuid = { version = "1", features = ["v4", "serde"] } @@ -76,4 +89,5 @@ rustdb-query = { path = "crates/rustdb-query" } rustdb-storage = { path = "crates/rustdb-storage" } rustdb-index = { path = "crates/rustdb-index" } rustdb-txn = { path = "crates/rustdb-txn" } +rustdb-auth = { path = "crates/rustdb-auth" } rustdb-commands = { path = "crates/rustdb-commands" } diff --git a/rust/crates/rustdb-auth/Cargo.toml b/rust/crates/rustdb-auth/Cargo.toml new file mode 100644 index 0000000..1e2ba6d --- /dev/null +++ b/rust/crates/rustdb-auth/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "rustdb-auth" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "Authentication primitives for RustDb" + +[dependencies] +base64 = { workspace = true } +bson = { workspace = true } +hmac = { workspace = true } +pbkdf2 = { workspace = true } +rand = { workspace = true } +rustdb-config = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sha2 = { workspace = true } +subtle = { workspace = true } +thiserror = { workspace = true } diff --git a/rust/crates/rustdb-auth/src/lib.rs b/rust/crates/rustdb-auth/src/lib.rs new file mode 100644 index 0000000..1de1da9 --- /dev/null +++ b/rust/crates/rustdb-auth/src/lib.rs @@ -0,0 +1,565 @@ +use std::collections::HashMap; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::sync::RwLock; + +use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _}; +use hmac::{Hmac, Mac}; +use pbkdf2::pbkdf2_hmac; +use rand::{rngs::OsRng, RngCore}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use subtle::ConstantTimeEq; + +use rustdb_config::{AuthOptions, AuthUserOptions}; + +type HmacSha256 = Hmac; + +const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; + +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + #[error("authentication is disabled")] + Disabled, + #[error("unsupported authentication mechanism: {0}")] + UnsupportedMechanism(String), + #[error("invalid SCRAM payload: {0}")] + InvalidPayload(String), + #[error("authentication failed")] + AuthenticationFailed, + #[error("unknown SASL conversation")] + UnknownConversation, + #[error("user already exists: {0}")] + UserAlreadyExists(String), + #[error("user not found: {0}")] + UserNotFound(String), + #[error("auth metadata persistence failed: {0}")] + Persistence(String), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthAction { + Read, + Write, + DbAdmin, + UserAdmin, + ClusterMonitor, +} + +#[derive(Debug, Clone)] +pub struct AuthenticatedUser { + pub username: String, + pub database: String, + pub roles: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ScramCredential { + salt: Vec, + iterations: u32, + stored_key: Vec, + server_key: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AuthUser { + username: String, + database: String, + roles: Vec, + scram_sha256: ScramCredential, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +struct PersistedAuthState { + users: Vec, +} + +#[derive(Debug, Clone)] +pub struct ScramConversation { + user: AuthenticatedUser, + client_first_bare: String, + server_first: String, + nonce: String, + stored_key: Vec, + server_key: Vec, +} + +#[derive(Debug, Clone)] +pub struct ScramStartResult { + pub payload: Vec, + pub conversation: ScramConversation, +} + +#[derive(Debug, Clone)] +pub struct ScramContinueResult { + pub payload: Vec, + pub user: AuthenticatedUser, +} + +#[derive(Debug)] +pub struct AuthEngine { + enabled: bool, + users: RwLock>, + users_path: Option, + scram_iterations: u32, +} + +impl AuthEngine { + pub fn from_options(options: &AuthOptions) -> Result { + let users_path = options.users_path.as_ref().map(PathBuf::from); + let mut users = if let Some(ref path) = users_path { + load_users(path)? + } else { + HashMap::new() + }; + + let mut changed = false; + for user_options in &options.users { + let key = user_key(&user_options.database, &user_options.username); + if !users.contains_key(&key) { + let user = AuthUser::from_options(user_options, options.scram_iterations); + users.insert(key, user); + changed = true; + } + } + + if changed { + if let Some(ref path) = users_path { + persist_users(path, &users)?; + } + } + + Ok(Self { + enabled: options.enabled, + users: RwLock::new(users), + users_path, + scram_iterations: options.scram_iterations, + }) + } + + pub fn disabled() -> Self { + Self { + enabled: false, + users: RwLock::new(HashMap::new()), + users_path: None, + scram_iterations: 15000, + } + } + + pub fn enabled(&self) -> bool { + self.enabled + } + + pub fn supported_mechanisms(&self, namespace_user: &str) -> Vec { + let Some((database, username)) = namespace_user.split_once('.') else { + return Vec::new(); + }; + let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); + if users.contains_key(&user_key(database, username)) { + vec![SCRAM_SHA_256.to_string()] + } else { + Vec::new() + } + } + + pub fn is_authorized( + &self, + authenticated_users: &[AuthenticatedUser], + target_db: &str, + action: AuthAction, + ) -> bool { + authenticated_users + .iter() + .any(|user| user.roles.iter().any(|role| role_allows(role, user, target_db, action))) + } + + pub fn create_user( + &self, + database: &str, + username: &str, + password: &str, + roles: Vec, + ) -> Result<(), AuthError> { + let key = user_key(database, username); + let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + if users.contains_key(&key) { + return Err(AuthError::UserAlreadyExists(format!("{database}.{username}"))); + } + let options = AuthUserOptions { + username: username.to_string(), + password: password.to_string(), + database: database.to_string(), + roles, + }; + users.insert(key, AuthUser::from_options(&options, self.scram_iterations)); + self.persist_locked(&users) + } + + pub fn drop_user(&self, database: &str, username: &str) -> Result<(), AuthError> { + let key = user_key(database, username); + let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + if users.remove(&key).is_none() { + return Err(AuthError::UserNotFound(format!("{database}.{username}"))); + } + self.persist_locked(&users) + } + + pub fn update_user( + &self, + database: &str, + username: &str, + password: Option<&str>, + roles: Option>, + ) -> Result<(), AuthError> { + let key = user_key(database, username); + let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + let user = users + .get_mut(&key) + .ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?; + if let Some(new_roles) = roles { + user.roles = new_roles; + } + if let Some(new_password) = password { + let options = AuthUserOptions { + username: username.to_string(), + password: new_password.to_string(), + database: database.to_string(), + roles: user.roles.clone(), + }; + user.scram_sha256 = AuthUser::from_options(&options, self.scram_iterations).scram_sha256; + } + self.persist_locked(&users) + } + + pub fn grant_roles( + &self, + database: &str, + username: &str, + roles: Vec, + ) -> Result<(), AuthError> { + let key = user_key(database, username); + let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + let user = users + .get_mut(&key) + .ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?; + for role in roles { + if !user.roles.contains(&role) { + user.roles.push(role); + } + } + self.persist_locked(&users) + } + + pub fn revoke_roles( + &self, + database: &str, + username: &str, + roles: Vec, + ) -> Result<(), AuthError> { + let key = user_key(database, username); + let mut users = self.users.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + let user = users + .get_mut(&key) + .ok_or_else(|| AuthError::UserNotFound(format!("{database}.{username}")))?; + user.roles.retain(|role| !roles.contains(role)); + self.persist_locked(&users) + } + + pub fn users_info(&self, database: &str, username: Option<&str>) -> Vec { + let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); + users + .values() + .filter(|user| user.database == database) + .filter(|user| username.map(|name| user.username == name).unwrap_or(true)) + .map(AuthUser::to_authenticated_user) + .collect() + } + + pub fn start_scram_sha256( + &self, + database: &str, + payload: &[u8], + ) -> Result { + if !self.enabled { + return Err(AuthError::Disabled); + } + + let message = std::str::from_utf8(payload) + .map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?; + let client_first_bare = message + .strip_prefix("n,,") + .ok_or_else(|| AuthError::InvalidPayload("expected SCRAM gs2 header 'n,,'".to_string()))?; + let attrs = parse_scram_attrs(client_first_bare); + let raw_username = attrs + .get("n") + .ok_or_else(|| AuthError::InvalidPayload("missing username".to_string()))?; + let username = decode_scram_name(raw_username); + let client_nonce = attrs + .get("r") + .ok_or_else(|| AuthError::InvalidPayload("missing client nonce".to_string()))?; + + let users = self.users.read().unwrap_or_else(|poisoned| poisoned.into_inner()); + let user = users + .get(&user_key(database, &username)) + .ok_or(AuthError::AuthenticationFailed)?; + + let nonce = format!("{}{}", client_nonce, secure_base64(18)); + let server_first = format!( + "r={},s={},i={}", + nonce, + BASE64_STANDARD.encode(&user.scram_sha256.salt), + user.scram_sha256.iterations, + ); + + Ok(ScramStartResult { + payload: server_first.as_bytes().to_vec(), + conversation: ScramConversation { + user: user.to_authenticated_user(), + client_first_bare: client_first_bare.to_string(), + server_first: server_first.clone(), + nonce, + stored_key: user.scram_sha256.stored_key.clone(), + server_key: user.scram_sha256.server_key.clone(), + }, + }) + } + + pub fn continue_scram_sha256( + &self, + conversation: ScramConversation, + payload: &[u8], + ) -> Result { + let message = std::str::from_utf8(payload) + .map_err(|_| AuthError::InvalidPayload("payload is not valid UTF-8".to_string()))?; + let proof_marker = ",p="; + let proof_pos = message + .rfind(proof_marker) + .ok_or_else(|| AuthError::InvalidPayload("missing client proof".to_string()))?; + let client_final_without_proof = &message[..proof_pos]; + let proof_b64 = &message[proof_pos + proof_marker.len()..]; + let attrs = parse_scram_attrs(client_final_without_proof); + let nonce = attrs + .get("r") + .ok_or_else(|| AuthError::InvalidPayload("missing nonce".to_string()))?; + if nonce != &conversation.nonce { + return Err(AuthError::AuthenticationFailed); + } + + let client_proof = BASE64_STANDARD + .decode(proof_b64.as_bytes()) + .map_err(|_| AuthError::InvalidPayload("invalid client proof encoding".to_string()))?; + if client_proof.len() != 32 || conversation.stored_key.len() != 32 { + return Err(AuthError::AuthenticationFailed); + } + + let auth_message = format!( + "{},{},{}", + conversation.client_first_bare, + conversation.server_first, + client_final_without_proof, + ); + let client_signature = hmac_sha256(&conversation.stored_key, auth_message.as_bytes()); + let client_key: Vec = client_proof + .iter() + .zip(client_signature.iter()) + .map(|(proof_byte, signature_byte)| proof_byte ^ signature_byte) + .collect(); + let computed_stored_key = Sha256::digest(&client_key).to_vec(); + + if computed_stored_key.ct_eq(&conversation.stored_key).unwrap_u8() != 1 { + return Err(AuthError::AuthenticationFailed); + } + + let server_signature = hmac_sha256(&conversation.server_key, auth_message.as_bytes()); + let server_final = format!("v={}", BASE64_STANDARD.encode(server_signature)); + + Ok(ScramContinueResult { + payload: server_final.as_bytes().to_vec(), + user: conversation.user, + }) + } + + fn persist_locked(&self, users: &HashMap) -> Result<(), AuthError> { + if let Some(ref path) = self.users_path { + persist_users(path, users)?; + } + Ok(()) + } +} + +impl Default for AuthEngine { + fn default() -> Self { + Self::disabled() + } +} + +impl AuthUser { + fn from_options(options: &AuthUserOptions, iterations: u32) -> Self { + let salt = secure_random(24); + let salted_password = salted_password(options.password.as_bytes(), &salt, iterations); + let client_key = hmac_sha256(&salted_password, b"Client Key"); + let stored_key = Sha256::digest(&client_key).to_vec(); + let server_key = hmac_sha256(&salted_password, b"Server Key"); + + Self { + username: options.username.clone(), + database: options.database.clone(), + roles: options.roles.clone(), + scram_sha256: ScramCredential { + salt, + iterations, + stored_key, + server_key, + }, + } + } + + fn to_authenticated_user(&self) -> AuthenticatedUser { + AuthenticatedUser { + username: self.username.clone(), + database: self.database.clone(), + roles: self.roles.clone(), + } + } +} + +fn role_allows(role: &str, user: &AuthenticatedUser, target_db: &str, action: AuthAction) -> bool { + let (role_db, role_name) = role.split_once('.').unwrap_or(("", role)); + if role_name == "root" { + return true; + } + + let any_database = role_name.ends_with("AnyDatabase"); + let scoped_db = if role_db.is_empty() { &user.database } else { role_db }; + if !any_database && scoped_db != target_db { + return false; + } + + match role_name { + "read" | "readAnyDatabase" => action == AuthAction::Read, + "readWrite" | "readWriteAnyDatabase" => { + matches!(action, AuthAction::Read | AuthAction::Write) + } + "dbAdmin" | "dbAdminAnyDatabase" => action == AuthAction::DbAdmin, + "userAdmin" | "userAdminAnyDatabase" => action == AuthAction::UserAdmin, + "clusterMonitor" => action == AuthAction::ClusterMonitor, + _ => false, + } +} + +fn load_users(path: &Path) -> Result, AuthError> { + if !path.exists() { + return Ok(HashMap::new()); + } + let data = std::fs::read_to_string(path).map_err(|e| AuthError::Persistence(e.to_string()))?; + let persisted: PersistedAuthState = serde_json::from_str(&data) + .map_err(|e| AuthError::Persistence(format!("failed to parse users file: {e}")))?; + Ok(persisted + .users + .into_iter() + .map(|user| (user_key(&user.database, &user.username), user)) + .collect()) +} + +fn persist_users(path: &Path, users: &HashMap) -> Result<(), AuthError> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|e| AuthError::Persistence(e.to_string()))?; + } + + let mut user_list: Vec = users.values().cloned().collect(); + user_list.sort_by(|a, b| a.database.cmp(&b.database).then(a.username.cmp(&b.username))); + let payload = serde_json::to_vec_pretty(&PersistedAuthState { users: user_list }) + .map_err(|e| AuthError::Persistence(e.to_string()))?; + + let tmp_path = path.with_extension("tmp"); + { + let mut file = std::fs::File::create(&tmp_path) + .map_err(|e| AuthError::Persistence(e.to_string()))?; + file.write_all(&payload) + .map_err(|e| AuthError::Persistence(e.to_string()))?; + file.sync_all() + .map_err(|e| AuthError::Persistence(e.to_string()))?; + } + std::fs::rename(&tmp_path, path).map_err(|e| AuthError::Persistence(e.to_string()))?; + if let Some(parent) = path.parent() { + if let Ok(dir) = std::fs::File::open(parent) { + let _ = dir.sync_all(); + } + } + Ok(()) +} + +fn user_key(database: &str, username: &str) -> String { + format!("{}\0{}", database, username) +} + +fn salted_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec { + let mut output = [0u8; 32]; + pbkdf2_hmac::(password, salt, iterations, &mut output); + output.to_vec() +} + +fn hmac_sha256(key: &[u8], message: &[u8]) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts keys of any size"); + mac.update(message); + mac.finalize().into_bytes().to_vec() +} + +fn secure_random(len: usize) -> Vec { + let mut bytes = vec![0u8; len]; + OsRng.fill_bytes(&mut bytes); + bytes +} + +fn secure_base64(len: usize) -> String { + BASE64_STANDARD.encode(secure_random(len)) +} + +fn parse_scram_attrs(input: &str) -> HashMap { + let mut result = HashMap::new(); + for part in input.split(',') { + if let Some((key, value)) = part.split_once('=') { + result.insert(key.to_string(), value.to_string()); + } + } + result +} + +fn decode_scram_name(input: &str) -> String { + input.replace("=2C", ",").replace("=3D", "=") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mechanism_lookup_returns_scram_sha256() { + let options = AuthOptions { + enabled: true, + users: vec![AuthUserOptions { + username: "root".to_string(), + password: "secret".to_string(), + database: "admin".to_string(), + roles: vec!["root".to_string()], + }], + users_path: None, + scram_iterations: 4096, + }; + let engine = AuthEngine::from_options(&options).unwrap(); + assert_eq!(engine.supported_mechanisms("admin.root"), vec![SCRAM_SHA_256.to_string()]); + } + + #[test] + fn read_write_role_allows_read_and_write_only_on_own_db() { + let user = AuthenticatedUser { + username: "app".to_string(), + database: "appdb".to_string(), + roles: vec!["readWrite".to_string()], + }; + assert!(role_allows("readWrite", &user, "appdb", AuthAction::Read)); + assert!(role_allows("readWrite", &user, "appdb", AuthAction::Write)); + assert!(!role_allows("readWrite", &user, "other", AuthAction::Read)); + assert!(!role_allows("readWrite", &user, "appdb", AuthAction::DbAdmin)); + } +} diff --git a/rust/crates/rustdb-commands/Cargo.toml b/rust/crates/rustdb-commands/Cargo.toml index 18ed9a9..faee471 100644 --- a/rust/crates/rustdb-commands/Cargo.toml +++ b/rust/crates/rustdb-commands/Cargo.toml @@ -22,3 +22,4 @@ rustdb-query = { workspace = true } rustdb-storage = { workspace = true } rustdb-index = { workspace = true } rustdb-txn = { workspace = true } +rustdb-auth = { workspace = true } diff --git a/rust/crates/rustdb-commands/src/context.rs b/rust/crates/rustdb-commands/src/context.rs index 4b1bb0b..db87f20 100644 --- a/rust/crates/rustdb-commands/src/context.rs +++ b/rust/crates/rustdb-commands/src/context.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use bson::{Bson, Document}; use dashmap::DashMap; +use rustdb_auth::{AuthEngine, AuthenticatedUser, ScramConversation}; use rustdb_index::{IndexEngine, IndexOptions}; use rustdb_storage::{OpLog, StorageAdapter}; use rustdb_txn::{SessionEngine, TransactionEngine}; @@ -22,6 +23,8 @@ pub struct CommandContext { pub start_time: std::time::Instant, /// Operation log for point-in-time replay. pub oplog: Arc, + /// Authentication engine and user store. + pub auth: Arc, } impl CommandContext { @@ -85,6 +88,43 @@ impl CommandContext { } } +/// Per-client connection state. Authentication is socket-scoped in MongoDB. +pub struct ConnectionState { + pub authenticated_users: Vec, + pub sasl_conversations: std::collections::HashMap, + next_conversation_id: i32, +} + +impl ConnectionState { + pub fn new() -> Self { + Self { + authenticated_users: Vec::new(), + sasl_conversations: std::collections::HashMap::new(), + next_conversation_id: 1, + } + } + + pub fn is_authenticated(&self) -> bool { + !self.authenticated_users.is_empty() + } + + pub fn next_conversation_id(&mut self) -> i32 { + let id = self.next_conversation_id; + self.next_conversation_id += 1; + id + } + + pub fn authenticate(&mut self, user: AuthenticatedUser) { + self.authenticated_users.push(user); + } +} + +impl Default for ConnectionState { + fn default() -> Self { + Self::new() + } +} + /// State of an open cursor from a find or aggregate command. pub struct CursorState { /// Documents remaining to be returned. diff --git a/rust/crates/rustdb-commands/src/error.rs b/rust/crates/rustdb-commands/src/error.rs index b652727..f5306de 100644 --- a/rust/crates/rustdb-commands/src/error.rs +++ b/rust/crates/rustdb-commands/src/error.rs @@ -30,6 +30,15 @@ pub enum CommandError { #[error("immutable field: {0}")] ImmutableField(String), + #[error("unauthorized: {0}")] + Unauthorized(String), + + #[error("authentication failed")] + AuthenticationFailed, + + #[error("illegal operation: {0}")] + IllegalOperation(String), + #[error("internal error: {0}")] InternalError(String), } @@ -47,6 +56,9 @@ impl CommandError { CommandError::NamespaceExists(_) => (48, "NamespaceExists"), CommandError::DuplicateKey(_) => (11000, "DuplicateKey"), CommandError::ImmutableField(_) => (66, "ImmutableField"), + CommandError::Unauthorized(_) => (13, "Unauthorized"), + CommandError::AuthenticationFailed => (18, "AuthenticationFailed"), + CommandError::IllegalOperation(_) => (20, "IllegalOperation"), CommandError::InternalError(_) => (1, "InternalError"), }; diff --git a/rust/crates/rustdb-commands/src/handlers/admin_handler.rs b/rust/crates/rustdb-commands/src/handlers/admin_handler.rs index 6e4ac52..75edba6 100644 --- a/rust/crates/rustdb-commands/src/handlers/admin_handler.rs +++ b/rust/crates/rustdb-commands/src/handlers/admin_handler.rs @@ -98,6 +98,18 @@ pub async fn handle( "ok": 1.0, }), + "createUser" => handle_create_user(cmd, db, ctx).await, + + "updateUser" => handle_update_user(cmd, db, ctx).await, + + "dropUser" => handle_drop_user(cmd, db, ctx).await, + + "usersInfo" => handle_users_info(cmd, db, ctx).await, + + "grantRolesToUser" => handle_grant_roles_to_user(cmd, db, ctx).await, + + "revokeRolesFromUser" => handle_revoke_roles_from_user(cmd, db, ctx).await, + "listDatabases" => handle_list_databases(cmd, ctx).await, "listCollections" => handle_list_collections(cmd, db, ctx).await, @@ -144,15 +156,9 @@ pub async fn handle( Ok(doc! { "ok": 1.0 }) } - "commitTransaction" => { - // Stub: acknowledge. - Ok(doc! { "ok": 1.0 }) - } - - "abortTransaction" => { - // Stub: acknowledge. - Ok(doc! { "ok": 1.0 }) - } + "commitTransaction" | "abortTransaction" => Err(CommandError::IllegalOperation( + "Transaction numbers are only allowed on a replica set member or mongos".into(), + )), // Auth stubs - accept silently. "saslStart" => Ok(doc! { @@ -189,6 +195,166 @@ pub async fn handle( } } +async fn handle_create_user( + cmd: &Document, + db: &str, + ctx: &CommandContext, +) -> CommandResult { + let username = cmd + .get_str("createUser") + .map_err(|_| CommandError::InvalidArgument("missing 'createUser' field".into()))?; + let password = cmd + .get_str("pwd") + .map_err(|_| CommandError::InvalidArgument("missing 'pwd' field".into()))?; + let roles = parse_roles(cmd, db, "roles")?; + ctx.auth + .create_user(db, username, password, roles) + .map_err(auth_error_to_command_error)?; + Ok(doc! { "ok": 1.0 }) +} + +async fn handle_update_user( + cmd: &Document, + db: &str, + ctx: &CommandContext, +) -> CommandResult { + let username = cmd + .get_str("updateUser") + .map_err(|_| CommandError::InvalidArgument("missing 'updateUser' field".into()))?; + let password = cmd.get_str("pwd").ok(); + let roles = if cmd.contains_key("roles") { + Some(parse_roles(cmd, db, "roles")?) + } else { + None + }; + ctx.auth + .update_user(db, username, password, roles) + .map_err(auth_error_to_command_error)?; + Ok(doc! { "ok": 1.0 }) +} + +async fn handle_drop_user( + cmd: &Document, + db: &str, + ctx: &CommandContext, +) -> CommandResult { + let username = cmd + .get_str("dropUser") + .map_err(|_| CommandError::InvalidArgument("missing 'dropUser' field".into()))?; + ctx.auth + .drop_user(db, username) + .map_err(auth_error_to_command_error)?; + Ok(doc! { "ok": 1.0 }) +} + +async fn handle_users_info( + cmd: &Document, + db: &str, + ctx: &CommandContext, +) -> CommandResult { + let username = match cmd.get("usersInfo") { + Some(Bson::String(name)) => Some(name.as_str()), + Some(Bson::Document(user_doc)) => user_doc.get_str("user").ok(), + _ => None, + }; + let users = ctx.auth.users_info(db, username); + let user_docs: Vec = users + .into_iter() + .map(|user| { + let roles: Vec = user + .roles + .iter() + .map(|role| Bson::Document(role_to_document(&user.database, role))) + .collect(); + Bson::Document(doc! { + "user": user.username, + "db": user.database, + "roles": roles, + "mechanisms": ["SCRAM-SHA-256"], + }) + }) + .collect(); + Ok(doc! { "users": user_docs, "ok": 1.0 }) +} + +async fn handle_grant_roles_to_user( + cmd: &Document, + db: &str, + ctx: &CommandContext, +) -> CommandResult { + let username = cmd + .get_str("grantRolesToUser") + .map_err(|_| CommandError::InvalidArgument("missing 'grantRolesToUser' field".into()))?; + let roles = parse_roles(cmd, db, "roles")?; + ctx.auth + .grant_roles(db, username, roles) + .map_err(auth_error_to_command_error)?; + Ok(doc! { "ok": 1.0 }) +} + +async fn handle_revoke_roles_from_user( + cmd: &Document, + db: &str, + ctx: &CommandContext, +) -> CommandResult { + let username = cmd + .get_str("revokeRolesFromUser") + .map_err(|_| CommandError::InvalidArgument("missing 'revokeRolesFromUser' field".into()))?; + let roles = parse_roles(cmd, db, "roles")?; + ctx.auth + .revoke_roles(db, username, roles) + .map_err(auth_error_to_command_error)?; + Ok(doc! { "ok": 1.0 }) +} + +fn parse_roles(cmd: &Document, db: &str, key: &str) -> CommandResult> { + let role_values = cmd + .get_array(key) + .map_err(|_| CommandError::InvalidArgument(format!("missing '{key}' array")))?; + let mut roles = Vec::with_capacity(role_values.len()); + for role_value in role_values { + match role_value { + Bson::String(role) => roles.push(role.clone()), + Bson::Document(role_doc) => { + let role = role_doc + .get_str("role") + .map_err(|_| CommandError::InvalidArgument("role document missing 'role'".into()))?; + let role_db = role_doc.get_str("db").unwrap_or(db); + if role_db == db { + roles.push(role.to_string()); + } else { + roles.push(format!("{role_db}.{role}")); + } + } + _ => return Err(CommandError::InvalidArgument("roles must be strings or documents".into())), + } + } + Ok(roles) +} + +fn role_to_document(default_db: &str, role: &str) -> Document { + if let Some((role_db, role_name)) = role.split_once('.') { + doc! { "role": role_name, "db": role_db } + } else { + doc! { "role": role, "db": default_db } + } +} + +fn auth_error_to_command_error(error: rustdb_auth::AuthError) -> CommandError { + match error { + rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message), + rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message), + rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message), + rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed, + rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message), + rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message), + rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()), + rustdb_auth::AuthError::UnknownConversation => { + CommandError::InvalidArgument("unknown SASL conversation".into()) + } + } +} + /// Handle `listDatabases` command. async fn handle_list_databases( cmd: &Document, diff --git a/rust/crates/rustdb-commands/src/handlers/auth_handler.rs b/rust/crates/rustdb-commands/src/handlers/auth_handler.rs new file mode 100644 index 0000000..9694d6a --- /dev/null +++ b/rust/crates/rustdb-commands/src/handlers/auth_handler.rs @@ -0,0 +1,87 @@ +use bson::{doc, Binary, Bson, Document}; + +use crate::context::{CommandContext, ConnectionState}; +use crate::error::{CommandError, CommandResult}; + +pub async fn handle_sasl_start( + cmd: &Document, + db: &str, + ctx: &CommandContext, + connection: &mut ConnectionState, +) -> CommandResult { + let mechanism = cmd + .get_str("mechanism") + .map_err(|_| CommandError::InvalidArgument("missing SASL mechanism".into()))?; + if mechanism != "SCRAM-SHA-256" { + return Err(CommandError::InvalidArgument(format!( + "unsupported SASL mechanism: {mechanism}" + ))); + } + + let payload = payload_bytes(cmd)?; + let result = ctx + .auth + .start_scram_sha256(db, &payload) + .map_err(map_auth_error)?; + let conversation_id = connection.next_conversation_id(); + connection + .sasl_conversations + .insert(conversation_id, result.conversation); + + Ok(doc! { + "conversationId": conversation_id, + "done": false, + "payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload }, + "ok": 1.0, + }) +} + +pub async fn handle_sasl_continue( + cmd: &Document, + ctx: &CommandContext, + connection: &mut ConnectionState, +) -> CommandResult { + let conversation_id = cmd + .get_i32("conversationId") + .map_err(|_| CommandError::InvalidArgument("missing SASL conversationId".into()))?; + let payload = payload_bytes(cmd)?; + let conversation = connection + .sasl_conversations + .remove(&conversation_id) + .ok_or_else(|| CommandError::InvalidArgument("unknown SASL conversation".into()))?; + let result = ctx + .auth + .continue_scram_sha256(conversation, &payload) + .map_err(map_auth_error)?; + connection.authenticate(result.user); + + Ok(doc! { + "conversationId": conversation_id, + "done": true, + "payload": Binary { subtype: bson::spec::BinarySubtype::Generic, bytes: result.payload }, + "ok": 1.0, + }) +} + +fn payload_bytes(cmd: &Document) -> CommandResult> { + match cmd.get("payload") { + Some(Bson::Binary(binary)) => Ok(binary.bytes.clone()), + Some(Bson::String(value)) => Ok(value.as_bytes().to_vec()), + _ => Err(CommandError::InvalidArgument("missing SASL payload".into())), + } +} + +fn map_auth_error(error: rustdb_auth::AuthError) -> CommandError { + match error { + rustdb_auth::AuthError::InvalidPayload(message) => CommandError::InvalidArgument(message), + rustdb_auth::AuthError::UnsupportedMechanism(message) => CommandError::InvalidArgument(message), + rustdb_auth::AuthError::Disabled => CommandError::Unauthorized("authentication is disabled".into()), + rustdb_auth::AuthError::UnknownConversation => { + CommandError::InvalidArgument("unknown SASL conversation".into()) + } + rustdb_auth::AuthError::AuthenticationFailed => CommandError::AuthenticationFailed, + rustdb_auth::AuthError::UserAlreadyExists(message) => CommandError::DuplicateKey(message), + rustdb_auth::AuthError::UserNotFound(message) => CommandError::NamespaceNotFound(message), + rustdb_auth::AuthError::Persistence(message) => CommandError::InternalError(message), + } +} diff --git a/rust/crates/rustdb-commands/src/handlers/hello_handler.rs b/rust/crates/rustdb-commands/src/handlers/hello_handler.rs index 0e0d75f..01f61b0 100644 --- a/rust/crates/rustdb-commands/src/handlers/hello_handler.rs +++ b/rust/crates/rustdb-commands/src/handlers/hello_handler.rs @@ -1,4 +1,4 @@ -use bson::{doc, Document}; +use bson::{doc, Bson, Document}; use crate::context::CommandContext; use crate::error::CommandResult; @@ -7,12 +7,13 @@ use crate::error::CommandResult; /// /// Returns server capabilities matching wire protocol expectations. pub async fn handle( - _cmd: &Document, + cmd: &Document, _db: &str, - _ctx: &CommandContext, + ctx: &CommandContext, ) -> CommandResult { - Ok(doc! { + let mut response = doc! { "ismaster": true, + "helloOk": true, "isWritablePrimary": true, "maxBsonObjectSize": 16_777_216_i32, "maxMessageSizeBytes": 48_000_000_i32, @@ -24,5 +25,19 @@ pub async fn handle( "maxWireVersion": 21_i32, "readOnly": false, "ok": 1.0, - }) + }; + + if ctx.auth.enabled() { + if let Ok(namespace_user) = cmd.get_str("saslSupportedMechs") { + let mechanisms: Vec = ctx + .auth + .supported_mechanisms(namespace_user) + .into_iter() + .map(Bson::String) + .collect(); + response.insert("saslSupportedMechs", Bson::Array(mechanisms)); + } + } + + Ok(response) } diff --git a/rust/crates/rustdb-commands/src/handlers/mod.rs b/rust/crates/rustdb-commands/src/handlers/mod.rs index c0002b9..1f44ce3 100644 --- a/rust/crates/rustdb-commands/src/handlers/mod.rs +++ b/rust/crates/rustdb-commands/src/handlers/mod.rs @@ -1,5 +1,6 @@ pub mod admin_handler; pub mod aggregate_handler; +pub mod auth_handler; pub mod delete_handler; pub mod find_handler; pub mod hello_handler; diff --git a/rust/crates/rustdb-commands/src/lib.rs b/rust/crates/rustdb-commands/src/lib.rs index 05aad94..6e37635 100644 --- a/rust/crates/rustdb-commands/src/lib.rs +++ b/rust/crates/rustdb-commands/src/lib.rs @@ -3,6 +3,6 @@ pub mod error; pub mod handlers; mod router; -pub use context::{CommandContext, CursorState}; +pub use context::{CommandContext, ConnectionState, CursorState}; pub use error::{CommandError, CommandResult}; pub use router::CommandRouter; diff --git a/rust/crates/rustdb-commands/src/router.rs b/rust/crates/rustdb-commands/src/router.rs index 149c375..c0c45ef 100644 --- a/rust/crates/rustdb-commands/src/router.rs +++ b/rust/crates/rustdb-commands/src/router.rs @@ -1,11 +1,12 @@ use std::sync::Arc; -use bson::Document; +use bson::{Bson, Document}; use tracing::{debug, warn}; use rustdb_wire::ParsedCommand; +use rustdb_auth::AuthAction; -use crate::context::CommandContext; +use crate::context::{CommandContext, ConnectionState}; use crate::error::CommandError; use crate::handlers; @@ -21,12 +22,46 @@ impl CommandRouter { } /// Route a parsed command to the appropriate handler, returning a BSON response document. - pub async fn route(&self, cmd: &ParsedCommand) -> Document { + pub async fn route(&self, cmd: &ParsedCommand, connection: &mut ConnectionState) -> Document { let db = &cmd.database; let command_name = cmd.command_name.as_str(); debug!(command = %command_name, database = %db, "routing command"); + if self.ctx.auth.enabled() + && !connection.is_authenticated() + && !allows_unauthenticated(command_name) + { + return CommandError::Unauthorized(format!( + "command '{}' requires authentication", + command_name, + )) + .to_error_doc(); + } + + if self.ctx.auth.enabled() && connection.is_authenticated() { + if let Some(action) = required_action(command_name, &cmd.command) { + if !self + .ctx + .auth + .is_authorized(&connection.authenticated_users, db, action) + { + return CommandError::Unauthorized(format!( + "command '{}' is not authorized for database '{}'", + command_name, db, + )) + .to_error_doc(); + } + } + } + + if transaction_command_unsupported(command_name, &cmd.command) { + return CommandError::IllegalOperation( + "Transaction numbers are only allowed on a replica set member or mongos".into(), + ) + .to_error_doc(); + } + // Extract session id if present, and touch the session. if let Some(lsid) = cmd.command.get("lsid") { if let Some(session_id) = rustdb_txn::SessionEngine::extract_session_id(lsid) { @@ -40,6 +75,14 @@ impl CommandRouter { handlers::hello_handler::handle(&cmd.command, db, &self.ctx).await } + // -- authentication -- + "saslStart" => { + handlers::auth_handler::handle_sasl_start(&cmd.command, db, &self.ctx, connection).await + } + "saslContinue" => { + handlers::auth_handler::handle_sasl_continue(&cmd.command, &self.ctx, connection).await + } + // -- query commands -- "find" => { handlers::find_handler::handle(&cmd.command, db, &self.ctx).await @@ -88,7 +131,9 @@ impl CommandRouter { | "dbStats" | "collStats" | "validate" | "explain" | "startSession" | "endSessions" | "killSessions" | "commitTransaction" | "abortTransaction" - | "saslStart" | "saslContinue" | "authenticate" | "logout" + | "authenticate" | "logout" + | "createUser" | "updateUser" | "dropUser" | "usersInfo" + | "grantRolesToUser" | "revokeRolesFromUser" | "currentOp" | "killOp" | "top" | "profile" | "compact" | "reIndex" | "fsync" | "connPoolSync" => { handlers::admin_handler::handle(&cmd.command, db, &self.ctx, command_name).await @@ -107,3 +152,64 @@ impl CommandRouter { } } } + +fn allows_unauthenticated(command_name: &str) -> bool { + matches!( + command_name, + "hello" | "ismaster" | "isMaster" | "saslStart" | "saslContinue" | "getnonce" + ) +} + +fn required_action(command_name: &str, command: &Document) -> Option { + match command_name { + "hello" | "ismaster" | "isMaster" | "saslStart" | "saslContinue" | "getnonce" => None, + "ping" | "buildInfo" | "buildinfo" | "hostInfo" | "whatsmyuri" | "getLog" + | "getCmdLineOpts" | "getParameter" | "getFreeMonitoringStatus" | "setFreeMonitoring" + | "getShardMap" | "shardingState" | "atlasVersion" | "connectionStatus" + | "startSession" | "endSessions" | "killSessions" | "authenticate" | "logout" => None, + + "find" | "getMore" | "killCursors" | "count" | "distinct" | "listIndexes" + | "listCollections" | "collStats" | "dbStats" | "validate" | "explain" => { + Some(AuthAction::Read) + } + + "aggregate" => Some(if aggregate_writes(command) { + AuthAction::Write + } else { + AuthAction::Read + }), + + "insert" | "update" | "findAndModify" | "delete" | "commitTransaction" + | "abortTransaction" => Some(AuthAction::Write), + + "createIndexes" | "dropIndexes" | "create" | "drop" | "dropDatabase" + | "renameCollection" | "compact" | "reIndex" | "fsync" | "profile" => { + Some(AuthAction::DbAdmin) + } + + "createUser" | "updateUser" | "dropUser" | "usersInfo" | "grantRolesToUser" + | "revokeRolesFromUser" => Some(AuthAction::UserAdmin), + + "serverStatus" | "listDatabases" | "currentOp" | "killOp" | "top" => { + Some(AuthAction::ClusterMonitor) + } + + _ => None, + } +} + +fn aggregate_writes(command: &Document) -> bool { + let Ok(pipeline) = command.get_array("pipeline") else { + return false; + }; + pipeline.last().and_then(|stage| match stage { + Bson::Document(doc) => Some(doc.contains_key("$out") || doc.contains_key("$merge")), + _ => None, + }).unwrap_or(false) +} + +fn transaction_command_unsupported(command_name: &str, command: &Document) -> bool { + matches!(command_name, "commitTransaction" | "abortTransaction") + || matches!(command.get("startTransaction"), Some(Bson::Boolean(true))) + || matches!(command.get("autocommit"), Some(Bson::Boolean(false))) +} diff --git a/rust/crates/rustdb-config/src/lib.rs b/rust/crates/rustdb-config/src/lib.rs index 32bdc94..d5b1e83 100644 --- a/rust/crates/rustdb-config/src/lib.rs +++ b/rust/crates/rustdb-config/src/lib.rs @@ -46,6 +46,99 @@ pub struct RustDbOptions { /// Interval in ms for periodic persistence (default: 60000) #[serde(default = "default_persist_interval")] pub persist_interval_ms: u64, + + /// Authentication configuration. + #[serde(default)] + pub auth: AuthOptions, + + /// TLS transport configuration for TCP listeners. + #[serde(default)] + pub tls: TlsOptions, +} + +/// Authentication configuration for the embedded server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AuthOptions { + /// Whether clients must authenticate before issuing protected commands. + #[serde(default)] + pub enabled: bool, + + /// Bootstrap users loaded at startup. Passwords are converted into SCRAM credentials in memory. + #[serde(default)] + pub users: Vec, + + /// Optional path for persisted SCRAM user metadata. Stores derived credentials, never plaintext passwords. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub users_path: Option, + + /// SCRAM iteration count used for bootstrap credentials. + #[serde(default = "default_scram_iterations")] + pub scram_iterations: u32, +} + +impl Default for AuthOptions { + fn default() -> Self { + Self { + enabled: false, + users: Vec::new(), + users_path: None, + scram_iterations: default_scram_iterations(), + } + } +} + +/// TLS transport configuration for the embedded server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TlsOptions { + /// Whether TCP client connections must use TLS. + #[serde(default)] + pub enabled: bool, + + /// PEM-encoded server certificate chain. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub cert_path: Option, + + /// PEM-encoded server private key. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub key_path: Option, + + /// PEM-encoded client CA roots for mTLS verification. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub ca_path: Option, + + /// Require clients to present a certificate signed by caPath. + #[serde(default)] + pub require_client_cert: bool, +} + +impl Default for TlsOptions { + fn default() -> Self { + Self { + enabled: false, + cert_path: None, + key_path: None, + ca_path: None, + require_client_cert: false, + } + } +} + +/// A bootstrap user for SCRAM authentication. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AuthUserOptions { + pub username: String, + pub password: String, + #[serde(default = "default_auth_database")] + pub database: String, + #[serde(default)] + pub roles: Vec, } fn default_port() -> u16 { @@ -60,6 +153,14 @@ fn default_persist_interval() -> u64 { 60000 } +fn default_scram_iterations() -> u32 { + 15000 +} + +fn default_auth_database() -> String { + "admin".to_string() +} + impl Default for RustDbOptions { fn default() -> Self { Self { @@ -70,6 +171,8 @@ impl Default for RustDbOptions { storage_path: None, persist_path: None, persist_interval_ms: default_persist_interval(), + auth: AuthOptions::default(), + tls: TlsOptions::default(), } } } @@ -92,6 +195,59 @@ impl RustDbOptions { "storagePath is required when storage is 'file'".to_string(), )); } + if self.auth.enabled { + if self.auth.users.is_empty() && self.auth.users_path.is_none() { + return Err(ConfigError::ValidationError( + "auth.users or auth.usersPath must be set when auth.enabled is true".to_string(), + )); + } + if self.auth.scram_iterations < 4096 { + return Err(ConfigError::ValidationError( + "auth.scramIterations must be at least 4096".to_string(), + )); + } + for user in &self.auth.users { + if user.username.is_empty() { + return Err(ConfigError::ValidationError( + "auth.users[].username must not be empty".to_string(), + )); + } + if user.password.is_empty() { + return Err(ConfigError::ValidationError( + format!("auth user '{}' must have a non-empty password", user.username), + )); + } + if user.database.is_empty() { + return Err(ConfigError::ValidationError( + format!("auth user '{}' must have a non-empty database", user.username), + )); + } + } + } + if self.tls.enabled { + if self.socket_path.is_some() { + return Err(ConfigError::ValidationError( + "tls.enabled is only supported for TCP listeners".to_string(), + )); + } + if self.tls.cert_path.as_deref().unwrap_or_default().is_empty() { + return Err(ConfigError::ValidationError( + "tls.certPath is required when tls.enabled is true".to_string(), + )); + } + if self.tls.key_path.as_deref().unwrap_or_default().is_empty() { + return Err(ConfigError::ValidationError( + "tls.keyPath is required when tls.enabled is true".to_string(), + )); + } + if self.tls.require_client_cert + && self.tls.ca_path.as_deref().unwrap_or_default().is_empty() + { + return Err(ConfigError::ValidationError( + "tls.caPath is required when tls.requireClientCert is true".to_string(), + )); + } + } Ok(()) } @@ -101,7 +257,12 @@ impl RustDbOptions { let encoded = urlencoding(socket_path); format!("mongodb://{}", encoded) } else { - format!("mongodb://{}:{}", self.host, self.port) + let base = format!("mongodb://{}:{}", self.host, self.port); + if self.tls.enabled { + format!("{}/?tls=true", base) + } else { + base + } } } } diff --git a/rust/crates/rustdb-storage/src/file.rs b/rust/crates/rustdb-storage/src/file.rs index a52cbc2..94c3a20 100644 --- a/rust/crates/rustdb-storage/src/file.rs +++ b/rust/crates/rustdb-storage/src/file.rs @@ -187,6 +187,27 @@ impl CollectionState { } } +fn truncate_invalid_tail( + data_path: &PathBuf, + stats: &crate::keydir::BuildStats, +) -> StorageResult<()> { + if stats.invalid_tail_bytes == 0 { + return Ok(()); + } + + tracing::warn!( + path = %data_path.display(), + valid_data_end = stats.valid_data_end, + invalid_tail_bytes = stats.invalid_tail_bytes, + "truncating invalid data file tail" + ); + + let file = std::fs::OpenOptions::new().write(true).open(data_path)?; + file.set_len(stats.valid_data_end)?; + file.sync_all()?; + Ok(()) +} + // --------------------------------------------------------------------------- // Collection cache key: "db\0coll" // --------------------------------------------------------------------------- @@ -279,7 +300,8 @@ impl FileStorageAdapter { hint_path, stored_size, actual_size ); } - let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?; + let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?; + truncate_invalid_tail(&data_path, &stats)?; (kd, dead, false) } else { // Size matches — validate entry integrity with spot-checks @@ -296,19 +318,22 @@ impl FileStorageAdapter { (kd, dead, true) } else { tracing::warn!("hint file {:?} failed validation, rebuilding from data file", hint_path); - let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?; + let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?; + truncate_invalid_tail(&data_path, &stats)?; (kd, dead, false) } } } _ => { debug!("hint file invalid, rebuilding KeyDir from data file"); - let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?; + let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?; + truncate_invalid_tail(&data_path, &stats)?; (kd, dead, false) } } } else if data_path.exists() { - let (kd, dead, _stats) = KeyDir::build_from_data_file(&data_path)?; + let (kd, dead, stats) = KeyDir::build_from_data_file(&data_path)?; + truncate_invalid_tail(&data_path, &stats)?; (kd, dead, false) } else { (KeyDir::new(), 0, false) diff --git a/rust/crates/rustdb-storage/src/keydir.rs b/rust/crates/rustdb-storage/src/keydir.rs index 650460c..3febb57 100644 --- a/rust/crates/rustdb-storage/src/keydir.rs +++ b/rust/crates/rustdb-storage/src/keydir.rs @@ -14,7 +14,7 @@ use dashmap::DashMap; use crate::error::{StorageError, StorageResult}; use crate::record::{ - DataRecord, FileHeader, FileType, RecordScanner, FILE_HEADER_SIZE, FORMAT_VERSION, + DataRecord, FileHeader, FileType, FILE_HEADER_SIZE, FORMAT_VERSION, }; // --------------------------------------------------------------------------- @@ -49,6 +49,10 @@ pub struct BuildStats { pub tombstones: u64, /// Number of records superseded by a later write for the same key. pub superseded_records: u64, + /// Byte offset immediately after the last valid record. + pub valid_data_end: u64, + /// Number of invalid tail bytes after the last valid record. + pub invalid_tail_bytes: u64, } // --------------------------------------------------------------------------- @@ -137,6 +141,7 @@ impl KeyDir { /// stale records (superseded by later writes or tombstoned). pub fn build_from_data_file(path: &Path) -> StorageResult<(Self, u64, BuildStats)> { let file = std::fs::File::open(path)?; + let file_len = file.metadata()?.len(); let mut reader = BufReader::new(file); // Read and validate file header @@ -152,13 +157,49 @@ impl KeyDir { let keydir = KeyDir::new(); let mut dead_bytes: u64 = 0; - let mut stats = BuildStats::default(); + let mut stats = BuildStats { + valid_data_end: FILE_HEADER_SIZE as u64, + ..BuildStats::default() + }; - let scanner = RecordScanner::new(reader, FILE_HEADER_SIZE as u64); - for result in scanner { - let (offset, record) = result?; + loop { + let record_offset = stats.valid_data_end; + let (record, disk_size) = match DataRecord::decode_from(&mut reader) { + Ok(Some((record, disk_size))) => (record, disk_size), + Ok(None) => { + if file_len > record_offset { + stats.invalid_tail_bytes = file_len - record_offset; + } + break; + } + Err(StorageError::IoError(e)) if e.kind() == io::ErrorKind::UnexpectedEof => { + stats.invalid_tail_bytes = file_len.saturating_sub(record_offset); + break; + } + Err(StorageError::ChecksumMismatch { expected, actual }) => { + tracing::warn!( + path = %path.display(), + offset = record_offset, + "stopping data file scan at checksum mismatch: expected 0x{expected:08X}, got 0x{actual:08X}" + ); + stats.invalid_tail_bytes = file_len.saturating_sub(record_offset); + break; + } + Err(StorageError::CorruptRecord(message)) => { + tracing::warn!( + path = %path.display(), + offset = record_offset, + "stopping data file scan at corrupt record: {message}" + ); + stats.invalid_tail_bytes = file_len.saturating_sub(record_offset); + break; + } + Err(e) => return Err(e), + }; + + stats.valid_data_end += disk_size as u64; let is_tombstone = record.is_tombstone(); - let disk_size = record.disk_size() as u32; + let disk_size = disk_size as u32; let value_len = record.value.len() as u32; let timestamp = record.timestamp; let key = String::from_utf8(record.key) @@ -175,7 +216,7 @@ impl KeyDir { dead_bytes += disk_size as u64; } else { let entry = KeyDirEntry { - offset, + offset: record_offset, record_len: disk_size, value_len, timestamp, diff --git a/rust/crates/rustdb/Cargo.toml b/rust/crates/rustdb/Cargo.toml index 92cf58a..b64e4ab 100644 --- a/rust/crates/rustdb/Cargo.toml +++ b/rust/crates/rustdb/Cargo.toml @@ -21,9 +21,12 @@ rustdb-query = { workspace = true } rustdb-storage = { workspace = true } rustdb-index = { workspace = true } rustdb-txn = { workspace = true } +rustdb-auth = { workspace = true } rustdb-commands = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } +tokio-rustls = { workspace = true } +rustls-pemfile = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } clap = { workspace = true } diff --git a/rust/crates/rustdb/src/lib.rs b/rust/crates/rustdb/src/lib.rs index 09fd95d..a7b284c 100644 --- a/rust/crates/rustdb/src/lib.rs +++ b/rust/crates/rustdb/src/lib.rs @@ -1,10 +1,12 @@ pub mod management; +use std::fs::File; +use std::io::BufReader; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use anyhow::Result; +use anyhow::{Context, Result}; use dashmap::DashMap; use tokio::net::TcpListener; #[cfg(unix)] @@ -12,13 +14,17 @@ use tokio::net::UnixListener; use tokio_util::codec::Framed; use tokio_util::sync::CancellationToken; -use rustdb_config::{RustDbOptions, StorageType}; +use rustdb_config::{RustDbOptions, StorageType, TlsOptions}; use rustdb_wire::{WireCodec, OP_QUERY}; use rustdb_wire::{encode_op_msg_response, encode_op_reply_response}; use rustdb_storage::{StorageAdapter, MemoryStorageAdapter, FileStorageAdapter, OpLog}; use rustdb_index::{IndexEngine, IndexOptions}; use rustdb_txn::{TransactionEngine, SessionEngine}; -use rustdb_commands::{CommandRouter, CommandContext}; +use rustdb_auth::AuthEngine; +use rustdb_commands::{CommandRouter, CommandContext, ConnectionState}; +use tokio_rustls::rustls::{RootCertStore, ServerConfig}; +use tokio_rustls::rustls::server::WebPkiClientVerifier; +use tokio_rustls::TlsAcceptor; /// The main RustDb server. pub struct RustDb { @@ -150,6 +156,8 @@ impl RustDb { } } + let auth = Arc::new(AuthEngine::from_options(&options.auth)?); + let ctx = Arc::new(CommandContext { storage, indexes, @@ -158,6 +166,7 @@ impl RustDb { cursors: Arc::new(DashMap::new()), start_time: std::time::Instant::now(), oplog: Arc::new(OpLog::new()), + auth, }); let router = Arc::new(CommandRouter::new(ctx.clone())); @@ -215,7 +224,12 @@ impl RustDb { } else { let addr = format!("{}:{}", self.options.host, self.options.port); let listener = TcpListener::bind(&addr).await?; - tracing::info!("RustDb listening on {}", addr); + let tls_acceptor = if self.options.tls.enabled { + Some(build_tls_acceptor(&self.options.tls)?) + } else { + None + }; + tracing::info!(tls = self.options.tls.enabled, "RustDb listening on {}", addr); let handle = tokio::spawn(async move { loop { @@ -226,9 +240,21 @@ impl RustDb { Ok((stream, _addr)) => { let _ = stream.set_nodelay(true); let router = router.clone(); - tokio::spawn(async move { - handle_connection(stream, router).await; - }); + match tls_acceptor.clone() { + Some(acceptor) => { + tokio::spawn(async move { + match acceptor.accept(stream).await { + Ok(tls_stream) => handle_connection(tls_stream, router).await, + Err(e) => tracing::debug!("TLS handshake failed: {}", e), + } + }); + } + None => { + tokio::spawn(async move { + handle_connection(stream, router).await; + }); + } + } } Err(e) => { tracing::error!("Accept error: {}", e); @@ -275,14 +301,88 @@ impl RustDb { } } +fn build_tls_acceptor(options: &TlsOptions) -> Result { + let cert_path = options + .cert_path + .as_deref() + .context("tls.certPath is required when tls.enabled is true")?; + let key_path = options + .key_path + .as_deref() + .context("tls.keyPath is required when tls.enabled is true")?; + + let certs = load_certs(cert_path)?; + let key = load_private_key(key_path)?; + + let config = if options.require_client_cert { + let ca_path = options + .ca_path + .as_deref() + .context("tls.caPath is required when tls.requireClientCert is true")?; + let roots = load_root_store(ca_path)?; + let verifier = WebPkiClientVerifier::builder(Arc::new(roots)) + .build() + .context("failed to build TLS client certificate verifier")?; + ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(certs, key) + .context("failed to build TLS server configuration")? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .context("failed to build TLS server configuration")? + }; + + Ok(TlsAcceptor::from(Arc::new(config))) +} + +fn load_certs(path: &str) -> Result>> { + let file = File::open(path).with_context(|| format!("failed to open TLS certificate file '{}'", path))?; + let mut reader = BufReader::new(file); + let certs = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .with_context(|| format!("failed to parse TLS certificate file '{}'", path))?; + + if certs.is_empty() { + anyhow::bail!("TLS certificate file '{}' did not contain any certificates", path); + } + + Ok(certs) +} + +fn load_private_key(path: &str) -> Result> { + let file = File::open(path).with_context(|| format!("failed to open TLS private key file '{}'", path))?; + let mut reader = BufReader::new(file); + rustls_pemfile::private_key(&mut reader) + .with_context(|| format!("failed to parse TLS private key file '{}'", path))? + .with_context(|| format!("TLS private key file '{}' did not contain a private key", path)) +} + +fn load_root_store(path: &str) -> Result { + let mut roots = RootCertStore::empty(); + for cert in load_certs(path)? { + roots + .add(cert) + .with_context(|| format!("failed to add TLS client CA certificate from '{}'", path))?; + } + + if roots.is_empty() { + anyhow::bail!("TLS client CA file '{}' did not contain usable certificates", path); + } + + Ok(roots) +} + /// Handle a single client connection using the wire protocol codec. async fn handle_connection(stream: S, router: Arc) where - S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { use futures_util::{SinkExt, StreamExt}; let mut framed = Framed::new(stream, WireCodec); + let mut connection = ConnectionState::new(); while let Some(result) = framed.next().await { match result { @@ -290,7 +390,7 @@ where let request_id = parsed_cmd.request_id; let op_code = parsed_cmd.op_code; - let response_doc = router.route(&parsed_cmd).await; + let response_doc = router.route(&parsed_cmd, &mut connection).await; let response_id = next_request_id(); diff --git a/rust/crates/rustdb/src/management.rs b/rust/crates/rustdb/src/management.rs index b5faf99..cbf51db 100644 --- a/rust/crates/rustdb/src/management.rs +++ b/rust/crates/rustdb/src/management.rs @@ -167,6 +167,9 @@ async fn handle_start( Ok(o) => o, Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)), }; + if let Err(e) = options.validate() { + return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)); + } let connection_uri = options.connection_uri(); diff --git a/test/test.auth.ts b/test/test.auth.ts new file mode 100644 index 0000000..20392be --- /dev/null +++ b/test/test.auth.ts @@ -0,0 +1,173 @@ +import { expect, tap } from '@git.zone/tstest/tapbundle'; +import * as smartdb from '../ts/index.js'; +import { MongoClient } from 'mongodb'; +import * as fs from 'fs'; +import * as os from 'os'; +import * as path from 'path'; + +let server: smartdb.SmartdbServer; +let authedClient: MongoClient; +let openClient: MongoClient; +let readerClient: MongoClient; +let tmpDir: string; +let usersPath: string; + +function makeTmpDir(): string { + return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-auth-test-')); +} + +function cleanTmpDir(dir: string): void { + if (fs.existsSync(dir)) { + fs.rmSync(dir, { recursive: true, force: true }); + } +} + +tap.test('auth: should start server with SCRAM-SHA-256 auth enabled', async () => { + tmpDir = makeTmpDir(); + usersPath = path.join(tmpDir, 'users.json'); + server = new smartdb.SmartdbServer({ + port: 27118, + auth: { + enabled: true, + usersPath, + scramIterations: 4096, + users: [ + { + username: 'root', + password: 'secret', + database: 'admin', + roles: ['root'], + }, + ], + }, + }); + await server.start(); + expect(server.running).toBeTrue(); +}); + +tap.test('auth: should reject protected commands before authentication', async () => { + openClient = new MongoClient('mongodb://127.0.0.1:27118', { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + await openClient.connect(); + + let threw = false; + try { + await openClient.db('admin').command({ ping: 1 }); + } catch (err: any) { + threw = true; + expect(err.code).toEqual(13); + } + expect(threw).toBeTrue(); +}); + +tap.test('auth: should reject invalid credentials', async () => { + const badClient = new MongoClient('mongodb://root:wrong@127.0.0.1:27118/admin?authSource=admin', { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + + let threw = false; + try { + await badClient.connect(); + await badClient.db('admin').command({ ping: 1 }); + } catch { + threw = true; + } finally { + await badClient.close().catch(() => undefined); + } + expect(threw).toBeTrue(); +}); + +tap.test('auth: should authenticate valid credentials', async () => { + authedClient = new MongoClient('mongodb://root:secret@127.0.0.1:27118/admin?authSource=admin', { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + await authedClient.connect(); + const result = await authedClient.db('admin').command({ ping: 1 }); + expect(result.ok).toEqual(1); +}); + +tap.test('auth: should allow CRUD after authentication', async () => { + const coll = authedClient.db('securedb').collection('notes'); + const inserted = await coll.insertOne({ title: 'enterprise auth' }); + expect(inserted.acknowledged).toBeTrue(); + + const doc = await coll.findOne({ _id: inserted.insertedId }); + expect(doc).toBeTruthy(); + expect(doc!.title).toEqual('enterprise auth'); +}); + +tap.test('auth: root should create a read-only user', async () => { + const result = await authedClient.db('admin').command({ + createUser: 'reader', + pwd: 'readpass', + roles: [{ role: 'read', db: 'securedb' }], + }); + expect(result.ok).toEqual(1); + + const usersInfo = await authedClient.db('admin').command({ usersInfo: 'reader' }); + expect(usersInfo.ok).toEqual(1); + expect(usersInfo.users.length).toEqual(1); + expect(usersInfo.users[0].user).toEqual('reader'); +}); + +tap.test('auth: read-only user should read but not write', async () => { + readerClient = new MongoClient('mongodb://reader:readpass@127.0.0.1:27118/admin?authSource=admin', { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + await readerClient.connect(); + + const doc = await readerClient.db('securedb').collection('notes').findOne({ title: 'enterprise auth' }); + expect(doc).toBeTruthy(); + + let threw = false; + try { + await readerClient.db('securedb').collection('notes').insertOne({ title: 'denied write' }); + } catch (err: any) { + threw = true; + expect(err.code).toEqual(13); + } + expect(threw).toBeTrue(); +}); + +tap.test('auth: persisted users should survive server restart', async () => { + await readerClient.close(); + await authedClient.close(); + await server.stop(); + + // Simulates a crash after writing the temporary auth metadata file but before rename. + fs.writeFileSync(path.join(tmpDir, 'users.tmp'), '{ invalid json'); + + server = new smartdb.SmartdbServer({ + port: 27118, + auth: { + enabled: true, + usersPath, + users: [], + scramIterations: 4096, + }, + }); + await server.start(); + + readerClient = new MongoClient('mongodb://reader:readpass@127.0.0.1:27118/admin?authSource=admin', { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + await readerClient.connect(); + const result = await readerClient.db('admin').command({ ping: 1 }); + expect(result.ok).toEqual(1); +}); + +tap.test('auth: cleanup', async () => { + await openClient.close(); + await readerClient.close(); + await server.stop(); + expect(server.running).toBeFalse(); + cleanTmpDir(tmpDir); +}); + +export default tap.start(); diff --git a/test/test.crash-recovery.ts b/test/test.crash-recovery.ts new file mode 100644 index 0000000..c6de58f --- /dev/null +++ b/test/test.crash-recovery.ts @@ -0,0 +1,91 @@ +import { expect, tap } from '@git.zone/tstest/tapbundle'; +import * as smartdb from '../ts/index.js'; +import { MongoClient, Db } from 'mongodb'; +import * as fs from 'fs'; +import * as os from 'os'; +import * as path from 'path'; + +let tmpDir: string; +let localDb: smartdb.LocalSmartDb; +let client: MongoClient; +let db: Db; +let dataPath: string; +let corruptedSize: number; + +function makeTmpDir(): string { + return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-crash-test-')); +} + +function cleanTmpDir(dir: string): void { + if (fs.existsSync(dir)) { + fs.rmSync(dir, { recursive: true, force: true }); + } +} + +async function startAndConnect(): Promise { + localDb = new smartdb.LocalSmartDb({ folderPath: tmpDir }); + const info = await localDb.start(); + client = new MongoClient(info.connectionUri, { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + await client.connect(); + db = client.db('crashtest'); +} + +tap.test('crash-recovery: create baseline data', async () => { + tmpDir = makeTmpDir(); + await startAndConnect(); + + await db.collection('docs').insertMany([ + { key: 'a', value: 1 }, + { key: 'b', value: 2 }, + { key: 'c', value: 3 }, + ]); + + await client.close(); + await localDb.stop(); + + dataPath = path.join(tmpDir, 'crashtest', 'docs', 'data.rdb'); + expect(fs.existsSync(dataPath)).toBeTrue(); +}); + +tap.test('crash-recovery: append a torn final record', async () => { + const data = fs.readFileSync(dataPath); + const partialRecord = data.subarray(64, 94); + expect(partialRecord.length).toEqual(30); + + fs.appendFileSync(dataPath, partialRecord); + corruptedSize = fs.statSync(dataPath).size; + expect(corruptedSize).toEqual(data.length + partialRecord.length); +}); + +tap.test('crash-recovery: restart truncates invalid tail and preserves valid records', async () => { + await startAndConnect(); + + const repairedSize = fs.statSync(dataPath).size; + expect(repairedSize < corruptedSize).toBeTrue(); + + const docs = await db.collection('docs').find({}).sort({ key: 1 }).toArray(); + expect(docs.map(doc => doc.key)).toEqual(['a', 'b', 'c']); +}); + +tap.test('crash-recovery: future writes remain durable after tail repair', async () => { + await db.collection('docs').insertOne({ key: 'd', value: 4 }); + expect(await db.collection('docs').countDocuments()).toEqual(4); + + await client.close(); + await localDb.stop(); + + await startAndConnect(); + const docs = await db.collection('docs').find({}).sort({ key: 1 }).toArray(); + expect(docs.map(doc => doc.key)).toEqual(['a', 'b', 'c', 'd']); +}); + +tap.test('crash-recovery: cleanup', async () => { + await client.close(); + await localDb.stop(); + cleanTmpDir(tmpDir); +}); + +export default tap.start(); diff --git a/test/test.tls.ts b/test/test.tls.ts new file mode 100644 index 0000000..6a11dd6 --- /dev/null +++ b/test/test.tls.ts @@ -0,0 +1,171 @@ +import { expect, tap } from '@git.zone/tstest/tapbundle'; +import * as smartdb from '../ts/index.js'; +import { MongoClient } from 'mongodb'; +import * as fs from 'fs'; +import * as net from 'net'; +import * as os from 'os'; +import * as path from 'path'; + +// Static test-only CA and server certificate. The private key is intentionally +// non-secret test fixture material and must not be reused outside tests. +const CA_PEM = `-----BEGIN CERTIFICATE----- +MIIDFTCCAf2gAwIBAgIUXQlk6FLuWELDKLw9KXi0UIYmU50wDQYJKoZIhvcNAQEL +BQAwGjEYMBYGA1UEAwwPU21hcnREQiBUZXN0IENBMB4XDTI2MDQyOTIxMjYxNFoX +DTM2MDQyNjIxMjYxNFowGjEYMBYGA1UEAwwPU21hcnREQiBUZXN0IENBMIIBIjAN +BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApnRgZvodreKEKkSodwgDe2JKsA3N +GC4c7dmqmOBRQst0OYRoW0kjHnzCVHoGlMTAnjJWXRayPeJCroSA0WhEZIjgHAjW +FuWIr+MUYdCG7czdbDEqZYGsrBDUwv+ydgsDNhLKtbfVfcJckdmFp+TT+Po3sf8o +u5AfOlcjhM22reBLhZJ2FfM2IbqygRbBxNvU3tH5E1kgu2CpYieXQsmqBwkOPM0S +fgkCjlqFeeqV7Jjdq1P6srIItzg6n8/5KGBTxc7VB11WxVAZMIxnOtwpOCpSjbiy +jymBLKvyZxklWGpG9HT6RzUTdp0WpwnO7FlbYqD97jrbwA7PfhbJVUkTeQIDAQAB +o1MwUTAdBgNVHQ4EFgQUaqFWiFvibBYpJjluNW4XlocmqOQwHwYDVR0jBBgwFoAU +aqFWiFvibBYpJjluNW4XlocmqOQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B +AQsFAAOCAQEAdbmRCxeHwfq6Mw0BRXWYM81xrzDMDBwLkIyaVkBJXCEX4Ybj8QHv +tplNqgQae1Hr1qYyNzkivDI/hPnvv/wDsAnT8Wz0/udPpcASTXC03xhRtFXwBSGq +2GtLa53cZHJLoGu1S2ntM6Xo3gropXSx/+LIfefsQvqRO/5WxRrEE10OiFr19rA7 +md0nD6zXdwrMRghu6ACuxX6Ext6QJbTL4r1UGbHg2a9UbdBjcb8sfFPLyEjiLpBK +DYvRjddKOwbOpFPoLwmed59Pa6bcqT9NnkRHL+aXUm3M3HfVhNKae7JJShUmCzdx +rbKNJQAUp/mMHnBOSxYS7aqgwBKCiKtP4A== +-----END CERTIFICATE----- +`; + +const SERVER_CERT_PEM = `-----BEGIN CERTIFICATE----- +MIIDPTCCAiWgAwIBAgIUMfuX4VHvVJ8Vo6o1U2+f7MHU7dowDQYJKoZIhvcNAQEL +BQAwGjEYMBYGA1UEAwwPU21hcnREQiBUZXN0IENBMB4XDTI2MDQyOTIxMjYxNFoX +DTM2MDQyNjIxMjYxNFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5eFz1q4juQsEE7cPN5eFrLvRJW/zOMGBmiet +VTQSqVZ/3j3NBWsgxK2xQnNbEXGMlTEE11ih0cCQacc/JnbuvwOt3QX8X6oy4pmb +LMGQJEk2FgdpP6OtGqqYbt/fT7QBY39nt6z/RzxYZI7t5g/nkHnlzmzD+ila6k9b +TzBSfSmtHHKW/c6az/Dh/xe50zDgrzlBA7e5zoleKqRJFRZlDnDoLyx0EOUbbTbQ +vipMynP5bq8l6Fc0N9DAWmXvV4o2x0ZQjfEx5LTvbxNkVWtv8w9w4t4vAZqXwrXd +5OZETMWdy7ezxL0E9Snwc6sSfatlVenD/8P5hWJ/C0vCiw21RwIDAQABo4GAMH4w +GgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMAsGA1UdDwQEAwIFoDATBgNVHSUE +DDAKBggrBgEFBQcDATAdBgNVHQ4EFgQUK2nSXereMZek6gxLweY1AVt9OaswHwYD +VR0jBBgwFoAUaqFWiFvibBYpJjluNW4XlocmqOQwDQYJKoZIhvcNAQELBQADggEB +AAkC6suxamn+OEmJLMqgaGCvEtFbob5pMijYC32vJNPev+bUHMOB4Oo0FyO59sX3 +zfLLwk7jagbWJi37T714aSjyJwUHd4XA7McSabP4+1hOOL0NqfiE4yRnxPhlvf3E +9otoStAAJ86067DwIs5id7jYm+qrxn6bL+P1h+P1tYxnPOoD0v1cHVbtUNV2tH2E +eBhdtTbF+NHrj+oXFGI3jiI7qcwpJ9DFUo/w0sC0POY0T5aWl4ptSXVgEc7nkE91 +bbPOPyoMjjZ4WhKAW5UzfOafB0bO7+4E0GHcAkBJmS4V8g5qt56nftr+d58R/odY +0hQjpoIwzl9RCEW0h8xkqMQ= +-----END CERTIFICATE----- +`; + +const SERVER_KEY_PEM = `-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDl4XPWriO5CwQT +tw83l4Wsu9Elb/M4wYGaJ61VNBKpVn/ePc0FayDErbFCc1sRcYyVMQTXWKHRwJBp +xz8mdu6/A63dBfxfqjLimZsswZAkSTYWB2k/o60aqphu399PtAFjf2e3rP9HPFhk +ju3mD+eQeeXObMP6KVrqT1tPMFJ9Ka0ccpb9zprP8OH/F7nTMOCvOUEDt7nOiV4q +pEkVFmUOcOgvLHQQ5RttNtC+KkzKc/luryXoVzQ30MBaZe9XijbHRlCN8THktO9v +E2RVa2/zD3Di3i8BmpfCtd3k5kRMxZ3Lt7PEvQT1KfBzqxJ9q2VV6cP/w/mFYn8L +S8KLDbVHAgMBAAECggEAAInWJR8US1cow8kOepFayUxJUZ6hAbWGUa+dGtF757Sh +qQoZBFW7ZmqHu0Gc6X4MF79dJQn6mwyp6e2DCtqFdaITEqz0ad7yrpAwilrLtSIM +w+FxkCoYejMDF2Nj2QJxbGO8gPQhRu/vvxCMoxjPcImwjZq4nMnjAiB8dMOGte9V +av/RoWUOFXqeiJHqAXiE372I4BupwYhGrSUQyuVj3SugDRbzvPepTQNRxaBJQPgy +4ZtZ8FjJdPFvlyxv6fmLFULHwPNcS6PLWPuwpj7oEQzG4/Q9ojYj4EPdpoOW7qoH +h1Y6ag1vk5A/m9DjvMhIDzmUJmq8mlldxqbCBpH0+QKBgQD3Eh7F0ZXdLQe/aG5t +ul9hTv68NZa5M0JzJinB6WjXl2s0bUgIvAE9ZmfUYHs8AMvTu4YwJqsrpMuzFOT9 +Ct5wBSyFbPzVOt9MYE1Gipxx8RfEMSq7Sp0MjarX3h0Va8ry83NWzrN1CvyP8BQq +CuXo/IislCDgPg0uXhLD/7GsWQKBgQDuMEptldCKtpW6CdLdYih6xh0j1mdGU4Kb +7mTzo3OU3nDnGXGhqvJt/xpksPl7GPRHYQ1dqRzvLKHDtTJqhkedZBnE6A94LkVl +uNJnR8v4PkR9nKKg0uK2ug9VcfSiXUpl2yyYiDc123WjHdwH2U6BV3smb/7KwEvv +FWaP7PO6nwKBgAE2w5PxPa1ChWE5YCGF4uYVf0bpdH4gdFkgfOAJB4zXn504VDxG +wDLPB/+RIcnfryCxMS2XYwvp2V5d4eokXYdrXxagvHVHvsUfTAHmuHIO3zEFlNIq +wa7IG2jIHJh4WRzseUqZ5WPT0/3ZDiBOwWZtpzZB3A99/o6Vw73WycaxAoGAHTeR +OaYB4bIJ5bskwYEz4/N/SZEYM/k0cTop6fTnzaAHi2GEncchW7rKGwXWZHIoLMVL +5WxEH1aDNUV5vLVh/X1058FrfFt4qcSlEoQtEfNZZWscS8vygWWLUfjbgDsfUCU1 +cDRtSU71PCACiHfweE8pzQo539b8uYQPg6IWN5MCgYA6z/kvGiBB9xFBUAJPsj+w +XW/UGbn7svZaCob+N5RA9Rs/0idv/bO2nAauZyHG/nn6HXII6U5pmRyVqWKhI22q +K3J0LCP42Zb6/eYzQPbP1jWHCMaL2QJQGsl4NMZixlnNJV0aG/5CButqzSC/cMbG +DX0n+YqqWmCgHWU2csnlAA== +-----END PRIVATE KEY----- +`; + +let server: smartdb.SmartdbServer; +let client: MongoClient; +let tmpDir: string; +let caPath: string; +let certPath: string; +let keyPath: string; +let port: number; + +function makeTmpDir(): string { + return fs.mkdtempSync(path.join(os.tmpdir(), 'smartdb-tls-test-')); +} + +function cleanTmpDir(dir: string): void { + if (fs.existsSync(dir)) { + fs.rmSync(dir, { recursive: true, force: true }); + } +} + +async function getFreePort(): Promise { + return await new Promise((resolve, reject) => { + const probe = net.createServer(); + probe.once('error', reject); + probe.listen(0, '127.0.0.1', () => { + const address = probe.address(); + if (!address || typeof address === 'string') { + probe.close(() => reject(new Error('Failed to allocate TCP port'))); + return; + } + probe.close(() => resolve(address.port)); + }); + }); +} + +tap.test('tls: should start server with TLS enabled', async () => { + tmpDir = makeTmpDir(); + port = await getFreePort(); + caPath = path.join(tmpDir, 'ca.pem'); + certPath = path.join(tmpDir, 'server.pem'); + keyPath = path.join(tmpDir, 'server.key'); + + fs.writeFileSync(caPath, CA_PEM); + fs.writeFileSync(certPath, SERVER_CERT_PEM); + fs.writeFileSync(keyPath, SERVER_KEY_PEM, { mode: 0o600 }); + + server = new smartdb.SmartdbServer({ + port, + tls: { + enabled: true, + certPath, + keyPath, + }, + }); + await server.start(); + + expect(server.running).toBeTrue(); + expect(server.getConnectionUri()).toEqual(`mongodb://127.0.0.1:${port}/?tls=true`); +}); + +tap.test('tls: should connect with official MongoClient and CA validation', async () => { + client = new MongoClient(server.getConnectionUri(), { + directConnection: true, + serverSelectionTimeoutMS: 5000, + tlsCAFile: caPath, + }); + await client.connect(); + + const ping = await client.db('admin').command({ ping: 1 }); + expect(ping.ok).toEqual(1); +}); + +tap.test('tls: should support CRUD over encrypted transport', async () => { + const collection = client.db('tlsdb').collection('notes'); + const inserted = await collection.insertOne({ title: 'encrypted transport' }); + expect(inserted.acknowledged).toBeTrue(); + + const doc = await collection.findOne({ _id: inserted.insertedId }); + expect(doc).toBeTruthy(); + expect(doc!.title).toEqual('encrypted transport'); +}); + +tap.test('tls: cleanup', async () => { + await client.close(); + await server.stop(); + expect(server.running).toBeFalse(); + cleanTmpDir(tmpDir); +}); + +export default tap.start(); diff --git a/test/test.transactions.ts b/test/test.transactions.ts new file mode 100644 index 0000000..a1ef897 --- /dev/null +++ b/test/test.transactions.ts @@ -0,0 +1,115 @@ +import { expect, tap } from '@git.zone/tstest/tapbundle'; +import * as smartdb from '../ts/index.js'; +import { MongoClient } from 'mongodb'; +import * as net from 'net'; + +let server: smartdb.SmartdbServer; +let client: MongoClient; +let port: number; + +async function getFreePort(): Promise { + return await new Promise((resolve, reject) => { + const probe = net.createServer(); + probe.once('error', reject); + probe.listen(0, '127.0.0.1', () => { + const address = probe.address(); + if (!address || typeof address === 'string') { + probe.close(() => reject(new Error('Failed to allocate TCP port'))); + return; + } + probe.close(() => resolve(address.port)); + }); + }); +} + +tap.test('transactions: should start server and connect', async () => { + port = await getFreePort(); + server = new smartdb.SmartdbServer({ port }); + await server.start(); + + client = new MongoClient(`mongodb://127.0.0.1:${port}`, { + directConnection: true, + serverSelectionTimeoutMS: 5000, + }); + await client.connect(); + expect(server.running).toBeTrue(); +}); + +tap.test('transactions: should still support explicit sessions', async () => { + const result = await client.db('admin').command({ startSession: 1 }); + expect(result.ok).toEqual(1); + expect(result.id).toBeTruthy(); + + const end = await client.db('admin').command({ endSessions: [result.id] }); + expect(end.ok).toEqual(1); +}); + +tap.test('transactions: should reject raw transaction-scoped writes before mutation', async () => { + const db = client.db('txntest'); + const coll = db.collection('docs'); + await coll.insertOne({ key: 'outside', value: 1 }); + + let threw = false; + try { + await db.command({ + insert: 'docs', + documents: [{ key: 'inside-raw', value: 2 }], + startTransaction: true, + autocommit: false, + }); + } catch (err: any) { + threw = true; + expect(err.code).toEqual(20); + expect(err.codeName).toEqual('IllegalOperation'); + } + expect(threw).toBeTrue(); + + expect(await coll.countDocuments({ key: 'inside-raw' })).toEqual(0); + expect(await coll.countDocuments({ key: 'outside' })).toEqual(1); +}); + +tap.test('transactions: official driver transaction should fail without committing writes', async () => { + const coll = client.db('txntest').collection('driverdocs'); + await coll.insertOne({ key: 'outside-driver', value: 0 }); + const session = client.startSession(); + + let threw = false; + try { + session.startTransaction(); + await coll.insertOne({ key: 'inside-driver', value: 1 }, { session }); + await session.commitTransaction(); + } catch (err: any) { + threw = true; + expect(err.code).toEqual(20); + expect(err.codeName).toEqual('IllegalOperation'); + await session.abortTransaction().catch(() => undefined); + } finally { + await session.endSession(); + } + + expect(threw).toBeTrue(); + expect(await coll.countDocuments({ key: 'inside-driver' })).toEqual(0); + expect(await coll.countDocuments({ key: 'outside-driver' })).toEqual(1); +}); + +tap.test('transactions: commit and abort commands should be explicit unsupported errors', async () => { + for (const command of [{ commitTransaction: 1 }, { abortTransaction: 1 }]) { + let threw = false; + try { + await client.db('admin').command(command); + } catch (err: any) { + threw = true; + expect(err.code).toEqual(20); + expect(err.codeName).toEqual('IllegalOperation'); + } + expect(threw).toBeTrue(); + } +}); + +tap.test('transactions: cleanup', async () => { + await client.close(); + await server.stop(); + expect(server.running).toBeFalse(); +}); + +export default tap.start(); diff --git a/ts/ts_smartdb/index.ts b/ts/ts_smartdb/index.ts index aaa5e18..5a17387 100644 --- a/ts/ts_smartdb/index.ts +++ b/ts/ts_smartdb/index.ts @@ -2,7 +2,12 @@ // Export server (the main entry point for using SmartDB) export { SmartdbServer } from './server/SmartdbServer.js'; -export type { ISmartdbServerOptions } from './server/SmartdbServer.js'; +export type { + ISmartdbAuthOptions, + ISmartdbAuthUser, + ISmartdbServerOptions, + ISmartdbTlsOptions, +} from './server/SmartdbServer.js'; // Export bridge for advanced usage export { RustDbBridge } from './rust-db-bridge.js'; diff --git a/ts/ts_smartdb/rust-db-bridge.ts b/ts/ts_smartdb/rust-db-bridge.ts index 7b0f2b4..f102014 100644 --- a/ts/ts_smartdb/rust-db-bridge.ts +++ b/ts/ts_smartdb/rust-db-bridge.ts @@ -117,6 +117,24 @@ interface ISmartDbRustConfig { storagePath?: string; persistPath?: string; persistIntervalMs?: number; + auth?: { + enabled?: boolean; + users?: Array<{ + username: string; + password: string; + database?: string; + roles?: string[]; + }>; + usersPath?: string; + scramIterations?: number; + }; + tls?: { + enabled?: boolean; + certPath?: string; + keyPath?: string; + caPath?: string; + requireClientCert?: boolean; + }; } /** diff --git a/ts/ts_smartdb/server/SmartdbServer.ts b/ts/ts_smartdb/server/SmartdbServer.ts index 33effea..4a68b34 100644 --- a/ts/ts_smartdb/server/SmartdbServer.ts +++ b/ts/ts_smartdb/server/SmartdbServer.ts @@ -28,6 +28,32 @@ export interface ISmartdbServerOptions { persistPath?: string; /** Persistence interval in ms (default: 60000) */ persistIntervalMs?: number; + /** Authentication configuration. Disabled by default. */ + auth?: ISmartdbAuthOptions; + /** TLS transport configuration for TCP listeners. Disabled by default. */ + tls?: ISmartdbTlsOptions; +} + +export interface ISmartdbAuthOptions { + enabled?: boolean; + users?: ISmartdbAuthUser[]; + usersPath?: string; + scramIterations?: number; +} + +export interface ISmartdbAuthUser { + username: string; + password: string; + database?: string; + roles?: string[]; +} + +export interface ISmartdbTlsOptions { + enabled?: boolean; + certPath?: string; + keyPath?: string; + caPath?: string; + requireClientCert?: boolean; } /** @@ -64,6 +90,8 @@ export class SmartdbServer { storagePath: options.storagePath ?? './data', persistPath: options.persistPath, persistIntervalMs: options.persistIntervalMs ?? 60000, + auth: options.auth, + tls: options.tls, }; this.bridge = new RustDbBridge(); } @@ -106,6 +134,8 @@ export class SmartdbServer { storagePath: this.options.storagePath, persistPath: this.options.persistPath, persistIntervalMs: this.options.persistIntervalMs, + auth: this.options.auth, + tls: this.options.tls, }); this.resolvedConnectionUri = result.connectionUri; @@ -142,7 +172,8 @@ export class SmartdbServer { const encodedPath = encodeURIComponent(this.options.socketPath); return `mongodb://${encodedPath}`; } - return `mongodb://${this.options.host ?? '127.0.0.1'}:${this.options.port ?? 27017}`; + const baseUri = `mongodb://${this.options.host ?? '127.0.0.1'}:${this.options.port ?? 27017}`; + return this.options.tls?.enabled ? `${baseUri}/?tls=true` : baseUri; } /**