diff --git a/bench-vortex/src/bin/notimplemented.rs b/bench-vortex/src/bin/notimplemented.rs index 0f6cdb2713..190406dcb3 100644 --- a/bench-vortex/src/bin/notimplemented.rs +++ b/bench-vortex/src/bin/notimplemented.rs @@ -198,7 +198,7 @@ fn compute_funcs(encodings: &[ArrayData]) { for arr in encodings { let mut impls = vec![Cell::new(arr.encoding().id().as_ref())]; impls.push(bool_to_cell(arr.encoding().cast_fn().is_some())); - impls.push(bool_to_cell(arr.with_dyn(|a| a.fill_forward().is_some()))); + impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some())); impls.push(bool_to_cell(arr.encoding().filter_fn().is_some())); impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some())); impls.push(bool_to_cell( diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 2dd04f362d..e2e47971d8 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -10,13 +10,13 @@ use vortex_scalar::Scalar; use super::{ByteBoolArray, ByteBoolEncoding}; -impl ArrayCompute for ByteBoolArray { - fn fill_forward(&self) -> Option<&dyn FillForwardFn> { +impl ArrayCompute for ByteBoolArray {} + +impl ComputeVTable for ByteBoolEncoding { + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { None } -} -impl ComputeVTable for ByteBoolEncoding { fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -99,20 +99,23 @@ impl TakeFn for ByteBoolEncoding { } } -impl FillForwardFn for ByteBoolArray { - fn fill_forward(&self) -> VortexResult { - let validity = self.logical_validity(); - if self.dtype().nullability() == Nullability::NonNullable { - return Ok(self.to_array()); +impl FillForwardFn for ByteBoolEncoding { + fn fill_forward(&self, array: &ByteBoolArray) -> VortexResult { + let validity = array.logical_validity(); + if array.dtype().nullability() == Nullability::NonNullable { + return Ok(array.to_array()); } // all valid, but we need to convert to non-nullable if validity.all_valid() { - return Ok(Self::try_new(self.buffer().clone(), Validity::AllValid)?.into_array()); + return Ok( + ByteBoolArray::try_new(array.buffer().clone(), Validity::AllValid)?.into_array(), + ); } // all invalid => fill with default value (false) if validity.all_invalid() { return Ok( - Self::try_from_vec(vec![false; self.len()], Validity::AllValid)?.into_array(), + ByteBoolArray::try_from_vec(vec![false; array.len()], Validity::AllValid)? + .into_array(), ); } @@ -120,7 +123,7 @@ impl FillForwardFn for ByteBoolArray { .to_null_buffer()? .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?; - let bools = self.maybe_null_slice(); + let bools = array.maybe_null_slice(); let mut last_value = bool::default(); let filled = bools @@ -135,7 +138,7 @@ impl FillForwardFn for ByteBoolArray { }) .collect::>(); - Ok(Self::try_from_vec(filled, Validity::AllValid)?.into_array()) + Ok(ByteBoolArray::try_from_vec(filled, Validity::AllValid)?.into_array()) } } diff --git a/vortex-array/src/array/bool/compute/fill.rs b/vortex-array/src/array/bool/compute/fill.rs index 71831a921a..de7cf0725a 100644 --- a/vortex-array/src/array/bool/compute/fill.rs +++ b/vortex-array/src/array/bool/compute/fill.rs @@ -2,35 +2,37 @@ use arrow_buffer::BooleanBuffer; use vortex_dtype::Nullability; use vortex_error::{vortex_err, VortexResult}; -use crate::array::BoolArray; +use crate::array::{BoolArray, BoolEncoding}; use crate::compute::unary::FillForwardFn; use crate::validity::{ArrayValidity, Validity}; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, ToArrayData}; -impl FillForwardFn for BoolArray { - fn fill_forward(&self) -> VortexResult { - let validity = self.logical_validity(); +impl FillForwardFn for BoolEncoding { + fn fill_forward(&self, array: &BoolArray) -> VortexResult { + let validity = array.logical_validity(); // nothing to see or do in this case - if self.dtype().nullability() == Nullability::NonNullable { - return Ok(self.to_array()); + if array.dtype().nullability() == Nullability::NonNullable { + return Ok(array.to_array()); } + // all valid, but we need to convert to non-nullable if validity.all_valid() { - return Ok(Self::new(self.boolean_buffer(), Nullability::Nullable).into_array()); + return Ok(BoolArray::new(array.boolean_buffer(), Nullability::Nullable).into_array()); } // all invalid => fill with default value (false) if validity.all_invalid() { - return Ok( - Self::try_new(BooleanBuffer::new_unset(self.len()), Validity::AllValid)? - .into_array(), - ); + return Ok(BoolArray::try_new( + BooleanBuffer::new_unset(array.len()), + Validity::AllValid, + )? + .into_array()); } let validity = validity .to_null_buffer()? .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?; - let bools = self.boolean_buffer(); + let bools = array.boolean_buffer(); let mut last_value = false; let buffer = BooleanBuffer::from_iter(bools.iter().zip(validity.inner().iter()).map( |(v, valid)| { @@ -40,7 +42,7 @@ impl FillForwardFn for BoolArray { last_value }, )); - Ok(Self::try_new(buffer, Validity::AllValid)?.into_array()) + Ok(BoolArray::try_new(buffer, Validity::AllValid)?.into_array()) } } diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 559d95857b..9d26575fda 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -13,10 +13,6 @@ mod slice; mod take; impl ArrayCompute for BoolArray { - fn fill_forward(&self) -> Option<&dyn FillForwardFn> { - Some(self) - } - fn and(&self) -> Option<&dyn AndFn> { Some(self) } @@ -27,6 +23,10 @@ impl ArrayCompute for BoolArray { } impl ComputeVTable for BoolEncoding { + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/compute/fill.rs b/vortex-array/src/array/primitive/compute/fill.rs index 04821b1f98..d883a9d599 100644 --- a/vortex-array/src/array/primitive/compute/fill.rs +++ b/vortex-array/src/array/primitive/compute/fill.rs @@ -2,34 +2,35 @@ use vortex_dtype::{match_each_native_ptype, Nullability}; use vortex_error::{vortex_err, VortexResult}; use crate::array::primitive::PrimitiveArray; +use crate::array::PrimitiveEncoding; use crate::compute::unary::FillForwardFn; use crate::validity::{ArrayValidity, Validity}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, ToArrayData}; -impl FillForwardFn for PrimitiveArray { - fn fill_forward(&self) -> VortexResult { - if self.dtype().nullability() == Nullability::NonNullable { - return Ok(self.to_array()); +impl FillForwardFn for PrimitiveEncoding { + fn fill_forward(&self, array: &PrimitiveArray) -> VortexResult { + if array.dtype().nullability() == Nullability::NonNullable { + return Ok(array.to_array()); } - let validity = self.logical_validity(); + let validity = array.logical_validity(); if validity.all_valid() { return Ok(PrimitiveArray::new( - self.buffer().clone(), - self.ptype(), + array.buffer().clone(), + array.ptype(), Validity::AllValid, ) .into_array()); } - match_each_native_ptype!(self.ptype(), |$T| { + match_each_native_ptype!(array.ptype(), |$T| { if validity.all_invalid() { - return Ok(PrimitiveArray::from_vec(vec![$T::default(); self.len()], Validity::AllValid).into_array()); + return Ok(PrimitiveArray::from_vec(vec![$T::default(); array.len()], Validity::AllValid).into_array()); } let nulls = validity.to_null_buffer()?.ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?; - let maybe_null_slice = self.maybe_null_slice::<$T>(); + let maybe_null_slice = array.maybe_null_slice::<$T>(); let mut last_value = $T::default(); let filled = maybe_null_slice .iter() diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index 32e27f550a..eb67205ff5 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -24,10 +24,6 @@ impl ArrayCompute for PrimitiveArray { MaybeCompareFn::maybe_compare(self, other, operator) } - fn fill_forward(&self) -> Option<&dyn FillForwardFn> { - Some(self) - } - fn subtract_scalar(&self) -> Option<&dyn SubtractScalarFn> { Some(self) } @@ -42,6 +38,10 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 3072e530ad..30e85b4f48 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -37,6 +37,13 @@ pub trait ComputeVTable { None } + /// Array function that returns new arrays a non-null value is repeated across runs of nulls. + /// + /// See: [FillForwardFn]. + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { + None + } + /// Filter an array with a given mask. /// /// See: [FilterFn]. @@ -76,13 +83,6 @@ pub trait ArrayCompute { None } - /// Array function that returns new arrays a non-null value is repeated across runs of nulls. - /// - /// See: [FillForwardFn]. - fn fill_forward(&self) -> Option<&dyn FillForwardFn> { - None - } - /// Broadcast subtraction of scalar from Vortex array. /// /// See: [SubtractScalarFn]. diff --git a/vortex-array/src/compute/unary/fill_forward.rs b/vortex-array/src/compute/unary/fill_forward.rs index c08dd7d91f..71f6bb2342 100644 --- a/vortex-array/src/compute/unary/fill_forward.rs +++ b/vortex-array/src/compute/unary/fill_forward.rs @@ -1,5 +1,6 @@ -use vortex_error::{vortex_err, VortexResult}; +use vortex_error::{vortex_err, VortexError, VortexResult}; +use crate::encoding::Encoding; use crate::{ArrayDType, ArrayData}; /// Trait for filling forward on an array, i.e., replacing nulls with the last non-null value. @@ -7,8 +8,24 @@ use crate::{ArrayDType, ArrayData}; /// If the array is non-nullable, it is returned as-is. /// If the array is entirely nulls, the fill forward operation returns an array of the same length, filled with the default value of the array's type. /// The DType of the returned array is the same as the input array; the Validity of the returned array is always either NonNullable or AllValid. -pub trait FillForwardFn { - fn fill_forward(&self) -> VortexResult; +pub trait FillForwardFn { + fn fill_forward(&self, array: &Array) -> VortexResult; +} + +impl FillForwardFn for E +where + E: FillForwardFn, + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + fn fill_forward(&self, array: &ArrayData) -> VortexResult { + let array_ref = <&E::Array>::try_from(array)?; + let encoding = array + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + FillForwardFn::fill_forward(encoding, array_ref) + } } pub fn fill_forward(array: impl AsRef) -> VortexResult { @@ -16,15 +33,14 @@ pub fn fill_forward(array: impl AsRef) -> VortexResult { if !array.dtype().is_nullable() { return Ok(array.clone()); } - - array.with_dyn(|a| { - a.fill_forward() - .map(|t| t.fill_forward()) - .unwrap_or_else(|| { - Err(vortex_err!( - NotImplemented: "fill_forward", - array.encoding().id() - )) - }) - }) + array + .encoding() + .fill_forward_fn() + .map(|f| f.fill_forward(array)) + .unwrap_or_else(|| { + Err(vortex_err!( + NotImplemented: "fill_forward", + array.encoding().id() + )) + }) }