feat(auth,policy): add AWS SigV4 authentication and S3 bucket policy support
This commit is contained in:
310
rust/src/auth.rs
Normal file
310
rust/src/auth.rs
Normal file
@@ -0,0 +1,310 @@
|
||||
use hmac::{Hmac, Mac};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::Request;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::config::{Credential, S3Config};
|
||||
use crate::s3_error::S3Error;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// The identity of an authenticated caller.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthenticatedIdentity {
|
||||
pub access_key_id: String,
|
||||
}
|
||||
|
||||
/// Parsed components of an AWS4-HMAC-SHA256 Authorization header.
|
||||
struct SigV4Header {
|
||||
access_key_id: String,
|
||||
date_stamp: String,
|
||||
region: String,
|
||||
signed_headers: Vec<String>,
|
||||
signature: String,
|
||||
}
|
||||
|
||||
/// Verify the request's SigV4 signature. Returns the caller identity on success.
|
||||
pub fn verify_request(
|
||||
req: &Request<Incoming>,
|
||||
config: &S3Config,
|
||||
) -> Result<AuthenticatedIdentity, S3Error> {
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Reject SigV2
|
||||
if auth_header.starts_with("AWS ") {
|
||||
return Err(S3Error::authorization_header_malformed());
|
||||
}
|
||||
|
||||
if !auth_header.starts_with("AWS4-HMAC-SHA256") {
|
||||
return Err(S3Error::authorization_header_malformed());
|
||||
}
|
||||
|
||||
let parsed = parse_auth_header(auth_header)?;
|
||||
|
||||
// Look up credential
|
||||
let credential = find_credential(&parsed.access_key_id, config)
|
||||
.ok_or_else(S3Error::invalid_access_key_id)?;
|
||||
|
||||
// Get x-amz-date
|
||||
let amz_date = req
|
||||
.headers()
|
||||
.get("x-amz-date")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.or_else(|| {
|
||||
req.headers()
|
||||
.get("date")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
})
|
||||
.ok_or_else(|| S3Error::missing_security_header("Missing x-amz-date header"))?;
|
||||
|
||||
// Enforce 15-min clock skew
|
||||
check_clock_skew(amz_date)?;
|
||||
|
||||
// Get payload hash
|
||||
let content_sha256 = req
|
||||
.headers()
|
||||
.get("x-amz-content-sha256")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("UNSIGNED-PAYLOAD");
|
||||
|
||||
// Build canonical request
|
||||
let canonical_request = build_canonical_request(req, &parsed.signed_headers, content_sha256);
|
||||
|
||||
// Build string to sign
|
||||
let scope = format!(
|
||||
"{}/{}/s3/aws4_request",
|
||||
parsed.date_stamp, parsed.region
|
||||
);
|
||||
let canonical_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
|
||||
let string_to_sign = format!(
|
||||
"AWS4-HMAC-SHA256\n{}\n{}\n{}",
|
||||
amz_date, scope, canonical_hash
|
||||
);
|
||||
|
||||
// Derive signing key
|
||||
let signing_key = derive_signing_key(
|
||||
&credential.secret_access_key,
|
||||
&parsed.date_stamp,
|
||||
&parsed.region,
|
||||
);
|
||||
|
||||
// Compute signature
|
||||
let computed = hmac_sha256(&signing_key, string_to_sign.as_bytes());
|
||||
let computed_hex = hex::encode(&computed);
|
||||
|
||||
// Constant-time comparison
|
||||
if !constant_time_eq(computed_hex.as_bytes(), parsed.signature.as_bytes()) {
|
||||
return Err(S3Error::signature_does_not_match());
|
||||
}
|
||||
|
||||
Ok(AuthenticatedIdentity {
|
||||
access_key_id: parsed.access_key_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse the Authorization header into its components.
|
||||
fn parse_auth_header(header: &str) -> Result<SigV4Header, S3Error> {
|
||||
// Format: AWS4-HMAC-SHA256 Credential=KEY/YYYYMMDD/region/s3/aws4_request, SignedHeaders=h1;h2, Signature=hex
|
||||
let after_algo = header
|
||||
.strip_prefix("AWS4-HMAC-SHA256")
|
||||
.ok_or_else(S3Error::authorization_header_malformed)?
|
||||
.trim();
|
||||
|
||||
let mut credential_str = None;
|
||||
let mut signed_headers_str = None;
|
||||
let mut signature_str = None;
|
||||
|
||||
for part in after_algo.split(',') {
|
||||
let part = part.trim();
|
||||
if let Some(val) = part.strip_prefix("Credential=") {
|
||||
credential_str = Some(val.trim());
|
||||
} else if let Some(val) = part.strip_prefix("SignedHeaders=") {
|
||||
signed_headers_str = Some(val.trim());
|
||||
} else if let Some(val) = part.strip_prefix("Signature=") {
|
||||
signature_str = Some(val.trim());
|
||||
}
|
||||
}
|
||||
|
||||
let credential_str = credential_str
|
||||
.ok_or_else(S3Error::authorization_header_malformed)?;
|
||||
let signed_headers_str = signed_headers_str
|
||||
.ok_or_else(S3Error::authorization_header_malformed)?;
|
||||
let signature = signature_str
|
||||
.ok_or_else(S3Error::authorization_header_malformed)?
|
||||
.to_string();
|
||||
|
||||
// Parse credential: KEY/YYYYMMDD/region/s3/aws4_request
|
||||
let cred_parts: Vec<&str> = credential_str.splitn(5, '/').collect();
|
||||
if cred_parts.len() < 5 {
|
||||
return Err(S3Error::authorization_header_malformed());
|
||||
}
|
||||
|
||||
let access_key_id = cred_parts[0].to_string();
|
||||
let date_stamp = cred_parts[1].to_string();
|
||||
let region = cred_parts[2].to_string();
|
||||
|
||||
let signed_headers: Vec<String> = signed_headers_str
|
||||
.split(';')
|
||||
.map(|s| s.trim().to_lowercase())
|
||||
.collect();
|
||||
|
||||
Ok(SigV4Header {
|
||||
access_key_id,
|
||||
date_stamp,
|
||||
region,
|
||||
signed_headers,
|
||||
signature,
|
||||
})
|
||||
}
|
||||
|
||||
/// Find a credential by access key ID.
|
||||
fn find_credential<'a>(access_key_id: &str, config: &'a S3Config) -> Option<&'a Credential> {
|
||||
config
|
||||
.auth
|
||||
.credentials
|
||||
.iter()
|
||||
.find(|c| c.access_key_id == access_key_id)
|
||||
}
|
||||
|
||||
/// Check clock skew (15 minutes max).
|
||||
fn check_clock_skew(amz_date: &str) -> Result<(), S3Error> {
|
||||
// Parse ISO 8601 basic format: YYYYMMDDTHHMMSSZ
|
||||
let parsed = chrono::NaiveDateTime::parse_from_str(amz_date, "%Y%m%dT%H%M%SZ")
|
||||
.map_err(|_| S3Error::authorization_header_malformed())?;
|
||||
|
||||
let request_time = chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(parsed, chrono::Utc);
|
||||
let now = chrono::Utc::now();
|
||||
let diff = (now - request_time).num_seconds().unsigned_abs();
|
||||
|
||||
if diff > 15 * 60 {
|
||||
return Err(S3Error::request_time_too_skewed());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the canonical request string.
|
||||
fn build_canonical_request(
|
||||
req: &Request<Incoming>,
|
||||
signed_headers: &[String],
|
||||
payload_hash: &str,
|
||||
) -> String {
|
||||
let method = req.method().as_str();
|
||||
let uri_path = req.uri().path();
|
||||
|
||||
// Canonical URI: the path, already percent-encoded by the client
|
||||
let canonical_uri = if uri_path.is_empty() { "/" } else { uri_path };
|
||||
|
||||
// Canonical query string: sorted key=value pairs
|
||||
let canonical_query = build_canonical_query(req.uri().query().unwrap_or(""));
|
||||
|
||||
// Canonical headers: sorted by lowercase header name
|
||||
let canonical_headers = build_canonical_headers(req, signed_headers);
|
||||
|
||||
// Signed headers string
|
||||
let signed_headers_str = signed_headers.join(";");
|
||||
|
||||
// Payload hash — accept UNSIGNED-PAYLOAD and STREAMING-AWS4-HMAC-SHA256-PAYLOAD as-is
|
||||
let effective_payload_hash = if payload_hash == "UNSIGNED-PAYLOAD"
|
||||
|| payload_hash == "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
|
||||
{
|
||||
payload_hash.to_string()
|
||||
} else {
|
||||
payload_hash.to_string()
|
||||
};
|
||||
|
||||
format!(
|
||||
"{}\n{}\n{}\n{}\n{}\n{}",
|
||||
method,
|
||||
canonical_uri,
|
||||
canonical_query,
|
||||
canonical_headers,
|
||||
signed_headers_str,
|
||||
effective_payload_hash
|
||||
)
|
||||
}
|
||||
|
||||
/// Build canonical query string (sorted key=value pairs).
|
||||
fn build_canonical_query(query: &str) -> String {
|
||||
if query.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut pairs: Vec<(String, String)> = Vec::new();
|
||||
for pair in query.split('&') {
|
||||
let mut parts = pair.splitn(2, '=');
|
||||
let key = parts.next().unwrap_or("");
|
||||
let value = parts.next().unwrap_or("");
|
||||
pairs.push((key.to_string(), value.to_string()));
|
||||
}
|
||||
pairs.sort();
|
||||
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&")
|
||||
}
|
||||
|
||||
/// Build canonical headers string.
|
||||
fn build_canonical_headers(req: &Request<Incoming>, signed_headers: &[String]) -> String {
|
||||
let mut header_map: HashMap<String, Vec<String>> = HashMap::new();
|
||||
|
||||
for (name, value) in req.headers() {
|
||||
let name_lower = name.as_str().to_lowercase();
|
||||
if signed_headers.contains(&name_lower) {
|
||||
if let Ok(val) = value.to_str() {
|
||||
header_map
|
||||
.entry(name_lower)
|
||||
.or_default()
|
||||
.push(val.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
for header_name in signed_headers {
|
||||
let values = header_map
|
||||
.get(header_name)
|
||||
.map(|v| v.join(","))
|
||||
.unwrap_or_default();
|
||||
result.push_str(header_name);
|
||||
result.push(':');
|
||||
result.push_str(&values);
|
||||
result.push('\n');
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Derive the signing key via 4-step HMAC chain.
|
||||
fn derive_signing_key(secret_key: &str, date_stamp: &str, region: &str) -> Vec<u8> {
|
||||
let k_secret = format!("AWS4{}", secret_key);
|
||||
let k_date = hmac_sha256(k_secret.as_bytes(), date_stamp.as_bytes());
|
||||
let k_region = hmac_sha256(&k_date, region.as_bytes());
|
||||
let k_service = hmac_sha256(&k_region, b"s3");
|
||||
hmac_sha256(&k_service, b"aws4_request")
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256.
|
||||
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid");
|
||||
mac.update(data);
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// Constant-time byte comparison.
|
||||
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
let mut diff = 0u8;
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
diff |= x ^ y;
|
||||
}
|
||||
diff == 0
|
||||
}
|
||||
Reference in New Issue
Block a user