diff --git a/vortex-ree/src/compute.rs b/vortex-ree/src/compute.rs index f5890a3883..6624c0d2cb 100644 --- a/vortex-ree/src/compute.rs +++ b/vortex-ree/src/compute.rs @@ -1,7 +1,10 @@ -use vortex::array::Array; -use vortex::compute::flatten::{flatten, FlattenFn, FlattenedArray}; +use vortex::array::primitive::PrimitiveArray; +use vortex::array::{Array, ArrayRef}; +use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray}; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::take::{take, TakeFn}; use vortex::compute::ArrayCompute; +use vortex::match_each_integer_ptype; use vortex::scalar::Scalar; use vortex::validity::ArrayValidity; use vortex_error::{VortexError, VortexResult}; @@ -17,6 +20,10 @@ impl ArrayCompute for REEArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } + + fn take(&self) -> Option<&dyn TakeFn> { + Some(self) + } } impl FlattenFn for REEArray { @@ -45,3 +52,39 @@ impl ScalarAtFn for REEArray { scalar_at(self.values(), self.find_physical_index(index)?) } } + +impl TakeFn for REEArray { + fn take(&self, indices: &dyn Array) -> VortexResult { + let primitive_indices = flatten_primitive(indices)?; + let physical_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| { + primitive_indices + .typed_data::<$P>() + .iter() + .map(|idx| { + self.find_physical_index(*idx as usize) + .map(|loc| loc as u64) + }) + .collect::>>()? + }); + take(self.values(), &PrimitiveArray::from(physical_indices)) + } +} + +#[cfg(test)] +mod test { + use vortex::array::downcast::DowncastArrayBuiltin; + use vortex::array::primitive::PrimitiveArray; + use vortex::compute::take::take; + + use crate::REEArray; + + #[test] + fn ree_take() { + let ree = REEArray::encode(&PrimitiveArray::from(vec![ + 1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5, + ])) + .unwrap(); + let taken = take(&ree, &PrimitiveArray::from(vec![9, 8, 1, 3])).unwrap(); + assert_eq!(taken.as_primitive().typed_data::(), &[5, 5, 1, 4]); + } +}