Skip to content

Commit

Permalink
Implement take for BoolArray
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 committed Mar 26, 2024
1 parent 2b3a96d commit 7e5fbc7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
46 changes: 44 additions & 2 deletions vortex-array/src/array/bool/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,52 @@
use arrow_buffer::BooleanBuffer;
use num_traits::AsPrimitive;

use vortex_error::VortexResult;

use crate::array::bool::BoolArray;
use crate::array::{Array, ArrayRef};
use crate::compute::flatten::flatten_primitive;
use crate::compute::take::TakeFn;
use vortex_error::VortexResult;
use crate::match_each_integer_ptype;
use crate::validity::ArrayValidity;

impl TakeFn for BoolArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
todo!()
let validity = self.validity().map(|v| v.take(indices)).transpose()?;
let indices = flatten_primitive(indices)?;
match_each_integer_ptype!(indices.ptype(), |$I| {
Ok(BoolArray::from_nullable(
take_bool(self.buffer(), indices.typed_data::<$I>()),
validity,
).into_array())
})
}
}

fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> Vec<bool> {
indices.iter().map(|&idx| bools.value(idx.as_())).collect()
}

#[cfg(test)]
mod test {
use crate::array::bool::BoolArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::compute::take::take;

#[test]
fn take_nullable() {
let reference = BoolArray::from_iter(vec![
Some(false),
Some(true),
Some(false),
None,
Some(false),
]);
let res = take(&reference, &PrimitiveArray::from(vec![0, 3, 4])).unwrap();
assert_eq!(
res.as_bool().buffer(),
BoolArray::from_iter(vec![Some(false), None, Some(false)]).buffer()
);
}
}
5 changes: 4 additions & 1 deletion vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::validity::{ArrayValidity, Validity};
use super::{check_slice_bounds, Array, ArrayRef, Encoding, EncodingId, EncodingRef, ENCODINGS};

mod compute;
mod flatten;
mod serde;
mod stats;

Expand Down Expand Up @@ -51,6 +50,10 @@ impl BoolArray {
)
}

pub fn from_nullable(values: Vec<bool>, validity: Option<Validity>) -> Self {
BoolArray::new(BooleanBuffer::from(values), validity)
}

#[inline]
pub fn buffer(&self) -> &BooleanBuffer {
&self.buffer
Expand Down

0 comments on commit 7e5fbc7

Please sign in to comment.