Refactor compress and decompress to use generic io::Read and io::Write traits
This commit is contained in:
parent
b68ef65f31
commit
34d0f0bafa
@ -1,55 +1,115 @@
|
||||
#[derive(Debug)]
|
||||
pub struct BitWriter {
|
||||
data: Vec<u8>,
|
||||
use std::io::{self, Bytes, Cursor, Read, Write};
|
||||
|
||||
pub struct BitWriter<'a, W: ?Sized + Write> {
|
||||
bits: u8,
|
||||
nextbit: usize,
|
||||
output: &'a mut W,
|
||||
}
|
||||
impl BitWriter {
|
||||
pub fn new() -> Self {
|
||||
impl<'a, W: Write> From<&'a mut W> for BitWriter<'a, W> {
|
||||
fn from(value: &'a mut W) -> Self {
|
||||
BitWriter::new(value)
|
||||
}
|
||||
}
|
||||
impl<'a, W: Write> BitWriter<'a, W> {
|
||||
pub fn new(writer: &'a mut W) -> Self {
|
||||
return BitWriter {
|
||||
data: vec![],
|
||||
bits: 0,
|
||||
nextbit: 7,
|
||||
output: writer,
|
||||
};
|
||||
//writer.into()
|
||||
}
|
||||
pub fn write(&mut self, bit: bool) {
|
||||
pub fn write(&mut self, bit: bool) -> io::Result<()> {
|
||||
if bit {
|
||||
self.bits = 1 << self.nextbit | self.bits;
|
||||
}
|
||||
if self.nextbit == 0 {
|
||||
self.data.push(self.bits);
|
||||
self.output.write(&[self.bits])?;
|
||||
self.bits = 0;
|
||||
self.nextbit = 7;
|
||||
} else {
|
||||
self.nextbit -= 1;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub fn flush(mut self) -> Vec<u8> {
|
||||
pub fn flush(self) -> std::io::Result<()> {
|
||||
if self.bits != 0 {
|
||||
self.data.push(self.bits);
|
||||
self.output.write(&[self.bits])?;
|
||||
}
|
||||
return self.data;
|
||||
return self.output.flush();
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Vec<u8>> for BitWriter {
|
||||
fn into(self) -> Vec<u8> {
|
||||
self.flush()
|
||||
pub struct BitReader<T> {
|
||||
next: u8,
|
||||
bits: u8,
|
||||
repeat_bits: i32,
|
||||
input: Bytes<T>,
|
||||
}
|
||||
|
||||
impl<R: Read> From<R> for BitReader<R> {
|
||||
fn from(value: R) -> Self {
|
||||
BitReader::new(value.bytes())
|
||||
}
|
||||
}
|
||||
impl IntoIterator for BitWriter {
|
||||
type Item = u8;
|
||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.flush().into_iter()
|
||||
impl<T: AsRef<[u8]>> From<T> for BitReader<Cursor<T>> {
|
||||
fn from(value: T) -> Self {
|
||||
let c = Cursor::new(value);
|
||||
c.into()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Poppable {
|
||||
impl<R: Read> BitReader<R> {
|
||||
pub fn new(value: Bytes<R>) -> Self {
|
||||
BitReader {
|
||||
next: 0,
|
||||
bits: 0,
|
||||
repeat_bits: 0,
|
||||
input: value,
|
||||
}
|
||||
}
|
||||
pub fn get_bit(&mut self) -> io::Result<bool> {
|
||||
if self.next == 0 {
|
||||
let next = self.input.next().transpose()?;
|
||||
if let Some(byte) = next {
|
||||
self.bits = byte;
|
||||
} else if self.repeat_bits <= 0 {
|
||||
return Err(io::Error::other("No more bits"));
|
||||
} else {
|
||||
println!("{}", self.repeat_bits);
|
||||
self.repeat_bits -= 8;
|
||||
}
|
||||
self.next = 1 << 7;
|
||||
}
|
||||
let bit = (self.bits & self.next) > 0;
|
||||
self.next = self.next >> 1;
|
||||
return Ok(bit);
|
||||
}
|
||||
}
|
||||
impl<T> BitReader<T> {
|
||||
pub fn with_repeat_bits(mut self, n_bits: u16) -> Self {
|
||||
self.repeat_bits = n_bits as i32;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::modelA::tests::COMPRESSED_BYTES;
|
||||
use crate::modelA::Metrics;
|
||||
|
||||
struct InputBits<'a> {
|
||||
input: &'a [u8],
|
||||
current_byte: u32,
|
||||
last_mask: u32,
|
||||
code_value_bits: i32,
|
||||
}
|
||||
|
||||
pub trait Poppable {
|
||||
fn pop(&mut self) -> Option<u8>;
|
||||
}
|
||||
impl Poppable for &[u8] {
|
||||
}
|
||||
impl Poppable for &[u8] {
|
||||
fn pop(&mut self) -> Option<u8> {
|
||||
if self.len() == 0 {
|
||||
return None;
|
||||
@ -58,44 +118,65 @@ impl Poppable for &[u8] {
|
||||
*self = &self[1..];
|
||||
return Some(out);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BitReader<'a> {
|
||||
data: &'a [u8],
|
||||
// bits: u8,
|
||||
nextbit: usize,
|
||||
}
|
||||
|
||||
impl<'a> BitReader<'a> {
|
||||
pub fn new(data: &'a [u8]) -> Self {
|
||||
BitReader { data, nextbit: 7 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a [u8]> for BitReader<'a> {
|
||||
fn from(value: &'a [u8]) -> Self {
|
||||
BitReader::new(value)
|
||||
impl<'a> InputBits<'a> {
|
||||
pub fn new<T: Metrics>(data: &'a [u8]) -> Self {
|
||||
Self {
|
||||
input: data,
|
||||
current_byte: 0,
|
||||
last_mask: 1,
|
||||
code_value_bits: T::CODE_VALUE_BITS as i32,
|
||||
}
|
||||
}
|
||||
|
||||
impl BitReader<'_> {
|
||||
pub fn pop(&mut self) -> Option<bool> {
|
||||
if self.data.len() == 0 {
|
||||
}
|
||||
fn get_bit(&mut self) -> Option<bool> {
|
||||
if self.last_mask == 1 {
|
||||
match self.input.pop() {
|
||||
None => {
|
||||
if self.code_value_bits <= 0 {
|
||||
return None;
|
||||
//panic!("IDK Man");
|
||||
} else {
|
||||
self.code_value_bits -= 8;
|
||||
}
|
||||
let bit = (self.data[0] >> self.nextbit) & 1;
|
||||
if self.nextbit == 0 {
|
||||
self.data.pop();
|
||||
self.nextbit = 8;
|
||||
}
|
||||
self.nextbit -= 1;
|
||||
return Some(bit > 0);
|
||||
Some(byte) => self.current_byte = byte as u32,
|
||||
}
|
||||
}
|
||||
self.last_mask = 0x80;
|
||||
} else {
|
||||
self.last_mask >>= 1;
|
||||
}
|
||||
let bit = (self.current_byte & self.last_mask) != 0;
|
||||
return Some(bit);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn bit_reader_test_i32() {
|
||||
bit_reader_test_type::<i32>();
|
||||
}
|
||||
#[test]
|
||||
fn bit_reader_test_u32() {
|
||||
bit_reader_test_type::<u32>();
|
||||
}
|
||||
#[test]
|
||||
fn bit_reader_test_i64() {
|
||||
bit_reader_test_type::<i64>();
|
||||
}
|
||||
#[test]
|
||||
fn bit_reader_test_u64() {
|
||||
bit_reader_test_type::<u64>();
|
||||
}
|
||||
#[test]
|
||||
fn bit_reader_test_i128() {
|
||||
bit_reader_test_type::<i128>();
|
||||
}
|
||||
fn bit_reader_test_type<T: Metrics>() {
|
||||
let mut br = BitReader::from(COMPRESSED_BYTES).with_repeat_bits(T::CODE_VALUE_BITS as u16);
|
||||
let mut ib = InputBits::new::<T>(&COMPRESSED_BYTES);
|
||||
|
||||
impl Iterator for BitReader<'_> {
|
||||
type Item = bool;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.pop()
|
||||
while let Some(a) = ib.get_bit() {
|
||||
let b = br.get_bit().unwrap();
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
let _ = br.get_bit().expect_err("Extra bits");
|
||||
}
|
||||
}
|
||||
|
||||
20
src/main.rs
20
src/main.rs
@ -24,23 +24,27 @@ ating system. Linux is normally used in combination with the GNU operating syste
|
||||
th Linux added, or GNU/Linux. All the so-called Linux distributions are really distributions of GNU/Linux!
|
||||
";
|
||||
type CodeValue = u32;
|
||||
println!("compressing...");
|
||||
println!(
|
||||
"Using model: ModelA<{}>",
|
||||
std::any::type_name::<CodeValue>()
|
||||
);
|
||||
let model: ModelA<CodeValue> = ModelA::default();
|
||||
model.print_metrics();
|
||||
println!("");
|
||||
|
||||
let enc = model.compress(data);
|
||||
//println!("{}", enc.len());
|
||||
println!("ModelA compressed to {} bytes", enc.len());
|
||||
let mut compressed = Vec::new();
|
||||
println!("compressing...");
|
||||
model.compress(&data[..], &mut compressed).unwrap();
|
||||
println!("ModelA compressed to {} bytes", compressed.len());
|
||||
println!(
|
||||
"Compression Ratio: {}",
|
||||
data.len() as f64 / enc.len() as f64
|
||||
data.len() as f64 / compressed.len() as f64
|
||||
);
|
||||
//println!("--------- Compressed data ---------\n{}", dump_hex(&enc));
|
||||
println!("");
|
||||
|
||||
println!("decompressing...");
|
||||
let mut decompressed = Vec::new();
|
||||
let model: ModelA<CodeValue> = ModelA::default();
|
||||
let dec = model.decompress(&enc).unwrap();
|
||||
println!("{}", String::from_utf8_lossy(&dec));
|
||||
model.decompress(&compressed, &mut decompressed).unwrap();
|
||||
println!("{}", String::from_utf8_lossy(&decompressed));
|
||||
}
|
||||
|
||||
109
src/modelA.rs
109
src/modelA.rs
@ -1,13 +1,13 @@
|
||||
use core::panic;
|
||||
use std::{
|
||||
fmt::Display,
|
||||
io::{self, Read, Write},
|
||||
ops::{BitAnd, Shl},
|
||||
usize,
|
||||
};
|
||||
|
||||
use num::{FromPrimitive, Integer};
|
||||
|
||||
use crate::bit_buffer::{BitWriter, Poppable};
|
||||
use crate::bit_buffer::{BitReader, BitWriter};
|
||||
|
||||
trait Digits {
|
||||
const PRECISION: usize;
|
||||
@ -72,43 +72,6 @@ struct Prob<T> {
|
||||
total: T,
|
||||
}
|
||||
|
||||
struct InputBits<'a> {
|
||||
input: &'a [u8],
|
||||
current_byte: u32,
|
||||
last_mask: u32,
|
||||
code_value_bits: i32,
|
||||
}
|
||||
|
||||
impl<'a> InputBits<'a> {
|
||||
pub fn new<T: Metrics>(data: &'a [u8]) -> Self {
|
||||
Self {
|
||||
input: data,
|
||||
current_byte: 0,
|
||||
last_mask: 1,
|
||||
code_value_bits: T::CODE_VALUE_BITS as i32,
|
||||
}
|
||||
}
|
||||
fn get_bit(&mut self) -> bool {
|
||||
if self.last_mask == 1 {
|
||||
match self.input.pop() {
|
||||
None => {
|
||||
if self.code_value_bits <= 0 {
|
||||
panic!("IDK Man");
|
||||
} else {
|
||||
self.code_value_bits -= 8;
|
||||
}
|
||||
}
|
||||
Some(byte) => self.current_byte = byte as u32,
|
||||
}
|
||||
self.last_mask = 0x80;
|
||||
} else {
|
||||
self.last_mask >>= 1;
|
||||
}
|
||||
let bit = (self.current_byte & self.last_mask) != 0;
|
||||
return bit;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub struct ModelA<CODE_VALUE> {
|
||||
@ -174,22 +137,28 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||
fn getCount(&self) -> CODE_VALUE {
|
||||
self.cumulative_frequency[257]
|
||||
}
|
||||
pub fn decompress(mut self, input: &[u8]) -> Option<Vec<u8>> {
|
||||
|
||||
pub fn decompress<T: io::Read, O: io::Write, I: Into<BitReader<T>>>(
|
||||
mut self,
|
||||
input: I,
|
||||
output: &mut O,
|
||||
) -> io::Result<()> {
|
||||
let ONE: CODE_VALUE = CODE_VALUE::one();
|
||||
let ZERO: CODE_VALUE = CODE_VALUE::zero();
|
||||
let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap();
|
||||
let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap();
|
||||
let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap();
|
||||
|
||||
let mut input = InputBits::new::<CODE_VALUE>(input);
|
||||
let mut output = vec![];
|
||||
let mut input: BitReader<T> = input
|
||||
.into()
|
||||
.with_repeat_bits(CODE_VALUE::CODE_VALUE_BITS as u16);
|
||||
|
||||
let mut low: CODE_VALUE = ZERO;
|
||||
let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap();
|
||||
let mut value: CODE_VALUE = ZERO;
|
||||
|
||||
for _ in 0..CODE_VALUE::CODE_VALUE_BITS {
|
||||
value = (value << CODE_VALUE::one()) + if input.get_bit() { ONE } else { ZERO };
|
||||
value = (value << CODE_VALUE::one()) + if input.get_bit()? { ONE } else { ZERO };
|
||||
}
|
||||
loop {
|
||||
let range: CODE_VALUE = high - low + ONE;
|
||||
@ -198,7 +167,7 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||
if c > 255 || c < 0 {
|
||||
break;
|
||||
}
|
||||
output.push(c as u8);
|
||||
output.write(&[c as u8])?;
|
||||
high = low + (range * p.high) / p.total - ONE;
|
||||
low = low + (range * p.low) / p.total;
|
||||
loop {
|
||||
@ -216,13 +185,17 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||
}
|
||||
low = low << ONE;
|
||||
high = (high << ONE) + ONE;
|
||||
value = (value << ONE) + if input.get_bit() { ONE } else { ZERO };
|
||||
value = (value << ONE) + if input.get_bit()? { ONE } else { ZERO };
|
||||
}
|
||||
}
|
||||
return Some(output);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
pub fn compress(mut self, input: &[u8]) -> Vec<u8> {
|
||||
pub fn compress<IN: Read, OUT: Write>(
|
||||
mut self,
|
||||
input: IN,
|
||||
output: &mut OUT,
|
||||
) -> std::io::Result<()> {
|
||||
let ONE: CODE_VALUE = CODE_VALUE::one();
|
||||
let ZERO: CODE_VALUE = CODE_VALUE::zero();
|
||||
let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap();
|
||||
@ -230,13 +203,17 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||
let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap();
|
||||
let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap();
|
||||
|
||||
let mut output = BitWriter::new();
|
||||
let mut output: BitWriter<OUT> = output.into();
|
||||
|
||||
let mut pending_bits: i32 = 0;
|
||||
let mut low: CODE_VALUE = ZERO;
|
||||
let mut high: CODE_VALUE = MAX_CODE;
|
||||
|
||||
for mut c in input.iter().map(|b| *b as i32).chain([256_i32]) {
|
||||
let mut iter = input
|
||||
.bytes()
|
||||
.map(|r| r.map(|b| b as i32))
|
||||
.chain([Ok(256_i32)]);
|
||||
while let Some(Ok(mut c)) = iter.next() {
|
||||
if c > 255 || c < 0 {
|
||||
c = 256;
|
||||
}
|
||||
@ -247,9 +224,9 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||
|
||||
loop {
|
||||
if high < ONE_HALF {
|
||||
Self::write_with_pending(false, &mut pending_bits, &mut output);
|
||||
Self::write_with_pending(false, &mut pending_bits, &mut output)?;
|
||||
} else if low >= ONE_HALF {
|
||||
Self::write_with_pending(true, &mut pending_bits, &mut output);
|
||||
Self::write_with_pending(true, &mut pending_bits, &mut output)?;
|
||||
} else if low >= ONE_FORTH && high < THREE_FOURTHS {
|
||||
pending_bits += 1;
|
||||
low = low - ONE_FORTH;
|
||||
@ -266,46 +243,52 @@ impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||
}
|
||||
pending_bits += 1;
|
||||
if low < ONE_FORTH {
|
||||
Self::write_with_pending(false, &mut pending_bits, &mut output);
|
||||
Self::write_with_pending(false, &mut pending_bits, &mut output)?;
|
||||
} else {
|
||||
Self::write_with_pending(true, &mut pending_bits, &mut output);
|
||||
Self::write_with_pending(true, &mut pending_bits, &mut output)?;
|
||||
}
|
||||
|
||||
return output.into();
|
||||
return output.flush();
|
||||
}
|
||||
|
||||
fn write_with_pending(bit: bool, pending: &mut i32, output: &mut BitWriter) {
|
||||
output.write(bit);
|
||||
fn write_with_pending<W: std::io::Write>(
|
||||
bit: bool,
|
||||
pending: &mut i32,
|
||||
output: &mut BitWriter<W>,
|
||||
) -> std::io::Result<()> {
|
||||
output.write(bit)?;
|
||||
for _ in 0..*pending {
|
||||
output.write(!bit);
|
||||
output.write(!bit)?;
|
||||
}
|
||||
*pending = 0;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n";
|
||||
pub const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n";
|
||||
/// Compressed bytes taken from output of the c++ version
|
||||
const COMPRESSED_BYTES: [u8; 14] = [
|
||||
pub const COMPRESSED_BYTES: [u8; 14] = [
|
||||
0x67, 0xfc, 0x3e, 0x4a, 0x9d, 0x03, 0x7f, 0x35, 0xf1, 0x08, 0xd8, 0xa6, 0xbc, 0xda,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn compression_test() {
|
||||
let model: ModelA<i32> = ModelA::default();
|
||||
let enc = model.compress(UNCOMPRESSED_BYTES);
|
||||
let mut enc = Vec::new();
|
||||
model.compress(&UNCOMPRESSED_BYTES[..], &mut enc).unwrap();
|
||||
assert_eq!(COMPRESSED_BYTES.len(), enc.len());
|
||||
for (a, b) in enc.iter().zip(COMPRESSED_BYTES.iter()) {
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompression_test() {
|
||||
let model: ModelA<i32> = ModelA::default();
|
||||
let dec = model.decompress(&COMPRESSED_BYTES).unwrap();
|
||||
let mut dec = Vec::new();
|
||||
model.decompress(&COMPRESSED_BYTES, &mut dec).unwrap();
|
||||
assert_eq!(UNCOMPRESSED_BYTES.len(), dec.len());
|
||||
for (a, b) in dec.iter().zip(UNCOMPRESSED_BYTES.iter()) {
|
||||
assert_eq!(a, b);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user