From afa9b5266f785144826b609b5a6d2634648ccb37 Mon Sep 17 00:00:00 2001 From: Moritz Borcherding Date: Tue, 15 Oct 2024 12:23:02 +0200 Subject: [PATCH] encode huffman tables according to spec --- src/encoding/bit_writer.rs | 7 ++++ src/fse/fse_decoder.rs | 3 +- src/fse/fse_encoder.rs | 75 ++++++++++++++++++++++++++++++++++++-- src/fse/mod.rs | 2 +- src/huff0/huff0_decoder.rs | 12 +----- src/huff0/huff0_encoder.rs | 39 +++++++++++++++++++- src/huff0/mod.rs | 5 ++- 7 files changed, 124 insertions(+), 19 deletions(-) diff --git a/src/encoding/bit_writer.rs b/src/encoding/bit_writer.rs index 2419c67..60e87ae 100644 --- a/src/encoding/bit_writer.rs +++ b/src/encoding/bit_writer.rs @@ -36,6 +36,13 @@ impl BitWriter { } } + pub fn append_bytes(&mut self, data: &[u8]) { + if self.misaligned() != 0 { + panic!("Don't append bytes when writer is misaligned") + } + self.output.extend_from_slice(data); + } + pub fn write_bits(&mut self, bits: impl Into, num_bits: usize) { self.write_bits_64(bits.into(), num_bits); } diff --git a/src/fse/fse_decoder.rs b/src/fse/fse_decoder.rs index 26e34e7..45adc93 100644 --- a/src/fse/fse_decoder.rs +++ b/src/fse/fse_decoder.rs @@ -189,7 +189,8 @@ impl<'t> FSEDecoder<'t> { if self.table.accuracy_log == 0 { return Err(FSEDecoderError::TableIsUninitialized); } - self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize]; + let new_state = bits.get_bits(self.table.accuracy_log); + self.state = self.table.decode[new_state as usize]; Ok(()) } diff --git a/src/fse/fse_encoder.rs b/src/fse/fse_encoder.rs index 64897c0..33dc432 100644 --- a/src/fse/fse_encoder.rs +++ b/src/fse/fse_encoder.rs @@ -38,6 +38,66 @@ impl FSEEncoder { writer.dump() } + pub fn encode_interleaved(&mut self, data: &[u8]) -> Vec { + self.write_table(); + + let mut state_1 = &self.table.states[data[data.len() - 1] as usize].states[0]; + let mut state_2 = &self.table.states[data[data.len() - 2] as usize].states[0]; + + let mut idx = data.len() - 4; + loop { + { + let state = state_1; + let x = data[idx + 1]; + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state_1 = next; + } + { + let state = state_2; + let x = data[idx]; + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state_2 = next; + } + + if idx < 2 { + break; + } + idx -= 2; + } + if idx == 1 { + let state = state_1; + let x = data[0]; + let next = self.table.next_state(x, state.index); + let diff = state.index - next.baseline; + self.writer.write_bits(diff as u64, next.num_bits as usize); + state_1 = next; + + self.writer + .write_bits(state_2.index as u64, self.acc_log() as usize); + self.writer + .write_bits(state_1.index as u64, self.acc_log() as usize); + } else { + self.writer + .write_bits(state_1.index as u64, self.acc_log() as usize); + self.writer + .write_bits(state_2.index as u64, self.acc_log() as usize); + } + + let mut writer = BitWriter::new(); + core::mem::swap(&mut self.writer, &mut writer); + let bits_to_fill = writer.misaligned(); + if bits_to_fill == 0 { + writer.write_bits(1u32, 8); + } else { + writer.write_bits(1u32, bits_to_fill); + } + writer.dump() + } + fn write_table(&mut self) { self.writer.write_bits(self.acc_log() - 5, 4); let mut probability_counter = 0usize; @@ -133,15 +193,15 @@ impl State { } } -pub fn build_table_from_data(data: &[u8]) -> FSETable { +pub fn build_table_from_data(data: &[u8], avoid_0_numbit: bool) -> FSETable { let mut counts = [0; 256]; for x in data { counts[*x as usize] += 1; } - build_table_from_counts(&counts) + build_table_from_counts(&counts, avoid_0_numbit) } -fn build_table_from_counts(counts: &[usize]) -> FSETable { +fn build_table_from_counts(counts: &[usize], avoid_0_numbit: bool) -> FSETable { let mut probs = [0; 256]; let mut min_count = 0; for (idx, count) in counts.iter().copied().enumerate() { @@ -172,6 +232,15 @@ fn build_table_from_counts(counts: &[usize]) -> FSETable { let max = probs.iter_mut().max().unwrap(); *max += diff as i32; + if avoid_0_numbit && *max > 1 << (acc_log - 1) { + let redistribute = *max - (1 << (acc_log - 1)); + *max -= redistribute; + let max = *max; + let second_max = probs.iter_mut().filter(|x| **x != max).max().unwrap(); + *second_max += redistribute; + assert!(*second_max <= max); + } + build_table_from_probabilities(&probs, acc_log) } diff --git a/src/fse/mod.rs b/src/fse/mod.rs index 46531e1..462444c 100644 --- a/src/fse/mod.rs +++ b/src/fse/mod.rs @@ -83,7 +83,7 @@ pub fn round_trip(data: &[u8]) { return; } - let mut encoder: FSEEncoder = FSEEncoder::new(fse_encoder::build_table_from_data(data)); + let mut encoder: FSEEncoder = FSEEncoder::new(fse_encoder::build_table_from_data(data, false)); let mut dec_table = FSETable::new(255); let encoded = encoder.encode(data); diff --git a/src/huff0/huff0_decoder.rs b/src/huff0/huff0_decoder.rs index cd353ce..df035d5 100644 --- a/src/huff0/huff0_decoder.rs +++ b/src/huff0/huff0_decoder.rs @@ -293,7 +293,7 @@ impl HuffmanTable { bits: Vec::with_capacity(256), bit_ranks: Vec::with_capacity(11), rank_indexes: Vec::with_capacity(11), - fse_table: FSETable::new(100), + fse_table: FSETable::new(255), } } @@ -586,14 +586,4 @@ impl HuffmanTable { Ok(()) } - - /// For internal tests construct directly from weights - pub(super) fn from_weights(mut weights: Vec) -> Self { - // Last weight is inferred by build_table_from_weights - weights.pop(); - let mut new = Self::new(); - new.weights = weights; - new.build_table_from_weights().unwrap(); - new - } } diff --git a/src/huff0/huff0_encoder.rs b/src/huff0/huff0_encoder.rs index 66131c4..6ef4ead 100644 --- a/src/huff0/huff0_encoder.rs +++ b/src/huff0/huff0_encoder.rs @@ -1,7 +1,10 @@ use alloc::vec::Vec; use core::cmp::Ordering; -use crate::encoding::bit_writer::BitWriter; +use crate::{ + encoding::bit_writer::BitWriter, + fse::fse_encoder::{self, FSEEncoder}, +}; pub struct HuffmanEncoder { table: HuffmanTable, @@ -16,6 +19,7 @@ impl HuffmanEncoder { } } pub fn encode(&mut self, data: &[u8]) { + self.write_table(); for symbol in data.iter().rev() { let (code, num_bits) = self.table.codes[*symbol as usize]; self.writer.write_bits(code, num_bits as usize); @@ -44,6 +48,39 @@ impl HuffmanEncoder { weights } + + fn write_table(&mut self) { + // TODO strategy for determining this? + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; // dont encode last weight + + if weights.len() > 16 { + // TODO share output vec between encoders + // TODO assert that no 0 num_bit states are generated here + let mut encoder = FSEEncoder::new(fse_encoder::build_table_from_data(&weights, true)); + let encoded = encoder.encode_interleaved(&weights); + assert!(encoded.len() < 128); + self.writer.write_bits(encoded.len() as u8, 8); + self.writer.append_bytes(&encoded); + } else { + self.writer.write_bits(weights.len() as u8 + 127, 8); + let pairs = weights.chunks_exact(2); + let remainder = pairs.remainder(); + for pair in pairs.into_iter() { + let weight1 = pair[0]; + let weight2 = pair[1]; + assert!(weight1 < 16); + assert!(weight2 < 16); + self.writer.write_bits(weight2, 4); + self.writer.write_bits(weight1, 4); + } + if !remainder.is_empty() { + let weight = remainder[0]; + assert!(weight < 16); + self.writer.write_bits(weight << 4, 8); + } + } + } } pub struct HuffmanTable { diff --git a/src/huff0/mod.rs b/src/huff0/mod.rs index 1a65241..36d1f20 100644 --- a/src/huff0/mod.rs +++ b/src/huff0/mod.rs @@ -22,10 +22,11 @@ pub fn round_trip(data: &[u8]) { encoder.encode(data); let encoded = encoder.dump(); - let decoder_table = HuffmanTable::from_weights(encoder.weights()); + let mut decoder_table = HuffmanTable::new(); + let table_bytes = decoder_table.build_decoder(&encoded).unwrap(); let mut decoder = HuffmanDecoder::new(&decoder_table); - let mut br = BitReaderReversed::new(&encoded); + let mut br = BitReaderReversed::new(&encoded[table_bytes as usize..]); let mut skipped_bits = 0; loop { let val = br.get_bits(1);