diff --git a/vortex-array/src/array/bool/compute.rs b/vortex-array/src/array/bool/compute.rs index 627e1b7946..2cbf1c077a 100644 --- a/vortex-array/src/array/bool/compute.rs +++ b/vortex-array/src/array/bool/compute.rs @@ -1,6 +1,7 @@ use crate::array::bool::BoolArray; -use crate::array::Array; -use crate::compute::cast::CastBoolFn; +use crate::array::{Array, ArrayRef}; +use crate::compute::cast::{cast_bool, CastBoolFn}; +use crate::compute::fill::FillForwardFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::ArrayCompute; use crate::error::VortexResult; @@ -11,6 +12,10 @@ impl ArrayCompute for BoolArray { Some(self) } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { + Some(self) + } + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -31,3 +36,45 @@ impl ScalarAtFn for BoolArray { } } } + +impl FillForwardFn for BoolArray { + fn fill_forward(&self) -> VortexResult { + if self.validity().is_none() { + Ok(dyn_clone::clone_box(self)) + } else { + let validity = cast_bool(self.validity().unwrap())?; + let bools = self.buffer(); + let mut last_value = false; + let filled = bools + .iter() + .zip(validity.buffer().iter()) + .map(|(v, valid)| { + if valid { + last_value = v; + } + last_value + }) + .collect::>(); + Ok(BoolArray::from(filled).boxed()) + } + } +} + +#[cfg(test)] +mod test { + use crate::array::bool::BoolArray; + use crate::array::downcast::DowncastArrayBuiltin; + use crate::compute; + + #[test] + fn fill_forward() { + let barr = BoolArray::from_iter(vec![None, Some(false), None, Some(true), None]); + let filled = compute::fill::fill_forward(&barr).unwrap(); + let filled_bool = filled.as_bool(); + assert_eq!( + filled_bool.buffer().iter().collect::>(), + vec![false, false, false, true, true] + ); + assert!(filled_bool.validity().is_none()); + } +} diff --git a/vortex-array/src/array/bool/stats.rs b/vortex-array/src/array/bool/stats.rs index f875e5933a..cf71e8622f 100644 --- a/vortex-array/src/array/bool/stats.rs +++ b/vortex-array/src/array/bool/stats.rs @@ -17,8 +17,7 @@ impl StatsCompute for BoolArray { let mut prev_bit = self.buffer().value(0); let mut true_count: usize = if prev_bit { 1 } else { 0 }; let mut run_count: usize = 0; - for i in 1..self.len() { - let bit = self.buffer().value(i); + for bit in self.buffer().iter().skip(1) { if bit { true_count += 1 } diff --git a/vortex-array/src/array/primitive/compute/fill.rs b/vortex-array/src/array/primitive/compute/fill.rs new file mode 100644 index 0000000000..f2dd4c98d7 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/fill.rs @@ -0,0 +1,80 @@ +use num_traits::Zero; + +use crate::array::primitive::PrimitiveArray; +use crate::array::{Array, ArrayRef}; +use crate::compute::cast::cast_bool; +use crate::compute::fill::FillForwardFn; +use crate::error::VortexResult; +use crate::match_each_native_ptype; +use crate::stats::Stat; + +impl FillForwardFn for PrimitiveArray { + fn fill_forward(&self) -> VortexResult { + if self.validity().is_none() { + Ok(dyn_clone::clone_box(self)) + } else if self + .stats() + .get_or_compute_as::(&Stat::NullCount) + .unwrap() + == 0usize + { + return Ok(PrimitiveArray::new(*self.ptype(), self.buffer().clone(), None).boxed()); + } else { + match_each_native_ptype!(self.ptype(), |$P| { + let validity = cast_bool(self.validity().unwrap())?; + let typed_data = self.typed_data::<$P>(); + let mut last_value = $P::zero(); + let filled = typed_data + .iter() + .zip(validity.buffer().iter()) + .map(|(v, valid)| { + if valid { + last_value = *v; + } + last_value + }) + .collect::>(); + Ok(filled.into()) + }) + } + } +} + +#[cfg(test)] +mod test { + use crate::array::bool::BoolArray; + use crate::array::downcast::DowncastArrayBuiltin; + use crate::array::primitive::PrimitiveArray; + use crate::array::Array; + use crate::compute; + + #[test] + fn leading_none() { + let arr = PrimitiveArray::from_iter(vec![None, Some(8u8), None, Some(10), None]); + let filled = compute::fill::fill_forward(arr.as_ref()).unwrap(); + let filled_primitive = filled.as_primitive(); + assert_eq!(filled_primitive.typed_data::(), vec![0, 8, 8, 10, 10]); + assert!(filled_primitive.validity().is_none()); + } + + #[test] + fn all_none() { + let arr = PrimitiveArray::from_iter(vec![Option::::None, None, None, None, None]); + let filled = compute::fill::fill_forward(arr.as_ref()).unwrap(); + let filled_primitive = filled.as_primitive(); + assert_eq!(filled_primitive.typed_data::(), vec![0, 0, 0, 0, 0]); + assert!(filled_primitive.validity().is_none()); + } + + #[test] + fn nullable_non_null() { + let arr = PrimitiveArray::from_nullable( + vec![8u8, 10u8, 12u8, 14u8, 16u8], + Some(BoolArray::from(vec![true, true, true, true, true]).boxed()), + ); + let filled = compute::fill::fill_forward(arr.as_ref()).unwrap(); + let filled_primitive = filled.as_primitive(); + assert_eq!(filled_primitive.typed_data::(), vec![8, 10, 12, 14, 16]); + assert!(filled_primitive.validity().is_none()); + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index 5516e07ef3..6047878bb5 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -1,10 +1,12 @@ use crate::array::primitive::PrimitiveArray; use crate::compute::cast::CastPrimitiveFn; +use crate::compute::fill::FillForwardFn; use crate::compute::patch::PatchFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::ArrayCompute; mod cast; +mod fill; mod patch; mod scalar_at; @@ -13,6 +15,10 @@ impl ArrayCompute for PrimitiveArray { Some(self) } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { + Some(self) + } + fn patch(&self) -> Option<&dyn PatchFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index fb94a5f41c..20fc036f22 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -110,6 +110,7 @@ struct StatsAccumulator { is_sorted: bool, is_strict_sorted: bool, run_count: usize, + null_count: usize, bit_widths: Vec, } @@ -122,6 +123,7 @@ impl StatsAccumulator { is_sorted: true, is_strict_sorted: true, run_count: 1, + null_count: 0, bit_widths: vec![0; size_of::() * 8 + 1], }; stats.bit_widths[first_value.bit_width()] += 1; @@ -133,6 +135,7 @@ impl StatsAccumulator { Some(n) => self.next(n), None => { self.bit_widths[0] += 1; + self.null_count += 1; } } } @@ -160,6 +163,7 @@ impl StatsAccumulator { StatsSet::from(HashMap::from([ (Stat::Min, self.min.into()), (Stat::Max, self.max.into()), + (Stat::NullCount, self.null_count.into()), (Stat::IsConstant, (self.min == self.max).into()), (Stat::BitWidthFreq, ListScalarVec(self.bit_widths).into()), (Stat::IsSorted, self.is_sorted.into()), diff --git a/vortex-array/src/compute/fill.rs b/vortex-array/src/compute/fill.rs new file mode 100644 index 0000000000..d7c6d0618e --- /dev/null +++ b/vortex-array/src/compute/fill.rs @@ -0,0 +1,22 @@ +use crate::array::{Array, ArrayRef}; +use crate::error::{VortexError, VortexResult}; + +pub trait FillForwardFn { + fn fill_forward(&self) -> VortexResult; +} + +pub fn fill_forward(array: &dyn Array) -> VortexResult { + if !array.dtype().is_nullable() { + return Ok(dyn_clone::clone_box(array)); + } + + array + .fill_forward() + .map(|t| t.fill_forward()) + .unwrap_or_else(|| { + Err(VortexError::NotImplemented( + "fill_forward", + array.encoding().id(), + )) + }) +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 4391e4ccf7..30fb93aa02 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -1,4 +1,5 @@ use cast::{CastBoolFn, CastPrimitiveFn}; +use fill::FillForwardFn; use patch::PatchFn; use scalar_at::ScalarAtFn; use take::TakeFn; @@ -6,6 +7,7 @@ use take::TakeFn; pub mod add; pub mod as_contiguous; pub mod cast; +pub mod fill; pub mod patch; pub mod repeat; pub mod scalar_at; @@ -21,6 +23,10 @@ pub trait ArrayCompute { None } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { + None + } + fn patch(&self) -> Option<&dyn PatchFn> { None } diff --git a/vortex-array/src/compute/patch.rs b/vortex-array/src/compute/patch.rs index f58a650cb0..bfe9cf9ccc 100644 --- a/vortex-array/src/compute/patch.rs +++ b/vortex-array/src/compute/patch.rs @@ -13,7 +13,12 @@ pub fn patch(array: &dyn Array, patch: &dyn Array) -> VortexResult { )); } - // TODO(ngates): check the dtype matches + if array.dtype().as_nullable() != patch.dtype().as_nullable() { + return Err(VortexError::MismatchedTypes( + array.dtype().clone(), + patch.dtype().clone(), + )); + } array .patch()