Files
smartvpn/rust/tests/wg_e2e.rs

321 lines
13 KiB
Rust

//! End-to-end WireGuard protocol tests over real UDP sockets.
//!
//! Entirely userspace — no root, no TUN devices.
//! Two boringtun `Tunn` instances exchange real WireGuard packets
//! over loopback UDP, validating handshake, encryption, and data flow.
use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration;
use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::{PublicKey, StaticSecret};
use tokio::net::UdpSocket;
use tokio::time;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use smartvpn_daemon::wireguard::generate_wg_keypair;
// ============================================================================
// Helpers
// ============================================================================
fn parse_key_pair(pub_b64: &str, priv_b64: &str) -> (PublicKey, StaticSecret) {
let pub_bytes: [u8; 32] = BASE64.decode(pub_b64).unwrap().try_into().unwrap();
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
(PublicKey::from(pub_bytes), StaticSecret::from(priv_bytes))
}
fn clone_secret(priv_b64: &str) -> StaticSecret {
let priv_bytes: [u8; 32] = BASE64.decode(priv_b64).unwrap().try_into().unwrap();
StaticSecret::from(priv_bytes)
}
fn make_ipv4_packet(src: Ipv4Addr, dst: Ipv4Addr, payload: &[u8]) -> Vec<u8> {
let total_len = 20 + payload.len();
let mut pkt = vec![0u8; total_len];
pkt[0] = 0x45;
pkt[2] = (total_len >> 8) as u8;
pkt[3] = total_len as u8;
pkt[9] = 0x11;
pkt[12..16].copy_from_slice(&src.octets());
pkt[16..20].copy_from_slice(&dst.octets());
pkt[20..].copy_from_slice(payload);
pkt
}
/// Send any WriteToNetwork result, then drain the tunn for more packets.
async fn send_and_drain(
tunn: &mut Tunn,
pkt: &[u8],
socket: &UdpSocket,
peer: SocketAddr,
) {
socket.send_to(pkt, peer).await.unwrap();
let mut drain_buf = vec![0u8; 2048];
loop {
match tunn.decapsulate(None, &[], &mut drain_buf) {
TunnResult::WriteToNetwork(p) => { socket.send_to(p, peer).await.unwrap(); }
_ => break,
}
}
}
/// Try to receive a UDP packet and decapsulate it. Returns decrypted IP data if any.
async fn try_recv_decap(
tunn: &mut Tunn,
socket: &UdpSocket,
timeout_ms: u64,
) -> Option<(Vec<u8>, Ipv4Addr, SocketAddr)> {
let mut recv_buf = vec![0u8; 65536];
let mut dst_buf = vec![0u8; 65536];
let (n, src_addr) = match time::timeout(
Duration::from_millis(timeout_ms),
socket.recv_from(&mut recv_buf),
).await {
Ok(Ok(r)) => r,
_ => return None,
};
let result = tunn.decapsulate(Some(src_addr.ip()), &recv_buf[..n], &mut dst_buf);
match result {
TunnResult::WriteToNetwork(pkt) => {
send_and_drain(tunn, pkt, socket, src_addr).await;
None
}
TunnResult::WriteToTunnelV4(pkt, addr) => Some((pkt.to_vec(), addr, src_addr)),
TunnResult::WriteToTunnelV6(_, _) => None,
TunnResult::Done => None,
TunnResult::Err(_) => None,
}
}
/// Drive the full WireGuard handshake between client and server over real UDP.
async fn do_handshake(
client_tunn: &mut Tunn,
server_tunn: &mut Tunn,
client_socket: &UdpSocket,
server_socket: &UdpSocket,
server_addr: SocketAddr,
) {
let mut buf = vec![0u8; 2048];
let mut recv_buf = vec![0u8; 65536];
let mut dst_buf = vec![0u8; 65536];
// Step 1: Client initiates handshake
match client_tunn.encapsulate(&[], &mut buf) {
TunnResult::WriteToNetwork(pkt) => {
client_socket.send_to(pkt, server_addr).await.unwrap();
}
_ => panic!("Expected handshake init"),
}
// Step 2: Server receives init → sends response
let (n, client_from) = server_socket.recv_from(&mut recv_buf).await.unwrap();
match server_tunn.decapsulate(Some(client_from.ip()), &recv_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
send_and_drain(server_tunn, pkt, server_socket, client_from).await;
}
other => panic!("Expected WriteToNetwork from server, got variant {}", variant_name(&other)),
}
// Step 3: Client receives response
let (n, _) = client_socket.recv_from(&mut recv_buf).await.unwrap();
match client_tunn.decapsulate(Some(server_addr.ip()), &recv_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
send_and_drain(client_tunn, pkt, client_socket, server_addr).await;
}
TunnResult::Done => {}
_ => {}
}
// Step 4: Process any remaining handshake packets
let _ = try_recv_decap(server_tunn, server_socket, 200).await;
let _ = try_recv_decap(client_tunn, client_socket, 100).await;
// Step 5: Timer ticks to settle
for _ in 0..3 {
match server_tunn.update_timers(&mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
server_socket.send_to(pkt, client_from).await.unwrap();
}
_ => {}
}
match client_tunn.update_timers(&mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
client_socket.send_to(pkt, server_addr).await.unwrap();
}
_ => {}
}
let _ = try_recv_decap(server_tunn, server_socket, 50).await;
let _ = try_recv_decap(client_tunn, client_socket, 50).await;
}
}
fn variant_name(r: &TunnResult) -> &'static str {
match r {
TunnResult::Done => "Done",
TunnResult::Err(_) => "Err",
TunnResult::WriteToNetwork(_) => "WriteToNetwork",
TunnResult::WriteToTunnelV4(_, _) => "WriteToTunnelV4",
TunnResult::WriteToTunnelV6(_, _) => "WriteToTunnelV6",
}
}
/// Encapsulate an IP packet and send it, then loop-receive on the other side until decrypted.
async fn send_and_expect_data(
sender_tunn: &mut Tunn,
receiver_tunn: &mut Tunn,
sender_socket: &UdpSocket,
receiver_socket: &UdpSocket,
dest_addr: SocketAddr,
ip_packet: &[u8],
) -> (Vec<u8>, Ipv4Addr) {
let mut enc_buf = vec![0u8; 65536];
match sender_tunn.encapsulate(ip_packet, &mut enc_buf) {
TunnResult::WriteToNetwork(pkt) => {
sender_socket.send_to(pkt, dest_addr).await.unwrap();
}
TunnResult::Err(e) => panic!("Encapsulate failed: {:?}", e),
other => panic!("Expected WriteToNetwork, got {}", variant_name(&other)),
}
// Receive — may need a few rounds for control packets
for _ in 0..10 {
if let Some((data, addr, _)) = try_recv_decap(receiver_tunn, receiver_socket, 1000).await {
return (data, addr);
}
}
panic!("Did not receive decrypted IP packet");
}
// ============================================================================
// Test 1: Single client ↔ server bidirectional data exchange
// ============================================================================
#[tokio::test]
async fn wg_e2e_single_client_bidirectional() {
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_addr = server_socket.local_addr().unwrap();
let client_addr = client_socket.local_addr().unwrap();
let mut server_tunn = Tunn::new(server_secret, client_public, None, None, 0, None);
let mut client_tunn = Tunn::new(client_secret, server_public, None, None, 1, None);
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
// Client → Server
let pkt_c2s = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"Hello from client!");
let (decrypted, src_ip) = send_and_expect_data(
&mut client_tunn, &mut server_tunn,
&client_socket, &server_socket,
server_addr, &pkt_c2s,
).await;
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
assert_eq!(&decrypted[..pkt_c2s.len()], &pkt_c2s[..]);
// Server → Client
let pkt_s2c = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2), b"Hello from server!");
let (decrypted, src_ip) = send_and_expect_data(
&mut server_tunn, &mut client_tunn,
&server_socket, &client_socket,
client_addr, &pkt_s2c,
).await;
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 1));
assert_eq!(&decrypted[..pkt_s2c.len()], &pkt_s2c[..]);
}
// ============================================================================
// Test 2: Two clients ↔ one server (peer routing)
// ============================================================================
#[tokio::test]
async fn wg_e2e_two_clients_peer_routing() {
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
let (client1_pub_b64, client1_priv_b64) = generate_wg_keypair();
let (client2_pub_b64, client2_priv_b64) = generate_wg_keypair();
let (server_public, _) = parse_key_pair(&server_pub_b64, &server_priv_b64);
let (client1_public, client1_secret) = parse_key_pair(&client1_pub_b64, &client1_priv_b64);
let (client2_public, client2_secret) = parse_key_pair(&client2_pub_b64, &client2_priv_b64);
// Separate server socket per peer to avoid UDP mux complexity in test
let server_socket_1 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_socket_2 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let client1_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let client2_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_addr_1 = server_socket_1.local_addr().unwrap();
let server_addr_2 = server_socket_2.local_addr().unwrap();
let mut server_tunn_1 = Tunn::new(clone_secret(&server_priv_b64), client1_public, None, None, 0, None);
let mut server_tunn_2 = Tunn::new(clone_secret(&server_priv_b64), client2_public, None, None, 1, None);
let mut client1_tunn = Tunn::new(client1_secret, server_public.clone(), None, None, 2, None);
let mut client2_tunn = Tunn::new(client2_secret, server_public, None, None, 3, None);
do_handshake(&mut client1_tunn, &mut server_tunn_1, &client1_socket, &server_socket_1, server_addr_1).await;
do_handshake(&mut client2_tunn, &mut server_tunn_2, &client2_socket, &server_socket_2, server_addr_2).await;
// Client 1 → Server
let pkt1 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"From client 1");
let (decrypted, src_ip) = send_and_expect_data(
&mut client1_tunn, &mut server_tunn_1,
&client1_socket, &server_socket_1,
server_addr_1, &pkt1,
).await;
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
assert_eq!(&decrypted[..pkt1.len()], &pkt1[..]);
// Client 2 → Server
let pkt2 = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 3), Ipv4Addr::new(10, 0, 0, 1), b"From client 2");
let (decrypted, src_ip) = send_and_expect_data(
&mut client2_tunn, &mut server_tunn_2,
&client2_socket, &server_socket_2,
server_addr_2, &pkt2,
).await;
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 3));
assert_eq!(&decrypted[..pkt2.len()], &pkt2[..]);
}
// ============================================================================
// Test 3: Preshared key handshake + data exchange
// ============================================================================
#[tokio::test]
async fn wg_e2e_preshared_key() {
let (server_pub_b64, server_priv_b64) = generate_wg_keypair();
let (client_pub_b64, client_priv_b64) = generate_wg_keypair();
let (server_public, server_secret) = parse_key_pair(&server_pub_b64, &server_priv_b64);
let (client_public, client_secret) = parse_key_pair(&client_pub_b64, &client_priv_b64);
let psk: [u8; 32] = rand::random();
let server_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_addr = server_socket.local_addr().unwrap();
let mut server_tunn = Tunn::new(server_secret, client_public, Some(psk), None, 0, None);
let mut client_tunn = Tunn::new(client_secret, server_public, Some(psk), None, 1, None);
do_handshake(&mut client_tunn, &mut server_tunn, &client_socket, &server_socket, server_addr).await;
let pkt = make_ipv4_packet(Ipv4Addr::new(10, 0, 0, 2), Ipv4Addr::new(10, 0, 0, 1), b"PSK-protected data");
let (decrypted, src_ip) = send_and_expect_data(
&mut client_tunn, &mut server_tunn,
&client_socket, &server_socket,
server_addr, &pkt,
).await;
assert_eq!(src_ip, Ipv4Addr::new(10, 0, 0, 2));
assert_eq!(&decrypted[..pkt.len()], &pkt[..]);
}