Skip to content

Commit

Permalink
CompareFn VTable (#1426)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Nov 21, 2024
1 parent 052b08b commit df11488
Show file tree
Hide file tree
Showing 17 changed files with 169 additions and 106 deletions.
2 changes: 1 addition & 1 deletion bench-vortex/src/bin/notimplemented.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
for arr in encodings {
let mut impls = vec![Cell::new(arr.encoding().id().as_ref())];
impls.push(bool_to_cell(arr.encoding().cast_fn().is_some()));
impls.push(bool_to_cell(arr.with_dyn(|a| a.compare().is_some())));
impls.push(bool_to_cell(arr.encoding().compare_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().filter_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some()));
Expand Down
21 changes: 13 additions & 8 deletions encodings/alp/src/alp/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ use vortex_dtype::Nullability;
use vortex_error::VortexResult;
use vortex_scalar::{PValue, Scalar};

use crate::{ALPArray, ALPFloat};

impl CompareFn for ALPArray {
fn compare(&self, array: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
if let Some(const_scalar) = array.as_constant() {
use crate::{ALPArray, ALPEncoding, ALPFloat};

impl CompareFn<ALPArray> for ALPEncoding {
fn compare(
&self,
lhs: &ALPArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
if let Some(const_scalar) = rhs.as_constant() {
let pvalue = const_scalar.value().as_pvalue()?;

return match pvalue {
Some(PValue::F32(f)) => alp_scalar_compare(self, f, operator).map(Some),
Some(PValue::F64(f)) => alp_scalar_compare(self, f, operator).map(Some),
Some(PValue::F32(f)) => alp_scalar_compare(lhs, f, operator).map(Some),
Some(PValue::F64(f)) => alp_scalar_compare(lhs, f, operator).map(Some),
Some(_) | None => Ok(Some(
ConstantArray::new(Scalar::bool(false, Nullability::Nullable), self.len())
ConstantArray::new(Scalar::bool(false, Nullability::Nullable), lhs.len())
.into_array(),
)),
};
Expand Down
8 changes: 4 additions & 4 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ use vortex_scalar::Scalar;

use crate::{match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPFloat};

impl ArrayCompute for ALPArray {
fn compare(&self) -> Option<&dyn CompareFn> {
impl ArrayCompute for ALPArray {}

impl ComputeVTable for ALPEncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
}

impl ComputeVTable for ALPEncoding {
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
19 changes: 12 additions & 7 deletions encodings/dict/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@ use vortex_array::compute::{compare, CompareFn, Operator};
use vortex_array::{ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::DictArray;
use crate::{DictArray, DictEncoding};

impl CompareFn for DictArray {
fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
impl CompareFn<DictArray> for DictEncoding {
fn compare(
&self,
lhs: &DictArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
// If the RHS is constant, then we just need to compare against our encoded values.
if let Some(const_scalar) = other.as_constant() {
if let Some(const_scalar) = rhs.as_constant() {
// Ensure the other is the same length as the dictionary
return compare(
self.values(),
ConstantArray::new(const_scalar, self.values().len()),
lhs.values(),
ConstantArray::new(const_scalar, lhs.values().len()),
operator,
)
.and_then(|values| Self::try_new(self.codes(), values))
.and_then(|values| DictArray::try_new(lhs.codes(), values))
.map(|a| a.into_array())
.map(Some);
}
Expand Down
8 changes: 4 additions & 4 deletions encodings/dict/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use vortex_scalar::Scalar;

use crate::{DictArray, DictEncoding};

impl ArrayCompute for DictArray {
fn compare(&self) -> Option<&dyn CompareFn> {
impl ArrayCompute for DictArray {}

impl ComputeVTable for DictEncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
}

impl ComputeVTable for DictEncoding {
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
19 changes: 12 additions & 7 deletions encodings/fsst/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ use vortex_buffer::Buffer;
use vortex_dtype::DType;
use vortex_error::VortexResult;

use crate::FSSTArray;

impl CompareFn for FSSTArray {
fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
match (other.as_constant(), operator) {
use crate::{FSSTArray, FSSTEncoding};

impl CompareFn<FSSTArray> for FSSTEncoding {
fn compare(
&self,
lhs: &FSSTArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
match (rhs.as_constant(), operator) {
// TODO(ngates): implement short-circuit comparisons for other operators.
(Some(constant_array), Operator::Eq | Operator::NotEq) => compare_fsst_constant(
self,
&ConstantArray::new(constant_array, self.len()),
lhs,
&ConstantArray::new(constant_array, lhs.len()),
operator == Operator::Eq,
)
.map(Some),
Expand Down
8 changes: 4 additions & 4 deletions encodings/fsst/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use vortex_scalar::Scalar;

use crate::{FSSTArray, FSSTEncoding};

impl ArrayCompute for FSSTArray {
fn compare(&self) -> Option<&dyn CompareFn> {
impl ArrayCompute for FSSTArray {}

impl ComputeVTable for FSSTEncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
}

impl ComputeVTable for FSSTEncoding {
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
27 changes: 16 additions & 11 deletions encodings/runend/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@ use vortex_array::compute::{compare, CompareFn, Operator};
use vortex_array::{ArrayData, ArrayLen, IntoArrayData};
use vortex_error::VortexResult;

use crate::RunEndArray;
use crate::{RunEndArray, RunEndEncoding};

impl CompareFn for RunEndArray {
fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
impl CompareFn<RunEndArray> for RunEndEncoding {
fn compare(
&self,
lhs: &RunEndArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
// If the RHS is constant, then we just need to compare against our encoded values.
if let Some(const_scalar) = other.as_constant() {
if let Some(const_scalar) = rhs.as_constant() {
return compare(
self.values(),
ConstantArray::new(const_scalar, self.values().len()),
lhs.values(),
ConstantArray::new(const_scalar, lhs.values().len()),
operator,
)
.and_then(|values| {
Self::with_offset_and_length(
self.ends(),
RunEndArray::with_offset_and_length(
lhs.ends(),
values,
self.validity().into_nullable(),
self.offset(),
self.len(),
lhs.validity().into_nullable(),
lhs.offset(),
lhs.len(),
)
})
.map(|a| a.into_array())
Expand Down
8 changes: 4 additions & 4 deletions encodings/runend/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ use vortex_scalar::{Scalar, ScalarValue};

use crate::{RunEndArray, RunEndEncoding};

impl ArrayCompute for RunEndArray {
fn compare(&self) -> Option<&dyn CompareFn> {
impl ArrayCompute for RunEndArray {}

impl ComputeVTable for RunEndEncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}
}

impl ComputeVTable for RunEndEncoding {
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
25 changes: 15 additions & 10 deletions vortex-array/src/array/chunked/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ mod scalar_at;
mod slice;
mod take;

impl ArrayCompute for ChunkedArray {
fn compare(&self) -> Option<&dyn CompareFn> {
Some(self)
}
}
impl ArrayCompute for ChunkedArray {}

impl ComputeVTable for ChunkedEncoding {
fn cast_fn(&self) -> Option<&dyn CastFn<ArrayData>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down Expand Up @@ -56,13 +56,18 @@ impl CastFn<ChunkedArray> for ChunkedEncoding {
}
}

impl CompareFn for ChunkedArray {
fn compare(&self, array: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
impl CompareFn<ChunkedArray> for ChunkedEncoding {
fn compare(
&self,
lhs: &ChunkedArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
let mut idx = 0;
let mut compare_chunks = Vec::with_capacity(self.nchunks());
let mut compare_chunks = Vec::with_capacity(lhs.nchunks());

for chunk in self.chunks() {
let sliced = slice(array, idx, idx + chunk.len())?;
for chunk in lhs.chunks() {
let sliced = slice(rhs, idx, idx + chunk.len())?;
let cmp_result = compare(&chunk, &sliced, operator)?;
compare_chunks.push(cmp_result);

Expand Down
19 changes: 12 additions & 7 deletions vortex-array/src/array/constant/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use vortex_error::VortexResult;

use crate::array::ConstantArray;
use crate::array::{ConstantArray, ConstantEncoding};
use crate::compute::{scalar_cmp, CompareFn, Operator};
use crate::{ArrayData, ArrayLen, IntoArrayData};

impl CompareFn for ConstantArray {
fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
impl CompareFn<ConstantArray> for ConstantEncoding {
fn compare(
&self,
lhs: &ConstantArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
// We only support comparing a constant array to another constant array.
// For all other encodings, we assume the constant is on the RHS.
if let Some(const_scalar) = other.as_constant() {
let lhs = self.owned_scalar();
let scalar = scalar_cmp(&lhs, &const_scalar, operator);
return Ok(Some(ConstantArray::new(scalar, self.len()).into_array()));
if let Some(const_scalar) = rhs.as_constant() {
let lhs_scalar = lhs.owned_scalar();
let scalar = scalar_cmp(&lhs_scalar, &const_scalar, operator);
return Ok(Some(ConstantArray::new(scalar, lhs.len()).into_array()));
}

Ok(None)
Expand Down
10 changes: 5 additions & 5 deletions vortex-array/src/array/constant/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ use crate::compute::{
};
use crate::{ArrayData, IntoArrayData};

impl ArrayCompute for ConstantArray {
fn compare(&self) -> Option<&dyn CompareFn> {
Some(self)
}
}
impl ArrayCompute for ConstantArray {}

impl ComputeVTable for ConstantEncoding {
fn binary_boolean_fn(
Expand All @@ -31,6 +27,10 @@ impl ComputeVTable for ConstantEncoding {
(lhs.is_constant() && rhs.is_constant()).then_some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
23 changes: 14 additions & 9 deletions vortex-array/src/array/extension/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,28 @@ use crate::compute::{compare, CompareFn, Operator};
use crate::encoding::EncodingVTable;
use crate::{ArrayDType, ArrayData, ArrayLen};

impl CompareFn for ExtensionArray {
fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult<Option<ArrayData>> {
impl CompareFn<ExtensionArray> for ExtensionEncoding {
fn compare(
&self,
lhs: &ExtensionArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
// If the RHS is a constant, we can extract the storage scalar.
if let Some(const_ext) = other.as_constant() {
if let Some(const_ext) = rhs.as_constant() {
let scalar_ext = ExtScalar::try_new(const_ext.dtype(), const_ext.value())?;
let storage_scalar = ConstantArray::new(
Scalar::new(self.storage().dtype().clone(), scalar_ext.value().clone()),
self.len(),
Scalar::new(lhs.storage().dtype().clone(), scalar_ext.value().clone()),
lhs.len(),
);

return compare(self.storage(), storage_scalar, operator).map(Some);
return compare(lhs.storage(), storage_scalar, operator).map(Some);
}

// If the RHS is an extension array matching ours, we can extract the storage.
if other.is_encoding(ExtensionEncoding.id()) {
let rhs_ext = ExtensionArray::try_from(other.clone())?;
return compare(self.storage(), rhs_ext.storage(), operator).map(Some);
if rhs.is_encoding(ExtensionEncoding.id()) {
let rhs_ext = ExtensionArray::try_from(rhs.clone())?;
return compare(lhs.storage(), rhs_ext.storage(), operator).map(Some);
}

// Otherwise, we need the RHS to handle this comparison.
Expand Down
10 changes: 5 additions & 5 deletions vortex-array/src/array/extension/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ use crate::compute::{
use crate::variants::ExtensionArrayTrait;
use crate::{ArrayData, IntoArrayData};

impl ArrayCompute for ExtensionArray {
fn compare(&self) -> Option<&dyn CompareFn> {
Some(self)
}
}
impl ArrayCompute for ExtensionArray {}

impl ComputeVTable for ExtensionEncoding {
fn cast_fn(&self) -> Option<&dyn CastFn<ArrayData>> {
Expand All @@ -26,6 +22,10 @@ impl ComputeVTable for ExtensionEncoding {
None
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<ArrayData>> {
Some(self)
}
Expand Down
Loading

0 comments on commit df11488

Please sign in to comment.