use bytes::{Buf, BufMut, BytesMut}; use tokio_util::codec::{Decoder, Encoder}; /// Packet types for the VPN binary protocol. #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PacketType { HandshakeInit = 0x01, HandshakeResp = 0x02, IpPacket = 0x10, Keepalive = 0x20, KeepaliveAck = 0x21, SessionResume = 0x30, SessionResumeOk = 0x31, SessionResumeErr = 0x32, Disconnect = 0x3F, } impl PacketType { pub fn from_u8(v: u8) -> Option { match v { 0x01 => Some(Self::HandshakeInit), 0x02 => Some(Self::HandshakeResp), 0x10 => Some(Self::IpPacket), 0x20 => Some(Self::Keepalive), 0x21 => Some(Self::KeepaliveAck), 0x30 => Some(Self::SessionResume), 0x31 => Some(Self::SessionResumeOk), 0x32 => Some(Self::SessionResumeErr), 0x3F => Some(Self::Disconnect), _ => None, } } } /// A framed packet: [type:1B][length:4B][payload:NB] #[derive(Debug, Clone)] pub struct Frame { pub packet_type: PacketType, pub payload: Vec, } /// Maximum frame payload size (64 KB). pub const MAX_FRAME_PAYLOAD: usize = 65536; /// Header size: 1 byte type + 4 bytes length. pub const HEADER_SIZE: usize = 5; /// tokio_util codec for Frame encode/decode over byte streams. pub struct FrameCodec; impl Decoder for FrameCodec { type Item = Frame; type Error = std::io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { if src.len() < HEADER_SIZE { return Ok(None); } let packet_type_byte = src[0]; let length = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize; if length > MAX_FRAME_PAYLOAD { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("Frame payload too large: {} bytes", length), )); } if src.len() < HEADER_SIZE + length { // Reserve capacity for the remaining bytes src.reserve(HEADER_SIZE + length - src.len()); return Ok(None); } let packet_type = PacketType::from_u8(packet_type_byte).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidData, format!("Unknown packet type: 0x{:02x}", packet_type_byte), ) })?; src.advance(HEADER_SIZE); let payload = src.split_to(length).to_vec(); Ok(Some(Frame { packet_type, payload, })) } } impl Encoder for FrameCodec { type Error = std::io::Error; fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { if item.payload.len() > MAX_FRAME_PAYLOAD { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Payload too large: {} bytes", item.payload.len()), )); } dst.reserve(HEADER_SIZE + item.payload.len()); dst.put_u8(item.packet_type as u8); dst.put_u32(item.payload.len() as u32); dst.put_slice(&item.payload); Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn roundtrip_frame() { let frame = Frame { packet_type: PacketType::IpPacket, payload: vec![1, 2, 3, 4, 5], }; let mut buf = BytesMut::new(); let mut codec = FrameCodec; codec.encode(frame.clone(), &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded.packet_type, PacketType::IpPacket); assert_eq!(decoded.payload, vec![1, 2, 3, 4, 5]); } #[test] fn partial_frame() { let mut buf = BytesMut::from(&[0x10, 0x00, 0x00][..]); let mut codec = FrameCodec; // Not enough bytes for header assert!(codec.decode(&mut buf).unwrap().is_none()); } #[test] fn reject_oversized_frame() { let mut buf = BytesMut::new(); buf.put_u8(0x10); // IpPacket buf.put_u32(MAX_FRAME_PAYLOAD as u32 + 1); let mut codec = FrameCodec; assert!(codec.decode(&mut buf).is_err()); } #[test] fn reject_unknown_packet_type() { let mut buf = BytesMut::new(); buf.put_u8(0xFF); buf.put_u32(0); let mut codec = FrameCodec; assert!(codec.decode(&mut buf).is_err()); } #[test] fn all_packet_types_roundtrip() { let types = [ PacketType::HandshakeInit, PacketType::HandshakeResp, PacketType::IpPacket, PacketType::Keepalive, PacketType::KeepaliveAck, PacketType::SessionResume, PacketType::SessionResumeOk, PacketType::SessionResumeErr, PacketType::Disconnect, ]; for pt in types { let frame = Frame { packet_type: pt, payload: vec![42], }; let mut buf = BytesMut::new(); let mut codec = FrameCodec; codec.encode(frame, &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded.packet_type, pt); assert_eq!(decoded.payload, vec![42]); } } }