1330 lines
45 KiB
Rust
1330 lines
45 KiB
Rust
|
|
use std::collections::HashMap;
|
||
|
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||
|
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::time::Instant;
|
||
|
|
|
||
|
|
use anyhow::{anyhow, Result};
|
||
|
|
use base64::engine::general_purpose::STANDARD as BASE64;
|
||
|
|
use base64::Engine;
|
||
|
|
use boringtun::noise::rate_limiter::RateLimiter;
|
||
|
|
use boringtun::noise::{Tunn, TunnResult};
|
||
|
|
use boringtun::x25519::{PublicKey, StaticSecret};
|
||
|
|
use rand::rngs::OsRng;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||
|
|
use tokio::net::UdpSocket;
|
||
|
|
use tokio::sync::{mpsc, oneshot, RwLock};
|
||
|
|
use tracing::{debug, error, info, warn};
|
||
|
|
|
||
|
|
use crate::network;
|
||
|
|
use crate::tunnel::{self, TunConfig};
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Constants
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
const MAX_UDP_PACKET: usize = 65536;
|
||
|
|
const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET;
|
||
|
|
/// Minimum dst buffer size for boringtun encapsulate/decapsulate
|
||
|
|
const _MIN_DST_BUF: usize = 148;
|
||
|
|
const TIMER_TICK_MS: u64 = 100;
|
||
|
|
const DEFAULT_WG_PORT: u16 = 51820;
|
||
|
|
const DEFAULT_TUN_ADDRESS: &str = "10.8.0.1";
|
||
|
|
const DEFAULT_TUN_NETMASK: &str = "255.255.255.0";
|
||
|
|
const DEFAULT_MTU: u16 = 1420;
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Configuration types
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
pub struct WgPeerConfig {
|
||
|
|
pub public_key: String,
|
||
|
|
#[serde(default)]
|
||
|
|
pub preshared_key: Option<String>,
|
||
|
|
pub allowed_ips: Vec<String>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub endpoint: Option<String>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub persistent_keepalive: Option<u16>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Deserialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
pub struct WgServerConfig {
|
||
|
|
pub private_key: String,
|
||
|
|
#[serde(default)]
|
||
|
|
pub listen_port: Option<u16>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub tun_address: Option<String>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub tun_netmask: Option<String>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub mtu: Option<u16>,
|
||
|
|
pub peers: Vec<WgPeerConfig>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub dns: Option<Vec<String>>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub enable_nat: Option<bool>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub subnet: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Deserialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
pub struct WgClientConfig {
|
||
|
|
pub private_key: String,
|
||
|
|
pub address: String,
|
||
|
|
#[serde(default)]
|
||
|
|
pub address_prefix: Option<u8>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub dns: Option<Vec<String>>,
|
||
|
|
#[serde(default)]
|
||
|
|
pub mtu: Option<u16>,
|
||
|
|
pub peer: WgPeerConfig,
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Stats types
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Default, Serialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
pub struct WgPeerStats {
|
||
|
|
pub bytes_sent: u64,
|
||
|
|
pub bytes_received: u64,
|
||
|
|
pub packets_sent: u64,
|
||
|
|
pub packets_received: u64,
|
||
|
|
pub last_handshake_time: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
pub struct WgPeerInfo {
|
||
|
|
pub public_key: String,
|
||
|
|
pub allowed_ips: Vec<String>,
|
||
|
|
pub endpoint: Option<String>,
|
||
|
|
pub persistent_keepalive: Option<u16>,
|
||
|
|
#[serde(flatten)]
|
||
|
|
pub stats: WgPeerStats,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Default, Serialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
pub struct WgServerStats {
|
||
|
|
pub total_bytes_sent: u64,
|
||
|
|
pub total_bytes_received: u64,
|
||
|
|
pub total_packets_sent: u64,
|
||
|
|
pub total_packets_received: u64,
|
||
|
|
pub active_peers: usize,
|
||
|
|
pub uptime_seconds: f64,
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Key generation and parsing
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
/// Generate a WireGuard-compatible X25519 keypair.
|
||
|
|
/// Returns (public_key_base64, private_key_base64).
|
||
|
|
pub fn generate_wg_keypair() -> (String, String) {
|
||
|
|
let private = StaticSecret::random_from_rng(OsRng);
|
||
|
|
let public = PublicKey::from(&private);
|
||
|
|
let priv_b64 = BASE64.encode(private.to_bytes());
|
||
|
|
let pub_b64 = BASE64.encode(public.to_bytes());
|
||
|
|
(pub_b64, priv_b64)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn parse_private_key(b64: &str) -> Result<StaticSecret> {
|
||
|
|
let bytes = BASE64.decode(b64)?;
|
||
|
|
if bytes.len() != 32 {
|
||
|
|
return Err(anyhow!("Private key must be 32 bytes, got {}", bytes.len()));
|
||
|
|
}
|
||
|
|
let mut arr = [0u8; 32];
|
||
|
|
arr.copy_from_slice(&bytes);
|
||
|
|
Ok(StaticSecret::from(arr))
|
||
|
|
}
|
||
|
|
|
||
|
|
fn parse_public_key(b64: &str) -> Result<PublicKey> {
|
||
|
|
let bytes = BASE64.decode(b64)?;
|
||
|
|
if bytes.len() != 32 {
|
||
|
|
return Err(anyhow!("Public key must be 32 bytes, got {}", bytes.len()));
|
||
|
|
}
|
||
|
|
let mut arr = [0u8; 32];
|
||
|
|
arr.copy_from_slice(&bytes);
|
||
|
|
Ok(PublicKey::from(arr))
|
||
|
|
}
|
||
|
|
|
||
|
|
fn parse_preshared_key(b64: &str) -> Result<[u8; 32]> {
|
||
|
|
let bytes = BASE64.decode(b64)?;
|
||
|
|
if bytes.len() != 32 {
|
||
|
|
return Err(anyhow!(
|
||
|
|
"Preshared key must be 32 bytes, got {}",
|
||
|
|
bytes.len()
|
||
|
|
));
|
||
|
|
}
|
||
|
|
let mut arr = [0u8; 32];
|
||
|
|
arr.copy_from_slice(&bytes);
|
||
|
|
Ok(arr)
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// AllowedIPs matching
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
struct AllowedIp {
|
||
|
|
addr: IpAddr,
|
||
|
|
prefix_len: u8,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AllowedIp {
|
||
|
|
fn parse(cidr: &str) -> Result<Self> {
|
||
|
|
let parts: Vec<&str> = cidr.split('/').collect();
|
||
|
|
if parts.len() != 2 {
|
||
|
|
return Err(anyhow!("Invalid CIDR: {}", cidr));
|
||
|
|
}
|
||
|
|
let addr: IpAddr = parts[0].parse()?;
|
||
|
|
let prefix_len: u8 = parts[1].parse()?;
|
||
|
|
match addr {
|
||
|
|
IpAddr::V4(_) if prefix_len > 32 => {
|
||
|
|
return Err(anyhow!("IPv4 prefix length {} > 32", prefix_len))
|
||
|
|
}
|
||
|
|
IpAddr::V6(_) if prefix_len > 128 => {
|
||
|
|
return Err(anyhow!("IPv6 prefix length {} > 128", prefix_len))
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
Ok(Self { addr, prefix_len })
|
||
|
|
}
|
||
|
|
|
||
|
|
fn matches(&self, ip: IpAddr) -> bool {
|
||
|
|
match (self.addr, ip) {
|
||
|
|
(IpAddr::V4(net), IpAddr::V4(target)) => {
|
||
|
|
if self.prefix_len == 0 {
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
if self.prefix_len >= 32 {
|
||
|
|
return net == target;
|
||
|
|
}
|
||
|
|
let mask = u32::MAX << (32 - self.prefix_len);
|
||
|
|
(u32::from(net) & mask) == (u32::from(target) & mask)
|
||
|
|
}
|
||
|
|
(IpAddr::V6(net), IpAddr::V6(target)) => {
|
||
|
|
if self.prefix_len == 0 {
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
if self.prefix_len >= 128 {
|
||
|
|
return net == target;
|
||
|
|
}
|
||
|
|
let net_bits = u128::from(net);
|
||
|
|
let target_bits = u128::from(target);
|
||
|
|
let mask = u128::MAX << (128 - self.prefix_len);
|
||
|
|
(net_bits & mask) == (target_bits & mask)
|
||
|
|
}
|
||
|
|
_ => false,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Extract destination IP from an IP packet header.
|
||
|
|
fn extract_dst_ip(packet: &[u8]) -> Option<IpAddr> {
|
||
|
|
if packet.is_empty() {
|
||
|
|
return None;
|
||
|
|
}
|
||
|
|
let version = packet[0] >> 4;
|
||
|
|
match version {
|
||
|
|
4 if packet.len() >= 20 => {
|
||
|
|
let dst = Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]);
|
||
|
|
Some(IpAddr::V4(dst))
|
||
|
|
}
|
||
|
|
6 if packet.len() >= 40 => {
|
||
|
|
let mut octets = [0u8; 16];
|
||
|
|
octets.copy_from_slice(&packet[24..40]);
|
||
|
|
Some(IpAddr::V6(Ipv6Addr::from(octets)))
|
||
|
|
}
|
||
|
|
_ => None,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Dynamic peer management commands
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
enum WgCommand {
|
||
|
|
AddPeer(WgPeerConfig, oneshot::Sender<Result<()>>),
|
||
|
|
RemovePeer(String, oneshot::Sender<Result<()>>),
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Internal peer state (owned by event loop)
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
struct PeerState {
|
||
|
|
tunn: Tunn,
|
||
|
|
public_key_b64: String,
|
||
|
|
allowed_ips: Vec<AllowedIp>,
|
||
|
|
endpoint: Option<SocketAddr>,
|
||
|
|
#[allow(dead_code)]
|
||
|
|
persistent_keepalive: Option<u16>,
|
||
|
|
stats: WgPeerStats,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PeerState {
|
||
|
|
fn matches_dst(&self, dst_ip: IpAddr) -> bool {
|
||
|
|
self.allowed_ips.iter().any(|aip| aip.matches(dst_ip))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// WgServer
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
pub struct WgServer {
|
||
|
|
shutdown_tx: Option<oneshot::Sender<()>>,
|
||
|
|
command_tx: Option<mpsc::Sender<WgCommand>>,
|
||
|
|
shared_stats: Arc<RwLock<HashMap<String, WgPeerStats>>>,
|
||
|
|
server_stats: Arc<RwLock<WgServerStats>>,
|
||
|
|
started_at: Option<Instant>,
|
||
|
|
listen_port: Option<u16>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl WgServer {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
shutdown_tx: None,
|
||
|
|
command_tx: None,
|
||
|
|
shared_stats: Arc::new(RwLock::new(HashMap::new())),
|
||
|
|
server_stats: Arc::new(RwLock::new(WgServerStats::default())),
|
||
|
|
started_at: None,
|
||
|
|
listen_port: None,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn is_running(&self) -> bool {
|
||
|
|
self.shutdown_tx.is_some()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn start(&mut self, config: WgServerConfig) -> Result<()> {
|
||
|
|
if self.is_running() {
|
||
|
|
return Err(anyhow!("WireGuard server is already running"));
|
||
|
|
}
|
||
|
|
|
||
|
|
let listen_port = config.listen_port.unwrap_or(DEFAULT_WG_PORT);
|
||
|
|
let tun_address = config
|
||
|
|
.tun_address
|
||
|
|
.as_deref()
|
||
|
|
.unwrap_or(DEFAULT_TUN_ADDRESS);
|
||
|
|
let tun_netmask = config
|
||
|
|
.tun_netmask
|
||
|
|
.as_deref()
|
||
|
|
.unwrap_or(DEFAULT_TUN_NETMASK);
|
||
|
|
let mtu = config.mtu.unwrap_or(DEFAULT_MTU);
|
||
|
|
|
||
|
|
// Parse server private key
|
||
|
|
let server_private = parse_private_key(&config.private_key)?;
|
||
|
|
let server_public = PublicKey::from(&server_private);
|
||
|
|
|
||
|
|
// Create rate limiter for DDoS protection
|
||
|
|
let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64));
|
||
|
|
|
||
|
|
// Build peer state
|
||
|
|
let peer_index = AtomicU32::new(0);
|
||
|
|
let mut peers: Vec<PeerState> = Vec::with_capacity(config.peers.len());
|
||
|
|
|
||
|
|
for peer_config in &config.peers {
|
||
|
|
let peer_public = parse_public_key(&peer_config.public_key)?;
|
||
|
|
let psk = match &peer_config.preshared_key {
|
||
|
|
Some(k) => Some(parse_preshared_key(k)?),
|
||
|
|
None => None,
|
||
|
|
};
|
||
|
|
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
|
||
|
|
|
||
|
|
// Clone the private key for each Tunn (StaticSecret doesn't implement Clone,
|
||
|
|
// so re-parse from config)
|
||
|
|
let priv_copy = parse_private_key(&config.private_key)?;
|
||
|
|
|
||
|
|
let tunn = Tunn::new(
|
||
|
|
priv_copy,
|
||
|
|
peer_public,
|
||
|
|
psk,
|
||
|
|
peer_config.persistent_keepalive,
|
||
|
|
idx,
|
||
|
|
Some(rate_limiter.clone()),
|
||
|
|
);
|
||
|
|
|
||
|
|
let allowed_ips: Vec<AllowedIp> = peer_config
|
||
|
|
.allowed_ips
|
||
|
|
.iter()
|
||
|
|
.map(|cidr| AllowedIp::parse(cidr))
|
||
|
|
.collect::<Result<Vec<_>>>()?;
|
||
|
|
|
||
|
|
let endpoint = match &peer_config.endpoint {
|
||
|
|
Some(ep) => Some(ep.parse::<SocketAddr>()?),
|
||
|
|
None => None,
|
||
|
|
};
|
||
|
|
|
||
|
|
peers.push(PeerState {
|
||
|
|
tunn,
|
||
|
|
public_key_b64: peer_config.public_key.clone(),
|
||
|
|
allowed_ips,
|
||
|
|
endpoint,
|
||
|
|
persistent_keepalive: peer_config.persistent_keepalive,
|
||
|
|
stats: WgPeerStats::default(),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create TUN device
|
||
|
|
let tun_config = TunConfig {
|
||
|
|
name: "wg0".to_string(),
|
||
|
|
address: tun_address.parse()?,
|
||
|
|
netmask: tun_netmask.parse()?,
|
||
|
|
mtu,
|
||
|
|
};
|
||
|
|
let tun_device = tunnel::create_tun(&tun_config)?;
|
||
|
|
info!("WireGuard TUN device created: {}", tun_config.name);
|
||
|
|
|
||
|
|
// Bind UDP socket
|
||
|
|
let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", listen_port)).await?;
|
||
|
|
info!("WireGuard server listening on UDP port {}", listen_port);
|
||
|
|
|
||
|
|
// Enable IP forwarding and NAT if requested
|
||
|
|
if config.enable_nat.unwrap_or(false) {
|
||
|
|
network::enable_ip_forwarding()?;
|
||
|
|
let subnet = config
|
||
|
|
.subnet
|
||
|
|
.as_deref()
|
||
|
|
.unwrap_or("10.8.0.0/24");
|
||
|
|
let iface = network::get_default_interface()?;
|
||
|
|
network::setup_nat(subnet, &iface).await?;
|
||
|
|
info!("NAT enabled for subnet {} via {}", subnet, iface);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Channels
|
||
|
|
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
|
||
|
|
let (command_tx, command_rx) = mpsc::channel::<WgCommand>(32);
|
||
|
|
|
||
|
|
let shared_stats = self.shared_stats.clone();
|
||
|
|
let server_stats = self.server_stats.clone();
|
||
|
|
let started_at = Instant::now();
|
||
|
|
|
||
|
|
// Initialize shared stats
|
||
|
|
{
|
||
|
|
let mut stats = shared_stats.write().await;
|
||
|
|
for peer in &peers {
|
||
|
|
stats.insert(peer.public_key_b64.clone(), WgPeerStats::default());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Spawn the event loop
|
||
|
|
tokio::spawn(async move {
|
||
|
|
if let Err(e) = wg_server_loop(
|
||
|
|
udp_socket,
|
||
|
|
tun_device,
|
||
|
|
peers,
|
||
|
|
peer_index,
|
||
|
|
rate_limiter,
|
||
|
|
config.private_key.clone(),
|
||
|
|
shared_stats,
|
||
|
|
server_stats,
|
||
|
|
started_at,
|
||
|
|
shutdown_rx,
|
||
|
|
command_rx,
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
{
|
||
|
|
error!("WireGuard server loop error: {}", e);
|
||
|
|
}
|
||
|
|
info!("WireGuard server loop exited");
|
||
|
|
});
|
||
|
|
|
||
|
|
self.shutdown_tx = Some(shutdown_tx);
|
||
|
|
self.command_tx = Some(command_tx);
|
||
|
|
self.started_at = Some(started_at);
|
||
|
|
self.listen_port = Some(listen_port);
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn stop(&mut self) -> Result<()> {
|
||
|
|
if let Some(tx) = self.shutdown_tx.take() {
|
||
|
|
let _ = tx.send(());
|
||
|
|
}
|
||
|
|
self.command_tx = None;
|
||
|
|
self.started_at = None;
|
||
|
|
self.listen_port = None;
|
||
|
|
info!("WireGuard server stopped");
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_status(&self) -> serde_json::Value {
|
||
|
|
if self.is_running() {
|
||
|
|
serde_json::json!({
|
||
|
|
"state": "running",
|
||
|
|
"listenPort": self.listen_port,
|
||
|
|
"uptimeSeconds": self.started_at.map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0),
|
||
|
|
})
|
||
|
|
} else {
|
||
|
|
serde_json::json!({ "state": "stopped" })
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn get_statistics(&self) -> serde_json::Value {
|
||
|
|
let mut stats = self.server_stats.write().await;
|
||
|
|
if let Some(started) = self.started_at {
|
||
|
|
stats.uptime_seconds = started.elapsed().as_secs_f64();
|
||
|
|
}
|
||
|
|
// Aggregate from peer stats
|
||
|
|
let peer_stats = self.shared_stats.read().await;
|
||
|
|
stats.active_peers = peer_stats.len();
|
||
|
|
stats.total_bytes_sent = peer_stats.values().map(|s| s.bytes_sent).sum();
|
||
|
|
stats.total_bytes_received = peer_stats.values().map(|s| s.bytes_received).sum();
|
||
|
|
stats.total_packets_sent = peer_stats.values().map(|s| s.packets_sent).sum();
|
||
|
|
stats.total_packets_received = peer_stats.values().map(|s| s.packets_received).sum();
|
||
|
|
serde_json::to_value(&*stats).unwrap_or_default()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn list_peers(&self) -> Vec<WgPeerInfo> {
|
||
|
|
let stats = self.shared_stats.read().await;
|
||
|
|
stats
|
||
|
|
.iter()
|
||
|
|
.map(|(key, s)| WgPeerInfo {
|
||
|
|
public_key: key.clone(),
|
||
|
|
allowed_ips: vec![], // populated from event loop snapshots
|
||
|
|
endpoint: None,
|
||
|
|
persistent_keepalive: None,
|
||
|
|
stats: s.clone(),
|
||
|
|
})
|
||
|
|
.collect()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn add_peer(&self, config: WgPeerConfig) -> Result<()> {
|
||
|
|
let tx = self
|
||
|
|
.command_tx
|
||
|
|
.as_ref()
|
||
|
|
.ok_or_else(|| anyhow!("Server not running"))?;
|
||
|
|
let (resp_tx, resp_rx) = oneshot::channel();
|
||
|
|
tx.send(WgCommand::AddPeer(config, resp_tx))
|
||
|
|
.await
|
||
|
|
.map_err(|_| anyhow!("Server event loop closed"))?;
|
||
|
|
resp_rx.await.map_err(|_| anyhow!("No response"))?
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn remove_peer(&self, public_key: &str) -> Result<()> {
|
||
|
|
let tx = self
|
||
|
|
.command_tx
|
||
|
|
.as_ref()
|
||
|
|
.ok_or_else(|| anyhow!("Server not running"))?;
|
||
|
|
let (resp_tx, resp_rx) = oneshot::channel();
|
||
|
|
tx.send(WgCommand::RemovePeer(public_key.to_string(), resp_tx))
|
||
|
|
.await
|
||
|
|
.map_err(|_| anyhow!("Server event loop closed"))?;
|
||
|
|
resp_rx.await.map_err(|_| anyhow!("No response"))?
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Server event loop
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
async fn wg_server_loop(
|
||
|
|
udp_socket: UdpSocket,
|
||
|
|
tun_device: tun::AsyncDevice,
|
||
|
|
mut peers: Vec<PeerState>,
|
||
|
|
peer_index: AtomicU32,
|
||
|
|
rate_limiter: Arc<RateLimiter>,
|
||
|
|
server_private_key_b64: String,
|
||
|
|
shared_stats: Arc<RwLock<HashMap<String, WgPeerStats>>>,
|
||
|
|
_server_stats: Arc<RwLock<WgServerStats>>,
|
||
|
|
_started_at: Instant,
|
||
|
|
mut shutdown_rx: oneshot::Receiver<()>,
|
||
|
|
mut command_rx: mpsc::Receiver<WgCommand>,
|
||
|
|
) -> Result<()> {
|
||
|
|
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
|
||
|
|
let mut tun_buf = vec![0u8; MAX_UDP_PACKET];
|
||
|
|
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
|
||
|
|
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
|
||
|
|
|
||
|
|
// Split TUN for concurrent read/write in select
|
||
|
|
let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device);
|
||
|
|
|
||
|
|
// Stats sync interval
|
||
|
|
let mut stats_timer =
|
||
|
|
tokio::time::interval(std::time::Duration::from_secs(1));
|
||
|
|
|
||
|
|
loop {
|
||
|
|
tokio::select! {
|
||
|
|
// --- UDP receive ---
|
||
|
|
result = udp_socket.recv_from(&mut udp_buf) => {
|
||
|
|
let (n, src_addr) = result?;
|
||
|
|
if n == 0 { continue; }
|
||
|
|
|
||
|
|
// Find which peer this packet belongs to by trying decapsulate
|
||
|
|
let mut handled = false;
|
||
|
|
for peer in peers.iter_mut() {
|
||
|
|
match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
udp_socket.send_to(packet, src_addr).await?;
|
||
|
|
// Drain loop
|
||
|
|
loop {
|
||
|
|
match peer.tunn.decapsulate(None, &[], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(pkt) => {
|
||
|
|
let ep = peer.endpoint.unwrap_or(src_addr);
|
||
|
|
udp_socket.send_to(pkt, ep).await?;
|
||
|
|
}
|
||
|
|
_ => break,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
peer.endpoint = Some(src_addr);
|
||
|
|
handled = true;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||
|
|
if peer.matches_dst(IpAddr::V4(addr)) {
|
||
|
|
let pkt_len = packet.len() as u64;
|
||
|
|
tun_writer.write_all(packet).await?;
|
||
|
|
peer.stats.bytes_received += pkt_len;
|
||
|
|
peer.stats.packets_received += 1;
|
||
|
|
}
|
||
|
|
peer.endpoint = Some(src_addr);
|
||
|
|
handled = true;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||
|
|
if peer.matches_dst(IpAddr::V6(addr)) {
|
||
|
|
let pkt_len = packet.len() as u64;
|
||
|
|
tun_writer.write_all(packet).await?;
|
||
|
|
peer.stats.bytes_received += pkt_len;
|
||
|
|
peer.stats.packets_received += 1;
|
||
|
|
}
|
||
|
|
peer.endpoint = Some(src_addr);
|
||
|
|
handled = true;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
TunnResult::Done => {
|
||
|
|
// This peer didn't recognize the packet, try next
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
TunnResult::Err(e) => {
|
||
|
|
debug!("decapsulate error from {}: {:?}", src_addr, e);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if !handled {
|
||
|
|
debug!("No peer matched UDP packet from {}", src_addr);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- TUN read ---
|
||
|
|
result = tun_reader.read(&mut tun_buf) => {
|
||
|
|
let n = result?;
|
||
|
|
if n == 0 { continue; }
|
||
|
|
|
||
|
|
let dst_ip = match extract_dst_ip(&tun_buf[..n]) {
|
||
|
|
Some(ip) => ip,
|
||
|
|
None => { continue; }
|
||
|
|
};
|
||
|
|
|
||
|
|
// Find peer whose AllowedIPs match the destination
|
||
|
|
for peer in peers.iter_mut() {
|
||
|
|
if !peer.matches_dst(dst_ip) {
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
match peer.tunn.encapsulate(&tun_buf[..n], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
if let Some(endpoint) = peer.endpoint {
|
||
|
|
let pkt_len = n as u64;
|
||
|
|
udp_socket.send_to(packet, endpoint).await?;
|
||
|
|
peer.stats.bytes_sent += pkt_len;
|
||
|
|
peer.stats.packets_sent += 1;
|
||
|
|
} else {
|
||
|
|
debug!("No endpoint for peer {}, dropping packet", peer.public_key_b64);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
TunnResult::Err(e) => {
|
||
|
|
debug!("encapsulate error for peer {}: {:?}", peer.public_key_b64, e);
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Timer tick (100ms) for WireGuard timers ---
|
||
|
|
_ = timer.tick() => {
|
||
|
|
for peer in peers.iter_mut() {
|
||
|
|
match peer.tunn.update_timers(&mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
if let Some(endpoint) = peer.endpoint {
|
||
|
|
udp_socket.send_to(packet, endpoint).await?;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
TunnResult::Err(e) => {
|
||
|
|
debug!("Timer error for peer {}: {:?}", peer.public_key_b64, e);
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Sync stats to shared state ---
|
||
|
|
_ = stats_timer.tick() => {
|
||
|
|
let mut shared = shared_stats.write().await;
|
||
|
|
for peer in peers.iter() {
|
||
|
|
shared.insert(peer.public_key_b64.clone(), peer.stats.clone());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Dynamic peer commands ---
|
||
|
|
cmd = command_rx.recv() => {
|
||
|
|
match cmd {
|
||
|
|
Some(WgCommand::AddPeer(config, resp_tx)) => {
|
||
|
|
let result = add_peer_to_loop(
|
||
|
|
&mut peers,
|
||
|
|
&config,
|
||
|
|
&peer_index,
|
||
|
|
&rate_limiter,
|
||
|
|
&server_private_key_b64,
|
||
|
|
);
|
||
|
|
if result.is_ok() {
|
||
|
|
let mut shared = shared_stats.write().await;
|
||
|
|
shared.insert(config.public_key.clone(), WgPeerStats::default());
|
||
|
|
}
|
||
|
|
let _ = resp_tx.send(result);
|
||
|
|
}
|
||
|
|
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
|
||
|
|
let prev_len = peers.len();
|
||
|
|
peers.retain(|p| p.public_key_b64 != pubkey);
|
||
|
|
if peers.len() < prev_len {
|
||
|
|
let mut shared = shared_stats.write().await;
|
||
|
|
shared.remove(&pubkey);
|
||
|
|
let _ = resp_tx.send(Ok(()));
|
||
|
|
} else {
|
||
|
|
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
None => {
|
||
|
|
info!("Command channel closed");
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Shutdown ---
|
||
|
|
_ = &mut shutdown_rx => {
|
||
|
|
info!("WireGuard server shutdown signal received");
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
fn add_peer_to_loop(
|
||
|
|
peers: &mut Vec<PeerState>,
|
||
|
|
config: &WgPeerConfig,
|
||
|
|
peer_index: &AtomicU32,
|
||
|
|
rate_limiter: &Arc<RateLimiter>,
|
||
|
|
server_private_key_b64: &str,
|
||
|
|
) -> Result<()> {
|
||
|
|
// Check for duplicate
|
||
|
|
if peers.iter().any(|p| p.public_key_b64 == config.public_key) {
|
||
|
|
return Err(anyhow!("Peer already exists: {}", config.public_key));
|
||
|
|
}
|
||
|
|
|
||
|
|
let peer_public = parse_public_key(&config.public_key)?;
|
||
|
|
let psk = match &config.preshared_key {
|
||
|
|
Some(k) => Some(parse_preshared_key(k)?),
|
||
|
|
None => None,
|
||
|
|
};
|
||
|
|
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
|
||
|
|
let priv_copy = parse_private_key(server_private_key_b64)?;
|
||
|
|
|
||
|
|
let tunn = Tunn::new(
|
||
|
|
priv_copy,
|
||
|
|
peer_public,
|
||
|
|
psk,
|
||
|
|
config.persistent_keepalive,
|
||
|
|
idx,
|
||
|
|
Some(rate_limiter.clone()),
|
||
|
|
);
|
||
|
|
|
||
|
|
let allowed_ips: Vec<AllowedIp> = config
|
||
|
|
.allowed_ips
|
||
|
|
.iter()
|
||
|
|
.map(|cidr| AllowedIp::parse(cidr))
|
||
|
|
.collect::<Result<Vec<_>>>()?;
|
||
|
|
|
||
|
|
let endpoint = match &config.endpoint {
|
||
|
|
Some(ep) => Some(ep.parse::<SocketAddr>()?),
|
||
|
|
None => None,
|
||
|
|
};
|
||
|
|
|
||
|
|
peers.push(PeerState {
|
||
|
|
tunn,
|
||
|
|
public_key_b64: config.public_key.clone(),
|
||
|
|
allowed_ips,
|
||
|
|
endpoint,
|
||
|
|
persistent_keepalive: config.persistent_keepalive,
|
||
|
|
stats: WgPeerStats::default(),
|
||
|
|
});
|
||
|
|
|
||
|
|
info!("Added WireGuard peer: {}", config.public_key);
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// WgClient
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
pub struct WgClient {
|
||
|
|
shutdown_tx: Option<oneshot::Sender<()>>,
|
||
|
|
shared_stats: Arc<RwLock<WgPeerStats>>,
|
||
|
|
state: Arc<RwLock<WgClientState>>,
|
||
|
|
assigned_ip: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize)]
|
||
|
|
#[serde(rename_all = "camelCase")]
|
||
|
|
struct WgClientState {
|
||
|
|
state: String,
|
||
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||
|
|
assigned_ip: Option<String>,
|
||
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||
|
|
connected_since: Option<String>,
|
||
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||
|
|
last_error: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl WgClient {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
shutdown_tx: None,
|
||
|
|
shared_stats: Arc::new(RwLock::new(WgPeerStats::default())),
|
||
|
|
state: Arc::new(RwLock::new(WgClientState {
|
||
|
|
state: "disconnected".to_string(),
|
||
|
|
assigned_ip: None,
|
||
|
|
connected_since: None,
|
||
|
|
last_error: None,
|
||
|
|
})),
|
||
|
|
assigned_ip: None,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn is_running(&self) -> bool {
|
||
|
|
self.shutdown_tx.is_some()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn connect(&mut self, config: WgClientConfig) -> Result<String> {
|
||
|
|
if self.is_running() {
|
||
|
|
return Err(anyhow!("WireGuard client is already connected"));
|
||
|
|
}
|
||
|
|
|
||
|
|
{
|
||
|
|
let mut state = self.state.write().await;
|
||
|
|
state.state = "connecting".to_string();
|
||
|
|
}
|
||
|
|
|
||
|
|
let mtu = config.mtu.unwrap_or(DEFAULT_MTU);
|
||
|
|
let _prefix = config.address_prefix.unwrap_or(24);
|
||
|
|
let address: Ipv4Addr = config.address.parse()?;
|
||
|
|
|
||
|
|
// Parse keys
|
||
|
|
let client_private = parse_private_key(&config.private_key)?;
|
||
|
|
let peer_public = parse_public_key(&config.peer.public_key)?;
|
||
|
|
let psk = match &config.peer.preshared_key {
|
||
|
|
Some(k) => Some(parse_preshared_key(k)?),
|
||
|
|
None => None,
|
||
|
|
};
|
||
|
|
|
||
|
|
let tunn = Tunn::new(
|
||
|
|
client_private,
|
||
|
|
peer_public,
|
||
|
|
psk,
|
||
|
|
config.peer.persistent_keepalive,
|
||
|
|
0, // single peer, index 0
|
||
|
|
None,
|
||
|
|
);
|
||
|
|
|
||
|
|
// Parse server endpoint
|
||
|
|
let endpoint: SocketAddr = config
|
||
|
|
.peer
|
||
|
|
.endpoint
|
||
|
|
.as_ref()
|
||
|
|
.ok_or_else(|| anyhow!("Peer endpoint is required for client mode"))?
|
||
|
|
.parse()?;
|
||
|
|
|
||
|
|
// Parse AllowedIPs
|
||
|
|
let allowed_ips: Vec<AllowedIp> = config
|
||
|
|
.peer
|
||
|
|
.allowed_ips
|
||
|
|
.iter()
|
||
|
|
.map(|cidr| AllowedIp::parse(cidr))
|
||
|
|
.collect::<Result<Vec<_>>>()?;
|
||
|
|
|
||
|
|
// Create TUN device
|
||
|
|
let tun_config = TunConfig {
|
||
|
|
name: "wg-client0".to_string(),
|
||
|
|
address,
|
||
|
|
netmask: Ipv4Addr::new(255, 255, 255, 0),
|
||
|
|
mtu,
|
||
|
|
};
|
||
|
|
let tun_device = tunnel::create_tun(&tun_config)?;
|
||
|
|
info!("WireGuard client TUN device created: {}", tun_config.name);
|
||
|
|
|
||
|
|
// Add routes for AllowedIPs
|
||
|
|
for cidr in &config.peer.allowed_ips {
|
||
|
|
if let Err(e) = tunnel::add_route(cidr, &tun_config.name).await {
|
||
|
|
warn!("Failed to add route for {}: {}", cidr, e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Bind ephemeral UDP socket
|
||
|
|
let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||
|
|
info!(
|
||
|
|
"WireGuard client bound to {}",
|
||
|
|
udp_socket.local_addr()?
|
||
|
|
);
|
||
|
|
|
||
|
|
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
|
||
|
|
let shared_stats = self.shared_stats.clone();
|
||
|
|
let state = self.state.clone();
|
||
|
|
let assigned_ip = config.address.clone();
|
||
|
|
|
||
|
|
// Update state
|
||
|
|
{
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.state = "connected".to_string();
|
||
|
|
s.assigned_ip = Some(assigned_ip.clone());
|
||
|
|
s.connected_since = Some(chrono_now());
|
||
|
|
}
|
||
|
|
|
||
|
|
// Spawn client loop
|
||
|
|
tokio::spawn(async move {
|
||
|
|
if let Err(e) = wg_client_loop(
|
||
|
|
udp_socket,
|
||
|
|
tun_device,
|
||
|
|
tunn,
|
||
|
|
endpoint,
|
||
|
|
allowed_ips,
|
||
|
|
shared_stats,
|
||
|
|
state.clone(),
|
||
|
|
shutdown_rx,
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
{
|
||
|
|
error!("WireGuard client loop error: {}", e);
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.state = "error".to_string();
|
||
|
|
s.last_error = Some(format!("{}", e));
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
self.shutdown_tx = Some(shutdown_tx);
|
||
|
|
self.assigned_ip = Some(config.address.clone());
|
||
|
|
|
||
|
|
Ok(config.address)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn disconnect(&mut self) -> Result<()> {
|
||
|
|
if let Some(tx) = self.shutdown_tx.take() {
|
||
|
|
let _ = tx.send(());
|
||
|
|
}
|
||
|
|
{
|
||
|
|
let mut s = self.state.write().await;
|
||
|
|
s.state = "disconnected".to_string();
|
||
|
|
s.assigned_ip = None;
|
||
|
|
s.connected_since = None;
|
||
|
|
}
|
||
|
|
self.assigned_ip = None;
|
||
|
|
info!("WireGuard client disconnected");
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn get_status(&self) -> serde_json::Value {
|
||
|
|
let s = self.state.read().await;
|
||
|
|
serde_json::to_value(&*s).unwrap_or_default()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn get_statistics(&self) -> serde_json::Value {
|
||
|
|
let stats = self.shared_stats.read().await;
|
||
|
|
serde_json::to_value(&*stats).unwrap_or_default()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Client event loop
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
async fn wg_client_loop(
|
||
|
|
udp_socket: UdpSocket,
|
||
|
|
tun_device: tun::AsyncDevice,
|
||
|
|
mut tunn: Tunn,
|
||
|
|
endpoint: SocketAddr,
|
||
|
|
_allowed_ips: Vec<AllowedIp>,
|
||
|
|
shared_stats: Arc<RwLock<WgPeerStats>>,
|
||
|
|
_state: Arc<RwLock<WgClientState>>,
|
||
|
|
mut shutdown_rx: oneshot::Receiver<()>,
|
||
|
|
) -> Result<()> {
|
||
|
|
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
|
||
|
|
let mut tun_buf = vec![0u8; MAX_UDP_PACKET];
|
||
|
|
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
|
||
|
|
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
|
||
|
|
let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1));
|
||
|
|
|
||
|
|
let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device);
|
||
|
|
|
||
|
|
// Local stats (synced periodically)
|
||
|
|
let mut local_stats = WgPeerStats::default();
|
||
|
|
|
||
|
|
// Initiate handshake
|
||
|
|
match tunn.encapsulate(&[], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
udp_socket.send_to(packet, endpoint).await?;
|
||
|
|
debug!("Sent WireGuard handshake initiation");
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
|
||
|
|
loop {
|
||
|
|
tokio::select! {
|
||
|
|
// --- UDP receive ---
|
||
|
|
result = udp_socket.recv_from(&mut udp_buf) => {
|
||
|
|
let (n, src_addr) = result?;
|
||
|
|
if n == 0 { continue; }
|
||
|
|
|
||
|
|
match tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
udp_socket.send_to(packet, endpoint).await?;
|
||
|
|
// Drain loop
|
||
|
|
loop {
|
||
|
|
match tunn.decapsulate(None, &[], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(pkt) => {
|
||
|
|
udp_socket.send_to(pkt, endpoint).await?;
|
||
|
|
}
|
||
|
|
_ => break,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
TunnResult::WriteToTunnelV4(packet, _addr) => {
|
||
|
|
let pkt_len = packet.len() as u64;
|
||
|
|
tun_writer.write_all(packet).await?;
|
||
|
|
local_stats.bytes_received += pkt_len;
|
||
|
|
local_stats.packets_received += 1;
|
||
|
|
}
|
||
|
|
TunnResult::WriteToTunnelV6(packet, _addr) => {
|
||
|
|
let pkt_len = packet.len() as u64;
|
||
|
|
tun_writer.write_all(packet).await?;
|
||
|
|
local_stats.bytes_received += pkt_len;
|
||
|
|
local_stats.packets_received += 1;
|
||
|
|
}
|
||
|
|
TunnResult::Done => {}
|
||
|
|
TunnResult::Err(e) => {
|
||
|
|
debug!("Client decapsulate error: {:?}", e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- TUN read ---
|
||
|
|
result = tun_reader.read(&mut tun_buf) => {
|
||
|
|
let n = result?;
|
||
|
|
if n == 0 { continue; }
|
||
|
|
|
||
|
|
match tunn.encapsulate(&tun_buf[..n], &mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
let pkt_len = n as u64;
|
||
|
|
udp_socket.send_to(packet, endpoint).await?;
|
||
|
|
local_stats.bytes_sent += pkt_len;
|
||
|
|
local_stats.packets_sent += 1;
|
||
|
|
}
|
||
|
|
TunnResult::Err(e) => {
|
||
|
|
debug!("Client encapsulate error: {:?}", e);
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Timer tick ---
|
||
|
|
_ = timer.tick() => {
|
||
|
|
match tunn.update_timers(&mut dst_buf) {
|
||
|
|
TunnResult::WriteToNetwork(packet) => {
|
||
|
|
udp_socket.send_to(packet, endpoint).await?;
|
||
|
|
}
|
||
|
|
TunnResult::Err(e) => {
|
||
|
|
debug!("Client timer error: {:?}", e);
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Sync stats ---
|
||
|
|
_ = stats_timer.tick() => {
|
||
|
|
let mut shared = shared_stats.write().await;
|
||
|
|
*shared = local_stats.clone();
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Shutdown ---
|
||
|
|
_ = &mut shutdown_rx => {
|
||
|
|
info!("WireGuard client shutdown signal received");
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Helpers
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
fn chrono_now() -> String {
|
||
|
|
// Simple ISO-8601 timestamp without chrono dependency
|
||
|
|
let dur = std::time::SystemTime::now()
|
||
|
|
.duration_since(std::time::UNIX_EPOCH)
|
||
|
|
.unwrap_or_default();
|
||
|
|
format!("{}s since epoch", dur.as_secs())
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Tests
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_generate_wg_keypair() {
|
||
|
|
let (pub_key, priv_key) = generate_wg_keypair();
|
||
|
|
// Base64 of 32 bytes = 44 chars (with padding)
|
||
|
|
assert_eq!(pub_key.len(), 44);
|
||
|
|
assert_eq!(priv_key.len(), 44);
|
||
|
|
|
||
|
|
// Decode and verify 32 bytes
|
||
|
|
let pub_bytes = BASE64.decode(&pub_key).unwrap();
|
||
|
|
let priv_bytes = BASE64.decode(&priv_key).unwrap();
|
||
|
|
assert_eq!(pub_bytes.len(), 32);
|
||
|
|
assert_eq!(priv_bytes.len(), 32);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_key_roundtrip() {
|
||
|
|
let (pub_b64, priv_b64) = generate_wg_keypair();
|
||
|
|
|
||
|
|
// Parse back
|
||
|
|
let secret = parse_private_key(&priv_b64).unwrap();
|
||
|
|
let public = parse_public_key(&pub_b64).unwrap();
|
||
|
|
|
||
|
|
// Derive public from private and verify match
|
||
|
|
let derived_public = PublicKey::from(&secret);
|
||
|
|
assert_eq!(public.to_bytes(), derived_public.to_bytes());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_parse_invalid_key() {
|
||
|
|
assert!(parse_private_key("not-valid-base64!!!").is_err());
|
||
|
|
assert!(parse_private_key("AAAA").is_err()); // too short (3 bytes)
|
||
|
|
assert!(parse_public_key("AAAA").is_err());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_allowed_ip_v4_match() {
|
||
|
|
let aip = AllowedIp::parse("10.0.0.0/24").unwrap();
|
||
|
|
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
|
||
|
|
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 254))));
|
||
|
|
assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 1, 1))));
|
||
|
|
assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_allowed_ip_v4_catch_all() {
|
||
|
|
let aip = AllowedIp::parse("0.0.0.0/0").unwrap();
|
||
|
|
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))));
|
||
|
|
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255))));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_allowed_ip_v4_host() {
|
||
|
|
let aip = AllowedIp::parse("10.0.0.5/32").unwrap();
|
||
|
|
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5))));
|
||
|
|
assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 6))));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_allowed_ip_v6_match() {
|
||
|
|
let aip = AllowedIp::parse("fd00::/64").unwrap();
|
||
|
|
assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1))));
|
||
|
|
assert!(!aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd01, 0, 0, 0, 0, 0, 0, 1))));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_allowed_ip_v6_catch_all() {
|
||
|
|
let aip = AllowedIp::parse("::/0").unwrap();
|
||
|
|
assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_allowed_ip_cross_family_no_match() {
|
||
|
|
let v4 = AllowedIp::parse("10.0.0.0/8").unwrap();
|
||
|
|
assert!(!v4.matches(IpAddr::V6(Ipv6Addr::LOCALHOST)));
|
||
|
|
|
||
|
|
let v6 = AllowedIp::parse("::/0").unwrap();
|
||
|
|
assert!(!v6.matches(IpAddr::V4(Ipv4Addr::LOCALHOST)));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_dst_ip_v4() {
|
||
|
|
// Minimal IPv4 header: version=4, IHL=5, total_length=20, dst at bytes 16-19
|
||
|
|
let mut pkt = [0u8; 20];
|
||
|
|
pkt[0] = 0x45; // version 4, IHL 5
|
||
|
|
pkt[16] = 10;
|
||
|
|
pkt[17] = 0;
|
||
|
|
pkt[18] = 0;
|
||
|
|
pkt[19] = 1;
|
||
|
|
assert_eq!(
|
||
|
|
extract_dst_ip(&pkt),
|
||
|
|
Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_dst_ip_v6() {
|
||
|
|
// Minimal IPv6 header: version=6, dst at bytes 24-39
|
||
|
|
let mut pkt = [0u8; 40];
|
||
|
|
pkt[0] = 0x60; // version 6
|
||
|
|
pkt[24] = 0xfd;
|
||
|
|
pkt[39] = 0x01;
|
||
|
|
let expected = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1));
|
||
|
|
assert_eq!(extract_dst_ip(&pkt), Some(expected));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_dst_ip_empty() {
|
||
|
|
assert_eq!(extract_dst_ip(&[]), None);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_loopback_tunnel() {
|
||
|
|
// Two Tunn instances: server and client, exchanging packets in memory
|
||
|
|
let (server_pub, server_priv) = generate_wg_keypair();
|
||
|
|
let (client_pub, client_priv) = generate_wg_keypair();
|
||
|
|
|
||
|
|
let server_secret = parse_private_key(&server_priv).unwrap();
|
||
|
|
let client_secret = parse_private_key(&client_priv).unwrap();
|
||
|
|
let server_public = parse_public_key(&server_pub).unwrap();
|
||
|
|
let client_public = parse_public_key(&client_pub).unwrap();
|
||
|
|
|
||
|
|
let mut server_tunn = Tunn::new(
|
||
|
|
server_secret,
|
||
|
|
client_public,
|
||
|
|
None,
|
||
|
|
None,
|
||
|
|
0,
|
||
|
|
None,
|
||
|
|
);
|
||
|
|
let mut client_tunn = Tunn::new(
|
||
|
|
client_secret,
|
||
|
|
server_public,
|
||
|
|
None,
|
||
|
|
None,
|
||
|
|
1,
|
||
|
|
None,
|
||
|
|
);
|
||
|
|
|
||
|
|
let mut buf_a = vec![0u8; 2048];
|
||
|
|
let mut buf_b = vec![0u8; 2048];
|
||
|
|
|
||
|
|
// Client initiates handshake
|
||
|
|
let handshake_init = match client_tunn.encapsulate(&[], &mut buf_a) {
|
||
|
|
TunnResult::WriteToNetwork(pkt) => pkt.to_vec(),
|
||
|
|
other => panic!("Expected WriteToNetwork for handshake init, got {:?}", format!("{:?}", std::mem::discriminant(&other))),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Server processes handshake init
|
||
|
|
let handshake_resp = match server_tunn.decapsulate(None, &handshake_init, &mut buf_b) {
|
||
|
|
TunnResult::WriteToNetwork(pkt) => pkt.to_vec(),
|
||
|
|
other => panic!("Expected WriteToNetwork for handshake resp, got {:?}", format!("{:?}", std::mem::discriminant(&other))),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Drain server
|
||
|
|
loop {
|
||
|
|
match server_tunn.decapsulate(None, &[], &mut buf_b) {
|
||
|
|
TunnResult::WriteToNetwork(_) => {}
|
||
|
|
_ => break,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Client processes handshake response
|
||
|
|
match client_tunn.decapsulate(None, &handshake_resp, &mut buf_a) {
|
||
|
|
TunnResult::WriteToNetwork(pkt) => {
|
||
|
|
// Client might send a keepalive or transport data
|
||
|
|
// Feed it to server
|
||
|
|
let pkt_copy = pkt.to_vec();
|
||
|
|
let _ = server_tunn.decapsulate(None, &pkt_copy, &mut buf_b);
|
||
|
|
}
|
||
|
|
TunnResult::Done => {}
|
||
|
|
other => {
|
||
|
|
// Drain
|
||
|
|
loop {
|
||
|
|
match client_tunn.decapsulate(None, &[], &mut buf_a) {
|
||
|
|
TunnResult::WriteToNetwork(_) => {}
|
||
|
|
_ => break,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Drain client
|
||
|
|
loop {
|
||
|
|
match client_tunn.decapsulate(None, &[], &mut buf_a) {
|
||
|
|
TunnResult::WriteToNetwork(_) => {}
|
||
|
|
_ => break,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Now try to send a fake IP packet from client to server
|
||
|
|
let mut fake_ip = [0u8; 28];
|
||
|
|
fake_ip[0] = 0x45; // IPv4
|
||
|
|
fake_ip[2] = 0;
|
||
|
|
fake_ip[3] = 28; // total length
|
||
|
|
// Source IP (bytes 12-15): 10.0.0.2 (client)
|
||
|
|
fake_ip[12] = 10;
|
||
|
|
fake_ip[13] = 0;
|
||
|
|
fake_ip[14] = 0;
|
||
|
|
fake_ip[15] = 2;
|
||
|
|
// Destination IP (bytes 16-19): 10.0.0.1 (server)
|
||
|
|
fake_ip[16] = 10;
|
||
|
|
fake_ip[17] = 0;
|
||
|
|
fake_ip[18] = 0;
|
||
|
|
fake_ip[19] = 1;
|
||
|
|
|
||
|
|
match client_tunn.encapsulate(&fake_ip, &mut buf_a) {
|
||
|
|
TunnResult::WriteToNetwork(encrypted) => {
|
||
|
|
let encrypted_copy = encrypted.to_vec();
|
||
|
|
// Server decapsulates
|
||
|
|
match server_tunn.decapsulate(None, &encrypted_copy, &mut buf_b) {
|
||
|
|
TunnResult::WriteToTunnelV4(decrypted, src_addr) => {
|
||
|
|
// src_addr is the source IP from the inner packet (for AllowedIPs check)
|
||
|
|
assert_eq!(src_addr, Ipv4Addr::new(10, 0, 0, 2));
|
||
|
|
assert_eq!(&decrypted[..fake_ip.len()], &fake_ip);
|
||
|
|
}
|
||
|
|
TunnResult::WriteToNetwork(_pkt) => {
|
||
|
|
// Might need another round trip, that's OK
|
||
|
|
}
|
||
|
|
_ => {
|
||
|
|
// Session might not be fully established yet, acceptable
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
TunnResult::Err(_) => {
|
||
|
|
// Session not yet established, acceptable in unit test
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|