From 1bec4c20127356845b4c9a83e6a3a86c8b5902f0 Mon Sep 17 00:00:00 2001 From: Will Manning Date: Fri, 5 Apr 2024 18:41:46 -0400 Subject: [PATCH] Optimize bitpacked `take` (#192) --- vortex-array/src/array/sparse/compute.rs | 135 ++++++++++---- vortex-array/src/array/sparse/mod.rs | 3 + vortex-array/src/compress.rs | 2 +- vortex-array/src/serde/mod.rs | 2 +- vortex-array/src/validity/owned.rs | 6 +- vortex-dict/src/compress.rs | 4 +- vortex-fastlanes/Cargo.toml | 4 + vortex-fastlanes/benches/bitpacking.rs | 3 + vortex-fastlanes/benches/bitpacking_take.rs | 143 +++++++++++++++ vortex-fastlanes/src/bitpacking/compress.rs | 77 ++------ vortex-fastlanes/src/bitpacking/compute.rs | 190 +++++++++++++++++--- vortex-fastlanes/src/bitpacking/mod.rs | 16 ++ vortex-fastlanes/src/lib.rs | 1 + 13 files changed, 459 insertions(+), 127 deletions(-) create mode 100644 vortex-fastlanes/benches/bitpacking_take.rs diff --git a/vortex-array/src/array/sparse/compute.rs b/vortex-array/src/array/sparse/compute.rs index cfdea6affd..cbc7a80323 100644 --- a/vortex-array/src/array/sparse/compute.rs +++ b/vortex-array/src/array/sparse/compute.rs @@ -129,10 +129,10 @@ impl TakeFn for SparseArray { fn take(&self, indices: &dyn Array) -> VortexResult { let flat_indices = flatten_primitive(indices)?; // if we are taking a lot of values we should build a hashmap - let (positions, physical_take_indices) = if indices.len() > 512 { - take_map(self, flat_indices)? + let (positions, physical_take_indices) = if indices.len() > 128 { + take_map(self, &flat_indices)? } else { - take_search_sorted(self, flat_indices)? + take_search_sorted(self, &flat_indices)? }; let taken_values = take(self.values(), &physical_take_indices)?; @@ -149,7 +149,7 @@ impl TakeFn for SparseArray { fn take_map( array: &SparseArray, - indices: PrimitiveArray, + indices: &PrimitiveArray, ) -> VortexResult<(PrimitiveArray, PrimitiveArray)> { let indices_map: HashMap = array .resolved_indices() @@ -162,7 +162,7 @@ fn take_map( .iter() .map(|pi| *pi as u64) .enumerate() - .filter_map(|(i, pi)| indices_map.get(&pi).copied().map(|phy_idx| (i as u64, phy_idx))) + .filter_map(|(i, pi)| indices_map.get(&pi).map(|phy_idx| (i as u64, phy_idx))) .unzip() }); Ok(( @@ -173,8 +173,9 @@ fn take_map( fn take_search_sorted( array: &SparseArray, - indices: PrimitiveArray, + indices: &PrimitiveArray, ) -> VortexResult<(PrimitiveArray, PrimitiveArray)> { + // adjust the input indices (to take) by the internal index offset of the array let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| { indices.typed_data::<$P>() .iter() @@ -183,25 +184,42 @@ fn take_search_sorted( }); // TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work - let physical_indices = PrimitiveArray::from( - adjusted_indices - .iter() - .map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64)) - .collect::>>()?, - ); + // search_sorted for the adjusted indices (need to validate that they are an exact match still) + let physical_indices = adjusted_indices + .iter() + .map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64)) + .collect::>>()?; + + // filter out indices that are out of bounds, which will cause the take to fail + let (adjusted_indices, physical_indices): (Vec, Vec) = adjusted_indices + .iter() + .zip_eq(physical_indices) + .filter(|(_, phys_idx)| *phys_idx < array.indices().len() as u64) + .unzip(); + + let physical_indices = PrimitiveArray::from(physical_indices); let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?; - let (positions, patch_indices): (Vec, Vec) = match_each_integer_ptype!(taken_indices.ptype(), |$P| { + let exact_matches: Vec = match_each_integer_ptype!(taken_indices.ptype(), |$P| { taken_indices .typed_data::<$P>() .iter() - .copied() - .enumerate() .zip_eq(adjusted_indices) - .zip_eq(physical_indices.typed_data::()) - .filter(|(((_, taken_idx), orig_idx), _)| *taken_idx as usize == *orig_idx) - .map(|(((i, _), _), physical_idx)| (i as u64, *physical_idx)) - .unzip() + .map(|(taken_idx, adj_idx)| *taken_idx as usize == adj_idx) + .collect() }); + let (positions, patch_indices): (Vec, Vec) = physical_indices + .typed_data::() + .iter() + .enumerate() + .filter_map(|(i, phy_idx)| { + // search_sorted != binary search, so we need to filter out indices that weren't found + if exact_matches[i] { + Some((i as u64, *phy_idx)) + } else { + None + } + }) + .unzip(); Ok(( PrimitiveArray::from(positions), PrimitiveArray::from(patch_indices), @@ -210,23 +228,30 @@ fn take_search_sorted( #[cfg(test)] mod test { + use itertools::Itertools; use vortex_schema::{DType, FloatWidth, Nullability}; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::primitive::PrimitiveArray; + use crate::array::sparse::compute::take_map; use crate::array::sparse::SparseArray; use crate::array::Array; + use crate::compute::as_contiguous::as_contiguous; use crate::compute::take::take; use crate::scalar::Scalar; - #[test] - fn sparse_take() { - let sparse = SparseArray::new( + fn sparse_array() -> SparseArray { + SparseArray::new( PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(), PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(), 100, Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), - ); + ) + } + + #[test] + fn sparse_take() { + let sparse = sparse_array(); let taken = take(&sparse, &PrimitiveArray::from(vec![0, 47, 47, 0, 99])).unwrap(); assert_eq!( taken @@ -248,12 +273,7 @@ mod test { #[test] fn nonexistent_take() { - let sparse = SparseArray::new( - PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(), - PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(), - 100, - Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), - ); + let sparse = sparse_array(); let taken = take(&sparse, &PrimitiveArray::from(vec![69])).unwrap(); assert_eq!( taken @@ -275,12 +295,7 @@ mod test { #[test] fn ordered_take() { - let sparse = SparseArray::new( - PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(), - PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(), - 100, - Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), - ); + let sparse = sparse_array(); let taken = take(&sparse, &PrimitiveArray::from(vec![69, 37])).unwrap(); assert_eq!( taken @@ -300,4 +315,54 @@ mod test { ); assert_eq!(taken.len(), 2); } + + #[test] + fn take_slices_and_reassemble() { + let sparse = sparse_array(); + let indices: PrimitiveArray = (0u64..10).collect_vec().into(); + let slices = (0..10) + .map(|i| sparse.slice(i * 10, (i + 1) * 10).unwrap()) + .collect_vec(); + + let taken = slices + .iter() + .map(|s| take(s, &indices).unwrap()) + .collect_vec(); + for i in [1, 2, 5, 6, 7, 8] { + assert_eq!(taken[i].as_sparse().indices().len(), 0); + } + for i in [0, 3, 4, 9] { + assert_eq!(taken[i].as_sparse().indices().len(), 1); + } + + let contiguous = as_contiguous(&taken).unwrap(); + assert_eq!( + contiguous + .as_sparse() + .indices() + .as_primitive() + .typed_data::(), + [0u64, 7, 7, 9] // relative offsets + ); + assert_eq!( + contiguous + .as_sparse() + .values() + .as_primitive() + .typed_data::(), + sparse.values().as_primitive().typed_data() + ); + } + + #[test] + fn test_take_map() { + let sparse = sparse_array(); + let indices = PrimitiveArray::from((0u64..100).collect_vec()); + let (positions, patch_indices) = take_map(&sparse, &indices).unwrap(); + assert_eq!( + positions.typed_data::(), + sparse.indices().as_primitive().typed_data() + ); + assert_eq!(patch_indices.typed_data::(), [0u64, 1, 2, 3]); + } } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 9321c567a9..fd88430e11 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -96,6 +96,9 @@ impl SparseArray { // TODO(ngates): replace this with a binary search that tells us if we get an exact match. let idx = search_sorted(self.indices(), true_index, SearchSortedSide::Left)?; + if idx >= self.indices().len() { + return Ok(None); + } // If the value at this index is equal to the true index, then it exists in the // indices array. diff --git a/vortex-array/src/compress.rs b/vortex-array/src/compress.rs index dbbfd0d6e1..6e9ab3ff4c 100644 --- a/vortex-array/src/compress.rs +++ b/vortex-array/src/compress.rs @@ -219,7 +219,7 @@ impl CompressCtx { ValidityView::Valid(_) | ValidityView::Invalid(_) => { Ok(Some(validity.to_owned_view())) } - ValidityView::Array(a) => Ok(Some(Validity::array(self.compress(a, None)?))), + ValidityView::Array(a) => Ok(Some(Validity::array(self.compress(a, None)?)?)), } } else { Ok(None) diff --git a/vortex-array/src/serde/mod.rs b/vortex-array/src/serde/mod.rs index c9b22a8e42..4735713f67 100644 --- a/vortex-array/src/serde/mod.rs +++ b/vortex-array/src/serde/mod.rs @@ -223,7 +223,7 @@ impl<'a> ReadCtx<'a> { [1u8] => Ok(Some(Validity::Invalid(self.read_usize()?))), [2u8] => Ok(Some(Validity::array( self.with_schema(&Validity::DTYPE).read()?, - ))), + )?)), _ => panic!("Invalid validity tag"), } } else { diff --git a/vortex-array/src/validity/owned.rs b/vortex-array/src/validity/owned.rs index 6dcfb91b18..3e5733fbca 100644 --- a/vortex-array/src/validity/owned.rs +++ b/vortex-array/src/validity/owned.rs @@ -28,11 +28,11 @@ pub enum Validity { impl Validity { pub const DTYPE: DType = DType::Bool(Nullability::NonNullable); - pub fn array(array: ArrayRef) -> Self { + pub fn array(array: ArrayRef) -> VortexResult { if !matches!(array.dtype(), &Validity::DTYPE) { - panic!("Validity array must be of type bool"); + vortex_bail!("Validity array must be of type bool"); } - Self::Array(array) + Ok(Self::Array(array)) } pub fn try_from_logical( diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index 9ed1eb1474..d1f7332eae 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -140,7 +140,7 @@ pub fn dict_encode_typed_primitive( validity.push(false); validity.extend(vec![true; values.len() - 1]); - Some(Validity::array(BoolArray::from(validity).into_array())) + Some(Validity::array(BoolArray::from(validity).into_array()).unwrap()) } else { None }; @@ -222,7 +222,7 @@ where validity.push(false); validity.extend(vec![true; offsets.len() - 2]); - Some(Validity::array(BoolArray::from(validity).into_array())) + Some(Validity::array(BoolArray::from(validity).into_array()).unwrap()) } else { None }; diff --git a/vortex-fastlanes/Cargo.toml b/vortex-fastlanes/Cargo.toml index d64989076e..24a265a240 100644 --- a/vortex-fastlanes/Cargo.toml +++ b/vortex-fastlanes/Cargo.toml @@ -33,3 +33,7 @@ simplelog = { workspace = true } [[bench]] name = "bitpacking" harness = false + +[[bench]] +name = "bitpacking_take" +harness = false diff --git a/vortex-fastlanes/benches/bitpacking.rs b/vortex-fastlanes/benches/bitpacking.rs index a9f3ee649b..465669ee84 100644 --- a/vortex-fastlanes/benches/bitpacking.rs +++ b/vortex-fastlanes/benches/bitpacking.rs @@ -29,6 +29,9 @@ fn pack_unpack(c: &mut Criterion) { }); let packed = bitpack_primitive(&values, bits); + let unpacked = unpack_primitive::(&packed, bits, 0, values.len()); + assert_eq!(unpacked, values); + c.bench_function("unpack_1M", |b| { b.iter(|| black_box(unpack_primitive::(&packed, bits, 0, values.len()))); }); diff --git a/vortex-fastlanes/benches/bitpacking_take.rs b/vortex-fastlanes/benches/bitpacking_take.rs new file mode 100644 index 0000000000..670347de8b --- /dev/null +++ b/vortex-fastlanes/benches/bitpacking_take.rs @@ -0,0 +1,143 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::{thread_rng, Rng}; +use vortex::array::downcast::DowncastArrayBuiltin; +use vortex::array::primitive::PrimitiveArray; +use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression}; +use vortex::compute::take::take; +use vortex::encoding::EncodingRef; +use vortex_fastlanes::{BitPackedEncoding, DowncastFastlanes}; + +fn values(len: usize, bits: usize) -> Vec { + let rng = thread_rng(); + let range = Uniform::new(0_u32, 2_u32.pow(bits as u32)); + rng.sample_iter(range).take(len).collect() +} + +fn bench_take(c: &mut Criterion) { + let cfg = CompressConfig::new().with_enabled([&BitPackedEncoding as EncodingRef]); + let ctx = CompressCtx::new(Arc::new(cfg)); + + let values = values(1_000_000, 8); + let uncompressed = PrimitiveArray::from(values.clone()); + let packed = BitPackedEncoding {} + .compress(&uncompressed, None, ctx) + .unwrap(); + + let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::>().into(); + c.bench_function("take_10_stratified", |b| { + b.iter(|| black_box(take(&packed, &stratified_indices).unwrap())); + }); + + let contiguous_indices: PrimitiveArray = (0..10).collect::>().into(); + c.bench_function("take_10_contiguous", |b| { + b.iter(|| black_box(take(&packed, &contiguous_indices).unwrap())); + }); + + let rng = thread_rng(); + let range = Uniform::new(0, values.len()); + let random_indices: PrimitiveArray = rng + .sample_iter(range) + .take(10_000) + .map(|i| i as u32) + .collect_vec() + .into(); + c.bench_function("take_10K_random", |b| { + b.iter(|| black_box(take(&packed, &random_indices).unwrap())); + }); + + let contiguous_indices: PrimitiveArray = (0..10_000).collect::>().into(); + c.bench_function("take_10K_contiguous", |b| { + b.iter(|| black_box(take(&packed, &contiguous_indices).unwrap())); + }); +} + +fn bench_patched_take(c: &mut Criterion) { + let cfg = CompressConfig::new().with_enabled([&BitPackedEncoding as EncodingRef]); + let ctx = CompressCtx::new(Arc::new(cfg)); + + let big_base2 = 1048576; + let num_exceptions = 10000; + let values = (0u32..big_base2 + num_exceptions).collect_vec(); + + let uncompressed = PrimitiveArray::from(values.clone()); + let packed = BitPackedEncoding {} + .compress(&uncompressed, None, ctx) + .unwrap(); + let packed = packed.as_bitpacked(); + assert!(packed.patches().is_some()); + assert_eq!( + packed.patches().unwrap().as_sparse().values().len(), + num_exceptions as usize + ); + + let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::>().into(); + c.bench_function("patched_take_10_stratified", |b| { + b.iter(|| black_box(take(packed, &stratified_indices).unwrap())); + }); + + let contiguous_indices: PrimitiveArray = (0..10).collect::>().into(); + c.bench_function("patched_take_10_contiguous", |b| { + b.iter(|| black_box(take(packed, &contiguous_indices).unwrap())); + }); + + let rng = thread_rng(); + let range = Uniform::new(0, values.len()); + let random_indices: PrimitiveArray = rng + .sample_iter(range) + .take(10_000) + .map(|i| i as u32) + .collect_vec() + .into(); + c.bench_function("patched_take_10K_random", |b| { + b.iter(|| black_box(take(packed, &random_indices).unwrap())); + }); + + let not_patch_indices: PrimitiveArray = (0u32..num_exceptions) + .cycle() + .take(10000) + .collect_vec() + .into(); + c.bench_function("patched_take_10K_contiguous_not_patches", |b| { + b.iter(|| black_box(take(packed, ¬_patch_indices).unwrap())); + }); + + let patch_indices: PrimitiveArray = (big_base2..big_base2 + num_exceptions) + .cycle() + .take(10000) + .collect_vec() + .into(); + c.bench_function("patched_take_10K_contiguous_patches", |b| { + b.iter(|| black_box(take(packed, &patch_indices).unwrap())); + }); + + // There are currently 2 magic parameters of note: + // 1. the threshold at which sparse take will switch from search_sorted to map (currently 128) + // 2. the threshold at which bitpacked take will switch from bulk patching to per chunk patching (currently 64) + // + // There are thus 3 cases to consider: + // 1. N < 64 per chunk, covered by patched_take_10K_random + // 2. N > 128 per chunk, covered by patched_take_10K_contiguous_* + // 3. 64 < N < 128 per chunk, which is what we're trying to cover here (with 100 per chunk). + // + // As a result of the above, we get both search_sorted and per chunk patching, almost entirely on patches. + // I've iterated on both thresholds (1) and (2) using this collection of benchmarks, and those + // were roughly the best values that I found. + let per_chunk_count = 100; + let adversarial_indices: PrimitiveArray = (0..(num_exceptions + 1024) / 1024) + .cycle() + .map(|chunk_idx| big_base2 - 1024 + chunk_idx * 1024) + .flat_map(|base_idx| (base_idx..(base_idx + per_chunk_count))) + .take(10000) + .collect_vec() + .into(); + c.bench_function("patched_take_10K_adversarial", |b| { + b.iter(|| black_box(take(packed, &adversarial_indices).unwrap())); + }); +} + +criterion_group!(benches, bench_take, bench_patched_take); +criterion_main!(benches); diff --git a/vortex-fastlanes/src/bitpacking/compress.rs b/vortex-fastlanes/src/bitpacking/compress.rs index 4ffd8570c7..f2df43cea0 100644 --- a/vortex-fastlanes/src/bitpacking/compress.rs +++ b/vortex-fastlanes/src/bitpacking/compress.rs @@ -9,7 +9,7 @@ use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression}; use vortex::compute::cast::cast; use vortex::compute::flatten::flatten_primitive; use vortex::match_each_integer_ptype; -use vortex::ptype::PType::{I16, I32, I64, I8, U16, U32, U64, U8}; +use vortex::ptype::PType::U8; use vortex::ptype::{NativePType, PType}; use vortex::scalar::{ListScalarVec, Scalar}; use vortex::stats::Stat; @@ -18,7 +18,7 @@ use vortex::view::ToOwnedView; use vortex_error::{vortex_bail, vortex_err, VortexResult}; use crate::downcast::DowncastFastlanes; -use crate::{BitPackedArray, BitPackedEncoding}; +use crate::{match_integers_by_width, BitPackedArray, BitPackedEncoding}; impl EncodingCompression for BitPackedEncoding { fn cost(&self) -> u8 { @@ -75,7 +75,7 @@ impl EncodingCompression for BitPackedEncoding { return Ok(parray.clone().into_array()); } - let packed = bitpack(parray, bit_width); + let packed = bitpack(parray, bit_width)?; let validity = ctx.compress_validity(parray.validity())?; @@ -101,17 +101,13 @@ impl EncodingCompression for BitPackedEncoding { } } -fn bitpack(parray: &PrimitiveArray, bit_width: usize) -> ArrayRef { +fn bitpack(parray: &PrimitiveArray, bit_width: usize) -> VortexResult { // We know the min is > 0, so it's safe to re-interpret signed integers as unsigned. // TODO(ngates): we should implement this using a vortex cast to centralize this hack. - let bytes = match parray.ptype() { - I8 | U8 => bitpack_primitive(parray.buffer().typed_data::(), bit_width), - I16 | U16 => bitpack_primitive(parray.buffer().typed_data::(), bit_width), - I32 | U32 => bitpack_primitive(parray.buffer().typed_data::(), bit_width), - I64 | U64 => bitpack_primitive(parray.buffer().typed_data::(), bit_width), - _ => panic!("Unsupported ptype {}", parray.ptype()), - }; - PrimitiveArray::from(bytes).into_array() + let bytes = match_integers_by_width!(parray.ptype(), |$P| { + bitpack_primitive(parray.buffer().typed_data::<$P>(), bit_width) + }); + Ok(PrimitiveArray::from(bytes).into_array()) } pub fn bitpack_primitive(array: &[T], bit_width: usize) -> Vec { @@ -169,25 +165,12 @@ pub fn unpack(array: &BitPackedArray) -> VortexResult { let encoded = flatten_primitive(&cast(array.encoded(), U8.into())?)?; let ptype: PType = array.dtype().try_into()?; - let mut unpacked = match ptype { - I8 | U8 => PrimitiveArray::from_nullable( - unpack_primitive::(encoded.typed_data::(), bit_width, offset, length), - array.validity().to_owned_view(), - ), - I16 | U16 => PrimitiveArray::from_nullable( - unpack_primitive::(encoded.typed_data::(), bit_width, offset, length), - array.validity().to_owned_view(), - ), - I32 | U32 => PrimitiveArray::from_nullable( - unpack_primitive::(encoded.typed_data::(), bit_width, offset, length), + let mut unpacked = match_integers_by_width!(ptype, |$P| { + PrimitiveArray::from_nullable( + unpack_primitive::<$P>(encoded.typed_data::(), bit_width, offset, length), array.validity().to_owned_view(), - ), - I64 | U64 => PrimitiveArray::from_nullable( - unpack_primitive::(encoded.typed_data::(), bit_width, offset, length), - array.validity().to_owned_view(), - ), - _ => panic!("Unsupported ptype {}", ptype), - }; + ) + }); // Cast to signed if necessary if ptype.is_signed_int() { @@ -280,35 +263,11 @@ pub(crate) fn unpack_single(array: &BitPackedArray, index: usize) -> VortexResul let ptype: PType = array.dtype().try_into()?; let index_in_encoded = index + array.offset(); - let scalar: Scalar = unsafe { - match ptype { - I8 | U8 => unpack_single_primitive::( - encoded.typed_data::(), - bit_width, - index_in_encoded, - ) - .map(|v| v.into()), - I16 | U16 => unpack_single_primitive::( - encoded.typed_data::(), - bit_width, - index_in_encoded, - ) - .map(|v| v.into()), - I32 | U32 => unpack_single_primitive::( - encoded.typed_data::(), - bit_width, - index_in_encoded, - ) - .map(|v| v.into()), - I64 | U64 => unpack_single_primitive::( - encoded.typed_data::(), - bit_width, - index_in_encoded, - ) - .map(|v| v.into()), - _ => vortex_bail!("Unsupported ptype {}", ptype), - }? - }; + let scalar: Scalar = match_integers_by_width!(ptype, |$P| { + unsafe { + unpack_single_primitive::<$P>(encoded.typed_data::(), bit_width, index_in_encoded).map(|v| v.into()) + } + })?; // Cast to fix signedness and nullability scalar.cast(array.dtype()) diff --git a/vortex-fastlanes/src/bitpacking/compute.rs b/vortex-fastlanes/src/bitpacking/compute.rs index 1b69074d2e..42faa50353 100644 --- a/vortex-fastlanes/src/bitpacking/compute.rs +++ b/vortex-fastlanes/src/bitpacking/compute.rs @@ -1,20 +1,23 @@ use std::cmp::min; +use fastlanez::TryBitPack; use itertools::Itertools; +use vortex::array::downcast::DowncastArrayBuiltin; use vortex::array::primitive::PrimitiveArray; +use vortex::array::sparse::SparseArray; use vortex::array::{Array, ArrayRef}; -use vortex::compute::as_contiguous::as_contiguous; use vortex::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray}; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::take::{take, TakeFn}; use vortex::compute::ArrayCompute; use vortex::match_each_integer_ptype; +use vortex::ptype::NativePType; use vortex::scalar::Scalar; -use vortex_error::{vortex_err, VortexResult}; +use vortex::validity::OwnedValidity; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; use crate::bitpacking::compress::{unpack, unpack_single}; -use crate::downcast::DowncastFastlanes; -use crate::BitPackedArray; +use crate::{match_integers_by_width, unpack_single_primitive, BitPackedArray}; impl ArrayCompute for BitPackedArray { fn flatten(&self) -> Option<&dyn FlattenFn> { @@ -53,32 +56,118 @@ impl ScalarAtFn for BitPackedArray { impl TakeFn for BitPackedArray { fn take(&self, indices: &dyn Array) -> VortexResult { - let prim_indices = flatten_primitive(indices)?; - // Group indices into 1024 chunks and relativise them to the beginning of each chunk - let relative_indices: Vec<(usize, Vec)> = match_each_integer_ptype!(prim_indices.ptype(), |$P| { - let grouped_indices = prim_indices - .typed_data::<$P>() - .iter() - .group_by(|idx| (**idx / 1024) as usize); - grouped_indices - .into_iter() - .map(|(k, g)| (k, g.map(|idx| (*idx % 1024) as u16).collect())) - .collect() + let indices = flatten_primitive(indices)?; + let ptype = self.dtype().try_into()?; + let taken_validity = self.validity().map(|v| v.take(&indices)).transpose()?; + let taken = match_integers_by_width!(ptype, |$T| { + PrimitiveArray::from_nullable(take_primitive::<$T>(self, &indices)?, taken_validity) }); + Ok(taken.reinterpret_cast(ptype).into_array()) + } +} - let taken = relative_indices +fn take_primitive( + array: &BitPackedArray, + indices: &PrimitiveArray, +) -> VortexResult> { + // Group indices into 1024-element chunks and relativise them to the beginning of each chunk + let relative_indices: Vec<(usize, Vec)> = match_each_integer_ptype!(indices.ptype(), |$P| { + indices + .typed_data::<$P>() + .iter() + .group_by(|idx| (**idx / 1024) as usize) .into_iter() - .map(|(chunk, offsets)| { - let sliced = self.slice(chunk * 1024, min((chunk + 1) * 1024, self.len()))?; - - take( - &unpack(sliced.as_bitpacked())?, - &PrimitiveArray::from(offsets), - ) - }) - .collect::>>()?; - as_contiguous(&taken) + .map(|(k, g)| (k, g.map(|idx| (*idx % 1024) as u16).collect())) + .collect() + }); + + let bit_width = array.bit_width(); + let packed = flatten_primitive(array.encoded())?; + let packed = packed.typed_data::(); + + let patches = array + .patches() + .map(|p| { + p.maybe_sparse() + .ok_or(vortex_err!("Only sparse patches are currently supported!")) + }) + .transpose()?; + + // if we have a small number of relatively large batches, we gain by slicing and then patching inside the loop + // if we have a large number of relatively small batches, the overhead isn't worth it, and we're better off with a bulk patch + // roughly, if we have an average of less than 64 elements per batch, we prefer bulk patching + let prefer_bulk_patch = relative_indices.len() * 64 > indices.len(); + + // assuming the buffer is already allocated (which will happen at most once) + // then unpacking all 1024 elements takes ~8.8x as long as unpacking a single element + // see https://github.com/fulcrum-so/vortex/pull/190#issue-2223752833 + // however, the gap should be smaller with larger registers (e.g., AVX-512) vs the 128 bit + // ones on M2 Macbook Air. + let unpack_chunk_threshold = 8; + + let mut output = Vec::with_capacity(indices.len()); + let mut buffer: Vec = Vec::new(); + for (chunk, offsets) in relative_indices { + let packed_chunk = &packed[chunk * 128 * bit_width..][..128 * bit_width]; + if offsets.len() > unpack_chunk_threshold { + buffer.clear(); + TryBitPack::try_unpack_into(packed_chunk, bit_width, &mut buffer) + .map_err(|_| vortex_err!("Unsupported bit width {}", bit_width))?; + for index in &offsets { + output.push(buffer[*index as usize]); + } + } else { + for index in &offsets { + output.push(unsafe { + unpack_single_primitive::(packed_chunk, bit_width, *index as usize)? + }); + } + } + + if !prefer_bulk_patch { + if let Some(patches) = patches { + let patches_slice = + patches.slice(chunk * 1024, min((chunk + 1) * 1024, patches.len()))?; + let patches_slice = patches_slice + .maybe_sparse() + .ok_or(vortex_err!("Only sparse patches are currently supported!"))?; + let offsets = PrimitiveArray::from(offsets); + do_patch_for_take_primitive(patches_slice, &offsets, &mut output)?; + } + } } + + if prefer_bulk_patch { + if let Some(patches) = patches { + do_patch_for_take_primitive(patches, indices, &mut output)?; + } + } + + Ok(output) +} + +fn do_patch_for_take_primitive( + patches: &SparseArray, + indices: &PrimitiveArray, + output: &mut [T], +) -> VortexResult<()> { + let taken_patches = take(patches, indices)?; + let taken_patches = taken_patches + .maybe_sparse() + .ok_or(vortex_err!("Only sparse patches are currently supported!"))?; + + let base_index = output.len() - indices.len(); + let output_patches = flatten_primitive(taken_patches.values())?; + taken_patches + .resolved_indices() + .iter() + .map(|idx| base_index + *idx) + .zip_eq(output_patches.typed_data::()) + .for_each(|(idx, val)| { + output[idx] = *val; + }); + + Ok(()) } #[cfg(test)] @@ -86,6 +175,8 @@ mod test { use std::sync::Arc; use itertools::Itertools; + use rand::distributions::Uniform; + use rand::{thread_rng, Rng}; use vortex::array::downcast::DowncastArrayBuiltin; use vortex::array::primitive::{PrimitiveArray, PrimitiveEncoding}; use vortex::array::Array; @@ -112,6 +203,53 @@ mod test { assert_eq!(res_bytes, &[0, 62, 31, 33, 9, 18]); } + #[test] + fn take_random_indices() { + let cfg = CompressConfig::new().with_enabled([&BitPackedEncoding as EncodingRef]); + let ctx = CompressCtx::new(Arc::new(cfg)); + + let num_patches: usize = 128; + let values = (0..u16::MAX as u32 + num_patches as u32).collect::>(); + let uncompressed = PrimitiveArray::from(values.clone()); + let packed = BitPackedEncoding {} + .compress(&uncompressed, None, ctx) + .unwrap(); + let packed = packed.as_bitpacked(); + assert!(packed.patches().is_some()); + + let patches = packed.patches().unwrap().as_sparse(); + assert_eq!( + patches.resolved_indices(), + ((values.len() + 1 - num_patches)..values.len()).collect_vec() + ); + + let rng = thread_rng(); + let range = Uniform::new(0, values.len()); + let random_indices: PrimitiveArray = rng + .sample_iter(range) + .take(10_000) + .map(|i| i as u32) + .collect_vec() + .into(); + let taken = take(packed, &random_indices).unwrap(); + + // sanity check + random_indices + .typed_data::() + .iter() + .enumerate() + .for_each(|(ti, i)| { + assert_eq!( + scalar_at(packed, *i as usize).unwrap(), + Scalar::from(values[*i as usize]) + ); + assert_eq!( + scalar_at(&taken, ti).unwrap(), + Scalar::from(values[*i as usize]) + ); + }); + } + #[test] fn test_scalar_at() { let cfg = CompressConfig::new().with_enabled([&BitPackedEncoding as EncodingRef]); diff --git a/vortex-fastlanes/src/bitpacking/mod.rs b/vortex-fastlanes/src/bitpacking/mod.rs index 88639cccc9..5a6de529cd 100644 --- a/vortex-fastlanes/src/bitpacking/mod.rs +++ b/vortex-fastlanes/src/bitpacking/mod.rs @@ -228,6 +228,22 @@ impl Encoding for BitPackedEncoding { } } +#[macro_export] +macro_rules! match_integers_by_width { + ($self:expr, | $_:tt $enc:ident | $($body:tt)*) => ({ + macro_rules! __with__ {( $_ $enc:ident ) => ( $($body)* )} + use vortex::ptype::PType; + use vortex_error::vortex_bail; + match $self { + PType::I8 | PType::U8 => __with__! { u8 }, + PType::I16 | PType::U16 => __with__! { u16 }, + PType::I32 | PType::U32 => __with__! { u32 }, + PType::I64 | PType::U64 => __with__! { u64 }, + _ => vortex_bail!("Unsupported ptype {}", $self), + } + }) +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/vortex-fastlanes/src/lib.rs b/vortex-fastlanes/src/lib.rs index 3c200e050c..198d1c6a31 100644 --- a/vortex-fastlanes/src/lib.rs +++ b/vortex-fastlanes/src/lib.rs @@ -3,6 +3,7 @@ pub use bitpacking::*; pub use delta::*; +pub use downcast::*; use linkme::distributed_slice; pub use r#for::*; use vortex::encoding::{EncodingRef, ENCODINGS};