175 lines
4.8 KiB
Rust
175 lines
4.8 KiB
Rust
|
|
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
|
||
|
|
/// JWT claims (minimal structure).
|
||
|
|
#[derive(Debug, Serialize, Deserialize)]
|
||
|
|
pub struct Claims {
|
||
|
|
pub sub: Option<String>,
|
||
|
|
pub exp: Option<u64>,
|
||
|
|
pub iss: Option<String>,
|
||
|
|
pub aud: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// JWT auth validator.
|
||
|
|
pub struct JwtValidator {
|
||
|
|
decoding_key: DecodingKey,
|
||
|
|
validation: Validation,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl JwtValidator {
|
||
|
|
pub fn new(
|
||
|
|
secret: &str,
|
||
|
|
algorithm: Option<&str>,
|
||
|
|
issuer: Option<&str>,
|
||
|
|
audience: Option<&str>,
|
||
|
|
) -> Self {
|
||
|
|
let algo = match algorithm {
|
||
|
|
Some("HS384") => Algorithm::HS384,
|
||
|
|
Some("HS512") => Algorithm::HS512,
|
||
|
|
Some("RS256") => Algorithm::RS256,
|
||
|
|
_ => Algorithm::HS256,
|
||
|
|
};
|
||
|
|
|
||
|
|
let mut validation = Validation::new(algo);
|
||
|
|
if let Some(iss) = issuer {
|
||
|
|
validation.set_issuer(&[iss]);
|
||
|
|
}
|
||
|
|
if let Some(aud) = audience {
|
||
|
|
validation.set_audience(&[aud]);
|
||
|
|
}
|
||
|
|
|
||
|
|
Self {
|
||
|
|
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
|
||
|
|
validation,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Validate a JWT token string (without "Bearer " prefix).
|
||
|
|
/// Returns the claims if valid.
|
||
|
|
pub fn validate(&self, token: &str) -> Result<Claims, String> {
|
||
|
|
decode::<Claims>(token, &self.decoding_key, &self.validation)
|
||
|
|
.map(|data| data.claims)
|
||
|
|
.map_err(|e| e.to_string())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Extract token from Authorization header.
|
||
|
|
pub fn extract_token(auth_header: &str) -> Option<&str> {
|
||
|
|
let header = auth_header.trim();
|
||
|
|
if header.starts_with("Bearer ") {
|
||
|
|
Some(&header[7..])
|
||
|
|
} else {
|
||
|
|
None
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use jsonwebtoken::{encode, EncodingKey, Header};
|
||
|
|
|
||
|
|
fn make_token(secret: &str, claims: &Claims) -> String {
|
||
|
|
encode(
|
||
|
|
&Header::default(),
|
||
|
|
claims,
|
||
|
|
&EncodingKey::from_secret(secret.as_bytes()),
|
||
|
|
)
|
||
|
|
.unwrap()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn future_exp() -> u64 {
|
||
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||
|
|
SystemTime::now()
|
||
|
|
.duration_since(UNIX_EPOCH)
|
||
|
|
.unwrap()
|
||
|
|
.as_secs()
|
||
|
|
+ 3600
|
||
|
|
}
|
||
|
|
|
||
|
|
fn past_exp() -> u64 {
|
||
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||
|
|
SystemTime::now()
|
||
|
|
.duration_since(UNIX_EPOCH)
|
||
|
|
.unwrap()
|
||
|
|
.as_secs()
|
||
|
|
- 3600
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_valid_token() {
|
||
|
|
let secret = "test-secret";
|
||
|
|
let claims = Claims {
|
||
|
|
sub: Some("user123".to_string()),
|
||
|
|
exp: Some(future_exp()),
|
||
|
|
iss: None,
|
||
|
|
aud: None,
|
||
|
|
};
|
||
|
|
let token = make_token(secret, &claims);
|
||
|
|
let validator = JwtValidator::new(secret, None, None, None);
|
||
|
|
let result = validator.validate(&token);
|
||
|
|
assert!(result.is_ok());
|
||
|
|
assert_eq!(result.unwrap().sub, Some("user123".to_string()));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_expired_token() {
|
||
|
|
let secret = "test-secret";
|
||
|
|
let claims = Claims {
|
||
|
|
sub: Some("user123".to_string()),
|
||
|
|
exp: Some(past_exp()),
|
||
|
|
iss: None,
|
||
|
|
aud: None,
|
||
|
|
};
|
||
|
|
let token = make_token(secret, &claims);
|
||
|
|
let validator = JwtValidator::new(secret, None, None, None);
|
||
|
|
assert!(validator.validate(&token).is_err());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_wrong_secret() {
|
||
|
|
let claims = Claims {
|
||
|
|
sub: Some("user123".to_string()),
|
||
|
|
exp: Some(future_exp()),
|
||
|
|
iss: None,
|
||
|
|
aud: None,
|
||
|
|
};
|
||
|
|
let token = make_token("correct-secret", &claims);
|
||
|
|
let validator = JwtValidator::new("wrong-secret", None, None, None);
|
||
|
|
assert!(validator.validate(&token).is_err());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_issuer_validation() {
|
||
|
|
let secret = "test-secret";
|
||
|
|
let claims = Claims {
|
||
|
|
sub: Some("user123".to_string()),
|
||
|
|
exp: Some(future_exp()),
|
||
|
|
iss: Some("my-issuer".to_string()),
|
||
|
|
aud: None,
|
||
|
|
};
|
||
|
|
let token = make_token(secret, &claims);
|
||
|
|
|
||
|
|
// Correct issuer
|
||
|
|
let validator = JwtValidator::new(secret, None, Some("my-issuer"), None);
|
||
|
|
assert!(validator.validate(&token).is_ok());
|
||
|
|
|
||
|
|
// Wrong issuer
|
||
|
|
let validator = JwtValidator::new(secret, None, Some("other-issuer"), None);
|
||
|
|
assert!(validator.validate(&token).is_err());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_token_bearer() {
|
||
|
|
assert_eq!(
|
||
|
|
JwtValidator::extract_token("Bearer abc123"),
|
||
|
|
Some("abc123")
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_token_non_bearer() {
|
||
|
|
assert_eq!(JwtValidator::extract_token("Basic abc123"), None);
|
||
|
|
assert_eq!(JwtValidator::extract_token("abc123"), None);
|
||
|
|
}
|
||
|
|
}
|