From 9c3e82cfb4337e3241681e14e82bb8bb4bbe70de Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Wed, 6 Mar 2024 16:38:18 +0000 Subject: [PATCH] Fix dictionary encoding (#81) --- vortex-dict/src/compress.rs | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index ce6ef03372..bcb2d861fb 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -225,7 +225,7 @@ fn dict_encode_typed_varbin( validity: Option<&dyn Array>, ) -> (PrimitiveArray, VarBinArray) where - O: NativePType + Unsigned + FromPrimitive, + O: NativePType + Unsigned + FromPrimitive + AsPrimitive, K: NativePType + Unsigned + FromPrimitive + AsPrimitive, V: Fn(usize) -> U, U: AsRef<[u8]>, @@ -242,7 +242,7 @@ where let byte_ref = byte_val.as_ref(); let value_hash = hasher.hash_one(byte_ref); let raw_entry = lookup_dict.raw_entry_mut().from_hash(value_hash, |idx| { - byte_ref == value_lookup(idx.as_()).as_ref() + byte_ref == bytes_at_primitive(offsets.as_slice(), bytes.as_slice(), idx.as_()) }); let code: K = match raw_entry { @@ -252,7 +252,11 @@ where bytes.extend_from_slice(byte_ref); offsets.push(::from_usize(bytes.len()).unwrap()); vac.insert_with_hasher(value_hash, next_code, (), |idx| { - hasher.hash_one(value_lookup(idx.as_()).as_ref()) + hasher.hash_one(bytes_at_primitive( + offsets.as_slice(), + bytes.as_slice(), + idx.as_(), + )) }); next_code } @@ -272,6 +276,7 @@ where #[cfg(test)] mod test { + use vortex::array::downcast::DowncastArrayBuiltin; use vortex::array::primitive::PrimitiveArray; use vortex::array::varbin::VarBinArray; use vortex::compute::scalar_at::scalar_at; @@ -359,4 +364,19 @@ mod test { "again" ); } + + #[test] + fn repeated_values() { + let arr = VarBinArray::from(vec!["a", "a", "b", "b", "a", "b", "a", "b"]); + let (codes, values) = dict_encode_varbin(&arr); + assert_eq!( + values.bytes().as_primitive().typed_data::(), + "ab".as_bytes() + ); + assert_eq!( + values.offsets().as_primitive().typed_data::(), + &[0, 1, 2] + ); + assert_eq!(codes.typed_data::(), &[0u8, 0, 1, 1, 0, 1, 0, 1]); + } }