237 lines
7.7 KiB
Rust
237 lines
7.7 KiB
Rust
use bson::Document;
|
|
use std::collections::HashMap;
|
|
|
|
use crate::error::WireError;
|
|
use crate::opcodes::*;
|
|
|
|
/// Parsed wire protocol message header (16 bytes).
|
|
#[derive(Debug, Clone)]
|
|
pub struct MessageHeader {
|
|
pub message_length: i32,
|
|
pub request_id: i32,
|
|
pub response_to: i32,
|
|
pub op_code: i32,
|
|
}
|
|
|
|
/// A parsed OP_MSG section.
|
|
#[derive(Debug, Clone)]
|
|
pub enum OpMsgSection {
|
|
/// Section type 0: single BSON document body.
|
|
Body(Document),
|
|
/// Section type 1: named document sequence for bulk operations.
|
|
DocumentSequence {
|
|
identifier: String,
|
|
documents: Vec<Document>,
|
|
},
|
|
}
|
|
|
|
/// A fully parsed command extracted from any message type.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ParsedCommand {
|
|
pub command_name: String,
|
|
pub command: Document,
|
|
pub database: String,
|
|
pub request_id: i32,
|
|
pub op_code: i32,
|
|
/// Document sequences from OP_MSG section type 1 (e.g., "documents" for insert).
|
|
pub document_sequences: Option<HashMap<String, Vec<Document>>>,
|
|
}
|
|
|
|
/// Parse a message header from a byte slice (must be >= 16 bytes).
|
|
pub fn parse_header(buf: &[u8]) -> MessageHeader {
|
|
MessageHeader {
|
|
message_length: i32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
|
|
request_id: i32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
|
|
response_to: i32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]),
|
|
op_code: i32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
|
|
}
|
|
}
|
|
|
|
/// Parse a complete message from a buffer.
|
|
/// Returns the parsed command and bytes consumed, or None if not enough data.
|
|
pub fn parse_message(buf: &[u8]) -> Result<Option<(ParsedCommand, usize)>, WireError> {
|
|
if buf.len() < 16 {
|
|
return Ok(None);
|
|
}
|
|
|
|
let header = parse_header(buf);
|
|
let msg_len = header.message_length as usize;
|
|
|
|
if buf.len() < msg_len {
|
|
return Ok(None);
|
|
}
|
|
|
|
let message_buf = &buf[..msg_len];
|
|
|
|
match header.op_code {
|
|
OP_MSG => parse_op_msg(message_buf, &header).map(|cmd| Some((cmd, msg_len))),
|
|
OP_QUERY => parse_op_query(message_buf, &header).map(|cmd| Some((cmd, msg_len))),
|
|
other => Err(WireError::UnsupportedOpCode(other)),
|
|
}
|
|
}
|
|
|
|
/// Parse an OP_MSG message.
|
|
fn parse_op_msg(buf: &[u8], header: &MessageHeader) -> Result<ParsedCommand, WireError> {
|
|
let mut offset = 16; // skip header
|
|
|
|
let flag_bits = u32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
|
|
offset += 4;
|
|
|
|
let mut body: Option<Document> = None;
|
|
let mut document_sequences: HashMap<String, Vec<Document>> = HashMap::new();
|
|
|
|
// Parse sections until end (or checksum)
|
|
let message_end = if flag_bits & MSG_FLAG_CHECKSUM_PRESENT != 0 {
|
|
header.message_length as usize - 4
|
|
} else {
|
|
header.message_length as usize
|
|
};
|
|
|
|
while offset < message_end {
|
|
let section_type = buf[offset];
|
|
offset += 1;
|
|
|
|
match section_type {
|
|
SECTION_BODY => {
|
|
let doc_size = i32::from_le_bytes([
|
|
buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
|
|
]) as usize;
|
|
let doc = bson::from_slice(&buf[offset..offset + doc_size])?;
|
|
body = Some(doc);
|
|
offset += doc_size;
|
|
}
|
|
SECTION_DOCUMENT_SEQUENCE => {
|
|
let section_size = i32::from_le_bytes([
|
|
buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
|
|
]) as usize;
|
|
let section_end = offset + section_size;
|
|
offset += 4;
|
|
|
|
// Read identifier (C string, null-terminated)
|
|
let id_start = offset;
|
|
while offset < section_end && buf[offset] != 0 {
|
|
offset += 1;
|
|
}
|
|
let identifier = std::str::from_utf8(&buf[id_start..offset])
|
|
.unwrap_or("")
|
|
.to_string();
|
|
offset += 1; // skip null terminator
|
|
|
|
// Read documents
|
|
let mut documents = Vec::new();
|
|
while offset < section_end {
|
|
let doc_size = i32::from_le_bytes([
|
|
buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
|
|
]) as usize;
|
|
let doc = bson::from_slice(&buf[offset..offset + doc_size])?;
|
|
documents.push(doc);
|
|
offset += doc_size;
|
|
}
|
|
|
|
document_sequences.insert(identifier, documents);
|
|
}
|
|
other => return Err(WireError::UnknownSectionType(other)),
|
|
}
|
|
}
|
|
|
|
let command = body.ok_or(WireError::MissingBody)?;
|
|
let command_name = command
|
|
.keys()
|
|
.next()
|
|
.map(|s| s.to_string())
|
|
.unwrap_or_default();
|
|
let database = command
|
|
.get_str("$db")
|
|
.unwrap_or("admin")
|
|
.to_string();
|
|
|
|
Ok(ParsedCommand {
|
|
command_name,
|
|
command,
|
|
database,
|
|
request_id: header.request_id,
|
|
op_code: header.op_code,
|
|
document_sequences: if document_sequences.is_empty() {
|
|
None
|
|
} else {
|
|
Some(document_sequences)
|
|
},
|
|
})
|
|
}
|
|
|
|
/// Parse an OP_QUERY message (legacy, used for initial driver handshake).
|
|
fn parse_op_query(buf: &[u8], header: &MessageHeader) -> Result<ParsedCommand, WireError> {
|
|
let mut offset = 16; // skip header
|
|
|
|
let _flags = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
|
|
offset += 4;
|
|
|
|
// Read full collection name (C string)
|
|
let name_start = offset;
|
|
while offset < buf.len() && buf[offset] != 0 {
|
|
offset += 1;
|
|
}
|
|
let full_collection_name = std::str::from_utf8(&buf[name_start..offset])
|
|
.unwrap_or("")
|
|
.to_string();
|
|
offset += 1; // skip null terminator
|
|
|
|
let _number_to_skip = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
|
|
offset += 4;
|
|
|
|
let _number_to_return = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]);
|
|
offset += 4;
|
|
|
|
// Read query document
|
|
let doc_size = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]) as usize;
|
|
let query: Document = bson::from_slice(&buf[offset..offset + doc_size])?;
|
|
|
|
// Extract database from collection name (format: "dbname.$cmd")
|
|
let parts: Vec<&str> = full_collection_name.splitn(2, '.').collect();
|
|
let database = parts.first().unwrap_or(&"admin").to_string();
|
|
|
|
let mut command_name = query
|
|
.keys()
|
|
.next()
|
|
.map(|s| s.to_string())
|
|
.unwrap_or_else(|| "find".to_string());
|
|
|
|
// Map legacy isMaster/ismaster to hello
|
|
if parts.get(1) == Some(&"$cmd") {
|
|
if command_name == "isMaster" || command_name == "ismaster" {
|
|
command_name = "hello".to_string();
|
|
}
|
|
} else {
|
|
command_name = "find".to_string();
|
|
}
|
|
|
|
Ok(ParsedCommand {
|
|
command_name,
|
|
command: query,
|
|
database,
|
|
request_id: header.request_id,
|
|
op_code: header.op_code,
|
|
document_sequences: None,
|
|
})
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_parse_header() {
|
|
let mut buf = [0u8; 16];
|
|
buf[0..4].copy_from_slice(&100i32.to_le_bytes()); // messageLength
|
|
buf[4..8].copy_from_slice(&42i32.to_le_bytes()); // requestID
|
|
buf[8..12].copy_from_slice(&0i32.to_le_bytes()); // responseTo
|
|
buf[12..16].copy_from_slice(&OP_MSG.to_le_bytes()); // opCode
|
|
|
|
let header = parse_header(&buf);
|
|
assert_eq!(header.message_length, 100);
|
|
assert_eq!(header.request_id, 42);
|
|
assert_eq!(header.response_to, 0);
|
|
assert_eq!(header.op_code, OP_MSG);
|
|
}
|
|
}
|