diff --git a/src/bit_buffer.rs b/src/bit_buffer.rs new file mode 100644 index 0000000..b9ed89b --- /dev/null +++ b/src/bit_buffer.rs @@ -0,0 +1,87 @@ +#[derive(Debug)] +pub struct BitWriter { + data: Vec, + bits: u8, + nextbit: usize, +} +impl BitWriter { + pub fn new() -> Self { + return BitWriter { + data: vec![], + bits: 0, + nextbit: 7, + }; + } + pub fn write(&mut self, bit: bool) { + if bit { + self.bits = 1 << self.nextbit | self.bits; + } + if self.nextbit == 0 { + self.data.push(self.bits); + self.bits = 0; + self.nextbit = 7; + } else { + self.nextbit -= 1; + } + } + pub fn flush(mut self) -> Vec { + if self.bits != 0 { + self.data.push(self.bits); + } + return self.data; + } +} + +trait Poppable { + fn pop(&mut self) -> Option; +} +impl Poppable for &[u8] { + fn pop(&mut self) -> Option { + if self.len() == 0 { + return None; + } + let out = self[0]; + *self = &self[1..]; + return Some(out); + } +} + +pub struct BitReader<'a> { + data: &'a [u8], + // bits: u8, + nextbit: usize, +} + +impl<'a> BitReader<'a> { + pub fn new(data: &'a [u8]) -> Self { + BitReader { data, nextbit: 7 } + } +} + +impl<'a> From<&'a [u8]> for BitReader<'a> { + fn from(value: &'a [u8]) -> Self { + BitReader::new(value) + } +} + +impl BitReader<'_> { + pub fn pop(&mut self) -> Option { + if self.data.len() == 0 { + return None; + } + let bit = (self.data[0] >> self.nextbit) & 1; + if self.nextbit == 0 { + self.data.pop(); + self.nextbit = 8; + } + self.nextbit -= 1; + return Some(bit > 0); + } +} + +impl Iterator for BitReader<'_> { + type Item = bool; + fn next(&mut self) -> Option { + self.pop() + } +} diff --git a/src/main.rs b/src/main.rs index 578ace1..f432492 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,9 @@ // https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html use std::collections::HashMap; +mod bit_buffer; +use bit_buffer::{BitReader, BitWriter}; + type Model = HashMap; fn get_symbol(model: &Model, d: f64) -> Option { @@ -12,6 +15,14 @@ fn get_symbol(model: &Model, d: f64) -> Option { } return None; } +fn _get_symbol(model: &Model, low: f64, high: f64) -> Option { + for (symbol, (start, end)) in model { + if low >= *start && high < *end { + return Some(*symbol); + } + } + return None; +} fn encode(data: &[u8], model: &Model) -> f64 { let mut high: f64 = 1.0; @@ -24,6 +35,33 @@ fn encode(data: &[u8], model: &Model) -> f64 { } return low + (high - low) / 2.0; } +fn _encode(input: &[u8], model: &Model) -> Vec { + let mut output = BitWriter::new(); + let mut high = u64::MAX; + let mut low = u64::MIN; + for symbol in input { + let range = high - low; + let p = model.get(symbol).expect("Invalid/Unsupported data"); + high = low + (range as f64 * p.1) as u64; + low = low + (range as f64 * p.0) as u64; + loop { + if high < 1 << 63 { + output.write(false); + print!("0"); + } else if low >= 1 << (u64::BITS - 1) { + output.write(true); + print!("1"); + } else { + break; + } + low <<= 1; + high <<= 1; + high |= 1; + } + } + println!(""); + return output.flush(); +} fn decode(message: f64, model: &Model) { let mut high: f64 = 1.0; @@ -49,6 +87,31 @@ fn decode(message: f64, model: &Model) { low = low + range * p.0; } } +fn _decode(input: &[u8], model: &Model) -> Vec { + let mut high = 1.0; + let mut low = 0.0; + let mut output = vec![]; + for bit in BitReader::new(input) { + let diff = high - low; + if bit { + //print!("1"); + low = low + (diff / 2.0); + } else { + high = high - (diff / 2.0); + //print!("0"); + } + if let Some(symbol) = _get_symbol(model, low, high) { + //println!("\nGot sym: {} from [{}, {})", symbol as char, low, high); + output.push(symbol); + let (slow, shigh) = model.get(&symbol).unwrap(); + let symdiff = *shigh - *slow; + high = (high - *slow) / symdiff; + low = (low - *slow) / symdiff; + } + } + + return output; +} fn make_model(probabilities: &[(u8, f64)]) -> Model { let mut model = HashMap::new(); @@ -92,8 +155,21 @@ const ENGLISH: &[(u8, f64)] = &[ (b'-', 0.02), ]; fn main() { + let data = b"hello world-"; + println!("MODEL:"); let model: Model = make_model(ENGLISH); - let message = encode(b"hello world-", &model); + println!(""); + let message = encode(data, &model); println!("{message}"); decode(message, &model); + println!("Compression Ratio: {}\n", data.len() as f64 / 8.0); + + let _enc = _encode(data, &model); + let _dec = _decode(&_enc, &model); + + println!("{}", String::from_utf8(_dec).unwrap()); + println!( + "Compression Ratio: {}", + data.len() as f64 / _enc.len() as f64 + ); }