From 34d0f0bafad38d245c67411d21824704dbe4bf11 Mon Sep 17 00:00:00 2001 From: Lucas Schumacher Date: Tue, 10 Sep 2024 18:14:59 -0400 Subject: [PATCH] Refactor compress and decompress to use generic io::Read and io::Write traits --- src/bit_buffer.rs | 211 ++++++++++++++++++++++++++++++++-------------- src/main.rs | 20 +++-- src/modelA.rs | 109 ++++++++++-------------- 3 files changed, 204 insertions(+), 136 deletions(-) diff --git a/src/bit_buffer.rs b/src/bit_buffer.rs index 10a373e..a7ab2ac 100644 --- a/src/bit_buffer.rs +++ b/src/bit_buffer.rs @@ -1,101 +1,182 @@ -#[derive(Debug)] -pub struct BitWriter { - data: Vec, +use std::io::{self, Bytes, Cursor, Read, Write}; + +pub struct BitWriter<'a, W: ?Sized + Write> { bits: u8, nextbit: usize, + output: &'a mut W, } -impl BitWriter { - pub fn new() -> Self { +impl<'a, W: Write> From<&'a mut W> for BitWriter<'a, W> { + fn from(value: &'a mut W) -> Self { + BitWriter::new(value) + } +} +impl<'a, W: Write> BitWriter<'a, W> { + pub fn new(writer: &'a mut W) -> Self { return BitWriter { - data: vec![], bits: 0, nextbit: 7, + output: writer, }; + //writer.into() } - pub fn write(&mut self, bit: bool) { + pub fn write(&mut self, bit: bool) -> io::Result<()> { if bit { self.bits = 1 << self.nextbit | self.bits; } if self.nextbit == 0 { - self.data.push(self.bits); + self.output.write(&[self.bits])?; self.bits = 0; self.nextbit = 7; } else { self.nextbit -= 1; } + Ok(()) } - pub fn flush(mut self) -> Vec { + pub fn flush(self) -> std::io::Result<()> { if self.bits != 0 { - self.data.push(self.bits); + self.output.write(&[self.bits])?; } - return self.data; + return self.output.flush(); } } -impl Into> for BitWriter { - fn into(self) -> Vec { - self.flush() +pub struct BitReader { + next: u8, + bits: u8, + repeat_bits: i32, + input: Bytes, +} + +impl From for BitReader { + fn from(value: R) -> Self { + BitReader::new(value.bytes()) } } -impl IntoIterator for BitWriter { - type Item = u8; - type IntoIter = std::vec::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.flush().into_iter() +impl> From for BitReader> { + fn from(value: T) -> Self { + let c = Cursor::new(value); + c.into() } } -pub trait Poppable { - fn pop(&mut self) -> Option; -} -impl Poppable for &[u8] { - fn pop(&mut self) -> Option { - if self.len() == 0 { - return None; +impl BitReader { + pub fn new(value: Bytes) -> Self { + BitReader { + next: 0, + bits: 0, + repeat_bits: 0, + input: value, } - let out = self[0]; - *self = &self[1..]; - return Some(out); } -} - -pub struct BitReader<'a> { - data: &'a [u8], - // bits: u8, - nextbit: usize, -} - -impl<'a> BitReader<'a> { - pub fn new(data: &'a [u8]) -> Self { - BitReader { data, nextbit: 7 } - } -} - -impl<'a> From<&'a [u8]> for BitReader<'a> { - fn from(value: &'a [u8]) -> Self { - BitReader::new(value) - } -} - -impl BitReader<'_> { - pub fn pop(&mut self) -> Option { - if self.data.len() == 0 { - return None; + pub fn get_bit(&mut self) -> io::Result { + if self.next == 0 { + let next = self.input.next().transpose()?; + if let Some(byte) = next { + self.bits = byte; + } else if self.repeat_bits <= 0 { + return Err(io::Error::other("No more bits")); + } else { + println!("{}", self.repeat_bits); + self.repeat_bits -= 8; + } + self.next = 1 << 7; } - let bit = (self.data[0] >> self.nextbit) & 1; - if self.nextbit == 0 { - self.data.pop(); - self.nextbit = 8; - } - self.nextbit -= 1; - return Some(bit > 0); + let bit = (self.bits & self.next) > 0; + self.next = self.next >> 1; + return Ok(bit); + } +} +impl BitReader { + pub fn with_repeat_bits(mut self, n_bits: u16) -> Self { + self.repeat_bits = n_bits as i32; + self } } -impl Iterator for BitReader<'_> { - type Item = bool; - fn next(&mut self) -> Option { - self.pop() +#[cfg(test)] +mod tests { + use super::*; + use crate::modelA::tests::COMPRESSED_BYTES; + use crate::modelA::Metrics; + + struct InputBits<'a> { + input: &'a [u8], + current_byte: u32, + last_mask: u32, + code_value_bits: i32, + } + + pub trait Poppable { + fn pop(&mut self) -> Option; + } + impl Poppable for &[u8] { + fn pop(&mut self) -> Option { + if self.len() == 0 { + return None; + } + let out = self[0]; + *self = &self[1..]; + return Some(out); + } + } + impl<'a> InputBits<'a> { + pub fn new(data: &'a [u8]) -> Self { + Self { + input: data, + current_byte: 0, + last_mask: 1, + code_value_bits: T::CODE_VALUE_BITS as i32, + } + } + fn get_bit(&mut self) -> Option { + if self.last_mask == 1 { + match self.input.pop() { + None => { + if self.code_value_bits <= 0 { + return None; + //panic!("IDK Man"); + } else { + self.code_value_bits -= 8; + } + } + Some(byte) => self.current_byte = byte as u32, + } + self.last_mask = 0x80; + } else { + self.last_mask >>= 1; + } + let bit = (self.current_byte & self.last_mask) != 0; + return Some(bit); + } + } + #[test] + fn bit_reader_test_i32() { + bit_reader_test_type::(); + } + #[test] + fn bit_reader_test_u32() { + bit_reader_test_type::(); + } + #[test] + fn bit_reader_test_i64() { + bit_reader_test_type::(); + } + #[test] + fn bit_reader_test_u64() { + bit_reader_test_type::(); + } + #[test] + fn bit_reader_test_i128() { + bit_reader_test_type::(); + } + fn bit_reader_test_type() { + let mut br = BitReader::from(COMPRESSED_BYTES).with_repeat_bits(T::CODE_VALUE_BITS as u16); + let mut ib = InputBits::new::(&COMPRESSED_BYTES); + + while let Some(a) = ib.get_bit() { + let b = br.get_bit().unwrap(); + assert_eq!(a, b); + } + let _ = br.get_bit().expect_err("Extra bits"); } } diff --git a/src/main.rs b/src/main.rs index 68a21c3..ef3acde 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,23 +24,27 @@ ating system. Linux is normally used in combination with the GNU operating syste th Linux added, or GNU/Linux. All the so-called Linux distributions are really distributions of GNU/Linux! "; type CodeValue = u32; - println!("compressing..."); + println!( + "Using model: ModelA<{}>", + std::any::type_name::() + ); let model: ModelA = ModelA::default(); model.print_metrics(); println!(""); - let enc = model.compress(data); - //println!("{}", enc.len()); - println!("ModelA compressed to {} bytes", enc.len()); + let mut compressed = Vec::new(); + println!("compressing..."); + model.compress(&data[..], &mut compressed).unwrap(); + println!("ModelA compressed to {} bytes", compressed.len()); println!( "Compression Ratio: {}", - data.len() as f64 / enc.len() as f64 + data.len() as f64 / compressed.len() as f64 ); - //println!("--------- Compressed data ---------\n{}", dump_hex(&enc)); println!(""); println!("decompressing..."); + let mut decompressed = Vec::new(); let model: ModelA = ModelA::default(); - let dec = model.decompress(&enc).unwrap(); - println!("{}", String::from_utf8_lossy(&dec)); + model.decompress(&compressed, &mut decompressed).unwrap(); + println!("{}", String::from_utf8_lossy(&decompressed)); } diff --git a/src/modelA.rs b/src/modelA.rs index b43c9b5..0a24073 100644 --- a/src/modelA.rs +++ b/src/modelA.rs @@ -1,13 +1,13 @@ -use core::panic; use std::{ fmt::Display, + io::{self, Read, Write}, ops::{BitAnd, Shl}, usize, }; use num::{FromPrimitive, Integer}; -use crate::bit_buffer::{BitWriter, Poppable}; +use crate::bit_buffer::{BitReader, BitWriter}; trait Digits { const PRECISION: usize; @@ -72,43 +72,6 @@ struct Prob { total: T, } -struct InputBits<'a> { - input: &'a [u8], - current_byte: u32, - last_mask: u32, - code_value_bits: i32, -} - -impl<'a> InputBits<'a> { - pub fn new(data: &'a [u8]) -> Self { - Self { - input: data, - current_byte: 0, - last_mask: 1, - code_value_bits: T::CODE_VALUE_BITS as i32, - } - } - fn get_bit(&mut self) -> bool { - if self.last_mask == 1 { - match self.input.pop() { - None => { - if self.code_value_bits <= 0 { - panic!("IDK Man"); - } else { - self.code_value_bits -= 8; - } - } - Some(byte) => self.current_byte = byte as u32, - } - self.last_mask = 0x80; - } else { - self.last_mask >>= 1; - } - let bit = (self.current_byte & self.last_mask) != 0; - return bit; - } -} - #[derive(Debug)] #[allow(non_camel_case_types)] pub struct ModelA { @@ -174,22 +137,28 @@ impl ModelA { fn getCount(&self) -> CODE_VALUE { self.cumulative_frequency[257] } - pub fn decompress(mut self, input: &[u8]) -> Option> { + + pub fn decompress>>( + mut self, + input: I, + output: &mut O, + ) -> io::Result<()> { let ONE: CODE_VALUE = CODE_VALUE::one(); let ZERO: CODE_VALUE = CODE_VALUE::zero(); let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap(); let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); - let mut input = InputBits::new::(input); - let mut output = vec![]; + let mut input: BitReader = input + .into() + .with_repeat_bits(CODE_VALUE::CODE_VALUE_BITS as u16); let mut low: CODE_VALUE = ZERO; let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); let mut value: CODE_VALUE = ZERO; for _ in 0..CODE_VALUE::CODE_VALUE_BITS { - value = (value << CODE_VALUE::one()) + if input.get_bit() { ONE } else { ZERO }; + value = (value << CODE_VALUE::one()) + if input.get_bit()? { ONE } else { ZERO }; } loop { let range: CODE_VALUE = high - low + ONE; @@ -198,7 +167,7 @@ impl ModelA { if c > 255 || c < 0 { break; } - output.push(c as u8); + output.write(&[c as u8])?; high = low + (range * p.high) / p.total - ONE; low = low + (range * p.low) / p.total; loop { @@ -216,13 +185,17 @@ impl ModelA { } low = low << ONE; high = (high << ONE) + ONE; - value = (value << ONE) + if input.get_bit() { ONE } else { ZERO }; + value = (value << ONE) + if input.get_bit()? { ONE } else { ZERO }; } } - return Some(output); + return Ok(()); } - pub fn compress(mut self, input: &[u8]) -> Vec { + pub fn compress( + mut self, + input: IN, + output: &mut OUT, + ) -> std::io::Result<()> { let ONE: CODE_VALUE = CODE_VALUE::one(); let ZERO: CODE_VALUE = CODE_VALUE::zero(); let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); @@ -230,13 +203,17 @@ impl ModelA { let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); - let mut output = BitWriter::new(); + let mut output: BitWriter = output.into(); let mut pending_bits: i32 = 0; let mut low: CODE_VALUE = ZERO; let mut high: CODE_VALUE = MAX_CODE; - for mut c in input.iter().map(|b| *b as i32).chain([256_i32]) { + let mut iter = input + .bytes() + .map(|r| r.map(|b| b as i32)) + .chain([Ok(256_i32)]); + while let Some(Ok(mut c)) = iter.next() { if c > 255 || c < 0 { c = 256; } @@ -247,9 +224,9 @@ impl ModelA { loop { if high < ONE_HALF { - Self::write_with_pending(false, &mut pending_bits, &mut output); + Self::write_with_pending(false, &mut pending_bits, &mut output)?; } else if low >= ONE_HALF { - Self::write_with_pending(true, &mut pending_bits, &mut output); + Self::write_with_pending(true, &mut pending_bits, &mut output)?; } else if low >= ONE_FORTH && high < THREE_FOURTHS { pending_bits += 1; low = low - ONE_FORTH; @@ -266,46 +243,52 @@ impl ModelA { } pending_bits += 1; if low < ONE_FORTH { - Self::write_with_pending(false, &mut pending_bits, &mut output); + Self::write_with_pending(false, &mut pending_bits, &mut output)?; } else { - Self::write_with_pending(true, &mut pending_bits, &mut output); + Self::write_with_pending(true, &mut pending_bits, &mut output)?; } - return output.into(); + return output.flush(); } - fn write_with_pending(bit: bool, pending: &mut i32, output: &mut BitWriter) { - output.write(bit); + fn write_with_pending( + bit: bool, + pending: &mut i32, + output: &mut BitWriter, + ) -> std::io::Result<()> { + output.write(bit)?; for _ in 0..*pending { - output.write(!bit); + output.write(!bit)?; } *pending = 0; + Ok(()) } } #[cfg(test)] -mod tests { +pub mod tests { use super::*; - const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n"; + pub const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n"; /// Compressed bytes taken from output of the c++ version - const COMPRESSED_BYTES: [u8; 14] = [ + pub const COMPRESSED_BYTES: [u8; 14] = [ 0x67, 0xfc, 0x3e, 0x4a, 0x9d, 0x03, 0x7f, 0x35, 0xf1, 0x08, 0xd8, 0xa6, 0xbc, 0xda, ]; #[test] fn compression_test() { let model: ModelA = ModelA::default(); - let enc = model.compress(UNCOMPRESSED_BYTES); + let mut enc = Vec::new(); + model.compress(&UNCOMPRESSED_BYTES[..], &mut enc).unwrap(); assert_eq!(COMPRESSED_BYTES.len(), enc.len()); for (a, b) in enc.iter().zip(COMPRESSED_BYTES.iter()) { assert_eq!(a, b); } } - #[test] fn decompression_test() { let model: ModelA = ModelA::default(); - let dec = model.decompress(&COMPRESSED_BYTES).unwrap(); + let mut dec = Vec::new(); + model.decompress(&COMPRESSED_BYTES, &mut dec).unwrap(); assert_eq!(UNCOMPRESSED_BYTES.len(), dec.len()); for (a, b) in dec.iter().zip(UNCOMPRESSED_BYTES.iter()) { assert_eq!(a, b);