Refactor compress and decompress to use generic io::Read and io::Write traits

This commit is contained in:
Lucas Schumacher 2024-09-10 18:14:59 -04:00
parent b68ef65f31
commit 34d0f0bafa
3 changed files with 204 additions and 136 deletions

View File

@ -1,50 +1,110 @@
#[derive(Debug)] use std::io::{self, Bytes, Cursor, Read, Write};
pub struct BitWriter {
data: Vec<u8>, pub struct BitWriter<'a, W: ?Sized + Write> {
bits: u8, bits: u8,
nextbit: usize, nextbit: usize,
output: &'a mut W,
} }
impl BitWriter { impl<'a, W: Write> From<&'a mut W> for BitWriter<'a, W> {
pub fn new() -> Self { fn from(value: &'a mut W) -> Self {
BitWriter::new(value)
}
}
impl<'a, W: Write> BitWriter<'a, W> {
pub fn new(writer: &'a mut W) -> Self {
return BitWriter { return BitWriter {
data: vec![],
bits: 0, bits: 0,
nextbit: 7, nextbit: 7,
output: writer,
}; };
//writer.into()
} }
pub fn write(&mut self, bit: bool) { pub fn write(&mut self, bit: bool) -> io::Result<()> {
if bit { if bit {
self.bits = 1 << self.nextbit | self.bits; self.bits = 1 << self.nextbit | self.bits;
} }
if self.nextbit == 0 { if self.nextbit == 0 {
self.data.push(self.bits); self.output.write(&[self.bits])?;
self.bits = 0; self.bits = 0;
self.nextbit = 7; self.nextbit = 7;
} else { } else {
self.nextbit -= 1; self.nextbit -= 1;
} }
Ok(())
} }
pub fn flush(mut self) -> Vec<u8> { pub fn flush(self) -> std::io::Result<()> {
if self.bits != 0 { if self.bits != 0 {
self.data.push(self.bits); self.output.write(&[self.bits])?;
} }
return self.data; return self.output.flush();
} }
} }
impl Into<Vec<u8>> for BitWriter { pub struct BitReader<T> {
fn into(self) -> Vec<u8> { next: u8,
self.flush() bits: u8,
repeat_bits: i32,
input: Bytes<T>,
} }
}
impl IntoIterator for BitWriter {
type Item = u8;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter { impl<R: Read> From<R> for BitReader<R> {
self.flush().into_iter() fn from(value: R) -> Self {
BitReader::new(value.bytes())
} }
} }
impl<T: AsRef<[u8]>> From<T> for BitReader<Cursor<T>> {
fn from(value: T) -> Self {
let c = Cursor::new(value);
c.into()
}
}
impl<R: Read> BitReader<R> {
pub fn new(value: Bytes<R>) -> Self {
BitReader {
next: 0,
bits: 0,
repeat_bits: 0,
input: value,
}
}
pub fn get_bit(&mut self) -> io::Result<bool> {
if self.next == 0 {
let next = self.input.next().transpose()?;
if let Some(byte) = next {
self.bits = byte;
} else if self.repeat_bits <= 0 {
return Err(io::Error::other("No more bits"));
} else {
println!("{}", self.repeat_bits);
self.repeat_bits -= 8;
}
self.next = 1 << 7;
}
let bit = (self.bits & self.next) > 0;
self.next = self.next >> 1;
return Ok(bit);
}
}
impl<T> BitReader<T> {
pub fn with_repeat_bits(mut self, n_bits: u16) -> Self {
self.repeat_bits = n_bits as i32;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::modelA::tests::COMPRESSED_BYTES;
use crate::modelA::Metrics;
struct InputBits<'a> {
input: &'a [u8],
current_byte: u32,
last_mask: u32,
code_value_bits: i32,
}
pub trait Poppable { pub trait Poppable {
fn pop(&mut self) -> Option<u8>; fn pop(&mut self) -> Option<u8>;
@ -59,43 +119,64 @@ impl Poppable for &[u8] {
return Some(out); return Some(out);
} }
} }
impl<'a> InputBits<'a> {
pub struct BitReader<'a> { pub fn new<T: Metrics>(data: &'a [u8]) -> Self {
data: &'a [u8], Self {
// bits: u8, input: data,
nextbit: usize, current_byte: 0,
} last_mask: 1,
code_value_bits: T::CODE_VALUE_BITS as i32,
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
BitReader { data, nextbit: 7 }
} }
} }
fn get_bit(&mut self) -> Option<bool> {
impl<'a> From<&'a [u8]> for BitReader<'a> { if self.last_mask == 1 {
fn from(value: &'a [u8]) -> Self { match self.input.pop() {
BitReader::new(value) None => {
} if self.code_value_bits <= 0 {
}
impl BitReader<'_> {
pub fn pop(&mut self) -> Option<bool> {
if self.data.len() == 0 {
return None; return None;
} //panic!("IDK Man");
let bit = (self.data[0] >> self.nextbit) & 1; } else {
if self.nextbit == 0 { self.code_value_bits -= 8;
self.data.pop();
self.nextbit = 8;
}
self.nextbit -= 1;
return Some(bit > 0);
} }
} }
Some(byte) => self.current_byte = byte as u32,
}
self.last_mask = 0x80;
} else {
self.last_mask >>= 1;
}
let bit = (self.current_byte & self.last_mask) != 0;
return Some(bit);
}
}
#[test]
fn bit_reader_test_i32() {
bit_reader_test_type::<i32>();
}
#[test]
fn bit_reader_test_u32() {
bit_reader_test_type::<u32>();
}
#[test]
fn bit_reader_test_i64() {
bit_reader_test_type::<i64>();
}
#[test]
fn bit_reader_test_u64() {
bit_reader_test_type::<u64>();
}
#[test]
fn bit_reader_test_i128() {
bit_reader_test_type::<i128>();
}
fn bit_reader_test_type<T: Metrics>() {
let mut br = BitReader::from(COMPRESSED_BYTES).with_repeat_bits(T::CODE_VALUE_BITS as u16);
let mut ib = InputBits::new::<T>(&COMPRESSED_BYTES);
impl Iterator for BitReader<'_> { while let Some(a) = ib.get_bit() {
type Item = bool; let b = br.get_bit().unwrap();
fn next(&mut self) -> Option<Self::Item> { assert_eq!(a, b);
self.pop() }
let _ = br.get_bit().expect_err("Extra bits");
} }
} }

View File

@ -24,23 +24,27 @@ ating system. Linux is normally used in combination with the GNU operating syste
th Linux added, or GNU/Linux. All the so-called Linux distributions are really distributions of GNU/Linux! th Linux added, or GNU/Linux. All the so-called Linux distributions are really distributions of GNU/Linux!
"; ";
type CodeValue = u32; type CodeValue = u32;
println!("compressing..."); println!(
"Using model: ModelA<{}>",
std::any::type_name::<CodeValue>()
);
let model: ModelA<CodeValue> = ModelA::default(); let model: ModelA<CodeValue> = ModelA::default();
model.print_metrics(); model.print_metrics();
println!(""); println!("");
let enc = model.compress(data); let mut compressed = Vec::new();
//println!("{}", enc.len()); println!("compressing...");
println!("ModelA compressed to {} bytes", enc.len()); model.compress(&data[..], &mut compressed).unwrap();
println!("ModelA compressed to {} bytes", compressed.len());
println!( println!(
"Compression Ratio: {}", "Compression Ratio: {}",
data.len() as f64 / enc.len() as f64 data.len() as f64 / compressed.len() as f64
); );
//println!("--------- Compressed data ---------\n{}", dump_hex(&enc));
println!(""); println!("");
println!("decompressing..."); println!("decompressing...");
let mut decompressed = Vec::new();
let model: ModelA<CodeValue> = ModelA::default(); let model: ModelA<CodeValue> = ModelA::default();
let dec = model.decompress(&enc).unwrap(); model.decompress(&compressed, &mut decompressed).unwrap();
println!("{}", String::from_utf8_lossy(&dec)); println!("{}", String::from_utf8_lossy(&decompressed));
} }

View File

@ -1,13 +1,13 @@
use core::panic;
use std::{ use std::{
fmt::Display, fmt::Display,
io::{self, Read, Write},
ops::{BitAnd, Shl}, ops::{BitAnd, Shl},
usize, usize,
}; };
use num::{FromPrimitive, Integer}; use num::{FromPrimitive, Integer};
use crate::bit_buffer::{BitWriter, Poppable}; use crate::bit_buffer::{BitReader, BitWriter};
trait Digits { trait Digits {
const PRECISION: usize; const PRECISION: usize;
@ -72,43 +72,6 @@ struct Prob<T> {
total: T, total: T,
} }
struct InputBits<'a> {
input: &'a [u8],
current_byte: u32,
last_mask: u32,
code_value_bits: i32,
}
impl<'a> InputBits<'a> {
pub fn new<T: Metrics>(data: &'a [u8]) -> Self {
Self {
input: data,
current_byte: 0,
last_mask: 1,
code_value_bits: T::CODE_VALUE_BITS as i32,
}
}
fn get_bit(&mut self) -> bool {
if self.last_mask == 1 {
match self.input.pop() {
None => {
if self.code_value_bits <= 0 {
panic!("IDK Man");
} else {
self.code_value_bits -= 8;
}
}
Some(byte) => self.current_byte = byte as u32,
}
self.last_mask = 0x80;
} else {
self.last_mask >>= 1;
}
let bit = (self.current_byte & self.last_mask) != 0;
return bit;
}
}
#[derive(Debug)] #[derive(Debug)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub struct ModelA<CODE_VALUE> { pub struct ModelA<CODE_VALUE> {
@ -174,22 +137,28 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
fn getCount(&self) -> CODE_VALUE { fn getCount(&self) -> CODE_VALUE {
self.cumulative_frequency[257] self.cumulative_frequency[257]
} }
pub fn decompress(mut self, input: &[u8]) -> Option<Vec<u8>> {
pub fn decompress<T: io::Read, O: io::Write, I: Into<BitReader<T>>>(
mut self,
input: I,
output: &mut O,
) -> io::Result<()> {
let ONE: CODE_VALUE = CODE_VALUE::one(); let ONE: CODE_VALUE = CODE_VALUE::one();
let ZERO: CODE_VALUE = CODE_VALUE::zero(); let ZERO: CODE_VALUE = CODE_VALUE::zero();
let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap(); let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap();
let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap();
let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap();
let mut input = InputBits::new::<CODE_VALUE>(input); let mut input: BitReader<T> = input
let mut output = vec![]; .into()
.with_repeat_bits(CODE_VALUE::CODE_VALUE_BITS as u16);
let mut low: CODE_VALUE = ZERO; let mut low: CODE_VALUE = ZERO;
let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap();
let mut value: CODE_VALUE = ZERO; let mut value: CODE_VALUE = ZERO;
for _ in 0..CODE_VALUE::CODE_VALUE_BITS { for _ in 0..CODE_VALUE::CODE_VALUE_BITS {
value = (value << CODE_VALUE::one()) + if input.get_bit() { ONE } else { ZERO }; value = (value << CODE_VALUE::one()) + if input.get_bit()? { ONE } else { ZERO };
} }
loop { loop {
let range: CODE_VALUE = high - low + ONE; let range: CODE_VALUE = high - low + ONE;
@ -198,7 +167,7 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
if c > 255 || c < 0 { if c > 255 || c < 0 {
break; break;
} }
output.push(c as u8); output.write(&[c as u8])?;
high = low + (range * p.high) / p.total - ONE; high = low + (range * p.high) / p.total - ONE;
low = low + (range * p.low) / p.total; low = low + (range * p.low) / p.total;
loop { loop {
@ -216,13 +185,17 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
} }
low = low << ONE; low = low << ONE;
high = (high << ONE) + ONE; high = (high << ONE) + ONE;
value = (value << ONE) + if input.get_bit() { ONE } else { ZERO }; value = (value << ONE) + if input.get_bit()? { ONE } else { ZERO };
} }
} }
return Some(output); return Ok(());
} }
pub fn compress(mut self, input: &[u8]) -> Vec<u8> { pub fn compress<IN: Read, OUT: Write>(
mut self,
input: IN,
output: &mut OUT,
) -> std::io::Result<()> {
let ONE: CODE_VALUE = CODE_VALUE::one(); let ONE: CODE_VALUE = CODE_VALUE::one();
let ZERO: CODE_VALUE = CODE_VALUE::zero(); let ZERO: CODE_VALUE = CODE_VALUE::zero();
let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap();
@ -230,13 +203,17 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap();
let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap();
let mut output = BitWriter::new(); let mut output: BitWriter<OUT> = output.into();
let mut pending_bits: i32 = 0; let mut pending_bits: i32 = 0;
let mut low: CODE_VALUE = ZERO; let mut low: CODE_VALUE = ZERO;
let mut high: CODE_VALUE = MAX_CODE; let mut high: CODE_VALUE = MAX_CODE;
for mut c in input.iter().map(|b| *b as i32).chain([256_i32]) { let mut iter = input
.bytes()
.map(|r| r.map(|b| b as i32))
.chain([Ok(256_i32)]);
while let Some(Ok(mut c)) = iter.next() {
if c > 255 || c < 0 { if c > 255 || c < 0 {
c = 256; c = 256;
} }
@ -247,9 +224,9 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
loop { loop {
if high < ONE_HALF { if high < ONE_HALF {
Self::write_with_pending(false, &mut pending_bits, &mut output); Self::write_with_pending(false, &mut pending_bits, &mut output)?;
} else if low >= ONE_HALF { } else if low >= ONE_HALF {
Self::write_with_pending(true, &mut pending_bits, &mut output); Self::write_with_pending(true, &mut pending_bits, &mut output)?;
} else if low >= ONE_FORTH && high < THREE_FOURTHS { } else if low >= ONE_FORTH && high < THREE_FOURTHS {
pending_bits += 1; pending_bits += 1;
low = low - ONE_FORTH; low = low - ONE_FORTH;
@ -266,46 +243,52 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
} }
pending_bits += 1; pending_bits += 1;
if low < ONE_FORTH { if low < ONE_FORTH {
Self::write_with_pending(false, &mut pending_bits, &mut output); Self::write_with_pending(false, &mut pending_bits, &mut output)?;
} else { } else {
Self::write_with_pending(true, &mut pending_bits, &mut output); Self::write_with_pending(true, &mut pending_bits, &mut output)?;
} }
return output.into(); return output.flush();
} }
fn write_with_pending(bit: bool, pending: &mut i32, output: &mut BitWriter) { fn write_with_pending<W: std::io::Write>(
output.write(bit); bit: bool,
pending: &mut i32,
output: &mut BitWriter<W>,
) -> std::io::Result<()> {
output.write(bit)?;
for _ in 0..*pending { for _ in 0..*pending {
output.write(!bit); output.write(!bit)?;
} }
*pending = 0; *pending = 0;
Ok(())
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { pub mod tests {
use super::*; use super::*;
const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n"; pub const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n";
/// Compressed bytes taken from output of the c++ version /// Compressed bytes taken from output of the c++ version
const COMPRESSED_BYTES: [u8; 14] = [ pub const COMPRESSED_BYTES: [u8; 14] = [
0x67, 0xfc, 0x3e, 0x4a, 0x9d, 0x03, 0x7f, 0x35, 0xf1, 0x08, 0xd8, 0xa6, 0xbc, 0xda, 0x67, 0xfc, 0x3e, 0x4a, 0x9d, 0x03, 0x7f, 0x35, 0xf1, 0x08, 0xd8, 0xa6, 0xbc, 0xda,
]; ];
#[test] #[test]
fn compression_test() { fn compression_test() {
let model: ModelA<i32> = ModelA::default(); let model: ModelA<i32> = ModelA::default();
let enc = model.compress(UNCOMPRESSED_BYTES); let mut enc = Vec::new();
model.compress(&UNCOMPRESSED_BYTES[..], &mut enc).unwrap();
assert_eq!(COMPRESSED_BYTES.len(), enc.len()); assert_eq!(COMPRESSED_BYTES.len(), enc.len());
for (a, b) in enc.iter().zip(COMPRESSED_BYTES.iter()) { for (a, b) in enc.iter().zip(COMPRESSED_BYTES.iter()) {
assert_eq!(a, b); assert_eq!(a, b);
} }
} }
#[test] #[test]
fn decompression_test() { fn decompression_test() {
let model: ModelA<i32> = ModelA::default(); let model: ModelA<i32> = ModelA::default();
let dec = model.decompress(&COMPRESSED_BYTES).unwrap(); let mut dec = Vec::new();
model.decompress(&COMPRESSED_BYTES, &mut dec).unwrap();
assert_eq!(UNCOMPRESSED_BYTES.len(), dec.len()); assert_eq!(UNCOMPRESSED_BYTES.len(), dec.len());
for (a, b) in dec.iter().zip(UNCOMPRESSED_BYTES.iter()) { for (a, b) in dec.iter().zip(UNCOMPRESSED_BYTES.iter()) {
assert_eq!(a, b); assert_eq!(a, b);