Add working encoder adaptation of modelA
This commit is contained in:
parent
8bfe71a1af
commit
8ea134e008
342
src/modelA.rs
Normal file
342
src/modelA.rs
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
use core::panic;
|
||||||
|
use std::{
|
||||||
|
fmt::Display,
|
||||||
|
ops::{BitAnd, Shl},
|
||||||
|
usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
use num::{FromPrimitive, Integer};
|
||||||
|
|
||||||
|
use crate::bit_buffer::{BitWriter, Poppable};
|
||||||
|
|
||||||
|
trait Digits {
|
||||||
|
const PRECISION: usize;
|
||||||
|
fn as_byte(&self) -> u8;
|
||||||
|
}
|
||||||
|
macro_rules! unsignedImplDigits {
|
||||||
|
($($type: ident),*) => { $(
|
||||||
|
impl Digits for $type {
|
||||||
|
const PRECISION: usize = (std::mem::size_of::<$type>() * 8);
|
||||||
|
fn as_byte(&self) -> u8 {*self as u8}
|
||||||
|
}
|
||||||
|
)* };
|
||||||
|
}
|
||||||
|
macro_rules! signedImplDigits {
|
||||||
|
($($type: ident),*) => { $(
|
||||||
|
impl Digits for $type {
|
||||||
|
const PRECISION: usize = (std::mem::size_of::<$type>() * 8) - 1;
|
||||||
|
fn as_byte(&self) -> u8 {*self as u8}
|
||||||
|
}
|
||||||
|
)* };
|
||||||
|
}
|
||||||
|
unsignedImplDigits!(u16, u32, u64, u128);
|
||||||
|
signedImplDigits!(i16, i32, i64, i128);
|
||||||
|
|
||||||
|
pub trait Metrics:
|
||||||
|
Integer + FromPrimitive + Copy + BitAnd<Output = Self> + Shl<Output = Self>
|
||||||
|
{
|
||||||
|
const PRECISION: usize;
|
||||||
|
|
||||||
|
const FREQUENCY_BITS: usize = (Self::PRECISION / 2) - 1;
|
||||||
|
const CODE_VALUE_BITS: usize = Self::FREQUENCY_BITS + 2;
|
||||||
|
const MAX_CODE: usize = (1 << Self::CODE_VALUE_BITS) - 1;
|
||||||
|
const MAX_FREQ: usize = (1 << Self::FREQUENCY_BITS) - 1;
|
||||||
|
|
||||||
|
const ONE_FOURTH: usize = 1 << (Self::CODE_VALUE_BITS - 2);
|
||||||
|
const ONE_HALF: usize = 2 * Self::ONE_FOURTH;
|
||||||
|
const THREE_FOURTHS: usize = 3 * Self::ONE_FOURTH;
|
||||||
|
|
||||||
|
fn as_byte(&self) -> u8;
|
||||||
|
|
||||||
|
fn print_metrics() {
|
||||||
|
println!("--------- Metrics ---------");
|
||||||
|
println!(" PRECISION: {}", Self::PRECISION);
|
||||||
|
println!(" FREQUENCY_BITS: {}", Self::FREQUENCY_BITS);
|
||||||
|
println!("CODE_VALUE_BITS: {}", Self::CODE_VALUE_BITS);
|
||||||
|
println!(" MAX_CODE: {}", Self::MAX_CODE);
|
||||||
|
println!(" MAX_FREQ: {}", Self::MAX_FREQ);
|
||||||
|
println!(" ONE_FOURTH: {}", Self::ONE_FOURTH);
|
||||||
|
println!(" ONE_HALF: {}", Self::ONE_HALF);
|
||||||
|
println!(" THREE_FOURTHS: {}", Self::THREE_FOURTHS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<T: Digits + Integer + FromPrimitive + Copy + BitAnd<Output = Self> + Shl<Output = Self>>
|
||||||
|
Metrics for T
|
||||||
|
{
|
||||||
|
const PRECISION: usize = T::PRECISION;
|
||||||
|
fn as_byte(&self) -> u8 {
|
||||||
|
self.as_byte()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
const PRECISION: u32 = 32;
|
||||||
|
|
||||||
|
// 15 bits for frequency count
|
||||||
|
const FREQUENCY_BITS: u32 = (PRECISION / 2) - 1;
|
||||||
|
// 17 bits for CODE_VALUE
|
||||||
|
const VALUE_BITS: u32 = FREQUENCY_BITS + 2;
|
||||||
|
|
||||||
|
const MAX_CODE: u32 = !((!0) << VALUE_BITS);
|
||||||
|
const MAX_FREQ: u32 = !((!0) << FREQUENCY_BITS);
|
||||||
|
const HALF: u32 = 1 << (VALUE_BITS - 1);
|
||||||
|
const LOW_CONVERGE: u32 = 0b10 << (VALUE_BITS - 2);
|
||||||
|
const HIGH_CONVERGE: u32 = 0b01 << (VALUE_BITS - 2);
|
||||||
|
*/
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Prob<T> {
|
||||||
|
low: T,
|
||||||
|
high: T,
|
||||||
|
total: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InputBits<'a> {
|
||||||
|
input: &'a [u8],
|
||||||
|
current_byte: u32,
|
||||||
|
last_mask: u32,
|
||||||
|
code_value_bits: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> InputBits<'a> {
|
||||||
|
pub fn new<T: Metrics>(data: &'a [u8]) -> Self {
|
||||||
|
Self {
|
||||||
|
input: data,
|
||||||
|
current_byte: 0,
|
||||||
|
last_mask: 1,
|
||||||
|
code_value_bits: T::CODE_VALUE_BITS as i32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn get_bit(&mut self) -> bool {
|
||||||
|
if self.last_mask == 1 {
|
||||||
|
match self.input.pop() {
|
||||||
|
None => {
|
||||||
|
if self.code_value_bits <= 0 {
|
||||||
|
panic!("IDK Man");
|
||||||
|
} else {
|
||||||
|
self.code_value_bits -= 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(byte) => self.current_byte = byte as u32,
|
||||||
|
}
|
||||||
|
self.last_mask = 0x80;
|
||||||
|
} else {
|
||||||
|
self.last_mask >>= 1;
|
||||||
|
}
|
||||||
|
return (self.current_byte & self.last_mask) != 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: use unified trait
|
||||||
|
//trait CodeValue: Metrics + Integer + Into<u8> {}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub struct ModelA<CODE_VALUE> {
|
||||||
|
cumulative_frequency: [CODE_VALUE; 258],
|
||||||
|
m_frozen: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Metrics> Default for ModelA<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let m_frozen = false;
|
||||||
|
let mut cumulative_frequency = [T::zero(); 258];
|
||||||
|
for i in 0..258 {
|
||||||
|
cumulative_frequency[i] = T::from_usize(i).unwrap();
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
cumulative_frequency,
|
||||||
|
m_frozen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
impl<CODE_VALUE: Metrics + Display> ModelA<CODE_VALUE> {
|
||||||
|
pub fn print_metrics(&self) {
|
||||||
|
CODE_VALUE::print_metrics();
|
||||||
|
}
|
||||||
|
fn update(&mut self, c: i32) {
|
||||||
|
for i in (c as usize + 1)..258 {
|
||||||
|
self.cumulative_frequency[i] = self.cumulative_frequency[i] + CODE_VALUE::one();
|
||||||
|
}
|
||||||
|
if self.cumulative_frequency[257] >= CODE_VALUE::from_usize(CODE_VALUE::MAX_FREQ).unwrap() {
|
||||||
|
self.m_frozen = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn getProbability(&mut self, c: i32) -> Prob<CODE_VALUE> {
|
||||||
|
let p = Prob {
|
||||||
|
low: self.cumulative_frequency[c as usize],
|
||||||
|
high: self.cumulative_frequency[c as usize + 1],
|
||||||
|
total: self.cumulative_frequency[257],
|
||||||
|
};
|
||||||
|
if !self.m_frozen {
|
||||||
|
self.update(c);
|
||||||
|
}
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
fn getChar(&mut self, scaled_value: CODE_VALUE) -> Option<(i32, Prob<CODE_VALUE>)> {
|
||||||
|
for i in 0..258 {
|
||||||
|
if scaled_value < self.cumulative_frequency[i + 1] {
|
||||||
|
let p = Prob {
|
||||||
|
low: self.cumulative_frequency[i],
|
||||||
|
high: self.cumulative_frequency[i + 1],
|
||||||
|
total: self.cumulative_frequency[257],
|
||||||
|
};
|
||||||
|
if !self.m_frozen {
|
||||||
|
self.update(i as i32)
|
||||||
|
}
|
||||||
|
return Some((i as i32, p));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
fn getCount(&self) -> CODE_VALUE {
|
||||||
|
self.cumulative_frequency[257]
|
||||||
|
}
|
||||||
|
pub fn decompress(mut self, input: &[u8]) -> Option<Vec<u8>> {
|
||||||
|
let ONE: CODE_VALUE = CODE_VALUE::one();
|
||||||
|
let ZERO: CODE_VALUE = CODE_VALUE::zero();
|
||||||
|
let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap();
|
||||||
|
let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap();
|
||||||
|
let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap();
|
||||||
|
|
||||||
|
let mut input = InputBits::new::<CODE_VALUE>(input);
|
||||||
|
let mut output = vec![];
|
||||||
|
|
||||||
|
let mut low: CODE_VALUE = ZERO;
|
||||||
|
let mut high: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap();
|
||||||
|
let mut value: CODE_VALUE = ZERO;
|
||||||
|
|
||||||
|
for _ in 0..CODE_VALUE::CODE_VALUE_BITS {
|
||||||
|
value = (value << CODE_VALUE::one()) + if input.get_bit() { ONE } else { ZERO };
|
||||||
|
}
|
||||||
|
loop {
|
||||||
|
let range: CODE_VALUE = high - low + ONE;
|
||||||
|
let scaled_value = ((value - low + ONE) * self.getCount() - ONE) / range;
|
||||||
|
let (c, p) = self.getChar(scaled_value).unwrap();
|
||||||
|
if c > 255 || c < 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
output.push(value.as_byte());
|
||||||
|
high = low + (range * p.high) / p.total - ONE;
|
||||||
|
low = low + (range * p.low) / p.total;
|
||||||
|
loop {
|
||||||
|
if high < ONE_HALF {
|
||||||
|
} else if low >= ONE_HALF {
|
||||||
|
value = value - ONE_HALF;
|
||||||
|
low = low - ONE_HALF;
|
||||||
|
high = high - ONE_HALF
|
||||||
|
} else if low >= ONE_FORTH && high < THREE_FOURTHS {
|
||||||
|
value = value - ONE_FORTH;
|
||||||
|
low = low - ONE_FORTH;
|
||||||
|
high = high - ONE_FORTH;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
low = low << ONE;
|
||||||
|
high = (high << ONE) + ONE;
|
||||||
|
value = (value << ONE) + if input.get_bit() { ONE } else { ZERO };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Some(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compress(mut self, input: &[u8]) -> Vec<u8> {
|
||||||
|
let ONE: CODE_VALUE = CODE_VALUE::one();
|
||||||
|
let ZERO: CODE_VALUE = CODE_VALUE::zero();
|
||||||
|
let MAX_CODE: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::MAX_CODE).unwrap();
|
||||||
|
let ONE_HALF: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_HALF).unwrap();
|
||||||
|
let ONE_FORTH: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::ONE_FOURTH).unwrap();
|
||||||
|
let THREE_FOURTHS: CODE_VALUE = CODE_VALUE::from_usize(CODE_VALUE::THREE_FOURTHS).unwrap();
|
||||||
|
|
||||||
|
let mut output = BitWriter::new();
|
||||||
|
|
||||||
|
let mut pending_bits: i32 = 0;
|
||||||
|
let mut low: CODE_VALUE = ZERO;
|
||||||
|
let mut high: CODE_VALUE = MAX_CODE;
|
||||||
|
|
||||||
|
for mut c in input.iter().map(|b| *b as i32).chain([256_i32]) {
|
||||||
|
if c > 255 || c < 0 {
|
||||||
|
c = 256;
|
||||||
|
} else {
|
||||||
|
println!("c: '{}'", c as u8 as char);
|
||||||
|
}
|
||||||
|
let p = self.getProbability(c);
|
||||||
|
let range: CODE_VALUE = high - low + ONE;
|
||||||
|
high = low + (range * p.high / p.total) - ONE;
|
||||||
|
low = low + (range * p.low / p.total);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if high < ONE_HALF {
|
||||||
|
Self::write_with_pending(false, &mut pending_bits, &mut output);
|
||||||
|
} else if low >= ONE_HALF {
|
||||||
|
Self::write_with_pending(true, &mut pending_bits, &mut output);
|
||||||
|
} else if low >= ONE_FORTH && high < THREE_FOURTHS {
|
||||||
|
pending_bits += 1;
|
||||||
|
low = low - ONE_FORTH;
|
||||||
|
high = high - ONE_FORTH;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
high = ((high << ONE) + ONE) & MAX_CODE;
|
||||||
|
low = (low << ONE) & MAX_CODE;
|
||||||
|
}
|
||||||
|
if c == 256 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("EOF");
|
||||||
|
pending_bits += 1;
|
||||||
|
if low < ONE_FORTH {
|
||||||
|
Self::write_with_pending(false, &mut pending_bits, &mut output);
|
||||||
|
} else {
|
||||||
|
Self::write_with_pending(true, &mut pending_bits, &mut output);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("");
|
||||||
|
return output.into();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_with_pending(bit: bool, pending: &mut i32, output: &mut BitWriter) {
|
||||||
|
print!("{}\n", if bit { "1" } else { "0" });
|
||||||
|
output.write(bit);
|
||||||
|
for _ in 0..*pending {
|
||||||
|
output.write(!bit);
|
||||||
|
print!("{}\n", if !bit { "1" } else { "0" });
|
||||||
|
}
|
||||||
|
*pending = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
const UNCOMPRESSED_BYTES: &[u8; 13] = b"hello world-\n";
|
||||||
|
/// Compressed bytes taken from output of the c++ version
|
||||||
|
const COMPRESSED_BYTES: [u8; 14] = [
|
||||||
|
0x67, 0xfc, 0x3e, 0x4a, 0x9d, 0x03, 0x7f, 0x35, 0xf1, 0x08, 0xd8, 0xa6, 0xbc, 0xda,
|
||||||
|
];
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compression_test() {
|
||||||
|
let model: ModelA<i32> = ModelA::default();
|
||||||
|
let enc = model.compress(UNCOMPRESSED_BYTES);
|
||||||
|
assert_eq!(COMPRESSED_BYTES.len(), enc.len());
|
||||||
|
for (a, b) in enc.iter().zip(COMPRESSED_BYTES.iter()) {
|
||||||
|
assert_eq!(a, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decompression_test() {
|
||||||
|
let model: ModelA<i32> = ModelA::default();
|
||||||
|
let dec = model.decompress(&COMPRESSED_BYTES).unwrap();
|
||||||
|
assert_eq!(UNCOMPRESSED_BYTES.len(), dec.len());
|
||||||
|
for (a, b) in dec.iter().zip(UNCOMPRESSED_BYTES.iter()) {
|
||||||
|
assert_eq!(a, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user