Skip to content

Commit

Permalink
Fill forward
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 committed Mar 5, 2024
1 parent 13fcc77 commit 9ea0d57
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 4 deletions.
59 changes: 57 additions & 2 deletions vortex-array/src/array/bool/compute.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use crate::array::bool::BoolArray;
use crate::array::Array;
use crate::compute::cast::CastBoolFn;
use crate::array::{Array, ArrayRef};
use crate::compute::cast::{cast_bool, CastBoolFn};
use crate::compute::fill::FillForwardFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
use crate::scalar::{NullableScalar, Scalar, ScalarRef};
use crate::stats::Stat;

impl ArrayCompute for BoolArray {
fn cast_bool(&self) -> Option<&dyn CastBoolFn> {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
Expand All @@ -31,3 +37,52 @@ impl ScalarAtFn for BoolArray {
}
}
}

impl FillForwardFn for BoolArray {
fn fill_forward(&self) -> VortexResult<ArrayRef> {
if self.validity().is_none() {
Ok(dyn_clone::clone_box(self))
} else if self
.stats()
.get_or_compute_as::<usize>(&Stat::NullCount)
.unwrap()
== 0usize
{
return Ok(BoolArray::new(self.buffer().clone(), None).boxed());
} else {
let validity = cast_bool(self.validity().unwrap())?;
let bools = self.buffer();
let mut last_value = false;
let filled = bools
.iter()
.zip(validity.buffer().iter())
.map(|(v, valid)| {
if valid {
last_value = v;
}
last_value
})
.collect::<Vec<_>>();
Ok(BoolArray::from(filled).boxed())
}
}
}

#[cfg(test)]
mod test {
use crate::array::bool::BoolArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::compute;

#[test]
fn fill_forward() {
let barr = BoolArray::from_iter(vec![None, Some(false), None, Some(true), None]);
let filled = compute::fill::fill_forward(&barr).unwrap();
let filled_bool = filled.as_bool();
assert_eq!(
filled_bool.buffer().iter().collect::<Vec<bool>>(),
vec![false, false, false, true, true]
);
assert!(filled_bool.validity().is_none());
}
}
58 changes: 56 additions & 2 deletions vortex-array/src/array/primitive/compute/fill.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,51 @@
use num_traits::Zero;

use crate::array::primitive::PrimitiveArray;
use crate::array::ArrayRef;
use crate::array::{Array, ArrayRef};
use crate::compute::cast::cast_bool;
use crate::compute::fill::FillForwardFn;
use crate::error::VortexResult;
use crate::match_each_native_ptype;
use crate::stats::Stat;

impl FillForwardFn for PrimitiveArray {
fn fill_forward(&self) -> VortexResult<ArrayRef> {
todo!()
if self.validity().is_none() {
Ok(dyn_clone::clone_box(self))
} else if self
.stats()
.get_or_compute_as::<usize>(&Stat::NullCount)
.unwrap()
== 0usize
{
return Ok(PrimitiveArray::new(*self.ptype(), self.buffer().clone(), None).boxed());
} else {
match_each_native_ptype!(self.ptype(), |$P| {
let validity = cast_bool(self.validity().unwrap())?;
let typed_data = self.typed_data::<$P>();
let mut last_value = $P::zero();
let filled = typed_data
.iter()
.zip(validity.buffer().iter())
.map(|(v, valid)| {
if valid {
last_value = *v;
}
last_value
})
.collect::<Vec<_>>();
Ok(PrimitiveArray::from(filled).boxed())
})
}
}
}

#[cfg(test)]
mod test {
use crate::array::bool::BoolArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::array::Array;
use crate::compute;

#[test]
Expand All @@ -23,4 +56,25 @@ mod test {
assert_eq!(filled_primitive.typed_data::<u8>(), vec![0, 8, 8, 10, 10]);
assert!(filled_primitive.validity().is_none());
}

#[test]
fn all_none() {
let arr = PrimitiveArray::from_iter(vec![Option::<u8>::None, None, None, None, None]);
let filled = compute::fill::fill_forward(arr.as_ref()).unwrap();
let filled_primitive = filled.as_primitive();
assert_eq!(filled_primitive.typed_data::<u8>(), vec![0, 0, 0, 0, 0]);
assert!(filled_primitive.validity().is_none());
}

#[test]
fn nullable_non_null() {
let arr = PrimitiveArray::from_nullable(
vec![8u8, 10u8, 12u8, 14u8, 16u8],
Some(BoolArray::from(vec![true, true, true, true, true]).boxed()),
);
let filled = compute::fill::fill_forward(arr.as_ref()).unwrap();
let filled_primitive = filled.as_primitive();
assert_eq!(filled_primitive.typed_data::<u8>(), vec![8, 10, 12, 14, 16]);
assert!(filled_primitive.validity().is_none());
}
}
4 changes: 4 additions & 0 deletions vortex-array/src/array/primitive/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct StatsAccumulator<T: NativePType> {
is_sorted: bool,
is_strict_sorted: bool,
run_count: usize,
null_count: usize,
bit_widths: Vec<usize>,
}

Expand All @@ -122,6 +123,7 @@ impl<T: NativePType> StatsAccumulator<T> {
is_sorted: true,
is_strict_sorted: true,
run_count: 1,
null_count: 0,
bit_widths: vec![0; size_of::<T>() * 8 + 1],
};
stats.bit_widths[first_value.bit_width()] += 1;
Expand All @@ -133,6 +135,7 @@ impl<T: NativePType> StatsAccumulator<T> {
Some(n) => self.next(n),
None => {
self.bit_widths[0] += 1;
self.null_count += 1;
}
}
}
Expand Down Expand Up @@ -160,6 +163,7 @@ impl<T: NativePType> StatsAccumulator<T> {
StatsSet::from(HashMap::from([
(Stat::Min, self.min.into()),
(Stat::Max, self.max.into()),
(Stat::NullCount, self.null_count.into()),
(Stat::IsConstant, (self.min == self.max).into()),
(Stat::BitWidthFreq, ListScalarVec(self.bit_widths).into()),
(Stat::IsSorted, self.is_sorted.into()),
Expand Down

0 comments on commit 9ea0d57

Please sign in to comment.