206 lines
7.0 KiB
Rust
206 lines
7.0 KiB
Rust
|
|
use std::time::{Duration, Instant};
|
||
|
|
|
||
|
|
use bson::Bson;
|
||
|
|
use dashmap::DashMap;
|
||
|
|
use tracing::{debug, warn};
|
||
|
|
|
||
|
|
use crate::error::{TransactionError, TransactionResult};
|
||
|
|
|
||
|
|
/// Represents a logical session.
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct Session {
|
||
|
|
pub id: String,
|
||
|
|
pub created_at: Instant,
|
||
|
|
pub last_activity_at: Instant,
|
||
|
|
pub txn_id: Option<String>,
|
||
|
|
pub in_transaction: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Engine that manages logical sessions with timeout and cleanup.
|
||
|
|
pub struct SessionEngine {
|
||
|
|
sessions: DashMap<String, Session>,
|
||
|
|
timeout: Duration,
|
||
|
|
_cleanup_interval: Duration,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl SessionEngine {
|
||
|
|
/// Create a new session engine.
|
||
|
|
///
|
||
|
|
/// * `timeout_ms` - Session timeout in milliseconds (default: 30 minutes = 1_800_000).
|
||
|
|
/// * `cleanup_interval_ms` - How often to run the cleanup task in milliseconds (default: 60_000).
|
||
|
|
pub fn new(timeout_ms: u64, cleanup_interval_ms: u64) -> Self {
|
||
|
|
Self {
|
||
|
|
sessions: DashMap::new(),
|
||
|
|
timeout: Duration::from_millis(timeout_ms),
|
||
|
|
_cleanup_interval: Duration::from_millis(cleanup_interval_ms),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get an existing session or create a new one. Returns the session id.
|
||
|
|
pub fn get_or_create_session(&self, id: &str) -> String {
|
||
|
|
if let Some(mut session) = self.sessions.get_mut(id) {
|
||
|
|
session.last_activity_at = Instant::now();
|
||
|
|
return session.id.clone();
|
||
|
|
}
|
||
|
|
|
||
|
|
let now = Instant::now();
|
||
|
|
let session = Session {
|
||
|
|
id: id.to_string(),
|
||
|
|
created_at: now,
|
||
|
|
last_activity_at: now,
|
||
|
|
txn_id: None,
|
||
|
|
in_transaction: false,
|
||
|
|
};
|
||
|
|
self.sessions.insert(id.to_string(), session);
|
||
|
|
debug!(session_id = %id, "created new session");
|
||
|
|
id.to_string()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Update the last activity timestamp for a session.
|
||
|
|
pub fn touch_session(&self, id: &str) {
|
||
|
|
if let Some(mut session) = self.sessions.get_mut(id) {
|
||
|
|
session.last_activity_at = Instant::now();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// End a session. If a transaction is active, it will be marked for abort.
|
||
|
|
pub fn end_session(&self, id: &str) {
|
||
|
|
if let Some((_, session)) = self.sessions.remove(id) {
|
||
|
|
if session.in_transaction {
|
||
|
|
warn!(
|
||
|
|
session_id = %id,
|
||
|
|
txn_id = ?session.txn_id,
|
||
|
|
"ending session with active transaction, transaction should be aborted"
|
||
|
|
);
|
||
|
|
}
|
||
|
|
debug!(session_id = %id, "session ended");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Associate a transaction with a session.
|
||
|
|
pub fn start_transaction(&self, session_id: &str, txn_id: &str) -> TransactionResult<()> {
|
||
|
|
let mut session = self
|
||
|
|
.sessions
|
||
|
|
.get_mut(session_id)
|
||
|
|
.ok_or_else(|| TransactionError::NotFound(format!("session {}", session_id)))?;
|
||
|
|
|
||
|
|
if session.in_transaction {
|
||
|
|
return Err(TransactionError::AlreadyActive(session_id.to_string()));
|
||
|
|
}
|
||
|
|
|
||
|
|
session.txn_id = Some(txn_id.to_string());
|
||
|
|
session.in_transaction = true;
|
||
|
|
session.last_activity_at = Instant::now();
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Disassociate the transaction from a session (after commit or abort).
|
||
|
|
pub fn end_transaction(&self, session_id: &str) {
|
||
|
|
if let Some(mut session) = self.sessions.get_mut(session_id) {
|
||
|
|
session.txn_id = None;
|
||
|
|
session.in_transaction = false;
|
||
|
|
session.last_activity_at = Instant::now();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check whether a session is currently in a transaction.
|
||
|
|
pub fn is_in_transaction(&self, session_id: &str) -> bool {
|
||
|
|
self.sessions
|
||
|
|
.get(session_id)
|
||
|
|
.map(|s| s.in_transaction)
|
||
|
|
.unwrap_or(false)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get the active transaction id for a session, if any.
|
||
|
|
pub fn get_transaction_id(&self, session_id: &str) -> Option<String> {
|
||
|
|
self.sessions
|
||
|
|
.get(session_id)
|
||
|
|
.and_then(|s| s.txn_id.clone())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Extract a session id from a BSON `lsid` value.
|
||
|
|
///
|
||
|
|
/// Handles the following formats:
|
||
|
|
/// - `{ "id": UUID }` (standard driver format)
|
||
|
|
/// - `{ "id": "string" }` (string shorthand)
|
||
|
|
/// - `{ "id": Binary(base64) }` (binary UUID)
|
||
|
|
pub fn extract_session_id(lsid: &Bson) -> Option<String> {
|
||
|
|
match lsid {
|
||
|
|
Bson::Document(doc) => {
|
||
|
|
if let Some(id_val) = doc.get("id") {
|
||
|
|
match id_val {
|
||
|
|
Bson::Binary(bin) => {
|
||
|
|
// UUID stored as Binary subtype 4.
|
||
|
|
let bytes = &bin.bytes;
|
||
|
|
if bytes.len() == 16 {
|
||
|
|
let uuid = uuid::Uuid::from_slice(bytes).ok()?;
|
||
|
|
Some(uuid.to_string())
|
||
|
|
} else {
|
||
|
|
// Fall back to base64 representation.
|
||
|
|
Some(base64_encode(bytes))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Bson::String(s) => Some(s.clone()),
|
||
|
|
_ => Some(format!("{}", id_val)),
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
None
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Bson::String(s) => Some(s.clone()),
|
||
|
|
_ => None,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Clean up expired sessions. Returns the number of sessions removed.
|
||
|
|
pub fn cleanup_expired(&self) -> usize {
|
||
|
|
let now = Instant::now();
|
||
|
|
let timeout = self.timeout;
|
||
|
|
let expired: Vec<String> = self
|
||
|
|
.sessions
|
||
|
|
.iter()
|
||
|
|
.filter(|entry| now.duration_since(entry.last_activity_at) > timeout)
|
||
|
|
.map(|entry| entry.id.clone())
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
let count = expired.len();
|
||
|
|
for id in &expired {
|
||
|
|
debug!(session_id = %id, "cleaning up expired session");
|
||
|
|
self.sessions.remove(id);
|
||
|
|
}
|
||
|
|
count
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for SessionEngine {
|
||
|
|
fn default() -> Self {
|
||
|
|
// 30 minutes timeout, 60 seconds cleanup interval.
|
||
|
|
Self::new(1_800_000, 60_000)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Simple base64 encoding for binary data (no external dependency needed).
|
||
|
|
fn base64_encode(data: &[u8]) -> String {
|
||
|
|
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||
|
|
let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
|
||
|
|
for chunk in data.chunks(3) {
|
||
|
|
let b0 = chunk[0] as u32;
|
||
|
|
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
|
||
|
|
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
|
||
|
|
let triple = (b0 << 16) | (b1 << 8) | b2;
|
||
|
|
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
|
||
|
|
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
|
||
|
|
if chunk.len() > 1 {
|
||
|
|
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
|
||
|
|
} else {
|
||
|
|
result.push('=');
|
||
|
|
}
|
||
|
|
if chunk.len() > 2 {
|
||
|
|
result.push(CHARS[(triple & 0x3F) as usize] as char);
|
||
|
|
} else {
|
||
|
|
result.push('=');
|
||
|
|
}
|
||
|
|
}
|
||
|
|
result
|
||
|
|
}
|