diff --git a/Changelog.md b/Changelog.md index 38bbc122..82e753df 100644 --- a/Changelog.md +++ b/Changelog.md @@ -33,4 +33,4 @@ This document records the changes made between versions, starting with version 0 * Added convenience functions to FrameDecoder to decode multiple frames from a buffer (https://github.com/philipc) # After 0.7.3 - +* Add initial compression support diff --git a/Readme.md b/Readme.md index 3a96963b..92d18138 100644 --- a/Readme.md +++ b/Readme.md @@ -9,8 +9,7 @@ A pure Rust implementation of the Zstandard compression algorithm, as defined in [this document](https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md). This crate contains a fully operational implementation of the decompression portion of the standard. - -*Work has started on a compressor, but it has not reached a point where the compressor provides any real function.* (CONTRIBUTORS WELCOME) +It also provides a compressor which is usable, but it does not yet reach the speed, ratio or configurability of the original zstd library. This crate is currently actively maintained. @@ -19,9 +18,14 @@ This crate is currently actively maintained. Feature complete on the decoder side. In terms of speed it is still behind the original C implementation which has a rust binding located [here](https://github.com/gyscos/zstd-rs). On the compression side: -- [x] Support for generating raw, uncompressed frames -- [ ] Support for generating RLE compressed blocks -- [ ] Support for generating compressed blocks at any compression level +- Support for generating compressed blocks at any compression level + - [x] Uncompressed + - [x] Fastest (roughly level 1) + - [ ] Default (roughly level 3) + - [ ] Better (roughly level 7) + - [ ] Best (roughly level 11) +- [ ] Checksums +- [ ] Dictionaries ## Speed diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 0d40fcff..634c1981 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -31,3 +31,11 @@ path = "fuzz_targets/encode.rs" [[bin]] name = "interop" path = "fuzz_targets/interop.rs" + +[[bin]] +name = "huff0" +path = "fuzz_targets/huff0.rs" + +[[bin]] +name = "fse" +path = "fuzz_targets/fse.rs" diff --git a/fuzz/artifacts/fse/crash-16fdc285684fe17e4a84ff6605c7f0e362af3dfa b/fuzz/artifacts/fse/crash-16fdc285684fe17e4a84ff6605c7f0e362af3dfa new file mode 100644 index 00000000..ff6325a6 Binary files /dev/null and b/fuzz/artifacts/fse/crash-16fdc285684fe17e4a84ff6605c7f0e362af3dfa differ diff --git a/fuzz/artifacts/fse/crash-da39a3ee5e6b4b0d3255bfef95601890afd80709 b/fuzz/artifacts/fse/crash-da39a3ee5e6b4b0d3255bfef95601890afd80709 new file mode 100644 index 00000000..e69de29b diff --git a/fuzz/artifacts/fse/crash-e587fc04ebe1b7e97d0aa916ef8d3f2cc92fb4b1 b/fuzz/artifacts/fse/crash-e587fc04ebe1b7e97d0aa916ef8d3f2cc92fb4b1 new file mode 100644 index 00000000..7bb4d16e Binary files /dev/null and b/fuzz/artifacts/fse/crash-e587fc04ebe1b7e97d0aa916ef8d3f2cc92fb4b1 differ diff --git a/fuzz/artifacts/huff0/crash-da39a3ee5e6b4b0d3255bfef95601890afd80709 b/fuzz/artifacts/huff0/crash-da39a3ee5e6b4b0d3255bfef95601890afd80709 new file mode 100644 index 00000000..e69de29b diff --git a/fuzz/artifacts/huff0/crash-e7d75b9bfbab3e8e4df53bb28b87a1a01ee99d3d b/fuzz/artifacts/huff0/crash-e7d75b9bfbab3e8e4df53bb28b87a1a01ee99d3d new file mode 100644 index 00000000..dcf90392 Binary files /dev/null and b/fuzz/artifacts/huff0/crash-e7d75b9bfbab3e8e4df53bb28b87a1a01ee99d3d differ diff --git a/fuzz/artifacts/interop/crash-5ba93c9db0cff93f52b521d7420e43f6eda2784f b/fuzz/artifacts/interop/crash-5ba93c9db0cff93f52b521d7420e43f6eda2784f new file mode 100644 index 00000000..f76dd238 Binary files /dev/null and b/fuzz/artifacts/interop/crash-5ba93c9db0cff93f52b521d7420e43f6eda2784f differ diff --git a/fuzz/artifacts/interop/crash-a9f55c479d7c420764bde5bd6c666a7997d79d26 b/fuzz/artifacts/interop/crash-a9f55c479d7c420764bde5bd6c666a7997d79d26 new file mode 100644 index 00000000..c3342383 Binary files /dev/null and b/fuzz/artifacts/interop/crash-a9f55c479d7c420764bde5bd6c666a7997d79d26 differ diff --git a/fuzz/fuzz_targets/encode.rs b/fuzz/fuzz_targets/encode.rs index f6ce1f36..cbb4186d 100644 --- a/fuzz/fuzz_targets/encode.rs +++ b/fuzz/fuzz_targets/encode.rs @@ -4,8 +4,21 @@ extern crate ruzstd; use ruzstd::encoding::{FrameCompressor, CompressionLevel}; fuzz_target!(|data: &[u8]| { - let mut content = data; - let mut compressor = FrameCompressor::new(data, CompressionLevel::Uncompressed); let mut output = Vec::new(); - compressor.compress(&mut output); + let mut compressor = FrameCompressor::new(data, &mut output, CompressionLevel::Uncompressed); + compressor.compress(); + + let mut decoded = Vec::with_capacity(data.len()); + let mut decoder = ruzstd::FrameDecoder::new(); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(data, &decoded); + + let mut output = Vec::new(); + let mut compressor = FrameCompressor::new(data, &mut output, CompressionLevel::Fastest); + compressor.compress(); + + let mut decoded = Vec::with_capacity(data.len()); + let mut decoder = ruzstd::FrameDecoder::new(); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(data, &decoded); }); \ No newline at end of file diff --git a/fuzz/fuzz_targets/fse.rs b/fuzz/fuzz_targets/fse.rs new file mode 100644 index 00000000..43efe777 --- /dev/null +++ b/fuzz/fuzz_targets/fse.rs @@ -0,0 +1,8 @@ +#![no_main] +#[macro_use] extern crate libfuzzer_sys; +extern crate ruzstd; +use ruzstd::fse::round_trip; + +fuzz_target!(|data: &[u8]| { + round_trip(data); +}); \ No newline at end of file diff --git a/fuzz/fuzz_targets/huff0.rs b/fuzz/fuzz_targets/huff0.rs new file mode 100644 index 00000000..b2dc806c --- /dev/null +++ b/fuzz/fuzz_targets/huff0.rs @@ -0,0 +1,8 @@ +#![no_main] +#[macro_use] extern crate libfuzzer_sys; +extern crate ruzstd; +use ruzstd::huff0::round_trip; + +fuzz_target!(|data: &[u8]| { + round_trip(data); +}); \ No newline at end of file diff --git a/fuzz/fuzz_targets/interop.rs b/fuzz/fuzz_targets/interop.rs index c16e84fe..00d2c90f 100644 --- a/fuzz/fuzz_targets/interop.rs +++ b/fuzz/fuzz_targets/interop.rs @@ -33,10 +33,19 @@ fn encode_zstd(data: &[u8]) -> Result, std::io::Error> { fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec { let mut input = Vec::new(); + let mut output = Vec::new(); data.read_to_end(&mut input).unwrap(); - let mut compressor = ruzstd::encoding::FrameCompressor::new(&input, ruzstd::encoding::CompressionLevel::Uncompressed); + let mut compressor = ruzstd::encoding::FrameCompressor::new(input.as_slice(), &mut output, ruzstd::encoding::CompressionLevel::Uncompressed); + compressor.compress(); + output +} + +fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec { + let mut input = Vec::new(); let mut output = Vec::new(); - compressor.compress(&mut output); + data.read_to_end(&mut input).unwrap(); + let mut compressor = ruzstd::encoding::FrameCompressor::new(input.as_slice(), &mut output, ruzstd::encoding::CompressionLevel::Fastest); + compressor.compress(); output } @@ -69,4 +78,12 @@ fuzz_target!(|data: &[u8]| { decoded, data, "Decoded data did not match the original input during compression" ); + // Compressed encoding + let mut input = data; + let compressed = encode_ruzstd_compressed(&mut input); + let decoded = decode_zstd(&compressed).unwrap(); + assert_eq!( + decoded, data, + "Decoded data did not match the original input during compression" + ); }); diff --git a/src/bin/zstd.rs b/src/bin/zstd.rs index af5760f9..d6aebb7c 100644 --- a/src/bin/zstd.rs +++ b/src/bin/zstd.rs @@ -1,10 +1,14 @@ extern crate ruzstd; use std::fs::File; +use std::io::BufReader; use std::io::Read; use std::io::Seek; use std::io::SeekFrom; use std::io::Write; +use std::time::Instant; +use ruzstd::encoding::CompressionLevel; +use ruzstd::encoding::FrameCompressor; use ruzstd::frame::ReadFrameHeaderError; use ruzstd::frame_decoder::FrameDecoderError; @@ -18,11 +22,7 @@ struct StateTracker { old_percentage: i8, } -fn main() { - let mut file_paths: Vec<_> = std::env::args().filter(|f| !f.starts_with('-')).collect(); - let flags: Vec<_> = std::env::args().filter(|f| f.starts_with('-')).collect(); - file_paths.remove(0); - +fn decompress(flags: &[String], file_paths: &[String]) { if !flags.contains(&"-d".to_owned()) { eprintln!("This zstd implementation only supports decompression. Please add a \"-d\" flag"); return; @@ -128,6 +128,63 @@ fn main() { } } +struct PercentPrintReader { + total: usize, + counter: usize, + last_percent: usize, + reader: R, +} + +impl Read for PercentPrintReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let new_bytes = self.reader.read(buf)?; + self.counter += new_bytes; + let progress = self.counter * 100 / self.total; + if progress > self.last_percent { + self.last_percent = progress; + eprint!("\r"); + eprint!("{} % done", progress); + } + Ok(new_bytes) + } +} + +fn main() { + let mut file_paths: Vec<_> = std::env::args().filter(|f| !f.starts_with('-')).collect(); + let flags: Vec<_> = std::env::args().filter(|f| f.starts_with('-')).collect(); + file_paths.remove(0); + + if flags.is_empty() { + for path in file_paths { + let start_instant = Instant::now(); + let file = std::fs::File::open(&path).unwrap(); + let input_len = file.metadata().unwrap().len() as usize; + let file = PercentPrintReader { + reader: BufReader::new(file), + total: input_len, + counter: 0, + last_percent: 0, + }; + let mut output = Vec::new(); + let mut encoder = FrameCompressor::new(file, &mut output, CompressionLevel::Fastest); + encoder.compress(); + println!( + "Compressed {path:} from {} to {} ({}%) took {}ms", + input_len, + output.len(), + if input_len == 0 { + 0 + } else { + output.len() * 100 / input_len + }, + start_instant.elapsed().as_millis() + ); + } + } else { + decompress(&flags, &file_paths); + } +} + fn do_something(data: &[u8], s: &mut StateTracker) { //Do something. Like writing it to a file or to stdout... std::io::stdout().write_all(data).unwrap(); diff --git a/src/decoding/block_decoder.rs b/src/decoding/block_decoder.rs index ed4962e2..e40d0613 100644 --- a/src/decoding/block_decoder.rs +++ b/src/decoding/block_decoder.rs @@ -447,6 +447,13 @@ impl BlockDecoder { vprintln!("Executing sequences"); execute_sequences(workspace)?; } else { + if !raw.is_empty() { + return Err(DecompressBlockError::DecodeSequenceError( + DecodeSequenceError::ExtraBits { + bits_remaining: raw.len() as isize * 8, + }, + )); + } workspace.buffer.push(&workspace.literals_buffer); workspace.sequences.clear(); } diff --git a/src/decoding/decodebuffer.rs b/src/decoding/decodebuffer.rs index 5805e265..44d4a595 100644 --- a/src/decoding/decodebuffer.rs +++ b/src/decoding/decodebuffer.rs @@ -285,7 +285,7 @@ impl DecodeBuffer { amount: usize, } - impl<'a> Drop for DrainGuard<'a> { + impl Drop for DrainGuard<'_> { fn drop(&mut self) { if self.amount != 0 { self.buffer.drop_first_n(self.amount); diff --git a/src/decoding/ringbuffer.rs b/src/decoding/ringbuffer.rs index 9af2a69f..adf1439b 100644 --- a/src/decoding/ringbuffer.rs +++ b/src/decoding/ringbuffer.rs @@ -362,6 +362,7 @@ impl RingBuffer { unsafe { copy_bytes_overshooting(src, dst, len - after_tail) } } } else { + #[allow(clippy::collapsible_else_if)] if self.head + start > self.cap { // Continuous read section and destination section: // diff --git a/src/encoding/bit_writer.rs b/src/encoding/bit_writer.rs index 122b9dc9..bbb4e92f 100644 --- a/src/encoding/bit_writer.rs +++ b/src/encoding/bit_writer.rs @@ -1,223 +1,230 @@ //! Use [BitWriter] to write an arbitrary amount of bits into a buffer. -use alloc::vec; use alloc::vec::Vec; /// An interface for writing an arbitrary number of bits into a buffer. Write new bits into the buffer with `write_bits`, and /// obtain the output using `dump`. -pub(crate) struct BitWriter { +#[derive(Debug)] +pub(crate) struct BitWriter>> { /// The buffer that's filled with bits - output: Vec, + output: V, + /// holds a partially filled byte which gets put in outpu when it's fill with a write_bits call + partial: u64, + bits_in_partial: usize, /// The index pointing to the next unoccupied bit. Effectively just /// the number of bits that have been written into the buffer so far. bit_idx: usize, } -impl BitWriter { +impl BitWriter> { /// Initialize a new writer. pub fn new() -> Self { Self { output: Vec::new(), + partial: 0, + bits_in_partial: 0, bit_idx: 0, } } +} - /// Wrap a writer around an existing vec. - /// - /// Currently unused, but will almost certainly be used later upon further optimizing - #[allow(unused)] - pub fn from(buf: Vec) -> Self { - Self { - bit_idx: buf.len() * 8, - output: buf, +impl BitWriter<&mut Vec> {} + +impl>> BitWriter { + /// Initialize a new writer. + pub fn from(mut output: V) -> BitWriter { + BitWriter { + bit_idx: output.as_mut().len() * 8, + output, + partial: 0, + bits_in_partial: 0, } } - /// Write `num_bits` from `bits` into the writer, returning the number of bits - /// read. - /// - /// `num_bits` refers to how many bits starting from the *least significant position*, - /// but the bits will be written starting from the *most significant position*, continuing - /// to the least significant position. - /// - /// It's up to the caller to ensure that any in the cursor beyond `num_bits` is always zero. - /// If it's not, the output buffer will be corrupt. - /// - /// Refer to tests for example usage. - // TODO: Because bitwriter isn't directly public, any errors would be caused by internal library bugs, - // and so this function should just panic if it encounters issues. - pub fn write_bits(&mut self, bits: &[u8], num_bits: usize) -> usize { - if bits.len() * 8 < num_bits { - panic!("asked to write more bits into buffer ({}) than were provided by the `bits` buffer ({})", num_bits, bits.len() * 8); + pub fn index(&self) -> usize { + self.bit_idx + self.bits_in_partial + } + + pub fn reset_to(&mut self, index: usize) { + assert!(index % 8 == 0); + self.partial = 0; + self.bits_in_partial = 0; + self.bit_idx = index; + self.output.as_mut().resize(index / 8, 0); + } + + pub fn change_bits(&mut self, idx: usize, bits: impl Into, num_bits: usize) { + self.change_bits_64(idx, bits.into(), num_bits); + } + + pub fn change_bits_64(&mut self, mut idx: usize, mut bits: u64, mut num_bits: usize) { + self.flush(); + assert!(idx + num_bits < self.index()); + assert!(self.index() - (idx + num_bits) > self.bits_in_partial); + + if idx % 8 != 0 { + let bits_in_first_byte = 8 - (idx % 8); + assert!(bits_in_first_byte <= num_bits); + self.output.as_mut()[idx / 8] &= 0xFFu8 >> bits_in_first_byte; + let new_bits = (bits << (8 - bits_in_first_byte)) as u8; + self.output.as_mut()[idx / 8] |= new_bits; + num_bits -= bits_in_first_byte; + bits >>= bits_in_first_byte; + idx += bits_in_first_byte; } - // Special handling for if both the input and output are byte aligned - if self.bit_idx % 8 == 0 && num_bits / 8 == bits.len() { - self.output.extend_from_slice(bits); - self.bit_idx += num_bits; - return num_bits; + let mut idx = idx / 8; + + while num_bits >= 8 { + self.output.as_mut()[idx] = bits as u8; + num_bits -= 8; + bits >>= 8; + idx += 1; + } + + if num_bits > 0 { + self.output.as_mut()[idx] &= 0xFFu8 << num_bits; + self.output.as_mut()[idx] |= bits as u8; } + } - // Make sure there's space for the new bits by finding the total size of the buffer in bits, then round up to the nearest multiple of 8 - // to find how many *bytes* that would occupy. After that, expand the vec to the new size. - let new_size_of_output = (self.bit_idx + num_bits + 7) >> 3; - let size_of_extension = new_size_of_output - self.output.len(); - let new_chunk: Vec = vec![0; size_of_extension]; - self.output.extend(new_chunk); - - // We will never need to operate across a byte boundary in a single iteration of the loop. - let mut num_bits_written: usize = 0; - while num_bits_written < num_bits { - // The number of unoccupied bits in the output buffer - // byte that the cursor is currently indexed into - let num_bits_left_in_output_byte = 8 - (self.bit_idx % 8); - // The number of bits left to write in the currently selected input buffer byte - let mut num_bits_left_in_input_byte = (num_bits - num_bits_written) % 8; - if num_bits_left_in_input_byte == 0 { - num_bits_left_in_input_byte = 8; - } - // The byte that we're currently reading from in the input - let input_byte_index: usize = num_bits_written / 8; - let byte_index_to_update = self.bit_idx / 8; - if num_bits_left_in_output_byte >= num_bits_left_in_input_byte { - // Case 1: We read from the input until the next input byte boundary (or end of data), because - // there's more free space in the output byte then there are bits to read in the input byte. - - // In the below example, we're adding - // 0b111 to a buffer, then adding 0b000. - // Because we start from the left, to position - // 0b111 in the correct position, we want the - // leftmost bit to be at index 7, and the rightmost - // bit to be in position 5. To achieve this, you can - // shift 0b111 over 5 times. - // - // 76543210 ◄─── Bit Index - // 111◄──── Move 0b111 to the left 5 slots so that it - // occupies the leftmost space - // The formula for this would look like (8 - num_bits_added). - // Then, to write 0b000 into the buffer, we can use the same - // formula again, but we need to account for the number of bits - // already written into the buffer. This means our new formula looks - // like (8 - num_bits_added - num_bits_already_in_buffer). In this case - // there are 3 bits already in the buffer, and we're writing in 3 bits, - // so (8 - 3 - 3) = 2. - // - // 111◄──── Data already in buffer - // 000◄─ New data being added into the buffer - // Then to determine what the final buffer looks like, we can simply OR - // the two buffers together. - // 111───── ◄── The lines mark "Unoccupied space", so they'd just be zeros - // ───000── - // - // 111000── ◄── The final buffer - - let num_bits_already_in_byte = 8 - num_bits_left_in_output_byte; - let num_bits_being_added = (num_bits - num_bits_written) % 8; - if num_bits_left_in_input_byte == 8 && num_bits_left_in_output_byte == 8 { - // In this case, we're trying to shift all the way over to the next byte, so just update that next byte - self.output[byte_index_to_update] = bits[input_byte_index]; - num_bits_written += 8; - self.bit_idx += 8; - continue; - } - // Shift the bits left - let num_spots_to_move_left = 8 - num_bits_being_added - num_bits_already_in_byte; - // Combine it with the existing data - let aligned_byte = bits[input_byte_index] << num_spots_to_move_left; - let merged_byte = self.output[byte_index_to_update] | aligned_byte; - // Write changes to the output buffer - self.output[byte_index_to_update] = merged_byte; - - // Advance the bit cursor forwards and update - // the number of bits being added - num_bits_written += num_bits_being_added; - self.bit_idx += num_bits_being_added; - } else { - // Case 2: There's not enough free space in the output byte to read till the next input byte boundary, so we - // read to the next output byte boundary. - - // This looks like reading from input bit index onwards N bits, where N is the number of free bits available in the output byte - // - // In the below example, we've already written 3 0s into the buffer, but we want to write - // 6 1s into the buffer. - // - // 76543210◄─── Bit Index - // 111 ◄─────── Data already in buffer - // 000000◄── Data we want to add to the buffer (not yet aligned). - // - // You'll note that we can't do the same thing we did last time, because we have more data - // than will fit into the byte, so we need do this in multiple passes, writing data up to the boundary, - // then writing data into the next byte. Getting that final bit can happen on the next pass, using the first case, where - // we read until an input byte boundary. - // Broken down into steps, this looks something like this: - // - // ◄──00000X Because there may be arbitrary data behind the cursor in the - // input data, we need to shift left, then right, to mask out that data - // and ensure it's all zeros (so that when we OR with the output, we don't corrupt it). - // Here, I've replaced that last 0 with an X because it's in the next byte, so it's ignored - // until the next pass. The amount we shift left will depend on how far into the input byte - // the input cursor is. - // - // ──►00000X Next we move that data to the right N spaces, where N is the number of bits already occupied - // in the current byte. In the example, that would be 3. - // Our value is now masked and aligned, so we can merge it with the currently selected output byte - // and update it, then advance the output and input cursors 8 - N bits, again, where N is the amount - // of bits already occupied in the buffer. - - // Shift the bits left to zero out any data behind the read cursor - let num_spots_to_move_left = (8 - num_bits_left_in_input_byte) % 8; - let masked_byte = bits[input_byte_index] << num_spots_to_move_left; - // Shift the bits right so that the data is inserted into the next free spot - let aligned_byte = masked_byte >> (self.bit_idx % 8); - // // Combine our newly aligned byte with the output byte - let merged_byte = self.output[byte_index_to_update] | aligned_byte; - // Write changes to the output buffer - self.output[byte_index_to_update] = merged_byte; - // Advance the bit cursor forwards and update - // the number of bits being added - num_bits_written += num_bits_left_in_output_byte; - self.bit_idx += num_bits_left_in_output_byte; - } + pub fn append_bytes(&mut self, data: &[u8]) { + if self.misaligned() != 0 { + panic!("Don't append bytes when writer is misaligned") + } + self.flush(); + self.output.as_mut().extend_from_slice(data); + self.bit_idx += data.len() * 8; + } + + pub fn flush(&mut self) { + let full_bytes = self.bits_in_partial / 8; + self.output + .as_mut() + .extend_from_slice(&self.partial.to_le_bytes()[..full_bytes]); + self.partial >>= full_bytes * 8; + self.bits_in_partial -= full_bytes * 8; + self.bit_idx += full_bytes * 8; + } + + /// Write the lower `num_bits` from `bits` into the writer + pub fn write_bits(&mut self, bits: impl Into, num_bits: usize) { + self.write_bits_64(bits.into(), num_bits); + } + + #[cold] + fn write_bits_64_cold(&mut self, bits: u64, num_bits: usize) { + let bits_free_in_partial = 64 - self.bits_in_partial; + let part = bits << (64 - bits_free_in_partial); + let merged = self.partial | part; + self.output + .as_mut() + .extend_from_slice(&merged.to_le_bytes()); + self.bit_idx += 64; + self.partial = 0; + self.bits_in_partial = 0; + + let mut num_bits = num_bits - bits_free_in_partial; + let mut bits = bits >> bits_free_in_partial; + + while num_bits / 8 > 0 { + let byte = bits as u8; + self.output.as_mut().push(byte); + num_bits -= 8; + self.bit_idx += 8; + bits >>= 8; + } + + debug_assert!(num_bits < 8); + if num_bits > 0 { + let mask = (1 << num_bits) - 1; + self.partial = bits & mask; + self.bits_in_partial = num_bits; + } + } + + pub fn write_bits_64(&mut self, bits: u64, num_bits: usize) { + if num_bits == 0 { + return; + } + + if bits > 0 { + debug_assert!(bits.ilog2() <= num_bits as u32); + } + + // fill partial byte first + if num_bits + self.bits_in_partial < 64 { + let part = bits << self.bits_in_partial; + let merged = self.partial | part; + self.partial = merged; + self.bits_in_partial += num_bits; + } else { + self.write_bits_64_cold(bits, num_bits); } - num_bits_written } /// Returns the populated buffer that you've been writing bits into. /// /// This function consumes the writer, so it cannot be used after /// dumping - pub fn dump(self) -> Vec { - if self.bit_idx % 8 != 0 { - panic!("`dump` was called on a bit writer but an even number of bytes weren't written into the buffer") + pub fn dump(mut self) -> V { + if self.misaligned() != 0 { + panic!("`dump` was called on a bit writer but an even number of bytes weren't written into the buffer. Was: {}", self.index()) } + self.flush(); + debug_assert_eq!(self.partial, 0); self.output } + + /// Returns how many bits are missing for an even byte + pub fn misaligned(&self) -> usize { + let idx = self.index(); + if idx % 8 == 0 { + 0 + } else { + 8 - (idx % 8) + } + } } #[cfg(test)] mod tests { use super::BitWriter; - use std::vec; + use alloc::vec; #[test] fn from_existing() { // Define an existing vec, write some bits into it - let existing_vec = vec![255_u8]; - let mut bw = BitWriter::from(existing_vec); - bw.write_bits(&[0], 8); - assert_eq!(vec![255, 0], bw.dump()); + let mut existing_vec = vec![255_u8]; + let mut bw = BitWriter::from(&mut existing_vec); + bw.write_bits(0u8, 8); + bw.flush(); + assert_eq!(vec![255, 0], existing_vec); + } + + #[test] + fn change_bits() { + let mut writer = BitWriter::new(); + writer.write_bits(0u32, 24); + writer.change_bits(8, 0xFFu8, 8); + assert_eq!(vec![0, 0xFF, 0], writer.dump()); + + let mut writer = BitWriter::new(); + writer.write_bits(0u32, 24); + writer.change_bits(6, 0x0FFFu16, 12); + assert_eq!(vec![0b11000000, 0xFF, 0b00000011], writer.dump()); } #[test] fn single_byte_written_4_4() { // Write the first 4 bits as 1s and the last 4 bits as 0s // 1010 is used where values should never be read from. - let mut bw: BitWriter = BitWriter::new(); - bw.write_bits(&[0b010_1111], 4); - bw.write_bits(&[0b1010_0000], 4); + let mut bw = BitWriter::new(); + bw.write_bits(0b1111u8, 4); + bw.write_bits(0b0000u8, 4); let output = bw.dump(); assert!(output.len() == 1, "Single byte written into writer returned a vec that wasn't one byte, vec was {} elements long", output.len()); assert_eq!( - 0b1111_0000, output[0], + 0b0000_1111, output[0], "4 bits and 4 bits written into buffer" ); } @@ -225,30 +232,30 @@ mod tests { #[test] fn single_byte_written_3_5() { // Write the first 3 bits as 1s and the last 5 bits as 0s - let mut bw: BitWriter = BitWriter::new(); - bw.write_bits(&[0b0101_0111], 3); - bw.write_bits(&[0b1010_0000], 5); + let mut bw = BitWriter::new(); + bw.write_bits(0b111u8, 3); + bw.write_bits(0b0_0000u8, 5); let output = bw.dump(); assert!(output.len() == 1, "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", output.len()); - assert_eq!(0b1110_0000, output[0], "3 and 5 bits written into buffer"); + assert_eq!(0b0000_0111, output[0], "3 and 5 bits written into buffer"); } #[test] fn single_byte_written_1_7() { // Write the first bit as a 1 and the last 7 bits as 0s - let mut bw: BitWriter = BitWriter::new(); - bw.write_bits(&[0b1], 1); - bw.write_bits(&[0], 7); + let mut bw = BitWriter::new(); + bw.write_bits(0b1u8, 1); + bw.write_bits(0u8, 7); let output = bw.dump(); assert!(output.len() == 1, "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", output.len()); - assert_eq!(0b1000_0000, output[0], "1 and 7 bits written into buffer"); + assert_eq!(0b0000_0001, output[0], "1 and 7 bits written into buffer"); } #[test] fn single_byte_written_8() { // Write an entire byte - let mut bw: BitWriter = BitWriter::new(); - bw.write_bits(&[1], 8); + let mut bw = BitWriter::new(); + bw.write_bits(1u8, 8); let output = bw.dump(); assert!(output.len() == 1, "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", output.len()); assert_eq!(1, output[0], "1 and 7 bits written into buffer"); @@ -258,49 +265,49 @@ mod tests { fn multi_byte_clean_boundary_4_4_4_4() { // Writing 4 bits at a time for 2 bytes let mut bw = BitWriter::new(); - bw.write_bits(&[0], 4); - bw.write_bits(&[0b1111], 4); - bw.write_bits(&[0b1111], 4); - bw.write_bits(&[0], 4); - assert_eq!(vec![0b0000_1111, 0b1111_0000], bw.dump()); + bw.write_bits(0u8, 4); + bw.write_bits(0b1111u8, 4); + bw.write_bits(0b1111u8, 4); + bw.write_bits(0u8, 4); + assert_eq!(vec![0b1111_0000, 0b0000_1111], bw.dump()); } #[test] fn multi_byte_clean_boundary_16_8() { // Writing 16 bits at once let mut bw = BitWriter::new(); - bw.write_bits(&[1, 0], 16); - bw.write_bits(&[69], 8); - assert_eq!(vec![1, 0, 69], bw.dump()) + bw.write_bits(0x0100u16, 16); + bw.write_bits(69u8, 8); + assert_eq!(vec![0, 1, 69], bw.dump()) } #[test] fn multi_byte_boundary_crossed_4_12() { // Writing 4 1s and then 12 zeros let mut bw = BitWriter::new(); - bw.write_bits(&[0b0000_1111], 4); - bw.write_bits(&[0b0000_0000, 0b1010_0000], 12); - assert_eq!(vec![0b1111_0000, 0b0000_0000], bw.dump()); + bw.write_bits(0b1111u8, 4); + bw.write_bits(0b0000_0011_0100_0010u16, 12); + assert_eq!(vec![0b0010_1111, 0b0011_0100], bw.dump()); } #[test] fn multi_byte_boundary_crossed_4_5_7() { // Writing 4 1s and then 5 zeros then 7 1s let mut bw = BitWriter::new(); - bw.write_bits(&[0b1010_1111], 4); - bw.write_bits(&[0b1010_0000], 5); - bw.write_bits(&[0b0111_1111], 7); - assert_eq!(vec![0b1111_0000, 0b0111_1111], bw.dump()); + bw.write_bits(0b1111u8, 4); + bw.write_bits(0b0_0000u8, 5); + bw.write_bits(0b111_1111u8, 7); + assert_eq!(vec![0b0000_1111, 0b1111_1110], bw.dump()); } #[test] fn multi_byte_boundary_crossed_1_9_6() { // Writing 1 1 and then 9 zeros then 6 1s let mut bw = BitWriter::new(); - bw.write_bits(&[0b0000_0001], 1); - bw.write_bits(&[0, 0b1010_1010], 9); - bw.write_bits(&[0b0011_1111], 6); - assert_eq!(vec![0b1000_0000, 0b0011_1111], bw.dump()); + bw.write_bits(0b1u8, 1); + bw.write_bits(0b0_0000_0000u16, 9); + bw.write_bits(0b11_1111u8, 6); + assert_eq!(vec![0b0000_0001, 0b1111_1100], bw.dump()); } #[test] @@ -309,10 +316,24 @@ mod tests { // Write a single bit in then dump it, making sure // the correct error is returned let mut bw = BitWriter::new(); - bw.write_bits(&[0], 1); + bw.write_bits(0u8, 1); bw.dump(); } + #[test] + #[should_panic] + fn catches_dirty_upper_bits() { + let mut bw = BitWriter::new(); + bw.write_bits(10u8, 1); + } + + #[test] + fn add_multiple_aligned() { + let mut bw = BitWriter::new(); + bw.write_bits(0x00_0F_F0_FFu32, 32); + assert_eq!(vec![0xFF, 0xF0, 0x0F, 0x00], bw.dump()); + } + // #[test] // fn catches_more_than_in_buf() { // todo!(); diff --git a/src/encoding/block_header.rs b/src/encoding/block_header.rs index 27477e25..cdfa7dc4 100644 --- a/src/encoding/block_header.rs +++ b/src/encoding/block_header.rs @@ -1,5 +1,5 @@ use crate::blocks::block::BlockType; -use std::vec::Vec; +use alloc::vec::Vec; #[derive(Debug)] pub struct BlockHeader { @@ -40,7 +40,7 @@ impl BlockHeader { mod tests { use super::BlockHeader; use crate::{blocks::block::BlockType, decoding::block_decoder}; - use std::vec::Vec; + use alloc::vec::Vec; #[test] fn block_header_serialize() { diff --git a/src/encoding/blocks/compressed.rs b/src/encoding/blocks/compressed.rs new file mode 100644 index 00000000..58a13ba5 --- /dev/null +++ b/src/encoding/blocks/compressed.rs @@ -0,0 +1,236 @@ +use alloc::vec::Vec; + +use crate::{ + encoding::{bit_writer::BitWriter, match_generator::Sequence, Matcher}, + fse::fse_encoder::{default_ll_table, default_ml_table, default_of_table, FSETable, State}, + huff0::huff0_encoder, +}; + +pub fn compress_block(matcher: &mut M, output: &mut Vec) { + let mut literals_vec = Vec::new(); + let mut sequences = Vec::new(); + matcher.start_matching(|seq| { + match seq { + Sequence::Literals { literals } => literals_vec.extend_from_slice(literals), + Sequence::Triple { + literals, + offset, + match_len, + } => { + literals_vec.extend_from_slice(literals); + sequences.push(crate::blocks::sequence_section::Sequence { + ll: literals.len() as u32, + ml: match_len as u32, + of: (offset + 3) as u32, // TODO make use of the offset history + }); + } + } + }); + + // literals section + + let mut writer = BitWriter::from(output); + if literals_vec.len() > 1024 { + compress_literals(&literals_vec, &mut writer); + } else { + raw_literals(&literals_vec, &mut writer); + } + + // sequences section + + if sequences.is_empty() { + writer.write_bits(0u8, 8); + } else { + encode_seqnum(sequences.len(), &mut writer); + + // use standard FSE tables + writer.write_bits(0u8, 8); + + let ll_table: FSETable = default_ll_table(); + let ml_table: FSETable = default_ml_table(); + let of_table: FSETable = default_of_table(); + + let sequence = sequences[sequences.len() - 1]; + let (ll_code, ll_add_bits, ll_num_bits) = encode_literal_length(sequence.ll); + let (of_code, of_add_bits, of_num_bits) = encode_offset(sequence.of); + let (ml_code, ml_add_bits, ml_num_bits) = encode_match_len(sequence.ml); + let mut ll_state: &State = ll_table.start_state(ll_code); + let mut ml_state: &State = ml_table.start_state(ml_code); + let mut of_state: &State = of_table.start_state(of_code); + + writer.write_bits(ll_add_bits, ll_num_bits); + writer.write_bits(ml_add_bits, ml_num_bits); + writer.write_bits(of_add_bits, of_num_bits); + + // encode backwards so the decoder reads the first sequence first + if sequences.len() > 1 { + for sequence in (0..=sequences.len() - 2).rev() { + let sequence = sequences[sequence]; + let (ll_code, ll_add_bits, ll_num_bits) = encode_literal_length(sequence.ll); + let (of_code, of_add_bits, of_num_bits) = encode_offset(sequence.of); + let (ml_code, ml_add_bits, ml_num_bits) = encode_match_len(sequence.ml); + + { + let next = of_table.next_state(of_code, of_state.index); + let diff = of_state.index - next.baseline; + writer.write_bits(diff as u64, next.num_bits as usize); + of_state = next; + } + { + let next = ml_table.next_state(ml_code, ml_state.index); + let diff = ml_state.index - next.baseline; + writer.write_bits(diff as u64, next.num_bits as usize); + ml_state = next; + } + { + let next = ll_table.next_state(ll_code, ll_state.index); + let diff = ll_state.index - next.baseline; + writer.write_bits(diff as u64, next.num_bits as usize); + ll_state = next; + } + + writer.write_bits(ll_add_bits, ll_num_bits); + writer.write_bits(ml_add_bits, ml_num_bits); + writer.write_bits(of_add_bits, of_num_bits); + } + } + writer.write_bits(ml_state.index as u64, ml_table.table_size.ilog2() as usize); + writer.write_bits(of_state.index as u64, of_table.table_size.ilog2() as usize); + writer.write_bits(ll_state.index as u64, ll_table.table_size.ilog2() as usize); + + let bits_to_fill = writer.misaligned(); + if bits_to_fill == 0 { + writer.write_bits(1u32, 8); + } else { + writer.write_bits(1u32, bits_to_fill); + } + } + writer.flush(); +} + +fn encode_seqnum(seqnum: usize, writer: &mut BitWriter>>) { + const UPPER_LIMIT: usize = 0xFFFF + 0x7F00; + match seqnum { + 1..=127 => writer.write_bits(seqnum as u32, 8), + 128..=0x7FFF => { + let upper = ((seqnum >> 8) | 0x80) as u8; + let lower = seqnum as u8; + writer.write_bits(upper, 8); + writer.write_bits(lower, 8); + } + 0x8000..=UPPER_LIMIT => { + let encode = seqnum - 0x7F00; + let upper = (encode >> 8) as u8; + let lower = encode as u8; + writer.write_bits(255u8, 8); + writer.write_bits(upper, 8); + writer.write_bits(lower, 8); + } + _ => unreachable!(), + } +} + +fn encode_literal_length(len: u32) -> (u8, u32, usize) { + match len { + 0..=15 => (len as u8, 0, 0), + 16..=17 => (16, len - 16, 1), + 18..=19 => (17, len - 18, 1), + 20..=21 => (18, len - 20, 1), + 22..=23 => (19, len - 22, 1), + 24..=27 => (20, len - 24, 2), + 28..=31 => (21, len - 28, 2), + 32..=39 => (22, len - 32, 3), + 40..=47 => (23, len - 40, 3), + 48..=63 => (24, len - 48, 4), + 64..=127 => (25, len - 64, 6), + 128..=255 => (26, len - 128, 7), + 256..=511 => (27, len - 256, 8), + 512..=1023 => (28, len - 512, 9), + 1024..=2047 => (29, len - 1024, 10), + 2048..=4095 => (30, len - 2048, 11), + 4096..=8191 => (31, len - 4096, 12), + 8192..=16383 => (32, len - 8192, 13), + 16384..=32767 => (33, len - 16384, 14), + 32768..=65535 => (34, len - 32768, 15), + 65536..=131071 => (35, len - 65536, 16), + 131072.. => unreachable!(), + } +} + +fn encode_match_len(len: u32) -> (u8, u32, usize) { + match len { + 0..=2 => unreachable!(), + 3..=34 => (len as u8 - 3, 0, 0), + 35..=36 => (32, len - 35, 1), + 37..=38 => (33, len - 37, 1), + 39..=40 => (34, len - 39, 1), + 41..=42 => (35, len - 41, 1), + 43..=46 => (36, len - 43, 2), + 47..=50 => (37, len - 47, 2), + 51..=58 => (38, len - 51, 3), + 59..=66 => (39, len - 59, 3), + 67..=82 => (40, len - 67, 4), + 83..=98 => (41, len - 83, 4), + 99..=130 => (42, len - 99, 5), + 131..=258 => (43, len - 131, 7), + 259..=514 => (44, len - 259, 8), + 515..=1026 => (45, len - 515, 9), + 1027..=2050 => (46, len - 1027, 10), + 2051..=4098 => (47, len - 2051, 11), + 4099..=8194 => (48, len - 4099, 12), + 8195..=16386 => (49, len - 8195, 13), + 16387..=32770 => (50, len - 16387, 14), + 32771..=65538 => (51, len - 32771, 15), + 65539..=131074 => (52, len - 32771, 16), + 131075.. => unreachable!(), + } +} + +fn encode_offset(len: u32) -> (u8, u32, usize) { + let log = len.ilog2(); + let lower = len & ((1 << log) - 1); + (log as u8, lower, log as usize) +} + +fn raw_literals(literals: &[u8], writer: &mut BitWriter<&mut Vec>) { + writer.write_bits(0u8, 2); + writer.write_bits(0b11u8, 2); + writer.write_bits(literals.len() as u32, 20); + writer.append_bytes(literals); +} + +fn compress_literals(literals: &[u8], writer: &mut BitWriter<&mut Vec>) { + let reset_idx = writer.index(); + writer.write_bits(2u8, 2); // compressed literals type + + let encoder_table = huff0_encoder::HuffmanTable::build_from_data(literals); + + let (size_format, size_bits) = match literals.len() { + 0..6 => (0b00u8, 10), + 6..1024 => (0b01, 10), + 1024..16384 => (0b10, 14), + 16384..262144 => (0b11, 18), + _ => unimplemented!("too many literals"), + }; + + writer.write_bits(size_format, 2); + writer.write_bits(literals.len() as u32, size_bits); + let size_index = writer.index(); + writer.write_bits(0u32, size_bits); + let index_before = writer.index(); + let mut encoder = huff0_encoder::HuffmanEncoder::new(encoder_table, writer); + if size_format == 0 { + encoder.encode(literals) + } else { + encoder.encode4x(literals) + }; + let encoded_len = (writer.index() - index_before) / 8; + writer.change_bits(size_index, encoded_len as u64, size_bits); + let total_len = (writer.index() - reset_idx) / 8; + + // If encoded len is bigger than the raw literals we are better off just writing the raw literals here + if total_len >= literals.len() { + writer.reset_to(reset_idx); + raw_literals(literals, writer); + } +} diff --git a/src/encoding/blocks/mod.rs b/src/encoding/blocks/mod.rs index 50f31f06..671ff6fb 100644 --- a/src/encoding/blocks/mod.rs +++ b/src/encoding/blocks/mod.rs @@ -3,6 +3,6 @@ //! //! There are a few different kinds of blocks, and implementations for those kinds are //! in this module. -mod raw; +mod compressed; -pub(super) use raw::*; +pub(super) use compressed::*; diff --git a/src/encoding/blocks/raw.rs b/src/encoding/blocks/raw.rs deleted file mode 100644 index 6ed340ea..00000000 --- a/src/encoding/blocks/raw.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::vec::Vec; - -/// Write the data from input into output. The data is not compressed. -pub(crate) fn compress_raw_block(input: &[u8], output: &mut Vec) { - output.extend_from_slice(input); -} - -#[cfg(test)] -mod tests { - use super::compress_raw_block; - use std::{vec, vec::Vec}; - #[test] - fn raw_block_compressed() { - let mut output: Vec = Vec::new(); - compress_raw_block(&[1, 2, 3], &mut output); - assert_eq!(vec![1_u8, 2, 3], output); - } -} diff --git a/src/encoding/frame_encoder.rs b/src/encoding/frame_encoder.rs index a9899549..deb275b0 100644 --- a/src/encoding/frame_encoder.rs +++ b/src/encoding/frame_encoder.rs @@ -1,12 +1,17 @@ //! Utilities and interfaces for encoding an entire frame. +use alloc::vec::Vec; use core::convert::TryInto; -use std::vec::Vec; -use super::{block_header::BlockHeader, blocks::compress_raw_block, frame_header::FrameHeader}; +use super::{ + block_header::BlockHeader, blocks::compress_block, frame_header::FrameHeader, + match_generator::MatchGeneratorDriver, Matcher, +}; + +use crate::io::{Read, Write}; /// Blocks cannot be larger than 128KB in size. -const MAX_BLOCK_SIZE: usize = 128000; +const MAX_BLOCK_SIZE: usize = 128 * 1024 - 20; /// The compression mode used impacts the speed of compression, /// and resulting compression ratios. Faster compression will result @@ -16,8 +21,6 @@ pub enum CompressionLevel { /// it in a Zstandard frame. Uncompressed, /// This level is roughly equivalent to Zstd compression level 1 - /// - /// UNIMPLEMENTED Fastest, /// This level is roughly equivalent to Zstd level 3, /// or the one used by the official compressor when no level @@ -43,80 +46,140 @@ pub enum CompressionLevel { /// # Examples /// ``` /// use ruzstd::encoding::{FrameCompressor, CompressionLevel}; -/// let mock_data = &[0x1, 0x2, 0x3, 0x4]; +/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4]; +/// let mut output = std::vec::Vec::new(); /// // Initialize a compressor. -/// let compressor = FrameCompressor::new(mock_data, CompressionLevel::Uncompressed); +/// let mut compressor = FrameCompressor::new(mock_data, &mut output, CompressionLevel::Uncompressed); /// -/// let mut output = Vec::new(); /// // `compress` writes the compressed output into the provided buffer. -/// compressor.compress(&mut output); +/// compressor.compress(); /// ``` -pub struct FrameCompressor<'input> { - uncompressed_data: &'input [u8], +pub struct FrameCompressor { + uncompressed_data: R, + compressed_data: W, compression_level: CompressionLevel, } -impl<'input> FrameCompressor<'input> { +impl FrameCompressor { /// Create a new `FrameCompressor` from the provided slice, but don't start compression yet. pub fn new( - uncompressed_data: &'input [u8], + uncompressed_data: R, + compressed_data: W, compression_level: CompressionLevel, - ) -> FrameCompressor<'input> { + ) -> FrameCompressor { Self { uncompressed_data, + compressed_data, compression_level, } } /// Compress the uncompressed data into a valid Zstd frame and write it into the provided buffer - pub fn compress(&self, output: &mut Vec) { + pub fn compress(&mut self) { + let mut output = Vec::with_capacity(1024 * 130); + let output = &mut output; let header = FrameHeader { - frame_content_size: Some(self.uncompressed_data.len().try_into().unwrap()), - single_segment: true, + frame_content_size: None, + single_segment: false, content_checksum: false, dictionary_id: None, - window_size: None, + window_size: Some(256 * 1024), }; header.serialize(output); - // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know) - if self.uncompressed_data.is_empty() { - let header = BlockHeader { - last_block: true, - block_type: crate::blocks::block::BlockType::Raw, - block_size: 0, - }; - // Write the header, then the block - header.serialize(output); - } - match self.compression_level { - CompressionLevel::Uncompressed => { - // Blocks are compressed by writing a header, then writing - // the block in repetition until the last block is reached. - let mut index = 0; - while index < self.uncompressed_data.len() { - let last_block = index + MAX_BLOCK_SIZE > self.uncompressed_data.len(); - // We read till the end of the data, or till the max block size, whichever comes sooner - let block_size = if last_block { - self.uncompressed_data.len() - index - } else { - MAX_BLOCK_SIZE - }; + + let mut matcher = MatchGeneratorDriver::new(1024 * 128, 1); + loop { + let mut uncompressed_data = matcher.get_next_space(); + let mut read_bytes = 0; + let last_block; + 'read_loop: loop { + let new_bytes = self + .uncompressed_data + .read(&mut uncompressed_data[read_bytes..]) + .unwrap(); + if new_bytes == 0 { + last_block = true; + break 'read_loop; + } + read_bytes += new_bytes; + if read_bytes == uncompressed_data.len() { + last_block = false; + break 'read_loop; + } + } + uncompressed_data.resize(read_bytes, 0); + + // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know) + if uncompressed_data.is_empty() { + let header = BlockHeader { + last_block: true, + block_type: crate::blocks::block::BlockType::Raw, + block_size: 0, + }; + // Write the header, then the block + header.serialize(output); + self.compressed_data.write_all(output).unwrap(); + output.clear(); + break; + } + + match self.compression_level { + CompressionLevel::Uncompressed => { let header = BlockHeader { last_block, block_type: crate::blocks::block::BlockType::Raw, - block_size: block_size.try_into().unwrap(), + block_size: read_bytes.try_into().unwrap(), }; // Write the header, then the block header.serialize(output); - compress_raw_block( - &self.uncompressed_data[index..(index + block_size)], - output, - ); - index += block_size; + output.extend_from_slice(&uncompressed_data); + } + CompressionLevel::Fastest => { + if uncompressed_data.iter().all(|x| uncompressed_data[0].eq(x)) { + let rle_byte = uncompressed_data[0]; + matcher.commit_space(uncompressed_data); + matcher.skip_matching(); + let header = BlockHeader { + last_block, + block_type: crate::blocks::block::BlockType::RLE, + block_size: read_bytes.try_into().unwrap(), + }; + // Write the header, then the block + header.serialize(output); + output.push(rle_byte); + } else { + let mut compressed = Vec::new(); + matcher.commit_space(uncompressed_data); + compress_block(&mut matcher, &mut compressed); + if compressed.len() >= MAX_BLOCK_SIZE { + let header = BlockHeader { + last_block, + block_type: crate::blocks::block::BlockType::Raw, + block_size: read_bytes.try_into().unwrap(), + }; + // Write the header, then the block + header.serialize(output); + output.extend_from_slice(matcher.get_last_space()); + } else { + let header = BlockHeader { + last_block, + block_type: crate::blocks::block::BlockType::Compressed, + block_size: (compressed.len()).try_into().unwrap(), + }; + // Write the header, then the block + header.serialize(output); + output.extend(compressed); + } + } + } + _ => { + unimplemented!(); } } - _ => { - unimplemented!(); + self.compressed_data.write_all(output).unwrap(); + output.clear(); + if last_block { + break; } } } @@ -124,24 +187,201 @@ impl<'input> FrameCompressor<'input> { #[cfg(test)] mod tests { + use alloc::vec; + use super::FrameCompressor; - use crate::frame::MAGIC_NUM; - use std::vec::Vec; + use crate::{frame::MAGIC_NUM, FrameDecoder}; + use alloc::vec::Vec; #[test] fn frame_starts_with_magic_num() { - let mock_data = &[1_u8, 2, 3]; - let compressor = FrameCompressor::new(mock_data, super::CompressionLevel::Uncompressed); + let mock_data = [1_u8, 2, 3].as_slice(); let mut output: Vec = Vec::new(); - compressor.compress(&mut output); + let mut compressor = FrameCompressor::new( + mock_data, + &mut output, + super::CompressionLevel::Uncompressed, + ); + compressor.compress(); assert!(output.starts_with(&MAGIC_NUM.to_le_bytes())); } #[test] fn very_simple_raw_compress() { - let mock_data = &[1_u8, 2, 3]; - let compressor = FrameCompressor::new(mock_data, super::CompressionLevel::Uncompressed); + let mock_data = [1_u8, 2, 3].as_slice(); + let mut output: Vec = Vec::new(); + let mut compressor = FrameCompressor::new( + mock_data, + &mut output, + super::CompressionLevel::Uncompressed, + ); + compressor.compress(); + } + + #[test] + fn very_simple_compress() { + let mut mock_data = vec![0; 1 << 17]; + mock_data.extend(vec![1; (1 << 17) - 1]); + mock_data.extend(vec![2; (1 << 18) - 1]); + mock_data.extend(vec![2; 1 << 17]); + mock_data.extend(vec![3; (1 << 17) - 1]); + let mut output: Vec = Vec::new(); + let mut compressor = FrameCompressor::new( + mock_data.as_slice(), + &mut output, + super::CompressionLevel::Uncompressed, + ); + compressor.compress(); + + let mut decoder = FrameDecoder::new(); + let mut decoded = Vec::with_capacity(mock_data.len()); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(mock_data, decoded); + + let mut decoded = Vec::new(); + zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap(); + assert_eq!(mock_data, decoded); + } + + #[test] + fn rle_compress() { + let mock_data = vec![0; 1 << 19]; + let mut output: Vec = Vec::new(); + let mut compressor = FrameCompressor::new( + mock_data.as_slice(), + &mut output, + super::CompressionLevel::Uncompressed, + ); + compressor.compress(); + + let mut decoder = FrameDecoder::new(); + let mut decoded = Vec::with_capacity(mock_data.len()); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(mock_data, decoded); + } + + #[test] + fn aaa_compress() { + let mock_data = vec![0, 1, 3, 4, 5]; let mut output: Vec = Vec::new(); - compressor.compress(&mut output); + let mut compressor = FrameCompressor::new( + mock_data.as_slice(), + &mut output, + super::CompressionLevel::Uncompressed, + ); + compressor.compress(); + + let mut decoder = FrameDecoder::new(); + let mut decoded = Vec::with_capacity(mock_data.len()); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(mock_data, decoded); + + let mut decoded = Vec::new(); + zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap(); + assert_eq!(mock_data, decoded); + } + + #[cfg(feature = "std")] + #[test] + fn fuzz_targets() { + use std::io::Read; + fn decode_ruzstd(data: &mut dyn std::io::Read) -> Vec { + let mut decoder = crate::StreamingDecoder::new(data).unwrap(); + let mut result: Vec = Vec::new(); + decoder.read_to_end(&mut result).expect("Decoding failed"); + result + } + + fn decode_ruzstd_writer(mut data: impl Read) -> Vec { + let mut decoder = crate::FrameDecoder::new(); + decoder.reset(&mut data).unwrap(); + let mut result = vec![]; + while !decoder.is_finished() || decoder.can_collect() > 0 { + decoder + .decode_blocks( + &mut data, + crate::BlockDecodingStrategy::UptoBytes(1024 * 1024), + ) + .unwrap(); + decoder.collect_to_writer(&mut result).unwrap(); + } + result + } + + fn encode_zstd(data: &[u8]) -> Result, std::io::Error> { + zstd::stream::encode_all(std::io::Cursor::new(data), 3) + } + + fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec { + let mut input = Vec::new(); + data.read_to_end(&mut input).unwrap(); + let mut output = Vec::new(); + + let mut compressor = crate::encoding::FrameCompressor::new( + input.as_slice(), + &mut output, + crate::encoding::CompressionLevel::Uncompressed, + ); + compressor.compress(); + output + } + + fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec { + let mut input = Vec::new(); + data.read_to_end(&mut input).unwrap(); + let mut output = Vec::new(); + + let mut compressor = crate::encoding::FrameCompressor::new( + input.as_slice(), + &mut output, + crate::encoding::CompressionLevel::Uncompressed, + ); + compressor.compress(); + output + } + + fn decode_zstd(data: &[u8]) -> Result, std::io::Error> { + let mut output = Vec::new(); + zstd::stream::copy_decode(data, &mut output)?; + Ok(output) + } + if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) { + for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() { + if file.as_ref().unwrap().file_type().unwrap().is_file() { + let data = std::fs::read(file.unwrap().path()).unwrap(); + let data = data.as_slice(); + // Decoding + let compressed = encode_zstd(data).unwrap(); + let decoded = decode_ruzstd(&mut compressed.as_slice()); + let decoded2 = decode_ruzstd_writer(&mut compressed.as_slice()); + assert!( + decoded == data, + "Decoded data did not match the original input during decompression" + ); + assert_eq!( + decoded2, data, + "Decoded data did not match the original input during decompression" + ); + + // Encoding + // Uncompressed encoding + let mut input = data; + let compressed = encode_ruzstd_uncompressed(&mut input); + let decoded = decode_zstd(&compressed).unwrap(); + assert_eq!( + decoded, data, + "Decoded data did not match the original input during compression" + ); + // Compressed encoding + let mut input = data; + let compressed = encode_ruzstd_compressed(&mut input); + let decoded = decode_zstd(&compressed).unwrap(); + assert_eq!( + decoded, data, + "Decoded data did not match the original input during compression" + ); + } + } + } } } diff --git a/src/encoding/frame_header.rs b/src/encoding/frame_header.rs index 2412aa1a..beadf2b1 100644 --- a/src/encoding/frame_header.rs +++ b/src/encoding/frame_header.rs @@ -4,7 +4,7 @@ use crate::encoding::{ util::{find_min_size, minify_val}, }; use crate::frame; -use std::vec::Vec; +use alloc::vec::Vec; /// A header for a single Zstandard frame. /// @@ -44,11 +44,10 @@ impl FrameHeader { // `Window_Descriptor // TODO: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor - // if !self.single_segment { - // unimplemented!( - // "Support for using window size over frame content size is not implemented" - // ); - // } + if !self.single_segment { + let exponent = 7; + output.push(exponent << 3); + } if let Some(id) = self.dictionary_id { output.extend(minify_val(id)); @@ -80,65 +79,66 @@ impl FrameHeader { // | 1 | 2 // | 2 | 4 // | 3 | 8 - if let Some(frame_content_size) = self.frame_content_size { - let field_size = find_min_size(frame_content_size); - let flag_value: u8 = match field_size { - 1 => 0, - 2 => 1, - 4 => 2, - 3 => 8, + + // `Dictionary_ID_flag`: + if let Some(id) = self.dictionary_id { + let flag_value: u8 = match find_min_size(id) { + 0 => 0, + 1 => 1, + 2 => 2, + 4 => 3, _ => panic!(), }; + bw.write_bits(flag_value, 2); + } else { + // A `Dictionary_ID` was not provided + bw.write_bits(0u8, 2); + } - bw.write_bits(&[flag_value], 2); + // `Content_Checksum_flag`: + if self.content_checksum { + bw.write_bits(1u8, 1); } else { - // `Frame_Content_Size` was not provided - bw.write_bits(&[0], 2); + bw.write_bits(0u8, 1); } + // `Reserved_bit`: + // This value must be zero + bw.write_bits(0u8, 1); + + // `Unused_bit`: + // An encoder compliant with this spec must set this bit to zero + bw.write_bits(0u8, 1); + // `Single_Segment_flag`: // If this flag is set, data must be regenerated within a single continuous memory segment, // and the `Frame_Content_Size` field must be present in the header. // If this flag is not set, the `Window_Descriptor` field must be present in the frame header. if self.single_segment { assert!(self.frame_content_size.is_some(), "if the `single_segment` flag is set to true, then a frame content size must be provided"); - bw.write_bits(&[1], 1); + bw.write_bits(1u8, 1); } else { assert!( self.window_size.is_some(), "if the `single_segment` flag is set to false, then a window size must be provided" ); - bw.write_bits(&[0], 1); + bw.write_bits(0u8, 1); } - // `Unused_bit`: - // An encoder compliant with this spec must set this bit to zero - bw.write_bits(&[0], 1); - - // `Reserved_bit`: - // This value must be zero - bw.write_bits(&[0], 1); - - // `Content_Checksum_flag`: - if self.content_checksum { - bw.write_bits(&[1], 1); - } else { - bw.write_bits(&[0], 1); - } - - // `Dictionary_ID_flag`: - if let Some(id) = self.dictionary_id { - let flag_value: u8 = match find_min_size(id) { - 0 => 0, - 1 => 1, - 2 => 2, - 4 => 3, + if let Some(frame_content_size) = self.frame_content_size { + let field_size = find_min_size(frame_content_size); + let flag_value: u8 = match field_size { + 1 => 0, + 2 => 1, + 4 => 2, + 3 => 8, _ => panic!(), }; - bw.write_bits(&[flag_value], 2); + + bw.write_bits(flag_value, 2); } else { - // A `Dictionary_ID` was not provided - bw.write_bits(&[0], 2); + // `Frame_Content_Size` was not provided + bw.write_bits(0u8, 2); } bw.dump()[0] @@ -163,7 +163,7 @@ fn minify_val_fcs(val: u64) -> Vec { mod tests { use super::FrameHeader; use crate::frame::{read_frame_header, FrameDescriptor}; - use std::vec::Vec; + use alloc::vec::Vec; #[test] fn frame_header_descriptor_decode() { diff --git a/src/encoding/match_generator.rs b/src/encoding/match_generator.rs new file mode 100644 index 00000000..63e72624 --- /dev/null +++ b/src/encoding/match_generator.rs @@ -0,0 +1,581 @@ +use alloc::vec::Vec; +use core::num::NonZeroUsize; + +use super::Matcher; + +const MIN_MATCH_LEN: usize = 5; + +/// Takes care of allocating and reusing vecs +pub(crate) struct MatchGeneratorDriver { + vec_pool: Vec>, + suffix_pool: Vec, + match_generator: MatchGenerator, + slice_size: usize, +} + +impl MatchGeneratorDriver { + /// slice_size says how big the slices should be that are allocated to work with + /// max_slices_in_window says how many slices should at most be used while looking for matches + pub(crate) fn new(slice_size: usize, max_slices_in_window: usize) -> Self { + Self { + vec_pool: Vec::new(), + suffix_pool: Vec::new(), + match_generator: MatchGenerator::new(max_slices_in_window * slice_size), + slice_size, + } + } +} + +impl Matcher for MatchGeneratorDriver { + fn get_next_space(&mut self) -> Vec { + self.vec_pool.pop().unwrap_or_else(|| { + let mut space = alloc::vec![0; self.slice_size]; + space.resize(space.capacity(), 0); + space + }) + } + + fn get_last_space(&mut self) -> &[u8] { + self.match_generator.window.last().unwrap().data.as_slice() + } + + fn commit_space(&mut self, space: Vec) { + let vec_pool = &mut self.vec_pool; + let suffixes = self + .suffix_pool + .pop() + .unwrap_or_else(|| SuffixStore::with_capacity(space.len())); + let suffix_pool = &mut self.suffix_pool; + self.match_generator + .add_data(space, suffixes, |mut data, mut suffixes| { + data.resize(data.capacity(), 0); + vec_pool.push(data); + suffixes.slots.clear(); + suffixes.slots.resize(suffixes.slots.capacity(), None); + suffix_pool.push(suffixes); + }); + } + + fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) { + while self.match_generator.next_sequence(&mut handle_sequence) {} + } + fn skip_matching(&mut self) { + self.match_generator.skip_matching(); + } +} + +/// This stores the index of a suffix of a string by hashing the first few bytes of that suffix +/// This means that collisions just overwrite and that you need to check validity after a get +struct SuffixStore { + // We use NonZeroUsize to enable niche optimization here. + // On store we do +1 and on get -1 + // This is ok since usize::MAX is never a valid offset + slots: Vec>, + len_log: u32, +} + +impl SuffixStore { + fn with_capacity(capacity: usize) -> Self { + Self { + slots: alloc::vec![None; capacity], + len_log: capacity.ilog2(), + } + } + + #[inline(always)] + fn insert(&mut self, suffix: &[u8], idx: usize) { + let key = self.key(suffix); + self.slots[key] = Some(NonZeroUsize::new(idx + 1).unwrap()); + } + + #[inline(always)] + fn contains_key(&self, suffix: &[u8]) -> bool { + let key = self.key(suffix); + self.slots[key].is_some() + } + + #[inline(always)] + fn get(&self, suffix: &[u8]) -> Option { + let key = self.key(suffix); + self.slots[key].map(|x| >::into(x) - 1) + } + + #[inline(always)] + fn key(&self, suffix: &[u8]) -> usize { + let s0 = suffix[0] as u64; + let s1 = suffix[1] as u64; + let s2 = suffix[2] as u64; + let s3 = suffix[3] as u64; + let s4 = suffix[4] as u64; + + const POLY: u64 = 0xCF3BCCDCABu64; + + let s0 = (s0 << 24).wrapping_mul(POLY); + let s1 = (s1 << 32).wrapping_mul(POLY); + let s2 = (s2 << 40).wrapping_mul(POLY); + let s3 = (s3 << 48).wrapping_mul(POLY); + let s4 = (s4 << 56).wrapping_mul(POLY); + + let index = s0 ^ s1 ^ s2 ^ s3 ^ s4; + let index = index >> (64 - self.len_log); + index as usize % self.slots.len() + } +} + +/// We keep a window of a few of these entries +/// All of these are valid targets for a match to be generated for +struct WindowEntry { + data: Vec, + /// Stores indexes into data + suffixes: SuffixStore, + /// Makes offset calculations efficient + base_offset: usize, +} + +pub(crate) struct MatchGenerator { + max_window_size: usize, + /// Data window we are operating on to find matches + /// The data we want to find matches for is in the last slice + window: Vec, + window_size: usize, + #[cfg(debug_assertions)] + concat_window: Vec, + /// Index in the last slice that we already processed + suffix_idx: usize, + /// Gets updated when a new sequence is returned to point right behind that sequence + last_idx_in_sequence: usize, +} + +#[derive(PartialEq, Eq, Debug)] +pub(crate) enum Sequence<'data> { + Triple { + literals: &'data [u8], + offset: usize, + match_len: usize, + }, + Literals { + literals: &'data [u8], + }, +} + +impl MatchGenerator { + /// max_size defines how many bytes will be used at most in the window used for matching + fn new(max_size: usize) -> Self { + Self { + max_window_size: max_size, + window: Vec::new(), + window_size: 0, + #[cfg(debug_assertions)] + concat_window: Vec::new(), + suffix_idx: 0, + last_idx_in_sequence: 0, + } + } + + /// Processes bytes in the current window until either a match is found or no more matches can be found + /// * If a match is found handle_sequence is called with the Triple variant + /// * If no more matches can be found but there are bytes still left handle_sequence is called with the Literals variant + /// * If no more matches can be found and no more bytes are left this returns false + fn next_sequence(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) -> bool { + loop { + let last_entry = self.window.last().unwrap(); + let data_slice = &last_entry.data; + + // We already reached the end of the window, check if we need to return a Literals{} + if self.suffix_idx >= data_slice.len() { + if self.last_idx_in_sequence != self.suffix_idx { + let literals = &data_slice[self.last_idx_in_sequence..]; + self.last_idx_in_sequence = self.suffix_idx; + handle_sequence(Sequence::Literals { literals }); + return true; + } else { + return false; + } + } + + // If the remaining data is smaller than the minimum match length we can stop and return a Literals{} + let data_slice = &data_slice[self.suffix_idx..]; + if data_slice.len() < MIN_MATCH_LEN { + let last_idx_in_sequence = self.last_idx_in_sequence; + self.last_idx_in_sequence = last_entry.data.len(); + self.suffix_idx = last_entry.data.len(); + handle_sequence(Sequence::Literals { + literals: &last_entry.data[last_idx_in_sequence..], + }); + return true; + } + + // This is the key we are looking to find a match for + let key = &data_slice[..MIN_MATCH_LEN]; + + // Look in each window entry + for (match_entry_idx, match_entry) in self.window.iter().enumerate() { + let is_last = match_entry_idx == self.window.len() - 1; + if let Some(match_index) = match_entry.suffixes.get(key) { + let match_slice = if is_last { + &match_entry.data[match_index..self.suffix_idx] + } else { + &match_entry.data[match_index..] + }; + + // Check how long the common prefix actually is + let match_len = Self::common_prefix_len(match_slice, data_slice); + + // Collisions in the suffix store might make this check fail + if match_len >= MIN_MATCH_LEN { + let offset = match_entry.base_offset + self.suffix_idx - match_index; + + // If we are in debug/tests make sure the match we found is actually at the offset we calculated + #[cfg(debug_assertions)] + { + let unprocessed = last_entry.data.len() - self.suffix_idx; + let start = self.concat_window.len() - unprocessed - offset; + let end = start + match_len; + let check_slice = &self.concat_window[start..end]; + debug_assert_eq!(check_slice, &match_slice[..match_len]); + } + + // For each index in the match we found we do not need to look for another match + // But we still want them registered in the suffix store + self.add_suffixes_till(self.suffix_idx + match_len); + + // All literals that were not included between this match and the last are now included here + let last_entry = self.window.last().unwrap(); + let literals = &last_entry.data[self.last_idx_in_sequence..self.suffix_idx]; + + // Update the indexes, all indexes upto and including the current index have been included in a sequence now + self.suffix_idx += match_len; + self.last_idx_in_sequence = self.suffix_idx; + handle_sequence(Sequence::Triple { + literals, + offset, + match_len, + }); + + return true; + } + } + } + + let last_entry = self.window.last_mut().unwrap(); + let key = &last_entry.data[self.suffix_idx..self.suffix_idx + MIN_MATCH_LEN]; + if !last_entry.suffixes.contains_key(key) { + last_entry.suffixes.insert(key, self.suffix_idx); + } + self.suffix_idx += 1; + } + } + + /// Find the common prefix length between two byte slices + #[inline(always)] + fn common_prefix_len(a: &[u8], b: &[u8]) -> usize { + Self::mismatch_chunks::<8>(a, b) + } + + /// Find the common prefix length between two byte slices with a configurable chunk length + /// This enables vectorization optimizations + fn mismatch_chunks(xs: &[u8], ys: &[u8]) -> usize { + let off = core::iter::zip(xs.chunks_exact(N), ys.chunks_exact(N)) + .take_while(|(x, y)| x == y) + .count() + * N; + off + core::iter::zip(&xs[off..], &ys[off..]) + .take_while(|(x, y)| x == y) + .count() + } + + /// Process bytes and add the suffixes to the suffix store up to a specific index + #[inline(always)] + fn add_suffixes_till(&mut self, idx: usize) { + let last_entry = self.window.last_mut().unwrap(); + if last_entry.data.len() < MIN_MATCH_LEN { + return; + } + let slice = &last_entry.data[self.suffix_idx..idx]; + for (key_index, key) in slice.windows(MIN_MATCH_LEN).enumerate() { + if !last_entry.suffixes.contains_key(key) { + last_entry.suffixes.insert(key, self.suffix_idx + key_index); + } + } + } + + /// Skip matching for the whole current window entry + fn skip_matching(&mut self) { + let len = self.window.last().unwrap().data.len(); + self.add_suffixes_till(len); + self.suffix_idx = len; + self.last_idx_in_sequence = len; + } + + /// Add a new window entry. Will panic if the last window entry hasn't been processed properly. + /// If any resources are released by pushing the new entry they are returned via the callback + fn add_data( + &mut self, + data: Vec, + suffixes: SuffixStore, + reuse_space: impl FnMut(Vec, SuffixStore), + ) { + assert!( + self.window.is_empty() || self.suffix_idx == self.window.last().unwrap().data.len() + ); + self.reserve(data.len(), reuse_space); + #[cfg(debug_assertions)] + self.concat_window.extend_from_slice(&data); + + if let Some(last_len) = self.window.last().map(|last| last.data.len()) { + for entry in self.window.iter_mut() { + entry.base_offset += last_len; + } + } + + let len = data.len(); + self.window.push(WindowEntry { + data, + suffixes, + base_offset: 0, + }); + self.window_size += len; + self.suffix_idx = 0; + self.last_idx_in_sequence = 0; + } + + /// Reserve space for a new window entry + /// If any resources are released by pushing the new entry they are returned via the callback + fn reserve(&mut self, amount: usize, mut reuse_space: impl FnMut(Vec, SuffixStore)) { + assert!(self.max_window_size >= amount); + while self.window_size + amount > self.max_window_size { + let removed = self.window.remove(0); + self.window_size -= removed.data.len(); + #[cfg(debug_assertions)] + self.concat_window.drain(0..removed.data.len()); + + let WindowEntry { + suffixes, + data: leaked_vec, + base_offset: _, + } = removed; + reuse_space(leaked_vec, suffixes); + } + } +} + +#[test] +fn matches() { + let mut matcher = MatchGenerator::new(1000); + let mut original_data = Vec::new(); + let mut reconstructed = Vec::new(); + + let assert_seq_equal = |seq1: Sequence<'_>, seq2: Sequence<'_>, reconstructed: &mut Vec| { + assert_eq!(seq1, seq2); + match seq2 { + Sequence::Literals { literals } => reconstructed.extend_from_slice(literals), + Sequence::Triple { + literals, + offset, + match_len, + } => { + reconstructed.extend_from_slice(literals); + let start = reconstructed.len() - offset; + let end = start + match_len; + reconstructed.extend_from_within(start..end); + } + } + }; + + matcher.add_data( + alloc::vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[0, 0, 0, 0, 0], + offset: 5, + match_len: 5, + }, + &mut reconstructed, + ) + }); + + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0,], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[ + 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, + ]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[1, 2, 3, 4, 5, 6], + offset: 6, + match_len: 6, + }, + &mut reconstructed, + ) + }); + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[], + offset: 12, + match_len: 6, + }, + &mut reconstructed, + ) + }); + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[], + offset: 28, + match_len: 5, + }, + &mut reconstructed, + ) + }); + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0, 0, 0], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0, 0, 0]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[], + offset: 23, + match_len: 6, + }, + &mut reconstructed, + ) + }); + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[7, 8, 9, 10, 11], + offset: 44, + match_len: 5, + }, + &mut reconstructed, + ) + }); + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![0, 0, 0, 0, 0], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[0, 0, 0, 0, 0]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[], + offset: 49, + match_len: 5, + }, + &mut reconstructed, + ) + }); + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![7, 8, 9, 10, 11], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[7, 8, 9, 10, 11]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[], + offset: 15, + match_len: 5, + }, + &mut reconstructed, + ) + }); + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![1, 3, 5, 7, 9], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + matcher.skip_matching(); + original_data.extend_from_slice(&[1, 3, 5, 7, 9]); + reconstructed.extend_from_slice(&[1, 3, 5, 7, 9]); + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![1, 3, 5, 7, 9], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[1, 3, 5, 7, 9]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[], + offset: 5, + match_len: 5, + }, + &mut reconstructed, + ) + }); + assert!(!matcher.next_sequence(|_| {})); + + matcher.add_data( + alloc::vec![0, 0, 11, 13, 15, 17, 20, 11, 13, 15, 17, 20, 21, 23], + SuffixStore::with_capacity(100), + |_, _| {}, + ); + original_data.extend_from_slice(&[0, 0, 11, 13, 15, 17, 20, 11, 13, 15, 17, 20, 21, 23]); + + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Triple { + literals: &[0, 0, 11, 13, 15, 17, 20], + offset: 5, + match_len: 5, + }, + &mut reconstructed, + ) + }); + matcher.next_sequence(|seq| { + assert_seq_equal( + seq, + Sequence::Literals { + literals: &[21, 23], + }, + &mut reconstructed, + ) + }); + assert!(!matcher.next_sequence(|_| {})); + + assert_eq!(reconstructed, original_data); +} diff --git a/src/encoding/mod.rs b/src/encoding/mod.rs index 0b53f871..357776ca 100644 --- a/src/encoding/mod.rs +++ b/src/encoding/mod.rs @@ -4,6 +4,48 @@ pub(crate) mod bit_writer; pub(crate) mod block_header; pub(crate) mod blocks; mod frame_encoder; -pub use frame_encoder::*; pub(crate) mod frame_header; +pub(crate) mod match_generator; pub(crate) mod util; + +use crate::io::{Read, Write}; +use alloc::vec::Vec; +pub use frame_encoder::*; +use match_generator::Sequence; + +/// Convenience function to compress some source into a target without reusing any resources of the compressor +/// ```rust +/// use ruzstd::encoding::{compress, CompressionLevel}; +/// let data: &[u8] = &[0,0,0,0,0,0,0,0,0,0,0,0]; +/// let mut target = Vec::new(); +/// compress(data, &mut target, CompressionLevel::Fastest); +/// ``` +pub fn compress(source: R, target: W, level: CompressionLevel) { + let mut frame_enc = FrameCompressor::new(source, target, level); + frame_enc.compress(); +} + +/// Convenience function to compress some source into a target without reusing any resources of the compressor into a Vec +/// ```rust +/// use ruzstd::encoding::{compress_to_vec, CompressionLevel}; +/// let data: &[u8] = &[0,0,0,0,0,0,0,0,0,0,0,0]; +/// let compressed = compress_to_vec(data, CompressionLevel::Fastest); +/// ``` +pub fn compress_to_vec(source: R, level: CompressionLevel) -> Vec { + let mut vec = Vec::new(); + compress(source, &mut vec, level); + vec +} + +pub(crate) trait Matcher { + /// Get a space where we can put data to be matched on + fn get_next_space(&mut self) -> alloc::vec::Vec; + /// Get a reference to the last commited space + fn get_last_space(&mut self) -> &[u8]; + /// Commit a space to the matcher so it can be matched against + fn commit_space(&mut self, space: alloc::vec::Vec); + /// Just process the data in the last commited space for future matching + fn skip_matching(&mut self); + /// Process the data in the last commited space for future matching AND generate matches for the data + fn start_matching(&mut self, handle_sequence: impl for<'a> FnMut(Sequence<'a>)); +} diff --git a/src/encoding/util.rs b/src/encoding/util.rs index 1a1103f1..920b1a4c 100644 --- a/src/encoding/util.rs +++ b/src/encoding/util.rs @@ -1,4 +1,4 @@ -use std::vec::Vec; +use alloc::vec::Vec; /// Returns the minimum number of bytes needed to represent this value, as /// either 1, 2, 4, or 8 bytes. A value of 0 will still return one byte. @@ -33,7 +33,7 @@ pub fn minify_val(val: u64) -> Vec { mod tests { use super::find_min_size; use super::minify_val; - use std::vec; + use alloc::vec; #[test] fn min_size_detection() { diff --git a/src/fse/fse_decoder.rs b/src/fse/fse_decoder.rs index 7bb2d5e3..45adc931 100644 --- a/src/fse/fse_decoder.rs +++ b/src/fse/fse_decoder.rs @@ -6,6 +6,7 @@ use alloc::vec::Vec; /// all literals from 0 to the highest present one /// /// +#[derive(Debug)] pub struct FSETable { /// The maximum symbol in the table (inclusive). Limits the probabilities length to max_symbol + 1. max_symbol: u8, @@ -144,7 +145,7 @@ impl From for FSEDecoderError { } /// A single entry in an FSE table. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct Entry { /// This value is used as an offset value, and it is added /// to a value read from the stream to determine the next state value. @@ -188,7 +189,8 @@ impl<'t> FSEDecoder<'t> { if self.table.accuracy_log == 0 { return Err(FSEDecoderError::TableIsUninitialized); } - self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize]; + let new_state = bits.get_bits(self.table.accuracy_log); + self.state = self.table.decode[new_state as usize]; Ok(()) } @@ -444,6 +446,9 @@ fn calc_baseline_and_numbits( num_states_symbol: u32, state_number: u32, ) -> (u32, u8) { + if num_states_symbol == 0 { + return (0, 0); + } let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol { num_states_symbol } else { diff --git a/src/fse/fse_encoder.rs b/src/fse/fse_encoder.rs new file mode 100644 index 00000000..2f2e9a15 --- /dev/null +++ b/src/fse/fse_encoder.rs @@ -0,0 +1,414 @@ +use crate::encoding::bit_writer::BitWriter; +use alloc::vec::Vec; + +pub(crate) struct FSEEncoder<'output, V: AsMut>> { + pub(super) table: FSETable, + writer: &'output mut BitWriter, +} + +impl>> FSEEncoder<'_, V> { + pub fn new(table: FSETable, writer: &mut BitWriter) -> FSEEncoder<'_, V> { + FSEEncoder { table, writer } + } + + pub fn into_table(self) -> FSETable { + self.table + } + + /// Encodes the data using the provided table + /// Writes + /// * Table description + /// * Encoded data + /// * Last state index + /// * Padding bits to fill up last byte + pub fn encode(&mut self, data: &[u8]) { + self.write_table(); + + let mut state = self.table.start_state(data[data.len() - 1]); + for x in data[0..data.len() - 1].iter().rev().copied() { + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state = next; + } + self.writer + .write_bits(state.index as u64, self.acc_log() as usize); + + let bits_to_fill = self.writer.misaligned(); + if bits_to_fill == 0 { + self.writer.write_bits(1u32, 8); + } else { + self.writer.write_bits(1u32, bits_to_fill); + } + } + + /// Encodes the data using the provided table but with two interleaved streams + /// Writes + /// * Table description + /// * Encoded data with two interleaved states + /// * Both Last state indexes + /// * Padding bits to fill up last byte + pub fn encode_interleaved(&mut self, data: &[u8]) { + self.write_table(); + + let mut state_1 = self.table.start_state(data[data.len() - 1]); + let mut state_2 = self.table.start_state(data[data.len() - 2]); + + // The first two symbols are represented by the start states + // Then encode the state transitions for two symbols at a time + let mut idx = data.len() - 4; + loop { + { + let state = state_1; + let x = data[idx + 1]; + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state_1 = next; + } + { + let state = state_2; + let x = data[idx]; + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state_2 = next; + } + + if idx < 2 { + break; + } + idx -= 2; + } + + // Determine if we have an even or odd number of symbols to encode + // If odd we need to encode the last states transition and encode the final states in the flipped order + if idx == 1 { + let state = state_1; + let x = data[0]; + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state_1 = next; + + self.writer + .write_bits(state_2.index as u64, self.acc_log() as usize); + self.writer + .write_bits(state_1.index as u64, self.acc_log() as usize); + } else { + self.writer + .write_bits(state_1.index as u64, self.acc_log() as usize); + self.writer + .write_bits(state_2.index as u64, self.acc_log() as usize); + } + + let bits_to_fill = self.writer.misaligned(); + if bits_to_fill == 0 { + self.writer.write_bits(1u32, 8); + } else { + self.writer.write_bits(1u32, bits_to_fill); + } + } + + fn write_table(&mut self) { + self.writer.write_bits(self.acc_log() - 5, 4); + let mut probability_counter = 0usize; + let probability_sum = 1 << self.acc_log(); + + let mut prob_idx = 0; + while probability_counter < probability_sum { + let max_remaining_value = probability_sum - probability_counter + 1; + let bits_to_write = max_remaining_value.ilog2() + 1; + let low_threshold = ((1 << bits_to_write) - 1) - (max_remaining_value); + let mask = (1 << (bits_to_write - 1)) - 1; + + let prob = self.table.states[prob_idx].probability; + prob_idx += 1; + let value = (prob + 1) as u32; + if value < low_threshold as u32 { + self.writer.write_bits(value, bits_to_write as usize - 1); + } else if value > mask { + self.writer + .write_bits(value + low_threshold as u32, bits_to_write as usize); + } else { + self.writer.write_bits(value, bits_to_write as usize); + } + + if prob == -1 { + probability_counter += 1; + } else if prob > 0 { + probability_counter += prob as usize; + } else { + let mut zeros = 0u8; + while self.table.states[prob_idx].probability == 0 { + zeros += 1; + prob_idx += 1; + if zeros == 3 { + self.writer.write_bits(3u8, 2); + zeros = 0; + } + } + self.writer.write_bits(zeros, 2); + } + } + self.writer.write_bits(0u8, self.writer.misaligned()); + } + + pub(super) fn acc_log(&self) -> u8 { + self.table.table_size.ilog2() as u8 + } +} + +#[derive(Debug)] +pub struct FSETable { + /// Indexed by symbol + pub(super) states: [SymbolStates; 256], + /// Sum of all states.states.len() + pub(crate) table_size: usize, +} + +impl FSETable { + pub(crate) fn next_state(&self, symbol: u8, idx: usize) -> &State { + let states = &self.states[symbol as usize]; + states.get(idx, self.table_size) + } + + pub(crate) fn start_state(&self, symbol: u8) -> &State { + let states = &self.states[symbol as usize]; + &states.states[0] + } +} + +#[derive(Debug)] +pub(super) struct SymbolStates { + /// Sorted by baseline to allow easy lookup using an index + pub(super) states: Vec, + pub(super) probability: i32, +} + +impl SymbolStates { + fn get(&self, idx: usize, max_idx: usize) -> &State { + let start_search_at = (idx * self.states.len()) / max_idx; + + self.states[start_search_at..] + .iter() + .find(|state| state.contains(idx)) + .unwrap() + } +} + +#[derive(Debug)] +pub(crate) struct State { + /// How many bits the range of this state needs to be encoded as + pub(crate) num_bits: u8, + /// The first index targeted by this state + pub(crate) baseline: usize, + /// The last index targeted by this state (baseline + the maximum number with numbits bits allows) + pub(crate) last_index: usize, + /// Index of this state in the decoding table + pub(crate) index: usize, +} + +impl State { + fn contains(&self, idx: usize) -> bool { + self.baseline <= idx && self.last_index >= idx + } +} + +pub fn build_table_from_data(data: &[u8], max_log: u8, avoid_0_numbit: bool) -> FSETable { + let mut counts = [0; 256]; + for x in data { + counts[*x as usize] += 1; + } + build_table_from_counts(&counts, max_log, avoid_0_numbit) +} + +fn build_table_from_counts(counts: &[usize], max_log: u8, avoid_0_numbit: bool) -> FSETable { + let mut probs = [0; 256]; + let mut min_count = 0; + for (idx, count) in counts.iter().copied().enumerate() { + probs[idx] = count as i32; + if count > 0 && (count < min_count || min_count == 0) { + min_count = count; + } + } + + // shift all probabilities down so that the lowest are 1 + min_count -= 1; + for prob in probs.iter_mut() { + if *prob > 0 { + *prob -= min_count as i32; + } + } + + // normalize probabilities to a 2^x + let sum = probs.iter().sum::(); + assert!(sum > 0); + let sum = sum as usize; + let acc_log = (sum.ilog2() as u8 + 1).max(5); + let acc_log = u8::min(acc_log, max_log); + + if sum < 1 << acc_log { + // just raise the maximum probability as much as possible + // TODO is this optimal? + let diff = (1 << acc_log) - sum; + let max = probs.iter_mut().max().unwrap(); + *max += diff as i32; + } else { + // decrease the smallest ones to 1 first + let mut diff = sum - (1 << max_log); + while diff > 0 { + let min = probs.iter_mut().filter(|prob| **prob > 1).min().unwrap(); + let decrease = usize::min(*min as usize - 1, diff); + diff -= decrease; + *min -= decrease as i32; + } + } + let max = probs.iter_mut().max().unwrap(); + if avoid_0_numbit && *max > 1 << (acc_log - 1) { + let redistribute = *max - (1 << (acc_log - 1)); + *max -= redistribute; + let max = *max; + + // find first occurence of the second_max to avoid lifting the last zero + let second_max = *probs.iter_mut().filter(|x| **x != max).max().unwrap(); + let second_max = probs.iter_mut().find(|x| **x == second_max).unwrap(); + *second_max += redistribute; + assert!(*second_max <= max); + } + build_table_from_probabilities(&probs, acc_log) +} + +pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSETable { + let mut states = core::array::from_fn::(|_| SymbolStates { + states: Vec::new(), + probability: 0, + }); + + // distribute -1 symbols + let mut negative_idx = (1 << acc_log) - 1; + for (symbol, _prob) in probs + .iter() + .copied() + .enumerate() + .filter(|prob| prob.1 == -1) + { + states[symbol].states.push(State { + num_bits: acc_log, + baseline: 0, + last_index: (1 << acc_log) - 1, + index: negative_idx, + }); + states[symbol].probability = -1; + negative_idx -= 1; + } + + // distribute other symbols + + // Setup all needed states per symbol with their respective index + let mut idx = 0; + for (symbol, prob) in probs.iter().copied().enumerate() { + if prob <= 0 { + continue; + } + states[symbol].probability = prob; + let states = &mut states[symbol].states; + for _ in 0..prob { + states.push(State { + num_bits: 0, + baseline: 0, + last_index: 0, + index: idx, + }); + + idx = next_position(idx, 1 << acc_log); + while idx > negative_idx { + idx = next_position(idx, 1 << acc_log); + } + } + assert_eq!(states.len(), prob as usize); + } + + // After all states know their index we can determine the numbits and baselines + for (symbol, prob) in probs.iter().copied().enumerate() { + if prob <= 0 { + continue; + } + let prob = prob as u32; + let state = &mut states[symbol]; + + // We process the states in their order in the table + state.states.sort_by(|l, r| l.index.cmp(&r.index)); + + let prob_log = if prob.is_power_of_two() { + prob.ilog2() + } else { + prob.ilog2() + 1 + }; + let rounded_up = 1u32 << prob_log; + + // The lower states target double the amount of indexes -> numbits + 1 + let double_states = rounded_up - prob; + let single_states = prob - double_states; + let num_bits = acc_log - prob_log as u8; + let mut baseline = (single_states as usize * (1 << (num_bits))) % (1 << acc_log); + for (idx, state) in state.states.iter_mut().enumerate() { + if (idx as u32) < double_states { + let num_bits = num_bits + 1; + state.baseline = baseline; + state.num_bits = num_bits; + state.last_index = baseline + ((1 << num_bits) - 1); + + baseline += 1 << num_bits; + baseline %= 1 << acc_log; + } else { + state.baseline = baseline; + state.num_bits = num_bits; + state.last_index = baseline + ((1 << num_bits) - 1); + baseline += 1 << num_bits; + } + } + + // For encoding we use the states ordered by the indexes they target + state.states.sort_by(|l, r| l.baseline.cmp(&r.baseline)); + } + + FSETable { + table_size: 1 << acc_log, + states, + } +} + +/// Calculate the position of the next entry of the table given the current +/// position and size of the table. +fn next_position(mut p: usize, table_size: usize) -> usize { + p += (table_size >> 1) + (table_size >> 3) + 3; + p &= table_size - 1; + p +} + +const ML_DIST: &[i32] = &[ + 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, +]; + +const LL_DIST: &[i32] = &[ + 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, + -1, -1, -1, -1, +]; + +const OF_DIST: &[i32] = &[ + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, +]; + +pub(crate) fn default_ml_table() -> FSETable { + build_table_from_probabilities(ML_DIST, 6) +} + +pub(crate) fn default_ll_table() -> FSETable { + build_table_from_probabilities(LL_DIST, 6) +} + +pub(crate) fn default_of_table() -> FSETable { + build_table_from_probabilities(OF_DIST, 5) +} diff --git a/src/fse/mod.rs b/src/fse/mod.rs index 1fa6b3f6..7b431376 100644 --- a/src/fse/mod.rs +++ b/src/fse/mod.rs @@ -13,4 +13,124 @@ //! mod fse_decoder; + pub use fse_decoder::*; +use fse_encoder::FSEEncoder; + +use crate::{decoding::bit_reader_reverse::BitReaderReversed, encoding::bit_writer::BitWriter}; +pub mod fse_encoder; + +#[test] +fn tables_equal() { + let probs = &[0, 0, -1, 3, 2, 2, (1 << 6) - 8]; + let mut dec_table = FSETable::new(255); + dec_table.build_from_probabilities(6, probs).unwrap(); + let enc_table = fse_encoder::build_table_from_probabilities(probs, 6); + + check_tables(&dec_table, &enc_table); +} + +fn check_tables(dec_table: &fse_decoder::FSETable, enc_table: &fse_encoder::FSETable) { + for (idx, dec_state) in dec_table.decode.iter().enumerate() { + let enc_states = &enc_table.states[dec_state.symbol as usize]; + let enc_state = enc_states + .states + .iter() + .find(|state| state.index == idx) + .unwrap(); + assert_eq!(enc_state.baseline, dec_state.base_line as usize); + assert_eq!(enc_state.num_bits, dec_state.num_bits); + } +} + +#[test] +fn roundtrip() { + round_trip(&(0..64).collect::>()); + let mut data = alloc::vec![]; + data.extend(0..32); + data.extend(0..32); + data.extend(0..32); + data.extend(0..32); + data.extend(0..32); + data.extend(20..32); + data.extend(20..32); + data.extend(0..32); + data.extend(20..32); + data.extend(100..255); + data.extend(20..32); + data.extend(20..32); + round_trip(&data); + + #[cfg(feature = "std")] + if std::fs::exists("fuzz/artifacts/fse").unwrap_or(false) { + for file in std::fs::read_dir("fuzz/artifacts/fse").unwrap() { + if file.as_ref().unwrap().file_type().unwrap().is_file() { + let data = std::fs::read(file.unwrap().path()).unwrap(); + round_trip(&data); + } + } + } +} + +/// Only needed for testing. +/// +/// Encodes the data with a table built from that data +/// Decodes the result again by first decoding the table and then the data +/// Asserts that the decoded data equals the input +pub fn round_trip(data: &[u8]) { + if data.len() < 2 { + return; + } + if data.iter().all(|x| *x == data[0]) { + return; + } + if data.len() < 64 { + return; + } + + let mut writer = BitWriter::new(); + let mut encoder = FSEEncoder::new( + fse_encoder::build_table_from_data(data, 22, false), + &mut writer, + ); + let mut dec_table = FSETable::new(255); + encoder.encode(data); + let acc_log = encoder.acc_log(); + let enc_table = encoder.into_table(); + let encoded = writer.dump(); + + let table_bytes = dec_table.build_decoder(&encoded, acc_log).unwrap(); + let encoded = &encoded[table_bytes..]; + let mut decoder = FSEDecoder::new(&dec_table); + + check_tables(&dec_table, &enc_table); + + let mut br = BitReaderReversed::new(encoded); + let mut skipped_bits = 0; + loop { + let val = br.get_bits(1); + skipped_bits += 1; + if val == 1 || skipped_bits > 8 { + break; + } + } + if skipped_bits > 8 { + //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data + panic!("Corrupted end marker"); + } + decoder.init_state(&mut br).unwrap(); + let mut decoded = alloc::vec::Vec::new(); + + for x in data { + let w = decoder.decode_symbol(); + assert_eq!(w, *x); + decoded.push(w); + if decoded.len() < data.len() { + decoder.update_state(&mut br); + } + } + + assert_eq!(&decoded, data); + + assert_eq!(br.bits_remaining(), 0); +} diff --git a/src/huff0/huff0_decoder.rs b/src/huff0/huff0_decoder.rs index 129cfc84..e7d98446 100644 --- a/src/huff0/huff0_decoder.rs +++ b/src/huff0/huff0_decoder.rs @@ -209,7 +209,7 @@ impl From for HuffmanDecoderError { /// A single entry in the table contains the decoded symbol/literal and the /// size of the prefix code. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct Entry { /// The byte that the prefix code replaces during encoding. symbol: u8, @@ -293,7 +293,7 @@ impl HuffmanTable { bits: Vec::with_capacity(256), bit_ranks: Vec::with_capacity(11), rank_indexes: Vec::with_capacity(11), - fse_table: FSETable::new(100), + fse_table: FSETable::new(255), } } @@ -360,9 +360,7 @@ impl HuffmanTable { }); } //fse decompress weights - let bytes_used_by_fse_header = self - .fse_table - .build_decoder(fse_stream, /*TODO find actual max*/ 100)?; + let bytes_used_by_fse_header = self.fse_table.build_decoder(fse_stream, 6)?; if bytes_used_by_fse_header > header as usize { return Err(err::FSETableUsedTooManyBytes { diff --git a/src/huff0/huff0_encoder.rs b/src/huff0/huff0_encoder.rs new file mode 100644 index 00000000..dfb3c65d --- /dev/null +++ b/src/huff0/huff0_encoder.rs @@ -0,0 +1,462 @@ +use alloc::vec::Vec; +use core::cmp::Ordering; + +use crate::{ + encoding::bit_writer::BitWriter, + fse::fse_encoder::{self, FSEEncoder}, +}; + +pub(crate) struct HuffmanEncoder<'output, V: AsMut>> { + table: HuffmanTable, + writer: &'output mut BitWriter, +} + +impl>> HuffmanEncoder<'_, V> { + pub fn new(table: HuffmanTable, writer: &mut BitWriter) -> HuffmanEncoder<'_, V> { + HuffmanEncoder { table, writer } + } + + /// Encodes the data using the provided table + /// Writes + /// * Table description + /// * Encoded data + /// * Padding bits to fill up last byte + pub fn encode(&mut self, data: &[u8]) { + self.write_table(); + Self::encode_stream(&self.table, self.writer, data); + } + + /// Encodes the data using the provided table in 4 concatenated streams + /// Writes + /// * Table description + /// * Jumptable + /// * Encoded data in 4 streams, each padded to fill the last byte + pub fn encode4x(&mut self, data: &[u8]) { + assert!(data.len() >= 4); + + // Split data in 4 equally sized parts (the last one might be a bit smaller than the rest) + let split_size = (data.len() + 3) / 4; + let src1 = &data[..split_size]; + let src2 = &data[split_size..split_size * 2]; + let src3 = &data[split_size * 2..split_size * 3]; + let src4 = &data[split_size * 3..]; + + // Write table description + self.write_table(); + + // Reserve space for the jump table, will be changed later + let size_idx = self.writer.index(); + self.writer.write_bits(0u16, 16); + self.writer.write_bits(0u16, 16); + self.writer.write_bits(0u16, 16); + + // Write the 4 streams, noting the sizes of the encoded streams + let index_before = self.writer.index(); + Self::encode_stream(&self.table, self.writer, src1); + let size1 = (self.writer.index() - index_before) / 8; + + let index_before = self.writer.index(); + Self::encode_stream(&self.table, self.writer, src2); + let size2 = (self.writer.index() - index_before) / 8; + + let index_before = self.writer.index(); + Self::encode_stream(&self.table, self.writer, src3); + let size3 = (self.writer.index() - index_before) / 8; + + Self::encode_stream(&self.table, self.writer, src4); + + // Sanity check, if this doesn't hold we produce a broken stream + assert!(size1 <= u16::MAX as usize); + assert!(size2 <= u16::MAX as usize); + assert!(size3 <= u16::MAX as usize); + + // Update the jumptable with the real sizes + self.writer.change_bits(size_idx, size1 as u16, 16); + self.writer.change_bits(size_idx + 16, size2 as u16, 16); + self.writer.change_bits(size_idx + 32, size3 as u16, 16); + } + + /// Encode one stream and pad it to fill the last byte + fn encode_stream>>( + table: &HuffmanTable, + writer: &mut BitWriter, + data: &[u8], + ) { + for symbol in data.iter().rev() { + let (code, num_bits) = table.codes[*symbol as usize]; + writer.write_bits(code, num_bits as usize); + } + + let bits_to_fill = writer.misaligned(); + if bits_to_fill == 0 { + writer.write_bits(1u32, 8); + } else { + writer.write_bits(1u32, bits_to_fill); + } + } + + pub(super) fn weights(&self) -> Vec { + let max = self.table.codes.iter().map(|(_, nb)| nb).max().unwrap(); + let weights = self + .table + .codes + .iter() + .copied() + .map(|(_, nb)| if nb == 0 { 0 } else { max - nb + 1 }) + .collect::>(); + + weights + } + + fn write_table(&mut self) { + // TODO strategy for determining this? + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; // dont encode last weight + if weights.len() > 16 { + let size_idx = self.writer.index(); + self.writer.write_bits(0u8, 8); + let idx_before = self.writer.index(); + let mut encoder = FSEEncoder::new( + fse_encoder::build_table_from_data(weights, 6, true), + self.writer, + ); + encoder.encode_interleaved(weights); + let encoded_len = (self.writer.index() - idx_before) / 8; + assert!(encoded_len < 128); + self.writer.change_bits(size_idx, encoded_len as u8, 8); + } else { + self.writer.write_bits(weights.len() as u8 + 127, 8); + let pairs = weights.chunks_exact(2); + let remainder = pairs.remainder(); + for pair in pairs.into_iter() { + let weight1 = pair[0]; + let weight2 = pair[1]; + assert!(weight1 < 16); + assert!(weight2 < 16); + self.writer.write_bits(weight2, 4); + self.writer.write_bits(weight1, 4); + } + if !remainder.is_empty() { + let weight = remainder[0]; + assert!(weight < 16); + self.writer.write_bits(weight << 4, 8); + } + } + } +} + +pub struct HuffmanTable { + /// Index is the symbol, values are the bitstring in the lower bits of the u32 and the amount of bits in the u8 + codes: Vec<(u32, u8)>, +} + +impl HuffmanTable { + pub fn build_from_data(data: &[u8]) -> Self { + let mut counts = [0; 256]; + let mut max = 0; + for x in data { + counts[*x as usize] += 1; + max = max.max(*x); + } + + Self::build_from_counts(&counts[..=max as usize]) + } + + pub fn build_from_counts(counts: &[usize]) -> Self { + assert!(counts.len() <= 256); + let zeros = counts.iter().filter(|x| **x == 0).count(); + let mut weights = distribute_weights(counts.len() - zeros); + let limit = weights.len().ilog2() as usize + 2; + redistribute_weights(&mut weights, limit); + + weights.reverse(); + let mut counts_sorted = counts.iter().enumerate().collect::>(); + counts_sorted.sort_by(|(_, c1), (_, c2)| c1.cmp(c2)); + + let mut weights_distributed = alloc::vec![0; counts.len()]; + for (idx, count) in counts_sorted { + if *count == 0 { + weights_distributed[idx] = 0; + } else { + weights_distributed[idx] = weights.pop().unwrap(); + } + } + + Self::build_from_weights(&weights_distributed) + } + + pub fn build_from_weights(weights: &[usize]) -> Self { + let mut sorted = Vec::with_capacity(weights.len()); + struct SortEntry { + symbol: u8, + weight: usize, + } + + // TODO this doesn't need to be a temporary Vec, it could be done in a [_; 264] + // only non-zero weights are interesting here + for (symbol, weight) in weights.iter().copied().enumerate() { + if weight > 0 { + sorted.push(SortEntry { + symbol: symbol as u8, + weight, + }); + } + } + // We process symbols ordered by weight and then ordered by symbol + sorted.sort_by(|left, right| match left.weight.cmp(&right.weight) { + Ordering::Equal => left.symbol.cmp(&right.symbol), + other => other, + }); + + // Prepare huffman table with placeholders + let mut table = HuffmanTable { + codes: Vec::with_capacity(weights.len()), + }; + for _ in 0..weights.len() { + table.codes.push((0, 0)); + } + + // Determine the number of bits needed for codes with the lowest weight + let weight_sum = sorted.iter().map(|e| 1 << (e.weight - 1)).sum::(); + if !weight_sum.is_power_of_two() { + panic!("This is an internal error"); + } + let max_num_bits = highest_bit_set(weight_sum) - 1; // this is a log_2 of a clean power of two + + // Starting at the symbols with the lowest weight we update the placeholders in the table + let mut current_code = 0; + let mut current_weight = 0; + let mut current_num_bits = 0; + for entry in sorted.iter() { + // If the entry isn't the same weight as the last one we need to change a few things + if current_weight != entry.weight { + // The code shifts by the difference of the weights to allow for enough unique values + current_code >>= entry.weight - current_weight; + // Encoding a symbol of this weight will take less bits than the previous weight + current_num_bits = max_num_bits - entry.weight + 1; + // Run the next update when the weight changes again + current_weight = entry.weight; + } + table.codes[entry.symbol as usize] = (current_code as u32, current_num_bits as u8); + current_code += 1; + } + + table + } +} + +/// Assert that the provided value is greater than zero, and returns index of the first set bit +fn highest_bit_set(x: usize) -> usize { + assert!(x > 0); + usize::BITS as usize - x.leading_zeros() as usize +} + +#[test] +fn huffman() { + let table = HuffmanTable::build_from_weights(&[2, 2, 2, 1, 1]); + assert_eq!(table.codes[0], (1, 2)); + assert_eq!(table.codes[1], (2, 2)); + assert_eq!(table.codes[2], (3, 2)); + assert_eq!(table.codes[3], (0, 3)); + assert_eq!(table.codes[4], (1, 3)); + + let table = HuffmanTable::build_from_weights(&[4, 3, 2, 0, 1, 1]); + assert_eq!(table.codes[0], (1, 1)); + assert_eq!(table.codes[1], (1, 2)); + assert_eq!(table.codes[2], (1, 3)); + assert_eq!(table.codes[3], (0, 0)); + assert_eq!(table.codes[4], (0, 4)); + assert_eq!(table.codes[5], (1, 4)); +} + +/// Distributes weights that add up to a clean power of two +fn distribute_weights(amount: usize) -> Vec { + assert!(amount >= 2); + assert!(amount <= 256); + let mut weights = Vec::new(); + + // This is the trivial power of two we always need + weights.push(1); + weights.push(1); + + // This is the weight we are adding right now + let mut target_weight = 1; + // Counts how many times we have added weights + let mut weight_counter = 2; + + // We always add a power of 2 new weights so that the weights that we add equal + // the weights are already in the vec if raised to the power of two. + // This means we double the weights in the vec -> results in a new power of two + // + // Example: [1, 1] -> [1,1,2] (2^1 + 2^1 == 2^2) + // + // Example: [1, 1] -> [1,1,1,1] (2^1 + 2^1 == 2^1 + 2^1) + // [1,1,1,1] -> [1,1,1,1,3] (2^1 + 2^1 + 2^1 + 2^1 == 2^3) + while weights.len() < amount { + let mut add_new = 1 << (weight_counter - target_weight); + let available_space = amount - weights.len(); + + // If the amount of new weights needed to get to the next power of two would exceed amount + // We instead add 1 of a bigger weight and start the cycle again + if add_new > available_space { + // TODO we could maybe instead do this until add_new <= available_space? + // target_weight += 1 + // add_new /= 2 + target_weight = weight_counter; + add_new = 1; + } + + for _ in 0..add_new { + weights.push(target_weight); + } + weight_counter += 1; + } + + assert_eq!(amount, weights.len()); + + weights +} + +/// Sometimes distribute_weights generates weights that require too many bits to encode +/// This redistributes the weights to have less variance by raising the lower weights while still maintaining the +/// required attributes of the weight distribution +fn redistribute_weights(weights: &mut [usize], max_num_bits: usize) { + let weight_sum_log = weights + .iter() + .copied() + .map(|x| 1 << x) + .sum::() + .ilog2() as usize; + + // Nothing needs to be done, this is already fine + if weight_sum_log < max_num_bits { + return; + } + + // We need to decrease the weight difference by the difference between weight_sum_log and max_num_bits + let decrease_weights_by = weight_sum_log - max_num_bits + 1; + + // To do that we raise the lower weights up by that difference, recording how much weight we added in the process + let mut added_weights = 0; + for weight in weights.iter_mut() { + if *weight < decrease_weights_by { + for add in *weight..decrease_weights_by { + added_weights += 1 << add; + } + *weight = decrease_weights_by; + } + } + + // Then we reduce weights until the added weights are equaled out + while added_weights > 0 { + // Find the highest weight that is still lower or equal to the added weight + let mut current_idx = 0; + let mut current_weight = 0; + for (idx, weight) in weights.iter().copied().enumerate() { + if 1 << (weight - 1) > added_weights { + break; + } + if weight > current_weight { + current_weight = weight; + current_idx = idx; + } + } + + // Reduce that weight by 1 + added_weights -= 1 << (current_weight - 1); + weights[current_idx] -= 1; + } + + // At the end we normalize the weights so that they start at 1 again + if weights[0] > 1 { + let offset = weights[0] - 1; + for weight in weights.iter_mut() { + *weight -= offset; + } + } +} + +#[test] +fn weights() { + // assert_eq!(distribute_weights(5).as_slice(), &[1, 1, 2, 3, 4]); + for amount in 2..=256 { + let mut weights = distribute_weights(amount); + assert_eq!(weights.len(), amount); + let sum = weights + .iter() + .copied() + .map(|weight| 1 << weight) + .sum::(); + assert!(sum.is_power_of_two()); + + for num_bit_limit in (amount.ilog2() as usize + 1)..=11 { + redistribute_weights(&mut weights, num_bit_limit); + let sum = weights + .iter() + .copied() + .map(|weight| 1 << weight) + .sum::(); + assert!(sum.is_power_of_two()); + assert!( + sum.ilog2() <= 11, + "Max bits too big: sum: {} {weights:?}", + sum + ); + + let codes = HuffmanTable::build_from_weights(&weights).codes; + for (code, num_bits) in codes.iter().copied() { + for (code2, num_bits2) in codes.iter().copied() { + if num_bits == 0 || num_bits2 == 0 || (code, num_bits) == (code2, num_bits2) { + continue; + } + if num_bits <= num_bits2 { + let code2_shifted = code2 >> (num_bits2 - num_bits); + assert_ne!( + code, code2_shifted, + "{:b},{num_bits:} is prefix of {:b},{num_bits2:}", + code, code2 + ); + } + } + } + } + } +} + +#[test] +fn counts() { + let counts = &[3, 0, 4, 1, 5]; + let table = HuffmanTable::build_from_counts(counts).codes; + + assert_eq!(table[1].1, 0); + assert!(table[3].1 >= table[0].1); + assert!(table[0].1 >= table[2].1); + assert!(table[2].1 >= table[4].1); + + let counts = &[3, 0, 4, 0, 7, 2, 2, 2, 0, 2, 2, 1, 5]; + let table = HuffmanTable::build_from_counts(counts).codes; + + assert_eq!(table[1].1, 0); + assert_eq!(table[3].1, 0); + assert_eq!(table[8].1, 0); + assert!(table[11].1 >= table[5].1); + assert!(table[5].1 >= table[6].1); + assert!(table[6].1 >= table[7].1); + assert!(table[7].1 >= table[9].1); + assert!(table[9].1 >= table[10].1); + assert!(table[10].1 >= table[0].1); + assert!(table[0].1 >= table[2].1); + assert!(table[2].1 >= table[12].1); + assert!(table[12].1 >= table[4].1); +} + +#[test] +fn from_data() { + let counts = &[3, 0, 4, 1, 5]; + let table = HuffmanTable::build_from_counts(counts).codes; + + let data = &[0, 2, 4, 4, 0, 3, 2, 2, 0, 2]; + let table2 = HuffmanTable::build_from_data(data).codes; + + assert_eq!(table, table2); +} diff --git a/src/huff0/mod.rs b/src/huff0/mod.rs index 3d847d65..5cbe0d08 100644 --- a/src/huff0/mod.rs +++ b/src/huff0/mod.rs @@ -3,4 +3,81 @@ /// used symbols get longer codes. Codes are prefix free, meaning no two codes /// will start with the same sequence of bits. mod huff0_decoder; +use alloc::vec::Vec; + pub use huff0_decoder::*; + +use crate::{decoding::bit_reader_reverse::BitReaderReversed, encoding::bit_writer::BitWriter}; +pub mod huff0_encoder; + +/// Only needed for testing. +/// +/// Encodes the data with a table built from that data +/// Decodes the result again by first decoding the table and then the data +/// Asserts that the decoded data equals the input +pub fn round_trip(data: &[u8]) { + if data.len() < 2 { + return; + } + if data.iter().all(|x| *x == data[0]) { + return; + } + let mut writer = BitWriter::new(); + let encoder_table = huff0_encoder::HuffmanTable::build_from_data(data); + let mut encoder = huff0_encoder::HuffmanEncoder::new(encoder_table, &mut writer); + + encoder.encode(data); + let encoded = writer.dump(); + let mut decoder_table = HuffmanTable::new(); + let table_bytes = decoder_table.build_decoder(&encoded).unwrap(); + let mut decoder = HuffmanDecoder::new(&decoder_table); + + let mut br = BitReaderReversed::new(&encoded[table_bytes as usize..]); + let mut skipped_bits = 0; + loop { + let val = br.get_bits(1); + skipped_bits += 1; + if val == 1 || skipped_bits > 8 { + break; + } + } + if skipped_bits > 8 { + //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data + panic!("Corrupted end marker"); + } + + decoder.init_state(&mut br); + let mut decoded = Vec::new(); + while br.bits_remaining() > -(decoder_table.max_num_bits as isize) { + decoded.push(decoder.decode_symbol()); + decoder.next_state(&mut br); + } + assert_eq!(&decoded, data); +} + +#[test] +fn roundtrip() { + round_trip(&[1, 1, 1, 1, 2, 3]); + round_trip(&[1, 1, 1, 1, 2, 3, 5, 45, 12, 90]); + + for size in 2..512 { + use alloc::vec; + let data = vec![123; size]; + round_trip(&data); + let mut data = Vec::new(); + for x in 0..size { + data.push(x as u8); + } + round_trip(&data); + } + + #[cfg(feature = "std")] + if std::fs::exists("fuzz/artifacts/huff0").unwrap_or(false) { + for file in std::fs::read_dir("fuzz/artifacts/huff0").unwrap() { + if file.as_ref().unwrap().file_type().unwrap().is_file() { + let data = std::fs::read(file.unwrap().path()).unwrap(); + round_trip(&data); + } + } + } +} diff --git a/src/io_nostd.rs b/src/io_nostd.rs index 4948263f..b0f989b1 100644 --- a/src/io_nostd.rs +++ b/src/io_nostd.rs @@ -9,6 +9,7 @@ pub enum ErrorKind { UnexpectedEof, WouldBlock, Other, + WriteAllEof, } impl ErrorKind { @@ -19,6 +20,7 @@ impl ErrorKind { UnexpectedEof => "unexpected end of file", WouldBlock => "operation would block", Other => "other error", + WriteAllEof => "write_all hit EOF", } } } @@ -61,6 +63,10 @@ impl Error { self.kind } + pub fn is_interrupted(&self) -> bool { + matches!(self.kind, ErrorKind::Interrupted) + } + pub fn get_ref(&self) -> Option<&(dyn core::fmt::Display + Send + Sync)> { self.err.as_ref().map(|e| e.as_ref()) } @@ -109,6 +115,18 @@ pub trait Read { Ok(()) } } + + fn read_to_end(&mut self, output: &mut alloc::vec::Vec) -> Result<(), Error> { + let mut buf = [0u8; 1024 * 16]; + loop { + let bytes = self.read(&mut buf)?; + if bytes == 0 { + break; + } + output.extend_from_slice(&buf[..bytes]); + } + Ok(()) + } } impl Read for &[u8] { @@ -127,7 +145,7 @@ impl Read for &[u8] { } } -impl<'a, T> Read for &'a mut T +impl Read for &mut T where T: Read, { @@ -139,9 +157,22 @@ where pub trait Write { fn write(&mut self, buf: &[u8]) -> Result; fn flush(&mut self) -> Result<(), Error>; + fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Error> { + while !buf.is_empty() { + match self.write(buf) { + Ok(0) => { + return Err(Error::from(ErrorKind::WriteAllEof)); + } + Ok(n) => buf = &buf[n..], + Err(ref e) if e.is_interrupted() => {} + Err(e) => return Err(e), + } + } + Ok(()) + } } -impl<'a, T> Write for &'a mut T +impl Write for &mut T where T: Write, { @@ -168,3 +199,15 @@ impl Write for &mut [u8] { Ok(()) } } + +impl Write for alloc::vec::Vec { + #[inline] + fn write(&mut self, data: &[u8]) -> Result { + self.extend_from_slice(data); + Ok(data.len()) + } + + fn flush(&mut self) -> Result<(), Error> { + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index f68e1a20..ee00bad3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,12 +6,12 @@ //! Decompression can be achieved by using the [`StreamingDecoder`] interface. //! //! ## Compression -//! Although functionality has not yet been implemented past raw frames, refer to the -//! [encoding] module for more info. +//! The [encoding] module contains the internals for compression. +//! Decompression can be achieved by using the [`encoding::compress`]/[`encoding::compress_to_vec`] functions or the [`FrameCompressor`] interface. //! //! # Speed //! The decoder has been measured to be roughly between 3.5 to 1.4 times slower -//! than the original implementation. +//! than the original implementation depending on the compressed data. #![no_std] #![deny(trivial_casts, trivial_numeric_casts, rust_2018_idioms)] @@ -35,7 +35,6 @@ macro_rules! vprintln { pub mod blocks; pub mod decoding; -#[cfg(feature = "std")] pub mod encoding; pub mod frame; pub mod frame_decoder; @@ -56,3 +55,5 @@ pub use io_nostd as io; pub use frame_decoder::BlockDecodingStrategy; pub use frame_decoder::FrameDecoder; pub use streaming_decoder::StreamingDecoder; + +pub use encoding::FrameCompressor; diff --git a/src/tests/encode_corpus.rs b/src/tests/encode_corpus.rs index ad61ec52..e8000e2a 100644 --- a/src/tests/encode_corpus.rs +++ b/src/tests/encode_corpus.rs @@ -26,13 +26,17 @@ fn test_encode_corpus_files_uncompressed_our_decompressor() { if path.extension() == Some(OsStr::new("zst")) { continue; } + println!("Trying file: {:?}", path); let input = fs::read(entry.path()).unwrap(); - let compressor = - FrameCompressor::new(&input, crate::encoding::CompressionLevel::Uncompressed); let mut compressed_file: Vec = Vec::new(); - compressor.compress(&mut compressed_file); + let mut compressor = FrameCompressor::new( + input.as_slice(), + &mut compressed_file, + crate::encoding::CompressionLevel::Fastest, + ); + compressor.compress(); let mut decompressed_output = Vec::new(); let mut decoder = crate::streaming_decoder::StreamingDecoder::new(compressed_file.as_slice()).unwrap(); @@ -83,10 +87,133 @@ fn test_encode_corpus_files_uncompressed_original_decompressor() { println!("Trying file: {:?}", path); let input = fs::read(entry.path()).unwrap(); - let compressor = - FrameCompressor::new(&input, crate::encoding::CompressionLevel::Uncompressed); let mut compressed_file: Vec = Vec::new(); - compressor.compress(&mut compressed_file); + let mut compressor = FrameCompressor::new( + input.as_slice(), + &mut compressed_file, + crate::encoding::CompressionLevel::Fastest, + ); + compressor.compress(); + let mut decompressed_output = Vec::new(); + // zstd::stream::copy_decode(compressed_file.as_slice(), &mut decompressed_output).unwrap(); + match zstd::stream::copy_decode(compressed_file.as_slice(), &mut decompressed_output) { + Ok(()) => { + if input != decompressed_output { + failures.push((path.to_owned(), "Input didn't equal output".to_owned())); + } + } + Err(e) => { + failures.push(( + path.to_owned(), + format!("Decompressor threw an error: {e:?}"), + )); + } + }; + + if !failures.is_empty() { + panic!( + "Decompression of the compressed file fails on the following files: {:?}", + failures + ); + } + } +} + +#[test] +fn test_encode_corpus_files_compressed_our_decompressor() { + extern crate std; + use crate::encoding::FrameCompressor; + use alloc::borrow::ToOwned; + use alloc::vec::Vec; + use std::ffi::OsStr; + use std::fs; + use std::io::Read; + use std::path::PathBuf; + use std::println; + + let mut failures: Vec = Vec::new(); + let mut files: Vec<_> = fs::read_dir("./decodecorpus_files").unwrap().collect(); + if fs::read_dir("./local_corpus_files").is_ok() { + files.extend(fs::read_dir("./local_corpus_files").unwrap()); + } + + files.sort_by_key(|x| match x { + Err(_) => "".to_owned(), + Ok(entry) => entry.path().to_str().unwrap().to_owned(), + }); + + for entry in files.iter().map(|f| f.as_ref().unwrap()) { + let path = entry.path(); + if path.extension() == Some(OsStr::new("zst")) { + continue; + } + println!("Trying file: {:?}", path); + let input = fs::read(entry.path()).unwrap(); + + let mut compressed_file: Vec = Vec::new(); + let mut compressor = FrameCompressor::new( + input.as_slice(), + &mut compressed_file, + crate::encoding::CompressionLevel::Fastest, + ); + compressor.compress(); + let mut decompressed_output = Vec::new(); + let mut decoder = + crate::streaming_decoder::StreamingDecoder::new(compressed_file.as_slice()).unwrap(); + decoder.read_to_end(&mut decompressed_output).unwrap(); + + if input != decompressed_output { + failures.push(path); + } + } + + if !failures.is_empty() { + panic!( + "Decompression of compressed file failed on the following files: {:?}", + failures + ); + } +} + +#[test] +fn test_encode_corpus_files_compressed_original_decompressor() { + extern crate std; + use crate::encoding::FrameCompressor; + use alloc::borrow::ToOwned; + use alloc::format; + use alloc::vec::Vec; + use std::ffi::OsStr; + use std::fs; + use std::path::PathBuf; + use std::println; + use std::string::String; + + let mut failures: Vec<(PathBuf, String)> = Vec::new(); + let mut files: Vec<_> = fs::read_dir("./decodecorpus_files").unwrap().collect(); + if fs::read_dir("./local_corpus_files").is_ok() { + files.extend(fs::read_dir("./local_corpus_files").unwrap()); + } + + files.sort_by_key(|x| match x { + Err(_) => "".to_owned(), + Ok(entry) => entry.path().to_str().unwrap().to_owned(), + }); + + for entry in files.iter().map(|f| f.as_ref().unwrap()) { + let path = entry.path(); + if path.extension() == Some(OsStr::new("zst")) { + continue; + } + println!("Trying file: {:?}", path); + let input = fs::read(entry.path()).unwrap(); + + let mut compressed_file: Vec = Vec::new(); + let mut compressor = FrameCompressor::new( + input.as_slice(), + &mut compressed_file, + crate::encoding::CompressionLevel::Fastest, + ); + compressor.compress(); let mut decompressed_output = Vec::new(); // zstd::stream::copy_decode(compressed_file.as_slice(), &mut decompressed_output).unwrap(); match zstd::stream::copy_decode(compressed_file.as_slice(), &mut decompressed_output) {