Compare commits

..

4 Commits

3 changed files with 176 additions and 35 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target /target
Cargo.lock

87
src/bit_buffer.rs Normal file
View File

@ -0,0 +1,87 @@
#[derive(Debug)]
pub struct BitWriter {
data: Vec<u8>,
bits: u8,
nextbit: usize,
}
impl BitWriter {
pub fn new() -> Self {
return BitWriter {
data: vec![],
bits: 0,
nextbit: 7,
};
}
pub fn write(&mut self, bit: bool) {
if bit {
self.bits = 1 << self.nextbit | self.bits;
}
if self.nextbit == 0 {
self.data.push(self.bits);
self.bits = 0;
self.nextbit = 7;
} else {
self.nextbit -= 1;
}
}
pub fn flush(mut self) -> Vec<u8> {
if self.bits != 0 {
self.data.push(self.bits);
}
return self.data;
}
}
trait Poppable {
fn pop(&mut self) -> Option<u8>;
}
impl Poppable for &[u8] {
fn pop(&mut self) -> Option<u8> {
if self.len() == 0 {
return None;
}
let out = self[0];
*self = &self[1..];
return Some(out);
}
}
pub struct BitReader<'a> {
data: &'a [u8],
// bits: u8,
nextbit: usize,
}
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
BitReader { data, nextbit: 7 }
}
}
impl<'a> From<&'a [u8]> for BitReader<'a> {
fn from(value: &'a [u8]) -> Self {
BitReader::new(value)
}
}
impl BitReader<'_> {
pub fn pop(&mut self) -> Option<bool> {
if self.data.len() == 0 {
return None;
}
let bit = (self.data[0] >> self.nextbit) & 1;
if self.nextbit == 0 {
self.data.pop();
self.nextbit = 8;
}
self.nextbit -= 1;
return Some(bit > 0);
}
}
impl Iterator for BitReader<'_> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
self.pop()
}
}

View File

@ -1,53 +1,97 @@
// https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html // https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
use std::collections::HashMap; use std::collections::HashMap;
mod bit_buffer;
use bit_buffer::{BitReader, BitWriter};
type Model = HashMap<u8, (f64, f64)>; type Model = HashMap<u8, (f64, f64)>;
fn get_symbol(model: &Model, d: f64) -> Option<u8> { fn get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> {
// Brute force
for (symbol, (start, end)) in model { for (symbol, (start, end)) in model {
if d >= *start && d < *end { if low >= *start && high < *end {
return Some(*symbol); return Some(*symbol);
} }
} }
return None; return None;
} }
fn encode(data: &[u8], model: &Model) -> f64 { fn encode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high: f64 = 1.0; const HALF: u64 = 1 << (u64::BITS - 1);
let mut low: f64 = 0.0; const LOW_CONVERGE: u64 = 0b10 << (u64::BITS - 2);
for symbol in data { const HIGH_CONVERGE: u64 = 0b01 << (u64::BITS - 2);
let p = model.get(symbol).expect("Invalid/Unsupported data");
let mut output = BitWriter::new();
let mut high = u64::MAX;
let mut low = u64::MIN;
let mut pending_bits = 0;
for symbol in input {
let range = high - low; let range = high - low;
high = low + range * p.1; let p = model.get(symbol).expect("Invalid/Unsupported data");
low = low + range * p.0; high = low + (range as f64 * p.1) as u64;
low = low + (range as f64 * p.0) as u64;
loop {
if high < HALF {
output.write(false);
print!("0");
while pending_bits > 0 {
output.write(true);
print!("1");
pending_bits -= 1;
}
} else if low >= HALF {
output.write(true);
print!("1");
while pending_bits > 0 {
output.write(true);
print!("0");
pending_bits -= 1;
}
} else if low >= LOW_CONVERGE && high < HIGH_CONVERGE {
println!("BET");
pending_bits += 1;
low <<= 1;
low &= HALF - 1;
high <<= 1;
high &= HALF + 1;
continue;
} else {
break;
}
low <<= 1;
high <<= 1;
high |= 1;
}
} }
return low + (high - low) / 2.0; println!("");
return output.flush();
} }
fn decode(message: f64, model: &Model) { fn decode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high: f64 = 1.0; let mut high = 1.0;
let mut low: f64 = 0.0; let mut low = 0.0;
loop { let mut output = vec![];
let range = high - low; for bit in BitReader::new(input) {
let d = (message - low) / range; let diff = high - low;
let c = match get_symbol(&model, d) { if bit {
Some(c) => c, //print!("1");
None => { low = low + (diff / 2.0);
println!(""); } else {
eprintln!("Decode error: d={d}"); high = high - (diff / 2.0);
return; //print!("0");
} }
}; if let Some(symbol) = get_symbol(model, low, high) {
if c == b'-' { //println!("\nGot sym: {} from [{}, {})", symbol as char, low, high);
println!(""); output.push(symbol);
return; let (slow, shigh) = model.get(&symbol).unwrap();
let symdiff = *shigh - *slow;
high = (high - *slow) / symdiff;
low = (low - *slow) / symdiff;
} }
print!("{}", c as char);
let p = model.get(&c).expect("Decode error");
high = low + range * p.1;
low = low + range * p.0;
} }
return output;
} }
fn make_model(probabilities: &[(u8, f64)]) -> Model { fn make_model(probabilities: &[(u8, f64)]) -> Model {
@ -92,8 +136,17 @@ const ENGLISH: &[(u8, f64)] = &[
(b'-', 0.02), (b'-', 0.02),
]; ];
fn main() { fn main() {
let data = b"hello world-";
println!("MODEL:");
let model: Model = make_model(ENGLISH); let model: Model = make_model(ENGLISH);
let message = encode(b"hello world-", &model); println!("");
println!("{message}");
decode(message, &model); let _enc = encode(data, &model);
let _dec = decode(&_enc, &model);
println!("{}", String::from_utf8(_dec).unwrap());
println!(
"Compression Ratio: {}",
data.len() as f64 / _enc.len() as f64
);
} }