use bson::Document; use std::collections::HashMap; use crate::error::WireError; use crate::opcodes::*; /// Parsed wire protocol message header (16 bytes). #[derive(Debug, Clone)] pub struct MessageHeader { pub message_length: i32, pub request_id: i32, pub response_to: i32, pub op_code: i32, } /// A parsed OP_MSG section. #[derive(Debug, Clone)] pub enum OpMsgSection { /// Section type 0: single BSON document body. Body(Document), /// Section type 1: named document sequence for bulk operations. DocumentSequence { identifier: String, documents: Vec, }, } /// A fully parsed command extracted from any message type. #[derive(Debug, Clone)] pub struct ParsedCommand { pub command_name: String, pub command: Document, pub database: String, pub request_id: i32, pub op_code: i32, /// Document sequences from OP_MSG section type 1 (e.g., "documents" for insert). pub document_sequences: Option>>, } /// Parse a message header from a byte slice (must be >= 16 bytes). pub fn parse_header(buf: &[u8]) -> MessageHeader { MessageHeader { message_length: i32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]), request_id: i32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]), response_to: i32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]), op_code: i32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]), } } /// Parse a complete message from a buffer. /// Returns the parsed command and bytes consumed, or None if not enough data. pub fn parse_message(buf: &[u8]) -> Result, WireError> { if buf.len() < 16 { return Ok(None); } let header = parse_header(buf); let msg_len = header.message_length as usize; if buf.len() < msg_len { return Ok(None); } let message_buf = &buf[..msg_len]; match header.op_code { OP_MSG => parse_op_msg(message_buf, &header).map(|cmd| Some((cmd, msg_len))), OP_QUERY => parse_op_query(message_buf, &header).map(|cmd| Some((cmd, msg_len))), other => Err(WireError::UnsupportedOpCode(other)), } } /// Parse an OP_MSG message. fn parse_op_msg(buf: &[u8], header: &MessageHeader) -> Result { let mut offset = 16; // skip header let flag_bits = u32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]); offset += 4; let mut body: Option = None; let mut document_sequences: HashMap> = HashMap::new(); // Parse sections until end (or checksum) let message_end = if flag_bits & MSG_FLAG_CHECKSUM_PRESENT != 0 { header.message_length as usize - 4 } else { header.message_length as usize }; while offset < message_end { let section_type = buf[offset]; offset += 1; match section_type { SECTION_BODY => { let doc_size = i32::from_le_bytes([ buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3], ]) as usize; let doc = bson::from_slice(&buf[offset..offset + doc_size])?; body = Some(doc); offset += doc_size; } SECTION_DOCUMENT_SEQUENCE => { let section_size = i32::from_le_bytes([ buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3], ]) as usize; let section_end = offset + section_size; offset += 4; // Read identifier (C string, null-terminated) let id_start = offset; while offset < section_end && buf[offset] != 0 { offset += 1; } let identifier = std::str::from_utf8(&buf[id_start..offset]) .unwrap_or("") .to_string(); offset += 1; // skip null terminator // Read documents let mut documents = Vec::new(); while offset < section_end { let doc_size = i32::from_le_bytes([ buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3], ]) as usize; let doc = bson::from_slice(&buf[offset..offset + doc_size])?; documents.push(doc); offset += doc_size; } document_sequences.insert(identifier, documents); } other => return Err(WireError::UnknownSectionType(other)), } } let command = body.ok_or(WireError::MissingBody)?; let command_name = command .keys() .next() .map(|s| s.to_string()) .unwrap_or_default(); let database = command .get_str("$db") .unwrap_or("admin") .to_string(); Ok(ParsedCommand { command_name, command, database, request_id: header.request_id, op_code: header.op_code, document_sequences: if document_sequences.is_empty() { None } else { Some(document_sequences) }, }) } /// Parse an OP_QUERY message (legacy, used for initial driver handshake). fn parse_op_query(buf: &[u8], header: &MessageHeader) -> Result { let mut offset = 16; // skip header let _flags = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]); offset += 4; // Read full collection name (C string) let name_start = offset; while offset < buf.len() && buf[offset] != 0 { offset += 1; } let full_collection_name = std::str::from_utf8(&buf[name_start..offset]) .unwrap_or("") .to_string(); offset += 1; // skip null terminator let _number_to_skip = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]); offset += 4; let _number_to_return = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]); offset += 4; // Read query document let doc_size = i32::from_le_bytes([buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]) as usize; let query: Document = bson::from_slice(&buf[offset..offset + doc_size])?; // Extract database from collection name (format: "dbname.$cmd") let parts: Vec<&str> = full_collection_name.splitn(2, '.').collect(); let database = parts.first().unwrap_or(&"admin").to_string(); let mut command_name = query .keys() .next() .map(|s| s.to_string()) .unwrap_or_else(|| "find".to_string()); // Map legacy isMaster/ismaster to hello if parts.get(1) == Some(&"$cmd") { if command_name == "isMaster" || command_name == "ismaster" { command_name = "hello".to_string(); } } else { command_name = "find".to_string(); } Ok(ParsedCommand { command_name, command: query, database, request_id: header.request_id, op_code: header.op_code, document_sequences: None, }) } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_header() { let mut buf = [0u8; 16]; buf[0..4].copy_from_slice(&100i32.to_le_bytes()); // messageLength buf[4..8].copy_from_slice(&42i32.to_le_bytes()); // requestID buf[8..12].copy_from_slice(&0i32.to_le_bytes()); // responseTo buf[12..16].copy_from_slice(&OP_MSG.to_le_bytes()); // opCode let header = parse_header(&buf); assert_eq!(header.message_length, 100); assert_eq!(header.request_id, 42); assert_eq!(header.response_to, 0); assert_eq!(header.op_code, OP_MSG); } }