Files
smartvpn/rust/src/crypto.rs

236 lines
8.6 KiB
Rust
Raw Normal View History

2026-02-27 10:18:23 +00:00
use anyhow::Result;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use snow::Builder;
/// Noise protocol pattern: IK (client presents static key, server authenticates client)
/// IK = Initiator's static key is transmitted; responder's Key is pre-known.
/// This provides mutual authentication: server verifies client identity via public key.
const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
2026-02-27 10:18:23 +00:00
/// Generate a new Noise static keypair.
/// Returns (public_key_base64, private_key_base64).
pub fn generate_keypair() -> Result<(String, String)> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
let keypair = builder.generate_keypair()?;
let public_key = BASE64.encode(&keypair.public);
let private_key = BASE64.encode(&keypair.private);
Ok((public_key, private_key))
}
/// Generate a raw Noise static keypair (not base64 encoded).
pub fn generate_keypair_raw() -> Result<snow::Keypair> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
Ok(builder.generate_keypair()?)
}
/// Create a Noise IK initiator (client side).
/// The client provides its own static keypair AND the server's public key.
/// The client's static key is transmitted (encrypted) during the handshake,
/// allowing the server to authenticate the client.
pub fn create_initiator(client_private_key: &[u8], server_public_key: &[u8]) -> Result<snow::HandshakeState> {
2026-02-27 10:18:23 +00:00
let builder = Builder::new(NOISE_PATTERN.parse()?);
let state = builder
.local_private_key(client_private_key)
2026-02-27 10:18:23 +00:00
.remote_public_key(server_public_key)
.build_initiator()?;
Ok(state)
}
/// Create a Noise IK responder (server side).
2026-02-27 10:18:23 +00:00
/// The server uses its static private key.
/// After the handshake, call `get_remote_static()` on the HandshakeState
/// (before `into_transport_mode()`) to retrieve the client's public key.
2026-02-27 10:18:23 +00:00
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
let state = builder
.local_private_key(private_key)
.build_responder()?;
Ok(state)
}
/// Perform the full Noise IK handshake between initiator and responder.
/// Returns (initiator_transport, responder_transport, client_public_key).
/// The client_public_key is extracted from the responder before entering transport mode.
2026-02-27 10:18:23 +00:00
pub fn perform_handshake(
mut initiator: snow::HandshakeState,
mut responder: snow::HandshakeState,
) -> Result<(snow::TransportState, snow::TransportState, Vec<u8>)> {
2026-02-27 10:18:23 +00:00
let mut buf = vec![0u8; 65535];
// -> e, es, s, ss (initiator sends ephemeral + encrypted static key)
2026-02-27 10:18:23 +00:00
let len = initiator.write_message(&[], &mut buf)?;
let msg1 = buf[..len].to_vec();
// <- e, ee, se (responder reads and responds)
2026-02-27 10:18:23 +00:00
responder.read_message(&msg1, &mut buf)?;
let len = responder.write_message(&[], &mut buf)?;
let msg2 = buf[..len].to_vec();
// Initiator reads response
initiator.read_message(&msg2, &mut buf)?;
// Extract client's public key from responder BEFORE entering transport mode
let client_public_key = responder
.get_remote_static()
.ok_or_else(|| anyhow::anyhow!("IK handshake did not provide client static key"))?
.to_vec();
2026-02-27 10:18:23 +00:00
let i_transport = initiator.into_transport_mode()?;
let r_transport = responder.into_transport_mode()?;
Ok((i_transport, r_transport, client_public_key))
2026-02-27 10:18:23 +00:00
}
/// XChaCha20-Poly1305 encryption for post-handshake data.
/// Uses random 24-byte nonces (safe due to large nonce space).
pub mod xchacha {
use anyhow::Result;
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, KeyInit},
};
use rand::RngCore;
pub const NONCE_SIZE: usize = 24;
pub const TAG_SIZE: usize = 16;
/// Encrypt plaintext with XChaCha20-Poly1305.
/// Returns: nonce (24 bytes) + ciphertext + tag (16 bytes).
pub fn encrypt(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
let cipher = XChaCha20Poly1305::new(key.into());
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
output.extend_from_slice(&nonce_bytes);
output.extend_from_slice(&ciphertext);
Ok(output)
}
/// Decrypt data encrypted with `encrypt()`.
/// Input: nonce (24 bytes) + ciphertext + tag (16 bytes).
pub fn decrypt(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
if data.len() < NONCE_SIZE + TAG_SIZE {
anyhow::bail!("Ciphertext too short: {} bytes", data.len());
}
let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
let nonce = XNonce::from_slice(nonce_bytes);
let cipher = XChaCha20Poly1305::new(key.into());
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
Ok(plaintext)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn keypair_generation() {
let (pub_key, priv_key) = generate_keypair().unwrap();
// Base64-encoded 32-byte keys = 44 chars
assert_eq!(pub_key.len(), 44);
assert_eq!(priv_key.len(), 44);
// Verify they decode back to 32 bytes
let pub_bytes = BASE64.decode(&pub_key).unwrap();
let priv_bytes = BASE64.decode(&priv_key).unwrap();
assert_eq!(pub_bytes.len(), 32);
assert_eq!(priv_bytes.len(), 32);
}
#[test]
fn noise_ik_handshake() {
2026-02-27 10:18:23 +00:00
let server_kp = generate_keypair_raw().unwrap();
let client_kp = generate_keypair_raw().unwrap();
2026-02-27 10:18:23 +00:00
let initiator = create_initiator(&client_kp.private, &server_kp.public).unwrap();
2026-02-27 10:18:23 +00:00
let responder = create_responder(&server_kp.private).unwrap();
let (mut i_transport, mut r_transport, remote_key) =
2026-02-27 10:18:23 +00:00
perform_handshake(initiator, responder).unwrap();
// Verify the server received the client's public key
assert_eq!(remote_key, client_kp.public);
2026-02-27 10:18:23 +00:00
// Test encrypted communication
let mut buf = vec![0u8; 65535];
let plaintext = b"hello from client";
let len = i_transport.write_message(plaintext, &mut buf).unwrap();
let mut out = vec![0u8; 65535];
let len = r_transport.read_message(&buf[..len], &mut out).unwrap();
assert_eq!(&out[..len], plaintext);
// Reverse direction
let plaintext = b"hello from server";
let len = r_transport.write_message(plaintext, &mut buf).unwrap();
let len = i_transport.read_message(&buf[..len], &mut out).unwrap();
assert_eq!(&out[..len], plaintext);
}
#[test]
fn noise_ik_wrong_server_key_fails() {
let server_kp = generate_keypair_raw().unwrap();
let wrong_server_kp = generate_keypair_raw().unwrap();
let client_kp = generate_keypair_raw().unwrap();
// Client uses wrong server public key
let initiator = create_initiator(&client_kp.private, &wrong_server_kp.public).unwrap();
let responder = create_responder(&server_kp.private).unwrap();
// Handshake should fail because client targeted wrong server
assert!(perform_handshake(initiator, responder).is_err());
}
2026-02-27 10:18:23 +00:00
#[test]
fn xchacha_encrypt_decrypt() {
let key = [42u8; 32];
let plaintext = b"secret VPN payload data";
let encrypted = xchacha::encrypt(&key, plaintext).unwrap();
// encrypted = nonce(24) + ciphertext + tag(16)
assert_eq!(encrypted.len(), 24 + plaintext.len() + 16);
let decrypted = xchacha::decrypt(&key, &encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn xchacha_wrong_key_fails() {
let key = [42u8; 32];
let wrong_key = [43u8; 32];
let plaintext = b"secret data";
let encrypted = xchacha::encrypt(&key, plaintext).unwrap();
assert!(xchacha::decrypt(&wrong_key, &encrypted).is_err());
}
#[test]
fn xchacha_too_short_fails() {
let key = [42u8; 32];
let short = vec![0u8; 30]; // less than nonce + tag
assert!(xchacha::decrypt(&key, &short).is_err());
}
#[test]
fn xchacha_tampered_fails() {
let key = [42u8; 32];
let plaintext = b"secret data";
let mut encrypted = xchacha::encrypt(&key, plaintext).unwrap();
// Tamper with ciphertext
let last = encrypted.len() - 1;
encrypted[last] ^= 0xFF;
assert!(xchacha::decrypt(&key, &encrypted).is_err());
}
}