use flate2::Compression; use flate2::read::{GzDecoder, GzEncoder}; use std::io::Read; use crate::error::ArchiveError; /// Supported compression algorithms. #[derive(Debug, Clone, Copy, PartialEq)] pub enum CompressionAlgorithm { Gzip, Zstd, } impl CompressionAlgorithm { pub fn from_str(s: &str) -> Self { match s { "zstd" => Self::Zstd, _ => Self::Gzip, } } /// Map to the flag bits stored in IDX entries (bits 1-2). pub fn to_flags(self) -> u32 { match self { Self::Gzip => 0x02, // FLAG_GZIP Self::Zstd => 0x04, // FLAG_ZSTD } } /// Determine algorithm from IDX flags. pub fn from_flags(flags: u32) -> Self { if flags & 0x04 != 0 { Self::Zstd } else { Self::Gzip } } } /// Compress data with the specified algorithm. pub fn compress(data: &[u8], algo: CompressionAlgorithm) -> Result, ArchiveError> { match algo { CompressionAlgorithm::Gzip => compress_gzip(data), CompressionAlgorithm::Zstd => compress_zstd(data), } } /// Decompress data with the specified algorithm. pub fn decompress(data: &[u8], algo: CompressionAlgorithm) -> Result, ArchiveError> { match algo { CompressionAlgorithm::Gzip => decompress_gzip(data), CompressionAlgorithm::Zstd => decompress_zstd(data), } } /// Decompress data by detecting algorithm from flags. pub fn decompress_by_flags(data: &[u8], flags: u32) -> Result, ArchiveError> { decompress(data, CompressionAlgorithm::from_flags(flags)) } // ==================== Gzip ==================== fn compress_gzip(data: &[u8]) -> Result, ArchiveError> { let mut encoder = GzEncoder::new(data, Compression::default()); let mut compressed = Vec::new(); encoder.read_to_end(&mut compressed) .map_err(|e| ArchiveError::Io(e))?; Ok(compressed) } fn decompress_gzip(data: &[u8]) -> Result, ArchiveError> { let mut decoder = GzDecoder::new(data); let mut decompressed = Vec::new(); decoder.read_to_end(&mut decompressed) .map_err(|e| ArchiveError::Io(e))?; Ok(decompressed) } // ==================== Zstd ==================== fn compress_zstd(data: &[u8]) -> Result, ArchiveError> { zstd::encode_all(data, 3) // level 3 = good balance of speed/ratio .map_err(|e| ArchiveError::Io(e)) } fn decompress_zstd(data: &[u8]) -> Result, ArchiveError> { zstd::decode_all(data) .map_err(|e| ArchiveError::Io(e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_gzip_roundtrip() { let data = b"Hello, this is test data for compression!"; let compressed = compress(data, CompressionAlgorithm::Gzip).unwrap(); let decompressed = decompress(&compressed, CompressionAlgorithm::Gzip).unwrap(); assert_eq!(data.as_slice(), decompressed.as_slice()); } #[test] fn test_zstd_roundtrip() { let data = b"Hello, this is test data for zstd compression!"; let compressed = compress(data, CompressionAlgorithm::Zstd).unwrap(); let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap(); assert_eq!(data.as_slice(), decompressed.as_slice()); } #[test] fn test_compression_reduces_size() { let data = vec![b'A'; 10000]; let gzip = compress(&data, CompressionAlgorithm::Gzip).unwrap(); let zstd = compress(&data, CompressionAlgorithm::Zstd).unwrap(); assert!(gzip.len() < data.len()); assert!(zstd.len() < data.len()); } #[test] fn test_decompress_by_flags() { let data = b"flag-based decompression test"; let gzip_compressed = compress(data, CompressionAlgorithm::Gzip).unwrap(); let result = decompress_by_flags(&gzip_compressed, 0x02).unwrap(); assert_eq!(data.as_slice(), result.as_slice()); let zstd_compressed = compress(data, CompressionAlgorithm::Zstd).unwrap(); let result = decompress_by_flags(&zstd_compressed, 0x04).unwrap(); assert_eq!(data.as_slice(), result.as_slice()); } }