feat(server): unify WireGuard into the shared server transport pipeline

This commit is contained in:
2026-03-30 06:52:20 +00:00
parent 70e838c8ff
commit 79d9928485
8 changed files with 608 additions and 634 deletions

View File

@@ -58,6 +58,12 @@ pub struct ServerConfig {
pub proxy_protocol: Option<bool>,
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
pub connection_ip_block_list: Option<Vec<String>>,
/// WireGuard: server X25519 private key (base64). Required when transport includes WG.
pub wg_private_key: Option<String>,
/// WireGuard: UDP listen port (default: 51820).
pub wg_listen_port: Option<u16>,
/// WireGuard: pre-configured peers.
pub wg_peers: Option<Vec<crate::wireguard::WgPeerConfig>>,
}
/// Information about a connected client.
@@ -81,6 +87,8 @@ pub struct ClientInfo {
pub registered_client_id: String,
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
pub remote_addr: Option<String>,
/// Transport used for this connection: "websocket", "quic", or "wireguard".
pub transport_type: String,
}
/// Server statistics.
@@ -130,6 +138,7 @@ pub struct ServerState {
pub struct VpnServer {
state: Option<Arc<ServerState>>,
shutdown_tx: Option<mpsc::Sender<()>>,
wg_command_tx: Option<mpsc::Sender<crate::wireguard::WgCommand>>,
}
impl VpnServer {
@@ -137,6 +146,7 @@ impl VpnServer {
Self {
state: None,
shutdown_tx: None,
wg_command_tx: None,
}
}
@@ -255,59 +265,79 @@ impl VpnServer {
ForwardingSetup::Testing => {}
}
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
self.state = Some(state.clone());
self.shutdown_tx = Some(shutdown_tx);
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
let transport_mode = config.transport_mode.as_deref().unwrap_or("all");
let listen_addr = config.listen_addr.clone();
match transport_mode {
"quic" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
error!("QUIC listener error: {}", e);
}
});
}
"both" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
let state2 = state.clone();
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
// Store second shutdown sender so both listeners stop
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(combined_tx);
// Determine if WG should be included
let include_wg = config.wg_private_key.is_some()
&& matches!(transport_mode, "all" | "wireguard");
// Forward combined shutdown to both listeners
tokio::spawn(async move {
combined_rx.recv().await;
let _ = shutdown_tx_orig.send(()).await;
let _ = shutdown_tx2.send(()).await;
});
// Collect shutdown senders for all listeners
let mut listener_shutdown_txs: Vec<mpsc::Sender<()>> = Vec::new();
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("WebSocket listener error: {}", e);
}
});
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
error!("QUIC listener error: {}", e);
}
});
}
_ => {
// "websocket" (default)
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
}
// Spawn transport listeners based on mode
let spawn_ws = matches!(transport_mode, "all" | "both" | "websocket");
let spawn_quic = matches!(transport_mode, "all" | "both" | "quic");
if spawn_ws {
let (tx, mut rx) = mpsc::channel::<()>(1);
listener_shutdown_txs.push(tx);
let ws_state = state.clone();
let ws_addr = listen_addr.clone();
tokio::spawn(async move {
if let Err(e) = run_ws_listener(ws_state, ws_addr, &mut rx).await {
error!("WebSocket listener error: {}", e);
}
});
}
if spawn_quic {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
let (tx, mut rx) = mpsc::channel::<()>(1);
listener_shutdown_txs.push(tx);
let quic_state = state.clone();
tokio::spawn(async move {
if let Err(e) = run_quic_listener(quic_state, quic_addr, idle_timeout, &mut rx).await {
error!("QUIC listener error: {}", e);
}
});
}
if include_wg {
let wg_config = crate::wireguard::WgListenerConfig {
private_key: config.wg_private_key.clone().unwrap(),
listen_port: config.wg_listen_port.unwrap_or(51820),
peers: config.wg_peers.clone().unwrap_or_default(),
};
let (tx, rx) = mpsc::channel::<()>(1);
listener_shutdown_txs.push(tx);
let (cmd_tx, cmd_rx) = mpsc::channel::<crate::wireguard::WgCommand>(32);
self.wg_command_tx = Some(cmd_tx);
let wg_state = state.clone();
tokio::spawn(async move {
if let Err(e) = crate::wireguard::run_wg_listener(wg_state, wg_config, rx, cmd_rx).await {
error!("WireGuard listener error: {}", e);
}
});
}
// Replace self.shutdown_tx with a combined sender that fans out to all listeners
if listener_shutdown_txs.len() > 1 {
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
// Take the original shutdown_tx (from line above)
let _ = self.shutdown_tx.take();
self.shutdown_tx = Some(combined_tx);
tokio::spawn(async move {
combined_rx.recv().await;
for tx in listener_shutdown_txs {
let _ = tx.send(()).await;
}
});
} else if let Some(single_tx) = listener_shutdown_txs.into_iter().next() {
self.shutdown_tx = Some(single_tx);
}
info!("VPN server started (transport: {})", transport_mode);
@@ -346,6 +376,7 @@ impl VpnServer {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
self.wg_command_tx = None;
self.state = None;
info!("VPN server stopped");
Ok(())
@@ -434,6 +465,54 @@ impl VpnServer {
Ok(())
}
// ── WireGuard Peer Management ────────────────────────────────────────
/// Add a WireGuard peer dynamically (delegates to the WG event loop).
pub async fn add_wg_peer(&self, config: crate::wireguard::WgPeerConfig) -> Result<()> {
let tx = self.wg_command_tx.as_ref()
.ok_or_else(|| anyhow::anyhow!("WireGuard listener not running"))?;
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
tx.send(crate::wireguard::WgCommand::AddPeer(config, resp_tx))
.await
.map_err(|_| anyhow::anyhow!("WG event loop closed"))?;
resp_rx.await.map_err(|_| anyhow::anyhow!("No response from WG loop"))?
}
/// Remove a WireGuard peer dynamically (delegates to the WG event loop).
pub async fn remove_wg_peer(&self, public_key: &str) -> Result<()> {
let tx = self.wg_command_tx.as_ref()
.ok_or_else(|| anyhow::anyhow!("WireGuard listener not running"))?;
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
tx.send(crate::wireguard::WgCommand::RemovePeer(public_key.to_string(), resp_tx))
.await
.map_err(|_| anyhow::anyhow!("WG event loop closed"))?;
resp_rx.await.map_err(|_| anyhow::anyhow!("No response from WG loop"))?
}
/// List WireGuard peers from the unified client list.
pub async fn list_wg_peers(&self) -> Vec<crate::wireguard::WgPeerInfo> {
if let Some(ref state) = self.state {
state.clients.read().await.values()
.filter(|c| c.transport_type == "wireguard")
.map(|c| crate::wireguard::WgPeerInfo {
public_key: c.authenticated_key.clone(),
allowed_ips: vec![format!("{}/32", c.assigned_ip)],
endpoint: c.remote_addr.clone(),
persistent_keepalive: None,
stats: crate::wireguard::WgPeerStats {
bytes_sent: c.bytes_sent,
bytes_received: c.bytes_received,
packets_sent: 0,
packets_received: 0,
last_handshake_time: None,
},
})
.collect()
} else {
Vec::new()
}
}
// ── Client Registry (Hub) Methods ───────────────────────────────────
/// Create a new client entry. Generates keypairs and assigns an IP.
@@ -751,6 +830,7 @@ async fn run_ws_listener(
Box::new(sink),
Box::new(stream),
remote_addr,
"websocket",
).await {
warn!("Client connection error: {}", e);
}
@@ -827,6 +907,7 @@ async fn run_quic_listener(
Box::new(sink),
Box::new(stream),
Some(remote),
"quic",
).await {
warn!("QUIC client error: {}", e);
}
@@ -916,6 +997,7 @@ async fn handle_client_connection(
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
remote_addr: Option<std::net::SocketAddr>,
transport_type: &str,
) -> Result<()> {
let server_private_key = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
@@ -1054,6 +1136,7 @@ async fn handle_client_connection(
authenticated_key: client_pub_key_b64.clone(),
registered_client_id: registered_client_id.clone(),
remote_addr: remote_addr.map(|a| a.to_string()),
transport_type: transport_type.to_string(),
};
state.clients.write().await.insert(client_id.clone(), client_info);