fix(wireguard): sync runtime peer management with client registration and derive the correct server public key from the WireGuard private key

This commit is contained in:
2026-03-31 02:11:29 +00:00
parent 42949b1233
commit 6e4cafe3c5
4 changed files with 152 additions and 15 deletions

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use anyhow::{anyhow, Result};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use boringtun::noise::errors::WireGuardError;
use boringtun::noise::rate_limiter::RateLimiter;
use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::{PublicKey, StaticSecret};
@@ -99,6 +100,13 @@ pub fn generate_wg_keypair() -> (String, String) {
(pub_b64, priv_b64)
}
/// Derive the WireGuard public key (base64) from a private key (base64).
pub fn wg_public_key_from_private(private_key_b64: &str) -> Result<String> {
let private = parse_private_key(private_key_b64)?;
let public = PublicKey::from(&private);
Ok(BASE64.encode(public.to_bytes()))
}
fn parse_private_key(b64: &str) -> Result<StaticSecret> {
let bytes = BASE64.decode(b64)?;
if bytes.len() != 32 {
@@ -215,8 +223,8 @@ struct PeerState {
}
impl PeerState {
fn matches_dst(&self, dst_ip: IpAddr) -> bool {
self.allowed_ips.iter().any(|aip| aip.matches(dst_ip))
fn matches_allowed_ips(&self, ip: IpAddr) -> bool {
self.allowed_ips.iter().any(|aip| aip.matches(ip))
}
}
@@ -286,9 +294,10 @@ pub struct WgListenerConfig {
pub peers: Vec<WgPeerConfig>,
}
/// Extract the first /32 IPv4 address from a list of AllowedIp entries.
/// This is the peer's VPN IP used for return-packet routing.
/// Extract the peer's VPN IP from AllowedIp entries.
/// Prefers /32 entries (exact match); falls back to any IPv4 address.
fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option<Ipv4Addr> {
// Prefer /32 entries (exact peer VPN IP)
for aip in allowed_ips {
if let IpAddr::V4(v4) = aip.addr {
if aip.prefix_len == 32 {
@@ -296,6 +305,12 @@ fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option<Ipv4Addr> {
}
}
}
// Fallback: use the first IPv4 address from any prefix length
for aip in allowed_ips {
if let IpAddr::V4(v4) = aip.addr {
return Some(v4);
}
}
None
}
@@ -495,7 +510,7 @@ pub async fn run_wg_listener(
break;
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if peer.matches_dst(IpAddr::V4(addr)) {
if peer.matches_allowed_ips(IpAddr::V4(addr)) {
let pkt_len = packet.len() as u64;
// Forward via shared forwarding engine
let mut engine = state.forwarding_engine.lock().await;
@@ -519,7 +534,7 @@ pub async fn run_wg_listener(
break;
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if peer.matches_dst(IpAddr::V6(addr)) {
if peer.matches_allowed_ips(IpAddr::V6(addr)) {
let pkt_len = packet.len() as u64;
let mut engine = state.forwarding_engine.lock().await;
match &mut *engine {
@@ -586,6 +601,9 @@ pub async fn run_wg_listener(
udp_socket.send_to(packet, endpoint).await?;
}
}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
warn!("WG peer {} connection expired", peer.public_key_b64);
}
TunnResult::Err(e) => {
debug!("Timer error for WG peer {}: {:?}",
peer.public_key_b64, e);
@@ -796,12 +814,12 @@ impl WgClient {
let state = self.state.clone();
let assigned_ip = config.address.clone();
// Update state
// Update state — handshake hasn't completed yet
{
let mut s = state.write().await;
s.state = "connected".to_string();
s.state = "handshaking".to_string();
s.assigned_ip = Some(assigned_ip.clone());
s.connected_since = Some(chrono_now());
s.connected_since = None;
}
// Spawn client loop
@@ -868,7 +886,7 @@ async fn wg_client_loop(
endpoint: SocketAddr,
_allowed_ips: Vec<AllowedIp>,
shared_stats: Arc<RwLock<WgPeerStats>>,
_state: Arc<RwLock<WgClientState>>,
state: Arc<RwLock<WgClientState>>,
mut shutdown_rx: oneshot::Receiver<()>,
) -> Result<()> {
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
@@ -876,6 +894,7 @@ async fn wg_client_loop(
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1));
let mut handshake_complete = false;
let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device);
@@ -916,14 +935,37 @@ async fn wg_client_loop(
tun_writer.write_all(packet).await?;
local_stats.bytes_received += pkt_len;
local_stats.packets_received += 1;
if !handshake_complete {
handshake_complete = true;
let mut s = state.write().await;
s.state = "connected".to_string();
s.connected_since = Some(chrono_now());
info!("WireGuard handshake completed, tunnel active");
}
}
TunnResult::WriteToTunnelV6(packet, _addr) => {
let pkt_len = packet.len() as u64;
tun_writer.write_all(packet).await?;
local_stats.bytes_received += pkt_len;
local_stats.packets_received += 1;
if !handshake_complete {
handshake_complete = true;
let mut s = state.write().await;
s.state = "connected".to_string();
s.connected_since = Some(chrono_now());
info!("WireGuard handshake completed, tunnel active");
}
}
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
warn!("WireGuard session expired during decapsulate, re-initiating handshake");
match tunn.format_handshake_initiation(&mut dst_buf, true) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
}
_ => {}
}
}
TunnResult::Err(e) => {
debug!("Client decapsulate error: {:?}", e);
}
@@ -955,6 +997,19 @@ async fn wg_client_loop(
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
warn!("WireGuard connection expired, re-initiating handshake");
match tunn.format_handshake_initiation(&mut dst_buf, true) {
TunnResult::WriteToNetwork(packet) => {
udp_socket.send_to(packet, endpoint).await?;
debug!("Sent handshake re-initiation after expiry");
}
TunnResult::Err(e) => {
warn!("Failed to re-initiate handshake: {:?}", e);
}
_ => {}
}
}
TunnResult::Err(e) => {
debug!("Client timer error: {:?}", e);
}
@@ -1028,6 +1083,19 @@ mod tests {
assert_eq!(public.to_bytes(), derived_public.to_bytes());
}
#[test]
fn test_wg_public_key_from_private() {
let (pub_b64, priv_b64) = generate_wg_keypair();
let derived = wg_public_key_from_private(&priv_b64).unwrap();
assert_eq!(derived, pub_b64);
}
#[test]
fn test_wg_public_key_from_private_invalid() {
assert!(wg_public_key_from_private("not-valid").is_err());
assert!(wg_public_key_from_private("AAAA").is_err());
}
#[test]
fn test_parse_invalid_key() {
assert!(parse_private_key("not-valid-base64!!!").is_err());