From 8ea134e008505bddfe3b3db7ecc1605454e67371 Mon Sep 17 00:00:00 2001 From: Lucas Schumacher Date: Fri, 6 Sep 2024 23:25:46 -0400 Subject: [PATCH] Add working encoder adaptation of modelA --- src/modelA.rs | 342 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 src/modelA.rs diff --git a/src/modelA.rs b/src/modelA.rs new file mode 100644 index 0000000..5d541bf --- /dev/null +++ b/src/modelA.rs @@ -0,0 +1,342 @@ +use core::panic; +use std::{ + fmt::Display, + ops::{BitAnd, Shl}, + usize, +}; + +use num::{FromPrimitive, Integer}; + +use crate::bit_buffer::{BitWriter, Poppable}; + +trait Digits { + const PRECISION: usize; + fn as_byte(&self) -> u8; +} +macro_rules! unsignedImplDigits { + ($($type: ident),*) => { $( + impl Digits for $type { + const PRECISION: usize = (std::mem::size_of::<$type>() * 8); + fn as_byte(&self) -> u8 {*self as u8} + } + )* }; +} +macro_rules! signedImplDigits { + ($($type: ident),*) => { $( + impl Digits for $type { + const PRECISION: usize = (std::mem::size_of::<$type>() * 8) - 1; + fn as_byte(&self) -> u8 {*self as u8} + } + )* }; +} +unsignedImplDigits!(u16, u32, u64, u128); +signedImplDigits!(i16, i32, i64, i128); + +pub trait Metrics: + Integer + FromPrimitive + Copy + BitAnd + Shl +{ + const PRECISION: usize; + + const FREQUENCY_BITS: usize = (Self::PRECISION / 2) - 1; + const CODE_VALUE_BITS: usize = Self::FREQUENCY_BITS + 2; + const MAX_CODE: usize = (1 << Self::CODE_VALUE_BITS) - 1; + const MAX_FREQ: usize = (1 << Self::FREQUENCY_BITS) - 1; + + const ONE_FOURTH: usize = 1 << (Self::CODE_VALUE_BITS - 2); + const ONE_HALF: usize = 2 * Self::ONE_FOURTH; + const THREE_FOURTHS: usize = 3 * Self::ONE_FOURTH; + + fn as_byte(&self) -> u8; + + fn print_metrics() { + println!("--------- Metrics ---------"); + println!(" PRECISION: {}", Self::PRECISION); + println!(" FREQUENCY_BITS: {}", Self::FREQUENCY_BITS); + println!("CODE_VALUE_BITS: {}", Self::CODE_VALUE_BITS); + println!(" MAX_CODE: {}", Self::MAX_CODE); + println!(" MAX_FREQ: {}", Self::MAX_FREQ); + println!(" ONE_FOURTH: {}", Self::ONE_FOURTH); + println!(" ONE_HALF: {}", Self::ONE_HALF); + println!(" THREE_FOURTHS: {}", Self::THREE_FOURTHS); + } +} +impl + Shl> + Metrics for T +{ + const PRECISION: usize = T::PRECISION; + fn as_byte(&self) -> u8 { + self.as_byte() + } +} + +/* +const PRECISION: u32 = 32; + +// 15 bits for frequency count +const FREQUENCY_BITS: u32 = (PRECISION / 2) - 1; +// 17 bits for CODE_VALUE +const VALUE_BITS: u32 = FREQUENCY_BITS + 2; + +const MAX_CODE: u32 = !((!0) << VALUE_BITS); +const MAX_FREQ: u32 = !((!0) << FREQUENCY_BITS); +const HALF: u32 = 1 << (VALUE_BITS - 1); +const LOW_CONVERGE: u32 = 0b10 << (VALUE_BITS - 2); +const HIGH_CONVERGE: u32 = 0b01 << (VALUE_BITS - 2); +*/ + +#[derive(Debug)] +struct Prob { + low: T, + high: T, + total: T, +} + +struct InputBits<'a> { + input: &'a [u8], + current_byte: u32, + last_mask: u32, + code_value_bits: i32, +} + +impl<'a> InputBits<'a> { + pub fn new(data: &'a [u8]) -> Self { + Self { + input: data, + current_byte: 0, + last_mask: 1, + code_value_bits: T::CODE_VALUE_BITS as i32, + } + } + fn get_bit(&mut self) -> bool { + if self.last_mask == 1 { + match self.input.pop() { + None => { + if self.code_value_bits <= 0 { + panic!("IDK Man"); + } else { + self.code_value_bits -= 8; + } + } + Some(byte) => self.current_byte = byte as u32, + } + self.last_mask = 0x80; + } else { + self.last_mask >>= 1; + } + return (self.current_byte & self.last_mask) != 0; + } +} + +//TODO: use unified trait +//trait CodeValue: Metrics + Integer + Into {} + +#[derive(Debug)] +#[allow(non_camel_case_types)] +pub struct ModelA { + cumulative_frequency: [CODE_VALUE; 258], + m_frozen: bool, +} + +impl Default for ModelA { + fn default() -> Self { + let m_frozen = false; + let mut cumulative_frequency = [T::zero(); 258]; + for i in 0..258 { + cumulative_frequency[i] = T::from_usize(i).unwrap(); + } + Self { + cumulative_frequency, + m_frozen, + } + } +} + +#[allow(dead_code)] +#[allow(non_snake_case)] +#[allow(non_camel_case_types)] +impl ModelA { + pub fn print_metrics(&self) { + CODE_VALUE::print_metrics(); + } + fn update(&mut self, c: i32) { + for i in (c as usize + 1)..258 { + self.cumulative_frequency[i] = self.cumulative_frequency[i] + CODE_VALUE::one(); + } + if self.cumulative_frequency[257] >= CODE_VALUE::from_usize(CODE_VALUE::MAX_FREQ).unwrap() { + self.m_frozen = true; + } + } + fn getProbability(&mut self, c: i32) -> Prob { + let p = Prob { + low: self.cumulative_frequency[c as usize], + high: self.cumulative_frequency[c as usize + 1], + total: self.cumulative_frequency[257], + }; + if !self.m_frozen { + self.update(c); + } + return p; + } + fn getChar(&mut self, scaled_value: CODE_VALUE) -> Option<(i32, Prob)> { + for i in 0..258 { + if scaled_value < self.cumulative_frequency[i + 1] { + let p = Prob { + low: self.cumulative_frequency[i], + high: self.cumulative_frequency[i + 1], + total: self.cumulative_frequency[257], + }; + if !self.m_frozen { + self.update(i as i32) + } + return Some((i as i32, p)); + } + } + return None; + } + fn getCount(&self) -> CODE_VALUE { + self.cumulative_frequency[257] + } + pub fn decompress(mut self, input: &[u8]) -> Option> { + let ONE: CODE_VALUE = CODE_VALUE::one(); + let ZERO: CODE_VALUE = CODE_VALUE::zero(); + let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap(); + let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); + let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); + + let mut input = InputBits::new::(input); + let mut output = vec![]; + + let mut low: CODE_VALUE = ZERO; + let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); + let mut value: CODE_VALUE = ZERO; + + for _ in 0..CODE_VALUE::CODE_VALUE_BITS { + value = (value << CODE_VALUE::one()) + if input.get_bit() { ONE } else { ZERO }; + } + loop { + let range: CODE_VALUE = high - low + ONE; + let scaled_value = ((value - low + ONE) * self.getCount() - ONE) / range; + let (c, p) = self.getChar(scaled_value).unwrap(); + if c > 255 || c < 0 { + break; + } + output.push(value.as_byte()); + high = low + (range * p.high) / p.total - ONE; + low = low + (range * p.low) / p.total; + loop { + if high < ONE_HALF { + } else if low >= ONE_HALF { + value = value - ONE_HALF; + low = low - ONE_HALF; + high = high - ONE_HALF + } else if low >= ONE_FORTH && high < THREE_FOURTHS { + value = value - ONE_FORTH; + low = low - ONE_FORTH; + high = high - ONE_FORTH; + } else { + break; + } + low = low << ONE; + high = (high << ONE) + ONE; + value = (value << ONE) + if input.get_bit() { ONE } else { ZERO }; + } + } + return Some(output); + } + + pub fn compress(mut self, input: &[u8]) -> Vec { + let ONE: CODE_VALUE = CODE_VALUE::one(); + let ZERO: CODE_VALUE = CODE_VALUE::zero(); + let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); + let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap(); + let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); + let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); + + let mut output = BitWriter::new(); + + let mut pending_bits: i32 = 0; + let mut low: CODE_VALUE = ZERO; + let mut high: CODE_VALUE = MAX_CODE; + + for mut c in input.iter().map(|b| *b as i32).chain([256_i32]) { + if c > 255 || c < 0 { + c = 256; + } else { + println!("c: '{}'", c as u8 as char); + } + let p = self.getProbability(c); + let range: CODE_VALUE = high - low + ONE; + high = low + (range * p.high / p.total) - ONE; + low = low + (range * p.low / p.total); + + loop { + if high < ONE_HALF { + Self::write_with_pending(false, &mut pending_bits, &mut output); + } else if low >= ONE_HALF { + Self::write_with_pending(true, &mut pending_bits, &mut output); + } else if low >= ONE_FORTH && high < THREE_FOURTHS { + pending_bits += 1; + low = low - ONE_FORTH; + high = high - ONE_FORTH; + } else { + break; + } + high = ((high << ONE) + ONE) & MAX_CODE; + low = (low << ONE) & MAX_CODE; + } + if c == 256 { + break; + } + } + println!("EOF"); + pending_bits += 1; + if low < ONE_FORTH { + Self::write_with_pending(false, &mut pending_bits, &mut output); + } else { + Self::write_with_pending(true, &mut pending_bits, &mut output); + } + + println!(""); + return output.into(); + } + + fn write_with_pending(bit: bool, pending: &mut i32, output: &mut BitWriter) { + print!("{}\n", if bit { "1" } else { "0" }); + output.write(bit); + for _ in 0..*pending { + output.write(!bit); + print!("{}\n", if !bit { "1" } else { "0" }); + } + *pending = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n"; + /// Compressed bytes taken from output of the c++ version + const COMPRESSED_BYTES: [u8; 14] = [ + 0x67, 0xfc, 0x3e, 0x4a, 0x9d, 0x03, 0x7f, 0x35, 0xf1, 0x08, 0xd8, 0xa6, 0xbc, 0xda, + ]; + + #[test] + fn compression_test() { + let model: ModelA = ModelA::default(); + let enc = model.compress(UNCOMPRESSED_BYTES); + assert_eq!(COMPRESSED_BYTES.len(), enc.len()); + for (a, b) in enc.iter().zip(COMPRESSED_BYTES.iter()) { + assert_eq!(a, b); + } + } + + #[test] + fn decompression_test() { + let model: ModelA = ModelA::default(); + let dec = model.decompress(&COMPRESSED_BYTES).unwrap(); + assert_eq!(UNCOMPRESSED_BYTES.len(), dec.len()); + for (a, b) in dec.iter().zip(UNCOMPRESSED_BYTES.iter()) { + assert_eq!(a, b); + } + } +}