feat(auth,client-registry): add Noise IK client authentication with managed client registry and per-client ACL controls

This commit is contained in:
2026-03-29 17:04:27 +00:00
parent 187a69028b
commit 01a0d8b9f4
20 changed files with 1930 additions and 897 deletions

View File

@@ -9,6 +9,8 @@ use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::{info, error, warn};
use crate::acl;
use crate::client_registry::{ClientEntry, ClientRegistry};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::mtu::{MtuConfig, TunnelOverhead};
@@ -45,6 +47,8 @@ pub struct ServerConfig {
pub quic_listen_addr: Option<String>,
/// QUIC idle timeout in seconds (default: 30).
pub quic_idle_timeout_secs: Option<u64>,
/// Pre-registered clients for IK authentication.
pub clients: Option<Vec<ClientEntry>>,
}
/// Information about a connected client.
@@ -62,6 +66,10 @@ pub struct ClientInfo {
pub keepalives_received: u64,
pub rate_limit_bytes_per_sec: Option<u64>,
pub burst_bytes: Option<u64>,
/// Client's authenticated Noise IK public key (base64).
pub authenticated_key: String,
/// Registered client ID from the client registry.
pub registered_client_id: String,
}
/// Server statistics.
@@ -88,6 +96,7 @@ pub struct ServerState {
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
pub mtu_config: MtuConfig,
pub started_at: std::time::Instant,
pub client_registry: RwLock<ClientRegistry>,
}
/// The VPN server.
@@ -127,6 +136,12 @@ impl VpnServer {
let overhead = TunnelOverhead::default_overhead();
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
// Build client registry from config
let registry = ClientRegistry::from_entries(
config.clients.clone().unwrap_or_default()
)?;
info!("Client registry loaded with {} entries", registry.len());
let state = Arc::new(ServerState {
config: config.clone(),
ip_pool: Mutex::new(ip_pool),
@@ -135,6 +150,7 @@ impl VpnServer {
rate_limiters: Mutex::new(HashMap::new()),
mtu_config,
started_at: std::time::Instant::now(),
client_registry: RwLock::new(registry),
});
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
@@ -287,6 +303,263 @@ impl VpnServer {
}
Ok(())
}
// ── Client Registry (Hub) Methods ───────────────────────────────────
/// Create a new client entry. Generates keypairs and assigns an IP.
/// Returns a JSON value with the full config bundle including secrets.
pub async fn create_client(&self, partial: serde_json::Value) -> Result<serde_json::Value> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let client_id = partial.get("clientId")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("clientId is required"))?
.to_string();
// Generate Noise IK keypair for the client
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
// Generate WireGuard keypair for the client
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
// Allocate a VPN IP
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
// Build entry from partial + generated values
let entry = ClientEntry {
client_id: client_id.clone(),
public_key: noise_pub.clone(),
wg_public_key: Some(wg_pub.clone()),
security: serde_json::from_value(
partial.get("security").cloned().unwrap_or(serde_json::Value::Null)
).ok(),
priority: partial.get("priority").and_then(|v| v.as_u64()).map(|v| v as u32),
enabled: partial.get("enabled").and_then(|v| v.as_bool()).or(Some(true)),
tags: partial.get("tags").and_then(|v| {
v.as_array().map(|a| a.iter().filter_map(|s| s.as_str().map(String::from)).collect())
}),
description: partial.get("description").and_then(|v| v.as_str()).map(String::from),
expires_at: partial.get("expiresAt").and_then(|v| v.as_str()).map(String::from),
assigned_ip: Some(assigned_ip.to_string()),
};
// Add to registry
state.client_registry.write().await.add(entry.clone())?;
// Build SmartVPN client config
let smartvpn_config = serde_json::json!({
"serverUrl": format!("wss://{}",
state.config.listen_addr.replace("0.0.0.0", "localhost")),
"serverPublicKey": state.config.public_key,
"clientPrivateKey": noise_priv,
"clientPublicKey": noise_pub,
"dns": state.config.dns,
"mtu": state.config.mtu,
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
});
// Build WireGuard config string
let wg_config = format!(
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
wg_priv,
assigned_ip,
state.config.dns.as_ref()
.map(|d| format!("DNS = {}", d.join(", ")))
.unwrap_or_default(),
state.config.public_key,
state.config.listen_addr,
);
let entry_json = serde_json::to_value(&entry)?;
Ok(serde_json::json!({
"entry": entry_json,
"smartvpnConfig": smartvpn_config,
"wireguardConfig": wg_config,
"secrets": {
"noisePrivateKey": noise_priv,
"wgPrivateKey": wg_priv,
}
}))
}
/// Remove a registered client from the registry (and disconnect if connected).
pub async fn remove_registered_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let entry = state.client_registry.write().await.remove(client_id)?;
// Release the IP if assigned
if let Some(ref ip_str) = entry.assigned_ip {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
state.ip_pool.lock().await.release(&ip);
}
}
// Disconnect if currently connected
let _ = self.disconnect_client(client_id).await;
Ok(())
}
/// Get a registered client by ID.
pub async fn get_registered_client(&self, client_id: &str) -> Result<serde_json::Value> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let registry = state.client_registry.read().await;
let entry = registry.get_by_id(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
Ok(serde_json::to_value(entry)?)
}
/// List all registered clients.
pub async fn list_registered_clients(&self) -> Vec<ClientEntry> {
if let Some(ref state) = self.state {
state.client_registry.read().await.list().into_iter().cloned().collect()
} else {
Vec::new()
}
}
/// Update a registered client's fields.
pub async fn update_registered_client(&self, client_id: &str, update: serde_json::Value) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
state.client_registry.write().await.update(client_id, |entry| {
if let Some(security) = update.get("security") {
entry.security = serde_json::from_value(security.clone()).ok();
}
if let Some(priority) = update.get("priority").and_then(|v| v.as_u64()) {
entry.priority = Some(priority as u32);
}
if let Some(enabled) = update.get("enabled").and_then(|v| v.as_bool()) {
entry.enabled = Some(enabled);
}
if let Some(tags) = update.get("tags").and_then(|v| v.as_array()) {
entry.tags = Some(tags.iter().filter_map(|s| s.as_str().map(String::from)).collect());
}
if let Some(desc) = update.get("description").and_then(|v| v.as_str()) {
entry.description = Some(desc.to_string());
}
if let Some(expires) = update.get("expiresAt").and_then(|v| v.as_str()) {
entry.expires_at = Some(expires.to_string());
}
})?;
Ok(())
}
/// Enable a registered client.
pub async fn enable_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(true);
})
}
/// Disable a registered client (also disconnects if connected).
pub async fn disable_client(&self, client_id: &str) -> Result<()> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
state.client_registry.write().await.update(client_id, |entry| {
entry.enabled = Some(false);
})?;
// Disconnect if currently connected
let _ = self.disconnect_client(client_id).await;
Ok(())
}
/// Rotate a client's keys. Returns a new config bundle with fresh keypairs.
pub async fn rotate_client_key(&self, client_id: &str) -> Result<serde_json::Value> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let (noise_pub, noise_priv) = crypto::generate_keypair()?;
let (wg_pub, wg_priv) = crate::wireguard::generate_wg_keypair();
state.client_registry.write().await.rotate_key(
client_id,
noise_pub.clone(),
Some(wg_pub.clone()),
)?;
// Disconnect existing connection (old key is no longer valid)
let _ = self.disconnect_client(client_id).await;
// Get updated entry for the config bundle
let entry_json = self.get_registered_client(client_id).await?;
let assigned_ip = entry_json.get("assignedIp")
.and_then(|v| v.as_str())
.unwrap_or("0.0.0.0");
let smartvpn_config = serde_json::json!({
"serverUrl": format!("wss://{}",
state.config.listen_addr.replace("0.0.0.0", "localhost")),
"serverPublicKey": state.config.public_key,
"clientPrivateKey": noise_priv,
"clientPublicKey": noise_pub,
"dns": state.config.dns,
"mtu": state.config.mtu,
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
});
let wg_config = format!(
"[Interface]\nPrivateKey = {}\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
wg_priv, assigned_ip,
state.config.dns.as_ref()
.map(|d| format!("DNS = {}", d.join(", ")))
.unwrap_or_default(),
state.config.public_key,
state.config.listen_addr,
);
Ok(serde_json::json!({
"entry": entry_json,
"smartvpnConfig": smartvpn_config,
"wireguardConfig": wg_config,
"secrets": {
"noisePrivateKey": noise_priv,
"wgPrivateKey": wg_priv,
}
}))
}
/// Export a client config (without secrets) in the specified format.
pub async fn export_client_config(&self, client_id: &str, format: &str) -> Result<serde_json::Value> {
let state = self.state.as_ref()
.ok_or_else(|| anyhow::anyhow!("Server not running"))?;
let registry = state.client_registry.read().await;
let entry = registry.get_by_id(client_id)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", client_id))?;
match format {
"smartvpn" => {
Ok(serde_json::json!({
"config": {
"serverUrl": format!("wss://{}",
state.config.listen_addr.replace("0.0.0.0", "localhost")),
"serverPublicKey": state.config.public_key,
"clientPublicKey": entry.public_key,
"dns": state.config.dns,
"mtu": state.config.mtu,
"keepaliveIntervalSecs": state.config.keepalive_interval_secs,
}
}))
}
"wireguard" => {
let assigned_ip = entry.assigned_ip.as_deref().unwrap_or("0.0.0.0");
let config = format!(
"[Interface]\nAddress = {}/24\n{}\n[Peer]\nPublicKey = {}\nAllowedIPs = 0.0.0.0/0\nEndpoint = {}\nPersistentKeepalive = 25\n",
assigned_ip,
state.config.dns.as_ref()
.map(|d| format!("DNS = {}", d.join(", ")))
.unwrap_or_default(),
state.config.public_key,
state.config.listen_addr,
);
Ok(serde_json::json!({ "config": config }))
}
_ => anyhow::bail!("Unknown format: {}", format),
}
}
}
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
@@ -421,26 +694,23 @@ async fn run_quic_listener(
Ok(())
}
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
/// the client, and runs the main packet forwarding loop.
/// Transport-agnostic client handler. Performs the Noise IK handshake, authenticates
/// the client against the registry, and runs the main packet forwarding loop.
async fn handle_client_connection(
state: Arc<ServerState>,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
) -> Result<()> {
let client_id = uuid_v4();
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
let server_private_key = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
&state.config.private_key,
)?;
// Noise IK handshake (server side = responder)
let mut responder = crypto::create_responder(&server_private_key)?;
let mut buf = vec![0u8; 65535];
// Receive handshake init
// Receive handshake init (-> e, es, s, ss)
let init_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before handshake"),
@@ -455,6 +725,47 @@ async fn handle_client_connection(
}
responder.read_message(&frame.payload, &mut buf)?;
// Extract client's static public key BEFORE entering transport mode
let client_pub_key_bytes = responder
.get_remote_static()
.ok_or_else(|| anyhow::anyhow!("IK handshake: no client static key received"))?
.to_vec();
let client_pub_key_b64 = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
&client_pub_key_bytes,
);
// Verify client against registry
let (registered_client_id, client_security) = {
let registry = state.client_registry.read().await;
if !registry.is_authorized(&client_pub_key_b64) {
warn!("Rejecting unauthorized client with key {}", &client_pub_key_b64[..8]);
// Send handshake response but then disconnect
let len = responder.write_message(&[], &mut buf)?;
let response_frame = Frame {
packet_type: PacketType::HandshakeResp,
payload: buf[..len].to_vec(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
// Send disconnect frame
let disconnect_frame = Frame {
packet_type: PacketType::Disconnect,
payload: Vec::new(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, disconnect_frame, &mut frame_bytes)?;
let _ = sink.send_reliable(frame_bytes.to_vec()).await;
anyhow::bail!("Client not authorized");
}
let entry = registry.get_by_key(&client_pub_key_b64).unwrap();
(entry.client_id.clone(), entry.security.clone())
};
// Complete handshake (<- e, ee, se)
let len = responder.write_message(&[], &mut buf)?;
let response_payload = buf[..len].to_vec();
@@ -468,9 +779,24 @@ async fn handle_client_connection(
let mut noise_transport = responder.into_transport_mode()?;
// Register client
let default_rate = state.config.default_rate_limit_bytes_per_sec;
let default_burst = state.config.default_burst_bytes;
// Use the registered client ID as the connection ID
let client_id = registered_client_id.clone();
// Allocate IP
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
// Determine rate limits: per-client security overrides server defaults
let (rate_limit, burst) = if let Some(ref sec) = client_security {
if let Some(ref rl) = sec.rate_limit {
(Some(rl.bytes_per_sec), Some(rl.burst_bytes))
} else {
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
}
} else {
(state.config.default_rate_limit_bytes_per_sec, state.config.default_burst_bytes)
};
// Register connected client
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: assigned_ip.to_string(),
@@ -481,13 +807,15 @@ async fn handle_client_connection(
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: default_rate,
burst_bytes: default_burst,
rate_limit_bytes_per_sec: rate_limit,
burst_bytes: burst,
authenticated_key: client_pub_key_b64.clone(),
registered_client_id: registered_client_id.clone(),
};
state.clients.write().await.insert(client_id.clone(), client_info);
// Set up rate limiter if defaults are configured
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
// Set up rate limiter
if let (Some(rate), Some(burst)) = (rate_limit, burst) {
state
.rate_limiters
.lock()
@@ -517,7 +845,7 @@ async fn handle_client_connection(
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
info!("Client {} connected with IP {}", client_id, assigned_ip);
info!("Client {} ({}) connected with IP {}", registered_client_id, &client_pub_key_b64[..8], assigned_ip);
// Main packet loop with dead-peer detection
let mut last_activity = tokio::time::Instant::now();
@@ -534,6 +862,24 @@ async fn handle_client_connection(
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
// ACL check on decrypted packet
if let Some(ref sec) = client_security {
if len >= 20 {
// Extract src/dst from IPv4 header
let src = Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]);
let dst = Ipv4Addr::new(buf[16], buf[17], buf[18], buf[19]);
let acl_result = acl::check_acl(sec, src, dst);
if acl_result != acl::AclResult::Allow {
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.packets_dropped += 1;
info.bytes_dropped += len as u64;
}
continue;
}
}
}
// Rate limiting check
let allowed = {
let mut limiters = state.rate_limiters.lock().await;
@@ -635,20 +981,6 @@ async fn handle_client_connection(
Ok(())
}
fn uuid_v4() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let bytes: [u8; 16] = rng.gen();
format!(
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
bytes[0], bytes[1], bytes[2], bytes[3],
bytes[4], bytes[5],
bytes[6], bytes[7],
bytes[8], bytes[9],
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
)
}
fn timestamp_now() -> String {
use std::time::SystemTime;
let duration = SystemTime::now()