diff --git a/bench-vortex/src/bin/notimplemented.rs b/bench-vortex/src/bin/notimplemented.rs index 6342f53446..69fb5671a0 100644 --- a/bench-vortex/src/bin/notimplemented.rs +++ b/bench-vortex/src/bin/notimplemented.rs @@ -187,7 +187,7 @@ fn compute_funcs(encodings: &[ArrayData]) { for arr in encodings { let mut impls = vec![Cell::new(arr.encoding().id().as_ref())]; impls.push(bool_to_cell(arr.encoding().cast_fn().is_some())); - impls.push(bool_to_cell(arr.with_dyn(|a| a.compare().is_some()))); + impls.push(bool_to_cell(arr.encoding().compare_fn().is_some())); impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some())); impls.push(bool_to_cell(arr.encoding().filter_fn().is_some())); impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some())); diff --git a/encodings/alp/src/alp/compute/compare.rs b/encodings/alp/src/alp/compute/compare.rs index c28c229882..5be5611539 100644 --- a/encodings/alp/src/alp/compute/compare.rs +++ b/encodings/alp/src/alp/compute/compare.rs @@ -5,18 +5,23 @@ use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_scalar::{PValue, Scalar}; -use crate::{ALPArray, ALPFloat}; - -impl CompareFn for ALPArray { - fn compare(&self, array: &ArrayData, operator: Operator) -> VortexResult> { - if let Some(const_scalar) = array.as_constant() { +use crate::{ALPArray, ALPEncoding, ALPFloat}; + +impl CompareFn for ALPEncoding { + fn compare( + &self, + lhs: &ALPArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { + if let Some(const_scalar) = rhs.as_constant() { let pvalue = const_scalar.value().as_pvalue()?; return match pvalue { - Some(PValue::F32(f)) => alp_scalar_compare(self, f, operator).map(Some), - Some(PValue::F64(f)) => alp_scalar_compare(self, f, operator).map(Some), + Some(PValue::F32(f)) => alp_scalar_compare(lhs, f, operator).map(Some), + Some(PValue::F64(f)) => alp_scalar_compare(lhs, f, operator).map(Some), Some(_) | None => Ok(Some( - ConstantArray::new(Scalar::bool(false, Nullability::Nullable), self.len()) + ConstantArray::new(Scalar::bool(false, Nullability::Nullable), lhs.len()) .into_array(), )), }; diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index d56bf79df1..82525315c4 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -12,13 +12,13 @@ use vortex_scalar::Scalar; use crate::{match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPFloat}; -impl ArrayCompute for ALPArray { - fn compare(&self) -> Option<&dyn CompareFn> { +impl ArrayCompute for ALPArray {} + +impl ComputeVTable for ALPEncoding { + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } -} -impl ComputeVTable for ALPEncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/encodings/dict/src/compute/compare.rs b/encodings/dict/src/compute/compare.rs index 572d730c34..042bbce4e1 100644 --- a/encodings/dict/src/compute/compare.rs +++ b/encodings/dict/src/compute/compare.rs @@ -3,19 +3,24 @@ use vortex_array::compute::{compare, CompareFn, Operator}; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; -use crate::DictArray; +use crate::{DictArray, DictEncoding}; -impl CompareFn for DictArray { - fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult> { +impl CompareFn for DictEncoding { + fn compare( + &self, + lhs: &DictArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { // If the RHS is constant, then we just need to compare against our encoded values. - if let Some(const_scalar) = other.as_constant() { + if let Some(const_scalar) = rhs.as_constant() { // Ensure the other is the same length as the dictionary return compare( - self.values(), - ConstantArray::new(const_scalar, self.values().len()), + lhs.values(), + ConstantArray::new(const_scalar, lhs.values().len()), operator, ) - .and_then(|values| Self::try_new(self.codes(), values)) + .and_then(|values| DictArray::try_new(lhs.codes(), values)) .map(|a| a.into_array()) .map(Some); } diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index ef9bdb0b66..b6c1f63bda 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -11,13 +11,13 @@ use vortex_scalar::Scalar; use crate::{DictArray, DictEncoding}; -impl ArrayCompute for DictArray { - fn compare(&self) -> Option<&dyn CompareFn> { +impl ArrayCompute for DictArray {} + +impl ComputeVTable for DictEncoding { + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } -} -impl ComputeVTable for DictEncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 6cb0ab8d88..16c94cea98 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -6,15 +6,20 @@ use vortex_buffer::Buffer; use vortex_dtype::DType; use vortex_error::VortexResult; -use crate::FSSTArray; - -impl CompareFn for FSSTArray { - fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult> { - match (other.as_constant(), operator) { +use crate::{FSSTArray, FSSTEncoding}; + +impl CompareFn for FSSTEncoding { + fn compare( + &self, + lhs: &FSSTArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { + match (rhs.as_constant(), operator) { // TODO(ngates): implement short-circuit comparisons for other operators. (Some(constant_array), Operator::Eq | Operator::NotEq) => compare_fsst_constant( - self, - &ConstantArray::new(constant_array, self.len()), + lhs, + &ConstantArray::new(constant_array, lhs.len()), operator == Operator::Eq, ) .map(Some), diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index f9a40bac9d..918fc36c32 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -13,13 +13,13 @@ use vortex_scalar::Scalar; use crate::{FSSTArray, FSSTEncoding}; -impl ArrayCompute for FSSTArray { - fn compare(&self) -> Option<&dyn CompareFn> { +impl ArrayCompute for FSSTArray {} + +impl ComputeVTable for FSSTEncoding { + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } -} -impl ComputeVTable for FSSTEncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/encodings/runend/src/compute/compare.rs b/encodings/runend/src/compute/compare.rs index ab5ff308a4..77388216ff 100644 --- a/encodings/runend/src/compute/compare.rs +++ b/encodings/runend/src/compute/compare.rs @@ -3,24 +3,29 @@ use vortex_array::compute::{compare, CompareFn, Operator}; use vortex_array::{ArrayData, ArrayLen, IntoArrayData}; use vortex_error::VortexResult; -use crate::RunEndArray; +use crate::{RunEndArray, RunEndEncoding}; -impl CompareFn for RunEndArray { - fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult> { +impl CompareFn for RunEndEncoding { + fn compare( + &self, + lhs: &RunEndArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { // If the RHS is constant, then we just need to compare against our encoded values. - if let Some(const_scalar) = other.as_constant() { + if let Some(const_scalar) = rhs.as_constant() { return compare( - self.values(), - ConstantArray::new(const_scalar, self.values().len()), + lhs.values(), + ConstantArray::new(const_scalar, lhs.values().len()), operator, ) .and_then(|values| { - Self::with_offset_and_length( - self.ends(), + RunEndArray::with_offset_and_length( + lhs.ends(), values, - self.validity().into_nullable(), - self.offset(), - self.len(), + lhs.validity().into_nullable(), + lhs.offset(), + lhs.len(), ) }) .map(|a| a.into_array()) diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index 2485337e6b..1ac3b20d4b 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -19,13 +19,13 @@ use vortex_scalar::{Scalar, ScalarValue}; use crate::{RunEndArray, RunEndEncoding}; -impl ArrayCompute for RunEndArray { - fn compare(&self) -> Option<&dyn CompareFn> { +impl ArrayCompute for RunEndArray {} + +impl ComputeVTable for RunEndEncoding { + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } -} -impl ComputeVTable for RunEndEncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index 07ee90f9d0..c506786761 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -14,17 +14,17 @@ mod scalar_at; mod slice; mod take; -impl ArrayCompute for ChunkedArray { - fn compare(&self) -> Option<&dyn CompareFn> { - Some(self) - } -} +impl ArrayCompute for ChunkedArray {} impl ComputeVTable for ChunkedEncoding { fn cast_fn(&self) -> Option<&dyn CastFn> { Some(self) } + fn compare_fn(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } @@ -56,13 +56,18 @@ impl CastFn for ChunkedEncoding { } } -impl CompareFn for ChunkedArray { - fn compare(&self, array: &ArrayData, operator: Operator) -> VortexResult> { +impl CompareFn for ChunkedEncoding { + fn compare( + &self, + lhs: &ChunkedArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { let mut idx = 0; - let mut compare_chunks = Vec::with_capacity(self.nchunks()); + let mut compare_chunks = Vec::with_capacity(lhs.nchunks()); - for chunk in self.chunks() { - let sliced = slice(array, idx, idx + chunk.len())?; + for chunk in lhs.chunks() { + let sliced = slice(rhs, idx, idx + chunk.len())?; let cmp_result = compare(&chunk, &sliced, operator)?; compare_chunks.push(cmp_result); diff --git a/vortex-array/src/array/constant/compute/compare.rs b/vortex-array/src/array/constant/compute/compare.rs index df3a3bb28c..d267b44699 100644 --- a/vortex-array/src/array/constant/compute/compare.rs +++ b/vortex-array/src/array/constant/compute/compare.rs @@ -1,17 +1,22 @@ use vortex_error::VortexResult; -use crate::array::ConstantArray; +use crate::array::{ConstantArray, ConstantEncoding}; use crate::compute::{scalar_cmp, CompareFn, Operator}; use crate::{ArrayData, ArrayLen, IntoArrayData}; -impl CompareFn for ConstantArray { - fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult> { +impl CompareFn for ConstantEncoding { + fn compare( + &self, + lhs: &ConstantArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { // We only support comparing a constant array to another constant array. // For all other encodings, we assume the constant is on the RHS. - if let Some(const_scalar) = other.as_constant() { - let lhs = self.owned_scalar(); - let scalar = scalar_cmp(&lhs, &const_scalar, operator); - return Ok(Some(ConstantArray::new(scalar, self.len()).into_array())); + if let Some(const_scalar) = rhs.as_constant() { + let lhs_scalar = lhs.owned_scalar(); + let scalar = scalar_cmp(&lhs_scalar, &const_scalar, operator); + return Ok(Some(ConstantArray::new(scalar, lhs.len()).into_array())); } Ok(None) diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index aa18a9a18a..2508363881 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -14,11 +14,7 @@ use crate::compute::{ }; use crate::{ArrayData, IntoArrayData}; -impl ArrayCompute for ConstantArray { - fn compare(&self) -> Option<&dyn CompareFn> { - Some(self) - } -} +impl ArrayCompute for ConstantArray {} impl ComputeVTable for ConstantEncoding { fn binary_boolean_fn( @@ -31,6 +27,10 @@ impl ComputeVTable for ConstantEncoding { (lhs.is_constant() && rhs.is_constant()).then_some(self) } + fn compare_fn(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/vortex-array/src/array/extension/compute/compare.rs b/vortex-array/src/array/extension/compute/compare.rs index 911dfbc9a7..571d8ec118 100644 --- a/vortex-array/src/array/extension/compute/compare.rs +++ b/vortex-array/src/array/extension/compute/compare.rs @@ -6,23 +6,28 @@ use crate::compute::{compare, CompareFn, Operator}; use crate::encoding::EncodingVTable; use crate::{ArrayDType, ArrayData, ArrayLen}; -impl CompareFn for ExtensionArray { - fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult> { +impl CompareFn for ExtensionEncoding { + fn compare( + &self, + lhs: &ExtensionArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { // If the RHS is a constant, we can extract the storage scalar. - if let Some(const_ext) = other.as_constant() { + if let Some(const_ext) = rhs.as_constant() { let scalar_ext = ExtScalar::try_new(const_ext.dtype(), const_ext.value())?; let storage_scalar = ConstantArray::new( - Scalar::new(self.storage().dtype().clone(), scalar_ext.value().clone()), - self.len(), + Scalar::new(lhs.storage().dtype().clone(), scalar_ext.value().clone()), + lhs.len(), ); - return compare(self.storage(), storage_scalar, operator).map(Some); + return compare(lhs.storage(), storage_scalar, operator).map(Some); } // If the RHS is an extension array matching ours, we can extract the storage. - if other.is_encoding(ExtensionEncoding.id()) { - let rhs_ext = ExtensionArray::try_from(other.clone())?; - return compare(self.storage(), rhs_ext.storage(), operator).map(Some); + if rhs.is_encoding(ExtensionEncoding.id()) { + let rhs_ext = ExtensionArray::try_from(rhs.clone())?; + return compare(lhs.storage(), rhs_ext.storage(), operator).map(Some); } // Otherwise, we need the RHS to handle this comparison. diff --git a/vortex-array/src/array/extension/compute/mod.rs b/vortex-array/src/array/extension/compute/mod.rs index e3786efe5f..eb1d500bbc 100644 --- a/vortex-array/src/array/extension/compute/mod.rs +++ b/vortex-array/src/array/extension/compute/mod.rs @@ -12,11 +12,7 @@ use crate::compute::{ use crate::variants::ExtensionArrayTrait; use crate::{ArrayData, IntoArrayData}; -impl ArrayCompute for ExtensionArray { - fn compare(&self) -> Option<&dyn CompareFn> { - Some(self) - } -} +impl ArrayCompute for ExtensionArray {} impl ComputeVTable for ExtensionEncoding { fn cast_fn(&self) -> Option<&dyn CastFn> { @@ -26,6 +22,10 @@ impl ComputeVTable for ExtensionEncoding { None } + fn compare_fn(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 8b481c3c46..7fae112c6b 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -3,10 +3,11 @@ use std::fmt::{Display, Formatter}; use arrow_ord::cmp; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_scalar::Scalar; use crate::arrow::{Datum, FromArrowArray}; +use crate::encoding::Encoding; use crate::{ArrayDType, ArrayData}; #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)] @@ -69,10 +70,36 @@ impl Operator { } } -pub trait CompareFn { +pub trait CompareFn { /// Compares two arrays and returns a new boolean array with the result of the comparison. /// Or, returns None if comparison is not supported for these arrays. - fn compare(&self, other: &ArrayData, operator: Operator) -> VortexResult>; + fn compare( + &self, + lhs: &Array, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult>; +} + +impl CompareFn for E +where + E: CompareFn, + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + fn compare( + &self, + lhs: &ArrayData, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { + let lhs_ref = <&E::Array>::try_from(lhs)?; + let encoding = lhs + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + CompareFn::compare(encoding, lhs_ref, rhs, operator) + } } pub fn compare( @@ -102,10 +129,11 @@ pub fn compare( return arrow_compare(left, right, operator); } - if let Some(result) = left.with_dyn(|lhs| { - lhs.compare() - .and_then(|f| f.compare(right, operator).transpose()) - }) { + if let Some(result) = left + .encoding() + .compare_fn() + .and_then(|f| f.compare(left, right, operator).transpose()) + { return result; } else { log::debug!( @@ -116,10 +144,11 @@ pub fn compare( ); } - if let Some(result) = right.with_dyn(|rhs| { - rhs.compare() - .and_then(|f| f.compare(left, operator.swap()).transpose()) - }) { + if let Some(result) = right + .encoding() + .compare_fn() + .and_then(|f| f.compare(right, left, operator.swap()).transpose()) + { return result; } else { log::debug!( diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 308e171d91..5b6b619640 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -23,7 +23,6 @@ pub trait FilterFn { fn filter(&self, array: &Array, mask: FilterMask) -> VortexResult; } -// TODO(ngates): write a macro for dispatching array-specific compute over ArrayData. impl FilterFn for E where E: FilterFn, diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index e4d0a0085f..94150c190e 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -46,6 +46,13 @@ pub trait ComputeVTable { None } + /// Binary operator implementation for arrays against other arrays. + /// + ///See: [CompareFn]. + fn compare_fn(&self) -> Option<&dyn CompareFn> { + None + } + /// Array function that returns new arrays a non-null value is repeated across runs of nulls. /// /// See: [FillForwardFn]. @@ -98,11 +105,4 @@ pub trait ComputeVTable { } /// Trait providing compute functions on top of Vortex arrays. -pub trait ArrayCompute { - /// Binary operator implementation for arrays against other arrays. - /// - ///See: [CompareFn]. - fn compare(&self) -> Option<&dyn CompareFn> { - None - } -} +pub trait ArrayCompute {}