Compare commits

..

4 Commits

Author SHA1 Message Date
jkunz 004c9ed252 v1.20.0 2026-05-24 01:24:11 +00:00
jkunz 90d7f0903b feat(userspace-nat): add VPN metadata to PROXY protocol forwarding 2026-05-24 01:23:53 +00:00
jkunz 10f9c2e609 v1.19.4 2026-05-12 23:08:13 +00:00
jkunz 3c515c7d7f fix(rust-wireguard): keep WireGuard peer registration and client state in sync 2026-05-12 23:08:11 +00:00
10 changed files with 526 additions and 67 deletions
+19
View File
@@ -3,6 +3,25 @@
## Pending
## 2026-05-24 - 1.20.0
### Features
- add PROXY v2 real-source forwarding with authenticated VPN metadata TLVs (userspace-nat)
- allows socket forwarding to emit the client remote IP instead of the tunnel IP when configured
- adds SmartVPN client metadata to outbound PROXY v2 headers for trusted downstream authorization
## 2026-05-12 - 1.19.4
### Fixes
- keep WireGuard peer registration and client state in sync (rust-wireguard)
- index WireGuard public keys in the client registry for duplicate detection and direct lookup
- skip inactive clients when loading or adding WireGuard peers and fail fast when peer registration cannot be completed
- make IP reservation idempotent for the same client and avoid releasing WireGuard-assigned IPs on disconnect
- roll back client registry and peer state when WireGuard peer creation or key rotation fails
- update hybrid routing entries when registered client networking flags change
## 2026-05-12 - 1.19.3
### Fixes
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartvpn",
"version": "1.19.3",
"version": "1.20.0",
"private": false,
"description": "A VPN solution with TypeScript control plane and Rust data plane daemon",
"type": "module",
+71 -1
View File
@@ -98,6 +98,8 @@ pub struct ClientRegistry {
entries: HashMap<String, ClientEntry>,
/// Secondary index: publicKey (base64) → clientId (fast lookup during handshake)
key_index: HashMap<String, String>,
/// WireGuard public key → clientId (fast lookup during WG handshakes)
wg_key_index: HashMap<String, String>,
/// Tertiary index: assignedIp → clientId (fast lookup during NAT destination policy)
ip_index: HashMap<String, String>,
}
@@ -107,6 +109,7 @@ impl ClientRegistry {
Self {
entries: HashMap::new(),
key_index: HashMap::new(),
wg_key_index: HashMap::new(),
ip_index: HashMap::new(),
}
}
@@ -132,6 +135,12 @@ impl ClientRegistry {
if self.key_index.contains_key(&entry.public_key) {
anyhow::bail!("Public key already registered to another client");
}
if let Some(ref wg_key) = entry.wg_public_key {
if self.wg_key_index.contains_key(wg_key) {
anyhow::bail!("WireGuard public key already registered to another client");
}
self.wg_key_index.insert(wg_key.clone(), entry.client_id.clone());
}
self.key_index.insert(entry.public_key.clone(), entry.client_id.clone());
if let Some(ref ip) = entry.assigned_ip {
self.ip_index.insert(ip.clone(), entry.client_id.clone());
@@ -145,6 +154,9 @@ impl ClientRegistry {
let entry = self.entries.remove(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
self.key_index.remove(&entry.public_key);
if let Some(ref wg_key) = entry.wg_public_key {
self.wg_key_index.remove(wg_key);
}
if let Some(ref ip) = entry.assigned_ip {
self.ip_index.remove(ip);
}
@@ -162,6 +174,12 @@ impl ClientRegistry {
self.entries.get(client_id)
}
/// Get a client by WireGuard public key.
pub fn get_by_wg_key(&self, public_key: &str) -> Option<&ClientEntry> {
let client_id = self.wg_key_index.get(public_key)?;
self.entries.get(client_id)
}
/// Get a client by assigned IP (used for per-client destination policy in NAT engine).
pub fn get_by_assigned_ip(&self, ip: &str) -> Option<&ClientEntry> {
let client_id = self.ip_index.get(ip)?;
@@ -184,6 +202,7 @@ impl ClientRegistry {
let entry = self.entries.get_mut(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
let old_key = entry.public_key.clone();
let old_wg_key = entry.wg_public_key.clone();
let old_ip = entry.assigned_ip.clone();
updater(entry);
// If public key changed, update the key index
@@ -191,6 +210,15 @@ impl ClientRegistry {
self.key_index.remove(&old_key);
self.key_index.insert(entry.public_key.clone(), client_id.to_string());
}
// If WireGuard public key changed, update the WG key index.
if entry.wg_public_key != old_wg_key {
if let Some(ref old) = old_wg_key {
self.wg_key_index.remove(old);
}
if let Some(ref new_key) = entry.wg_public_key {
self.wg_key_index.insert(new_key.clone(), client_id.to_string());
}
}
// If assigned IP changed, update the IP index
if entry.assigned_ip != old_ip {
if let Some(ref old) = old_ip {
@@ -210,13 +238,32 @@ impl ClientRegistry {
/// Rotate a client's keys. Returns the updated entry.
pub fn rotate_key(&mut self, client_id: &str, new_public_key: String, new_wg_public_key: Option<String>) -> Result<()> {
if let Some(existing_client_id) = self.key_index.get(&new_public_key) {
if existing_client_id != client_id {
anyhow::bail!("Public key already registered to another client");
}
}
if let Some(ref new_wg_key) = new_wg_public_key {
if let Some(existing_client_id) = self.wg_key_index.get(new_wg_key) {
if existing_client_id != client_id {
anyhow::bail!("WireGuard public key already registered to another client");
}
}
}
let entry = self.entries.get_mut(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
// Update key index
self.key_index.remove(&entry.public_key);
if let Some(ref old_wg_key) = entry.wg_public_key {
self.wg_key_index.remove(old_wg_key);
}
entry.public_key = new_public_key.clone();
entry.wg_public_key = new_wg_public_key;
self.key_index.insert(new_public_key, client_id.to_string());
if let Some(ref wg_key) = entry.wg_public_key {
self.wg_key_index.insert(wg_key.clone(), client_id.to_string());
}
Ok(())
}
@@ -367,13 +414,36 @@ mod tests {
let mut reg = ClientRegistry::new();
reg.add(make_entry("alice", "old_key")).unwrap();
reg.rotate_key("alice", "new_key".to_string(), None).unwrap();
reg.rotate_key("alice", "new_key".to_string(), Some("new_wg_key".to_string())).unwrap();
assert!(reg.get_by_key("old_key").is_none());
assert!(reg.get_by_key("new_key").is_some());
assert!(reg.get_by_wg_key("new_wg_key").is_some());
assert_eq!(reg.get_by_id("alice").unwrap().public_key, "new_key");
}
#[test]
fn lookup_by_wireguard_key() {
let mut reg = ClientRegistry::new();
let mut entry = make_entry("alice", "key_alice");
entry.wg_public_key = Some("wg_key_alice".to_string());
reg.add(entry).unwrap();
assert_eq!(reg.get_by_wg_key("wg_key_alice").unwrap().client_id, "alice");
}
#[test]
fn reject_duplicate_wireguard_key() {
let mut reg = ClientRegistry::new();
let mut alice = make_entry("alice", "key_alice");
alice.wg_public_key = Some("same_wg_key".to_string());
let mut bob = make_entry("bob", "key_bob");
bob.wg_public_key = Some("same_wg_key".to_string());
reg.add(alice).unwrap();
assert!(reg.add(bob).is_err());
}
#[test]
fn from_entries() {
let entries = vec![
+14 -1
View File
@@ -124,7 +124,10 @@ impl IpPool {
/// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips).
pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> {
if self.allocated.contains_key(&ip) {
if let Some(existing_client_id) = self.allocated.get(&ip) {
if existing_client_id == client_id {
return Ok(());
}
anyhow::bail!("IP {} is already allocated", ip);
}
self.allocated.insert(ip, client_id.to_string());
@@ -233,6 +236,16 @@ mod tests {
assert!(pool.allocate("client2").is_err());
}
#[test]
fn ip_pool_reserve_is_idempotent_for_same_client() {
let mut pool = IpPool::new("10.9.0.0/24").unwrap();
let ip = pool.allocate("alice").unwrap();
pool.reserve(ip, "alice").unwrap();
assert_eq!(pool.allocated_count(), 1);
assert!(pool.reserve(ip, "bob").is_err());
}
#[test]
fn ip_pool_invalid_subnet() {
assert!(IpPool::new("invalid").is_err());
+78 -3
View File
@@ -4,6 +4,7 @@
//! Spec: <https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt>
use anyhow::Result;
use serde::Serialize;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use tokio::io::AsyncReadExt;
@@ -17,6 +18,20 @@ const PP_V2_SIGNATURE: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
/// Custom PROXY v2 TLV type for authenticated SmartVPN connection metadata.
/// 0xEA is in the PP2_TYPE_MIN_CUSTOM range (0xE0-0xEF).
pub const PP2_TYPE_SMARTVPN_METADATA: u8 = 0xEA;
/// Authenticated VPN metadata sent to trusted downstream proxies.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct VpnProxyMetadata {
pub client_id: String,
pub assigned_ip: String,
pub transport_type: String,
pub remote_addr: Option<String>,
}
/// Parsed PROXY protocol v2 header.
#[derive(Debug, Clone)]
pub struct ProxyHeader {
@@ -124,6 +139,26 @@ async fn read_proxy_header_inner(stream: &mut TcpStream) -> Result<ProxyHeader>
/// Build a PROXY protocol v2 header (for testing / proxy implementations).
pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
build_pp_v2_header_with_vpn_metadata(src, dst, None)
}
/// Build a PROXY protocol v2 header with optional SmartVPN metadata TLV.
pub fn build_pp_v2_header_with_vpn_metadata(
src: SocketAddr,
dst: SocketAddr,
vpn_metadata: Option<&VpnProxyMetadata>,
) -> Vec<u8> {
let mut tlv_bytes = Vec::new();
if let Some(metadata) = vpn_metadata {
if let Ok(json) = serde_json::to_vec(metadata) {
if json.len() <= u16::MAX as usize {
tlv_bytes.push(PP2_TYPE_SMARTVPN_METADATA);
tlv_bytes.extend_from_slice(&(json.len() as u16).to_be_bytes());
tlv_bytes.extend_from_slice(&json);
}
}
}
let mut buf = Vec::new();
buf.extend_from_slice(&PP_V2_SIGNATURE);
@@ -131,22 +166,38 @@ pub fn build_pp_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
buf.push(0x21); // version 2 | PROXY command
buf.push(0x11); // AF_INET | STREAM
buf.extend_from_slice(&12u16.to_be_bytes()); // addr length
buf.extend_from_slice(&((12 + tlv_bytes.len()) as u16).to_be_bytes()); // addr length
buf.extend_from_slice(&s.ip().octets());
buf.extend_from_slice(&d.ip().octets());
buf.extend_from_slice(&s.port().to_be_bytes());
buf.extend_from_slice(&d.port().to_be_bytes());
buf.extend_from_slice(&tlv_bytes);
}
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
buf.push(0x21); // version 2 | PROXY command
buf.push(0x21); // AF_INET6 | STREAM
buf.extend_from_slice(&36u16.to_be_bytes()); // addr length
buf.extend_from_slice(&((36 + tlv_bytes.len()) as u16).to_be_bytes()); // addr length
buf.extend_from_slice(&s.ip().octets());
buf.extend_from_slice(&d.ip().octets());
buf.extend_from_slice(&s.port().to_be_bytes());
buf.extend_from_slice(&d.port().to_be_bytes());
buf.extend_from_slice(&tlv_bytes);
}
_ => {
let src_v6 = match src.ip() {
std::net::IpAddr::V4(v4) => v4.to_ipv6_mapped(),
std::net::IpAddr::V6(v6) => v6,
};
let dst_v6 = match dst.ip() {
std::net::IpAddr::V4(v4) => v4.to_ipv6_mapped(),
std::net::IpAddr::V6(v6) => v6,
};
return build_pp_v2_header_with_vpn_metadata(
SocketAddr::V6(SocketAddrV6::new(src_v6, src.port(), 0, 0)),
SocketAddr::V6(SocketAddrV6::new(dst_v6, dst.port(), 0, 0)),
vpn_metadata,
);
}
_ => panic!("Mismatched address families"),
}
buf
}
@@ -197,6 +248,30 @@ mod tests {
assert_eq!(parsed.dst_addr, dst);
}
#[tokio::test]
async fn build_ipv4_header_with_smartvpn_metadata_tlv() {
let src = "203.0.113.50:12345".parse::<SocketAddr>().unwrap();
let dst = "10.0.0.1:443".parse::<SocketAddr>().unwrap();
let metadata = VpnProxyMetadata {
client_id: "alice".to_string(),
assigned_ip: "10.8.0.2".to_string(),
transport_type: "wireguard".to_string(),
remote_addr: Some("203.0.113.50:51820".to_string()),
};
let header = build_pp_v2_header_with_vpn_metadata(src, dst, Some(&metadata));
let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize;
assert!(addr_len > 12);
assert_eq!(header[28], PP2_TYPE_SMARTVPN_METADATA);
let tlv_len = u16::from_be_bytes([header[29], header[30]]) as usize;
let json = std::str::from_utf8(&header[31..31 + tlv_len]).unwrap();
assert!(json.contains("\"clientId\":\"alice\""));
let parsed = parse_header_from_bytes(&header).await.unwrap();
assert_eq!(parsed.src_addr, src);
assert_eq!(parsed.dst_addr, dst);
}
#[tokio::test]
async fn parse_valid_ipv6_header() {
let src = "[2001:db8::1]:54321".parse::<SocketAddr>().unwrap();
+161 -28
View File
@@ -73,9 +73,12 @@ pub struct ServerConfig {
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
pub connection_ip_block_list: Option<Vec<String>>,
/// When true and forwarding_mode is "socket", the userspace NAT engine prepends
/// PROXY protocol v2 headers on outbound TCP connections, conveying the VPN client's
/// tunnel IP as the source address.
/// PROXY protocol v2 headers on outbound TCP connections.
pub socket_forward_proxy_protocol: Option<bool>,
/// Source address for outbound PROXY protocol headers: "tunnelIp" (legacy) or "remoteIp".
pub socket_forward_proxy_protocol_source: Option<String>,
/// Include authenticated SmartVPN metadata as a custom PROXY v2 TLV.
pub socket_forward_proxy_protocol_vpn_metadata: Option<bool>,
/// Destination routing policy for VPN client traffic (socket mode).
pub destination_policy: Option<DestinationPolicyConfig>,
/// WireGuard: server X25519 private key (base64). Required when transport includes WG.
@@ -431,11 +434,17 @@ impl VpnServer {
ForwardingSetup::Socket { packet_tx, packet_rx, shutdown_rx } => {
*state.forwarding_engine.lock().await = ForwardingEngine::Socket(packet_tx);
let proxy_protocol = config.socket_forward_proxy_protocol.unwrap_or(false);
let proxy_protocol_source = crate::userspace_nat::ProxyProtocolSource::from_config(
config.socket_forward_proxy_protocol_source.as_deref(),
);
let proxy_protocol_vpn_metadata = config.socket_forward_proxy_protocol_vpn_metadata.unwrap_or(false);
let nat_engine = crate::userspace_nat::NatEngine::new(
gateway_ip,
link_mtu as usize,
state.clone(),
proxy_protocol,
proxy_protocol_source,
proxy_protocol_vpn_metadata,
config.destination_policy.clone(),
);
tokio::spawn(async move {
@@ -473,11 +482,17 @@ impl VpnServer {
// Start socket (NAT) engine
let proxy_protocol = config.socket_forward_proxy_protocol.unwrap_or(false);
let proxy_protocol_source = crate::userspace_nat::ProxyProtocolSource::from_config(
config.socket_forward_proxy_protocol_source.as_deref(),
);
let proxy_protocol_vpn_metadata = config.socket_forward_proxy_protocol_vpn_metadata.unwrap_or(false);
let nat_engine = crate::userspace_nat::NatEngine::new(
gateway_ip,
link_mtu as usize,
state.clone(),
proxy_protocol,
proxy_protocol_source,
proxy_protocol_vpn_metadata,
config.destination_policy.clone(),
);
tokio::spawn(async move {
@@ -584,6 +599,9 @@ impl VpnServer {
if self.wg_command_tx.is_some() {
let registry = state.client_registry.read().await;
for entry in registry.list() {
if !entry.is_enabled() || entry.is_expired() {
continue;
}
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
let peer_config = crate::wireguard::WgPeerConfig {
public_key: wg_key.clone(),
@@ -725,8 +743,10 @@ impl VpnServer {
if let Some(ref state) = self.state {
let mut clients = state.clients.write().await;
if let Some(client) = clients.remove(client_id) {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
if client.transport_type != "wireguard" {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
}
state.rate_limiters.lock().await.remove(client_id);
info!("Client {} disconnected", client_id);
}
@@ -890,7 +910,9 @@ impl VpnServer {
persistent_keepalive: Some(25),
};
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
warn!("Failed to register WG peer for client {}: {}", client_id, e);
let _ = state.client_registry.write().await.remove(&client_id);
state.ip_pool.lock().await.release(&assigned_ip);
return Err(anyhow::anyhow!("Failed to register WG peer for client {}: {}", client_id, e));
}
}
@@ -948,15 +970,28 @@ impl VpnServer {
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let entry = state.client_registry.write().await.remove(client_id)?;
let entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
};
// Remove WG peer from running listener
if self.wg_command_tx.is_some() {
if let Some(ref wg_key) = entry.wg_public_key {
if let Err(e) = self.remove_wg_peer(wg_key).await {
debug!("Failed to remove WG peer for client {}: {}", client_id, e);
if entry.is_enabled() && !entry.is_expired() {
return Err(anyhow::anyhow!("Failed to remove WG peer for client {}: {}", client_id, e));
}
debug!("Failed to remove inactive WG peer for client {}: {}", client_id, e);
}
}
}
state.client_registry.write().await.remove(client_id)?;
// Release the IP if assigned
if let Some(ref ip_str) = entry.assigned_ip {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
@@ -1013,7 +1048,37 @@ impl VpnServer {
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
entry.expires_at = Some(expires.to_string());
}
if let Some(use_host_ip) = update.get("useHostIp").and_then(|v| v.as_bool()) {
entry.use_host_ip = Some(use_host_ip);
}
if let Some(use_dhcp) = update.get("useDhcp").and_then(|v| v.as_bool()) {
entry.use_dhcp = Some(use_dhcp);
}
if let Some(static_ip) = update.get("staticIp").and_then(|v| v.as_str()) {
entry.static_ip = Some(static_ip.to_string());
}
if let Some(force_vlan) = update.get("forceVlan").and_then(|v| v.as_bool()) {
entry.force_vlan = Some(force_vlan);
}
if let Some(vlan_id) = update.get("vlanId").and_then(|v| v.as_u64()) {
entry.vlan_id = Some(vlan_id as u16);
}
})?;
let updated_entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id).cloned()
};
if let Some(entry) = updated_entry {
if let Some(ref ip_str) = entry.assigned_ip {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
routing_table.write().await.insert(ip, entry.use_host_ip.unwrap_or(false));
}
}
}
}
Ok(())
}
@@ -1023,13 +1088,56 @@ impl VpnServer {
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(true);
})
})?;
let entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
};
if self.wg_command_tx.is_some() {
if let (Some(ref wg_key), Some(ref ip_str)) = (&entry.wg_public_key, &entry.assigned_ip) {
let peer_config = crate::wireguard::WgPeerConfig {
public_key: wg_key.clone(),
preshared_key: None,
allowed_ips: vec![format!("{}/32", ip_str)],
endpoint: None,
persistent_keepalive: Some(25),
};
if let Err(e) = self.add_wg_peer(peer_config).await {
let _ = state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(false);
});
return Err(anyhow::anyhow!("Failed to register WG peer for enabled client {}: {}", client_id, e));
}
}
}
Ok(())
}
/// Disable a registered client (also disconnects if connected).
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let entry = {
let registry = state.client_registry.read().await;
registry.get_by_id(client_id).cloned()
};
if self.wg_command_tx.is_some() {
if let Some(ref entry) = entry {
if let Some(ref wg_key) = entry.wg_public_key {
if let Err(e) = self.remove_wg_peer(wg_key).await {
debug!("Failed to remove WG peer for disabled client {}: {}", client_id, e);
}
}
}
}
state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(false);
})?;
@@ -1043,17 +1151,33 @@ impl VpnServer {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
// Capture old WG key before rotation (needed to remove from WG listener)
let old_wg_pub = {
// Capture old keys before rotation so listener and registry can be rolled back together.
let old_entry = {
let registry = state.client_registry.read().await;
let entry = registry.get_by_id(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
entry.wg_public_key.clone()
registry.get_by_id(client_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?
};
let old_noise_pub = old_entry.public_key.clone();
let old_wg_pub = old_entry.wg_public_key.clone();
let assigned_ip = old_entry.assigned_ip.clone().unwrap_or_else(|| "0.0.0.0".to_string());
let should_have_wg_peer = self.wg_command_tx.is_some()
&& old_entry.is_enabled()
&& !old_entry.is_expired()
&& old_wg_pub.is_some()
&& assigned_ip != "0.0.0.0";
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
if should_have_wg_peer {
if let Some(ref old_key) = old_wg_pub {
if let Err(e) = self.remove_wg_peer(old_key).await {
return Err(anyhow::anyhow!("Failed to remove old WG peer during rotation for {}: {}", client_id, e));
}
}
}
state.client_registry.write().await.rotate_key(
client_id,
noise_pub.clone(),
@@ -1063,31 +1187,40 @@ impl VpnServer {
// Disconnect existing connection (old key is no longer valid)
let _ = self.disconnect_client(client_id).await;
// Get updated entry for the config bundle
let entry_json = self.get_registered_client(client_id).await?;
let assigned_ip = entry_json.get("assignedIp")
.and_then(|v| v.as_str())
.unwrap_or("0.0.0.0");
// Update WG listener: remove old peer, add new peer
if self.wg_command_tx.is_some() {
if let Some(ref old_key) = old_wg_pub {
if let Err(e) = self.remove_wg_peer(old_key).await {
debug!("Failed to remove old WG peer during rotation: {}", e);
}
}
// Update WG listener with the new key. Roll back the registry if this fails.
if should_have_wg_peer {
let wg_peer_config = crate::wireguard::WgPeerConfig {
public_key: wg_pub.clone(),
preshared_key: None,
allowed_ips: vec![format!("{}/32", assigned_ip)],
allowed_ips: vec![format!("{}/32", assigned_ip.as_str())],
endpoint: None,
persistent_keepalive: Some(25),
};
if let Err(e) = self.add_wg_peer(wg_peer_config).await {
warn!("Failed to register new WG peer during rotation: {}", e);
let _ = state.client_registry.write().await.rotate_key(
client_id,
old_noise_pub,
old_wg_pub.clone(),
);
if let Some(ref old_key) = old_wg_pub {
let rollback_peer_config = crate::wireguard::WgPeerConfig {
public_key: old_key.clone(),
preshared_key: None,
allowed_ips: vec![format!("{}/32", assigned_ip.as_str())],
endpoint: None,
persistent_keepalive: Some(25),
};
if let Err(rollback_err) = self.add_wg_peer(rollback_peer_config).await {
warn!("Failed to restore old WG peer after rotation failure for {}: {}", client_id, rollback_err);
}
}
return Err(anyhow::anyhow!("Failed to register new WG peer during rotation for {}: {}", client_id, e));
}
}
// Get updated entry for the config bundle
let entry_json = self.get_registered_client(client_id).await?;
let smartvpn_server_url = format!("wss://{}",
state.config.server_endpoint.as_deref()
.unwrap_or(&state.config.listen_addr)
+97 -6
View File
@@ -21,6 +21,21 @@ use crate::tunnel;
/// Sessions exceeding this are aborted — the client cannot keep up.
const TCP_PENDING_SEND_MAX: usize = 512 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProxyProtocolSource {
TunnelIp,
RemoteIp,
}
impl ProxyProtocolSource {
pub fn from_config(value: Option<&str>) -> Self {
match value {
Some("remoteIp") => ProxyProtocolSource::RemoteIp,
_ => ProxyProtocolSource::TunnelIp,
}
}
}
// ============================================================================
// Virtual IP device for smoltcp
// ============================================================================
@@ -208,8 +223,10 @@ pub struct NatEngine {
bridge_tx: mpsc::Sender<BridgeMessage>,
start_time: std::time::Instant,
/// When true, outbound TCP connections prepend PROXY protocol v2 headers
/// with the VPN client's tunnel IP as source address.
/// with VPN source identity.
proxy_protocol: bool,
proxy_protocol_source: ProxyProtocolSource,
proxy_protocol_vpn_metadata: bool,
/// Destination routing policy: forceTarget, block, or allow.
destination_policy: Option<DestinationPolicyConfig>,
}
@@ -225,7 +242,15 @@ enum DestinationAction {
}
impl NatEngine {
pub fn new(gateway_ip: Ipv4Addr, mtu: usize, state: Arc<ServerState>, proxy_protocol: bool, destination_policy: Option<DestinationPolicyConfig>) -> Self {
pub fn new(
gateway_ip: Ipv4Addr,
mtu: usize,
state: Arc<ServerState>,
proxy_protocol: bool,
proxy_protocol_source: ProxyProtocolSource,
proxy_protocol_vpn_metadata: bool,
destination_policy: Option<DestinationPolicyConfig>,
) -> Self {
let mut device = VirtualIpDevice::new(mtu);
let config = Config::new(HardwareAddress::Ip);
let now = smoltcp::time::Instant::from_millis(0);
@@ -258,6 +283,8 @@ impl NatEngine {
bridge_tx,
start_time: std::time::Instant::now(),
proxy_protocol,
proxy_protocol_source,
proxy_protocol_vpn_metadata,
destination_policy,
}
}
@@ -481,6 +508,9 @@ impl NatEngine {
// Start bridge tasks for sessions whose handshake just completed
let bridge_tx_clone = self.bridge_tx.clone();
let proxy_protocol = self.proxy_protocol;
let proxy_protocol_source = self.proxy_protocol_source;
let proxy_protocol_vpn_metadata = self.proxy_protocol_vpn_metadata;
let state = Arc::clone(&self.state);
for (key, session) in self.tcp_sessions.iter_mut() {
if !session.bridge_started && !session.closing {
let socket = self.sockets.get_mut::<tcp::Socket>(session.smoltcp_handle);
@@ -492,8 +522,11 @@ impl NatEngine {
let k = key.clone();
let addr = session.connect_addr;
let pp = proxy_protocol;
let pp_source = proxy_protocol_source;
let pp_metadata = proxy_protocol_vpn_metadata;
let state = Arc::clone(&state);
tokio::spawn(async move {
tcp_bridge_task(k, data_rx, btx, pp, addr).await;
tcp_bridge_task(k, data_rx, btx, pp, pp_source, pp_metadata, state, addr).await;
});
debug!("NAT: TCP handshake complete, starting bridge for {}:{} -> {}:{}",
key.src_ip, key.src_port, key.dst_ip, key.dst_port);
@@ -748,6 +781,9 @@ async fn tcp_bridge_task(
mut data_rx: mpsc::Receiver<Vec<u8>>,
bridge_tx: mpsc::Sender<BridgeMessage>,
proxy_protocol: bool,
proxy_protocol_source: ProxyProtocolSource,
proxy_protocol_vpn_metadata: bool,
state: Arc<ServerState>,
connect_addr: SocketAddr,
) {
// Connect to resolved destination (may differ from key.dst_ip if policy rewrote it)
@@ -768,11 +804,21 @@ async fn tcp_bridge_task(
let (mut reader, mut writer) = stream.into_split();
// Send PROXY protocol v2 header with VPN client's tunnel IP as source
// Send PROXY protocol v2 header with configured client source identity.
if proxy_protocol {
let src = SocketAddr::new(key.src_ip.into(), key.src_port);
let (src, metadata) = build_proxy_protocol_identity(
&state,
key.src_ip,
key.src_port,
proxy_protocol_source,
proxy_protocol_vpn_metadata,
).await;
let dst = SocketAddr::new(key.dst_ip.into(), key.dst_port);
let pp_header = crate::proxy_protocol::build_pp_v2_header(src, dst);
let pp_header = crate::proxy_protocol::build_pp_v2_header_with_vpn_metadata(
src,
dst,
metadata.as_ref(),
);
if let Err(e) = writer.write_all(&pp_header).await {
debug!("NAT: failed to send PP v2 header to {}: {}", connect_addr, e);
let _ = bridge_tx.send(BridgeMessage::TcpClosed { key }).await;
@@ -818,6 +864,51 @@ async fn tcp_bridge_task(
read_task.abort();
}
async fn build_proxy_protocol_identity(
state: &Arc<ServerState>,
tunnel_ip: Ipv4Addr,
tunnel_port: u16,
proxy_protocol_source: ProxyProtocolSource,
include_metadata: bool,
) -> (SocketAddr, Option<crate::proxy_protocol::VpnProxyMetadata>) {
let tunnel_addr = SocketAddr::new(tunnel_ip.into(), tunnel_port);
let client_id = state
.client_registry
.read()
.await
.get_by_assigned_ip(&tunnel_ip.to_string())
.map(|entry| entry.client_id.clone());
let client_info = if let Some(ref client_id) = client_id {
state.clients.read().await.get(client_id).cloned()
} else {
None
};
let remote_addr = client_info
.as_ref()
.and_then(|info| info.remote_addr.as_ref())
.and_then(|addr| addr.parse::<SocketAddr>().ok());
let source_addr = match proxy_protocol_source {
ProxyProtocolSource::RemoteIp => remote_addr.unwrap_or(tunnel_addr),
ProxyProtocolSource::TunnelIp => tunnel_addr,
};
let metadata = if include_metadata {
client_info.map(|info| crate::proxy_protocol::VpnProxyMetadata {
client_id: info.registered_client_id,
assigned_ip: info.assigned_ip,
transport_type: info.transport_type,
remote_addr: info.remote_addr,
})
} else {
None
};
(source_addr, metadata)
}
async fn udp_bridge_task(
key: SessionKey,
mut data_rx: mpsc::Receiver<Vec<u8>>,
+76 -23
View File
@@ -215,6 +215,8 @@ pub enum WgCommand {
struct PeerState {
tunn: Tunn,
public_key_b64: String,
/// Registered SmartVPN client ID when this peer belongs to a registry entry.
client_id: Option<String>,
allowed_ips: Vec<AllowedIp>,
endpoint: Option<SocketAddr>,
#[allow(dead_code)]
@@ -237,6 +239,28 @@ impl PeerState {
}
}
fn synthetic_wg_client_id(pubkey: &str) -> String {
format!("wg-{}", &pubkey[..8.min(pubkey.len())])
}
fn peer_client_id(peer: &PeerState) -> String {
peer.client_id
.clone()
.unwrap_or_else(|| synthetic_wg_client_id(&peer.public_key_b64))
}
async fn registered_peer_info(
state: &Arc<ServerState>,
pubkey: &str,
) -> Option<(String, bool, bool)> {
let registry = state.client_registry.read().await;
registry.get_by_wg_key(pubkey).map(|entry| (
entry.client_id.clone(),
entry.use_host_ip.unwrap_or(false),
entry.is_enabled() && !entry.is_expired(),
))
}
fn add_peer_to_loop(
peers: &mut Vec<PeerState>,
@@ -281,6 +305,7 @@ fn add_peer_to_loop(
peers.push(PeerState {
tunn,
public_key_b64: config.public_key.clone(),
client_id: None,
allowed_ips,
endpoint,
persistent_keepalive: config.persistent_keepalive,
@@ -344,7 +369,7 @@ fn wg_timestamp_now() -> String {
/// Returns the VPN IP.
async fn register_wg_peer(
state: &Arc<ServerState>,
peer: &PeerState,
peer: &mut PeerState,
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
) -> Result<Option<Ipv4Addr>> {
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
@@ -356,12 +381,24 @@ async fn register_wg_peer(
}
};
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let (client_id, use_host_ip) = match registered_peer_info(state, &peer.public_key_b64).await {
Some((client_id, use_host_ip, true)) => (client_id, use_host_ip),
Some((client_id, _, false)) => {
warn!("WG peer {} maps to disabled or expired client {}, skipping registration", peer.public_key_b64, client_id);
return Ok(None);
}
None => (synthetic_wg_client_id(&peer.public_key_b64), false),
};
peer.client_id = Some(client_id.clone());
// Reserve IP in the pool
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
return Ok(None);
return Err(e);
}
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
routing_table.write().await.insert(vpn_ip, use_host_ip);
}
// Create per-peer return channel and register in tun_routes
@@ -393,7 +430,7 @@ async fn connect_wg_peer(
peer: &PeerState,
vpn_ip: Ipv4Addr,
) {
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let client_id = peer_client_id(peer);
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: vpn_ip.to_string(),
@@ -426,24 +463,27 @@ async fn connect_wg_peer(
/// Remove a WG peer from state.clients (disconnect without unregistering).
async fn disconnect_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
peer: &PeerState,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
let client_id = peer_client_id(peer);
if state.clients.write().await.remove(&client_id).is_some() {
info!("WG peer {} disconnected (removed from active clients)", pubkey);
info!("WG peer {} disconnected (removed from active clients)", peer.public_key_b64);
}
}
/// Unregister a WG peer from ServerState.
async fn unregister_wg_peer(
state: &Arc<ServerState>,
pubkey: &str,
peer: &PeerState,
vpn_ip: Option<Ipv4Addr>,
) {
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
let client_id = peer_client_id(peer);
if let Some(ip) = vpn_ip {
state.tun_routes.write().await.remove(&ip);
if let ForwardingEngine::Hybrid { routing_table, .. } = &*state.forwarding_engine.lock().await {
routing_table.write().await.remove(&ip);
}
state.ip_pool.lock().await.release(&ip);
}
state.clients.write().await.remove(&client_id);
@@ -501,6 +541,7 @@ pub async fn run_wg_listener(
peers.push(PeerState {
tunn,
public_key_b64: peer_config.public_key.clone(),
client_id: None,
allowed_ips,
endpoint,
persistent_keepalive: peer_config.persistent_keepalive,
@@ -523,9 +564,13 @@ pub async fn run_wg_listener(
// Register initial peers in ServerState (IP reservation + tun_routes only, NOT state.clients)
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
for peer in peers.iter_mut() {
if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await {
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
peer.vpn_ip = Some(ip);
match register_wg_peer(&state, peer, &wg_return_tx).await {
Ok(Some(ip)) => {
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
peer.vpn_ip = Some(ip);
}
Ok(None) => {}
Err(e) => warn!("Failed to register initial WG peer {}: {}", peer.public_key_b64, e),
}
}
@@ -705,7 +750,7 @@ pub async fn run_wg_listener(
warn!("WG peer {} connection expired", peer.public_key_b64);
if peer.is_connected {
peer.is_connected = false;
disconnect_wg_peer(&state, &peer.public_key_b64).await;
disconnect_wg_peer(&state, peer).await;
}
}
TunnResult::Err(e) => {
@@ -733,7 +778,7 @@ pub async fn run_wg_listener(
}
// Only update ClientInfo if peer is connected (in state.clients)
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
let client_id = peer_client_id(peer);
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_sent = peer.stats.bytes_sent;
info.bytes_received = peer.stats.bytes_received;
@@ -751,7 +796,7 @@ pub async fn run_wg_listener(
if now.duration_since(last) > std::time::Duration::from_secs(180) {
info!("WG peer {} idle timeout (180s), disconnecting", peer.public_key_b64);
peer.is_connected = false;
disconnect_wg_peer(&state, &peer.public_key_b64).await;
disconnect_wg_peer(&state, peer).await;
}
}
}
@@ -762,7 +807,12 @@ pub async fn run_wg_listener(
cmd = command_rx.recv() => {
match cmd {
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
let result = add_peer_to_loop(
if let Some((client_id, _, false)) = registered_peer_info(&state, &peer_config.public_key).await {
let _ = resp_tx.send(Err(anyhow!("WG peer maps to disabled or expired client: {}", client_id)));
continue;
}
let mut result = add_peer_to_loop(
&mut peers,
&peer_config,
&peer_index,
@@ -777,20 +827,23 @@ pub async fn run_wg_listener(
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
peer.vpn_ip = Some(ip);
}
Ok(None) => {}
Ok(None) => {
peers.retain(|p| p.public_key_b64 != peer_config.public_key);
result = Err(anyhow!("WG peer was not registered"));
}
Err(e) => {
warn!("Failed to register WG peer: {}", e);
peers.retain(|p| p.public_key_b64 != peer_config.public_key);
result = Err(e);
}
}
}
let _ = resp_tx.send(result);
}
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
let prev_len = peers.len();
peers.retain(|p| p.public_key_b64 != pubkey);
if peers.len() < prev_len {
if let Some(index) = peers.iter().position(|p| p.public_key_b64 == pubkey) {
let peer = peers.remove(index);
let vpn_ip = peer_vpn_ips.remove(&pubkey);
unregister_wg_peer(&state, &pubkey, vpn_ip).await;
unregister_wg_peer(&state, &peer, vpn_ip).await;
let _ = resp_tx.send(Ok(()));
} else {
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
@@ -814,7 +867,7 @@ pub async fn run_wg_listener(
// Cleanup: unregister all peers from ServerState
for peer in &peers {
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await;
unregister_wg_peer(&state, peer, vpn_ip).await;
}
info!("WireGuard listener stopped");
+1 -1
View File
@@ -3,6 +3,6 @@
*/
export const commitinfo = {
name: '@push.rocks/smartvpn',
version: '1.19.3',
version: '1.20.0',
description: 'A VPN solution with TypeScript control plane and Rust data plane daemon'
}
+8 -3
View File
@@ -122,10 +122,15 @@ export interface IVpnServerConfig {
* Supports exact IPs, CIDR, wildcards, ranges. */
connectionIpBlockList?: string[];
/** When true and forwardingMode is 'socket', the userspace NAT engine prepends
* PROXY protocol v2 headers on outbound TCP connections, conveying the VPN client's
* tunnel IP as the source address. This allows downstream services (e.g. SmartProxy)
* to see the real VPN client identity instead of 127.0.0.1. */
* PROXY protocol v2 headers on outbound TCP connections. */
socketForwardProxyProtocol?: boolean;
/** Source address to place into outbound PROXY v2 headers.
* 'tunnelIp' preserves legacy behavior. 'remoteIp' exposes the VPN client's
* real connecting IP when known, with tunnel IP fallback. */
socketForwardProxyProtocolSource?: 'tunnelIp' | 'remoteIp';
/** When true, outbound PROXY v2 headers include authenticated SmartVPN metadata
* in a vendor TLV: clientId, assignedIp, transportType, and remoteAddr. */
socketForwardProxyProtocolVpnMetadata?: boolean;
/** Destination routing policy for VPN client traffic (socket mode).
* Controls where decrypted traffic goes: allow through, block, or redirect to a target.
* Default: all traffic passes through (backward compatible). */