Skip to content

Commit

Permalink
table generation now matches
Browse files Browse the repository at this point in the history
  • Loading branch information
KillingSpark committed Oct 13, 2024
1 parent 262a7be commit 73f7797
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 49 deletions.
105 changes: 63 additions & 42 deletions src/fse/fse_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::encoding::bit_writer::BitWriter;
use core::{iter::from_fn, u8};
use std::vec::{self, Vec};
use std::{
dbg,
vec::{self, Vec},
};

pub struct FSEEncoder {
table: FSETable,
Expand Down Expand Up @@ -33,7 +36,7 @@ impl FSEEncoder {
#[derive(Debug)]
pub struct FSETable {
/// Indexed by symbol
states: [SymbolStates; 256],
pub(super) states: [SymbolStates; 256],
table_size: usize,
}

Expand All @@ -45,9 +48,9 @@ impl FSETable {
}

#[derive(Debug)]
struct SymbolStates {
pub(super) struct SymbolStates {
/// Sorted by baseline
states: Vec<State>,
pub(super) states: Vec<State>,
}

impl SymbolStates {
Expand All @@ -61,12 +64,12 @@ impl SymbolStates {
}

#[derive(Debug)]
struct State {
num_bits: u8,
baseline: usize,
last_index: usize,
pub(super) struct State {
pub(super) num_bits: u8,
pub(super) baseline: usize,
pub(super) last_index: usize,
/// Index of this state in the decoding table
index: usize,
pub(super) index: usize,
}

impl State {
Expand Down Expand Up @@ -121,59 +124,78 @@ pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSET
let mut states =
core::array::from_fn::<SymbolStates, 256, _>(|_| SymbolStates { states: Vec::new() });

let mut indexes_used = alloc::vec![false; 1 << acc_log];

// distribute -1 symbols
let mut idx = (1 << acc_log) - 1;
for (symbol, prob) in probs.iter().copied().filter(|prob| *prob == -1).enumerate() {
states[symbol].states.push(State{
let mut negative_idx = (1 << acc_log) - 1;
for (symbol, _prob) in probs.iter().copied().enumerate().filter(|prob| prob.1 == -1) {
dbg!(symbol, negative_idx);
states[symbol].states.push(State {
num_bits: acc_log,
baseline: 0,
last_index: (1 << acc_log) - 1,
index: idx,
index: negative_idx,
});
indexes_used[idx] = true;
idx -= 1;
negative_idx -= 1;
}

// distribute other symbols
let mut idx = 0;
for (symbol, prob) in probs.iter().copied().enumerate() {
if prob == 0 {
if prob <= 0 {
continue;
}
let states = &mut states[symbol].states;
let prob_log = (prob as u32).ilog2();
let rounded_up = 1 << (prob_log + 1);
for _ in 0..prob {
states.push(State {
num_bits: 0,
baseline: 0,
last_index: 0,
index: idx,
});

idx = next_position(idx, 1 << acc_log);
while idx > negative_idx {
idx = next_position(idx, 1 << acc_log);
}
}
assert_eq!(states.len(), prob as usize);
}

for (symbol, prob) in probs.iter().copied().enumerate() {
if prob <= 0 {
continue;
}
let prob = prob as u32;
let state = &mut states[symbol];
state.states.sort_by(|l,r| l.index.cmp(&r.index));

let prob_log = if prob.is_power_of_two() {
prob.ilog2()
} else {
prob.ilog2() + 1
};
let rounded_up = 1u32 << prob_log;
let double_states = rounded_up - prob;
let single_states = prob - double_states;
let num_bits = acc_log - prob_log as u8;
let mut baseline = 0;
for state_idx in 0..prob {
if state_idx < double_states {
let mut baseline = (single_states as usize * (1 << (num_bits))) % (1 << acc_log);
for (idx, state) in state.states.iter_mut().enumerate() {
if (idx as u32) < double_states {
let num_bits = num_bits + 1;
states.push(State{
num_bits: num_bits,
baseline,
last_index: baseline + ((1 << num_bits) - 1),
index: idx,
});
state.baseline = baseline;
state.num_bits = num_bits;
state.last_index= baseline + ((1 << num_bits) - 1);

baseline += 1 << num_bits;
indexes_used[idx] = true;
baseline %= 1 << acc_log;
} else {
states.push(State{
num_bits,
baseline,
last_index: baseline + ((1 << num_bits) - 1),
index: idx,
});
state.baseline = baseline;
state.num_bits = num_bits;
state.last_index= baseline + ((1 << num_bits) - 1);
baseline += 1 << num_bits;
indexes_used[idx] = true;
}

while indexes_used[idx] {
idx = next_position(idx, 1 << acc_log);
}
}
state.states.sort_by(|l,r| l.baseline.cmp(&r.baseline));
}

FSETable {
Expand All @@ -182,11 +204,10 @@ pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSET
}
}

//utility functions for building the decoding table from probabilities
/// Calculate the position of the next entry of the table given the current
/// position and size of the table.
fn next_position(mut p: usize, table_size: usize) -> usize {
p += (table_size >> 1) + (table_size >> 3) + 3;
p &= table_size - 1;
p
}
}
21 changes: 14 additions & 7 deletions src/fse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
//! <https://arxiv.org/pdf/1311.2540>

mod fse_decoder;

pub use fse_decoder::*;
mod fse_encoder;
pub mod fse_encoder;

#[test]
fn tables() {
let probs = &[0,0,-1,3,2,2];
fn tables_equal() {
let probs = &[0, 0, -1, 3, 2, 2, (1 << 6) - 8];
let mut dec_table = FSETable::new(255);
dec_table.build_from_probabilities(3, probs).unwrap();
let enc_table = fse_encoder::build_table_from_probabilities(probs, 3);
panic!("{:?}\n{:?}", dec_table, enc_table);
}
dec_table.build_from_probabilities(6, probs).unwrap();
let enc_table = fse_encoder::build_table_from_probabilities(probs, 6);

for (idx, dec_state) in dec_table.decode.iter().enumerate() {
let enc_states = &enc_table.states[dec_state.symbol as usize];
let enc_state = enc_states.states.iter().find(| state| state.index == idx).unwrap();
assert_eq!(enc_state.baseline, dec_state.base_line as usize);
assert_eq!(enc_state.num_bits, dec_state.num_bits);
}
}

0 comments on commit 73f7797

Please sign in to comment.