diff --git a/vortex-array/src/array/sparse/compress.rs b/vortex-array/src/array/sparse/compress.rs index 97dd58e806..43591be8de 100644 --- a/vortex-array/src/array/sparse/compress.rs +++ b/vortex-array/src/array/sparse/compress.rs @@ -32,7 +32,7 @@ impl EncodingCompression for SparseEncoding { ctx.named("values") .compress(sparse_array.values(), sparse_like.map(|sa| sa.values()))?, sparse_array.len(), - sparse_array.fill_value.clone(), + sparse_array.fill_value().clone(), ) .into_array()) } diff --git a/vortex-array/src/array/sparse/compute.rs b/vortex-array/src/array/sparse/compute.rs index 12c94f1bfb..d484d8af7d 100644 --- a/vortex-array/src/array/sparse/compute.rs +++ b/vortex-array/src/array/sparse/compute.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use arrow_buffer::BooleanBufferBuilder; use itertools::Itertools; use vortex_error::{vortex_bail, VortexResult}; @@ -9,10 +11,12 @@ use crate::array::{Array, ArrayRef}; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; +use crate::compute::take::{take, TakeFn}; use crate::compute::ArrayCompute; -use crate::match_each_native_ptype; use crate::ptype::NativePType; use crate::scalar::Scalar; +use crate::{match_each_integer_ptype, match_each_native_ptype}; impl ArrayCompute for SparseArray { fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> { @@ -26,6 +30,10 @@ impl ArrayCompute for SparseArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } + + fn take(&self) -> Option<&dyn TakeFn> { + Some(self) + } } impl AsContiguousFn for SparseArray { @@ -116,3 +124,141 @@ impl ScalarAtFn for SparseArray { } } } + +impl TakeFn for SparseArray { + fn take(&self, indices: &dyn Array) -> VortexResult { + let flat_indices = flatten_primitive(indices)?; + // if we are taking a lot of values we should build a hashmap + let exact_taken_indices = if indices.len() > 512 { + take_map(self, flat_indices)? + } else { + take_search_sorted(self, flat_indices)? + }; + + let taken_values = take(self.values(), &exact_taken_indices)?; + + Ok(SparseArray::new( + PrimitiveArray::from((0u64..exact_taken_indices.len() as u64).collect::>()) + .into_array(), + taken_values, + indices.len(), + self.fill_value().clone(), + ) + .into_array()) + } +} + +fn take_map(array: &SparseArray, indices: PrimitiveArray) -> VortexResult { + let indices_map: HashMap = array + .resolved_indices() + .iter() + .enumerate() + .map(|(i, r)| (*r as u64, i as u64)) + .collect(); + let patch_indices: Vec = match_each_integer_ptype!(indices.ptype(), |$P| { + indices.typed_data::<$P>() + .iter() + .map(|i| *i as u64) + .filter_map(|pi| indices_map.get(&pi).copied()) + .collect::>() + }); + Ok(PrimitiveArray::from(patch_indices)) +} + +fn take_search_sorted( + array: &SparseArray, + indices: PrimitiveArray, +) -> VortexResult { + let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| { + indices.typed_data::<$P>() + .iter() + .map(|i| *i as usize + array.indices_offset()) + .collect::>() + }); + + // TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work + let physical_indices = PrimitiveArray::from( + adjusted_indices + .iter() + .map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64)) + .collect::>>()?, + ); + let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?; + match_each_integer_ptype!(taken_indices.ptype(), |$P| { + Ok(PrimitiveArray::from(taken_indices + .typed_data::<$P>() + .iter() + .copied() + .zip_eq(adjusted_indices) + .zip_eq(physical_indices.typed_data::()) + .filter(|((taken_idx, orig_idx), _)| *taken_idx as usize == *orig_idx) + .map(|(_, physical_idx)| *physical_idx) + .collect::>())) + }) +} + +#[cfg(test)] +mod test { + use vortex_schema::{DType, FloatWidth, Nullability}; + + use crate::array::downcast::DowncastArrayBuiltin; + use crate::array::primitive::PrimitiveArray; + use crate::array::sparse::SparseArray; + use crate::array::Array; + use crate::compute::take::take; + use crate::scalar::Scalar; + + #[test] + fn sparse_take() { + let sparse = SparseArray::new( + PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(), + PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(), + 100, + Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), + ); + let taken = take(&sparse, &PrimitiveArray::from(vec![0, 47, 47, 0, 99])).unwrap(); + assert_eq!( + taken + .as_sparse() + .indices() + .as_primitive() + .typed_data::(), + [0, 1, 2, 3, 4] + ); + assert_eq!( + taken + .as_sparse() + .values() + .as_primitive() + .typed_data::(), + [1.23f64, 9.99, 9.99, 1.23, 3.5] + ); + } + + #[test] + fn nonexistent_take() { + let sparse = SparseArray::new( + PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(), + PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(), + 100, + Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), + ); + let taken = take(&sparse, &PrimitiveArray::from(vec![69])).unwrap(); + assert_eq!( + taken + .as_sparse() + .indices() + .as_primitive() + .typed_data::(), + [] + ); + assert_eq!( + taken + .as_sparse() + .values() + .as_primitive() + .typed_data::(), + [] + ); + } +} diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index a80bcbb46f..9321c567a9 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -1,6 +1,5 @@ use std::sync::{Arc, RwLock}; -use itertools::Itertools; use linkme::distributed_slice; use vortex_error::{vortex_bail, VortexResult}; use vortex_schema::DType; @@ -8,20 +7,18 @@ use vortex_schema::DType; use crate::array::constant::ConstantArray; use crate::array::{check_slice_bounds, Array, ArrayRef}; use crate::compress::EncodingCompression; -use crate::compute::cast::cast; use crate::compute::flatten::flatten_primitive; use crate::compute::scalar_at::scalar_at; use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::ArrayCompute; use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::ptype::PType; use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsCompute, StatsSet}; use crate::validity::ArrayValidity; use crate::validity::Validity; -use crate::{impl_array, ArrayWalker}; +use crate::{impl_array, match_each_integer_ptype, ArrayWalker}; mod compress; mod compute; @@ -112,12 +109,14 @@ impl SparseArray { /// Return indices as a vector of usize with the indices_offset applied. pub fn resolved_indices(&self) -> Vec { - flatten_primitive(cast(self.indices(), PType::U64.into()).unwrap().as_ref()) - .unwrap() - .typed_data::() - .iter() - .map(|v| (*v as usize) - self.indices_offset) - .collect_vec() + let flat_indices = flatten_primitive(self.indices()).unwrap(); + match_each_integer_ptype!(flat_indices.ptype(), |$P| { + flat_indices + .typed_data::<$P>() + .iter() + .map(|v| (*v as usize) - self.indices_offset) + .collect::>() + }) } }