initial
This commit is contained in:
186
rust/src/codec.rs
Normal file
186
rust/src/codec.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
/// Packet types for the VPN binary protocol.
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PacketType {
|
||||
HandshakeInit = 0x01,
|
||||
HandshakeResp = 0x02,
|
||||
IpPacket = 0x10,
|
||||
Keepalive = 0x20,
|
||||
KeepaliveAck = 0x21,
|
||||
SessionResume = 0x30,
|
||||
SessionResumeOk = 0x31,
|
||||
SessionResumeErr = 0x32,
|
||||
Disconnect = 0x3F,
|
||||
}
|
||||
|
||||
impl PacketType {
|
||||
pub fn from_u8(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0x01 => Some(Self::HandshakeInit),
|
||||
0x02 => Some(Self::HandshakeResp),
|
||||
0x10 => Some(Self::IpPacket),
|
||||
0x20 => Some(Self::Keepalive),
|
||||
0x21 => Some(Self::KeepaliveAck),
|
||||
0x30 => Some(Self::SessionResume),
|
||||
0x31 => Some(Self::SessionResumeOk),
|
||||
0x32 => Some(Self::SessionResumeErr),
|
||||
0x3F => Some(Self::Disconnect),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A framed packet: [type:1B][length:4B][payload:NB]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Frame {
|
||||
pub packet_type: PacketType,
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Maximum frame payload size (64 KB).
|
||||
pub const MAX_FRAME_PAYLOAD: usize = 65536;
|
||||
|
||||
/// Header size: 1 byte type + 4 bytes length.
|
||||
pub const HEADER_SIZE: usize = 5;
|
||||
|
||||
/// tokio_util codec for Frame encode/decode over byte streams.
|
||||
pub struct FrameCodec;
|
||||
|
||||
impl Decoder for FrameCodec {
|
||||
type Item = Frame;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, Self::Error> {
|
||||
if src.len() < HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let packet_type_byte = src[0];
|
||||
let length = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
|
||||
|
||||
if length > MAX_FRAME_PAYLOAD {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Frame payload too large: {} bytes", length),
|
||||
));
|
||||
}
|
||||
|
||||
if src.len() < HEADER_SIZE + length {
|
||||
// Reserve capacity for the remaining bytes
|
||||
src.reserve(HEADER_SIZE + length - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let packet_type = PacketType::from_u8(packet_type_byte).ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Unknown packet type: 0x{:02x}", packet_type_byte),
|
||||
)
|
||||
})?;
|
||||
|
||||
src.advance(HEADER_SIZE);
|
||||
let payload = src.split_to(length).to_vec();
|
||||
|
||||
Ok(Some(Frame {
|
||||
packet_type,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Frame> for FrameCodec {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
if item.payload.len() > MAX_FRAME_PAYLOAD {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("Payload too large: {} bytes", item.payload.len()),
|
||||
));
|
||||
}
|
||||
|
||||
dst.reserve(HEADER_SIZE + item.payload.len());
|
||||
dst.put_u8(item.packet_type as u8);
|
||||
dst.put_u32(item.payload.len() as u32);
|
||||
dst.put_slice(&item.payload);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn roundtrip_frame() {
|
||||
let frame = Frame {
|
||||
packet_type: PacketType::IpPacket,
|
||||
payload: vec![1, 2, 3, 4, 5],
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
let mut codec = FrameCodec;
|
||||
codec.encode(frame.clone(), &mut buf).unwrap();
|
||||
|
||||
let decoded = codec.decode(&mut buf).unwrap().unwrap();
|
||||
assert_eq!(decoded.packet_type, PacketType::IpPacket);
|
||||
assert_eq!(decoded.payload, vec![1, 2, 3, 4, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_frame() {
|
||||
let mut buf = BytesMut::from(&[0x10, 0x00, 0x00][..]);
|
||||
let mut codec = FrameCodec;
|
||||
// Not enough bytes for header
|
||||
assert!(codec.decode(&mut buf).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_oversized_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u8(0x10); // IpPacket
|
||||
buf.put_u32(MAX_FRAME_PAYLOAD as u32 + 1);
|
||||
let mut codec = FrameCodec;
|
||||
assert!(codec.decode(&mut buf).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_unknown_packet_type() {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u8(0xFF);
|
||||
buf.put_u32(0);
|
||||
let mut codec = FrameCodec;
|
||||
assert!(codec.decode(&mut buf).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_packet_types_roundtrip() {
|
||||
let types = [
|
||||
PacketType::HandshakeInit,
|
||||
PacketType::HandshakeResp,
|
||||
PacketType::IpPacket,
|
||||
PacketType::Keepalive,
|
||||
PacketType::KeepaliveAck,
|
||||
PacketType::SessionResume,
|
||||
PacketType::SessionResumeOk,
|
||||
PacketType::SessionResumeErr,
|
||||
PacketType::Disconnect,
|
||||
];
|
||||
|
||||
for pt in types {
|
||||
let frame = Frame {
|
||||
packet_type: pt,
|
||||
payload: vec![42],
|
||||
};
|
||||
let mut buf = BytesMut::new();
|
||||
let mut codec = FrameCodec;
|
||||
codec.encode(frame, &mut buf).unwrap();
|
||||
let decoded = codec.decode(&mut buf).unwrap().unwrap();
|
||||
assert_eq!(decoded.packet_type, pt);
|
||||
assert_eq!(decoded.payload, vec![42]);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user