diff --git a/examples/example.rs b/examples/example.rs index bfdda96..47ed695 100644 --- a/examples/example.rs +++ b/examples/example.rs @@ -1,3 +1,4 @@ +use sac::model::Model; use sac::modelA::ModelA; const DATA: &[u8] = b" diff --git a/src/bit_buffer.rs b/src/bit_buffer.rs index c0eb5e0..0ec0eee 100644 --- a/src/bit_buffer.rs +++ b/src/bit_buffer.rs @@ -95,8 +95,8 @@ impl BitReader { #[cfg(test)] mod tests { use super::*; + use crate::model::Metrics; use crate::modelA::tests::COMPRESSED_BYTES; - use crate::modelA::Metrics; struct InputBits<'a> { input: &'a [u8], diff --git a/src/lib.rs b/src/lib.rs index 2d3e5ef..4626581 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ // https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html +pub mod model; #[allow(non_snake_case)] pub mod modelA; -mod bit_buffer; +pub mod bit_buffer; diff --git a/src/main.rs b/src/main.rs index bd83bbc..ce1afc4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use std::{ path::Path, }; +use sac::model::Model; use sac::modelA::ModelA; enum Mode { diff --git a/src/model.rs b/src/model.rs index c83597d..2a0043d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,133 +1,207 @@ -use std::collections::HashMap; -pub type Model = HashMap; +use num::{FromPrimitive, Integer}; +use std::{ + io::{self, Read, Write}, + ops::{BitAnd, Shl}, + usize, +}; + use crate::bit_buffer::{BitReader, BitWriter}; -pub fn get_symbol(model: &Model, low: f64, high: f64) -> Option { - for (symbol, (start, end)) in model { - if low >= *start && high < *end { - return Some(*symbol); +trait Precision { + const PRECISION: usize; +} +macro_rules! unsignedImplDigits { + ($($type: ident),*) => { $( + impl Precision for $type { + const PRECISION: usize = (std::mem::size_of::<$type>() * 8); } + )* }; +} +macro_rules! signedImplDigits { + ($($type: ident),*) => { $( + impl Precision for $type { + const PRECISION: usize = (std::mem::size_of::<$type>() * 8) - 1; + } + )* }; +} +unsignedImplDigits!(u32, u64); +signedImplDigits!(i32, i64, i128); + +pub trait Metrics: + Integer + FromPrimitive + Copy + BitAnd + Shl +{ + const PRECISION: usize; + + const FREQUENCY_BITS: usize = (Self::PRECISION / 2) - 1; + const CODE_VALUE_BITS: usize = Self::FREQUENCY_BITS + 2; + const MAX_CODE: usize = if Self::CODE_VALUE_BITS == 64 { + u64::MAX as usize + } else { + (1 << Self::CODE_VALUE_BITS) - 1 + }; + const MAX_FREQ: usize = (1 << Self::FREQUENCY_BITS) - 1; + + const ONE_FOURTH: usize = 1 << (Self::CODE_VALUE_BITS - 2); + const ONE_HALF: usize = 2 * Self::ONE_FOURTH; + const THREE_FOURTHS: usize = 3 * Self::ONE_FOURTH; + + fn print_metrics() { + println!("--------- Metrics ---------"); + println!(" PRECISION: {}", Self::PRECISION); + println!(" FREQUENCY_BITS: {}", Self::FREQUENCY_BITS); + println!("CODE_VALUE_BITS: {}", Self::CODE_VALUE_BITS); + println!(" MAX_CODE: {}", Self::MAX_CODE); + println!(" MAX_FREQ: {}", Self::MAX_FREQ); + println!(" ONE_FOURTH: {}", Self::ONE_FOURTH); + println!(" ONE_HALF: {}", Self::ONE_HALF); + println!(" THREE_FOURTHS: {}", Self::THREE_FOURTHS); } - return None; +} +impl< + T: Precision + Integer + FromPrimitive + Copy + BitAnd + Shl, + > Metrics for T +{ + const PRECISION: usize = T::PRECISION; } -pub fn make_model(probabilities: &[(u8, f64)]) -> Model { - let mut model = HashMap::new(); - let mut end: f64 = 0.0; - for (symbol, probability) in probabilities { - let start: f64 = end; - end = start + probability; - model.insert(*symbol, (start, end)); - println!("{}: [{}, {})", *symbol as char, start, end); - } - return model; +#[derive(Debug)] +pub struct Prob { + pub low: T, + pub high: T, + pub max_code: T, } -pub const ENGLISH: &[(u8, f64)] = &[ - (b'a', 0.08), - (b'b', 0.01), - (b'c', 0.02), - (b'd', 0.04), - (b'e', 0.12), - (b'f', 0.02), - (b'g', 0.02), - (b'h', 0.06), - (b'i', 0.07), - (b'j', 0.01), - (b'k', 0.01), - (b'l', 0.04), - (b'm', 0.02), - (b'n', 0.06), - (b'o', 0.07), - (b'p', 0.01), - (b'q', 0.01), - (b'r', 0.06), - (b's', 0.06), - (b't', 0.09), - (b'u', 0.02), - (b'v', 0.01), - (b'w', 0.02), - (b'x', 0.01), - (b'y', 0.02), - (b'z', 0.01), - (b' ', 0.01), - (b'-', 0.02), -]; -fn encode(input: &[u8], model: &Model) -> Vec { - const HALF: u64 = 1 << (u64::BITS - 1); - const LOW_CONVERGE: u64 = 0b10 << (u64::BITS - 2); - const HIGH_CONVERGE: u64 = 0b01 << (u64::BITS - 2); +pub trait Model { + fn get_probability(&mut self, c: i32) -> Prob; + fn get_char(&mut self, scaled_value: CodeWord) -> Option<(i32, Prob)>; + fn get_max_code(&self) -> CodeWord; - let mut output = BitWriter::new(); + #[allow(non_snake_case)] + fn decompress>>( + mut self, + input: I, + output: &mut O, + ) -> io::Result<()> + where + Self: Sized, + { + let ONE: CodeWord = CodeWord::one(); + let ZERO: CodeWord = CodeWord::zero(); + let ONE_HALF: CodeWord = CodeWord::from_usize(CodeWord::ONE_HALF).unwrap(); + let ONE_FORTH: CodeWord = CodeWord::from_usize(CodeWord::ONE_FOURTH).unwrap(); + let THREE_FOURTHS: CodeWord = CodeWord::from_usize(CodeWord::THREE_FOURTHS).unwrap(); - let mut high = u64::MAX; - let mut low = u64::MIN; - let mut pending_bits = 0; + let mut input: BitReader = input + .into() + .with_repeat_bits(CodeWord::CODE_VALUE_BITS as u16); - for symbol in input { - let range = high - low; - let p = model.get(symbol).expect("Invalid/Unsupported data"); - high = low + (range as f64 * p.1) as u64; - low = low + (range as f64 * p.0) as u64; + let mut low: CodeWord = ZERO; + let mut high: CodeWord = CodeWord::from_usize(CodeWord::MAX_CODE).unwrap(); + let mut value: CodeWord = ZERO; + + for _ in 0..CodeWord::CODE_VALUE_BITS { + value = (value << CodeWord::one()) + if input.get_bit()? { ONE } else { ZERO }; + } loop { - if high < HALF { - output.write(false); - print!("0"); - while pending_bits > 0 { - output.write(true); - print!("1"); - pending_bits -= 1; - } - } else if low >= HALF { - output.write(true); - print!("1"); - while pending_bits > 0 { - output.write(true); - print!("0"); - pending_bits -= 1; - } - } else if low >= LOW_CONVERGE && high < HIGH_CONVERGE { - println!("BET"); - pending_bits += 1; - low <<= 1; - low &= HALF - 1; - high <<= 1; - high &= HALF + 1; - continue; - } else { + let range: CodeWord = high - low + ONE; + let scaled_value = ((value - low + ONE) * self.get_max_code() - ONE) / range; + let (c, p) = self.get_char(scaled_value).unwrap(); + if c > 255 || c < 0 { break; } - low <<= 1; - high <<= 1; - high |= 1; + output.write(&[c as u8])?; + high = low + (range * p.high) / p.max_code - ONE; + low = low + (range * p.low) / p.max_code; + loop { + if high < ONE_HALF { + } else if low >= ONE_HALF { + value = value - ONE_HALF; + low = low - ONE_HALF; + high = high - ONE_HALF + } else if low >= ONE_FORTH && high < THREE_FOURTHS { + value = value - ONE_FORTH; + low = low - ONE_FORTH; + high = high - ONE_FORTH; + } else { + break; + } + low = low << ONE; + high = (high << ONE) + ONE; + value = (value << ONE) + if input.get_bit()? { ONE } else { ZERO }; + } } + return Ok(()); } - println!(""); - return output.flush(); -} -fn decode(input: &[u8], model: &Model) -> Vec { - let mut high = 1.0; - let mut low = 0.0; - let mut output = vec![]; - for bit in BitReader::new(input) { - let diff = high - low; - if bit { - //print!("1"); - low = low + (diff / 2.0); + #[allow(non_snake_case)] + fn compress(mut self, input: IN, output: &mut OUT) -> std::io::Result<()> + where + Self: Sized, + { + let ONE: CodeWord = CodeWord::one(); + let ZERO: CodeWord = CodeWord::zero(); + let MAX_CODE: CodeWord = CodeWord::from_usize(CodeWord::MAX_CODE).unwrap(); + let ONE_HALF: CodeWord = CodeWord::from_usize(CodeWord::ONE_HALF).unwrap(); + let ONE_FORTH: CodeWord = CodeWord::from_usize(CodeWord::ONE_FOURTH).unwrap(); + let THREE_FOURTHS: CodeWord = CodeWord::from_usize(CodeWord::THREE_FOURTHS).unwrap(); + + let mut output: BitWriter = output.into(); + + let mut pending_bits: i32 = 0; + let mut low: CodeWord = ZERO; + let mut high: CodeWord = MAX_CODE; + + let mut iter = input + .bytes() + .map(|r| r.map(|b| b as i32)) + .chain([Ok(256_i32)]); + while let Some(Ok(mut c)) = iter.next() { + if c > 255 || c < 0 { + c = 256; + } + let p = self.get_probability(c); + let range: CodeWord = high - low + ONE; + high = low + (range * p.high / p.max_code) - ONE; + low = low + (range * p.low / p.max_code); + + loop { + if high < ONE_HALF { + write_with_pending(false, &mut pending_bits, &mut output)?; + } else if low >= ONE_HALF { + write_with_pending(true, &mut pending_bits, &mut output)?; + } else if low >= ONE_FORTH && high < THREE_FOURTHS { + pending_bits += 1; + low = low - ONE_FORTH; + high = high - ONE_FORTH; + } else { + break; + } + high = ((high << ONE) + ONE) & MAX_CODE; + low = (low << ONE) & MAX_CODE; + } + if c == 256 { + break; + } + } + pending_bits += 1; + if low < ONE_FORTH { + write_with_pending(false, &mut pending_bits, &mut output)?; } else { - high = high - (diff / 2.0); - //print!("0"); + write_with_pending(true, &mut pending_bits, &mut output)?; } - if let Some(symbol) = get_symbol(model, low, high) { - //println!("\nGot sym: {} from [{}, {})", symbol as char, low, high); - output.push(symbol); - let (slow, shigh) = model.get(&symbol).unwrap(); - let symdiff = *shigh - *slow; - high = (high - *slow) / symdiff; - low = (low - *slow) / symdiff; - } - } - return output; + return output.flush(); + } +} +fn write_with_pending( + bit: bool, + pending: &mut i32, + output: &mut BitWriter, +) -> std::io::Result<()> { + output.write(bit)?; + for _ in 0..*pending { + output.write(!bit)?; + } + *pending = 0; + Ok(()) } diff --git a/src/modelA.rs b/src/modelA.rs index 0a24073..721ab49 100644 --- a/src/modelA.rs +++ b/src/modelA.rs @@ -1,81 +1,7 @@ -use std::{ - fmt::Display, - io::{self, Read, Write}, - ops::{BitAnd, Shl}, - usize, -}; +use crate::model::{Metrics, Model, Prob}; -use num::{FromPrimitive, Integer}; - -use crate::bit_buffer::{BitReader, BitWriter}; - -trait Digits { - const PRECISION: usize; -} -macro_rules! unsignedImplDigits { - ($($type: ident),*) => { $( - impl Digits for $type { - const PRECISION: usize = (std::mem::size_of::<$type>() * 8); - } - )* }; -} -macro_rules! signedImplDigits { - ($($type: ident),*) => { $( - impl Digits for $type { - const PRECISION: usize = (std::mem::size_of::<$type>() * 8) - 1; - } - )* }; -} -unsignedImplDigits!(u32, u64); -signedImplDigits!(i32, i64, i128); - -pub trait Metrics: - Integer + FromPrimitive + Copy + BitAnd + Shl -{ - const PRECISION: usize; - - const FREQUENCY_BITS: usize = (Self::PRECISION / 2) - 1; - const CODE_VALUE_BITS: usize = Self::FREQUENCY_BITS + 2; - const MAX_CODE: usize = if Self::CODE_VALUE_BITS == 64 { - u64::MAX as usize - } else { - (1 << Self::CODE_VALUE_BITS) - 1 - }; - const MAX_FREQ: usize = (1 << Self::FREQUENCY_BITS) - 1; - - const ONE_FOURTH: usize = 1 << (Self::CODE_VALUE_BITS - 2); - const ONE_HALF: usize = 2 * Self::ONE_FOURTH; - const THREE_FOURTHS: usize = 3 * Self::ONE_FOURTH; - - fn print_metrics() { - println!("--------- Metrics ---------"); - println!(" PRECISION: {}", Self::PRECISION); - println!(" FREQUENCY_BITS: {}", Self::FREQUENCY_BITS); - println!("CODE_VALUE_BITS: {}", Self::CODE_VALUE_BITS); - println!(" MAX_CODE: {}", Self::MAX_CODE); - println!(" MAX_FREQ: {}", Self::MAX_FREQ); - println!(" ONE_FOURTH: {}", Self::ONE_FOURTH); - println!(" ONE_HALF: {}", Self::ONE_HALF); - println!(" THREE_FOURTHS: {}", Self::THREE_FOURTHS); - } -} -impl + Shl> - Metrics for T -{ - const PRECISION: usize = T::PRECISION; -} - -#[derive(Debug)] -struct Prob { - low: T, - high: T, - total: T, -} - -#[derive(Debug)] -#[allow(non_camel_case_types)] -pub struct ModelA { - cumulative_frequency: [CODE_VALUE; 258], +pub struct ModelA { + cumulative_frequency: [T; 258], m_frozen: bool, } @@ -93,38 +19,39 @@ impl Default for ModelA { } } -#[allow(non_snake_case)] -#[allow(non_camel_case_types)] -impl ModelA { +impl ModelA { pub fn print_metrics(&self) { - CODE_VALUE::print_metrics(); + T::print_metrics(); } fn update(&mut self, c: i32) { for i in (c as usize + 1)..258 { - self.cumulative_frequency[i] = self.cumulative_frequency[i] + CODE_VALUE::one(); + self.cumulative_frequency[i] = self.cumulative_frequency[i] + T::one(); } - if self.cumulative_frequency[257] >= CODE_VALUE::from_usize(CODE_VALUE::MAX_FREQ).unwrap() { + if self.cumulative_frequency[257] >= T::from_usize(T::MAX_FREQ).unwrap() { self.m_frozen = true; } } - fn getProbability(&mut self, c: i32) -> Prob { +} +impl Model for ModelA { + fn get_probability(&mut self, c: i32) -> crate::model::Prob { let p = Prob { low: self.cumulative_frequency[c as usize], high: self.cumulative_frequency[c as usize + 1], - total: self.cumulative_frequency[257], + max_code: self.cumulative_frequency[257], }; if !self.m_frozen { self.update(c); } return p; } - fn getChar(&mut self, scaled_value: CODE_VALUE) -> Option<(i32, Prob)> { + + fn get_char(&mut self, scaled_value: T) -> Option<(i32, crate::model::Prob)> { for i in 0..258 { if scaled_value < self.cumulative_frequency[i + 1] { let p = Prob { low: self.cumulative_frequency[i], high: self.cumulative_frequency[i + 1], - total: self.cumulative_frequency[257], + max_code: self.cumulative_frequency[257], }; if !self.m_frozen { self.update(i as i32) @@ -134,135 +61,10 @@ impl ModelA { } return None; } - fn getCount(&self) -> CODE_VALUE { + + fn get_max_code(&self) -> T { self.cumulative_frequency[257] } - - pub fn decompress>>( - mut self, - input: I, - output: &mut O, - ) -> io::Result<()> { - let ONE: CODE_VALUE = CODE_VALUE::one(); - let ZERO: CODE_VALUE = CODE_VALUE::zero(); - let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap(); - let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); - let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); - - let mut input: BitReader = input - .into() - .with_repeat_bits(CODE_VALUE::CODE_VALUE_BITS as u16); - - let mut low: CODE_VALUE = ZERO; - let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); - let mut value: CODE_VALUE = ZERO; - - for _ in 0..CODE_VALUE::CODE_VALUE_BITS { - value = (value << CODE_VALUE::one()) + if input.get_bit()? { ONE } else { ZERO }; - } - loop { - let range: CODE_VALUE = high - low + ONE; - let scaled_value = ((value - low + ONE) * self.getCount() - ONE) / range; - let (c, p) = self.getChar(scaled_value).unwrap(); - if c > 255 || c < 0 { - break; - } - output.write(&[c as u8])?; - high = low + (range * p.high) / p.total - ONE; - low = low + (range * p.low) / p.total; - loop { - if high < ONE_HALF { - } else if low >= ONE_HALF { - value = value - ONE_HALF; - low = low - ONE_HALF; - high = high - ONE_HALF - } else if low >= ONE_FORTH && high < THREE_FOURTHS { - value = value - ONE_FORTH; - low = low - ONE_FORTH; - high = high - ONE_FORTH; - } else { - break; - } - low = low << ONE; - high = (high << ONE) + ONE; - value = (value << ONE) + if input.get_bit()? { ONE } else { ZERO }; - } - } - return Ok(()); - } - - pub fn compress( - mut self, - input: IN, - output: &mut OUT, - ) -> std::io::Result<()> { - let ONE: CODE_VALUE = CODE_VALUE::one(); - let ZERO: CODE_VALUE = CODE_VALUE::zero(); - let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap(); - let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap(); - let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap(); - let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap(); - - let mut output: BitWriter = output.into(); - - let mut pending_bits: i32 = 0; - let mut low: CODE_VALUE = ZERO; - let mut high: CODE_VALUE = MAX_CODE; - - let mut iter = input - .bytes() - .map(|r| r.map(|b| b as i32)) - .chain([Ok(256_i32)]); - while let Some(Ok(mut c)) = iter.next() { - if c > 255 || c < 0 { - c = 256; - } - let p = self.getProbability(c); - let range: CODE_VALUE = high - low + ONE; - high = low + (range * p.high / p.total) - ONE; - low = low + (range * p.low / p.total); - - loop { - if high < ONE_HALF { - Self::write_with_pending(false, &mut pending_bits, &mut output)?; - } else if low >= ONE_HALF { - Self::write_with_pending(true, &mut pending_bits, &mut output)?; - } else if low >= ONE_FORTH && high < THREE_FOURTHS { - pending_bits += 1; - low = low - ONE_FORTH; - high = high - ONE_FORTH; - } else { - break; - } - high = ((high << ONE) + ONE) & MAX_CODE; - low = (low << ONE) & MAX_CODE; - } - if c == 256 { - break; - } - } - pending_bits += 1; - if low < ONE_FORTH { - Self::write_with_pending(false, &mut pending_bits, &mut output)?; - } else { - Self::write_with_pending(true, &mut pending_bits, &mut output)?; - } - - return output.flush(); - } - - fn write_with_pending( - bit: bool, - pending: &mut i32, - output: &mut BitWriter, - ) -> std::io::Result<()> { - output.write(bit)?; - for _ in 0..*pending { - output.write(!bit)?; - } - *pending = 0; - Ok(()) - } } #[cfg(test)]