Skip to content

Commit

Permalink
FillForward VTable (#1405)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Nov 20, 2024
1 parent 2e722f4 commit 091a15d
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 66 deletions.
2 changes: 1 addition & 1 deletion bench-vortex/src/bin/notimplemented.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
for arr in encodings {
let mut impls = vec![Cell::new(arr.encoding().id().as_ref())];
impls.push(bool_to_cell(arr.encoding().cast_fn().is_some()));
impls.push(bool_to_cell(arr.with_dyn(|a| a.fill_forward().is_some())));
impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().filter_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some()));
impls.push(bool_to_cell(
Expand Down
29 changes: 16 additions & 13 deletions encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use vortex_scalar::Scalar;

use super::{ByteBoolArray, ByteBoolEncoding};

impl ArrayCompute for ByteBoolArray {
fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
impl ArrayCompute for ByteBoolArray {}

impl ComputeVTable for ByteBoolEncoding {
fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<ArrayData>> {
None
}
}

impl ComputeVTable for ByteBoolEncoding {
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<ArrayData>> {
Some(self)
}
Expand Down Expand Up @@ -99,28 +99,31 @@ impl TakeFn<ByteBoolArray> for ByteBoolEncoding {
}
}

impl FillForwardFn for ByteBoolArray {
fn fill_forward(&self) -> VortexResult<ArrayData> {
let validity = self.logical_validity();
if self.dtype().nullability() == Nullability::NonNullable {
return Ok(self.to_array());
impl FillForwardFn<ByteBoolArray> for ByteBoolEncoding {
fn fill_forward(&self, array: &ByteBoolArray) -> VortexResult<ArrayData> {
let validity = array.logical_validity();
if array.dtype().nullability() == Nullability::NonNullable {
return Ok(array.to_array());
}
// all valid, but we need to convert to non-nullable
if validity.all_valid() {
return Ok(Self::try_new(self.buffer().clone(), Validity::AllValid)?.into_array());
return Ok(
ByteBoolArray::try_new(array.buffer().clone(), Validity::AllValid)?.into_array(),
);
}
// all invalid => fill with default value (false)
if validity.all_invalid() {
return Ok(
Self::try_from_vec(vec![false; self.len()], Validity::AllValid)?.into_array(),
ByteBoolArray::try_from_vec(vec![false; array.len()], Validity::AllValid)?
.into_array(),
);
}

let validity = validity
.to_null_buffer()?
.ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;

let bools = self.maybe_null_slice();
let bools = array.maybe_null_slice();
let mut last_value = bool::default();

let filled = bools
Expand All @@ -135,7 +138,7 @@ impl FillForwardFn for ByteBoolArray {
})
.collect::<Vec<_>>();

Ok(Self::try_from_vec(filled, Validity::AllValid)?.into_array())
Ok(ByteBoolArray::try_from_vec(filled, Validity::AllValid)?.into_array())
}
}

Expand Down
28 changes: 15 additions & 13 deletions vortex-array/src/array/bool/compute/fill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,37 @@ use arrow_buffer::BooleanBuffer;
use vortex_dtype::Nullability;
use vortex_error::{vortex_err, VortexResult};

use crate::array::BoolArray;
use crate::array::{BoolArray, BoolEncoding};
use crate::compute::unary::FillForwardFn;
use crate::validity::{ArrayValidity, Validity};
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, ToArrayData};

impl FillForwardFn for BoolArray {
fn fill_forward(&self) -> VortexResult<ArrayData> {
let validity = self.logical_validity();
impl FillForwardFn<BoolArray> for BoolEncoding {
fn fill_forward(&self, array: &BoolArray) -> VortexResult<ArrayData> {
let validity = array.logical_validity();
// nothing to see or do in this case
if self.dtype().nullability() == Nullability::NonNullable {
return Ok(self.to_array());
if array.dtype().nullability() == Nullability::NonNullable {
return Ok(array.to_array());
}

// all valid, but we need to convert to non-nullable
if validity.all_valid() {
return Ok(Self::new(self.boolean_buffer(), Nullability::Nullable).into_array());
return Ok(BoolArray::new(array.boolean_buffer(), Nullability::Nullable).into_array());
}
// all invalid => fill with default value (false)
if validity.all_invalid() {
return Ok(
Self::try_new(BooleanBuffer::new_unset(self.len()), Validity::AllValid)?
.into_array(),
);
return Ok(BoolArray::try_new(
BooleanBuffer::new_unset(array.len()),
Validity::AllValid,
)?
.into_array());
}

let validity = validity
.to_null_buffer()?
.ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;

let bools = self.boolean_buffer();
let bools = array.boolean_buffer();
let mut last_value = false;
let buffer = BooleanBuffer::from_iter(bools.iter().zip(validity.inner().iter()).map(
|(v, valid)| {
Expand All @@ -40,7 +42,7 @@ impl FillForwardFn for BoolArray {
last_value
},
));
Ok(Self::try_new(buffer, Validity::AllValid)?.into_array())
Ok(BoolArray::try_new(buffer, Validity::AllValid)?.into_array())
}
}

Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/bool/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ mod slice;
mod take;

impl ArrayCompute for BoolArray {
fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}

fn and(&self) -> Option<&dyn AndFn> {
Some(self)
}
Expand All @@ -27,6 +23,10 @@ impl ArrayCompute for BoolArray {
}

impl ComputeVTable for BoolEncoding {
fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
21 changes: 11 additions & 10 deletions vortex-array/src/array/primitive/compute/fill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,35 @@ use vortex_dtype::{match_each_native_ptype, Nullability};
use vortex_error::{vortex_err, VortexResult};

use crate::array::primitive::PrimitiveArray;
use crate::array::PrimitiveEncoding;
use crate::compute::unary::FillForwardFn;
use crate::validity::{ArrayValidity, Validity};
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, ToArrayData};

impl FillForwardFn for PrimitiveArray {
fn fill_forward(&self) -> VortexResult<ArrayData> {
if self.dtype().nullability() == Nullability::NonNullable {
return Ok(self.to_array());
impl FillForwardFn<PrimitiveArray> for PrimitiveEncoding {
fn fill_forward(&self, array: &PrimitiveArray) -> VortexResult<ArrayData> {
if array.dtype().nullability() == Nullability::NonNullable {
return Ok(array.to_array());
}

let validity = self.logical_validity();
let validity = array.logical_validity();
if validity.all_valid() {
return Ok(PrimitiveArray::new(
self.buffer().clone(),
self.ptype(),
array.buffer().clone(),
array.ptype(),
Validity::AllValid,
)
.into_array());
}

match_each_native_ptype!(self.ptype(), |$T| {
match_each_native_ptype!(array.ptype(), |$T| {
if validity.all_invalid() {
return Ok(PrimitiveArray::from_vec(vec![$T::default(); self.len()], Validity::AllValid).into_array());
return Ok(PrimitiveArray::from_vec(vec![$T::default(); array.len()], Validity::AllValid).into_array());
}

let nulls = validity.to_null_buffer()?.ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;
let maybe_null_slice = self.maybe_null_slice::<$T>();
let maybe_null_slice = array.maybe_null_slice::<$T>();
let mut last_value = $T::default();
let filled = maybe_null_slice
.iter()
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ impl ArrayCompute for PrimitiveArray {
MaybeCompareFn::maybe_compare(self, other, operator)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}

fn subtract_scalar(&self) -> Option<&dyn SubtractScalarFn> {
Some(self)
}
Expand All @@ -42,6 +38,10 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
14 changes: 7 additions & 7 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ pub trait ComputeVTable {
None
}

/// Array function that returns new arrays a non-null value is repeated across runs of nulls.
///
/// See: [FillForwardFn].
fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<ArrayData>> {
None
}

/// Filter an array with a given mask.
///
/// See: [FilterFn].
Expand Down Expand Up @@ -76,13 +83,6 @@ pub trait ArrayCompute {
None
}

/// Array function that returns new arrays a non-null value is repeated across runs of nulls.
///
/// See: [FillForwardFn].
fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
None
}

/// Broadcast subtraction of scalar from Vortex array.
///
/// See: [SubtractScalarFn].
Expand Down
44 changes: 30 additions & 14 deletions vortex-array/src/compute/unary/fill_forward.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
use vortex_error::{vortex_err, VortexResult};
use vortex_error::{vortex_err, VortexError, VortexResult};

use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData};

/// Trait for filling forward on an array, i.e., replacing nulls with the last non-null value.
///
/// If the array is non-nullable, it is returned as-is.
/// If the array is entirely nulls, the fill forward operation returns an array of the same length, filled with the default value of the array's type.
/// The DType of the returned array is the same as the input array; the Validity of the returned array is always either NonNullable or AllValid.
pub trait FillForwardFn {
fn fill_forward(&self) -> VortexResult<ArrayData>;
pub trait FillForwardFn<Array> {
fn fill_forward(&self, array: &Array) -> VortexResult<ArrayData>;
}

impl<E: Encoding + 'static> FillForwardFn<ArrayData> for E
where
E: FillForwardFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn fill_forward(&self, array: &ArrayData) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
FillForwardFn::fill_forward(encoding, array_ref)
}
}

pub fn fill_forward(array: impl AsRef<ArrayData>) -> VortexResult<ArrayData> {
let array = array.as_ref();
if !array.dtype().is_nullable() {
return Ok(array.clone());
}

array.with_dyn(|a| {
a.fill_forward()
.map(|t| t.fill_forward())
.unwrap_or_else(|| {
Err(vortex_err!(
NotImplemented: "fill_forward",
array.encoding().id()
))
})
})
array
.encoding()
.fill_forward_fn()
.map(|f| f.fill_forward(array))
.unwrap_or_else(|| {
Err(vortex_err!(
NotImplemented: "fill_forward",
array.encoding().id()
))
})
}

0 comments on commit 091a15d

Please sign in to comment.