Skip to content

Commit

Permalink
fix(inflate): use inputwrapper struct instead of iter to simplify inp…
Browse files Browse the repository at this point in the history
…ut reading and change some data types for performance
  • Loading branch information
oyvindln committed Nov 19, 2024
1 parent 9f1fc5e commit 423bdf8
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 66 deletions.
114 changes: 52 additions & 62 deletions miniz_oxide/src/inflate/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use super::*;
use crate::shared::{update_adler32, HUFFMAN_LENGTH_ORDER};
use ::core::cell::Cell;

use ::core::cmp;
use ::core::convert::TryInto;
use ::core::{cmp, slice};

use self::output_buffer::OutputBuffer;
use self::output_buffer::{InputWrapper, OutputBuffer};

pub const TINFL_LZ_DICT_SIZE: usize = 32_768;

Expand Down Expand Up @@ -47,7 +47,7 @@ impl HuffmanTable {

/// Get the symbol and the code length from the huffman tree.
#[inline]
fn tree_lookup(&self, fast_symbol: i32, bit_buf: BitBuffer, mut code_len: u32) -> (i32, u32) {
fn tree_lookup(&self, fast_symbol: i32, bit_buf: BitBuffer, mut code_len: u8) -> (i32, u32) {
let mut symbol = fast_symbol;
// We step through the tree until we encounter a positive value, which indicates a
// symbol.
Expand All @@ -65,7 +65,9 @@ impl HuffmanTable {
break;
}
}
(symbol, code_len)
// Note: Using a u8 for code_len inside this function seems to improve performance, but changing it
// in localvars seems to worsen things so we convert it to a u32 here.
(symbol, u32::from(code_len))
}

#[inline]
Expand All @@ -87,7 +89,7 @@ impl HuffmanTable {
}
} else {
// We didn't get a symbol from the fast lookup table, so check the tree instead.
Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS.into()))
Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS))
}
}
}
Expand Down Expand Up @@ -370,10 +372,12 @@ const DIST_BASE: [u16; 30] = [
/// Get the number of extra bits used for a distance code.
/// (Code numbers above `NUM_DISTANCE_CODES` will give some garbage
/// value.)
#[inline(always)]
const fn num_extra_bits_for_distance_code(code: u8) -> u8 {
// TODO: Need to verify that this is faster on all platforms.
// This can be easily calculated without a lookup.
let c = code >> 1;
c - (c != 0) as u8
c.saturating_sub(1)
}

/// The mask used when indexing the base/extra arrays.
Expand All @@ -392,27 +396,12 @@ fn memset<T: Copy>(slice: &mut [T], val: T) {
/// # Panics
/// Panics if there are less than two bytes left.
#[inline]
fn read_u16_le(iter: &mut slice::Iter<u8>) -> u16 {
fn read_u16_le(iter: &mut InputWrapper) -> u16 {
let ret = {
let two_bytes = iter.as_ref()[..2].try_into().unwrap();
let two_bytes = iter.as_slice()[..2].try_into().unwrap_or_default();
u16::from_le_bytes(two_bytes)
};
iter.nth(1);
ret
}

/// Read an le u32 value from the slice iterator.
///
/// # Panics
/// Panics if there are less than four bytes left.
#[inline(always)]
#[cfg(target_pointer_width = "64")]
fn read_u32_le(iter: &mut slice::Iter<u8>) -> u32 {
let ret = {
let four_bytes: [u8; 4] = iter.as_ref()[..4].try_into().unwrap();
u32::from_le_bytes(four_bytes)
};
iter.nth(3);
iter.advance(2);
ret
}

Expand All @@ -423,10 +412,10 @@ fn read_u32_le(iter: &mut slice::Iter<u8>) -> u32 {
/// This function assumes that there is at least 4 bytes left in the input buffer.
#[inline(always)]
#[cfg(target_pointer_width = "64")]
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>) {
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut InputWrapper) {
// Read four bytes into the buffer at once.
if l.num_bits < 30 {
l.bit_buf |= BitBuffer::from(read_u32_le(in_iter)) << l.num_bits;
l.bit_buf |= BitBuffer::from(in_iter.read_u32_le()) << l.num_bits;
l.num_bits += 32;
}
}
Expand All @@ -435,7 +424,7 @@ fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>) {
/// Ensures at least 16 bits are present, requires at least 2 bytes in the in buffer.
#[inline(always)]
#[cfg(not(target_pointer_width = "64"))]
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>) {
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut InputWrapper) {
// If the buffer is 32-bit wide, read 2 bytes instead.
if l.num_bits < 15 {
l.bit_buf |= BitBuffer::from(read_u16_le(in_iter)) << l.num_bits;
Expand Down Expand Up @@ -491,7 +480,7 @@ fn decode_huffman_code<F>(
l: &mut LocalVars,
table: usize,
flags: u32,
in_iter: &mut slice::Iter<u8>,
in_iter: &mut InputWrapper,
f: F,
) -> Action
where
Expand All @@ -501,7 +490,7 @@ where
// ready in the bit buffer to start decoding the next huffman code.
if l.num_bits < 15 {
// First, make sure there is enough data in the bit buffer to decode a huffman code.
if in_iter.len() < 2 {
if in_iter.bytes_left() < 2 {
// If there is less than 2 bytes left in the input buffer, we try to look up
// the huffman code with what's available, and return if that doesn't succeed.
// Original explanation in miniz:
Expand Down Expand Up @@ -581,7 +570,7 @@ where
// Mask out the length value.
symbol &= 511;
} else {
let res = r.tables[table].tree_lookup(symbol, l.bit_buf, u32::from(FAST_LOOKUP_BITS));
let res = r.tables[table].tree_lookup(symbol, l.bit_buf, FAST_LOOKUP_BITS);
symbol = res.0;
code_len = res.1;
};
Expand All @@ -599,13 +588,13 @@ where
/// returning the result.
/// If reading fails, `Action::End is returned`
#[inline]
fn read_byte<F>(in_iter: &mut slice::Iter<u8>, flags: u32, f: F) -> Action
fn read_byte<F>(in_iter: &mut InputWrapper, flags: u32, f: F) -> Action
where
F: FnOnce(u8) -> Action,
{
match in_iter.next() {
match in_iter.read_byte() {
None => end_of_input(flags),
Some(&byte) => f(byte),
Some(byte) => f(byte),
}
}

Expand All @@ -618,7 +607,7 @@ where
fn read_bits<F>(
l: &mut LocalVars,
amount: u32,
in_iter: &mut slice::Iter<u8>,
in_iter: &mut InputWrapper,
flags: u32,
f: F,
) -> Action
Expand Down Expand Up @@ -647,7 +636,7 @@ where
}

#[inline]
fn pad_to_bytes<F>(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>, flags: u32, f: F) -> Action
fn pad_to_bytes<F>(l: &mut LocalVars, in_iter: &mut InputWrapper, flags: u32, f: F) -> Action
where
F: FnOnce(&mut LocalVars) -> Action,
{
Expand Down Expand Up @@ -854,7 +843,7 @@ struct LocalVars {
pub num_bits: u32,
pub dist: u32,
pub counter: u32,
pub num_extra: u32,
pub num_extra: u8,
}

#[inline]
Expand Down Expand Up @@ -981,7 +970,7 @@ fn apply_match(
/// and already improves decompression speed a fair bit.
fn decompress_fast(
r: &mut DecompressorOxide,
in_iter: &mut slice::Iter<u8>,
in_iter: &mut InputWrapper,
out_buf: &mut OutputBuffer,
flags: u32,
local_vars: &mut LocalVars,
Expand All @@ -1001,7 +990,7 @@ fn decompress_fast(
// + 29 + 32 (left in bit buf, including last 13 dist extra) = 111 bits < 14 bytes
// We need the one extra byte as we may write one length and one full match
// before checking again.
if out_buf.bytes_left() < 259 || in_iter.len() < 14 {
if out_buf.bytes_left() < 259 || in_iter.bytes_left() < 14 {
state = State::DecodeLitlen;
break 'o TINFLStatus::Done;
}
Expand Down Expand Up @@ -1063,18 +1052,19 @@ fn decompress_fast(
// The symbol was a length code.
// # Optimization
// Mask the value to avoid bounds checks
// We could use get_unchecked later if can statically verify that
// this will never go out of bounds.
l.num_extra = u32::from(LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
// While the maximum is checked, the compiler isn't able to know that the
// value won't wrap around here.
l.num_extra = LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK];
l.counter = u32::from(LENGTH_BASE[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
// Length and distance codes have a number of extra bits depending on
// the base, which together with the base gives us the exact value.

// We need to make sure we have at least 33 (so min 5 bytes) bits in the buffer at this spot.
fill_bit_buffer(&mut l, in_iter);
if l.num_extra != 0 {
let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1);
l.bit_buf >>= l.num_extra;
l.num_bits -= l.num_extra;
l.num_bits -= u32::from(l.num_extra);
l.counter += extra_bits as u32;
}

Expand All @@ -1093,7 +1083,7 @@ fn decompress_fast(
break 'o TINFLStatus::Failed;
}

l.num_extra = u32::from(num_extra_bits_for_distance_code(symbol as u8));
l.num_extra = num_extra_bits_for_distance_code(symbol as u8);
l.dist = u32::from(DIST_BASE[symbol as usize]);
} else {
state.begin(InvalidCodeLen);
Expand All @@ -1104,7 +1094,7 @@ fn decompress_fast(
fill_bit_buffer(&mut l, in_iter);
let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1);
l.bit_buf >>= l.num_extra;
l.num_bits -= l.num_extra;
l.num_bits -= u32::from(l.num_extra);
l.dist += extra_bits as u32;
}

Expand Down Expand Up @@ -1194,7 +1184,7 @@ pub fn decompress(
return (TINFLStatus::BadParam, 0, 0);
}

let mut in_iter = in_buf.iter();
let mut in_iter = InputWrapper::from_slice(in_buf);

let mut state = r.state;

Expand All @@ -1206,7 +1196,7 @@ pub fn decompress(
num_bits: r.num_bits,
dist: r.dist,
counter: r.counter,
num_extra: r.num_extra,
num_extra: r.num_extra as u8,
};

let mut status = 'state_machine: loop {
Expand Down Expand Up @@ -1351,20 +1341,20 @@ pub fn decompress(
}),

RawMemcpy2 => generate_state!(state, 'state_machine, {
if in_iter.len() > 0 {
if in_iter.bytes_left() > 0 {
// Copy as many raw bytes as possible from the input to the output using memcpy.
// Raw block lengths are limited to 64 * 1024, so casting through usize and u32
// is not an issue.
let space_left = out_buf.bytes_left();
let bytes_to_copy = cmp::min(cmp::min(
space_left,
in_iter.len()),
in_iter.bytes_left()),
l.counter as usize
);

out_buf.write_slice(&in_iter.as_slice()[..bytes_to_copy]);

in_iter.nth(bytes_to_copy - 1);
in_iter.advance(bytes_to_copy);
l.counter -= bytes_to_copy as u32;
Action::Jump(RawMemcpy1)
} else {
Expand Down Expand Up @@ -1456,7 +1446,7 @@ pub fn decompress(
}),

ReadExtraBitsCodeSize => generate_state!(state, 'state_machine, {
let num_extra = l.num_extra;
let num_extra = l.num_extra.into();
read_bits(&mut l, num_extra, &mut in_iter, flags, |l, mut extra_bits| {
// Mask to avoid a bounds check.
extra_bits += [3, 3, 11][(l.dist as usize - 16) & 3];
Expand All @@ -1478,7 +1468,7 @@ pub fn decompress(
}),

DecodeLitlen => generate_state!(state, 'state_machine, {
if in_iter.len() < 4 || out_buf.bytes_left() < 2 {
if in_iter.bytes_left() < 4 || out_buf.bytes_left() < 2 {
// See if we can decode a literal with the data we have left.
// Jumps to next state (WriteSymbol) if successful.
decode_huffman_code(
Expand All @@ -1496,7 +1486,7 @@ pub fn decompress(
// If there is enough space, use the fast inner decompression
// function.
out_buf.bytes_left() >= 259 &&
in_iter.len() >= 14
in_iter.bytes_left() >= 14
{
let (status, new_state) = decompress_fast(
r,
Expand Down Expand Up @@ -1587,7 +1577,7 @@ pub fn decompress(
// We could use get_unchecked later if can statically verify that
// this will never go out of bounds.
l.num_extra =
u32::from(LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK];
l.counter = u32::from(LENGTH_BASE[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
// Length and distance codes have a number of extra bits depending on
// the base, which together with the base gives us the exact value.
Expand All @@ -1600,7 +1590,7 @@ pub fn decompress(
}),

ReadExtraBitsLitlen => generate_state!(state, 'state_machine, {
let num_extra = l.num_extra;
let num_extra = l.num_extra.into();
read_bits(&mut l, num_extra, &mut in_iter, flags, |l, extra_bits| {
l.counter += extra_bits as u32;
Action::Jump(DecodeDistance)
Expand All @@ -1622,7 +1612,7 @@ pub fn decompress(
// Invalid distance code.
return Action::Jump(InvalidDist)
}
l.num_extra = u32::from(num_extra_bits_for_distance_code(symbol as u8));
l.num_extra = num_extra_bits_for_distance_code(symbol as u8);
l.dist = u32::from(DIST_BASE[symbol]);
if l.num_extra != 0 {
// ReadEXTRA_BITS_DISTACNE
Expand All @@ -1634,7 +1624,7 @@ pub fn decompress(
}),

ReadExtraBitsDistance => generate_state!(state, 'state_machine, {
let num_extra = l.num_extra;
let num_extra = l.num_extra.into();
read_bits(&mut l, num_extra, &mut in_iter, flags, |l, extra_bits| {
l.dist += extra_bits as u32;
Action::Jump(HuffDecodeOuterLoop2)
Expand Down Expand Up @@ -1710,9 +1700,9 @@ pub fn decompress(
if r.finish != 0 {
pad_to_bytes(&mut l, &mut in_iter, flags, |_| Action::None);

let in_consumed = in_buf.len() - in_iter.len();
let in_consumed = in_buf.len() - in_iter.bytes_left();
let undo = undo_bytes(&mut l, in_consumed as u32) as usize;
in_iter = in_buf[in_consumed - undo..].iter();
in_iter = InputWrapper::from_slice(in_buf[in_consumed - undo..].iter().as_slice());

l.bit_buf &= ((1 as BitBuffer) << l.num_bits) - 1;
debug_assert_eq!(l.num_bits, 0);
Expand Down Expand Up @@ -1765,7 +1755,7 @@ pub fn decompress(
let in_undo = if status != TINFLStatus::NeedsMoreInput
&& status != TINFLStatus::FailedCannotMakeProgress
{
undo_bytes(&mut l, (in_buf.len() - in_iter.len()) as u32) as usize
undo_bytes(&mut l, (in_buf.len() - in_iter.bytes_left()) as u32) as usize
} else {
0
};
Expand All @@ -1785,7 +1775,7 @@ pub fn decompress(
r.num_bits = l.num_bits;
r.dist = l.dist;
r.counter = l.counter;
r.num_extra = l.num_extra;
r.num_extra = l.num_extra.into();

r.bit_buf &= ((1 as BitBuffer) << r.num_bits) - 1;

Expand Down Expand Up @@ -1816,7 +1806,7 @@ pub fn decompress(

(
status,
in_buf.len() - in_iter.len() - in_undo,
in_buf.len() - in_iter.bytes_left() - in_undo,
out_buf.position() - out_pos,
)
}
Expand Down Expand Up @@ -1911,7 +1901,7 @@ mod test {
num_bits: d.num_bits,
dist: d.dist,
counter: d.counter,
num_extra: d.num_extra,
num_extra: d.num_extra as u8,
};
init_tree(&mut d, &mut l).unwrap();
let llt = &d.tables[LITLEN_TABLE];
Expand Down
Loading

0 comments on commit 423bdf8

Please sign in to comment.