feat(server): unify WireGuard into the shared server transport pipeline
This commit is contained in:
@@ -7,7 +7,7 @@ use tracing::{info, error, warn};
|
||||
use crate::client::{ClientConfig, VpnClient};
|
||||
use crate::crypto;
|
||||
use crate::server::{ServerConfig, VpnServer};
|
||||
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig, WgServer, WgServerConfig};
|
||||
use crate::wireguard::{self, WgClient, WgClientConfig, WgPeerConfig};
|
||||
|
||||
// ============================================================================
|
||||
// IPC protocol types
|
||||
@@ -95,7 +95,6 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
let mut vpn_client = VpnClient::new();
|
||||
let mut vpn_server = VpnServer::new();
|
||||
let mut wg_client = WgClient::new();
|
||||
let mut wg_server = WgServer::new();
|
||||
|
||||
send_event_stdout("ready", serde_json::json!({ "mode": mode }));
|
||||
|
||||
@@ -131,7 +130,7 @@ pub async fn management_loop_stdio(mode: &str) -> Result<()> {
|
||||
|
||||
let response = match mode {
|
||||
"client" => handle_client_request(&request, &mut vpn_client, &mut wg_client).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server, &mut wg_server).await,
|
||||
"server" => handle_server_request(&request, &mut vpn_server).await,
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
send_response_stdout(&response);
|
||||
@@ -154,7 +153,6 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
let vpn_client = std::sync::Arc::new(Mutex::new(VpnClient::new()));
|
||||
let vpn_server = std::sync::Arc::new(Mutex::new(VpnServer::new()));
|
||||
let wg_client = std::sync::Arc::new(Mutex::new(WgClient::new()));
|
||||
let wg_server = std::sync::Arc::new(Mutex::new(WgServer::new()));
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
@@ -163,10 +161,9 @@ pub async fn management_loop_socket(socket_path: &str, mode: &str) -> Result<()>
|
||||
let client = vpn_client.clone();
|
||||
let server = vpn_server.clone();
|
||||
let wg_c = wg_client.clone();
|
||||
let wg_s = wg_server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_socket_connection(stream, &mode, client, server, wg_c, wg_s).await
|
||||
handle_socket_connection(stream, &mode, client, server, wg_c).await
|
||||
{
|
||||
warn!("Socket connection error: {}", e);
|
||||
}
|
||||
@@ -185,7 +182,6 @@ async fn handle_socket_connection(
|
||||
vpn_client: std::sync::Arc<Mutex<VpnClient>>,
|
||||
vpn_server: std::sync::Arc<Mutex<VpnServer>>,
|
||||
wg_client: std::sync::Arc<Mutex<WgClient>>,
|
||||
wg_server: std::sync::Arc<Mutex<WgServer>>,
|
||||
) -> Result<()> {
|
||||
let (reader, mut writer) = stream.into_split();
|
||||
let buf_reader = BufReader::new(reader);
|
||||
@@ -241,8 +237,7 @@ async fn handle_socket_connection(
|
||||
}
|
||||
"server" => {
|
||||
let mut server = vpn_server.lock().await;
|
||||
let mut wg_s = wg_server.lock().await;
|
||||
handle_server_request(&request, &mut server, &mut wg_s).await
|
||||
handle_server_request(&request, &mut server).await
|
||||
}
|
||||
_ => ManagementResponse::err(request.id.clone(), format!("Unknown mode: {}", mode)),
|
||||
};
|
||||
@@ -381,92 +376,46 @@ async fn handle_client_request(
|
||||
async fn handle_server_request(
|
||||
request: &ManagementRequest,
|
||||
vpn_server: &mut VpnServer,
|
||||
wg_server: &mut WgServer,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => {
|
||||
// Check if transportMode is "wireguard"
|
||||
let transport_mode = request.params
|
||||
.get("config")
|
||||
.and_then(|c| c.get("transportMode"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if transport_mode == "wireguard" {
|
||||
let config: WgServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid WG config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG start failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
let config: ServerConfig = match serde_json::from_value(
|
||||
request.params.get("config").cloned().unwrap_or_default(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(id, format!("Invalid config: {}", e));
|
||||
}
|
||||
};
|
||||
match vpn_server.start(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Start failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
if wg_server.is_running() {
|
||||
match wg_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("WG stop failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
}
|
||||
match vpn_server.stop().await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Stop failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"getStatus" => {
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_status())
|
||||
} else {
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
let status = vpn_server.get_status();
|
||||
ManagementResponse::ok(id, status)
|
||||
}
|
||||
"getStatistics" => {
|
||||
if wg_server.is_running() {
|
||||
ManagementResponse::ok(id, wg_server.get_statistics().await)
|
||||
} else {
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
let stats = vpn_server.get_statistics().await;
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id, v),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"listClients" => {
|
||||
if wg_server.is_running() {
|
||||
let peers = wg_server.list_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
} else {
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
let clients = vpn_server.list_clients().await;
|
||||
match serde_json::to_value(&clients) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "clients": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
}
|
||||
}
|
||||
"disconnectClient" => {
|
||||
@@ -546,9 +495,6 @@ async fn handle_server_request(
|
||||
)
|
||||
}
|
||||
"addWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let config: WgPeerConfig = match serde_json::from_value(
|
||||
request.params.get("peer").cloned().unwrap_or_default(),
|
||||
) {
|
||||
@@ -557,29 +503,23 @@ async fn handle_server_request(
|
||||
return ManagementResponse::err(id, format!("Invalid peer config: {}", e));
|
||||
}
|
||||
};
|
||||
match wg_server.add_peer(config).await {
|
||||
match vpn_server.add_wg_peer(config).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Add peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"removeWgPeer" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let public_key = match request.params.get("publicKey").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id, "Missing publicKey".to_string()),
|
||||
};
|
||||
match wg_server.remove_peer(&public_key).await {
|
||||
match vpn_server.remove_wg_peer(&public_key).await {
|
||||
Ok(()) => ManagementResponse::ok(id, serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id, format!("Remove peer failed: {}", e)),
|
||||
}
|
||||
}
|
||||
"listWgPeers" => {
|
||||
if !wg_server.is_running() {
|
||||
return ManagementResponse::err(id, "WireGuard server not running".to_string());
|
||||
}
|
||||
let peers = wg_server.list_peers().await;
|
||||
let peers = vpn_server.list_wg_peers().await;
|
||||
match serde_json::to_value(&peers) {
|
||||
Ok(v) => ManagementResponse::ok(id, serde_json::json!({ "peers": v })),
|
||||
Err(e) => ManagementResponse::err(id, format!("Serialize error: {}", e)),
|
||||
|
||||
@@ -86,6 +86,16 @@ impl IpPool {
|
||||
client_id
|
||||
}
|
||||
|
||||
/// Reserve a specific IP for a client (e.g., WireGuard static IP from allowed_ips).
|
||||
pub fn reserve(&mut self, ip: Ipv4Addr, client_id: &str) -> Result<()> {
|
||||
if self.allocated.contains_key(&ip) {
|
||||
anyhow::bail!("IP {} is already allocated", ip);
|
||||
}
|
||||
self.allocated.insert(ip, client_id.to_string());
|
||||
info!("Reserved IP {} for client {}", ip, client_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of currently allocated IPs.
|
||||
pub fn allocated_count(&self) -> usize {
|
||||
self.allocated.len()
|
||||
|
||||
@@ -58,6 +58,12 @@ pub struct ServerConfig {
|
||||
pub proxy_protocol: Option<bool>,
|
||||
/// Server-level IP block list — applied at TCP accept, before Noise handshake.
|
||||
pub connection_ip_block_list: Option<Vec<String>>,
|
||||
/// WireGuard: server X25519 private key (base64). Required when transport includes WG.
|
||||
pub wg_private_key: Option<String>,
|
||||
/// WireGuard: UDP listen port (default: 51820).
|
||||
pub wg_listen_port: Option<u16>,
|
||||
/// WireGuard: pre-configured peers.
|
||||
pub wg_peers: Option<Vec<crate::wireguard::WgPeerConfig>>,
|
||||
}
|
||||
|
||||
/// Information about a connected client.
|
||||
@@ -81,6 +87,8 @@ pub struct ClientInfo {
|
||||
pub registered_client_id: String,
|
||||
/// Real client IP:port (from PROXY protocol header or direct TCP connection).
|
||||
pub remote_addr: Option<String>,
|
||||
/// Transport used for this connection: "websocket", "quic", or "wireguard".
|
||||
pub transport_type: String,
|
||||
}
|
||||
|
||||
/// Server statistics.
|
||||
@@ -130,6 +138,7 @@ pub struct ServerState {
|
||||
pub struct VpnServer {
|
||||
state: Option<Arc<ServerState>>,
|
||||
shutdown_tx: Option<mpsc::Sender<()>>,
|
||||
wg_command_tx: Option<mpsc::Sender<crate::wireguard::WgCommand>>,
|
||||
}
|
||||
|
||||
impl VpnServer {
|
||||
@@ -137,6 +146,7 @@ impl VpnServer {
|
||||
Self {
|
||||
state: None,
|
||||
shutdown_tx: None,
|
||||
wg_command_tx: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,59 +265,79 @@ impl VpnServer {
|
||||
ForwardingSetup::Testing => {}
|
||||
}
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||
self.state = Some(state.clone());
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
|
||||
let transport_mode = config.transport_mode.as_deref().unwrap_or("all");
|
||||
let listen_addr = config.listen_addr.clone();
|
||||
|
||||
match transport_mode {
|
||||
"quic" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
"both" => {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
let state2 = state.clone();
|
||||
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
|
||||
// Store second shutdown sender so both listeners stop
|
||||
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
|
||||
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
|
||||
self.shutdown_tx = Some(combined_tx);
|
||||
// Determine if WG should be included
|
||||
let include_wg = config.wg_private_key.is_some()
|
||||
&& matches!(transport_mode, "all" | "wireguard");
|
||||
|
||||
// Forward combined shutdown to both listeners
|
||||
tokio::spawn(async move {
|
||||
combined_rx.recv().await;
|
||||
let _ = shutdown_tx_orig.send(()).await;
|
||||
let _ = shutdown_tx2.send(()).await;
|
||||
});
|
||||
// Collect shutdown senders for all listeners
|
||||
let mut listener_shutdown_txs: Vec<mpsc::Sender<()>> = Vec::new();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("WebSocket listener error: {}", e);
|
||||
}
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// "websocket" (default)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
|
||||
error!("Server listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
// Spawn transport listeners based on mode
|
||||
let spawn_ws = matches!(transport_mode, "all" | "both" | "websocket");
|
||||
let spawn_quic = matches!(transport_mode, "all" | "both" | "quic");
|
||||
|
||||
if spawn_ws {
|
||||
let (tx, mut rx) = mpsc::channel::<()>(1);
|
||||
listener_shutdown_txs.push(tx);
|
||||
let ws_state = state.clone();
|
||||
let ws_addr = listen_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ws_listener(ws_state, ws_addr, &mut rx).await {
|
||||
error!("WebSocket listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if spawn_quic {
|
||||
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
|
||||
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
|
||||
let (tx, mut rx) = mpsc::channel::<()>(1);
|
||||
listener_shutdown_txs.push(tx);
|
||||
let quic_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_quic_listener(quic_state, quic_addr, idle_timeout, &mut rx).await {
|
||||
error!("QUIC listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if include_wg {
|
||||
let wg_config = crate::wireguard::WgListenerConfig {
|
||||
private_key: config.wg_private_key.clone().unwrap(),
|
||||
listen_port: config.wg_listen_port.unwrap_or(51820),
|
||||
peers: config.wg_peers.clone().unwrap_or_default(),
|
||||
};
|
||||
let (tx, rx) = mpsc::channel::<()>(1);
|
||||
listener_shutdown_txs.push(tx);
|
||||
let (cmd_tx, cmd_rx) = mpsc::channel::<crate::wireguard::WgCommand>(32);
|
||||
self.wg_command_tx = Some(cmd_tx);
|
||||
let wg_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::wireguard::run_wg_listener(wg_state, wg_config, rx, cmd_rx).await {
|
||||
error!("WireGuard listener error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Replace self.shutdown_tx with a combined sender that fans out to all listeners
|
||||
if listener_shutdown_txs.len() > 1 {
|
||||
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
|
||||
// Take the original shutdown_tx (from line above)
|
||||
let _ = self.shutdown_tx.take();
|
||||
self.shutdown_tx = Some(combined_tx);
|
||||
tokio::spawn(async move {
|
||||
combined_rx.recv().await;
|
||||
for tx in listener_shutdown_txs {
|
||||
let _ = tx.send(()).await;
|
||||
}
|
||||
});
|
||||
} else if let Some(single_tx) = listener_shutdown_txs.into_iter().next() {
|
||||
self.shutdown_tx = Some(single_tx);
|
||||
}
|
||||
|
||||
info!("VPN server started (transport: {})", transport_mode);
|
||||
@@ -346,6 +376,7 @@ impl VpnServer {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(()).await;
|
||||
}
|
||||
self.wg_command_tx = None;
|
||||
self.state = None;
|
||||
info!("VPN server stopped");
|
||||
Ok(())
|
||||
@@ -434,6 +465,54 @@ impl VpnServer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── WireGuard Peer Management ────────────────────────────────────────
|
||||
|
||||
/// Add a WireGuard peer dynamically (delegates to the WG event loop).
|
||||
pub async fn add_wg_peer(&self, config: crate::wireguard::WgPeerConfig) -> Result<()> {
|
||||
let tx = self.wg_command_tx.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("WireGuard listener not running"))?;
|
||||
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
|
||||
tx.send(crate::wireguard::WgCommand::AddPeer(config, resp_tx))
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("WG event loop closed"))?;
|
||||
resp_rx.await.map_err(|_| anyhow::anyhow!("No response from WG loop"))?
|
||||
}
|
||||
|
||||
/// Remove a WireGuard peer dynamically (delegates to the WG event loop).
|
||||
pub async fn remove_wg_peer(&self, public_key: &str) -> Result<()> {
|
||||
let tx = self.wg_command_tx.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("WireGuard listener not running"))?;
|
||||
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
|
||||
tx.send(crate::wireguard::WgCommand::RemovePeer(public_key.to_string(), resp_tx))
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("WG event loop closed"))?;
|
||||
resp_rx.await.map_err(|_| anyhow::anyhow!("No response from WG loop"))?
|
||||
}
|
||||
|
||||
/// List WireGuard peers from the unified client list.
|
||||
pub async fn list_wg_peers(&self) -> Vec<crate::wireguard::WgPeerInfo> {
|
||||
if let Some(ref state) = self.state {
|
||||
state.clients.read().await.values()
|
||||
.filter(|c| c.transport_type == "wireguard")
|
||||
.map(|c| crate::wireguard::WgPeerInfo {
|
||||
public_key: c.authenticated_key.clone(),
|
||||
allowed_ips: vec![format!("{}/32", c.assigned_ip)],
|
||||
endpoint: c.remote_addr.clone(),
|
||||
persistent_keepalive: None,
|
||||
stats: crate::wireguard::WgPeerStats {
|
||||
bytes_sent: c.bytes_sent,
|
||||
bytes_received: c.bytes_received,
|
||||
packets_sent: 0,
|
||||
packets_received: 0,
|
||||
last_handshake_time: None,
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Client Registry (Hub) Methods ───────────────────────────────────
|
||||
|
||||
/// Create a new client entry. Generates keypairs and assigns an IP.
|
||||
@@ -751,6 +830,7 @@ async fn run_ws_listener(
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
remote_addr,
|
||||
"websocket",
|
||||
).await {
|
||||
warn!("Client connection error: {}", e);
|
||||
}
|
||||
@@ -827,6 +907,7 @@ async fn run_quic_listener(
|
||||
Box::new(sink),
|
||||
Box::new(stream),
|
||||
Some(remote),
|
||||
"quic",
|
||||
).await {
|
||||
warn!("QUIC client error: {}", e);
|
||||
}
|
||||
@@ -916,6 +997,7 @@ async fn handle_client_connection(
|
||||
mut sink: Box<dyn TransportSink>,
|
||||
mut stream: Box<dyn TransportStream>,
|
||||
remote_addr: Option<std::net::SocketAddr>,
|
||||
transport_type: &str,
|
||||
) -> Result<()> {
|
||||
let server_private_key = base64::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
@@ -1054,6 +1136,7 @@ async fn handle_client_connection(
|
||||
authenticated_key: client_pub_key_b64.clone(),
|
||||
registered_client_id: registered_client_id.clone(),
|
||||
remote_addr: remote_addr.map(|a| a.to_string()),
|
||||
transport_type: transport_type.to_string(),
|
||||
};
|
||||
state.clients.write().await.insert(client_id.clone(), client_info);
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
@@ -17,8 +15,7 @@ use tokio::net::UdpSocket;
|
||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::network;
|
||||
use crate::tunnel::extract_dst_ip;
|
||||
use crate::server::{ClientInfo, ForwardingEngine, ServerState};
|
||||
use crate::tunnel::{self, TunConfig};
|
||||
|
||||
// ============================================================================
|
||||
@@ -30,9 +27,6 @@ const WG_BUFFER_SIZE: usize = MAX_UDP_PACKET;
|
||||
/// Minimum dst buffer size for boringtun encapsulate/decapsulate
|
||||
const _MIN_DST_BUF: usize = 148;
|
||||
const TIMER_TICK_MS: u64 = 100;
|
||||
const DEFAULT_WG_PORT: u16 = 51820;
|
||||
const DEFAULT_TUN_ADDRESS: &str = "10.8.0.1";
|
||||
const DEFAULT_TUN_NETMASK: &str = "255.255.255.0";
|
||||
const DEFAULT_MTU: u16 = 1420;
|
||||
|
||||
// ============================================================================
|
||||
@@ -52,27 +46,6 @@ pub struct WgPeerConfig {
|
||||
pub persistent_keepalive: Option<u16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WgServerConfig {
|
||||
pub private_key: String,
|
||||
#[serde(default)]
|
||||
pub listen_port: Option<u16>,
|
||||
#[serde(default)]
|
||||
pub tun_address: Option<String>,
|
||||
#[serde(default)]
|
||||
pub tun_netmask: Option<String>,
|
||||
#[serde(default)]
|
||||
pub mtu: Option<u16>,
|
||||
pub peers: Vec<WgPeerConfig>,
|
||||
#[serde(default)]
|
||||
pub dns: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
pub enable_nat: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub subnet: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WgClientConfig {
|
||||
@@ -112,17 +85,6 @@ pub struct WgPeerInfo {
|
||||
pub stats: WgPeerStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WgServerStats {
|
||||
pub total_bytes_sent: u64,
|
||||
pub total_bytes_received: u64,
|
||||
pub total_packets_sent: u64,
|
||||
pub total_packets_received: u64,
|
||||
pub active_peers: usize,
|
||||
pub uptime_seconds: f64,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Key generation and parsing
|
||||
// ============================================================================
|
||||
@@ -233,7 +195,7 @@ impl AllowedIp {
|
||||
// Dynamic peer management commands
|
||||
// ============================================================================
|
||||
|
||||
enum WgCommand {
|
||||
pub enum WgCommand {
|
||||
AddPeer(WgPeerConfig, oneshot::Sender<Result<()>>),
|
||||
RemovePeer(String, oneshot::Sender<Result<()>>),
|
||||
}
|
||||
@@ -258,451 +220,6 @@ impl PeerState {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WgServer
|
||||
// ============================================================================
|
||||
|
||||
pub struct WgServer {
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
command_tx: Option<mpsc::Sender<WgCommand>>,
|
||||
shared_stats: Arc<RwLock<HashMap<String, WgPeerStats>>>,
|
||||
server_stats: Arc<RwLock<WgServerStats>>,
|
||||
started_at: Option<Instant>,
|
||||
listen_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl WgServer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
shutdown_tx: None,
|
||||
command_tx: None,
|
||||
shared_stats: Arc::new(RwLock::new(HashMap::new())),
|
||||
server_stats: Arc::new(RwLock::new(WgServerStats::default())),
|
||||
started_at: None,
|
||||
listen_port: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.shutdown_tx.is_some()
|
||||
}
|
||||
|
||||
pub async fn start(&mut self, config: WgServerConfig) -> Result<()> {
|
||||
if self.is_running() {
|
||||
return Err(anyhow!("WireGuard server is already running"));
|
||||
}
|
||||
|
||||
let listen_port = config.listen_port.unwrap_or(DEFAULT_WG_PORT);
|
||||
let tun_address = config
|
||||
.tun_address
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_TUN_ADDRESS);
|
||||
let tun_netmask = config
|
||||
.tun_netmask
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_TUN_NETMASK);
|
||||
let mtu = config.mtu.unwrap_or(DEFAULT_MTU);
|
||||
|
||||
// Parse server private key
|
||||
let server_private = parse_private_key(&config.private_key)?;
|
||||
let server_public = PublicKey::from(&server_private);
|
||||
|
||||
// Create rate limiter for DDoS protection
|
||||
let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64));
|
||||
|
||||
// Build peer state
|
||||
let peer_index = AtomicU32::new(0);
|
||||
let mut peers: Vec<PeerState> = Vec::with_capacity(config.peers.len());
|
||||
|
||||
for peer_config in &config.peers {
|
||||
let peer_public = parse_public_key(&peer_config.public_key)?;
|
||||
let psk = match &peer_config.preshared_key {
|
||||
Some(k) => Some(parse_preshared_key(k)?),
|
||||
None => None,
|
||||
};
|
||||
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
// Clone the private key for each Tunn (StaticSecret doesn't implement Clone,
|
||||
// so re-parse from config)
|
||||
let priv_copy = parse_private_key(&config.private_key)?;
|
||||
|
||||
let tunn = Tunn::new(
|
||||
priv_copy,
|
||||
peer_public,
|
||||
psk,
|
||||
peer_config.persistent_keepalive,
|
||||
idx,
|
||||
Some(rate_limiter.clone()),
|
||||
);
|
||||
|
||||
let allowed_ips: Vec<AllowedIp> = peer_config
|
||||
.allowed_ips
|
||||
.iter()
|
||||
.map(|cidr| AllowedIp::parse(cidr))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let endpoint = match &peer_config.endpoint {
|
||||
Some(ep) => Some(ep.parse::<SocketAddr>()?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
peers.push(PeerState {
|
||||
tunn,
|
||||
public_key_b64: peer_config.public_key.clone(),
|
||||
allowed_ips,
|
||||
endpoint,
|
||||
persistent_keepalive: peer_config.persistent_keepalive,
|
||||
stats: WgPeerStats::default(),
|
||||
});
|
||||
}
|
||||
|
||||
// Create TUN device
|
||||
let tun_config = TunConfig {
|
||||
name: "wg0".to_string(),
|
||||
address: tun_address.parse()?,
|
||||
netmask: tun_netmask.parse()?,
|
||||
mtu,
|
||||
};
|
||||
let tun_device = tunnel::create_tun(&tun_config)?;
|
||||
info!("WireGuard TUN device created: {}", tun_config.name);
|
||||
|
||||
// Bind UDP socket
|
||||
let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", listen_port)).await?;
|
||||
info!("WireGuard server listening on UDP port {}", listen_port);
|
||||
|
||||
// Enable IP forwarding and NAT if requested
|
||||
if config.enable_nat.unwrap_or(false) {
|
||||
network::enable_ip_forwarding()?;
|
||||
let subnet = config
|
||||
.subnet
|
||||
.as_deref()
|
||||
.unwrap_or("10.8.0.0/24");
|
||||
let iface = network::get_default_interface()?;
|
||||
network::setup_nat(subnet, &iface).await?;
|
||||
info!("NAT enabled for subnet {} via {}", subnet, iface);
|
||||
}
|
||||
|
||||
// Channels
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
|
||||
let (command_tx, command_rx) = mpsc::channel::<WgCommand>(32);
|
||||
|
||||
let shared_stats = self.shared_stats.clone();
|
||||
let server_stats = self.server_stats.clone();
|
||||
let started_at = Instant::now();
|
||||
|
||||
// Initialize shared stats
|
||||
{
|
||||
let mut stats = shared_stats.write().await;
|
||||
for peer in &peers {
|
||||
stats.insert(peer.public_key_b64.clone(), WgPeerStats::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn the event loop
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = wg_server_loop(
|
||||
udp_socket,
|
||||
tun_device,
|
||||
peers,
|
||||
peer_index,
|
||||
rate_limiter,
|
||||
config.private_key.clone(),
|
||||
shared_stats,
|
||||
server_stats,
|
||||
started_at,
|
||||
shutdown_rx,
|
||||
command_rx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("WireGuard server loop error: {}", e);
|
||||
}
|
||||
info!("WireGuard server loop exited");
|
||||
});
|
||||
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
self.command_tx = Some(command_tx);
|
||||
self.started_at = Some(started_at);
|
||||
self.listen_port = Some(listen_port);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stop(&mut self) -> Result<()> {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
self.command_tx = None;
|
||||
self.started_at = None;
|
||||
self.listen_port = None;
|
||||
info!("WireGuard server stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_status(&self) -> serde_json::Value {
|
||||
if self.is_running() {
|
||||
serde_json::json!({
|
||||
"state": "running",
|
||||
"listenPort": self.listen_port,
|
||||
"uptimeSeconds": self.started_at.map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0),
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({ "state": "stopped" })
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_statistics(&self) -> serde_json::Value {
|
||||
let mut stats = self.server_stats.write().await;
|
||||
if let Some(started) = self.started_at {
|
||||
stats.uptime_seconds = started.elapsed().as_secs_f64();
|
||||
}
|
||||
// Aggregate from peer stats
|
||||
let peer_stats = self.shared_stats.read().await;
|
||||
stats.active_peers = peer_stats.len();
|
||||
stats.total_bytes_sent = peer_stats.values().map(|s| s.bytes_sent).sum();
|
||||
stats.total_bytes_received = peer_stats.values().map(|s| s.bytes_received).sum();
|
||||
stats.total_packets_sent = peer_stats.values().map(|s| s.packets_sent).sum();
|
||||
stats.total_packets_received = peer_stats.values().map(|s| s.packets_received).sum();
|
||||
serde_json::to_value(&*stats).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn list_peers(&self) -> Vec<WgPeerInfo> {
|
||||
let stats = self.shared_stats.read().await;
|
||||
stats
|
||||
.iter()
|
||||
.map(|(key, s)| WgPeerInfo {
|
||||
public_key: key.clone(),
|
||||
allowed_ips: vec![], // populated from event loop snapshots
|
||||
endpoint: None,
|
||||
persistent_keepalive: None,
|
||||
stats: s.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn add_peer(&self, config: WgPeerConfig) -> Result<()> {
|
||||
let tx = self
|
||||
.command_tx
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Server not running"))?;
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
tx.send(WgCommand::AddPeer(config, resp_tx))
|
||||
.await
|
||||
.map_err(|_| anyhow!("Server event loop closed"))?;
|
||||
resp_rx.await.map_err(|_| anyhow!("No response"))?
|
||||
}
|
||||
|
||||
pub async fn remove_peer(&self, public_key: &str) -> Result<()> {
|
||||
let tx = self
|
||||
.command_tx
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Server not running"))?;
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
tx.send(WgCommand::RemovePeer(public_key.to_string(), resp_tx))
|
||||
.await
|
||||
.map_err(|_| anyhow!("Server event loop closed"))?;
|
||||
resp_rx.await.map_err(|_| anyhow!("No response"))?
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server event loop
|
||||
// ============================================================================
|
||||
|
||||
async fn wg_server_loop(
|
||||
udp_socket: UdpSocket,
|
||||
tun_device: tun::AsyncDevice,
|
||||
mut peers: Vec<PeerState>,
|
||||
peer_index: AtomicU32,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
server_private_key_b64: String,
|
||||
shared_stats: Arc<RwLock<HashMap<String, WgPeerStats>>>,
|
||||
_server_stats: Arc<RwLock<WgServerStats>>,
|
||||
_started_at: Instant,
|
||||
mut shutdown_rx: oneshot::Receiver<()>,
|
||||
mut command_rx: mpsc::Receiver<WgCommand>,
|
||||
) -> Result<()> {
|
||||
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
|
||||
let mut tun_buf = vec![0u8; MAX_UDP_PACKET];
|
||||
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
|
||||
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
|
||||
|
||||
// Split TUN for concurrent read/write in select
|
||||
let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_device);
|
||||
|
||||
// Stats sync interval
|
||||
let mut stats_timer =
|
||||
tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// --- UDP receive ---
|
||||
result = udp_socket.recv_from(&mut udp_buf) => {
|
||||
let (n, src_addr) = result?;
|
||||
if n == 0 { continue; }
|
||||
|
||||
// Find which peer this packet belongs to by trying decapsulate
|
||||
let mut handled = false;
|
||||
for peer in peers.iter_mut() {
|
||||
match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
udp_socket.send_to(packet, src_addr).await?;
|
||||
// Drain loop
|
||||
loop {
|
||||
match peer.tunn.decapsulate(None, &[], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
let ep = peer.endpoint.unwrap_or(src_addr);
|
||||
udp_socket.send_to(pkt, ep).await?;
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
peer.endpoint = Some(src_addr);
|
||||
handled = true;
|
||||
break;
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||||
if peer.matches_dst(IpAddr::V4(addr)) {
|
||||
let pkt_len = packet.len() as u64;
|
||||
tun_writer.write_all(packet).await?;
|
||||
peer.stats.bytes_received += pkt_len;
|
||||
peer.stats.packets_received += 1;
|
||||
}
|
||||
peer.endpoint = Some(src_addr);
|
||||
handled = true;
|
||||
break;
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||||
if peer.matches_dst(IpAddr::V6(addr)) {
|
||||
let pkt_len = packet.len() as u64;
|
||||
tun_writer.write_all(packet).await?;
|
||||
peer.stats.bytes_received += pkt_len;
|
||||
peer.stats.packets_received += 1;
|
||||
}
|
||||
peer.endpoint = Some(src_addr);
|
||||
handled = true;
|
||||
break;
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// This peer didn't recognize the packet, try next
|
||||
continue;
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
debug!("decapsulate error from {}: {:?}", src_addr, e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !handled {
|
||||
debug!("No peer matched UDP packet from {}", src_addr);
|
||||
}
|
||||
}
|
||||
|
||||
// --- TUN read ---
|
||||
result = tun_reader.read(&mut tun_buf) => {
|
||||
let n = result?;
|
||||
if n == 0 { continue; }
|
||||
|
||||
let dst_ip = match extract_dst_ip(&tun_buf[..n]) {
|
||||
Some(ip) => ip,
|
||||
None => { continue; }
|
||||
};
|
||||
|
||||
// Find peer whose AllowedIPs match the destination
|
||||
for peer in peers.iter_mut() {
|
||||
if !peer.matches_dst(dst_ip) {
|
||||
continue;
|
||||
}
|
||||
match peer.tunn.encapsulate(&tun_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
if let Some(endpoint) = peer.endpoint {
|
||||
let pkt_len = n as u64;
|
||||
udp_socket.send_to(packet, endpoint).await?;
|
||||
peer.stats.bytes_sent += pkt_len;
|
||||
peer.stats.packets_sent += 1;
|
||||
} else {
|
||||
debug!("No endpoint for peer {}, dropping packet", peer.public_key_b64);
|
||||
}
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
debug!("encapsulate error for peer {}: {:?}", peer.public_key_b64, e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Timer tick (100ms) for WireGuard timers ---
|
||||
_ = timer.tick() => {
|
||||
for peer in peers.iter_mut() {
|
||||
match peer.tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
if let Some(endpoint) = peer.endpoint {
|
||||
udp_socket.send_to(packet, endpoint).await?;
|
||||
}
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
debug!("Timer error for peer {}: {:?}", peer.public_key_b64, e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Sync stats to shared state ---
|
||||
_ = stats_timer.tick() => {
|
||||
let mut shared = shared_stats.write().await;
|
||||
for peer in peers.iter() {
|
||||
shared.insert(peer.public_key_b64.clone(), peer.stats.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// --- Dynamic peer commands ---
|
||||
cmd = command_rx.recv() => {
|
||||
match cmd {
|
||||
Some(WgCommand::AddPeer(config, resp_tx)) => {
|
||||
let result = add_peer_to_loop(
|
||||
&mut peers,
|
||||
&config,
|
||||
&peer_index,
|
||||
&rate_limiter,
|
||||
&server_private_key_b64,
|
||||
);
|
||||
if result.is_ok() {
|
||||
let mut shared = shared_stats.write().await;
|
||||
shared.insert(config.public_key.clone(), WgPeerStats::default());
|
||||
}
|
||||
let _ = resp_tx.send(result);
|
||||
}
|
||||
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
|
||||
let prev_len = peers.len();
|
||||
peers.retain(|p| p.public_key_b64 != pubkey);
|
||||
if peers.len() < prev_len {
|
||||
let mut shared = shared_stats.write().await;
|
||||
shared.remove(&pubkey);
|
||||
let _ = resp_tx.send(Ok(()));
|
||||
} else {
|
||||
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
info!("Command channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shutdown ---
|
||||
_ = &mut shutdown_rx => {
|
||||
info!("WireGuard server shutdown signal received");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_peer_to_loop(
|
||||
peers: &mut Vec<PeerState>,
|
||||
@@ -757,6 +274,410 @@ fn add_peer_to_loop(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integrated WG listener (shares ServerState with WS/QUIC)
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for the integrated WireGuard listener.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WgListenerConfig {
|
||||
pub private_key: String,
|
||||
pub listen_port: u16,
|
||||
pub peers: Vec<WgPeerConfig>,
|
||||
}
|
||||
|
||||
/// Extract the first /32 IPv4 address from a list of AllowedIp entries.
|
||||
/// This is the peer's VPN IP used for return-packet routing.
|
||||
fn extract_peer_vpn_ip(allowed_ips: &[AllowedIp]) -> Option<Ipv4Addr> {
|
||||
for aip in allowed_ips {
|
||||
if let IpAddr::V4(v4) = aip.addr {
|
||||
if aip.prefix_len == 32 {
|
||||
return Some(v4);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Timestamp helper (mirrors server.rs timestamp_now).
|
||||
fn wg_timestamp_now() -> String {
|
||||
use std::time::SystemTime;
|
||||
let duration = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
format!("{}", duration.as_secs())
|
||||
}
|
||||
|
||||
/// Register a WG peer in ServerState (tun_routes, clients, ip_pool).
|
||||
/// Returns the VPN IP and the per-peer return-packet receiver.
|
||||
async fn register_wg_peer(
|
||||
state: &Arc<ServerState>,
|
||||
peer: &PeerState,
|
||||
wg_return_tx: &mpsc::Sender<(String, Vec<u8>)>,
|
||||
) -> Result<Option<Ipv4Addr>> {
|
||||
let vpn_ip = match extract_peer_vpn_ip(&peer.allowed_ips) {
|
||||
Some(ip) => ip,
|
||||
None => {
|
||||
warn!("WG peer {} has no /32 IPv4 in allowed_ips, skipping registration",
|
||||
peer.public_key_b64);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
|
||||
|
||||
// Reserve IP in the pool
|
||||
if let Err(e) = state.ip_pool.lock().await.reserve(vpn_ip, &client_id) {
|
||||
warn!("Failed to reserve IP {} for WG peer {}: {}", vpn_ip, client_id, e);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Create per-peer return channel and register in tun_routes
|
||||
let fwd_mode = state.config.forwarding_mode.as_deref().unwrap_or("testing");
|
||||
let forwarding_active = fwd_mode == "tun" || fwd_mode == "socket";
|
||||
if forwarding_active {
|
||||
let (peer_return_tx, mut peer_return_rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
state.tun_routes.write().await.insert(vpn_ip, peer_return_tx);
|
||||
|
||||
// Spawn relay task: per-peer channel → merged channel tagged with pubkey
|
||||
let relay_tx = wg_return_tx.clone();
|
||||
let pubkey = peer.public_key_b64.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(packet) = peer_return_rx.recv().await {
|
||||
if relay_tx.send((pubkey.clone(), packet)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Insert ClientInfo
|
||||
let client_info = ClientInfo {
|
||||
client_id: client_id.clone(),
|
||||
assigned_ip: vpn_ip.to_string(),
|
||||
connected_since: wg_timestamp_now(),
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
packets_dropped: 0,
|
||||
bytes_dropped: 0,
|
||||
last_keepalive_at: None,
|
||||
keepalives_received: 0,
|
||||
rate_limit_bytes_per_sec: None,
|
||||
burst_bytes: None,
|
||||
authenticated_key: peer.public_key_b64.clone(),
|
||||
registered_client_id: client_id,
|
||||
remote_addr: peer.endpoint.map(|e| e.to_string()),
|
||||
transport_type: "wireguard".to_string(),
|
||||
};
|
||||
state.clients.write().await.insert(client_info.client_id.clone(), client_info);
|
||||
|
||||
Ok(Some(vpn_ip))
|
||||
}
|
||||
|
||||
/// Unregister a WG peer from ServerState.
|
||||
async fn unregister_wg_peer(
|
||||
state: &Arc<ServerState>,
|
||||
pubkey: &str,
|
||||
vpn_ip: Option<Ipv4Addr>,
|
||||
) {
|
||||
let client_id = format!("wg-{}", &pubkey[..8.min(pubkey.len())]);
|
||||
|
||||
if let Some(ip) = vpn_ip {
|
||||
state.tun_routes.write().await.remove(&ip);
|
||||
state.ip_pool.lock().await.release(&ip);
|
||||
}
|
||||
state.clients.write().await.remove(&client_id);
|
||||
state.rate_limiters.lock().await.remove(&client_id);
|
||||
}
|
||||
|
||||
/// Integrated WireGuard listener that shares ServerState with WS/QUIC listeners.
|
||||
/// Uses the shared ForwardingEngine for packet routing instead of its own TUN device.
|
||||
pub async fn run_wg_listener(
|
||||
state: Arc<ServerState>,
|
||||
config: WgListenerConfig,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
mut command_rx: mpsc::Receiver<WgCommand>,
|
||||
) -> Result<()> {
|
||||
// Parse server private key
|
||||
let server_private = parse_private_key(&config.private_key)?;
|
||||
let server_public = PublicKey::from(&server_private);
|
||||
|
||||
// Create rate limiter for DDoS protection
|
||||
let rate_limiter = Arc::new(RateLimiter::new(&server_public, TIMER_TICK_MS as u64));
|
||||
|
||||
// Build initial peer state
|
||||
let peer_index = AtomicU32::new(0);
|
||||
let mut peers: Vec<PeerState> = Vec::with_capacity(config.peers.len());
|
||||
|
||||
for peer_config in &config.peers {
|
||||
let peer_public = parse_public_key(&peer_config.public_key)?;
|
||||
let psk = match &peer_config.preshared_key {
|
||||
Some(k) => Some(parse_preshared_key(k)?),
|
||||
None => None,
|
||||
};
|
||||
let idx = peer_index.fetch_add(1, Ordering::Relaxed);
|
||||
let priv_copy = parse_private_key(&config.private_key)?;
|
||||
|
||||
let tunn = Tunn::new(
|
||||
priv_copy,
|
||||
peer_public,
|
||||
psk,
|
||||
peer_config.persistent_keepalive,
|
||||
idx,
|
||||
Some(rate_limiter.clone()),
|
||||
);
|
||||
|
||||
let allowed_ips: Vec<AllowedIp> = peer_config
|
||||
.allowed_ips
|
||||
.iter()
|
||||
.map(|cidr| AllowedIp::parse(cidr))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let endpoint = match &peer_config.endpoint {
|
||||
Some(ep) => Some(ep.parse::<SocketAddr>()?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
peers.push(PeerState {
|
||||
tunn,
|
||||
public_key_b64: peer_config.public_key.clone(),
|
||||
allowed_ips,
|
||||
endpoint,
|
||||
persistent_keepalive: peer_config.persistent_keepalive,
|
||||
stats: WgPeerStats::default(),
|
||||
});
|
||||
}
|
||||
|
||||
// Bind UDP socket
|
||||
let udp_socket = UdpSocket::bind(format!("0.0.0.0:{}", config.listen_port)).await?;
|
||||
info!("WireGuard listener started on UDP port {}", config.listen_port);
|
||||
|
||||
// Merged return-packet channel: all per-peer channels feed into this
|
||||
let (wg_return_tx, mut wg_return_rx) = mpsc::channel::<(String, Vec<u8>)>(1024);
|
||||
|
||||
// Register initial peers in ServerState and track their VPN IPs
|
||||
let mut peer_vpn_ips: HashMap<String, Ipv4Addr> = HashMap::new();
|
||||
for peer in &peers {
|
||||
if let Ok(Some(ip)) = register_wg_peer(&state, peer, &wg_return_tx).await {
|
||||
peer_vpn_ips.insert(peer.public_key_b64.clone(), ip);
|
||||
}
|
||||
}
|
||||
|
||||
// Buffers
|
||||
let mut udp_buf = vec![0u8; MAX_UDP_PACKET];
|
||||
let mut dst_buf = vec![0u8; WG_BUFFER_SIZE];
|
||||
let mut timer = tokio::time::interval(std::time::Duration::from_millis(TIMER_TICK_MS));
|
||||
let mut stats_timer = tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// --- UDP receive → decapsulate → ForwardingEngine ---
|
||||
result = udp_socket.recv_from(&mut udp_buf) => {
|
||||
let (n, src_addr) = result?;
|
||||
if n == 0 { continue; }
|
||||
|
||||
let mut handled = false;
|
||||
for peer in peers.iter_mut() {
|
||||
match peer.tunn.decapsulate(Some(src_addr.ip()), &udp_buf[..n], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
udp_socket.send_to(packet, src_addr).await?;
|
||||
loop {
|
||||
match peer.tunn.decapsulate(None, &[], &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(pkt) => {
|
||||
let ep = peer.endpoint.unwrap_or(src_addr);
|
||||
udp_socket.send_to(pkt, ep).await?;
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
peer.endpoint = Some(src_addr);
|
||||
handled = true;
|
||||
break;
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||||
if peer.matches_dst(IpAddr::V4(addr)) {
|
||||
let pkt_len = packet.len() as u64;
|
||||
// Forward via shared forwarding engine
|
||||
let mut engine = state.forwarding_engine.lock().await;
|
||||
match &mut *engine {
|
||||
ForwardingEngine::Tun(writer) => {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
if let Err(e) = writer.write_all(packet).await {
|
||||
warn!("TUN write error for WG peer: {}", e);
|
||||
}
|
||||
}
|
||||
ForwardingEngine::Socket(sender) => {
|
||||
let _ = sender.try_send(packet.to_vec());
|
||||
}
|
||||
ForwardingEngine::Testing => {}
|
||||
}
|
||||
peer.stats.bytes_received += pkt_len;
|
||||
peer.stats.packets_received += 1;
|
||||
}
|
||||
peer.endpoint = Some(src_addr);
|
||||
handled = true;
|
||||
break;
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||||
if peer.matches_dst(IpAddr::V6(addr)) {
|
||||
let pkt_len = packet.len() as u64;
|
||||
let mut engine = state.forwarding_engine.lock().await;
|
||||
match &mut *engine {
|
||||
ForwardingEngine::Tun(writer) => {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
if let Err(e) = writer.write_all(packet).await {
|
||||
warn!("TUN write error for WG peer: {}", e);
|
||||
}
|
||||
}
|
||||
ForwardingEngine::Socket(sender) => {
|
||||
let _ = sender.try_send(packet.to_vec());
|
||||
}
|
||||
ForwardingEngine::Testing => {}
|
||||
}
|
||||
peer.stats.bytes_received += pkt_len;
|
||||
peer.stats.packets_received += 1;
|
||||
}
|
||||
peer.endpoint = Some(src_addr);
|
||||
handled = true;
|
||||
break;
|
||||
}
|
||||
TunnResult::Done => { continue; }
|
||||
TunnResult::Err(e) => {
|
||||
debug!("decapsulate error from {}: {:?}", src_addr, e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !handled {
|
||||
debug!("No WG peer matched UDP packet from {}", src_addr);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Return packets from tun_routes → encapsulate → UDP ---
|
||||
Some((pubkey, packet)) = wg_return_rx.recv() => {
|
||||
if let Some(peer) = peers.iter_mut().find(|p| p.public_key_b64 == pubkey) {
|
||||
match peer.tunn.encapsulate(&packet, &mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(out) => {
|
||||
if let Some(endpoint) = peer.endpoint {
|
||||
let pkt_len = packet.len() as u64;
|
||||
udp_socket.send_to(out, endpoint).await?;
|
||||
peer.stats.bytes_sent += pkt_len;
|
||||
peer.stats.packets_sent += 1;
|
||||
} else {
|
||||
debug!("No endpoint for WG peer {}, dropping return packet",
|
||||
peer.public_key_b64);
|
||||
}
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
debug!("encapsulate error for WG peer {}: {:?}",
|
||||
peer.public_key_b64, e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- WireGuard protocol timers (100ms) ---
|
||||
_ = timer.tick() => {
|
||||
for peer in peers.iter_mut() {
|
||||
match peer.tunn.update_timers(&mut dst_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
if let Some(endpoint) = peer.endpoint {
|
||||
udp_socket.send_to(packet, endpoint).await?;
|
||||
}
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
debug!("Timer error for WG peer {}: {:?}",
|
||||
peer.public_key_b64, e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Sync stats to ServerState (every 1s) ---
|
||||
_ = stats_timer.tick() => {
|
||||
let mut clients = state.clients.write().await;
|
||||
let mut stats = state.stats.write().await;
|
||||
for peer in peers.iter() {
|
||||
let client_id = format!("wg-{}", &peer.public_key_b64[..8.min(peer.public_key_b64.len())]);
|
||||
if let Some(info) = clients.get_mut(&client_id) {
|
||||
// Update stats delta
|
||||
let prev_sent = info.bytes_sent;
|
||||
let prev_recv = info.bytes_received;
|
||||
info.bytes_sent = peer.stats.bytes_sent;
|
||||
info.bytes_received = peer.stats.bytes_received;
|
||||
info.remote_addr = peer.endpoint.map(|e| e.to_string());
|
||||
|
||||
// Update aggregate stats
|
||||
stats.bytes_sent += peer.stats.bytes_sent.saturating_sub(prev_sent);
|
||||
stats.bytes_received += peer.stats.bytes_received.saturating_sub(prev_recv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Dynamic peer commands ---
|
||||
cmd = command_rx.recv() => {
|
||||
match cmd {
|
||||
Some(WgCommand::AddPeer(peer_config, resp_tx)) => {
|
||||
let result = add_peer_to_loop(
|
||||
&mut peers,
|
||||
&peer_config,
|
||||
&peer_index,
|
||||
&rate_limiter,
|
||||
&config.private_key,
|
||||
);
|
||||
if result.is_ok() {
|
||||
// Register new peer in ServerState
|
||||
let peer = peers.last().unwrap();
|
||||
match register_wg_peer(&state, peer, &wg_return_tx).await {
|
||||
Ok(Some(ip)) => {
|
||||
peer_vpn_ips.insert(peer_config.public_key.clone(), ip);
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
warn!("Failed to register WG peer: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = resp_tx.send(result);
|
||||
}
|
||||
Some(WgCommand::RemovePeer(pubkey, resp_tx)) => {
|
||||
let prev_len = peers.len();
|
||||
peers.retain(|p| p.public_key_b64 != pubkey);
|
||||
if peers.len() < prev_len {
|
||||
let vpn_ip = peer_vpn_ips.remove(&pubkey);
|
||||
unregister_wg_peer(&state, &pubkey, vpn_ip).await;
|
||||
let _ = resp_tx.send(Ok(()));
|
||||
} else {
|
||||
let _ = resp_tx.send(Err(anyhow!("Peer not found: {}", pubkey)));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
info!("WG command channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shutdown ---
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("WireGuard listener shutdown signal received");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup: unregister all peers from ServerState
|
||||
for peer in &peers {
|
||||
let vpn_ip = peer_vpn_ips.get(&peer.public_key_b64).copied();
|
||||
unregister_wg_peer(&state, &peer.public_key_b64, vpn_ip).await;
|
||||
}
|
||||
|
||||
info!("WireGuard listener stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WgClient
|
||||
// ============================================================================
|
||||
@@ -1077,6 +998,7 @@ fn chrono_now() -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tunnel::extract_dst_ip;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user