Clean up main and switch to modelA

This commit is contained in:
Lucas Schumacher 2024-09-08 06:28:49 -04:00
parent 5eac451458
commit b68ef65f31
2 changed files with 115 additions and 90 deletions

View File

@ -1,101 +1,46 @@
// https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html // https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
use std::collections::HashMap; #[allow(non_snake_case)]
mod modelA;
mod bit_buffer; mod bit_buffer;
use bit_buffer::{BitReader, BitWriter}; //mod model;
mod model; use modelA::ModelA;
use model::{get_symbol, make_model, Model, ENGLISH};
fn encode(input: &[u8], model: &Model) -> Vec<u8> {
const HALF: u64 = 1 << (u64::BITS - 1);
const LOW_CONVERGE: u64 = 0b10 << (u64::BITS - 2);
const HIGH_CONVERGE: u64 = 0b01 << (u64::BITS - 2);
let mut output = BitWriter::new();
let mut high = u64::MAX;
let mut low = u64::MIN;
let mut pending_bits = 0;
for symbol in input {
let range = high - low;
let p = model.get(symbol).expect("Invalid/Unsupported data");
high = low + (range as f64 * p.1) as u64;
low = low + (range as f64 * p.0) as u64;
loop {
if high < HALF {
output.write(false);
print!("0");
while pending_bits > 0 {
output.write(true);
print!("1");
pending_bits -= 1;
}
} else if low >= HALF {
output.write(true);
print!("1");
while pending_bits > 0 {
output.write(true);
print!("0");
pending_bits -= 1;
}
} else if low >= LOW_CONVERGE && high < HIGH_CONVERGE {
println!("BET");
pending_bits += 1;
low <<= 1;
low &= HALF - 1;
high <<= 1;
high &= HALF + 1;
continue;
} else {
break;
}
low <<= 1;
high <<= 1;
high |= 1;
}
}
println!("");
return output.flush();
}
fn decode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high = 1.0;
let mut low = 0.0;
let mut output = vec![];
for bit in BitReader::new(input) {
let diff = high - low;
if bit {
//print!("1");
low = low + (diff / 2.0);
} else {
high = high - (diff / 2.0);
//print!("0");
}
if let Some(symbol) = get_symbol(model, low, high) {
//println!("\nGot sym: {} from [{}, {})", symbol as char, low, high);
output.push(symbol);
let (slow, shigh) = model.get(&symbol).unwrap();
let symdiff = *shigh - *slow;
high = (high - *slow) / symdiff;
low = (low - *slow) / symdiff;
}
}
return output;
}
fn main() { fn main() {
let data = b"hello world-"; let data = b"
println!("MODEL:"); I'd just like to interject for a moment. What you're refering to as Linux, is in fact, GNU/Linux, or as I've re
let model: Model = make_model(ENGLISH); aken to calling it, GNU plus Linux. Linux is not an operating system unto itself, but rather another free compo
a fully functioning GNU system made useful by the GNU corelibs, shell utilities and vital system components com
a full OS as defined by POSIX.
Many computer users run a modified version of the GNU system every day, without realizing it. Through a peculia
f events, the version of GNU which is widely used today is often called Linux, and many of its users are not aw
it is basically the GNU system, developed by the GNU Project.
There really is a Linux, and these people are using it, but it is just a part of the system they use. Linux is
el: the program in the system that allocates the machine's resources to the other programs that you run. The ke
an essential part of an operating system, but useless by itself; it can only function in the context of a compl
ating system. Linux is normally used in combination with the GNU operating system: the whole system is basicall
th Linux added, or GNU/Linux. All the so-called Linux distributions are really distributions of GNU/Linux!
";
type CodeValue = u32;
println!("compressing...");
let model: ModelA<CodeValue> = ModelA::default();
model.print_metrics();
println!(""); println!("");
let _enc = encode(data, &model); let enc = model.compress(data);
let _dec = decode(&_enc, &model); //println!("{}", enc.len());
println!("ModelA compressed to {} bytes", enc.len());
println!("{}", String::from_utf8(_dec).unwrap());
println!( println!(
"Compression Ratio: {}", "Compression Ratio: {}",
data.len() as f64 / _enc.len() as f64 data.len() as f64 / enc.len() as f64
); );
//println!("--------- Compressed data ---------\n{}", dump_hex(&enc));
println!("");
println!("decompressing...");
let model: ModelA<CodeValue> = ModelA::default();
let dec = model.decompress(&enc).unwrap();
println!("{}", String::from_utf8_lossy(&dec));
} }

View File

@ -1,5 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
pub type Model = HashMap<u8, (f64, f64)>; pub type Model = HashMap<u8, (f64, f64)>;
use crate::bit_buffer::{BitReader, BitWriter};
pub fn get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> { pub fn get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> {
for (symbol, (start, end)) in model { for (symbol, (start, end)) in model {
@ -51,3 +52,82 @@ pub const ENGLISH: &[(u8, f64)] = &[
(b' ', 0.01), (b' ', 0.01),
(b'-', 0.02), (b'-', 0.02),
]; ];
fn encode(input: &[u8], model: &Model) -> Vec<u8> {
const HALF: u64 = 1 << (u64::BITS - 1);
const LOW_CONVERGE: u64 = 0b10 << (u64::BITS - 2);
const HIGH_CONVERGE: u64 = 0b01 << (u64::BITS - 2);
let mut output = BitWriter::new();
let mut high = u64::MAX;
let mut low = u64::MIN;
let mut pending_bits = 0;
for symbol in input {
let range = high - low;
let p = model.get(symbol).expect("Invalid/Unsupported data");
high = low + (range as f64 * p.1) as u64;
low = low + (range as f64 * p.0) as u64;
loop {
if high < HALF {
output.write(false);
print!("0");
while pending_bits > 0 {
output.write(true);
print!("1");
pending_bits -= 1;
}
} else if low >= HALF {
output.write(true);
print!("1");
while pending_bits > 0 {
output.write(true);
print!("0");
pending_bits -= 1;
}
} else if low >= LOW_CONVERGE && high < HIGH_CONVERGE {
println!("BET");
pending_bits += 1;
low <<= 1;
low &= HALF - 1;
high <<= 1;
high &= HALF + 1;
continue;
} else {
break;
}
low <<= 1;
high <<= 1;
high |= 1;
}
}
println!("");
return output.flush();
}
fn decode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high = 1.0;
let mut low = 0.0;
let mut output = vec![];
for bit in BitReader::new(input) {
let diff = high - low;
if bit {
//print!("1");
low = low + (diff / 2.0);
} else {
high = high - (diff / 2.0);
//print!("0");
}
if let Some(symbol) = get_symbol(model, low, high) {
//println!("\nGot sym: {} from [{}, {})", symbol as char, low, high);
output.push(symbol);
let (slow, shigh) = model.get(&symbol).unwrap();
let symdiff = *shigh - *slow;
high = (high - *slow) / symdiff;
low = (low - *slow) / symdiff;
}
}
return output;
}