diff --git a/vortex-ree/src/compute.rs b/vortex-ree/src/compute.rs index f5890a3883..57b42162e6 100644 --- a/vortex-ree/src/compute.rs +++ b/vortex-ree/src/compute.rs @@ -1,7 +1,10 @@ -use vortex::array::Array; -use vortex::compute::flatten::{flatten, FlattenFn, FlattenedArray}; +use vortex::array::primitive::PrimitiveArray; +use vortex::array::{Array, ArrayRef}; +use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray}; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::take::{take, TakeFn}; use vortex::compute::ArrayCompute; +use vortex::match_each_integer_ptype; use vortex::scalar::Scalar; use vortex::validity::ArrayValidity; use vortex_error::{VortexError, VortexResult}; @@ -17,6 +20,10 @@ impl ArrayCompute for REEArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } + + fn take(&self) -> Option<&dyn TakeFn> { + Some(self) + } } impl FlattenFn for REEArray { @@ -45,3 +52,50 @@ impl ScalarAtFn for REEArray { scalar_at(self.values(), self.find_physical_index(index)?) } } + +impl TakeFn for REEArray { + fn take(&self, indices: &dyn Array) -> VortexResult { + let primitive_indices = flatten_primitive(indices)?; + let mut values_to_take: Vec = Vec::new(); + let physical_indices: Vec = match_each_integer_ptype!(primitive_indices.ptype(), |$P| { + primitive_indices + .typed_data::<$P>() + .iter() + .map(|idx| { + self.find_physical_index(*idx as usize).map(|loc| { + values_to_take + .iter() + .position(|to_take| *to_take == loc as u64) + .map(|p| p as u64) + .unwrap_or_else(|| { + let position = values_to_take.len(); + values_to_take.push(loc as u64); + position as u64 + }) + }) + }) + .collect::>>()? + }); + let taken_values = take(self.values(), &PrimitiveArray::from(values_to_take))?; + take(&taken_values, &PrimitiveArray::from(physical_indices)) + } +} + +#[cfg(test)] +mod test { + use vortex::array::downcast::DowncastArrayBuiltin; + use vortex::array::primitive::PrimitiveArray; + use vortex::compute::take::take; + + use crate::REEArray; + + #[test] + fn ree_take() { + let ree = REEArray::encode(&PrimitiveArray::from(vec![ + 1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5, + ])) + .unwrap(); + let taken = take(&ree, &PrimitiveArray::from(vec![8, 1, 3])).unwrap(); + assert_eq!(taken.as_primitive().typed_data::(), &[5, 1, 4]); + } +}