Skip to content

Commit

Permalink
Fix decompressing long inputs (#30)
Browse files Browse the repository at this point in the history
* Start adding new clippy lints

* Improve compressor bitwriter slightly

* Make num_bits a u32

* Add test for compressing long inputs

* Make `num_bits` a `u8`

* Remove unreachable `expect`s

* Add cache to Bindings action

* Add long input decompress test

* Fix long input decompress bug

* Fix clippy

* Limit n in `read_bits`

* Make `code` a `u32`

* Remove `dict_size`

* Use stream code constants

* Remove uneeded insert

* Make `enlarge_in` a `u64`

* Decrease position var size
  • Loading branch information
adumbidiot authored Oct 20, 2022
1 parent 50b0435 commit 69bf361
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 58 deletions.
23 changes: 6 additions & 17 deletions src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
use crate::constants::BASE64_KEY;
use crate::constants::CLOSE_CODE;
use crate::constants::START_CODE_BITS;
use crate::constants::U16_CODE;
use crate::constants::U8_CODE;
use crate::constants::URI_KEY;
use crate::IntoWideIter;
use std::collections::HashMap;
use std::collections::HashSet;
use std::convert::TryInto;

/// The starting size of a codepoint.
///
/// Compression starts with the following codes:
/// 0: u8
/// 1: u16
/// 2: close stream
const START_NUM_BITS: u8 = 2;

/// The stream code for a `u8`.
const U8_CODE: u32 = 0;

/// The stream code for a `u16`.
const U16_CODE: u32 = 1;

/// The number of "base codes",
/// the default codes of all streams.
///
Expand Down Expand Up @@ -96,7 +85,7 @@ where

bit_buffer: 0,

num_bits: START_NUM_BITS,
num_bits: START_CODE_BITS,

bit_position: 0,
bits_per_char,
Expand All @@ -114,10 +103,10 @@ where
{
Some(Some(first_w_char)) => {
if first_w_char < 256 {
self.write_bits(self.num_bits, U8_CODE);
self.write_bits(self.num_bits, U8_CODE.into());
self.write_bits(8, first_w_char.into());
} else {
self.write_bits(self.num_bits, U16_CODE);
self.write_bits(self.num_bits, U16_CODE.into());
self.write_bits(16, first_w_char.into());
}
self.decrement_enlarge_in();
Expand Down
16 changes: 15 additions & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
pub const URI_KEY: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-$";
pub const BASE64_KEY: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";

/// The stream code for a `u8`.
pub const U8_CODE: u8 = 0;

/// The stream code for a `u16`.
pub const U16_CODE: u8 = 1;

/// End of stream signal
pub const CLOSE_CODE: u16 = 2;
pub const CLOSE_CODE: u8 = 2;

/// The starting size of a code.
///
/// Compression starts with the following codes:
/// 0: u8
/// 1: u16
/// 2: close stream
pub const START_CODE_BITS: u8 = 2;
89 changes: 49 additions & 40 deletions src/decompress.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// TODO: Disable this
#![allow(clippy::cast_possible_truncation)]

use crate::constants::BASE64_KEY;
use crate::constants::CLOSE_CODE;
use crate::constants::START_CODE_BITS;
use crate::constants::U16_CODE;
use crate::constants::U8_CODE;
use crate::constants::URI_KEY;
use crate::IntoWideIter;
use std::convert::TryFrom;
Expand All @@ -12,8 +12,8 @@ use std::convert::TryInto;
pub struct DecompressContext<I> {
val: u16,
compressed_data: I,
position: usize,
reset_val: usize,
position: u16,
reset_val: u16,
}

impl<I> DecompressContext<I>
Expand All @@ -24,8 +24,17 @@ where
///
/// # Errors
/// Returns `None` if the iterator is empty.
///
/// # Panics
/// Panics if `bits_per_char` is greater than the number of bits in a `u16`.
#[inline]
pub fn new(mut compressed_data: I, reset_val: usize) -> Option<Self> {
pub fn new(mut compressed_data: I, bits_per_char: u8) -> Option<Self> {
assert!(usize::from(bits_per_char) <= std::mem::size_of::<u16>() * 8);

let reset_val_pow = bits_per_char - 1;
// (1 << 15) <= u16::MAX
let reset_val: u16 = 1 << reset_val_pow;

Some(DecompressContext {
val: compressed_data.next()?,
compressed_data,
Expand All @@ -36,7 +45,7 @@ where

#[inline]
pub fn read_bit(&mut self) -> Option<bool> {
let res = self.val & (self.position as u16);
let res = self.val & self.position;
self.position >>= 1;

if self.position == 0 {
Expand All @@ -47,11 +56,14 @@ where
Some(res != 0)
}

/// Read n bits.
///
/// `u32` is the return type as we expect all possible codes to be within that type's range.
#[inline]
pub fn read_bits(&mut self, n: usize) -> Option<u32> {
pub fn read_bits(&mut self, n: u8) -> Option<u32> {
let mut res = 0;
let max_power = 2_u32.pow(n as u32);
let mut power = 1;
let max_power: u32 = 1 << n;
let mut power: u32 = 1;
while power != max_power {
res |= u32::from(self.read_bit()?) * power;
power <<= 1;
Expand Down Expand Up @@ -162,69 +174,67 @@ pub fn decompress_from_uint8_array(compressed: &[u8]) -> Option<Vec<u16>> {
/// # Panics
/// Panics if `bits_per_char` is greater than the number of bits in a `u16`.
#[inline]
pub fn decompress_internal<I>(compressed: I, bits_per_char: usize) -> Option<Vec<u16>>
pub fn decompress_internal<I>(compressed: I, bits_per_char: u8) -> Option<Vec<u16>>
where
I: Iterator<Item = u16>,
{
assert!(bits_per_char <= std::mem::size_of::<u16>() * 8);

// u16::MAX < u32::MAX
let reset_val_pow = u32::try_from(bits_per_char).unwrap() - 1;
let reset_val = 2_usize.pow(reset_val_pow);
let mut ctx = match DecompressContext::new(compressed, reset_val) {
let mut ctx = match DecompressContext::new(compressed, bits_per_char) {
Some(ctx) => ctx,
None => return Some(Vec::new()),
};

let mut dictionary: Vec<Vec<u16>> = Vec::with_capacity(3);
let mut dictionary: Vec<Vec<u16>> = Vec::with_capacity(16);
for i in 0_u16..3_u16 {
dictionary.push(vec![i]);
}

let next = ctx.read_bits(2)?;
let first_entry: u16 = match next as u16 {
0 | 1 => {
let bits_to_read = (next * 8) + 8;
ctx.read_bits(bits_to_read as usize)? as u16
// u8::MAX > u2::MAX
let code = u8::try_from(ctx.read_bits(START_CODE_BITS)?).unwrap();
let first_entry = match code {
U8_CODE | U16_CODE => {
let bits_to_read = (code * 8) + 8;
// bits_to_read == 8 or 16 <= 16
u16::try_from(ctx.read_bits(bits_to_read)?).unwrap()
}
CLOSE_CODE => return Some(Vec::new()),
_ => return None,
};
dictionary.insert(3, vec![first_entry]);
dictionary.push(vec![first_entry]);

let mut w = vec![first_entry];
let mut result = vec![first_entry];
let mut num_bits = 3;
let mut enlarge_in = 4;
let mut dict_size = 4;
let mut num_bits: u8 = 3;
let mut enlarge_in: u64 = 4;
let mut entry;
loop {
let mut cc = ctx.read_bits(num_bits)? as usize;
match cc as u16 {
0 | 1 => {
let bits_to_read = (cc * 8) + 8;
let mut code = ctx.read_bits(num_bits)?;
match u8::try_from(code) {
Ok(code_u8 @ (U8_CODE | U16_CODE)) => {
let bits_to_read = (code_u8 * 8) + 8;
// if cc == 0 {
// if (errorCount++ > 10000) return "Error"; // TODO: Error logic
// }

let bits = ctx.read_bits(bits_to_read)? as u16;
// bits_to_read == 8 or 16 <= 16
let bits = u16::try_from(ctx.read_bits(bits_to_read)?).unwrap();
dictionary.push(vec![bits]);
dict_size += 1;
cc = dict_size - 1;
code = u32::try_from(dictionary.len() - 1).ok()?;
enlarge_in -= 1;
}
CLOSE_CODE => return Some(result),
Ok(CLOSE_CODE) => return Some(result),
_ => {}
}

if enlarge_in == 0 {
enlarge_in = 2_u32.pow(num_bits as u32);
enlarge_in = 1 << num_bits;
num_bits += 1;
}

if let Some(entry_value) = dictionary.get(cc) {
// Return error if code cannot be converted to dictionary index
let code_usize = usize::try_from(code).ok()?;
if let Some(entry_value) = dictionary.get(code_usize) {
entry = entry_value.clone();
} else if cc == dict_size {
} else if code_usize == dictionary.len() {
entry = w.clone();
entry.push(*w.first()?);
} else {
Expand All @@ -237,13 +247,12 @@ where
let mut to_be_inserted = w.clone();
to_be_inserted.push(*entry.first()?);
dictionary.push(to_be_inserted);
dict_size += 1;
enlarge_in -= 1;

w = entry;

if enlarge_in == 0 {
enlarge_in = 2_u32.pow(num_bits as u32);
enlarge_in = 1 << num_bits;
num_bits += 1;
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/valid_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ fn valid_long_input_round() {
assert_eq!(a, b, "[index={}] {} != {}", i, a, b);
}
}
assert_eq!(compressed, js_compressed);

let decompressed = lz_str::decompress(&compressed).expect("decompression failed");
assert_eq!(decompressed, data);
}

0 comments on commit 69bf361

Please sign in to comment.