Remove fixed precision encoder and decoder

This commit is contained in:
Lucas Schumacher 2024-08-19 22:59:31 -04:00
parent 91f1860ce5
commit 5a3bfa9618

View File

@ -6,16 +6,7 @@ use bit_buffer::{BitReader, BitWriter};
type Model = HashMap<u8, (f64, f64)>; type Model = HashMap<u8, (f64, f64)>;
fn get_symbol(model: &Model, d: f64) -> Option<u8> { fn get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> {
// Brute force
for (symbol, (start, end)) in model {
if d >= *start && d < *end {
return Some(*symbol);
}
}
return None;
}
fn _get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> {
for (symbol, (start, end)) in model { for (symbol, (start, end)) in model {
if low >= *start && high < *end { if low >= *start && high < *end {
return Some(*symbol); return Some(*symbol);
@ -24,18 +15,7 @@ fn _get_symbol(model: &Model, low: f64, high: f64) -> Option<u8> {
return None; return None;
} }
fn encode(data: &[u8], model: &Model) -> f64 { fn encode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high: f64 = 1.0;
let mut low: f64 = 0.0;
for symbol in data {
let p = model.get(symbol).expect("Invalid/Unsupported data");
let range = high - low;
high = low + range * p.1;
low = low + range * p.0;
}
return low + (high - low) / 2.0;
}
fn _encode(input: &[u8], model: &Model) -> Vec<u8> {
let mut output = BitWriter::new(); let mut output = BitWriter::new();
let mut high = u64::MAX; let mut high = u64::MAX;
let mut low = u64::MIN; let mut low = u64::MIN;
@ -63,31 +43,7 @@ fn _encode(input: &[u8], model: &Model) -> Vec<u8> {
return output.flush(); return output.flush();
} }
fn decode(message: f64, model: &Model) { fn decode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high: f64 = 1.0;
let mut low: f64 = 0.0;
loop {
let range = high - low;
let d = (message - low) / range;
let c = match get_symbol(&model, d) {
Some(c) => c,
None => {
println!("");
eprintln!("Decode error: d={d}");
return;
}
};
if c == b'-' {
println!("");
return;
}
print!("{}", c as char);
let p = model.get(&c).expect("Decode error");
high = low + range * p.1;
low = low + range * p.0;
}
}
fn _decode(input: &[u8], model: &Model) -> Vec<u8> {
let mut high = 1.0; let mut high = 1.0;
let mut low = 0.0; let mut low = 0.0;
let mut output = vec![]; let mut output = vec![];
@ -100,7 +56,7 @@ fn _decode(input: &[u8], model: &Model) -> Vec<u8> {
high = high - (diff / 2.0); high = high - (diff / 2.0);
//print!("0"); //print!("0");
} }
if let Some(symbol) = _get_symbol(model, low, high) { if let Some(symbol) = get_symbol(model, low, high) {
//println!("\nGot sym: {} from [{}, {})", symbol as char, low, high); //println!("\nGot sym: {} from [{}, {})", symbol as char, low, high);
output.push(symbol); output.push(symbol);
let (slow, shigh) = model.get(&symbol).unwrap(); let (slow, shigh) = model.get(&symbol).unwrap();
@ -159,13 +115,9 @@ fn main() {
println!("MODEL:"); println!("MODEL:");
let model: Model = make_model(ENGLISH); let model: Model = make_model(ENGLISH);
println!(""); println!("");
let message = encode(data, &model);
println!("{message}");
decode(message, &model);
println!("Compression Ratio: {}\n", data.len() as f64 / 8.0);
let _enc = _encode(data, &model); let _enc = encode(data, &model);
let _dec = _decode(&_enc, &model); let _dec = decode(&_enc, &model);
println!("{}", String::from_utf8(_dec).unwrap()); println!("{}", String::from_utf8(_dec).unwrap());
println!( println!(