feat(server): unify WireGuard into the shared server transport pipeline

This commit is contained in:
2026-03-30 06:52:20 +00:00
parent 70e838c8ff
commit 79d9928485
8 changed files with 608 additions and 634 deletions

View File

@@ -2,8 +2,6 @@ use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Instant;
use anyhow::{anyhow, Result};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
@@ -17,8 +15,7 @@ use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot, RwLock};
use tracing::{debug, error, info, warn};
use crate::network;
use crate::tunnel::extract_dst_ip;
use crate::server::{ClientInfo, ForwardingEngine, ServerState};
use crate::tunnel::{self, TunConfig};
// ============================================================================
@@ -30,9 +27,6 @@ const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET;
/// Minimum dst buffer size for boringtun encapsulate/decapsulate
const _MIN_DST_BUF: usize = 148;
const TIMER_TICK_MS: u64 = 100;
const DEFAULT_WG_PORT: u16 = 51820;
const DEFAULT_TUN_ADDRESS: &str = "10.8.0.1";
const DEFAULT_TUN_NETMASK: &str = "255.255.255.0";
const DEFAULT_MTU: u16 = 1420;
// ============================================================================
@@ -52,27 +46,6 @@ pub struct WgPeerConfig {
pub persistent_keepalive: Option<u16>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WgServerConfig {
pub private_key: String,
#[serde(default)]
pub listen_port: Option<u16>,
#[serde(default)]
pub tun_address: Option<String>,
#[serde(default)]
pub tun_netmask: Option<String>,
#[serde(default)]
pub mtu: Option<u16>,
pub peers: Vec<WgPeerConfig>,
#[serde(default)]
pub dns: Option<Vec<String>>,
#[serde(default)]
pub enable_nat: Option<bool>,
#[serde(default)]
pub subnet: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WgClientConfig {
@@ -112,17 +85,6 @@ pub struct WgPeerInfo {
pub stats: WgPeerStats,
}
#[derive(Debug, Clone, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WgServerStats {
pub total_bytes_sent: u64,
pub total_bytes_received: u64,
pub total_packets_sent: u64,
pub total_packets_received: u64,
pub active_peers: usize,
pub uptime_seconds: f64,
}
// ============================================================================
// Key generation and parsing
// ============================================================================
@@ -233,7 +195,7 @@ impl AllowedIp {
// Dynamic peer management commands
// ============================================================================
enum WgCommand {
pub enum WgCommand {
AddPeer(WgPeerConfig, oneshot::Sender<Result<()>>),
RemovePeer(String, oneshot::Sender<Result<()>>),
}
@@ -258,451 +220,6 @@ impl PeerState {
}
}
// ============================================================================
// WgServer
// ============================================================================
pub struct WgServer {
shutdown_tx: Option<oneshot::Sender<()>>,
command_tx: Option<mpsc::Sender<WgCommand>>,
shared_stats: Arc<RwLock<HashMap<String, WgPeerStats>>>,
server_stats: Arc<RwLock<WgServerStats>>,
started_at: Option<Instant>,
listen_port: Option<u16>,
}
impl WgServer {
pub fn new() -> Self {
Self {
shutdown_tx: None,
command_tx: None,
shared_stats: Arc::new(RwLock::new(HashMap::new())),
server_stats: Arc::new(RwLock::new(WgServerStats::default())),
started_at: None,
listen_port: None,
}
}
pub fn is_running(&self) -> bool {
self.shutdown_tx.is_some()
}
pub async fn start(&mut self, config: WgServerConfig) -> Result<()> {
if self.is_running() {
return Err(anyhow!("WireGuard server is already running"));
}
let listen_port = config.listen_port.unwrap_or(DEFAULT_WG_PORT);
let tun_address = config
.tun_address
.as_deref()
.unwrap_or(DEFAULT_TUN_ADDRESS);
let tun_netmask = config
.tun_netmask
.as_deref()
.unwrap_or(DEFAULT_TUN_NETMASK);
let mtu = config.mtu.unwrap_or(DEFAULT_MTU);
// Parse server private key
let server_private = parse_private_key(&config.private_key)?;
let server_public = PublicKey::from(&server_private);
// Create rate limiter for DDoS protection
let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64));
// Build peer state
let peer_index = AtomicU32::new(0);
let mut peers: Vec<PeerState> = Vec::with_capacity(config.peers.len());
for peer_config in &config.peers {
let peer_public = parse_public_key(&peer_config.public_key)?;
let psk = match &peer_config.preshared_key {
Some(k) => Some(parse_preshared_key(k)?),
None => None,
};
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
// Clone the private key for each Tunn (StaticSecret doesn't implement Clone,
// so re-parse from config)
let priv_copy = parse_private_key(&config.private_key)?;
let tunn = Tunn::new(
priv_copy,
peer_public,
psk,
peer_config.persistent_keepalive,
idx,
Some(rate_limiter.clone()),
);
let allowed_ips: Vec<AllowedIp> = peer_config
.allowed_ips
.iter()
.map(|cidr| AllowedIp::parse(cidr))
.collect::<Result<Vec<_>>>()?;
let endpoint = match &peer_config.endpoint {
Some(ep) => Some(ep.parse::<SocketAddr>()?),
None => None,
};
peers.push(PeerState {
tunn,
public_key_b64: peer_config.public_key.clone(),
allowed_ips,
endpoint,
persistent_keepalive: peer_config.persistent_keepalive,
stats: WgPeerStats::default(),
});
}
// Create TUN device
let tun_config = TunConfig {
name: "wg0".to_string(),
address: tun_address.parse()?,
netmask: tun_netmask.parse()?,
mtu,
};
let tun_device = tunnel::create_tun(&tun_config)?;
info!("WireGuard TUN device created: {}", tun_config.name);
// Bind UDP socket
let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", listen_port)).await?;
info!("WireGuard server listening on UDP port {}", listen_port);
// Enable IP forwarding and NAT if requested
if config.enable_nat.unwrap_or(false) {
network::enable_ip_forwarding()?;
let subnet = config
.subnet
.as_deref()
.unwrap_or("10.8.0.0/24");
let iface = network::get_default_interface()?;
network::setup_nat(subnet, &iface).await?;
info!("NAT enabled for subnet {} via {}", subnet, iface);
}
// Channels
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let (command_tx, command_rx) = mpsc::channel::<WgCommand>(32);
let shared_stats = self.shared_stats.clone();
let server_stats = self.server_stats.clone();
let started_at = Instant::now();
// Initialize shared stats
{
let mut stats = shared_stats.write().await;
for peer in &peers {
stats.insert(peer.public_key_b64.clone(), WgPeerStats::default());
}
}
// Spawn the event loop
tokio::spawn(async move {
if let Err(e) = wg_server_loop(
udp_socket,
tun_device,
peers,
peer_index,
rate_limiter,
config.private_key.clone(),
shared_stats,
server_stats,
started_at,
shutdown_rx,
command_rx,
)
.await
{
error!("WireGuard server loop error: {}", e);
}
info!("WireGuard server loop exited");
});
self.shutdown_tx = Some(shutdown_tx);
self.command_tx = Some(command_tx);
self.started_at = Some(started_at);
self.listen_port = Some(listen_port);
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
self.command_tx = None;
self.started_at = None;
self.listen_port = None;
info!("WireGuard server stopped");
Ok(())
}
pub fn get_status(&self) -> serde_json::Value {
if self.is_running() {
serde_json::json!({
"state": "running",
"listenPort": self.listen_port,
"uptimeSeconds": self.started_at.map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0),
})
} else {
serde_json::json!({ "state": "stopped" })
}
}
pub async fn get_statistics(&self) -> serde_json::Value {
let mut stats = self.server_stats.write().await;
if let Some(started) = self.started_at {
stats.uptime_seconds = started.elapsed().as_secs_f64();
}
// Aggregate from peer stats
let peer_stats = self.shared_stats.read().await;
stats.active_peers = peer_stats.len();
stats.total_bytes_sent = peer_stats.values().map(|s| s.bytes_sent).sum();
stats.total_bytes_received = peer_stats.values().map(|s| s.bytes_received).sum();
stats.total_packets_sent = peer_stats.values().map(|s| s.packets_sent).sum();
stats.total_packets_received = peer_stats.values().map(|s| s.packets_received).sum();
serde_json::to_value(&*stats).unwrap_or_default()
}
pub async fn list_peers(&self) -> Vec<WgPeerInfo> {
let stats = self.shared_stats.read().await;
stats
.iter()
.map(|(key, s)| WgPeerInfo {
public_key: key.clone(),
allowed_ips: vec![], // populated from event loop snapshots
endpoint: None,
persistent_keepalive: None,
stats: s.clone(),
})
.collect()
}
pub async fn add_peer(&self, config: WgPeerConfig) -> Result<()> {
let tx = self
.command_tx
.as_ref()
.ok_or_else(|| anyhow!("Server not running"))?;
let (resp_tx, resp_rx) = oneshot::channel();
tx.send(WgCommand::AddPeer(config, resp_tx))
.await
.map_err(|_| anyhow!("Server event loop closed"))?;
resp_rx.await.map_err(|_| anyhow!("No response"))?
}
pub async fn remove_peer(&self, public_key: &str) -> Result<()> {
let tx = self
.command_tx
.as_ref()
.ok_or_else(|| anyhow!("Server not running"))?;
let (resp_tx, resp_rx) = oneshot::channel();
tx.send(WgCommand::RemovePeer(public_key.to_string(), resp_tx))
.await
.map_err(|_| anyhow!("Server event loop closed"))?;
resp_rx.await.map_err(|_| anyhow!("No response"))?
}
}
// ============================================================================
// Server event loop
// ============================================================================
async fn wg_server_loop(
udp_socket: UdpSocket,
tun_device: tun::AsyncDevice,
mut peers: Vec<PeerState>,
peer_index: AtomicU32,
rate_limiter: Arc<RateLimiter>,
server_private_key_b64: String,
shared_stats: Arc<RwLock<HashMap<String, WgPeerStats>>>,
_server_stats: Arc<RwLock<WgServerStats>>,
_started_at: Instant,
mut shutdown_rx: oneshot::Receiver<()>,
mut command_rx: mpsc::Receiver<WgCommand>,
) -> Result<()> {
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
let mut tun_buf = vec![0u8; MAX_UDP_PACKET];
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
// Split TUN for concurrent read/write in select
let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device);
// Stats sync interval
let mut stats_timer =
tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
// --- UDP receive ---
result = udp_socket.recv_from(&mut udp_buf) => {
let (n, src_addr) = result?;
if n == 0 { continue; }
// Find which peer this packet belongs to by trying decapsulate
let mut handled = false;
for peer in peers.iter_mut() {
match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, src_addr).await?;
// Drain loop
loop {
match peer.tunn.decapsulate(None, &[], &mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
let ep = peer.endpoint.unwrap_or(src_addr);
udp_socket.send_to(pkt, ep).await?;
}
_ => break,
}
}
peer.endpoint = Some(src_addr);
handled = true;
break;
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if peer.matches_dst(IpAddr::V4(addr)) {
let pkt_len = packet.len() as u64;
tun_writer.write_all(packet).await?;
peer.stats.bytes_received += pkt_len;
peer.stats.packets_received += 1;
}
peer.endpoint = Some(src_addr);
handled = true;
break;
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if peer.matches_dst(IpAddr::V6(addr)) {
let pkt_len = packet.len() as u64;
tun_writer.write_all(packet).await?;
peer.stats.bytes_received += pkt_len;
peer.stats.packets_received += 1;
}
peer.endpoint = Some(src_addr);
handled = true;
break;
}
TunnResult::Done => {
// This peer didn't recognize the packet, try next
continue;
}
TunnResult::Err(e) => {
debug!("decapsulate error from {}: {:?}", src_addr, e);
continue;
}
}
}
if !handled {
debug!("No peer matched UDP packet from {}", src_addr);
}
}
// --- TUN read ---
result = tun_reader.read(&mut tun_buf) => {
let n = result?;
if n == 0 { continue; }
let dst_ip = match extract_dst_ip(&tun_buf[..n]) {
Some(ip) => ip,
None => { continue; }
};
// Find peer whose AllowedIPs match the destination
for peer in peers.iter_mut() {
if !peer.matches_dst(dst_ip) {
continue;
}
match peer.tunn.encapsulate(&tun_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
if let Some(endpoint) = peer.endpoint {
let pkt_len = n as u64;
udp_socket.send_to(packet, endpoint).await?;
peer.stats.bytes_sent += pkt_len;
peer.stats.packets_sent += 1;
} else {
debug!("No endpoint for peer {}, dropping packet", peer.public_key_b64);
}
}
TunnResult::Err(e) => {
debug!("encapsulate error for peer {}: {:?}", peer.public_key_b64, e);
}
_ => {}
}
break;
}
}
// --- Timer tick (100ms) for WireGuard timers ---
_ = timer.tick() => {
for peer in peers.iter_mut() {
match peer.tunn.update_timers(&mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
if let Some(endpoint) = peer.endpoint {
udp_socket.send_to(packet, endpoint).await?;
}
}
TunnResult::Err(e) => {
debug!("Timer error for peer {}: {:?}", peer.public_key_b64, e);
}
_ => {}
}
}
}
// --- Sync stats to shared state ---
_ = stats_timer.tick() => {
let mut shared = shared_stats.write().await;
for peer in peers.iter() {
shared.insert(peer.public_key_b64.clone(), peer.stats.clone());
}
}
// --- Dynamic peer commands ---
cmd = command_rx.recv() => {
match cmd {
Some(WgCommand::AddPeer(config, resp_tx)) => {
let result = add_peer_to_loop(
&mut peers,
&config,
&peer_index,
&rate_limiter,
&server_private_key_b64,
);
if result.is_ok() {
let mut shared = shared_stats.write().await;
shared.insert(config.public_key.clone(), WgPeerStats::default());
}
let _ = resp_tx.send(result);
}
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
let prev_len = peers.len();
peers.retain(|p| p.public_key_b64 != pubkey);
if peers.len() < prev_len {
let mut shared = shared_stats.write().await;
shared.remove(&pubkey);
let _ = resp_tx.send(Ok(()));
} else {
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
}
}
None => {
info!("Command channel closed");
break;
}
}
}
// --- Shutdown ---
_ = &mut shutdown_rx => {
info!("WireGuard server shutdown signal received");
break;
}
}
}
Ok(())
}
fn add_peer_to_loop(
peers: &mut Vec<PeerState>,
@@ -757,6 +274,410 @@ fn add_peer_to_loop(
Ok(())
}
// ============================================================================
// Integrated WG listener (shares ServerState with WS/QUIC)
// ============================================================================
/// Configuration for the integrated WireGuard listener.
#[derive(Debug, Clone)]
pub struct WgListenerConfig {
pub private_key: String,
pub listen_port: u16,
pub peers: Vec<WgPeerConfig>,
}
/// Extract the first /32 IPv4 address from a list of AllowedIp entries.
/// This is the peer's VPN IP used for return-packet routing.
fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option<Ipv4Addr> {
for aip in allowed_ips {
if let IpAddr::V4(v4) = aip.addr {
if aip.prefix_len == 32 {
return Some(v4);
}
}
}
None
}
/// Timestamp helper (mirrors server.rs timestamp_now).
fn wg_timestamp_now() -> String {
use std::time::SystemTime;
let duration = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
format!("{}", duration.as_secs())
}
/// Register a WG peer in ServerState (tun_routes, clients, ip_pool).
/// Returns the VPN IP and the per-peer return-packet receiver.
async fn register_wg_peer(
state: &Arc<ServerState>,
peer: &PeerState,
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
) -> Result<Option<Ipv4Addr>> {
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
Some(ip) => ip,
None => {
warn!("WG peer {} has no /32 IPv4 in allowed_ips, skipping registration",
peer.public_key_b64);
return Ok(None);
}
};
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
// Reserve IP in the pool
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
return Ok(None);
}
// Create per-peer return channel and register in tun_routes
let fwd_mode = state.config.forwarding_mode.as_deref().unwrap_or("testing");
let forwarding_active = fwd_mode == "tun" || fwd_mode == "socket";
if forwarding_active {
let (peer_return_tx, mut peer_return_rx) = mpsc::channel::<Vec<u8>>(256);
state.tun_routes.write().await.insert(vpn_ip, peer_return_tx);
// Spawn relay task: per-peer channel → merged channel tagged with pubkey
let relay_tx = wg_return_tx.clone();
let pubkey = peer.public_key_b64.clone();
tokio::spawn(async move {
while let Some(packet) = peer_return_rx.recv().await {
if relay_tx.send((pubkey.clone(), packet)).await.is_err() {
break;
}
}
});
}
// Insert ClientInfo
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: vpn_ip.to_string(),
connected_since: wg_timestamp_now(),
bytes_sent: 0,
bytes_received: 0,
packets_dropped: 0,
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: None,
burst_bytes: None,
authenticated_key: peer.public_key_b64.clone(),
registered_client_id: client_id,
remote_addr: peer.endpoint.map(|e| e.to_string()),
transport_type: "wireguard".to_string(),
};
state.clients.write().await.insert(client_info.client_id.clone(), client_info);
Ok(Some(vpn_ip))
}
/// Unregister a WG peer from ServerState.
async fn unregister_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
vpn_ip: Option<Ipv4Addr>,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
if let Some(ip) = vpn_ip {
state.tun_routes.write().await.remove(&ip);
state.ip_pool.lock().await.release(&ip);
}
state.clients.write().await.remove(&client_id);
state.rate_limiters.lock().await.remove(&client_id);
}
/// Integrated WireGuard listener that shares ServerState with WS/QUIC listeners.
/// Uses the shared ForwardingEngine for packet routing instead of its own TUN device.
pub async fn run_wg_listener(
state: Arc<ServerState>,
config: WgListenerConfig,
mut shutdown_rx: mpsc::Receiver<()>,
mut command_rx: mpsc::Receiver<WgCommand>,
) -> Result<()> {
// Parse server private key
let server_private = parse_private_key(&config.private_key)?;
let server_public = PublicKey::from(&server_private);
// Create rate limiter for DDoS protection
let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64));
// Build initial peer state
let peer_index = AtomicU32::new(0);
let mut peers: Vec<PeerState> = Vec::with_capacity(config.peers.len());
for peer_config in &config.peers {
let peer_public = parse_public_key(&peer_config.public_key)?;
let psk = match &peer_config.preshared_key {
Some(k) => Some(parse_preshared_key(k)?),
None => None,
};
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
let priv_copy = parse_private_key(&config.private_key)?;
let tunn = Tunn::new(
priv_copy,
peer_public,
psk,
peer_config.persistent_keepalive,
idx,
Some(rate_limiter.clone()),
);
let allowed_ips: Vec<AllowedIp> = peer_config
.allowed_ips
.iter()
.map(|cidr| AllowedIp::parse(cidr))
.collect::<Result<Vec<_>>>()?;
let endpoint = match &peer_config.endpoint {
Some(ep) => Some(ep.parse::<SocketAddr>()?),
None => None,
};
peers.push(PeerState {
tunn,
public_key_b64: peer_config.public_key.clone(),
allowed_ips,
endpoint,
persistent_keepalive: peer_config.persistent_keepalive,
stats: WgPeerStats::default(),
});
}
// Bind UDP socket
let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", config.listen_port)).await?;
info!("WireGuard listener started on UDP port {}", config.listen_port);
// Merged return-packet channel: all per-peer channels feed into this
let (wg_return_tx, mut wg_return_rx) = mpsc::channel::<(String, Vec<u8>)>(1024);
// Register initial peers in ServerState and track their VPN IPs
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
for peer in &peers {
if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await {
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
}
}
// Buffers
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
// --- UDP receive → decapsulate → ForwardingEngine ---
result = udp_socket.recv_from(&mut udp_buf) => {
let (n, src_addr) = result?;
if n == 0 { continue; }
let mut handled = false;
for peer in peers.iter_mut() {
match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, src_addr).await?;
loop {
match peer.tunn.decapsulate(None, &[], &mut dst_buf) {
TunnResult::WriteToNetwork(pkt) => {
let ep = peer.endpoint.unwrap_or(src_addr);
udp_socket.send_to(pkt, ep).await?;
}
_ => break,
}
}
peer.endpoint = Some(src_addr);
handled = true;
break;
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if peer.matches_dst(IpAddr::V4(addr)) {
let pkt_len = packet.len() as u64;
// Forward via shared forwarding engine
let mut engine = state.forwarding_engine.lock().await;
match &mut *engine {
ForwardingEngine::Tun(writer) => {
use tokio::io::AsyncWriteExt;
if let Err(e) = writer.write_all(packet).await {
warn!("TUN write error for WG peer: {}", e);
}
}
ForwardingEngine::Socket(sender) => {
let _ = sender.try_send(packet.to_vec());
}
ForwardingEngine::Testing => {}
}
peer.stats.bytes_received += pkt_len;
peer.stats.packets_received += 1;
}
peer.endpoint = Some(src_addr);
handled = true;
break;
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if peer.matches_dst(IpAddr::V6(addr)) {
let pkt_len = packet.len() as u64;
let mut engine = state.forwarding_engine.lock().await;
match &mut *engine {
ForwardingEngine::Tun(writer) => {
use tokio::io::AsyncWriteExt;
if let Err(e) = writer.write_all(packet).await {
warn!("TUN write error for WG peer: {}", e);
}
}
ForwardingEngine::Socket(sender) => {
let _ = sender.try_send(packet.to_vec());
}
ForwardingEngine::Testing => {}
}
peer.stats.bytes_received += pkt_len;
peer.stats.packets_received += 1;
}
peer.endpoint = Some(src_addr);
handled = true;
break;
}
TunnResult::Done => { continue; }
TunnResult::Err(e) => {
debug!("decapsulate error from {}: {:?}", src_addr, e);
continue;
}
}
}
if !handled {
debug!("No WG peer matched UDP packet from {}", src_addr);
}
}
// --- Return packets from tun_routes → encapsulate → UDP ---
Some((pubkey, packet)) = wg_return_rx.recv() => {
if let Some(peer) = peers.iter_mut().find(|p| p.public_key_b64 == pubkey) {
match peer.tunn.encapsulate(&packet, &mut dst_buf) {
TunnResult::WriteToNetwork(out) => {
if let Some(endpoint) = peer.endpoint {
let pkt_len = packet.len() as u64;
udp_socket.send_to(out, endpoint).await?;
peer.stats.bytes_sent += pkt_len;
peer.stats.packets_sent += 1;
} else {
debug!("No endpoint for WG peer {}, dropping return packet",
peer.public_key_b64);
}
}
TunnResult::Err(e) => {
debug!("encapsulate error for WG peer {}: {:?}",
peer.public_key_b64, e);
}
_ => {}
}
}
}
// --- WireGuard protocol timers (100ms) ---
_ = timer.tick() => {
for peer in peers.iter_mut() {
match peer.tunn.update_timers(&mut dst_buf) {
TunnResult::WriteToNetwork(packet) => {
if let Some(endpoint) = peer.endpoint {
udp_socket.send_to(packet, endpoint).await?;
}
}
TunnResult::Err(e) => {
debug!("Timer error for WG peer {}: {:?}",
peer.public_key_b64, e);
}
_ => {}
}
}
}
// --- Sync stats to ServerState (every 1s) ---
_ = stats_timer.tick() => {
let mut clients = state.clients.write().await;
let mut stats = state.stats.write().await;
for peer in peers.iter() {
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
if let Some(info) = clients.get_mut(&client_id) {
// Update stats delta
let prev_sent = info.bytes_sent;
let prev_recv = info.bytes_received;
info.bytes_sent = peer.stats.bytes_sent;
info.bytes_received = peer.stats.bytes_received;
info.remote_addr = peer.endpoint.map(|e| e.to_string());
// Update aggregate stats
stats.bytes_sent += peer.stats.bytes_sent.saturating_sub(prev_sent);
stats.bytes_received += peer.stats.bytes_received.saturating_sub(prev_recv);
}
}
}
// --- Dynamic peer commands ---
cmd = command_rx.recv() => {
match cmd {
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
let result = add_peer_to_loop(
&mut peers,
&peer_config,
&peer_index,
&rate_limiter,
&config.private_key,
);
if result.is_ok() {
// Register new peer in ServerState
let peer = peers.last().unwrap();
match register_wg_peer(&state, peer, &wg_return_tx).await {
Ok(Some(ip)) => {
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
}
Ok(None) => {}
Err(e) => {
warn!("Failed to register WG peer: {}", e);
}
}
}
let _ = resp_tx.send(result);
}
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
let prev_len = peers.len();
peers.retain(|p| p.public_key_b64 != pubkey);
if peers.len() < prev_len {
let vpn_ip = peer_vpn_ips.remove(&pubkey);
unregister_wg_peer(&state, &pubkey, vpn_ip).await;
let _ = resp_tx.send(Ok(()));
} else {
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
}
}
None => {
info!("WG command channel closed");
break;
}
}
}
// --- Shutdown ---
_ = shutdown_rx.recv() => {
info!("WireGuard listener shutdown signal received");
break;
}
}
}
// Cleanup: unregister all peers from ServerState
for peer in &peers {
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await;
}
info!("WireGuard listener stopped");
Ok(())
}
// ============================================================================
// WgClient
// ============================================================================
@@ -1077,6 +998,7 @@ fn chrono_now() -> String {
#[cfg(test)]
mod tests {
use super::*;
use crate::tunnel::extract_dst_ip;
use std::net::Ipv6Addr;
#[test]