547 lines
17 KiB
Rust
547 lines
17 KiB
Rust
use hmac::{Hmac, Mac};
|
|
use hyper::body::Incoming;
|
|
use hyper::Request;
|
|
use sha2::{Digest, Sha256};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::path::PathBuf;
|
|
use tokio::fs;
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::config::{AuthConfig, Credential};
|
|
use crate::error::StorageError;
|
|
|
|
type HmacSha256 = Hmac<Sha256>;
|
|
|
|
/// The identity of an authenticated caller.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AuthenticatedIdentity {
|
|
pub access_key_id: String,
|
|
pub bucket_name: Option<String>,
|
|
}
|
|
|
|
/// Parsed components of an AWS4-HMAC-SHA256 Authorization header.
|
|
struct SigV4Header {
|
|
access_key_id: String,
|
|
date_stamp: String,
|
|
region: String,
|
|
signed_headers: Vec<String>,
|
|
signature: String,
|
|
}
|
|
|
|
/// Verify the request's SigV4 signature. Returns the caller identity on success.
|
|
pub fn verify_request(
|
|
req: &Request<Incoming>,
|
|
credentials: &[Credential],
|
|
) -> Result<AuthenticatedIdentity, StorageError> {
|
|
let auth_header = req
|
|
.headers()
|
|
.get("authorization")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("");
|
|
|
|
// Reject SigV2
|
|
if auth_header.starts_with("AWS ") {
|
|
return Err(StorageError::authorization_header_malformed());
|
|
}
|
|
|
|
if !auth_header.starts_with("AWS4-HMAC-SHA256") {
|
|
return Err(StorageError::authorization_header_malformed());
|
|
}
|
|
|
|
let parsed = parse_auth_header(auth_header)?;
|
|
|
|
// Look up credential
|
|
let credential = find_credential(&parsed.access_key_id, credentials)
|
|
.ok_or_else(StorageError::invalid_access_key_id)?;
|
|
|
|
// Get x-amz-date
|
|
let amz_date = req
|
|
.headers()
|
|
.get("x-amz-date")
|
|
.and_then(|v| v.to_str().ok())
|
|
.or_else(|| req.headers().get("date").and_then(|v| v.to_str().ok()))
|
|
.ok_or_else(|| StorageError::missing_security_header("Missing x-amz-date header"))?;
|
|
|
|
// Enforce 15-min clock skew
|
|
check_clock_skew(amz_date)?;
|
|
|
|
// Get payload hash
|
|
let content_sha256 = req
|
|
.headers()
|
|
.get("x-amz-content-sha256")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("UNSIGNED-PAYLOAD");
|
|
|
|
// Build canonical request
|
|
let canonical_request = build_canonical_request(req, &parsed.signed_headers, content_sha256);
|
|
|
|
// Build string to sign
|
|
let scope = format!("{}/{}/s3/aws4_request", parsed.date_stamp, parsed.region);
|
|
let canonical_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
|
|
let string_to_sign = format!(
|
|
"AWS4-HMAC-SHA256\n{}\n{}\n{}",
|
|
amz_date, scope, canonical_hash
|
|
);
|
|
|
|
// Derive signing key
|
|
let signing_key = derive_signing_key(
|
|
&credential.secret_access_key,
|
|
&parsed.date_stamp,
|
|
&parsed.region,
|
|
);
|
|
|
|
// Compute signature
|
|
let computed = hmac_sha256(&signing_key, string_to_sign.as_bytes());
|
|
let computed_hex = hex::encode(&computed);
|
|
|
|
// Constant-time comparison
|
|
if !constant_time_eq(computed_hex.as_bytes(), parsed.signature.as_bytes()) {
|
|
return Err(StorageError::signature_does_not_match());
|
|
}
|
|
|
|
Ok(AuthenticatedIdentity {
|
|
access_key_id: parsed.access_key_id,
|
|
bucket_name: credential.bucket_name.clone(),
|
|
})
|
|
}
|
|
|
|
/// Parse the Authorization header into its components.
|
|
fn parse_auth_header(header: &str) -> Result<SigV4Header, StorageError> {
|
|
// Format: AWS4-HMAC-SHA256 Credential=KEY/YYYYMMDD/region/s3/aws4_request, SignedHeaders=h1;h2, Signature=hex
|
|
let after_algo = header
|
|
.strip_prefix("AWS4-HMAC-SHA256")
|
|
.ok_or_else(StorageError::authorization_header_malformed)?
|
|
.trim();
|
|
|
|
let mut credential_str = None;
|
|
let mut signed_headers_str = None;
|
|
let mut signature_str = None;
|
|
|
|
for part in after_algo.split(',') {
|
|
let part = part.trim();
|
|
if let Some(val) = part.strip_prefix("Credential=") {
|
|
credential_str = Some(val.trim());
|
|
} else if let Some(val) = part.strip_prefix("SignedHeaders=") {
|
|
signed_headers_str = Some(val.trim());
|
|
} else if let Some(val) = part.strip_prefix("Signature=") {
|
|
signature_str = Some(val.trim());
|
|
}
|
|
}
|
|
|
|
let credential_str = credential_str.ok_or_else(StorageError::authorization_header_malformed)?;
|
|
let signed_headers_str =
|
|
signed_headers_str.ok_or_else(StorageError::authorization_header_malformed)?;
|
|
let signature = signature_str
|
|
.ok_or_else(StorageError::authorization_header_malformed)?
|
|
.to_string();
|
|
|
|
// Parse credential: KEY/YYYYMMDD/region/s3/aws4_request
|
|
let cred_parts: Vec<&str> = credential_str.splitn(5, '/').collect();
|
|
if cred_parts.len() < 5 {
|
|
return Err(StorageError::authorization_header_malformed());
|
|
}
|
|
|
|
let access_key_id = cred_parts[0].to_string();
|
|
let date_stamp = cred_parts[1].to_string();
|
|
let region = cred_parts[2].to_string();
|
|
|
|
let signed_headers: Vec<String> = signed_headers_str
|
|
.split(';')
|
|
.map(|s| s.trim().to_lowercase())
|
|
.collect();
|
|
|
|
Ok(SigV4Header {
|
|
access_key_id,
|
|
date_stamp,
|
|
region,
|
|
signed_headers,
|
|
signature,
|
|
})
|
|
}
|
|
|
|
/// Find a credential by access key ID.
|
|
fn find_credential<'a>(
|
|
access_key_id: &str,
|
|
credentials: &'a [Credential],
|
|
) -> Option<&'a Credential> {
|
|
credentials
|
|
.iter()
|
|
.find(|c| c.access_key_id == access_key_id)
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct RuntimeCredentialStore {
|
|
enabled: bool,
|
|
credentials: RwLock<Vec<Credential>>,
|
|
persistence_path: Option<PathBuf>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct CredentialMetadata {
|
|
pub access_key_id: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub bucket_name: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub region: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct BucketTenantMetadata {
|
|
pub bucket_name: String,
|
|
pub access_key_id: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub region: Option<String>,
|
|
}
|
|
|
|
impl RuntimeCredentialStore {
|
|
pub async fn new(
|
|
config: &AuthConfig,
|
|
persistence_path: Option<PathBuf>,
|
|
) -> anyhow::Result<Self> {
|
|
let credentials = match persistence_path.as_ref() {
|
|
Some(path) if path.exists() => {
|
|
let content = fs::read_to_string(path).await?;
|
|
let credentials: Vec<Credential> = serde_json::from_str(&content)?;
|
|
validate_credentials(&credentials)
|
|
.map_err(|error| anyhow::anyhow!(error.message))?;
|
|
credentials
|
|
}
|
|
_ => config.credentials.clone(),
|
|
};
|
|
|
|
Ok(Self {
|
|
enabled: config.enabled,
|
|
credentials: RwLock::new(credentials),
|
|
persistence_path,
|
|
})
|
|
}
|
|
|
|
pub fn enabled(&self) -> bool {
|
|
self.enabled
|
|
}
|
|
|
|
pub async fn list_credentials(&self) -> Vec<CredentialMetadata> {
|
|
self.credentials
|
|
.read()
|
|
.await
|
|
.iter()
|
|
.map(|credential| CredentialMetadata {
|
|
access_key_id: credential.access_key_id.clone(),
|
|
bucket_name: credential.bucket_name.clone(),
|
|
region: credential.region.clone(),
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
pub async fn snapshot_credentials(&self) -> Vec<Credential> {
|
|
self.credentials.read().await.clone()
|
|
}
|
|
|
|
pub async fn replace_credentials(
|
|
&self,
|
|
credentials: Vec<Credential>,
|
|
) -> Result<(), StorageError> {
|
|
validate_credentials(&credentials)?;
|
|
self.persist_credentials(&credentials).await?;
|
|
*self.credentials.write().await = credentials;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn replace_bucket_tenant_credential(
|
|
&self,
|
|
bucket_name: &str,
|
|
mut credential: Credential,
|
|
) -> Result<Credential, StorageError> {
|
|
validate_bucket_scope(bucket_name)?;
|
|
credential.bucket_name = Some(bucket_name.to_string());
|
|
|
|
let mut credentials = self.credentials.read().await.clone();
|
|
if credentials.iter().any(|existing| {
|
|
existing.access_key_id == credential.access_key_id
|
|
&& existing.bucket_name.as_deref() != Some(bucket_name)
|
|
}) {
|
|
return Err(StorageError::invalid_request(
|
|
"Credential accessKeyId is already assigned to another principal.",
|
|
));
|
|
}
|
|
|
|
credentials.retain(|existing| existing.bucket_name.as_deref() != Some(bucket_name));
|
|
credentials.push(credential.clone());
|
|
validate_credentials(&credentials)?;
|
|
self.persist_credentials(&credentials).await?;
|
|
*self.credentials.write().await = credentials;
|
|
Ok(credential)
|
|
}
|
|
|
|
pub async fn remove_bucket_tenant_credentials(
|
|
&self,
|
|
bucket_name: &str,
|
|
access_key_id: Option<&str>,
|
|
) -> Result<usize, StorageError> {
|
|
validate_bucket_scope(bucket_name)?;
|
|
let mut credentials = self.credentials.read().await.clone();
|
|
let before = credentials.len();
|
|
credentials.retain(|credential| {
|
|
if credential.bucket_name.as_deref() != Some(bucket_name) {
|
|
return true;
|
|
}
|
|
|
|
if let Some(access_key_id) = access_key_id {
|
|
credential.access_key_id != access_key_id
|
|
} else {
|
|
false
|
|
}
|
|
});
|
|
|
|
let removed = before.saturating_sub(credentials.len());
|
|
if credentials.is_empty() {
|
|
return Err(StorageError::invalid_request(
|
|
"Cannot remove the last active credential.",
|
|
));
|
|
}
|
|
self.persist_credentials(&credentials).await?;
|
|
*self.credentials.write().await = credentials;
|
|
Ok(removed)
|
|
}
|
|
|
|
pub async fn list_bucket_tenants(&self) -> Vec<BucketTenantMetadata> {
|
|
let mut tenants: Vec<BucketTenantMetadata> = self
|
|
.credentials
|
|
.read()
|
|
.await
|
|
.iter()
|
|
.filter_map(|credential| {
|
|
credential
|
|
.bucket_name
|
|
.as_ref()
|
|
.map(|bucket_name| BucketTenantMetadata {
|
|
bucket_name: bucket_name.clone(),
|
|
access_key_id: credential.access_key_id.clone(),
|
|
region: credential.region.clone(),
|
|
})
|
|
})
|
|
.collect();
|
|
tenants.sort_by(|a, b| {
|
|
a.bucket_name
|
|
.cmp(&b.bucket_name)
|
|
.then_with(|| a.access_key_id.cmp(&b.access_key_id))
|
|
});
|
|
tenants
|
|
}
|
|
|
|
pub async fn get_bucket_tenant_credential(&self, bucket_name: &str) -> Option<Credential> {
|
|
self.credentials
|
|
.read()
|
|
.await
|
|
.iter()
|
|
.find(|credential| credential.bucket_name.as_deref() == Some(bucket_name))
|
|
.cloned()
|
|
}
|
|
|
|
async fn persist_credentials(&self, credentials: &[Credential]) -> Result<(), StorageError> {
|
|
let Some(path) = self.persistence_path.as_ref() else {
|
|
return Ok(());
|
|
};
|
|
|
|
if let Some(parent) = path.parent() {
|
|
fs::create_dir_all(parent)
|
|
.await
|
|
.map_err(|error| StorageError::internal_error(&error.to_string()))?;
|
|
}
|
|
|
|
let temp_path = path.with_extension("json.tmp");
|
|
let json = serde_json::to_string_pretty(credentials)
|
|
.map_err(|error| StorageError::internal_error(&error.to_string()))?;
|
|
fs::write(&temp_path, json)
|
|
.await
|
|
.map_err(|error| StorageError::internal_error(&error.to_string()))?;
|
|
fs::rename(&temp_path, path)
|
|
.await
|
|
.map_err(|error| StorageError::internal_error(&error.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn validate_bucket_scope(bucket_name: &str) -> Result<(), StorageError> {
|
|
if bucket_name.trim().is_empty() {
|
|
return Err(StorageError::invalid_request(
|
|
"Bucket tenant bucketName must not be empty.",
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_credentials(credentials: &[Credential]) -> Result<(), StorageError> {
|
|
if credentials.is_empty() {
|
|
return Err(StorageError::invalid_request(
|
|
"Credential replacement requires at least one credential.",
|
|
));
|
|
}
|
|
|
|
let mut seen_access_keys = HashSet::new();
|
|
for credential in credentials {
|
|
if credential.access_key_id.trim().is_empty() {
|
|
return Err(StorageError::invalid_request(
|
|
"Credential accessKeyId must not be empty.",
|
|
));
|
|
}
|
|
|
|
if credential.secret_access_key.trim().is_empty() {
|
|
return Err(StorageError::invalid_request(
|
|
"Credential secretAccessKey must not be empty.",
|
|
));
|
|
}
|
|
|
|
if !seen_access_keys.insert(credential.access_key_id.as_str()) {
|
|
return Err(StorageError::invalid_request(
|
|
"Credential accessKeyId values must be unique.",
|
|
));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Check clock skew (15 minutes max).
|
|
fn check_clock_skew(amz_date: &str) -> Result<(), StorageError> {
|
|
// Parse ISO 8601 basic format: YYYYMMDDTHHMMSSZ
|
|
let parsed = chrono::NaiveDateTime::parse_from_str(amz_date, "%Y%m%dT%H%M%SZ")
|
|
.map_err(|_| StorageError::authorization_header_malformed())?;
|
|
|
|
let request_time =
|
|
chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(parsed, chrono::Utc);
|
|
let now = chrono::Utc::now();
|
|
let diff = (now - request_time).num_seconds().unsigned_abs();
|
|
|
|
if diff > 15 * 60 {
|
|
return Err(StorageError::request_time_too_skewed());
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Build the canonical request string.
|
|
fn build_canonical_request(
|
|
req: &Request<Incoming>,
|
|
signed_headers: &[String],
|
|
payload_hash: &str,
|
|
) -> String {
|
|
let method = req.method().as_str();
|
|
let uri_path = req.uri().path();
|
|
|
|
// Canonical URI: the path, already percent-encoded by the client
|
|
let canonical_uri = if uri_path.is_empty() { "/" } else { uri_path };
|
|
|
|
// Canonical query string: sorted key=value pairs
|
|
let canonical_query = build_canonical_query(req.uri().query().unwrap_or(""));
|
|
|
|
// Canonical headers: sorted by lowercase header name
|
|
let canonical_headers = build_canonical_headers(req, signed_headers);
|
|
|
|
// Signed headers string
|
|
let signed_headers_str = signed_headers.join(";");
|
|
|
|
// Payload hash — accept UNSIGNED-PAYLOAD and STREAMING-AWS4-HMAC-SHA256-PAYLOAD as-is
|
|
let effective_payload_hash = if payload_hash == "UNSIGNED-PAYLOAD"
|
|
|| payload_hash == "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
|
|
{
|
|
payload_hash.to_string()
|
|
} else {
|
|
payload_hash.to_string()
|
|
};
|
|
|
|
format!(
|
|
"{}\n{}\n{}\n{}\n{}\n{}",
|
|
method,
|
|
canonical_uri,
|
|
canonical_query,
|
|
canonical_headers,
|
|
signed_headers_str,
|
|
effective_payload_hash
|
|
)
|
|
}
|
|
|
|
/// Build canonical query string (sorted key=value pairs).
|
|
fn build_canonical_query(query: &str) -> String {
|
|
if query.is_empty() {
|
|
return String::new();
|
|
}
|
|
|
|
let mut pairs: Vec<(String, String)> = Vec::new();
|
|
for pair in query.split('&') {
|
|
let mut parts = pair.splitn(2, '=');
|
|
let key = parts.next().unwrap_or("");
|
|
let value = parts.next().unwrap_or("");
|
|
pairs.push((key.to_string(), value.to_string()));
|
|
}
|
|
pairs.sort();
|
|
|
|
pairs
|
|
.iter()
|
|
.map(|(k, v)| format!("{}={}", k, v))
|
|
.collect::<Vec<_>>()
|
|
.join("&")
|
|
}
|
|
|
|
/// Build canonical headers string.
|
|
fn build_canonical_headers(req: &Request<Incoming>, signed_headers: &[String]) -> String {
|
|
let mut header_map: HashMap<String, Vec<String>> = HashMap::new();
|
|
|
|
for (name, value) in req.headers() {
|
|
let name_lower = name.as_str().to_lowercase();
|
|
if signed_headers.contains(&name_lower) {
|
|
if let Ok(val) = value.to_str() {
|
|
header_map
|
|
.entry(name_lower)
|
|
.or_default()
|
|
.push(val.trim().to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut result = String::new();
|
|
for header_name in signed_headers {
|
|
let values = header_map
|
|
.get(header_name)
|
|
.map(|v| v.join(","))
|
|
.unwrap_or_default();
|
|
result.push_str(header_name);
|
|
result.push(':');
|
|
result.push_str(&values);
|
|
result.push('\n');
|
|
}
|
|
result
|
|
}
|
|
|
|
/// Derive the signing key via 4-step HMAC chain.
|
|
fn derive_signing_key(secret_key: &str, date_stamp: &str, region: &str) -> Vec<u8> {
|
|
let k_secret = format!("AWS4{}", secret_key);
|
|
let k_date = hmac_sha256(k_secret.as_bytes(), date_stamp.as_bytes());
|
|
let k_region = hmac_sha256(&k_date, region.as_bytes());
|
|
let k_service = hmac_sha256(&k_region, b"s3");
|
|
hmac_sha256(&k_service, b"aws4_request")
|
|
}
|
|
|
|
/// Compute HMAC-SHA256.
|
|
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
|
|
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid");
|
|
mac.update(data);
|
|
mac.finalize().into_bytes().to_vec()
|
|
}
|
|
|
|
/// Constant-time byte comparison.
|
|
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
|
if a.len() != b.len() {
|
|
return false;
|
|
}
|
|
let mut diff = 0u8;
|
|
for (x, y) in a.iter().zip(b.iter()) {
|
|
diff |= x ^ y;
|
|
}
|
|
diff == 0
|
|
}
|