Skip to content

Commit

Permalink
don't return errors on too large requests on a reversed bitreader
Browse files Browse the repository at this point in the history
  • Loading branch information
KillingSpark committed Apr 1, 2024
1 parent 9f48937 commit ceb03d7
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 74 deletions.
4 changes: 2 additions & 2 deletions benches/reversedbitreader_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ruzstd::decoding::bit_reader_reverse::BitReaderReversed;
fn do_all_accesses(br: &mut BitReaderReversed, accesses: &[u8]) -> u64 {
let mut sum = 0;
for x in accesses {
sum += br.get_bits(*x).unwrap();
sum += br.get_bits(*x);
}
let _ = black_box(br);
sum
Expand All @@ -24,7 +24,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut br = BitReaderReversed::new(&rand_vec);
while br.bits_remaining() > 0 {
let x = rng.gen_range(0..20);
br.get_bits(x).unwrap();
br.get_bits(x);
access_vec.push(x);
}

Expand Down
50 changes: 16 additions & 34 deletions src/decoding/bit_reader_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,33 @@ impl<'s> BitReaderReversed<'s> {
}

#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> Result<u64, GetBitsError> {
pub fn get_bits(&mut self, n: u8) -> u64 {
if n == 0 {
return Ok(0);
return 0;
}
if self.bits_in_container >= n {
return Ok(self.get_bits_unchecked(n));
return self.get_bits_unchecked(n);
}

self.get_bits_cold(n)
}

#[cold]
fn get_bits_cold(&mut self, n: u8) -> Result<u64, GetBitsError> {
if n > 56 {
return Err(GetBitsError::TooManyBits {
num_requested_bits: usize::from(n),
limit: 56,
});
}

fn get_bits_cold(&mut self, n: u8) -> u64 {
let signed_n = n as isize;

if self.bits_remaining() <= 0 {
self.idx -= signed_n;
return Ok(0);
return 0;
}

if self.bits_remaining() < signed_n {
let emulated_read_shift = signed_n - self.bits_remaining();
let v = self.get_bits(self.bits_remaining() as u8)?;
let v = self.get_bits(self.bits_remaining() as u8);
debug_assert!(self.idx == 0);
let value = v << emulated_read_shift;
self.idx -= emulated_read_shift;
return Ok(value);
return value;
}

while (self.bits_in_container < n) && self.idx > 0 {
Expand All @@ -147,23 +140,18 @@ impl<'s> BitReaderReversed<'s> {

//if we reach this point there are enough bits in the container

Ok(self.get_bits_unchecked(n))
self.get_bits_unchecked(n)
}

#[inline(always)]
pub fn get_bits_triple(
&mut self,
n1: u8,
n2: u8,
n3: u8,
) -> Result<(u64, u64, u64), GetBitsError> {
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum = n1 as usize + n2 as usize + n3 as usize;
if sum == 0 {
return Ok((0, 0, 0));
return (0, 0, 0);
}
if sum > 56 {
// try and get the values separatly
return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?));
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}
let sum = sum as u8;

Expand All @@ -184,29 +172,23 @@ impl<'s> BitReaderReversed<'s> {
self.get_bits_unchecked(n3)
};

return Ok((v1, v2, v3));
return (v1, v2, v3);
}

self.get_bits_triple_cold(n1, n2, n3, sum)
}

#[cold]
fn get_bits_triple_cold(
&mut self,
n1: u8,
n2: u8,
n3: u8,
sum: u8,
) -> Result<(u64, u64, u64), GetBitsError> {
fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
let sum_signed = sum as isize;

if self.bits_remaining() <= 0 {
self.idx -= sum_signed;
return Ok((0, 0, 0));
return (0, 0, 0);
}

if self.bits_remaining() < sum_signed {
return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?));
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}

while (self.bits_in_container < sum) && self.idx > 0 {
Expand All @@ -233,7 +215,7 @@ impl<'s> BitReaderReversed<'s> {
self.get_bits_unchecked(n3)
};

Ok((v1, v2, v3))
(v1, v2, v3)
}

#[inline(always)]
Expand Down
12 changes: 6 additions & 6 deletions src/decoding/literals_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ fn decompress_literals(
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -136,11 +136,11 @@ fn decompress_literals(
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br)?;
decoder.init_state(&mut br);

while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol());
decoder.next_state(&mut br)?;
decoder.next_state(&mut br);
}
if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
return Err(DecompressLiteralsError::BitstreamReadMismatch {
Expand All @@ -158,7 +158,7 @@ fn decompress_literals(
let mut br = BitReaderReversed::new(source);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -168,10 +168,10 @@ fn decompress_literals(
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br)?;
decoder.init_state(&mut br);
while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol());
decoder.next_state(&mut br)?;
decoder.next_state(&mut br);
}
bytes_read += source.len() as u32;
}
Expand Down
18 changes: 9 additions & 9 deletions src/decoding/sequence_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub fn decode_sequences(
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand Down Expand Up @@ -137,7 +137,7 @@ fn decode_sequences_with_rle(
});
}

let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?;
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);

if offset == 0 {
Expand All @@ -157,13 +157,13 @@ fn decode_sequences_with_rle(
// br.bits_remaining() / 8,
//);
if scratch.ll_rle.is_none() {
ll_dec.update_state(br)?;
ll_dec.update_state(br);
}
if scratch.ml_rle.is_none() {
ml_dec.update_state(br)?;
ml_dec.update_state(br);
}
if scratch.of_rle.is_none() {
of_dec.update_state(br)?;
of_dec.update_state(br);
}
}

Expand Down Expand Up @@ -212,7 +212,7 @@ fn decode_sequences_without_rle(
});
}

let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?;
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);

if offset == 0 {
Expand All @@ -231,9 +231,9 @@ fn decode_sequences_without_rle(
// br.bits_remaining(),
// br.bits_remaining() / 8,
//);
ll_dec.update_state(br)?;
ml_dec.update_state(br)?;
of_dec.update_state(br)?;
ll_dec.update_state(br);
ml_dec.update_state(br);
of_dec.update_state(br);
}

if br.bits_remaining() < 0 {
Expand Down
10 changes: 3 additions & 7 deletions src/fse/fse_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,19 @@ impl<'t> FSEDecoder<'t> {
if self.table.accuracy_log == 0 {
return Err(FSEDecoderError::TableIsUninitialized);
}
self.state = self.table.decode[bits.get_bits(self.table.accuracy_log)? as usize];
self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize];

Ok(())
}

pub fn update_state(
&mut self,
bits: &mut BitReaderReversed<'_>,
) -> Result<(), FSEDecoderError> {
pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
let num_bits = self.state.num_bits;
let add = bits.get_bits(num_bits)?;
let add = bits.get_bits(num_bits);
let base_line = self.state.base_line;
let new_state = base_line + add as u32;
self.state = self.table.decode[new_state as usize];

//println!("Update: {}, {} -> {}", base_line, add, self.state);
Ok(())
}
}

Expand Down
24 changes: 9 additions & 15 deletions src/huff0/huff0_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,20 @@ impl<'t> HuffmanDecoder<'t> {
self.table.decode[self.state as usize].symbol
}

pub fn init_state(
&mut self,
br: &mut BitReaderReversed<'_>,
) -> Result<u8, HuffmanDecoderError> {
pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
let num_bits = self.table.max_num_bits;
let new_bits = br.get_bits(num_bits)?;
let new_bits = br.get_bits(num_bits);
self.state = new_bits;
Ok(num_bits)
num_bits
}

pub fn next_state(
&mut self,
br: &mut BitReaderReversed<'_>,
) -> Result<u8, HuffmanDecoderError> {
pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
let num_bits = self.table.decode[self.state as usize].num_bits;
let new_bits = br.get_bits(num_bits)?;
let new_bits = br.get_bits(num_bits);
self.state <<= num_bits;
self.state &= self.table.decode.len() as u64 - 1;
self.state |= new_bits;
Ok(num_bits)
num_bits
}
}

Expand Down Expand Up @@ -235,7 +229,7 @@ impl HuffmanTable {
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -254,7 +248,7 @@ impl HuffmanTable {
loop {
let w = dec1.decode_symbol();
self.weights.push(w);
dec1.update_state(&mut br)?;
dec1.update_state(&mut br);

if br.bits_remaining() <= -1 {
//collect final states
Expand All @@ -264,7 +258,7 @@ impl HuffmanTable {

let w = dec2.decode_symbol();
self.weights.push(w);
dec2.update_state(&mut br)?;
dec2.update_state(&mut br);

if br.bits_remaining() <= -1 {
//collect final states
Expand Down
2 changes: 1 addition & 1 deletion src/tests/bit_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn test_bitreader_reversed() {
num_bits = 128 - bits_read;
}

let bits = br.get_bits(num_bits).unwrap();
let bits = br.get_bits(num_bits);
bits_read += num_bits;
accumulator |= u128::from(bits) << (128 - bits_read);
if bits_read >= 128 {
Expand Down

0 comments on commit ceb03d7

Please sign in to comment.