Files

204 lines
6.9 KiB
Rust
Raw Permalink Normal View History

2026-02-27 10:18:23 +00:00
use anyhow::Result;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use snow::Builder;
/// Noise protocol pattern: NK (client knows server pubkey, no client auth at Noise level)
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
/// Generate a new Noise static keypair.
/// Returns (public_key_base64, private_key_base64).
pub fn generate_keypair() -> Result<(String, String)> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
let keypair = builder.generate_keypair()?;
let public_key = BASE64.encode(&keypair.public);
let private_key = BASE64.encode(&keypair.private);
Ok((public_key, private_key))
}
/// Generate a raw Noise static keypair (not base64 encoded).
pub fn generate_keypair_raw() -> Result<snow::Keypair> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
Ok(builder.generate_keypair()?)
}
/// Create a Noise NK initiator (client side).
/// The client knows the server's static public key.
pub fn create_initiator(server_public_key: &[u8]) -> Result<snow::HandshakeState> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
let state = builder
.remote_public_key(server_public_key)
.build_initiator()?;
Ok(state)
}
/// Create a Noise NK responder (server side).
/// The server uses its static private key.
pub fn create_responder(private_key: &[u8]) -> Result<snow::HandshakeState> {
let builder = Builder::new(NOISE_PATTERN.parse()?);
let state = builder
.local_private_key(private_key)
.build_responder()?;
Ok(state)
}
/// Perform the full Noise NK handshake between initiator and responder.
/// Returns (initiator_transport, responder_transport).
pub fn perform_handshake(
mut initiator: snow::HandshakeState,
mut responder: snow::HandshakeState,
) -> Result<(snow::TransportState, snow::TransportState)> {
let mut buf = vec![0u8; 65535];
// -> e, es (initiator sends)
let len = initiator.write_message(&[], &mut buf)?;
let msg1 = buf[..len].to_vec();
// <- e, ee (responder reads and responds)
responder.read_message(&msg1, &mut buf)?;
let len = responder.write_message(&[], &mut buf)?;
let msg2 = buf[..len].to_vec();
// Initiator reads response
initiator.read_message(&msg2, &mut buf)?;
let i_transport = initiator.into_transport_mode()?;
let r_transport = responder.into_transport_mode()?;
Ok((i_transport, r_transport))
}
/// XChaCha20-Poly1305 encryption for post-handshake data.
/// Uses random 24-byte nonces (safe due to large nonce space).
pub mod xchacha {
use anyhow::Result;
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, KeyInit},
};
use rand::RngCore;
pub const NONCE_SIZE: usize = 24;
pub const TAG_SIZE: usize = 16;
/// Encrypt plaintext with XChaCha20-Poly1305.
/// Returns: nonce (24 bytes) + ciphertext + tag (16 bytes).
pub fn encrypt(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
let cipher = XChaCha20Poly1305::new(key.into());
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
output.extend_from_slice(&nonce_bytes);
output.extend_from_slice(&ciphertext);
Ok(output)
}
/// Decrypt data encrypted with `encrypt()`.
/// Input: nonce (24 bytes) + ciphertext + tag (16 bytes).
pub fn decrypt(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
if data.len() < NONCE_SIZE + TAG_SIZE {
anyhow::bail!("Ciphertext too short: {} bytes", data.len());
}
let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
let nonce = XNonce::from_slice(nonce_bytes);
let cipher = XChaCha20Poly1305::new(key.into());
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
Ok(plaintext)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn keypair_generation() {
let (pub_key, priv_key) = generate_keypair().unwrap();
// Base64-encoded 32-byte keys = 44 chars
assert_eq!(pub_key.len(), 44);
assert_eq!(priv_key.len(), 44);
// Verify they decode back to 32 bytes
let pub_bytes = BASE64.decode(&pub_key).unwrap();
let priv_bytes = BASE64.decode(&priv_key).unwrap();
assert_eq!(pub_bytes.len(), 32);
assert_eq!(priv_bytes.len(), 32);
}
#[test]
fn noise_handshake() {
let server_kp = generate_keypair_raw().unwrap();
let initiator = create_initiator(&server_kp.public).unwrap();
let responder = create_responder(&server_kp.private).unwrap();
let (mut i_transport, mut r_transport) =
perform_handshake(initiator, responder).unwrap();
// Test encrypted communication
let mut buf = vec![0u8; 65535];
let plaintext = b"hello from client";
let len = i_transport.write_message(plaintext, &mut buf).unwrap();
let mut out = vec![0u8; 65535];
let len = r_transport.read_message(&buf[..len], &mut out).unwrap();
assert_eq!(&out[..len], plaintext);
// Reverse direction
let plaintext = b"hello from server";
let len = r_transport.write_message(plaintext, &mut buf).unwrap();
let len = i_transport.read_message(&buf[..len], &mut out).unwrap();
assert_eq!(&out[..len], plaintext);
}
#[test]
fn xchacha_encrypt_decrypt() {
let key = [42u8; 32];
let plaintext = b"secret VPN payload data";
let encrypted = xchacha::encrypt(&key, plaintext).unwrap();
// encrypted = nonce(24) + ciphertext + tag(16)
assert_eq!(encrypted.len(), 24 + plaintext.len() + 16);
let decrypted = xchacha::decrypt(&key, &encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn xchacha_wrong_key_fails() {
let key = [42u8; 32];
let wrong_key = [43u8; 32];
let plaintext = b"secret data";
let encrypted = xchacha::encrypt(&key, plaintext).unwrap();
assert!(xchacha::decrypt(&wrong_key, &encrypted).is_err());
}
#[test]
fn xchacha_too_short_fails() {
let key = [42u8; 32];
let short = vec![0u8; 30]; // less than nonce + tag
assert!(xchacha::decrypt(&key, &short).is_err());
}
#[test]
fn xchacha_tampered_fails() {
let key = [42u8; 32];
let plaintext = b"secret data";
let mut encrypted = xchacha::encrypt(&key, plaintext).unwrap();
// Tamper with ciphertext
let last = encrypted.len() - 1;
encrypted[last] ^= 0xFF;
assert!(xchacha::decrypt(&key, &encrypted).is_err());
}
}