Skip to content

Commit

Permalink
Fill forward compute function (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Mar 5, 2024
1 parent d784211 commit e6a7477
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 5 deletions.
51 changes: 49 additions & 2 deletions vortex-array/src/array/bool/compute.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::array::bool::BoolArray;
use crate::array::Array;
use crate::compute::cast::CastBoolFn;
use crate::array::{Array, ArrayRef};
use crate::compute::cast::{cast_bool, CastBoolFn};
use crate::compute::fill::FillForwardFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
Expand All @@ -11,6 +12,10 @@ impl ArrayCompute for BoolArray {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
Expand All @@ -31,3 +36,45 @@ impl ScalarAtFn for BoolArray {
}
}
}

impl FillForwardFn for BoolArray {
fn fill_forward(&self) -> VortexResult<ArrayRef> {
if self.validity().is_none() {
Ok(dyn_clone::clone_box(self))
} else {
let validity = cast_bool(self.validity().unwrap())?;
let bools = self.buffer();
let mut last_value = false;
let filled = bools
.iter()
.zip(validity.buffer().iter())
.map(|(v, valid)| {
if valid {
last_value = v;
}
last_value
})
.collect::<Vec<_>>();
Ok(BoolArray::from(filled).boxed())
}
}
}

#[cfg(test)]
mod test {
use crate::array::bool::BoolArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::compute;

#[test]
fn fill_forward() {
let barr = BoolArray::from_iter(vec![None, Some(false), None, Some(true), None]);
let filled = compute::fill::fill_forward(&barr).unwrap();
let filled_bool = filled.as_bool();
assert_eq!(
filled_bool.buffer().iter().collect::<Vec<bool>>(),
vec![false, false, false, true, true]
);
assert!(filled_bool.validity().is_none());
}
}
3 changes: 1 addition & 2 deletions vortex-array/src/array/bool/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ impl StatsCompute for BoolArray {
let mut prev_bit = self.buffer().value(0);
let mut true_count: usize = if prev_bit { 1 } else { 0 };
let mut run_count: usize = 0;
for i in 1..self.len() {
let bit = self.buffer().value(i);
for bit in self.buffer().iter().skip(1) {
if bit {
true_count += 1
}
Expand Down
80 changes: 80 additions & 0 deletions vortex-array/src/array/primitive/compute/fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use num_traits::Zero;

use crate::array::primitive::PrimitiveArray;
use crate::array::{Array, ArrayRef};
use crate::compute::cast::cast_bool;
use crate::compute::fill::FillForwardFn;
use crate::error::VortexResult;
use crate::match_each_native_ptype;
use crate::stats::Stat;

impl FillForwardFn for PrimitiveArray {
fn fill_forward(&self) -> VortexResult<ArrayRef> {
if self.validity().is_none() {
Ok(dyn_clone::clone_box(self))
} else if self
.stats()
.get_or_compute_as::<usize>(&Stat::NullCount)
.unwrap()
== 0usize
{
return Ok(PrimitiveArray::new(*self.ptype(), self.buffer().clone(), None).boxed());
} else {
match_each_native_ptype!(self.ptype(), |$P| {
let validity = cast_bool(self.validity().unwrap())?;
let typed_data = self.typed_data::<$P>();
let mut last_value = $P::zero();
let filled = typed_data
.iter()
.zip(validity.buffer().iter())
.map(|(v, valid)| {
if valid {
last_value = *v;
}
last_value
})
.collect::<Vec<_>>();
Ok(filled.into())
})
}
}
}

#[cfg(test)]
mod test {
use crate::array::bool::BoolArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::array::Array;
use crate::compute;

#[test]
fn leading_none() {
let arr = PrimitiveArray::from_iter(vec![None, Some(8u8), None, Some(10), None]);
let filled = compute::fill::fill_forward(arr.as_ref()).unwrap();
let filled_primitive = filled.as_primitive();
assert_eq!(filled_primitive.typed_data::<u8>(), vec![0, 8, 8, 10, 10]);
assert!(filled_primitive.validity().is_none());
}

#[test]
fn all_none() {
let arr = PrimitiveArray::from_iter(vec![Option::<u8>::None, None, None, None, None]);
let filled = compute::fill::fill_forward(arr.as_ref()).unwrap();
let filled_primitive = filled.as_primitive();
assert_eq!(filled_primitive.typed_data::<u8>(), vec![0, 0, 0, 0, 0]);
assert!(filled_primitive.validity().is_none());
}

#[test]
fn nullable_non_null() {
let arr = PrimitiveArray::from_nullable(
vec![8u8, 10u8, 12u8, 14u8, 16u8],
Some(BoolArray::from(vec![true, true, true, true, true]).boxed()),
);
let filled = compute::fill::fill_forward(arr.as_ref()).unwrap();
let filled_primitive = filled.as_primitive();
assert_eq!(filled_primitive.typed_data::<u8>(), vec![8, 10, 12, 14, 16]);
assert!(filled_primitive.validity().is_none());
}
}
6 changes: 6 additions & 0 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::array::primitive::PrimitiveArray;
use crate::compute::cast::CastPrimitiveFn;
use crate::compute::fill::FillForwardFn;
use crate::compute::patch::PatchFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::ArrayCompute;

mod cast;
mod fill;
mod patch;
mod scalar_at;

Expand All @@ -13,6 +15,10 @@ impl ArrayCompute for PrimitiveArray {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}

fn patch(&self) -> Option<&dyn PatchFn> {
Some(self)
}
Expand Down
4 changes: 4 additions & 0 deletions vortex-array/src/array/primitive/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct StatsAccumulator<T: NativePType> {
is_sorted: bool,
is_strict_sorted: bool,
run_count: usize,
null_count: usize,
bit_widths: Vec<usize>,
}

Expand All @@ -122,6 +123,7 @@ impl<T: NativePType> StatsAccumulator<T> {
is_sorted: true,
is_strict_sorted: true,
run_count: 1,
null_count: 0,
bit_widths: vec![0; size_of::<T>() * 8 + 1],
};
stats.bit_widths[first_value.bit_width()] += 1;
Expand All @@ -133,6 +135,7 @@ impl<T: NativePType> StatsAccumulator<T> {
Some(n) => self.next(n),
None => {
self.bit_widths[0] += 1;
self.null_count += 1;
}
}
}
Expand Down Expand Up @@ -160,6 +163,7 @@ impl<T: NativePType> StatsAccumulator<T> {
StatsSet::from(HashMap::from([
(Stat::Min, self.min.into()),
(Stat::Max, self.max.into()),
(Stat::NullCount, self.null_count.into()),
(Stat::IsConstant, (self.min == self.max).into()),
(Stat::BitWidthFreq, ListScalarVec(self.bit_widths).into()),
(Stat::IsSorted, self.is_sorted.into()),
Expand Down
22 changes: 22 additions & 0 deletions vortex-array/src/compute/fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::array::{Array, ArrayRef};
use crate::error::{VortexError, VortexResult};

pub trait FillForwardFn {
fn fill_forward(&self) -> VortexResult<ArrayRef>;
}

pub fn fill_forward(array: &dyn Array) -> VortexResult<ArrayRef> {
if !array.dtype().is_nullable() {
return Ok(dyn_clone::clone_box(array));
}

array
.fill_forward()
.map(|t| t.fill_forward())
.unwrap_or_else(|| {
Err(VortexError::NotImplemented(
"fill_forward",
array.encoding().id(),
))
})
}
6 changes: 6 additions & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use cast::{CastBoolFn, CastPrimitiveFn};
use fill::FillForwardFn;
use patch::PatchFn;
use scalar_at::ScalarAtFn;
use take::TakeFn;

pub mod add;
pub mod as_contiguous;
pub mod cast;
pub mod fill;
pub mod patch;
pub mod repeat;
pub mod scalar_at;
Expand All @@ -21,6 +23,10 @@ pub trait ArrayCompute {
None
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
None
}

fn patch(&self) -> Option<&dyn PatchFn> {
None
}
Expand Down
7 changes: 6 additions & 1 deletion vortex-array/src/compute/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ pub fn patch(array: &dyn Array, patch: &dyn Array) -> VortexResult<ArrayRef> {
));
}

// TODO(ngates): check the dtype matches
if array.dtype().as_nullable() != patch.dtype().as_nullable() {
return Err(VortexError::MismatchedTypes(
array.dtype().clone(),
patch.dtype().clone(),
));
}

array
.patch()
Expand Down

0 comments on commit e6a7477

Please sign in to comment.