diff --git a/src/main.rs b/src/main.rs index bbc5db0..4bf799b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,21 +16,46 @@ fn get_symbol(model: &Model, low: f64, high: f64) -> Option { } fn encode(input: &[u8], model: &Model) -> Vec { + const HALF: u64 = 1 << (u64::BITS - 1); + const LOW_CONVERGE: u64 = 0b10 << (u64::BITS - 2); + const HIGH_CONVERGE: u64 = 0b01 << (u64::BITS - 2); + let mut output = BitWriter::new(); + let mut high = u64::MAX; let mut low = u64::MIN; + let mut pending_bits = 0; + for symbol in input { let range = high - low; let p = model.get(symbol).expect("Invalid/Unsupported data"); high = low + (range as f64 * p.1) as u64; low = low + (range as f64 * p.0) as u64; loop { - if high < 1 << 63 { + if high < HALF { output.write(false); print!("0"); - } else if low >= 1 << (u64::BITS - 1) { + while pending_bits > 0 { + output.write(true); + print!("1"); + pending_bits -= 1; + } + } else if low >= HALF { output.write(true); print!("1"); + while pending_bits > 0 { + output.write(true); + print!("0"); + pending_bits -= 1; + } + } else if low >= LOW_CONVERGE && high < HIGH_CONVERGE { + println!("BET"); + pending_bits += 1; + low <<= 1; + low &= HALF - 1; + high <<= 1; + high &= HALF + 1; + continue; } else { break; }