Files
containerarchive/rust/src/compression.rs

132 lines
4.1 KiB
Rust

use flate2::Compression;
use flate2::read::{GzDecoder, GzEncoder};
use std::io::Read;
use crate::error::ArchiveError;
/// Supported compression algorithms.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CompressionAlgorithm {
Gzip,
Zstd,
}
impl CompressionAlgorithm {
pub fn from_str(s: &str) -> Self {
match s {
"zstd" => Self::Zstd,
_ => Self::Gzip,
}
}
/// Map to the flag bits stored in IDX entries (bits 1-2).
pub fn to_flags(self) -> u32 {
match self {
Self::Gzip => 0x02, // FLAG_GZIP
Self::Zstd => 0x04, // FLAG_ZSTD
}
}
/// Determine algorithm from IDX flags.
pub fn from_flags(flags: u32) -> Self {
if flags & 0x04 != 0 {
Self::Zstd
} else {
Self::Gzip
}
}
}
/// Compress data with the specified algorithm.
pub fn compress(data: &[u8], algo: CompressionAlgorithm) -> Result<Vec<u8>, ArchiveError> {
match algo {
CompressionAlgorithm::Gzip => compress_gzip(data),
CompressionAlgorithm::Zstd => compress_zstd(data),
}
}
/// Decompress data with the specified algorithm.
pub fn decompress(data: &[u8], algo: CompressionAlgorithm) -> Result<Vec<u8>, ArchiveError> {
match algo {
CompressionAlgorithm::Gzip => decompress_gzip(data),
CompressionAlgorithm::Zstd => decompress_zstd(data),
}
}
/// Decompress data by detecting algorithm from flags.
pub fn decompress_by_flags(data: &[u8], flags: u32) -> Result<Vec<u8>, ArchiveError> {
decompress(data, CompressionAlgorithm::from_flags(flags))
}
// ==================== Gzip ====================
fn compress_gzip(data: &[u8]) -> Result<Vec<u8>, ArchiveError> {
let mut encoder = GzEncoder::new(data, Compression::default());
let mut compressed = Vec::new();
encoder.read_to_end(&mut compressed)
.map_err(|e| ArchiveError::Io(e))?;
Ok(compressed)
}
fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>, ArchiveError> {
let mut decoder = GzDecoder::new(data);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)
.map_err(|e| ArchiveError::Io(e))?;
Ok(decompressed)
}
// ==================== Zstd ====================
fn compress_zstd(data: &[u8]) -> Result<Vec<u8>, ArchiveError> {
zstd::encode_all(data, 3) // level 3 = good balance of speed/ratio
.map_err(|e| ArchiveError::Io(e))
}
fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>, ArchiveError> {
zstd::decode_all(data)
.map_err(|e| ArchiveError::Io(e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gzip_roundtrip() {
let data = b"Hello, this is test data for compression!";
let compressed = compress(data, CompressionAlgorithm::Gzip).unwrap();
let decompressed = decompress(&compressed, CompressionAlgorithm::Gzip).unwrap();
assert_eq!(data.as_slice(), decompressed.as_slice());
}
#[test]
fn test_zstd_roundtrip() {
let data = b"Hello, this is test data for zstd compression!";
let compressed = compress(data, CompressionAlgorithm::Zstd).unwrap();
let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap();
assert_eq!(data.as_slice(), decompressed.as_slice());
}
#[test]
fn test_compression_reduces_size() {
let data = vec![b'A'; 10000];
let gzip = compress(&data, CompressionAlgorithm::Gzip).unwrap();
let zstd = compress(&data, CompressionAlgorithm::Zstd).unwrap();
assert!(gzip.len() < data.len());
assert!(zstd.len() < data.len());
}
#[test]
fn test_decompress_by_flags() {
let data = b"flag-based decompression test";
let gzip_compressed = compress(data, CompressionAlgorithm::Gzip).unwrap();
let result = decompress_by_flags(&gzip_compressed, 0x02).unwrap();
assert_eq!(data.as_slice(), result.as_slice());
let zstd_compressed = compress(data, CompressionAlgorithm::Zstd).unwrap();
let result = decompress_by_flags(&zstd_compressed, 0x04).unwrap();
assert_eq!(data.as_slice(), result.as_slice());
}
}