Files
smartproxy/rust/crates/rustproxy-security/src/jwt_auth.rs

175 lines
4.8 KiB
Rust
Raw Normal View History

use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure).
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: Option<String>,
pub exp: Option<u64>,
pub iss: Option<String>,
pub aud: Option<String>,
}
/// JWT auth validator.
pub struct JwtValidator {
decoding_key: DecodingKey,
validation: Validation,
}
impl JwtValidator {
pub fn new(
secret: &str,
algorithm: Option<&str>,
issuer: Option<&str>,
audience: Option<&str>,
) -> Self {
let algo = match algorithm {
Some("HS384") => Algorithm::HS384,
Some("HS512") => Algorithm::HS512,
Some("RS256") => Algorithm::RS256,
_ => Algorithm::HS256,
};
let mut validation = Validation::new(algo);
if let Some(iss) = issuer {
validation.set_issuer(&[iss]);
}
if let Some(aud) = audience {
validation.set_audience(&[aud]);
}
Self {
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
validation,
}
}
/// Validate a JWT token string (without "Bearer " prefix).
/// Returns the claims if valid.
pub fn validate(&self, token: &str) -> Result<Claims, String> {
decode::<Claims>(token, &self.decoding_key, &self.validation)
.map(|data| data.claims)
.map_err(|e| e.to_string())
}
/// Extract token from Authorization header.
pub fn extract_token(auth_header: &str) -> Option<&str> {
let header = auth_header.trim();
if header.starts_with("Bearer ") {
Some(&header[7..])
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{encode, EncodingKey, Header};
fn make_token(secret: &str, claims: &Claims) -> String {
encode(
&Header::default(),
claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
fn future_exp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600
}
fn past_exp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 3600
}
#[test]
fn test_valid_token() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: None,
aud: None,
};
let token = make_token(secret, &claims);
let validator = JwtValidator::new(secret, None, None, None);
let result = validator.validate(&token);
assert!(result.is_ok());
assert_eq!(result.unwrap().sub, Some("user123".to_string()));
}
#[test]
fn test_expired_token() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(past_exp()),
iss: None,
aud: None,
};
let token = make_token(secret, &claims);
let validator = JwtValidator::new(secret, None, None, None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_wrong_secret() {
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: None,
aud: None,
};
let token = make_token("correct-secret", &claims);
let validator = JwtValidator::new("wrong-secret", None, None, None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_issuer_validation() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: Some("my-issuer".to_string()),
aud: None,
};
let token = make_token(secret, &claims);
// Correct issuer
let validator = JwtValidator::new(secret, None, Some("my-issuer"), None);
assert!(validator.validate(&token).is_ok());
// Wrong issuer
let validator = JwtValidator::new(secret, None, Some("other-issuer"), None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_extract_token_bearer() {
assert_eq!(
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
}
#[test]
fn test_extract_token_non_bearer() {
assert_eq!(JwtValidator::extract_token("Basic abc123"), None);
assert_eq!(JwtValidator::extract_token("abc123"), None);
}
}