Skip to content

Commit

Permalink
Slightly simplified SparseArray FlattenFn (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Apr 5, 2024
1 parent 51ec1b3 commit 27580f2
Showing 1 changed file with 29 additions and 38 deletions.
67 changes: 29 additions & 38 deletions vortex-array/src/array/sparse/compute.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use arrow_buffer::BooleanBufferBuilder;
use itertools::Itertools;
use vortex_error::{vortex_bail, vortex_err, VortexResult};
use vortex_error::{vortex_bail, VortexResult};

use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::SparseArray;
use crate::array::{Array, ArrayRef};
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::flatten::{flatten, FlattenFn, FlattenedArray};
use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::ArrayCompute;
use crate::match_each_native_ptype;
Expand Down Expand Up @@ -67,54 +67,45 @@ impl FlattenFn for SparseArray {

let mut validity = BooleanBufferBuilder::new(self.len());
validity.append_n(self.len(), false);
let values = flatten(self.values())?;
let null_fill = self.fill_value().is_null();
if let FlattenedArray::Primitive(ref parray) = values {
match_each_native_ptype!(parray.ptype(), |$P| {
flatten_primitive::<$P>(
self,
parray,
indices,
null_fill,
validity
)
})
} else {
Err(vortex_err!(
"Cannot flatten SparseArray with non-primitive values"
))
}
let values = flatten_primitive(self.values())?;
match_each_native_ptype!(values.ptype(), |$P| {
flatten_sparse_values(
values.typed_data::<$P>(),
&indices,
self.len(),
self.fill_value(),
validity
)
})
}
}
fn flatten_primitive<T: NativePType>(
sparse_array: &SparseArray,
parray: &PrimitiveArray,
indices: Vec<usize>,
null_fill: bool,

fn flatten_sparse_values<T: NativePType>(
values: &[T],
indices: &[usize],
len: usize,
fill_value: &Scalar,
mut validity: BooleanBufferBuilder,
) -> VortexResult<FlattenedArray> {
let fill_value = if null_fill {
let primitive_fill = if fill_value.is_null() {
T::default()
} else {
sparse_array.fill_value.clone().try_into()?
fill_value.try_into()?
};
let mut values = vec![fill_value; sparse_array.len()];
let mut result = vec![primitive_fill; len];

for (offset, v) in parray.typed_data::<T>().iter().enumerate() {
let idx = indices[offset];
values[idx] = *v;
validity.set_bit(idx, true);
for (v, idx) in values.iter().zip_eq(indices) {
result[*idx] = *v;
validity.set_bit(*idx, true);
}

let validity = validity.finish();
if null_fill {
Ok(FlattenedArray::Primitive(PrimitiveArray::from_nullable(
values,
Some(validity.into()),
)))
let array = if fill_value.is_null() {
PrimitiveArray::from_nullable(result, Some(validity.into()))
} else {
Ok(FlattenedArray::Primitive(PrimitiveArray::from(values)))
}
PrimitiveArray::from(result)
};
Ok(FlattenedArray::Primitive(array))
}

impl ScalarAtFn for SparseArray {
Expand Down

0 comments on commit 27580f2

Please sign in to comment.