Files
smartvpn/rust/src/server.rs

659 lines
26 KiB
Rust
Raw Normal View History

2026-02-27 10:18:23 +00:00
use anyhow::Result;
use bytes::BytesMut;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::time::Duration;
2026-02-27 10:18:23 +00:00
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::{info, error, warn};
use crate::codec::{Frame, FrameCodec, PacketType};
use crate::crypto;
use crate::mtu::{MtuConfig, TunnelOverhead};
2026-02-27 10:18:23 +00:00
use crate::network::IpPool;
use crate::ratelimit::TokenBucket;
2026-02-27 10:18:23 +00:00
use crate::transport;
use crate::transport_trait::{self, TransportSink, TransportStream};
use crate::quic_transport;
2026-02-27 10:18:23 +00:00
/// Dead-peer timeout: 3x max keepalive interval (Healthy=60s).
const DEAD_PEER_TIMEOUT: Duration = Duration::from_secs(180);
2026-02-27 10:18:23 +00:00
/// Server configuration (matches TS IVpnServerConfig).
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerConfig {
pub listen_addr: String,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub private_key: String,
pub public_key: String,
pub subnet: String,
pub dns: Option<Vec<String>>,
pub mtu: Option<u16>,
pub keepalive_interval_secs: Option<u64>,
pub enable_nat: Option<bool>,
/// Default rate limit for new clients (bytes/sec). None = unlimited.
pub default_rate_limit_bytes_per_sec: Option<u64>,
/// Default burst size for new clients (bytes). None = unlimited.
pub default_burst_bytes: Option<u64>,
/// Transport mode: "websocket" (default), "quic", or "both".
pub transport_mode: Option<String>,
/// QUIC listen address (host:port). Defaults to listen_addr.
pub quic_listen_addr: Option<String>,
/// QUIC idle timeout in seconds (default: 30).
pub quic_idle_timeout_secs: Option<u64>,
2026-02-27 10:18:23 +00:00
}
/// Information about a connected client.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientInfo {
pub client_id: String,
pub assigned_ip: String,
pub connected_since: String,
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_dropped: u64,
pub bytes_dropped: u64,
pub last_keepalive_at: Option<String>,
pub keepalives_received: u64,
pub rate_limit_bytes_per_sec: Option<u64>,
pub burst_bytes: Option<u64>,
2026-02-27 10:18:23 +00:00
}
/// Server statistics.
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ServerStatistics {
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub keepalives_sent: u64,
pub keepalives_received: u64,
pub uptime_seconds: u64,
pub active_clients: u64,
pub total_connections: u64,
}
/// Shared server state.
pub struct ServerState {
pub config: ServerConfig,
pub ip_pool: Mutex<IpPool>,
pub clients: RwLock<HashMap<String, ClientInfo>>,
pub stats: RwLock<ServerStatistics>,
pub rate_limiters: Mutex<HashMap<String, TokenBucket>>,
pub mtu_config: MtuConfig,
2026-02-27 10:18:23 +00:00
pub started_at: std::time::Instant,
}
/// The VPN server.
pub struct VpnServer {
state: Option<Arc<ServerState>>,
shutdown_tx: Option<mpsc::Sender<()>>,
}
impl VpnServer {
pub fn new() -> Self {
Self {
state: None,
shutdown_tx: None,
}
}
pub async fn start(&mut self, config: ServerConfig) -> Result<()> {
if self.state.is_some() {
anyhow::bail!("Server is already running");
}
let ip_pool = IpPool::new(&config.subnet)?;
if config.enable_nat.unwrap_or(false) {
if let Err(e) = crate::network::enable_ip_forwarding() {
warn!("Failed to enable IP forwarding: {}", e);
}
if let Ok(iface) = crate::network::get_default_interface() {
if let Err(e) = crate::network::setup_nat(&config.subnet, &iface).await {
warn!("Failed to setup NAT: {}", e);
}
}
}
let link_mtu = config.mtu.unwrap_or(1420);
// Compute effective MTU from overhead
let overhead = TunnelOverhead::default_overhead();
let mtu_config = MtuConfig::new(overhead.effective_tun_mtu(1500).max(link_mtu));
2026-02-27 10:18:23 +00:00
let state = Arc::new(ServerState {
config: config.clone(),
ip_pool: Mutex::new(ip_pool),
clients: RwLock::new(HashMap::new()),
stats: RwLock::new(ServerStatistics::default()),
rate_limiters: Mutex::new(HashMap::new()),
mtu_config,
2026-02-27 10:18:23 +00:00
started_at: std::time::Instant::now(),
});
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
self.state = Some(state.clone());
self.shutdown_tx = Some(shutdown_tx);
let transport_mode = config.transport_mode.as_deref().unwrap_or("both");
2026-02-27 10:18:23 +00:00
let listen_addr = config.listen_addr.clone();
match transport_mode {
"quic" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state, quic_addr, idle_timeout, &mut shutdown_rx).await {
error!("QUIC listener error: {}", e);
}
});
2026-02-27 10:18:23 +00:00
}
"both" => {
let quic_addr = config.quic_listen_addr.clone().unwrap_or_else(|| listen_addr.clone());
let idle_timeout = config.quic_idle_timeout_secs.unwrap_or(30);
let state2 = state.clone();
let (shutdown_tx2, mut shutdown_rx2) = mpsc::channel::<()>(1);
// Store second shutdown sender so both listeners stop
let shutdown_tx_orig = self.shutdown_tx.take().unwrap();
let (combined_tx, mut combined_rx) = mpsc::channel::<()>(1);
self.shutdown_tx = Some(combined_tx);
// Forward combined shutdown to both listeners
tokio::spawn(async move {
combined_rx.recv().await;
let _ = shutdown_tx_orig.send(()).await;
let _ = shutdown_tx2.send(()).await;
});
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("WebSocket listener error: {}", e);
}
});
tokio::spawn(async move {
if let Err(e) = run_quic_listener(state2, quic_addr, idle_timeout, &mut shutdown_rx2).await {
error!("QUIC listener error: {}", e);
}
});
}
_ => {
// "websocket" (default)
tokio::spawn(async move {
if let Err(e) = run_ws_listener(state, listen_addr, &mut shutdown_rx).await {
error!("Server listener error: {}", e);
}
});
}
}
2026-02-27 10:18:23 +00:00
info!("VPN server started (transport: {})", transport_mode);
2026-02-27 10:18:23 +00:00
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
self.state = None;
info!("VPN server stopped");
Ok(())
}
pub fn get_status(&self) -> serde_json::Value {
if let Some(ref state) = self.state {
serde_json::json!({
"state": "connected",
"connectedSince": format!("{:?}", state.started_at.elapsed()),
})
} else {
serde_json::json!({ "state": "disconnected" })
}
}
pub async fn get_statistics(&self) -> ServerStatistics {
if let Some(ref state) = self.state {
let mut stats = state.stats.read().await.clone();
stats.uptime_seconds = state.started_at.elapsed().as_secs();
stats.active_clients = state.clients.read().await.len() as u64;
stats
} else {
ServerStatistics::default()
}
}
pub async fn list_clients(&self) -> Vec<ClientInfo> {
if let Some(ref state) = self.state {
state.clients.read().await.values().cloned().collect()
} else {
Vec::new()
}
}
pub async fn disconnect_client(&self, client_id: &str) -> Result<()> {
if let Some(ref state) = self.state {
let mut clients = state.clients.write().await;
if let Some(client) = clients.remove(client_id) {
let ip: Ipv4Addr = client.assigned_ip.parse()?;
state.ip_pool.lock().await.release(&ip);
state.rate_limiters.lock().await.remove(client_id);
2026-02-27 10:18:23 +00:00
info!("Client {} disconnected", client_id);
}
}
Ok(())
}
/// Set a rate limit for a specific client.
pub async fn set_client_rate_limit(
&self,
client_id: &str,
rate_bytes_per_sec: u64,
burst_bytes: u64,
) -> Result<()> {
if let Some(ref state) = self.state {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(client_id) {
limiter.update_limits(rate_bytes_per_sec, burst_bytes);
} else {
limiters.insert(
client_id.to_string(),
TokenBucket::new(rate_bytes_per_sec, burst_bytes),
);
}
// Update client info
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = Some(rate_bytes_per_sec);
info.burst_bytes = Some(burst_bytes);
}
}
Ok(())
}
/// Remove rate limit for a specific client (unlimited).
pub async fn remove_client_rate_limit(&self, client_id: &str) -> Result<()> {
if let Some(ref state) = self.state {
state.rate_limiters.lock().await.remove(client_id);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(client_id) {
info.rate_limit_bytes_per_sec = None;
info.burst_bytes = None;
}
}
Ok(())
}
2026-02-27 10:18:23 +00:00
}
/// WebSocket listener — accepts TCP connections, upgrades to WS, then hands off
/// to the transport-agnostic `handle_client_connection`.
async fn run_ws_listener(
2026-02-27 10:18:23 +00:00
state: Arc<ServerState>,
listen_addr: String,
shutdown_rx: &mut mpsc::Receiver<()>,
) -> Result<()> {
let listener = TcpListener::bind(&listen_addr).await?;
info!("WebSocket server listening on {}", listen_addr);
loop {
tokio::select! {
accept = listener.accept() => {
match accept {
Ok((stream, addr)) => {
info!("New connection from {}", addr);
let state = state.clone();
tokio::spawn(async move {
match transport::accept_connection(stream).await {
Ok(ws) => {
let (sink, stream) = transport_trait::split_ws(ws);
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("Client connection error: {}", e);
}
}
Err(e) => {
warn!("WebSocket upgrade failed: {}", e);
}
2026-02-27 10:18:23 +00:00
}
});
}
Err(e) => {
error!("Accept error: {}", e);
}
}
}
_ = shutdown_rx.recv() => {
info!("Shutdown signal received");
break;
}
}
}
Ok(())
}
/// QUIC listener — accepts QUIC connections and hands off to the transport-agnostic
/// `handle_client_connection`.
async fn run_quic_listener(
2026-02-27 10:18:23 +00:00
state: Arc<ServerState>,
listen_addr: String,
idle_timeout_secs: u64,
shutdown_rx: &mut mpsc::Receiver<()>,
2026-02-27 10:18:23 +00:00
) -> Result<()> {
// Generate or use configured TLS certificate for QUIC
let (cert_chain, private_key) = if let (Some(ref cert_pem), Some(ref key_pem)) =
(&state.config.tls_cert, &state.config.tls_key)
{
// Parse PEM certificates
let certs: Vec<rustls_pki_types::CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
(certs, key)
} else {
// Generate self-signed certificate
let (certs, key) = quic_transport::generate_self_signed_cert()?;
info!("QUIC using self-signed certificate (hash: {})", quic_transport::cert_hash(&certs[0]));
(certs, key)
};
let endpoint = quic_transport::create_quic_server(quic_transport::QuicServerConfig {
listen_addr,
cert_chain,
private_key,
idle_timeout_secs,
})?;
2026-02-27 10:18:23 +00:00
loop {
tokio::select! {
incoming = endpoint.accept() => {
match incoming {
Some(incoming) => {
let state = state.clone();
tokio::spawn(async move {
match incoming.await {
Ok(conn) => {
let remote = conn.remote_address();
info!("New QUIC connection from {}", remote);
match quic_transport::accept_quic_connection(conn).await {
Ok((sink, stream)) => {
if let Err(e) = handle_client_connection(
state,
Box::new(sink),
Box::new(stream),
).await {
warn!("QUIC client error: {}", e);
}
}
Err(e) => {
warn!("QUIC stream accept failed: {}", e);
}
}
}
Err(e) => {
warn!("QUIC handshake failed: {}", e);
}
}
});
}
None => {
info!("QUIC endpoint closed");
break;
}
}
}
_ = shutdown_rx.recv() => {
info!("QUIC shutdown signal received");
endpoint.close(0u32.into(), b"shutdown");
break;
}
}
}
Ok(())
}
/// Transport-agnostic client handler. Performs the Noise NK handshake, registers
/// the client, and runs the main packet forwarding loop.
async fn handle_client_connection(
state: Arc<ServerState>,
mut sink: Box<dyn TransportSink>,
mut stream: Box<dyn TransportStream>,
) -> Result<()> {
2026-02-27 10:18:23 +00:00
let client_id = uuid_v4();
let assigned_ip = state.ip_pool.lock().await.allocate(&client_id)?;
let server_private_key = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
&state.config.private_key,
)?;
let mut responder = crypto::create_responder(&server_private_key)?;
let mut buf = vec![0u8; 65535];
// Receive handshake init
let init_msg = match stream.recv_reliable().await? {
Some(data) => data,
None => anyhow::bail!("Connection closed before handshake"),
2026-02-27 10:18:23 +00:00
};
let mut frame_buf = BytesMut::from(&init_msg[..]);
let frame = <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf)?
.ok_or_else(|| anyhow::anyhow!("Incomplete handshake frame"))?;
if frame.packet_type != PacketType::HandshakeInit {
anyhow::bail!("Expected HandshakeInit, got {:?}", frame.packet_type);
}
responder.read_message(&frame.payload, &mut buf)?;
let len = responder.write_message(&[], &mut buf)?;
let response_payload = buf[..len].to_vec();
let response_frame = Frame {
packet_type: PacketType::HandshakeResp,
payload: response_payload,
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, response_frame, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
2026-02-27 10:18:23 +00:00
let mut noise_transport = responder.into_transport_mode()?;
// Register client
let default_rate = state.config.default_rate_limit_bytes_per_sec;
let default_burst = state.config.default_burst_bytes;
2026-02-27 10:18:23 +00:00
let client_info = ClientInfo {
client_id: client_id.clone(),
assigned_ip: assigned_ip.to_string(),
connected_since: timestamp_now(),
bytes_sent: 0,
bytes_received: 0,
packets_dropped: 0,
bytes_dropped: 0,
last_keepalive_at: None,
keepalives_received: 0,
rate_limit_bytes_per_sec: default_rate,
burst_bytes: default_burst,
2026-02-27 10:18:23 +00:00
};
state.clients.write().await.insert(client_id.clone(), client_info);
// Set up rate limiter if defaults are configured
if let (Some(rate), Some(burst)) = (default_rate, default_burst) {
state
.rate_limiters
.lock()
.await
.insert(client_id.clone(), TokenBucket::new(rate, burst));
}
2026-02-27 10:18:23 +00:00
{
let mut stats = state.stats.write().await;
stats.total_connections += 1;
}
// Send assigned IP info (encrypted), include effective MTU
2026-02-27 10:18:23 +00:00
let ip_info = serde_json::json!({
"assignedIp": assigned_ip.to_string(),
"gateway": state.ip_pool.lock().await.gateway_addr().to_string(),
"mtu": state.config.mtu.unwrap_or(1420),
"effectiveMtu": state.mtu_config.effective_mtu,
2026-02-27 10:18:23 +00:00
});
let ip_info_bytes = serde_json::to_vec(&ip_info)?;
let len = noise_transport.write_message(&ip_info_bytes, &mut buf)?;
let encrypted_info = Frame {
packet_type: PacketType::IpPacket,
payload: buf[..len].to_vec(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, encrypted_info, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
2026-02-27 10:18:23 +00:00
info!("Client {} connected with IP {}", client_id, assigned_ip);
// Main packet loop with dead-peer detection
let mut last_activity = tokio::time::Instant::now();
2026-02-27 10:18:23 +00:00
loop {
tokio::select! {
msg = stream.recv_reliable() => {
match msg {
Ok(Some(data)) => {
last_activity = tokio::time::Instant::now();
let mut frame_buf = BytesMut::from(&data[..]);
match <FrameCodec as tokio_util::codec::Decoder>::decode(&mut FrameCodec, &mut frame_buf) {
Ok(Some(frame)) => match frame.packet_type {
PacketType::IpPacket => {
match noise_transport.read_message(&frame.payload, &mut buf) {
Ok(len) => {
// Rate limiting check
let allowed = {
let mut limiters = state.rate_limiters.lock().await;
if let Some(limiter) = limiters.get_mut(&client_id) {
limiter.try_consume(len)
} else {
true
}
};
if !allowed {
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.packets_dropped += 1;
info.bytes_dropped += len as u64;
}
continue;
}
let mut stats = state.stats.write().await;
stats.bytes_received += len as u64;
stats.packets_received += 1;
// Update per-client stats
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.bytes_received += len as u64;
}
}
Err(e) => {
warn!("Decrypt error from {}: {}", client_id, e);
break;
}
}
}
PacketType::Keepalive => {
// Echo the keepalive payload back in the ACK
let ack_frame = Frame {
packet_type: PacketType::KeepaliveAck,
payload: frame.payload.clone(),
};
let mut frame_bytes = BytesMut::new();
<FrameCodec as tokio_util::codec::Encoder<Frame>>::encode(&mut FrameCodec, ack_frame, &mut frame_bytes)?;
sink.send_reliable(frame_bytes.to_vec()).await?;
2026-02-27 10:18:23 +00:00
let mut stats = state.stats.write().await;
stats.keepalives_received += 1;
stats.keepalives_sent += 1;
// Update per-client keepalive tracking
drop(stats);
let mut clients = state.clients.write().await;
if let Some(info) = clients.get_mut(&client_id) {
info.last_keepalive_at = Some(timestamp_now());
info.keepalives_received += 1;
}
2026-02-27 10:18:23 +00:00
}
PacketType::Disconnect => {
info!("Client {} sent disconnect", client_id);
2026-02-27 10:18:23 +00:00
break;
}
_ => {
warn!("Unexpected packet type from {}: {:?}", client_id, frame.packet_type);
}
},
Ok(None) => {
warn!("Incomplete frame from {}", client_id);
}
Err(e) => {
warn!("Frame decode error from {}: {}", client_id, e);
break;
2026-02-27 10:18:23 +00:00
}
}
}
Ok(None) => {
info!("Client {} connection closed", client_id);
break;
}
Err(e) => {
warn!("Transport error from {}: {}", client_id, e);
2026-02-27 10:18:23 +00:00
break;
}
}
}
_ = tokio::time::sleep_until(last_activity + DEAD_PEER_TIMEOUT) => {
warn!("Client {} dead-peer timeout ({}s inactivity)", client_id, DEAD_PEER_TIMEOUT.as_secs());
2026-02-27 10:18:23 +00:00
break;
}
}
}
// Cleanup
state.clients.write().await.remove(&client_id);
state.ip_pool.lock().await.release(&assigned_ip);
state.rate_limiters.lock().await.remove(&client_id);
2026-02-27 10:18:23 +00:00
info!("Client {} disconnected, released IP {}", client_id, assigned_ip);
Ok(())
}
fn uuid_v4() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let bytes: [u8; 16] = rng.gen();
format!(
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
bytes[0], bytes[1], bytes[2], bytes[3],
bytes[4], bytes[5],
bytes[6], bytes[7],
bytes[8], bytes[9],
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
)
}
fn timestamp_now() -> String {
use std::time::SystemTime;
let duration = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
format!("{}", duration.as_secs())
}