// https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html use std::collections::HashMap; mod bit_buffer; use bit_buffer::{BitReader, BitWriter}; type Model = HashMap; fn get_symbol(model: &Model, low: f64, high: f64) -> Option { for (symbol, (start, end)) in model { if low >= *start && high < *end { return Some(*symbol); } } return None; } fn encode(input: &[u8], model: &Model) -> Vec { let mut output = BitWriter::new(); let mut high = u64::MAX; let mut low = u64::MIN; for symbol in input { let range = high - low; let p = model.get(symbol).expect("Invalid/Unsupported data"); high = low + (range as f64 * p.1) as u64; low = low + (range as f64 * p.0) as u64; loop { if high < 1 << 63 { output.write(false); print!("0"); } else if low >= 1 << (u64::BITS - 1) { output.write(true); print!("1"); } else { break; } low <<= 1; high <<= 1; high |= 1; } } println!(""); return output.flush(); } fn decode(input: &[u8], model: &Model) -> Vec { let mut high = 1.0; let mut low = 0.0; let mut output = vec![]; for bit in BitReader::new(input) { let diff = high - low; if bit { //print!("1"); low = low + (diff / 2.0); } else { high = high - (diff / 2.0); //print!("0"); } if let Some(symbol) = get_symbol(model, low, high) { //println!("\nGot sym: {} from [{}, {})", symbol as char, low, high); output.push(symbol); let (slow, shigh) = model.get(&symbol).unwrap(); let symdiff = *shigh - *slow; high = (high - *slow) / symdiff; low = (low - *slow) / symdiff; } } return output; } fn make_model(probabilities: &[(u8, f64)]) -> Model { let mut model = HashMap::new(); let mut end: f64 = 0.0; for (symbol, probability) in probabilities { let start: f64 = end; end = start + probability; model.insert(*symbol, (start, end)); println!("{}: [{}, {})", *symbol as char, start, end); } return model; } const ENGLISH: &[(u8, f64)] = &[ (b'a', 0.08), (b'b', 0.01), (b'c', 0.02), (b'd', 0.04), (b'e', 0.12), (b'f', 0.02), (b'g', 0.02), (b'h', 0.06), (b'i', 0.07), (b'j', 0.01), (b'k', 0.01), (b'l', 0.04), (b'm', 0.02), (b'n', 0.06), (b'o', 0.07), (b'p', 0.01), (b'q', 0.01), (b'r', 0.06), (b's', 0.06), (b't', 0.09), (b'u', 0.02), (b'v', 0.01), (b'w', 0.02), (b'x', 0.01), (b'y', 0.02), (b'z', 0.01), (b' ', 0.01), (b'-', 0.02), ]; fn main() { let data = b"hello world-"; println!("MODEL:"); let model: Model = make_model(ENGLISH); println!(""); let _enc = encode(data, &model); let _dec = decode(&_enc, &model); println!("{}", String::from_utf8(_dec).unwrap()); println!( "Compression Ratio: {}", data.len() as f64 / _enc.len() as f64 ); }