use hmac::{Hmac, Mac}; use hyper::body::Incoming; use hyper::Request; use sha2::{Digest, Sha256}; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use tokio::fs; use tokio::sync::RwLock; use crate::config::{AuthConfig, Credential}; use crate::error::StorageError; type HmacSha256 = Hmac; /// The identity of an authenticated caller. #[derive(Debug, Clone)] pub struct AuthenticatedIdentity { pub access_key_id: String, pub bucket_name: Option, } /// Parsed components of an AWS4-HMAC-SHA256 Authorization header. struct SigV4Header { access_key_id: String, date_stamp: String, region: String, signed_headers: Vec, signature: String, } /// Verify the request's SigV4 signature. Returns the caller identity on success. pub fn verify_request( req: &Request, credentials: &[Credential], ) -> Result { let auth_header = req .headers() .get("authorization") .and_then(|v| v.to_str().ok()) .unwrap_or(""); // Reject SigV2 if auth_header.starts_with("AWS ") { return Err(StorageError::authorization_header_malformed()); } if !auth_header.starts_with("AWS4-HMAC-SHA256") { return Err(StorageError::authorization_header_malformed()); } let parsed = parse_auth_header(auth_header)?; // Look up credential let credential = find_credential(&parsed.access_key_id, credentials) .ok_or_else(StorageError::invalid_access_key_id)?; // Get x-amz-date let amz_date = req .headers() .get("x-amz-date") .and_then(|v| v.to_str().ok()) .or_else(|| req.headers().get("date").and_then(|v| v.to_str().ok())) .ok_or_else(|| StorageError::missing_security_header("Missing x-amz-date header"))?; // Enforce 15-min clock skew check_clock_skew(amz_date)?; // Get payload hash let content_sha256 = req .headers() .get("x-amz-content-sha256") .and_then(|v| v.to_str().ok()) .unwrap_or("UNSIGNED-PAYLOAD"); // Build canonical request let canonical_request = build_canonical_request(req, &parsed.signed_headers, content_sha256); // Build string to sign let scope = format!("{}/{}/s3/aws4_request", parsed.date_stamp, parsed.region); let canonical_hash = hex::encode(Sha256::digest(canonical_request.as_bytes())); let string_to_sign = format!( "AWS4-HMAC-SHA256\n{}\n{}\n{}", amz_date, scope, canonical_hash ); // Derive signing key let signing_key = derive_signing_key( &credential.secret_access_key, &parsed.date_stamp, &parsed.region, ); // Compute signature let computed = hmac_sha256(&signing_key, string_to_sign.as_bytes()); let computed_hex = hex::encode(&computed); // Constant-time comparison if !constant_time_eq(computed_hex.as_bytes(), parsed.signature.as_bytes()) { return Err(StorageError::signature_does_not_match()); } Ok(AuthenticatedIdentity { access_key_id: parsed.access_key_id, bucket_name: credential.bucket_name.clone(), }) } /// Parse the Authorization header into its components. fn parse_auth_header(header: &str) -> Result { // Format: AWS4-HMAC-SHA256 Credential=KEY/YYYYMMDD/region/s3/aws4_request, SignedHeaders=h1;h2, Signature=hex let after_algo = header .strip_prefix("AWS4-HMAC-SHA256") .ok_or_else(StorageError::authorization_header_malformed)? .trim(); let mut credential_str = None; let mut signed_headers_str = None; let mut signature_str = None; for part in after_algo.split(',') { let part = part.trim(); if let Some(val) = part.strip_prefix("Credential=") { credential_str = Some(val.trim()); } else if let Some(val) = part.strip_prefix("SignedHeaders=") { signed_headers_str = Some(val.trim()); } else if let Some(val) = part.strip_prefix("Signature=") { signature_str = Some(val.trim()); } } let credential_str = credential_str.ok_or_else(StorageError::authorization_header_malformed)?; let signed_headers_str = signed_headers_str.ok_or_else(StorageError::authorization_header_malformed)?; let signature = signature_str .ok_or_else(StorageError::authorization_header_malformed)? .to_string(); // Parse credential: KEY/YYYYMMDD/region/s3/aws4_request let cred_parts: Vec<&str> = credential_str.splitn(5, '/').collect(); if cred_parts.len() < 5 { return Err(StorageError::authorization_header_malformed()); } let access_key_id = cred_parts[0].to_string(); let date_stamp = cred_parts[1].to_string(); let region = cred_parts[2].to_string(); let signed_headers: Vec = signed_headers_str .split(';') .map(|s| s.trim().to_lowercase()) .collect(); Ok(SigV4Header { access_key_id, date_stamp, region, signed_headers, signature, }) } /// Find a credential by access key ID. fn find_credential<'a>( access_key_id: &str, credentials: &'a [Credential], ) -> Option<&'a Credential> { credentials .iter() .find(|c| c.access_key_id == access_key_id) } #[derive(Debug)] pub struct RuntimeCredentialStore { enabled: bool, credentials: RwLock>, persistence_path: Option, } #[derive(Debug, Clone, serde::Serialize)] #[serde(rename_all = "camelCase")] pub struct CredentialMetadata { pub access_key_id: String, #[serde(skip_serializing_if = "Option::is_none")] pub bucket_name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub region: Option, } #[derive(Debug, Clone, serde::Serialize)] #[serde(rename_all = "camelCase")] pub struct BucketTenantMetadata { pub bucket_name: String, pub access_key_id: String, #[serde(skip_serializing_if = "Option::is_none")] pub region: Option, } impl RuntimeCredentialStore { pub async fn new( config: &AuthConfig, persistence_path: Option, ) -> anyhow::Result { let credentials = match persistence_path.as_ref() { Some(path) if path.exists() => { let content = fs::read_to_string(path).await?; let credentials: Vec = serde_json::from_str(&content)?; validate_credentials(&credentials) .map_err(|error| anyhow::anyhow!(error.message))?; credentials } _ => config.credentials.clone(), }; Ok(Self { enabled: config.enabled, credentials: RwLock::new(credentials), persistence_path, }) } pub fn enabled(&self) -> bool { self.enabled } pub async fn list_credentials(&self) -> Vec { self.credentials .read() .await .iter() .map(|credential| CredentialMetadata { access_key_id: credential.access_key_id.clone(), bucket_name: credential.bucket_name.clone(), region: credential.region.clone(), }) .collect() } pub async fn snapshot_credentials(&self) -> Vec { self.credentials.read().await.clone() } pub async fn replace_credentials( &self, credentials: Vec, ) -> Result<(), StorageError> { validate_credentials(&credentials)?; self.persist_credentials(&credentials).await?; *self.credentials.write().await = credentials; Ok(()) } pub async fn replace_bucket_tenant_credential( &self, bucket_name: &str, mut credential: Credential, ) -> Result { validate_bucket_scope(bucket_name)?; credential.bucket_name = Some(bucket_name.to_string()); let mut credentials = self.credentials.read().await.clone(); if credentials.iter().any(|existing| { existing.access_key_id == credential.access_key_id && existing.bucket_name.as_deref() != Some(bucket_name) }) { return Err(StorageError::invalid_request( "Credential accessKeyId is already assigned to another principal.", )); } credentials.retain(|existing| existing.bucket_name.as_deref() != Some(bucket_name)); credentials.push(credential.clone()); validate_credentials(&credentials)?; self.persist_credentials(&credentials).await?; *self.credentials.write().await = credentials; Ok(credential) } pub async fn remove_bucket_tenant_credentials( &self, bucket_name: &str, access_key_id: Option<&str>, ) -> Result { validate_bucket_scope(bucket_name)?; let mut credentials = self.credentials.read().await.clone(); let before = credentials.len(); credentials.retain(|credential| { if credential.bucket_name.as_deref() != Some(bucket_name) { return true; } if let Some(access_key_id) = access_key_id { credential.access_key_id != access_key_id } else { false } }); let removed = before.saturating_sub(credentials.len()); if credentials.is_empty() { return Err(StorageError::invalid_request( "Cannot remove the last active credential.", )); } self.persist_credentials(&credentials).await?; *self.credentials.write().await = credentials; Ok(removed) } pub async fn list_bucket_tenants(&self) -> Vec { let mut tenants: Vec = self .credentials .read() .await .iter() .filter_map(|credential| { credential .bucket_name .as_ref() .map(|bucket_name| BucketTenantMetadata { bucket_name: bucket_name.clone(), access_key_id: credential.access_key_id.clone(), region: credential.region.clone(), }) }) .collect(); tenants.sort_by(|a, b| { a.bucket_name .cmp(&b.bucket_name) .then_with(|| a.access_key_id.cmp(&b.access_key_id)) }); tenants } pub async fn get_bucket_tenant_credential(&self, bucket_name: &str) -> Option { self.credentials .read() .await .iter() .find(|credential| credential.bucket_name.as_deref() == Some(bucket_name)) .cloned() } async fn persist_credentials(&self, credentials: &[Credential]) -> Result<(), StorageError> { let Some(path) = self.persistence_path.as_ref() else { return Ok(()); }; if let Some(parent) = path.parent() { fs::create_dir_all(parent) .await .map_err(|error| StorageError::internal_error(&error.to_string()))?; } let temp_path = path.with_extension("json.tmp"); let json = serde_json::to_string_pretty(credentials) .map_err(|error| StorageError::internal_error(&error.to_string()))?; fs::write(&temp_path, json) .await .map_err(|error| StorageError::internal_error(&error.to_string()))?; fs::rename(&temp_path, path) .await .map_err(|error| StorageError::internal_error(&error.to_string()))?; Ok(()) } } fn validate_bucket_scope(bucket_name: &str) -> Result<(), StorageError> { if bucket_name.trim().is_empty() { return Err(StorageError::invalid_request( "Bucket tenant bucketName must not be empty.", )); } Ok(()) } fn validate_credentials(credentials: &[Credential]) -> Result<(), StorageError> { if credentials.is_empty() { return Err(StorageError::invalid_request( "Credential replacement requires at least one credential.", )); } let mut seen_access_keys = HashSet::new(); for credential in credentials { if credential.access_key_id.trim().is_empty() { return Err(StorageError::invalid_request( "Credential accessKeyId must not be empty.", )); } if credential.secret_access_key.trim().is_empty() { return Err(StorageError::invalid_request( "Credential secretAccessKey must not be empty.", )); } if !seen_access_keys.insert(credential.access_key_id.as_str()) { return Err(StorageError::invalid_request( "Credential accessKeyId values must be unique.", )); } } Ok(()) } /// Check clock skew (15 minutes max). fn check_clock_skew(amz_date: &str) -> Result<(), StorageError> { // Parse ISO 8601 basic format: YYYYMMDDTHHMMSSZ let parsed = chrono::NaiveDateTime::parse_from_str(amz_date, "%Y%m%dT%H%M%SZ") .map_err(|_| StorageError::authorization_header_malformed())?; let request_time = chrono::DateTime::::from_naive_utc_and_offset(parsed, chrono::Utc); let now = chrono::Utc::now(); let diff = (now - request_time).num_seconds().unsigned_abs(); if diff > 15 * 60 { return Err(StorageError::request_time_too_skewed()); } Ok(()) } /// Build the canonical request string. fn build_canonical_request( req: &Request, signed_headers: &[String], payload_hash: &str, ) -> String { let method = req.method().as_str(); let uri_path = req.uri().path(); // Canonical URI: the path, already percent-encoded by the client let canonical_uri = if uri_path.is_empty() { "/" } else { uri_path }; // Canonical query string: sorted key=value pairs let canonical_query = build_canonical_query(req.uri().query().unwrap_or("")); // Canonical headers: sorted by lowercase header name let canonical_headers = build_canonical_headers(req, signed_headers); // Signed headers string let signed_headers_str = signed_headers.join(";"); // Payload hash — accept UNSIGNED-PAYLOAD and STREAMING-AWS4-HMAC-SHA256-PAYLOAD as-is let effective_payload_hash = if payload_hash == "UNSIGNED-PAYLOAD" || payload_hash == "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" { payload_hash.to_string() } else { payload_hash.to_string() }; format!( "{}\n{}\n{}\n{}\n{}\n{}", method, canonical_uri, canonical_query, canonical_headers, signed_headers_str, effective_payload_hash ) } /// Build canonical query string (sorted key=value pairs). fn build_canonical_query(query: &str) -> String { if query.is_empty() { return String::new(); } let mut pairs: Vec<(String, String)> = Vec::new(); for pair in query.split('&') { let mut parts = pair.splitn(2, '='); let key = parts.next().unwrap_or(""); let value = parts.next().unwrap_or(""); pairs.push((key.to_string(), value.to_string())); } pairs.sort(); pairs .iter() .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join("&") } /// Build canonical headers string. fn build_canonical_headers(req: &Request, signed_headers: &[String]) -> String { let mut header_map: HashMap> = HashMap::new(); for (name, value) in req.headers() { let name_lower = name.as_str().to_lowercase(); if signed_headers.contains(&name_lower) { if let Ok(val) = value.to_str() { header_map .entry(name_lower) .or_default() .push(val.trim().to_string()); } } } let mut result = String::new(); for header_name in signed_headers { let values = header_map .get(header_name) .map(|v| v.join(",")) .unwrap_or_default(); result.push_str(header_name); result.push(':'); result.push_str(&values); result.push('\n'); } result } /// Derive the signing key via 4-step HMAC chain. fn derive_signing_key(secret_key: &str, date_stamp: &str, region: &str) -> Vec { let k_secret = format!("AWS4{}", secret_key); let k_date = hmac_sha256(k_secret.as_bytes(), date_stamp.as_bytes()); let k_region = hmac_sha256(&k_date, region.as_bytes()); let k_service = hmac_sha256(&k_region, b"s3"); hmac_sha256(&k_service, b"aws4_request") } /// Compute HMAC-SHA256. fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec { let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid"); mac.update(data); mac.finalize().into_bytes().to_vec() } /// Constant-time byte comparison. fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { if a.len() != b.len() { return false; } let mut diff = 0u8; for (x, y) in a.iter().zip(b.iter()) { diff |= x ^ y; } diff == 0 }