128 lines
3.2 KiB
Rust
128 lines
3.2 KiB
Rust
// https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
|
|
use std::collections::HashMap;
|
|
|
|
mod bit_buffer;
|
|
use bit_buffer::{BitReader, BitWriter};
|
|
|
|
type Model = HashMap<u8, (f64, f64)>;
|
|
|
|
fn get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> {
|
|
for (symbol, (start, end)) in model {
|
|
if low >= *start && high < *end {
|
|
return Some(*symbol);
|
|
}
|
|
}
|
|
return None;
|
|
}
|
|
|
|
fn encode(input: &[u8], model: &Model) -> Vec<u8> {
|
|
let mut output = BitWriter::new();
|
|
let mut high = u64::MAX;
|
|
let mut low = u64::MIN;
|
|
for symbol in input {
|
|
let range = high - low;
|
|
let p = model.get(symbol).expect("Invalid/Unsupported data");
|
|
high = low + (range as f64 * p.1) as u64;
|
|
low = low + (range as f64 * p.0) as u64;
|
|
loop {
|
|
if high < 1 << 63 {
|
|
output.write(false);
|
|
print!("0");
|
|
} else if low >= 1 << (u64::BITS - 1) {
|
|
output.write(true);
|
|
print!("1");
|
|
} else {
|
|
break;
|
|
}
|
|
low <<= 1;
|
|
high <<= 1;
|
|
high |= 1;
|
|
}
|
|
}
|
|
println!("");
|
|
return output.flush();
|
|
}
|
|
|
|
fn decode(input: &[u8], model: &Model) -> Vec<u8> {
|
|
let mut high = 1.0;
|
|
let mut low = 0.0;
|
|
let mut output = vec![];
|
|
for bit in BitReader::new(input) {
|
|
let diff = high - low;
|
|
if bit {
|
|
//print!("1");
|
|
low = low + (diff / 2.0);
|
|
} else {
|
|
high = high - (diff / 2.0);
|
|
//print!("0");
|
|
}
|
|
if let Some(symbol) = get_symbol(model, low, high) {
|
|
//println!("\nGot sym: {} from [{}, {})", symbol as char, low, high);
|
|
output.push(symbol);
|
|
let (slow, shigh) = model.get(&symbol).unwrap();
|
|
let symdiff = *shigh - *slow;
|
|
high = (high - *slow) / symdiff;
|
|
low = (low - *slow) / symdiff;
|
|
}
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
fn make_model(probabilities: &[(u8, f64)]) -> Model {
|
|
let mut model = HashMap::new();
|
|
let mut end: f64 = 0.0;
|
|
for (symbol, probability) in probabilities {
|
|
let start: f64 = end;
|
|
end = start + probability;
|
|
model.insert(*symbol, (start, end));
|
|
println!("{}: [{}, {})", *symbol as char, start, end);
|
|
}
|
|
return model;
|
|
}
|
|
const ENGLISH: &[(u8, f64)] = &[
|
|
(b'a', 0.08),
|
|
(b'b', 0.01),
|
|
(b'c', 0.02),
|
|
(b'd', 0.04),
|
|
(b'e', 0.12),
|
|
(b'f', 0.02),
|
|
(b'g', 0.02),
|
|
(b'h', 0.06),
|
|
(b'i', 0.07),
|
|
(b'j', 0.01),
|
|
(b'k', 0.01),
|
|
(b'l', 0.04),
|
|
(b'm', 0.02),
|
|
(b'n', 0.06),
|
|
(b'o', 0.07),
|
|
(b'p', 0.01),
|
|
(b'q', 0.01),
|
|
(b'r', 0.06),
|
|
(b's', 0.06),
|
|
(b't', 0.09),
|
|
(b'u', 0.02),
|
|
(b'v', 0.01),
|
|
(b'w', 0.02),
|
|
(b'x', 0.01),
|
|
(b'y', 0.02),
|
|
(b'z', 0.01),
|
|
(b' ', 0.01),
|
|
(b'-', 0.02),
|
|
];
|
|
fn main() {
|
|
let data = b"hello world-";
|
|
println!("MODEL:");
|
|
let model: Model = make_model(ENGLISH);
|
|
println!("");
|
|
|
|
let _enc = encode(data, &model);
|
|
let _dec = decode(&_enc, &model);
|
|
|
|
println!("{}", String::from_utf8(_dec).unwrap());
|
|
println!(
|
|
"Compression Ratio: {}",
|
|
data.len() as f64 / _enc.len() as f64
|
|
);
|
|
}
|