Skip to content

Commit

Permalink
feat: add BinaryNumericFn for array arithmetic (#1640)
Browse files Browse the repository at this point in the history
I did not implement any binary numeric functions because it is not clear
that there are any cases where we can out run decompression. Two run end
arrays might be a happy path? Two dictionaries, maybe, if the
dictionaries are much smaller than the decompressed arrays?

Binary scalar numeric functions are more obviously valuable: clickbench
includes several uses of scalar add or subtract.
  • Loading branch information
danking authored Dec 17, 2024
1 parent e69bde8 commit fa08a07
Show file tree
Hide file tree
Showing 23 changed files with 570 additions and 268 deletions.
4 changes: 2 additions & 2 deletions bench-vortex/src/bin/notimplemented.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
"fill_forward",
"filter",
"scalar_at",
"subtract_scalar",
"binary_numeric",
"search_sorted",
"slice",
"take",
Expand All @@ -190,7 +190,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().filter_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().subtract_scalar_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().binary_numeric_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().search_sorted_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().slice_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().take_fn().is_some()));
Expand Down
27 changes: 24 additions & 3 deletions encodings/dict/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@ mod compare;
mod like;

use vortex_array::compute::{
filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, LikeFn,
ScalarAtFn, SliceFn, TakeFn,
binary_numeric, filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable,
FilterFn, FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn,
};
use vortex_array::{ArrayData, IntoArrayData};
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use vortex_scalar::{BinaryNumericOperator, Scalar};

use crate::{DictArray, DictEncoding};

impl ComputeVTable for DictEncoding {
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
Expand All @@ -37,6 +41,23 @@ impl ComputeVTable for DictEncoding {
}
}

impl BinaryNumericFn<DictArray> for DictEncoding {
fn binary_numeric(
&self,
array: &DictArray,
rhs: &ArrayData,
op: BinaryNumericOperator,
) -> VortexResult<Option<ArrayData>> {
if !rhs.is_constant() {
return Ok(None);
}

DictArray::try_new(array.codes(), binary_numeric(&array.values(), rhs, op)?)
.map(IntoArrayData::into_array)
.map(Some)
}
}

impl ScalarAtFn<DictArray> for DictEncoding {
fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;
Expand Down
32 changes: 29 additions & 3 deletions encodings/runend/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ use std::ops::AddAssign;
use num_traits::AsPrimitive;
use vortex_array::array::{BooleanBuffer, PrimitiveArray};
use vortex_array::compute::{
filter, scalar_at, slice, CompareFn, ComputeVTable, FillNullFn, FilterFn, FilterMask, InvertFn,
ScalarAtFn, SliceFn, TakeFn,
binary_numeric, filter, scalar_at, slice, BinaryNumericFn, CompareFn, ComputeVTable,
FillNullFn, FilterFn, FilterMask, InvertFn, ScalarAtFn, SliceFn, TakeFn,
};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
use vortex_error::{VortexResult, VortexUnwrap};
use vortex_scalar::Scalar;
use vortex_scalar::{BinaryNumericOperator, Scalar};

use crate::{RunEndArray, RunEndEncoding};

impl ComputeVTable for RunEndEncoding {
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
Expand Down Expand Up @@ -50,6 +54,28 @@ impl ComputeVTable for RunEndEncoding {
}
}

impl BinaryNumericFn<RunEndArray> for RunEndEncoding {
fn binary_numeric(
&self,
array: &RunEndArray,
rhs: &ArrayData,
op: BinaryNumericOperator,
) -> VortexResult<Option<ArrayData>> {
if !rhs.is_constant() {
return Ok(None);
}

RunEndArray::with_offset_and_length(
array.ends(),
binary_numeric(&array.values(), rhs, op)?,
array.offset(),
array.len(),
)
.map(IntoArrayData::into_array)
.map(Some)
}
}

impl ScalarAtFn<RunEndArray> for RunEndEncoding {
fn scalar_at(&self, array: &RunEndArray, index: usize) -> VortexResult<Scalar> {
scalar_at(array.values(), array.find_physical_index(index)?)
Expand Down
3 changes: 1 addition & 2 deletions vortex-array/benches/scalar_subtract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ fn scalar_subtract(c: &mut Criterion) {

group.bench_function("vortex", |b| {
b.iter(|| {
let array =
vortex_array::compute::subtract_scalar(&chunked, &to_subtract.into()).unwrap();
let array = vortex_array::compute::sub_scalar(&chunked, to_subtract.into()).unwrap();

let chunked = ChunkedArray::try_from(array).unwrap();
black_box(chunked);
Expand Down
12 changes: 6 additions & 6 deletions vortex-array/src/array/chunked/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use vortex_error::VortexResult;
use crate::array::chunked::ChunkedArray;
use crate::array::ChunkedEncoding;
use crate::compute::{
try_cast, BinaryBooleanFn, CastFn, CompareFn, ComputeVTable, FillNullFn, FilterFn, InvertFn,
ScalarAtFn, SliceFn, SubtractScalarFn, TakeFn,
try_cast, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, ComputeVTable, FillNullFn,
FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn,
};
use crate::{ArrayData, IntoArrayData};

Expand All @@ -23,6 +23,10 @@ impl ComputeVTable for ChunkedEncoding {
Some(self)
}

fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
Some(self)
}

fn cast_fn(&self) -> Option<&dyn CastFn<ArrayData>> {
Some(self)
}
Expand Down Expand Up @@ -51,10 +55,6 @@ impl ComputeVTable for ChunkedEncoding {
Some(self)
}

fn subtract_scalar_fn(&self) -> Option<&dyn SubtractScalarFn<ArrayData>> {
Some(self)
}

fn take_fn(&self) -> Option<&dyn TakeFn<ArrayData>> {
Some(self)
}
Expand Down
9 changes: 4 additions & 5 deletions vortex-array/src/array/chunked/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use vortex_scalar::Scalar;
use crate::array::chunked::ChunkedArray;
use crate::array::ChunkedEncoding;
use crate::compute::{
scalar_at, search_sorted_usize, slice, subtract_scalar, take, try_cast, SearchSortedSide,
TakeFn,
scalar_at, search_sorted_usize, slice, sub_scalar, take, try_cast, SearchSortedSide, TakeFn,
};
use crate::stats::ArrayStatistics;
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
Expand Down Expand Up @@ -93,15 +92,15 @@ fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResu
.max_value_as_u64()
.try_into()?
{
subtract_scalar(
sub_scalar(
&chunk_indices,
&Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
)?
} else {
// Note. this try_cast (memory copy) is unnecessary, could instead upcast in the subtract fn.
// and avoid an extra
let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())?;
subtract_scalar(&u64_chunk_indices, &chunk_begin.into())?
sub_scalar(&u64_chunk_indices, chunk_begin.into())?
};

indices_by_chunk[chunk_idx] = Some(chunk_indices);
Expand Down
36 changes: 22 additions & 14 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use itertools::Itertools;
use serde::{Deserialize, Serialize};
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap};
use vortex_scalar::Scalar;
use vortex_scalar::BinaryNumericOperator;

use crate::array::primitive::PrimitiveArray;
use crate::compute::{
scalar_at, search_sorted_usize, subtract_scalar, SearchSortedSide, SubtractScalarFn,
binary_numeric, scalar_at, search_sorted_usize, slice, BinaryNumericFn, SearchSortedSide,
};
use crate::encoding::ids;
use crate::iter::{ArrayIterator, ArrayIteratorAdapter};
Expand Down Expand Up @@ -234,17 +234,25 @@ impl ValidityVTable<ChunkedArray> for ChunkedEncoding {
}
}

impl SubtractScalarFn<ChunkedArray> for ChunkedEncoding {
fn subtract_scalar(
impl BinaryNumericFn<ChunkedArray> for ChunkedEncoding {
fn binary_numeric(
&self,
array: &ChunkedArray,
to_subtract: &Scalar,
) -> VortexResult<ArrayData> {
let chunks = array
.chunks()
.map(|chunk| subtract_scalar(&chunk, to_subtract))
.collect::<VortexResult<Vec<_>>>()?;
Ok(ChunkedArray::try_new(chunks, array.dtype().clone())?.into_array())
rhs: &ArrayData,
op: BinaryNumericOperator,
) -> VortexResult<Option<ArrayData>> {
let mut start = 0;

let mut new_chunks = Vec::with_capacity(array.nchunks());
for chunk in array.chunks() {
let end = start + chunk.len();
new_chunks.push(binary_numeric(&chunk, &slice(rhs, start, end)?, op)?);
start = end;
}

ChunkedArray::try_new(new_chunks, array.dtype().clone())
.map(IntoArrayData::into_array)
.map(Some)
}
}

Expand All @@ -254,7 +262,7 @@ mod test {
use vortex_error::VortexResult;

use crate::array::chunked::ChunkedArray;
use crate::compute::{scalar_at, subtract_scalar};
use crate::compute::{scalar_at, sub_scalar};
use crate::{assert_arrays_eq, ArrayDType, IntoArrayData, IntoArrayVariant};

fn chunked_array() -> ChunkedArray {
Expand All @@ -271,9 +279,9 @@ mod test {

#[test]
fn test_scalar_subtract() {
let chunked = chunked_array();
let chunked = chunked_array().into_array();
let to_subtract = 1u64;
let array = subtract_scalar(&chunked, &to_subtract.into()).unwrap();
let array = sub_scalar(&chunked, to_subtract.into()).unwrap();

let chunked = ChunkedArray::try_from(array).unwrap();
let mut chunks_out = chunked.chunks();
Expand Down
31 changes: 31 additions & 0 deletions vortex-array/src/array/constant/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use vortex_error::{vortex_err, VortexResult};
use vortex_scalar::BinaryNumericOperator;

use crate::array::{ConstantArray, ConstantEncoding};
use crate::compute::BinaryNumericFn;
use crate::{ArrayData, ArrayLen as _, IntoArrayData as _};

impl BinaryNumericFn<ConstantArray> for ConstantEncoding {
fn binary_numeric(
&self,
array: &ConstantArray,
rhs: &ArrayData,
op: BinaryNumericOperator,
) -> VortexResult<Option<ArrayData>> {
let Some(rhs) = rhs.as_constant() else {
return Ok(None);
};

Ok(Some(
ConstantArray::new(
array
.scalar()
.as_primitive()
.checked_numeric_operator(rhs.as_primitive(), op)?
.ok_or_else(|| vortex_err!("numeric overflow"))?,
array.len(),
)
.into_array(),
))
}
}
9 changes: 7 additions & 2 deletions vortex-array/src/array/constant/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod binary_numeric;
mod boolean;
mod compare;
mod invert;
Expand All @@ -9,8 +10,8 @@ use vortex_scalar::Scalar;
use crate::array::constant::ConstantArray;
use crate::array::ConstantEncoding;
use crate::compute::{
BinaryBooleanFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn,
SearchSortedFn, SliceFn, TakeFn,
BinaryBooleanFn, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn,
ScalarAtFn, SearchSortedFn, SliceFn, TakeFn,
};
use crate::{ArrayData, IntoArrayData};

Expand All @@ -19,6 +20,10 @@ impl ComputeVTable for ConstantEncoding {
Some(self)
}

fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
Expand Down
20 changes: 18 additions & 2 deletions vortex-array/src/array/null/compute.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use vortex_dtype::{match_each_integer_ptype, DType};
use vortex_error::{vortex_bail, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::{BinaryNumericOperator, Scalar};

use crate::array::null::NullArray;
use crate::array::NullEncoding;
use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn};
use crate::compute::{BinaryNumericFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn};
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};

Expand All @@ -13,6 +13,10 @@ impl ComputeVTable for NullEncoding {
Some(self)
}

fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand All @@ -22,6 +26,18 @@ impl ComputeVTable for NullEncoding {
}
}

impl BinaryNumericFn<NullArray> for NullEncoding {
fn binary_numeric(
&self,
array: &NullArray,
_rhs: &ArrayData,
_op: BinaryNumericOperator,
) -> VortexResult<Option<ArrayData>> {
// for any arithmetic operation, forall X. NULL op X = NULL
Ok(Some(NullArray::new(array.len()).into_array()))
}
}

impl SliceFn<NullArray> for NullEncoding {
fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult<ArrayData> {
Ok(NullArray::new(stop - start).into_array())
Expand Down
7 changes: 1 addition & 6 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::array::PrimitiveEncoding;
use crate::compute::{
CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn,
SearchSortedUsizeFn, SliceFn, SubtractScalarFn, TakeFn,
SearchSortedUsizeFn, SliceFn, TakeFn,
};
use crate::ArrayData;

Expand All @@ -11,7 +11,6 @@ mod filter;
mod scalar_at;
mod search_sorted;
mod slice;
mod subtract_scalar;
mod take;

impl ComputeVTable for PrimitiveEncoding {
Expand Down Expand Up @@ -43,10 +42,6 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn subtract_scalar_fn(&self) -> Option<&dyn SubtractScalarFn<ArrayData>> {
Some(self)
}

fn take_fn(&self) -> Option<&dyn TakeFn<ArrayData>> {
Some(self)
}
Expand Down
Loading

0 comments on commit fa08a07

Please sign in to comment.