Files
smartvpn/rust/src/wireguard.rs

1427 lines
54 KiB
Rust

use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use anyhow::{anyhow, Result};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use boringtun::noise::errors::WireGuardError;
use boringtun::noise::rate_limiter::RateLimiter;
use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::{PublicKey, StaticSecret};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot, RwLock};
use tracing::{debug, error, info, warn};
use crate::server::{ClientInfo, ForwardingEngine, ServerState};
use crate::tunnel::{self, TunConfig};
// ============================================================================
// Constants
// ============================================================================
const MAX_UDP_PACKET: usize = 65536;
const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET;
/// Minimum dst buffer size for boringtun encapsulate/decapsulate
const _MIN_DST_BUF: usize = 148;
const TIMER_TICK_MS: u64 = 100;
const DEFAULT_MTU: u16 = 1420;
// ============================================================================
// Configuration types
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WgPeerConfig {
pub public_key: String,
#[serde(default)]
pub preshared_key: Option<String>,
pub allowed_ips: Vec<String>,
#[serde(default)]
pub endpoint: Option<String>,
#[serde(default)]
pub persistent_keepalive: Option<u16>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WgClientConfig {
pub private_key: String,
pub address: String,
#[serde(default)]
pub address_prefix: Option<u8>,
#[serde(default)]
pub dns: Option<Vec<String>>,
#[serde(default)]
pub mtu: Option<u16>,
pub peer: WgPeerConfig,
}
// ============================================================================
// Stats types
// ============================================================================
#[derive(Debug, Clone, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WgPeerStats {
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub last_handshake_time: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WgPeerInfo {
pub public_key: String,
pub allowed_ips: Vec<String>,
pub endpoint: Option<String>,
pub persistent_keepalive: Option<u16>,
#[serde(flatten)]
pub stats: WgPeerStats,
}
// ============================================================================
// Key generation and parsing
// ============================================================================
/// Generate a WireGuard-compatible X25519 keypair.
/// Returns (public_key_base64, private_key_base64).
pub fn generate_wg_keypair() -> (String, String) {
let private = StaticSecret::random_from_rng(OsRng);
let public = PublicKey::from(&private);
let priv_b64 = BASE64.encode(private.to_bytes());
let pub_b64 = BASE64.encode(public.to_bytes());
(pub_b64, priv_b64)
}
/// Derive the WireGuard public key (base64) from a private key (base64).
pub fn wg_public_key_from_private(private_key_b64: &str) -> Result<String> {
let private = parse_private_key(private_key_b64)?;
let public = PublicKey::from(&private);
Ok(BASE64.encode(public.to_bytes()))
}
fn parse_private_key(b64: &str) -> Result<StaticSecret> {
let bytes = BASE64.decode(b64)?;
if bytes.len() != 32 {
return Err(anyhow!("Private key must be 32 bytes, got {}", bytes.len()));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(StaticSecret::from(arr))
}
fn parse_public_key(b64: &str) -> Result<PublicKey> {
let bytes = BASE64.decode(b64)?;
if bytes.len() != 32 {
return Err(anyhow!("Public key must be 32 bytes, got {}", bytes.len()));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(PublicKey::from(arr))
}
fn parse_preshared_key(b64: &str) -> Result<[u8; 32]> {
let bytes = BASE64.decode(b64)?;
if bytes.len() != 32 {
return Err(anyhow!(
"Preshared key must be 32 bytes, got {}",
bytes.len()
));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(arr)
}
// ============================================================================
// AllowedIPs matching
// ============================================================================
#[derive(Debug, Clone)]
struct AllowedIp {
addr: IpAddr,
prefix_len: u8,
}
impl AllowedIp {
fn parse(cidr: &str) -> Result<Self> {
let parts: Vec<&str> = cidr.split('/').collect();
if parts.len() != 2 {
return Err(anyhow!("Invalid CIDR: {}", cidr));
}
let addr: IpAddr = parts[0].parse()?;
let prefix_len: u8 = parts[1].parse()?;
match addr {
IpAddr::V4(_) if prefix_len > 32 => {
return Err(anyhow!("IPv4 prefix length {} > 32", prefix_len))
}
IpAddr::V6(_) if prefix_len > 128 => {
return Err(anyhow!("IPv6 prefix length {} > 128", prefix_len))
}
_ => {}
}
Ok(Self { addr, prefix_len })
}
fn matches(&self, ip: IpAddr) -> bool {
match (self.addr, ip) {
(IpAddr::V4(net), IpAddr::V4(target)) => {
if self.prefix_len == 0 {
return true;
}
if self.prefix_len >= 32 {
return net == target;
}
let mask = u32::MAX << (32 - self.prefix_len);
(u32::from(net) & mask) == (u32::from(target) & mask)
}
(IpAddr::V6(net), IpAddr::V6(target)) => {
if self.prefix_len == 0 {
return true;
}
if self.prefix_len >= 128 {
return net == target;
}
let net_bits = u128::from(net);
let target_bits = u128::from(target);
let mask = u128::MAX << (128 - self.prefix_len);
(net_bits & mask) == (target_bits & mask)
}
_ => false,
}
}
}
// ============================================================================
// Dynamic peer management commands
// ============================================================================
pub enum WgCommand {
AddPeer(WgPeerConfig, oneshot::Sender<Result<()>>),
RemovePeer(String, oneshot::Sender<Result<()>>),
}
// ============================================================================
// Internal peer state (owned by event loop)
// ============================================================================
struct PeerState {
tunn: Tunn,
public_key_b64: String,
allowed_ips: Vec<AllowedIp>,
endpoint: Option<SocketAddr>,
#[allow(dead_code)]
persistent_keepalive: Option<u16>,
stats: WgPeerStats,
/// Whether this peer has completed a WireGuard handshake and is in state.clients.
is_connected: bool,
/// Last time we received data or handshake activity from this peer.
last_activity_at: Option<tokio::time::Instant>,
/// VPN IP assigned during registration (used for connect/disconnect).
vpn_ip: Option<Ipv4Addr>,
/// Previous synced byte counts for aggregate stats delta tracking.
prev_synced_bytes_sent: u64,
prev_synced_bytes_received: u64,
}
impl PeerState {
fn matches_allowed_ips(&self, ip: IpAddr) -> bool {
self.allowed_ips.iter().any(|aip| aip.matches(ip))
}
}
fn add_peer_to_loop(
peers: &mut Vec<PeerState>,
config: &WgPeerConfig,
peer_index: &AtomicU32,
rate_limiter: &Arc<RateLimiter>,
server_private_key_b64: &str,
) -> Result<()> {
// Check for duplicate
if peers.iter().any(|p| p.public_key_b64 == config.public_key) {
return Err(anyhow!("Peer already exists: {}", config.public_key));
}
let peer_public = parse_public_key(&config.public_key)?;
let psk = match &config.preshared_key {
Some(k) => Some(parse_preshared_key(k)?),
None => None,
};
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
let priv_copy = parse_private_key(server_private_key_b64)?;
let tunn = Tunn::new(
priv_copy,
peer_public,
psk,
config.persistent_keepalive,
idx,
Some(rate_limiter.clone()),
);
let allowed_ips: Vec<AllowedIp> = config
.allowed_ips
.iter()
.map(|cidr| AllowedIp::parse(cidr))
.collect::<Result<Vec<_>>>()?;
let endpoint = match &config.endpoint {
Some(ep) => Some(ep.parse::<SocketAddr>()?),
None => None,
};
peers.push(PeerState {
tunn,
public_key_b64: config.public_key.clone(),
allowed_ips,
endpoint,
persistent_keepalive: config.persistent_keepalive,
stats: WgPeerStats::default(),
is_connected: false,
last_activity_at: None,
vpn_ip: None,
prev_synced_bytes_sent: 0,
prev_synced_bytes_received: 0,
});
info!("Added WireGuard peer: {}", config.public_key);
Ok(())
}
// ============================================================================
// Integrated WG listener (shares ServerState with WS/QUIC)
// ============================================================================
/// Configuration for the integrated WireGuard listener.
#[derive(Debug, Clone)]
pub struct WgListenerConfig {
pub private_key: String,
pub listen_port: u16,
pub peers: Vec<WgPeerConfig>,
}
/// Extract the peer's VPN IP from AllowedIp entries.
/// Prefers /32 entries (exact match); falls back to any IPv4 address.
fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option<Ipv4Addr> {
// Prefer /32 entries (exact peer VPN IP)
for aip in allowed_ips {
if let IpAddr::V4(v4) = aip.addr {
if aip.prefix_len == 32 {
return Some(v4);
}
}
}
// Fallback: use the first non-unspecified IPv4 address from any prefix length
for aip in allowed_ips {
if let IpAddr::V4(v4) = aip.addr {
if !v4.is_unspecified() {
return Some(v4);
}
}
}
None
}
/// Timestamp helper (mirrors server.rs timestamp_now).
fn wg_timestamp_now() -> String {
use std::time::SystemTime;
let duration = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
format!("{}", duration.as_secs())
}
/// Register a WG peer in ServerState (tun_routes + ip_pool only).
/// Does NOT add to state.clients — peers appear there only after handshake.
/// Returns the VPN IP.
async fn register_wg_peer(
state: &Arc<ServerState>,
peer: &PeerState,
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
) -> Result<Option<Ipv4Addr>> {
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
Some(ip) => ip,
None => {
warn!("WG peer {} has no /32 IPv4 in allowed_ips, skipping registration",
peer.public_key_b64);
return Ok(None);
}
};
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
// Reserve IP in the pool
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
return Ok(None);
}
// Create per-peer return channel and register in tun_routes
let fwd_mode = state.config.forwarding_mode.as_deref().unwrap_or("testing");
let forwarding_active = fwd_mode == "tun" || fwd_mode == "socket";
if forwarding_active {
let (peer_return_tx, mut peer_return_rx) = mpsc::channel::<Vec<u8>>(256);
state.tun_routes.write().await.insert(vpn_ip, peer_return_tx);
// Spawn relay task: per-peer channel → merged channel tagged with pubkey
let relay_tx = wg_return_tx.clone();
let pubkey = peer.public_key_b64.clone();
tokio::spawn(async move {
while let Some(packet) = peer_return_rx.recv().await {
if relay_tx.send((pubkey.clone(), packet)).await.is_err() {
break;
}
}
});
}
info!("WG peer {} registered with IP {} (not yet connected)", peer.public_key_b64, vpn_ip);
Ok(Some(vpn_ip))
}
/// Add a WG peer to state.clients on first successful handshake (data received).
async fn connect_wg_peer(
state: &Arc<ServerState>,
peer: &PeerState,
vpn_ip: Ipv4Addr,
) {
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: vpn_ip.to_string(),
connected_since: wg_timestamp_now(),
bytes_sent: peer.stats.bytes_sent,
bytes_received: peer.stats.bytes_received,
packets_dropped: 0,
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: None,
burst_bytes: None,
authenticated_key: peer.public_key_b64.clone(),
registered_client_id: client_id.clone(),
remote_addr: peer.endpoint.map(|e| e.to_string()),
transport_type: "wireguard".to_string(),
};
state.clients.write().await.insert(client_info.client_id.clone(), client_info);
// Increment total_connections
{
let mut stats = state.stats.write().await;
stats.total_connections += 1;
stats.total_connections_wireguard += 1;
}
info!("WG peer {} connected (IP: {})", peer.public_key_b64, vpn_ip);
}
/// Remove a WG peer from state.clients (disconnect without unregistering).
async fn disconnect_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
if state.clients.write().await.remove(&client_id).is_some() {
info!("WG peer {} disconnected (removed from active clients)", pubkey);
}
}
/// Unregister a WG peer from ServerState.
async fn unregister_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
vpn_ip: Option<Ipv4Addr>,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
if let Some(ip) = vpn_ip {
state.tun_routes.write().await.remove(&ip);
state.ip_pool.lock().await.release(&ip);
}
state.clients.write().await.remove(&client_id);
state.rate_limiters.lock().await.remove(&client_id);
}
/// Integrated WireGuard listener that shares ServerState with WS/QUIC listeners.
/// Uses the shared ForwardingEngine for packet routing instead of its own TUN device.
pub async fn run_wg_listener(
state: Arc<ServerState>,
config: WgListenerConfig,
mut shutdown_rx: mpsc::Receiver<()>,
mut command_rx: mpsc::Receiver<WgCommand>,
) -> Result<()> {
// Parse server private key
let server_private = parse_private_key(&config.private_key)?;
let server_public = PublicKey::from(&server_private);
// Create rate limiter for DDoS protection
let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64));
// Build initial peer state
let peer_index = AtomicU32::new(0);
let mut peers: Vec<PeerState> = Vec::with_capacity(config.peers.len());
for peer_config in &config.peers {
let peer_public = parse_public_key(&peer_config.public_key)?;
let psk = match &peer_config.preshared_key {
Some(k) => Some(parse_preshared_key(k)?),
None => None,
};
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
let priv_copy = parse_private_key(&config.private_key)?;
let tunn = Tunn::new(
priv_copy,
peer_public,
psk,
peer_config.persistent_keepalive,
idx,
Some(rate_limiter.clone()),
);
let allowed_ips: Vec<AllowedIp> = peer_config
.allowed_ips
.iter()
.map(|cidr| AllowedIp::parse(cidr))
.collect::<Result<Vec<_>>>()?;
let endpoint = match &peer_config.endpoint {
Some(ep) => Some(ep.parse::<SocketAddr>()?),
None => None,
};
peers.push(PeerState {
tunn,
public_key_b64: peer_config.public_key.clone(),
allowed_ips,
endpoint,
persistent_keepalive: peer_config.persistent_keepalive,
stats: WgPeerStats::default(),
is_connected: false,
last_activity_at: None,
vpn_ip: None,
prev_synced_bytes_sent: 0,
prev_synced_bytes_received: 0,
});
}
// Bind UDP socket
let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", config.listen_port)).await?;
info!("WireGuard listener started on UDP port {}", config.listen_port);
// Merged return-packet channel: all per-peer channels feed into this
let (wg_return_tx, mut wg_return_rx) = mpsc::channel::<(String, Vec<u8>)>(1024);
// Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients)
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
for peer in peers.iter_mut() {
if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await {
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
peer.vpn_ip = Some(ip);
}
}
// Buffers
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1));
let mut idle_check_timer = tokio::time::interval(std::time::Duration::from_secs(10));
loop {
tokio::select! {
// --- UDP receive → decapsulate → ForwardingEngine ---
result = udp_socket.recv_from(&mut udp_buf) => {
let (n, src_addr) = result?;
if n == 0 { continue; }
let mut handled = false;
for peer in peers.iter_mut() {
match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, src_addr).await?;
loop {
match peer.tunn.decapsulate(None, &[], &mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
let ep = peer.endpoint.unwrap_or(src_addr);
udp_socket.send_to(pkt, ep).await?;
}
_ => break,
}
}
peer.endpoint = Some(src_addr);
// Handshake response counts as activity
peer.last_activity_at = Some(tokio::time::Instant::now());
handled = true;
break;
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if peer.matches_allowed_ips(IpAddr::V4(addr)) {
let pkt_len = packet.len() as u64;
// Forward via shared forwarding engine
let mut engine = state.forwarding_engine.lock().await;
match &mut *engine {
ForwardingEngine::Tun(writer) => {
use tokio::io::AsyncWriteExt;
if let Err(e) = writer.write_all(packet).await {
warn!("TUN write error for WG peer: {}", e);
}
}
ForwardingEngine::Socket(sender) => {
let _ = sender.try_send(packet.to_vec());
}
ForwardingEngine::Bridge(sender) => {
let _ = sender.try_send(packet.to_vec());
}
ForwardingEngine::Hybrid { socket_tx, bridge_tx, routing_table } => {
if packet.len() >= 20 {
let src_ip = Ipv4Addr::new(packet[12], packet[13], packet[14], packet[15]);
let use_bridge = routing_table.read().await.get(&src_ip).copied().unwrap_or(false);
if use_bridge {
let _ = bridge_tx.try_send(packet.to_vec());
} else {
let _ = socket_tx.try_send(packet.to_vec());
}
}
}
ForwardingEngine::Testing => {}
}
peer.stats.bytes_received += pkt_len;
peer.stats.packets_received += 1;
}
peer.endpoint = Some(src_addr);
// Track activity and detect handshake completion
peer.last_activity_at = Some(tokio::time::Instant::now());
if !peer.is_connected {
peer.is_connected = true;
peer.stats.last_handshake_time = Some(wg_timestamp_now());
if let Some(vpn_ip) = peer.vpn_ip {
connect_wg_peer(&state, peer, vpn_ip).await;
}
}
handled = true;
break;
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if peer.matches_allowed_ips(IpAddr::V6(addr)) {
let pkt_len = packet.len() as u64;
let mut engine = state.forwarding_engine.lock().await;
match &mut *engine {
ForwardingEngine::Tun(writer) => {
use tokio::io::AsyncWriteExt;
if let Err(e) = writer.write_all(packet).await {
warn!("TUN write error for WG peer: {}", e);
}
}
ForwardingEngine::Socket(sender) => {
let _ = sender.try_send(packet.to_vec());
}
ForwardingEngine::Bridge(sender) => {
let _ = sender.try_send(packet.to_vec());
}
ForwardingEngine::Hybrid { socket_tx, bridge_tx, routing_table } => {
if packet.len() >= 20 {
let src_ip = Ipv4Addr::new(packet[12], packet[13], packet[14], packet[15]);
let use_bridge = routing_table.read().await.get(&src_ip).copied().unwrap_or(false);
if use_bridge {
let _ = bridge_tx.try_send(packet.to_vec());
} else {
let _ = socket_tx.try_send(packet.to_vec());
}
}
}
ForwardingEngine::Testing => {}
}
peer.stats.bytes_received += pkt_len;
peer.stats.packets_received += 1;
}
peer.endpoint = Some(src_addr);
// Track activity and detect handshake completion
peer.last_activity_at = Some(tokio::time::Instant::now());
if !peer.is_connected {
peer.is_connected = true;
peer.stats.last_handshake_time = Some(wg_timestamp_now());
if let Some(vpn_ip) = peer.vpn_ip {
connect_wg_peer(&state, peer, vpn_ip).await;
}
}
handled = true;
break;
}
TunnResult::Done => { continue; }
TunnResult::Err(e) => {
debug!("decapsulate error from {}: {:?}", src_addr, e);
continue;
}
}
}
if !handled {
debug!("No WG peer matched UDP packet from {}", src_addr);
}
}
// --- Return packets from tun_routes → encapsulate → UDP ---
Some((pubkey, packet)) = wg_return_rx.recv() => {
if let Some(peer) = peers.iter_mut().find(|p| p.public_key_b64 == pubkey) {
match peer.tunn.encapsulate(&packet, &mut dst_buf) {
TunnResult::WriteToNetwork(out) => {
if let Some(endpoint) = peer.endpoint {
let pkt_len = packet.len() as u64;
udp_socket.send_to(out, endpoint).await?;
peer.stats.bytes_sent += pkt_len;
peer.stats.packets_sent += 1;
} else {
debug!("No endpoint for WG peer {}, dropping return packet",
peer.public_key_b64);
}
}
TunnResult::Err(e) => {
debug!("encapsulate error for WG peer {}: {:?}",
peer.public_key_b64, e);
}
_ => {}
}
}
}
// --- WireGuard protocol timers (100ms) ---
_ = timer.tick() => {
for peer in peers.iter_mut() {
match peer.tunn.update_timers(&mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
if let Some(endpoint) = peer.endpoint {
udp_socket.send_to(packet, endpoint).await?;
}
}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
warn!("WG peer {} connection expired", peer.public_key_b64);
if peer.is_connected {
peer.is_connected = false;
disconnect_wg_peer(&state, &peer.public_key_b64).await;
}
}
TunnResult::Err(e) => {
debug!("Timer error for WG peer {}: {:?}",
peer.public_key_b64, e);
}
_ => {}
}
}
}
// --- Sync stats to ServerState (every 1s) ---
_ = stats_timer.tick() => {
let mut clients = state.clients.write().await;
let mut stats = state.stats.write().await;
for peer in peers.iter_mut() {
// Always update aggregate stats (regardless of connection state)
let delta_sent = peer.stats.bytes_sent.saturating_sub(peer.prev_synced_bytes_sent);
let delta_recv = peer.stats.bytes_received.saturating_sub(peer.prev_synced_bytes_received);
if delta_sent > 0 || delta_recv > 0 {
stats.bytes_sent += delta_sent;
stats.bytes_received += delta_recv;
peer.prev_synced_bytes_sent = peer.stats.bytes_sent;
peer.prev_synced_bytes_received = peer.stats.bytes_received;
}
// Only update ClientInfo if peer is connected (in state.clients)
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_sent = peer.stats.bytes_sent;
info.bytes_received = peer.stats.bytes_received;
info.remote_addr = peer.endpoint.map(|e| e.to_string());
}
}
}
// --- Idle timeout check (every 10s) ---
_ = idle_check_timer.tick() => {
let now = tokio::time::Instant::now();
for peer in peers.iter_mut() {
if peer.is_connected {
if let Some(last) = peer.last_activity_at {
if now.duration_since(last) > std::time::Duration::from_secs(180) {
info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64);
peer.is_connected = false;
disconnect_wg_peer(&state, &peer.public_key_b64).await;
}
}
}
}
}
// --- Dynamic peer commands ---
cmd = command_rx.recv() => {
match cmd {
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
let result = add_peer_to_loop(
&mut peers,
&peer_config,
&peer_index,
&rate_limiter,
&config.private_key,
);
if result.is_ok() {
// Register new peer in ServerState (IP + tun_routes only)
let peer = peers.last_mut().unwrap();
match register_wg_peer(&state, peer, &wg_return_tx).await {
Ok(Some(ip)) => {
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
peer.vpn_ip = Some(ip);
}
Ok(None) => {}
Err(e) => {
warn!("Failed to register WG peer: {}", e);
}
}
}
let _ = resp_tx.send(result);
}
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
let prev_len = peers.len();
peers.retain(|p| p.public_key_b64 != pubkey);
if peers.len() < prev_len {
let vpn_ip = peer_vpn_ips.remove(&pubkey);
unregister_wg_peer(&state, &pubkey, vpn_ip).await;
let _ = resp_tx.send(Ok(()));
} else {
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
}
}
None => {
info!("WG command channel closed");
break;
}
}
}
// --- Shutdown ---
_ = shutdown_rx.recv() => {
info!("WireGuard listener shutdown signal received");
break;
}
}
}
// Cleanup: unregister all peers from ServerState
for peer in &peers {
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await;
}
info!("WireGuard listener stopped");
Ok(())
}
// ============================================================================
// WgClient
// ============================================================================
pub struct WgClient {
shutdown_tx: Option<oneshot::Sender<()>>,
shared_stats: Arc<RwLock<WgPeerStats>>,
state: Arc<RwLock<WgClientState>>,
assigned_ip: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct WgClientState {
state: String,
#[serde(skip_serializing_if = "Option::is_none")]
assigned_ip: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
connected_since: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
last_error: Option<String>,
}
impl WgClient {
pub fn new() -> Self {
Self {
shutdown_tx: None,
shared_stats: Arc::new(RwLock::new(WgPeerStats::default())),
state: Arc::new(RwLock::new(WgClientState {
state: "disconnected".to_string(),
assigned_ip: None,
connected_since: None,
last_error: None,
})),
assigned_ip: None,
}
}
pub fn is_running(&self) -> bool {
self.shutdown_tx.is_some()
}
pub async fn connect(&mut self, config: WgClientConfig) -> Result<String> {
if self.is_running() {
return Err(anyhow!("WireGuard client is already connected"));
}
{
let mut state = self.state.write().await;
state.state = "connecting".to_string();
}
let mtu = config.mtu.unwrap_or(DEFAULT_MTU);
let _prefix = config.address_prefix.unwrap_or(24);
let address: Ipv4Addr = config.address.parse()?;
// Parse keys
let client_private = parse_private_key(&config.private_key)?;
let peer_public = parse_public_key(&config.peer.public_key)?;
let psk = match &config.peer.preshared_key {
Some(k) => Some(parse_preshared_key(k)?),
None => None,
};
let tunn = Tunn::new(
client_private,
peer_public,
psk,
config.peer.persistent_keepalive,
0, // single peer, index 0
None,
);
// Parse server endpoint
let endpoint: SocketAddr = config
.peer
.endpoint
.as_ref()
.ok_or_else(|| anyhow!("Peer endpoint is required for client mode"))?
.parse()?;
// Parse AllowedIPs
let allowed_ips: Vec<AllowedIp> = config
.peer
.allowed_ips
.iter()
.map(|cidr| AllowedIp::parse(cidr))
.collect::<Result<Vec<_>>>()?;
// Create TUN device
let tun_config = TunConfig {
name: "wg-client0".to_string(),
address,
netmask: Ipv4Addr::new(255, 255, 255, 0),
mtu,
};
let tun_device = tunnel::create_tun(&tun_config)?;
info!("WireGuard client TUN device created: {}", tun_config.name);
// Add routes for AllowedIPs
for cidr in &config.peer.allowed_ips {
if let Err(e) = tunnel::add_route(cidr, &tun_config.name).await {
warn!("Failed to add route for {}: {}", cidr, e);
}
}
// Bind ephemeral UDP socket
let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
info!(
"WireGuard client bound to {}",
udp_socket.local_addr()?
);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let shared_stats = self.shared_stats.clone();
let state = self.state.clone();
let assigned_ip = config.address.clone();
// Update state — handshake hasn't completed yet
{
let mut s = state.write().await;
s.state = "handshaking".to_string();
s.assigned_ip = Some(assigned_ip.clone());
s.connected_since = None;
}
// Spawn client loop
tokio::spawn(async move {
if let Err(e) = wg_client_loop(
udp_socket,
tun_device,
tunn,
endpoint,
allowed_ips,
shared_stats,
state.clone(),
shutdown_rx,
)
.await
{
error!("WireGuard client loop error: {}", e);
let mut s = state.write().await;
s.state = "error".to_string();
s.last_error = Some(format!("{}", e));
}
});
self.shutdown_tx = Some(shutdown_tx);
self.assigned_ip = Some(config.address.clone());
Ok(config.address)
}
pub async fn disconnect(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
{
let mut s = self.state.write().await;
s.state = "disconnected".to_string();
s.assigned_ip = None;
s.connected_since = None;
}
self.assigned_ip = None;
info!("WireGuard client disconnected");
Ok(())
}
pub async fn get_status(&self) -> serde_json::Value {
let s = self.state.read().await;
serde_json::to_value(&*s).unwrap_or_default()
}
pub async fn get_statistics(&self) -> serde_json::Value {
let stats = self.shared_stats.read().await;
serde_json::to_value(&*stats).unwrap_or_default()
}
}
// ============================================================================
// Client event loop
// ============================================================================
async fn wg_client_loop(
udp_socket: UdpSocket,
tun_device: tun::AsyncDevice,
mut tunn: Tunn,
endpoint: SocketAddr,
_allowed_ips: Vec<AllowedIp>,
shared_stats: Arc<RwLock<WgPeerStats>>,
state: Arc<RwLock<WgClientState>>,
mut shutdown_rx: oneshot::Receiver<()>,
) -> Result<()> {
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
let mut tun_buf = vec![0u8; MAX_UDP_PACKET];
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1));
let mut handshake_complete = false;
let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device);
// Local stats (synced periodically)
let mut local_stats = WgPeerStats::default();
// Initiate handshake
match tunn.encapsulate(&[], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
debug!("Sent WireGuard handshake initiation");
}
_ => {}
}
loop {
tokio::select! {
// --- UDP receive ---
result = udp_socket.recv_from(&mut udp_buf) => {
let (n, src_addr) = result?;
if n == 0 { continue; }
match tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
// Drain loop
loop {
match tunn.decapsulate(None, &[], &mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
udp_socket.send_to(pkt, endpoint).await?;
}
_ => break,
}
}
}
TunnResult::WriteToTunnelV4(packet, _addr) => {
let pkt_len = packet.len() as u64;
tun_writer.write_all(packet).await?;
local_stats.bytes_received += pkt_len;
local_stats.packets_received += 1;
if !handshake_complete {
handshake_complete = true;
let mut s = state.write().await;
s.state = "connected".to_string();
s.connected_since = Some(chrono_now());
info!("WireGuard handshake completed, tunnel active");
}
}
TunnResult::WriteToTunnelV6(packet, _addr) => {
let pkt_len = packet.len() as u64;
tun_writer.write_all(packet).await?;
local_stats.bytes_received += pkt_len;
local_stats.packets_received += 1;
if !handshake_complete {
handshake_complete = true;
let mut s = state.write().await;
s.state = "connected".to_string();
s.connected_since = Some(chrono_now());
info!("WireGuard handshake completed, tunnel active");
}
}
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
warn!("WireGuard session expired during decapsulate, re-initiating handshake");
match tunn.format_handshake_initiation(&mut dst_buf, true) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
}
_ => {}
}
}
TunnResult::Err(e) => {
debug!("Client decapsulate error: {:?}", e);
}
}
}
// --- TUN read ---
result = tun_reader.read(&mut tun_buf) => {
let n = result?;
if n == 0 { continue; }
match tunn.encapsulate(&tun_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
let pkt_len = n as u64;
udp_socket.send_to(packet, endpoint).await?;
local_stats.bytes_sent += pkt_len;
local_stats.packets_sent += 1;
}
TunnResult::Err(e) => {
debug!("Client encapsulate error: {:?}", e);
}
_ => {}
}
}
// --- Timer tick ---
_ = timer.tick() => {
match tunn.update_timers(&mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
warn!("WireGuard connection expired, re-initiating handshake");
match tunn.format_handshake_initiation(&mut dst_buf, true) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
debug!("Sent handshake re-initiation after expiry");
}
TunnResult::Err(e) => {
warn!("Failed to re-initiate handshake: {:?}", e);
}
_ => {}
}
}
TunnResult::Err(e) => {
debug!("Client timer error: {:?}", e);
}
_ => {}
}
}
// --- Sync stats ---
_ = stats_timer.tick() => {
let mut shared = shared_stats.write().await;
*shared = local_stats.clone();
}
// --- Shutdown ---
_ = &mut shutdown_rx => {
info!("WireGuard client shutdown signal received");
break;
}
}
}
Ok(())
}
// ============================================================================
// Helpers
// ============================================================================
fn chrono_now() -> String {
// Simple ISO-8601 timestamp without chrono dependency
let dur = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
format!("{}s since epoch", dur.as_secs())
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use crate::tunnel::extract_dst_ip;
use std::net::Ipv6Addr;
#[test]
fn test_generate_wg_keypair() {
let (pub_key, priv_key) = generate_wg_keypair();
// Base64 of 32 bytes = 44 chars (with padding)
assert_eq!(pub_key.len(), 44);
assert_eq!(priv_key.len(), 44);
// Decode and verify 32 bytes
let pub_bytes = BASE64.decode(&pub_key).unwrap();
let priv_bytes = BASE64.decode(&priv_key).unwrap();
assert_eq!(pub_bytes.len(), 32);
assert_eq!(priv_bytes.len(), 32);
}
#[test]
fn test_key_roundtrip() {
let (pub_b64, priv_b64) = generate_wg_keypair();
// Parse back
let secret = parse_private_key(&priv_b64).unwrap();
let public = parse_public_key(&pub_b64).unwrap();
// Derive public from private and verify match
let derived_public = PublicKey::from(&secret);
assert_eq!(public.to_bytes(), derived_public.to_bytes());
}
#[test]
fn test_wg_public_key_from_private() {
let (pub_b64, priv_b64) = generate_wg_keypair();
let derived = wg_public_key_from_private(&priv_b64).unwrap();
assert_eq!(derived, pub_b64);
}
#[test]
fn test_wg_public_key_from_private_invalid() {
assert!(wg_public_key_from_private("not-valid").is_err());
assert!(wg_public_key_from_private("AAAA").is_err());
}
#[test]
fn test_parse_invalid_key() {
assert!(parse_private_key("not-valid-base64!!!").is_err());
assert!(parse_private_key("AAAA").is_err()); // too short (3 bytes)
assert!(parse_public_key("AAAA").is_err());
}
#[test]
fn test_allowed_ip_v4_match() {
let aip = AllowedIp::parse("10.0.0.0/24").unwrap();
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 254))));
assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 1, 1))));
assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))));
}
#[test]
fn test_allowed_ip_v4_catch_all() {
let aip = AllowedIp::parse("0.0.0.0/0").unwrap();
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))));
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255))));
}
#[test]
fn test_allowed_ip_v4_host() {
let aip = AllowedIp::parse("10.0.0.5/32").unwrap();
assert!(aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5))));
assert!(!aip.matches(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 6))));
}
#[test]
fn test_allowed_ip_v6_match() {
let aip = AllowedIp::parse("fd00::/64").unwrap();
assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1))));
assert!(!aip.matches(IpAddr::V6(Ipv6Addr::new(0xfd01, 0, 0, 0, 0, 0, 0, 1))));
}
#[test]
fn test_allowed_ip_v6_catch_all() {
let aip = AllowedIp::parse("::/0").unwrap();
assert!(aip.matches(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))));
}
#[test]
fn test_allowed_ip_cross_family_no_match() {
let v4 = AllowedIp::parse("10.0.0.0/8").unwrap();
assert!(!v4.matches(IpAddr::V6(Ipv6Addr::LOCALHOST)));
let v6 = AllowedIp::parse("::/0").unwrap();
assert!(!v6.matches(IpAddr::V4(Ipv4Addr::LOCALHOST)));
}
#[test]
fn test_extract_dst_ip_v4() {
// Minimal IPv4 header: version=4, IHL=5, total_length=20, dst at bytes 16-19
let mut pkt = [0u8; 20];
pkt[0] = 0x45; // version 4, IHL 5
pkt[16] = 10;
pkt[17] = 0;
pkt[18] = 0;
pkt[19] = 1;
assert_eq!(
extract_dst_ip(&pkt),
Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
);
}
#[test]
fn test_extract_dst_ip_v6() {
// Minimal IPv6 header: version=6, dst at bytes 24-39
let mut pkt = [0u8; 40];
pkt[0] = 0x60; // version 6
pkt[24] = 0xfd;
pkt[39] = 0x01;
let expected = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1));
assert_eq!(extract_dst_ip(&pkt), Some(expected));
}
#[test]
fn test_extract_dst_ip_empty() {
assert_eq!(extract_dst_ip(&[]), None);
}
#[test]
fn test_loopback_tunnel() {
// Two Tunn instances: server and client, exchanging packets in memory
let (server_pub, server_priv) = generate_wg_keypair();
let (client_pub, client_priv) = generate_wg_keypair();
let server_secret = parse_private_key(&server_priv).unwrap();
let client_secret = parse_private_key(&client_priv).unwrap();
let server_public = parse_public_key(&server_pub).unwrap();
let client_public = parse_public_key(&client_pub).unwrap();
let mut server_tunn = Tunn::new(
server_secret,
client_public,
None,
None,
0,
None,
);
let mut client_tunn = Tunn::new(
client_secret,
server_public,
None,
None,
1,
None,
);
let mut buf_a = vec![0u8; 2048];
let mut buf_b = vec![0u8; 2048];
// Client initiates handshake
let handshake_init = match client_tunn.encapsulate(&[], &mut buf_a) {
TunnResult::WriteToNetwork(pkt) => pkt.to_vec(),
other => panic!("Expected WriteToNetwork for handshake init, got {:?}", format!("{:?}", std::mem::discriminant(&other))),
};
// Server processes handshake init
let handshake_resp = match server_tunn.decapsulate(None, &handshake_init, &mut buf_b) {
TunnResult::WriteToNetwork(pkt) => pkt.to_vec(),
other => panic!("Expected WriteToNetwork for handshake resp, got {:?}", format!("{:?}", std::mem::discriminant(&other))),
};
// Drain server
loop {
match server_tunn.decapsulate(None, &[], &mut buf_b) {
TunnResult::WriteToNetwork(_) => {}
_ => break,
}
}
// Client processes handshake response
match client_tunn.decapsulate(None, &handshake_resp, &mut buf_a) {
TunnResult::WriteToNetwork(pkt) => {
// Client might send a keepalive or transport data
// Feed it to server
let pkt_copy = pkt.to_vec();
let _ = server_tunn.decapsulate(None, &pkt_copy, &mut buf_b);
}
TunnResult::Done => {}
_other => {
// Drain
loop {
match client_tunn.decapsulate(None, &[], &mut buf_a) {
TunnResult::WriteToNetwork(_) => {}
_ => break,
}
}
}
}
// Drain client
loop {
match client_tunn.decapsulate(None, &[], &mut buf_a) {
TunnResult::WriteToNetwork(_) => {}
_ => break,
}
}
// Now try to send a fake IP packet from client to server
let mut fake_ip = [0u8; 28];
fake_ip[0] = 0x45; // IPv4
fake_ip[2] = 0;
fake_ip[3] = 28; // total length
// Source IP (bytes 12-15): 10.0.0.2 (client)
fake_ip[12] = 10;
fake_ip[13] = 0;
fake_ip[14] = 0;
fake_ip[15] = 2;
// Destination IP (bytes 16-19): 10.0.0.1 (server)
fake_ip[16] = 10;
fake_ip[17] = 0;
fake_ip[18] = 0;
fake_ip[19] = 1;
match client_tunn.encapsulate(&fake_ip, &mut buf_a) {
TunnResult::WriteToNetwork(encrypted) => {
let encrypted_copy = encrypted.to_vec();
// Server decapsulates
match server_tunn.decapsulate(None, &encrypted_copy, &mut buf_b) {
TunnResult::WriteToTunnelV4(decrypted, src_addr) => {
// src_addr is the source IP from the inner packet (for AllowedIPs check)
assert_eq!(src_addr, Ipv4Addr::new(10, 0, 0, 2));
assert_eq!(&decrypted[..fake_ip.len()], &fake_ip);
}
TunnResult::WriteToNetwork(_pkt) => {
// Might need another round trip, that's OK
}
_ => {
// Session might not be fully established yet, acceptable
}
}
}
TunnResult::Err(_) => {
// Session not yet established, acceptable in unit test
}
_ => {}
}
}
}