From 175b7e0c59f3c36b7b9b4065ced947d8e8cf12e3 Mon Sep 17 00:00:00 2001 From: Wonwoo Choi Date: Sun, 29 Sep 2024 02:32:32 +0900 Subject: [PATCH] Implement decoding of ANS streams --- jxl/src/entropy_coding/ans.rs | 430 +++++++++++++++++++++++++- jxl/src/entropy_coding/context_map.rs | 2 +- jxl/src/entropy_coding/decode.rs | 57 ++-- jxl/src/error.rs | 6 +- 4 files changed, 466 insertions(+), 29 deletions(-) diff --git a/jxl/src/entropy_coding/ans.rs b/jxl/src/entropy_coding/ans.rs index 9a8bdcb..e7773eb 100644 --- a/jxl/src/entropy_coding/ans.rs +++ b/jxl/src/entropy_coding/ans.rs @@ -2,6 +2,432 @@ // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +// Originally written for jxl-oxide. + +use crate::bit_reader::BitReader; +use crate::error::{Error, Result}; + +const LOG_SUM_PROBS: usize = 12; +const SUM_PROBS: u16 = 1 << LOG_SUM_PROBS; + +const RLE_MARKER_SYM: u16 = LOG_SUM_PROBS as u16 + 1; + +#[derive(Debug)] +struct AnsHistogram { + buckets: Vec, + log_bucket_size: usize, + bucket_mask: u32, + // For optimizing fast-lossless case. + #[allow(unused)] + single_symbol: Option, +} + +// log_alphabet_size <= 8 and log_bucket_size <= 7, so u8 is sufficient for symbols and cutoffs. +#[derive(Debug, Copy, Clone)] +#[repr(C)] +struct Bucket { + alias_symbol: u8, + alias_cutoff: u8, + dist: u16, + alias_offset: u16, + alias_dist_xor: u16, +} + +impl AnsHistogram { + fn decode_dist_single_symbol(br: &mut BitReader, dist: &mut [u16]) -> Result { + let table_size = dist.len(); + + let v0 = Self::read_u8(br)? as usize; + let v1 = Self::read_u8(br)? as usize; + if v0 == v1 { + return Err(Error::InvalidAnsHistogram); + } + + let alphabet_size = v0.max(v1) + 1; + if alphabet_size > table_size { + return Err(Error::InvalidAnsHistogram); + } + + let prob = br.read(LOG_SUM_PROBS)? as u16; + dist[v0] = prob; + dist[v1] = SUM_PROBS - prob; + + Ok(alphabet_size) + } + + fn decode_dist_two_symbols(br: &mut BitReader, dist: &mut [u16]) -> Result { + let table_size = dist.len(); + + let val = Self::read_u8(br)? as usize; + let alphabet_size = val + 1; + if alphabet_size > table_size { + return Err(Error::InvalidAnsHistogram); + } + + dist[val] = SUM_PROBS; + + Ok(alphabet_size) + } + + fn decode_dist_evenly_distributed(br: &mut BitReader, dist: &mut [u16]) -> Result { + let table_size = dist.len(); + + let alphabet_size = Self::read_u8(br)? as usize + 1; + if alphabet_size > table_size { + return Err(Error::InvalidAnsHistogram); + } + + let base = SUM_PROBS as usize / alphabet_size; + let remainder = SUM_PROBS as usize % alphabet_size; + dist[0..remainder].fill(base as u16 + 1); + dist[remainder..alphabet_size].fill(base as u16); + + Ok(alphabet_size) + } + + fn decode_dist_complex(br: &mut BitReader, dist: &mut [u16]) -> Result { + let table_size = dist.len(); + + let mut len = 0usize; + while len < 3 { + if br.read(1)? != 0 { + len += 1; + } else { + break; + } + } + + let shift = (br.read(len)? + (1 << len) - 1) as i16; + if shift > 13 { + return Err(Error::InvalidAnsHistogram); + } + + let alphabet_size = Self::read_u8(br)? as usize + 3; + if alphabet_size > table_size { + return Err(Error::InvalidAnsHistogram); + } + + // TODO(tirr-c): This could be an array of length `SUM_PROB / 4` (4 is from the minimum + // value of `repeat_count`). Change if using array is faster. + let mut repeat_ranges = Vec::new(); + let mut omit_data = None; + let mut idx = 0; + while idx < alphabet_size { + dist[idx] = Self::read_prefix(br)?; + if dist[idx] == RLE_MARKER_SYM { + let repeat_count = Self::read_u8(br)? as usize + 4; + if idx + repeat_count > alphabet_size { + return Err(Error::InvalidAnsHistogram); + } + repeat_ranges.push(idx..(idx + repeat_count)); + idx += repeat_count; + continue; + } + match &mut omit_data { + Some((log, pos)) => { + if dist[idx] > *log { + *log = dist[idx]; + *pos = idx; + } + } + data => { + *data = Some((dist[idx], idx)); + } + } + idx += 1; + } + let Some((_, omit_pos)) = omit_data else { + return Err(Error::InvalidAnsHistogram); + }; + if dist.get(omit_pos + 1) == Some(&RLE_MARKER_SYM) { + return Err(Error::InvalidAnsHistogram); + } + + let mut repeat_range_idx = 0usize; + let mut acc = 0; + let mut prev_dist = 0u16; + for (idx, code) in dist.iter_mut().enumerate() { + if repeat_range_idx < repeat_ranges.len() + && repeat_ranges[repeat_range_idx].start <= idx + { + if repeat_ranges[repeat_range_idx].end == idx { + repeat_range_idx += 1; + } else { + *code = prev_dist; + acc += *code; + if acc > SUM_PROBS { + return Err(Error::InvalidAnsHistogram); + } + continue; + } + } + + if *code == 0 { + prev_dist = 0; + continue; + } + if idx == omit_pos { + prev_dist = 0; + continue; + } + if *code > 1 { + let zeros = (*code - 1) as i16; + let bitcount = (shift - ((LOG_SUM_PROBS as i16 - zeros) >> 1)).clamp(0, zeros); + *code = (1 << zeros) + ((br.read(bitcount as usize)? as u16) << (zeros - bitcount)); + } + prev_dist = *code; + acc += *code; + if acc > SUM_PROBS { + return Err(Error::InvalidAnsHistogram); + } + } + dist[omit_pos] = SUM_PROBS - acc; + + Ok(alphabet_size) + } + + fn build_alias_map(alphabet_size: usize, log_bucket_size: usize, dist: &[u16]) -> Vec { + #[derive(Debug)] + struct WorkingBucket { + dist: u16, + alias_symbol: u16, + alias_offset: u16, + alias_cutoff: u16, + } + + let bucket_size = 1u16 << log_bucket_size; + let mut buckets: Vec<_> = dist + .iter() + .enumerate() + .map(|(i, &dist)| WorkingBucket { + dist, + alias_symbol: if i < alphabet_size { i as u16 } else { 0 }, + alias_offset: 0, + alias_cutoff: dist, + }) + .collect(); + + let mut underfull = Vec::new(); + let mut overfull = Vec::new(); + for (idx, &WorkingBucket { dist, .. }) in buckets.iter().enumerate() { + match dist.cmp(&bucket_size) { + std::cmp::Ordering::Less => underfull.push(idx), + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => overfull.push(idx), + } + } + while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) { + let by = bucket_size - buckets[u].alias_cutoff; + buckets[o].alias_cutoff -= by; + buckets[u].alias_symbol = o as u16; + buckets[u].alias_offset = buckets[o].alias_cutoff; + match buckets[o].alias_cutoff.cmp(&bucket_size) { + std::cmp::Ordering::Less => underfull.push(o), + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => overfull.push(o), + } + } + + // Assertion failure only happens if `dist` doesn't sum to `SUM_PROB`, which is checked + // before building alias map. + assert!(overfull.is_empty() && underfull.is_empty()); + + buckets + .iter() + .enumerate() + .map(|(idx, bucket)| { + if bucket.alias_cutoff == bucket_size { + Bucket { + dist: bucket.dist, + alias_symbol: idx as u8, + alias_offset: 0, + alias_cutoff: 0, + alias_dist_xor: 0, + } + } else { + Bucket { + dist: bucket.dist, + alias_symbol: bucket.alias_symbol as u8, + alias_offset: bucket.alias_offset - bucket.alias_cutoff, + alias_cutoff: bucket.alias_cutoff as u8, + alias_dist_xor: bucket.dist ^ buckets[bucket.alias_symbol as usize].dist, + } + } + }) + .collect() + } + + // log_alphabet_size: 5 + u(2) + pub fn decode(br: &mut BitReader, log_alpha_size: usize) -> Result { + debug_assert!((5..=8).contains(&log_alpha_size)); + let table_size = (1u16 << log_alpha_size) as usize; + // 4 <= log_bucket_size <= 7 + let log_bucket_size = LOG_SUM_PROBS - log_alpha_size; + let bucket_size = 1u16 << log_bucket_size; + let bucket_mask = bucket_size as u32 - 1; + + let mut dist = vec![0u16; table_size]; + let alphabet_size = if br.read(1)? != 0 { + if br.read(1)? != 0 { + Self::decode_dist_single_symbol(br, &mut dist)? + } else { + Self::decode_dist_two_symbols(br, &mut dist)? + } + } else if br.read(1)? != 0 { + Self::decode_dist_evenly_distributed(br, &mut dist)? + } else { + Self::decode_dist_complex(br, &mut dist)? + }; + + if let Some(single_sym_idx) = dist.iter().position(|&d| d == SUM_PROBS) { + let buckets = dist + .into_iter() + .enumerate() + .map(|(i, dist)| Bucket { + dist, + alias_symbol: single_sym_idx as u8, + alias_offset: bucket_size * i as u16, + alias_cutoff: 0, + alias_dist_xor: dist ^ SUM_PROBS, + }) + .collect(); + return Ok(Self { + buckets, + log_bucket_size, + bucket_mask, + single_symbol: Some(single_sym_idx as u32), + }); + } + + Ok(Self { + buckets: Self::build_alias_map(alphabet_size, log_bucket_size, &dist), + log_bucket_size, + bucket_mask, + single_symbol: None, + }) + } + + fn read_u8(bitstream: &mut BitReader) -> Result { + Ok(if bitstream.read(1)? != 0 { + let n = bitstream.read(3)?; + ((1 << n) + bitstream.read(n as usize)?) as u8 + } else { + 0 + }) + } + + fn read_prefix(br: &mut BitReader) -> Result { + // Prefix code lookup table. + #[rustfmt::skip] + const TABLE: [(u8, u8); 128] = [ + (10, 3), (12, 7), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), (11, 6), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), (13, 7), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), (11, 6), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4), + (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4), + ]; + + let index = br.peek(7); + let (sym, bits) = TABLE[index as usize]; + br.consume(bits as usize)?; + Ok(sym as u16) + } +} + +impl AnsHistogram { + #[inline(always)] + pub fn read(&self, br: &mut BitReader, state: &mut u32) -> Result { + let idx = *state & 0xfff; + let i = (idx >> self.log_bucket_size) as usize; + let pos = idx & self.bucket_mask; + + let bucket = self.buckets[i]; + let alias_symbol = bucket.alias_symbol as usize; + let alias_cutoff = bucket.alias_cutoff as u32; + let dist = bucket.dist as u32; + + let map_to_alias = pos >= alias_cutoff; + let (offset, dist_xor) = if map_to_alias { + (bucket.alias_offset as u32, bucket.alias_dist_xor as u32) + } else { + (0, 0) + }; + + let dist = dist ^ dist_xor; + let symbol = if map_to_alias { alias_symbol } else { i }; + let offset = offset + pos; + + let next_state = (*state >> LOG_SUM_PROBS) * dist + offset; + let appended_state = (next_state << 16) | br.peek(16) as u32; + let select_appended = next_state < (1 << 16); + *state = if select_appended { + appended_state + } else { + next_state + }; + br.consume(if select_appended { 16 } else { 0 })?; + Ok(symbol as u32) + } + + // For optimizing fast-lossless case. + #[allow(unused)] + #[inline] + pub fn single_symbol(&self) -> Option { + self.single_symbol + } +} + +#[derive(Debug)] +pub struct AnsCodes { + histograms: Vec, +} + +impl AnsCodes { + pub fn decode(num: usize, log_alpha_size: usize, br: &mut BitReader) -> Result { + let histograms = (0..num) + .map(|_| AnsHistogram::decode(br, log_alpha_size)) + .collect::>()?; + Ok(Self { histograms }) + } +} + +#[derive(Debug)] +pub struct AnsReader(u32); + +impl AnsReader { + /// Expected final ANS state. + const CHECKSUM: u32 = 0x130000; + + pub fn new_unused() -> Self { + Self(0) + } + + pub fn init(br: &mut BitReader) -> Result { + let initial_state = br.read(32)? as u32; + Ok(Self(initial_state)) + } + + pub fn read(&mut self, codes: &AnsCodes, br: &mut BitReader, ctx: usize) -> Result { + codes.histograms[ctx].read(br, &mut self.0) + } -//use crate::bit_reader::BitReader; -//use crate::error::Error; + pub fn check_final_state(self) -> Result<()> { + if self.0 == Self::CHECKSUM { + Ok(()) + } else { + Err(Error::AnsChecksumMismatch) + } + } +} diff --git a/jxl/src/entropy_coding/context_map.rs b/jxl/src/entropy_coding/context_map.rs index fa744d0..4fb8b8e 100644 --- a/jxl/src/entropy_coding/context_map.rs +++ b/jxl/src/entropy_coding/context_map.rs @@ -55,7 +55,7 @@ pub fn decode_context_map(num_contexts: usize, br: &mut BitReader) -> Result 2)?; - let reader = histograms.make_reader(br)?; + let mut reader = histograms.make_reader(br)?; let mut ctx_map: Vec = (0..num_contexts) .map(|_| { diff --git a/jxl/src/entropy_coding/decode.rs b/jxl/src/entropy_coding/decode.rs index e8c2ee1..4a402dd 100644 --- a/jxl/src/entropy_coding/decode.rs +++ b/jxl/src/entropy_coding/decode.rs @@ -13,10 +13,10 @@ use crate::entropy_coding::ans::*; use crate::entropy_coding::context_map::*; use crate::entropy_coding::huffman::*; use crate::entropy_coding::hybrid_uint::*; -use crate::error::Error; +use crate::error::{Error, Result}; use crate::headers::encodings::*; -pub fn decode_varint16(br: &mut BitReader) -> Result { +pub fn decode_varint16(br: &mut BitReader) -> Result { if br.read(1)? != 0 { let nbits = br.read(4)? as usize; if nbits == 0 { @@ -47,6 +47,7 @@ struct LZ77Params { #[derive(Debug)] enum Codes { Huffman(HuffmanCodes), + Ans(AnsCodes), } #[derive(Debug)] @@ -66,40 +67,39 @@ pub struct Histograms { #[derive(Debug)] pub struct Reader<'a> { histograms: &'a Histograms, + ans_reader: AnsReader, } impl<'a> Reader<'a> { fn read_internal( - &self, + &mut self, br: &mut BitReader, uint_config: &HybridUint, cluster: usize, - ) -> Result { + ) -> Result { let symbol = match &self.histograms.codes { Codes::Huffman(hc) => hc.read(br, cluster)?, + Codes::Ans(ans) => self.ans_reader.read(ans, br, cluster)?, }; uint_config.read(symbol, br) } - pub fn read(&self, br: &mut BitReader, context: usize) -> Result { + pub fn read(&mut self, br: &mut BitReader, context: usize) -> Result { assert!(!self.histograms.lz77_params.enabled); let cluster = self.histograms.context_map[context] as usize; self.read_internal(br, &self.histograms.uint_configs[cluster], cluster) } - pub fn check_final_state(self) -> Result<(), Error> { + pub fn check_final_state(self) -> Result<()> { match &self.histograms.codes { Codes::Huffman(_) => Ok(()), + Codes::Ans(_) => self.ans_reader.check_final_state(), } } } impl Histograms { - pub fn decode( - num_contexts: usize, - br: &mut BitReader, - allow_lz77: bool, - ) -> Result { + pub fn decode(num_contexts: usize, br: &mut BitReader, allow_lz77: bool) -> Result { let lz77_params = LZ77Params::read_unconditional(&(), br, &Empty {})?; if !allow_lz77 && lz77_params.enabled { return Err(Error::LZ77Disallowed); @@ -133,12 +133,16 @@ impl Histograms { }; let num_histograms = *context_map.iter().max().unwrap() + 1; let uint_configs = ((0..num_histograms).map(|_| HybridUint::decode(log_alpha_size, br))) - .collect::>()?; + .collect::>()?; let codes = if use_prefix_code { Codes::Huffman(HuffmanCodes::decode(num_histograms as usize, br)?) } else { - unimplemented!(); + Codes::Ans(AnsCodes::decode( + num_histograms as usize, + log_alpha_size, + br, + )?) }; Ok(Histograms { @@ -150,26 +154,29 @@ impl Histograms { codes, }) } - fn make_reader_impl( - &self, - _br: &mut BitReader, - _image_width: Option, - ) -> Result { + + fn make_reader_impl(&self, br: &mut BitReader, _image_width: Option) -> Result { if self.lz77_params.enabled { unimplemented!() } - Ok(Reader { histograms: self }) + + let ans_reader = if matches!(self.codes, Codes::Ans(_)) { + AnsReader::init(br)? + } else { + AnsReader::new_unused() + }; + + Ok(Reader { + histograms: self, + ans_reader, + }) } - pub fn make_reader(&self, br: &mut BitReader) -> Result { + pub fn make_reader(&self, br: &mut BitReader) -> Result { self.make_reader_impl(br, None) } - pub fn make_reader_with_width( - &self, - br: &mut BitReader, - image_width: usize, - ) -> Result { + pub fn make_reader_with_width(&self, br: &mut BitReader, image_width: usize) -> Result { self.make_reader_impl(br, Some(image_width)) } } diff --git a/jxl/src/error.rs b/jxl/src/error.rs index 2ede5ae..71c726c 100644 --- a/jxl/src/error.rs +++ b/jxl/src/error.rs @@ -55,6 +55,10 @@ pub enum Error { AlphabetTooLargeHuff(usize), #[error("Invalid Huffman code")] InvalidHuffman, + #[error("Invalid ANS histogram")] + InvalidAnsHistogram, + #[error("ANS stream checksum mismatch")] + AnsChecksumMismatch, #[error("Integer too large: nbits {0} > 29")] IntegerTooLarge(u32), #[error("Invalid context map: context id {0} > 255")] @@ -77,4 +81,4 @@ pub enum Error { ArithmeticOverflow, } -pub type Result = std::result::Result; +pub type Result = std::result::Result;