112 lines
3.0 KiB
Rust
112 lines
3.0 KiB
Rust
|
|
use base64::Engine;
|
||
|
|
use base64::engine::general_purpose::STANDARD as BASE64;
|
||
|
|
|
||
|
|
/// Basic auth validator.
|
||
|
|
pub struct BasicAuthValidator {
|
||
|
|
users: Vec<(String, String)>,
|
||
|
|
realm: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl BasicAuthValidator {
|
||
|
|
pub fn new(users: Vec<(String, String)>, realm: Option<String>) -> Self {
|
||
|
|
Self {
|
||
|
|
users,
|
||
|
|
realm: realm.unwrap_or_else(|| "Restricted".to_string()),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Validate an Authorization header value.
|
||
|
|
/// Returns the username if valid.
|
||
|
|
pub fn validate(&self, auth_header: &str) -> Option<String> {
|
||
|
|
let auth_header = auth_header.trim();
|
||
|
|
if !auth_header.starts_with("Basic ") {
|
||
|
|
return None;
|
||
|
|
}
|
||
|
|
|
||
|
|
let encoded = &auth_header[6..];
|
||
|
|
let decoded = BASE64.decode(encoded).ok()?;
|
||
|
|
let credentials = String::from_utf8(decoded).ok()?;
|
||
|
|
|
||
|
|
let mut parts = credentials.splitn(2, ':');
|
||
|
|
let username = parts.next()?;
|
||
|
|
let password = parts.next()?;
|
||
|
|
|
||
|
|
for (u, p) in &self.users {
|
||
|
|
if u == username && p == password {
|
||
|
|
return Some(username.to_string());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
None
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get the realm for WWW-Authenticate header.
|
||
|
|
pub fn realm(&self) -> &str {
|
||
|
|
&self.realm
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Generate the WWW-Authenticate header value.
|
||
|
|
pub fn www_authenticate(&self) -> String {
|
||
|
|
format!("Basic realm=\"{}\"", self.realm)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use base64::Engine;
|
||
|
|
|
||
|
|
fn make_validator() -> BasicAuthValidator {
|
||
|
|
BasicAuthValidator::new(
|
||
|
|
vec![
|
||
|
|
("admin".to_string(), "secret".to_string()),
|
||
|
|
("user".to_string(), "pass".to_string()),
|
||
|
|
],
|
||
|
|
Some("TestRealm".to_string()),
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn encode_basic(user: &str, pass: &str) -> String {
|
||
|
|
let encoded = BASE64.encode(format!("{}:{}", user, pass));
|
||
|
|
format!("Basic {}", encoded)
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_valid_credentials() {
|
||
|
|
let validator = make_validator();
|
||
|
|
let header = encode_basic("admin", "secret");
|
||
|
|
assert_eq!(validator.validate(&header), Some("admin".to_string()));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_invalid_password() {
|
||
|
|
let validator = make_validator();
|
||
|
|
let header = encode_basic("admin", "wrong");
|
||
|
|
assert_eq!(validator.validate(&header), None);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_not_basic_scheme() {
|
||
|
|
let validator = make_validator();
|
||
|
|
assert_eq!(validator.validate("Bearer sometoken"), None);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_malformed_base64() {
|
||
|
|
let validator = make_validator();
|
||
|
|
assert_eq!(validator.validate("Basic !!!not-base64!!!"), None);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_www_authenticate_format() {
|
||
|
|
let validator = make_validator();
|
||
|
|
assert_eq!(validator.www_authenticate(), "Basic realm=\"TestRealm\"");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_default_realm() {
|
||
|
|
let validator = BasicAuthValidator::new(vec![], None);
|
||
|
|
assert_eq!(validator.www_authenticate(), "Basic realm=\"Restricted\"");
|
||
|
|
}
|
||
|
|
}
|