Skip to content

Commit

Permalink
Add Take for REEArray (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Mar 28, 2024
1 parent c56ba98 commit f3ce3ac
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions vortex-ree/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use vortex::array::Array;
use vortex::compute::flatten::{flatten, FlattenFn, FlattenedArray};
use vortex::array::primitive::PrimitiveArray;
use vortex::array::{Array, ArrayRef};
use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray};
use vortex::compute::scalar_at::{scalar_at, ScalarAtFn};
use vortex::compute::take::{take, TakeFn};
use vortex::compute::ArrayCompute;
use vortex::match_each_integer_ptype;
use vortex::scalar::Scalar;
use vortex::validity::ArrayValidity;
use vortex_error::{VortexError, VortexResult};
Expand All @@ -17,6 +20,10 @@ impl ArrayCompute for REEArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}

fn take(&self) -> Option<&dyn TakeFn> {
Some(self)
}
}

impl FlattenFn for REEArray {
Expand Down Expand Up @@ -45,3 +52,39 @@ impl ScalarAtFn for REEArray {
scalar_at(self.values(), self.find_physical_index(index)?)
}
}

impl TakeFn for REEArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
let primitive_indices = flatten_primitive(indices)?;
let physical_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
primitive_indices
.typed_data::<$P>()
.iter()
.map(|idx| {
self.find_physical_index(*idx as usize)
.map(|loc| loc as u64)
})
.collect::<VortexResult<Vec<_>>>()?
});
take(self.values(), &PrimitiveArray::from(physical_indices))
}
}

#[cfg(test)]
mod test {
use vortex::array::downcast::DowncastArrayBuiltin;
use vortex::array::primitive::PrimitiveArray;
use vortex::compute::take::take;

use crate::REEArray;

#[test]
fn ree_take() {
let ree = REEArray::encode(&PrimitiveArray::from(vec![
1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5,
]))
.unwrap();
let taken = take(&ree, &PrimitiveArray::from(vec![9, 8, 1, 3])).unwrap();
assert_eq!(taken.as_primitive().typed_data::<i32>(), &[5, 5, 1, 4]);
}
}

0 comments on commit f3ce3ac

Please sign in to comment.