Skip to content

Commit

Permalink
Make BinaryBooleanFn consistent with CompareFn (#1488)
Browse files Browse the repository at this point in the history
So now the implementation returns `Result<Option<ArrayData>>` and the
dispatch logic is almost identical.
  • Loading branch information
gatesn authored Nov 27, 2024
1 parent 7507c99 commit dbeb724
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 28 deletions.
6 changes: 1 addition & 5 deletions vortex-array/src/array/bool/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ mod slice;
mod take;

impl ComputeVTable for BoolEncoding {
fn binary_boolean_fn(
&self,
_lhs: &ArrayData,
_rhs: &ArrayData,
) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
fn binary_boolean_fn(&self) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
// We only implement this when other is a constant value, otherwise we fall back to the
// default implementation that canonicalizes to Arrow.
// TODO(ngates): implement this for constants.
Expand Down
10 changes: 8 additions & 2 deletions vortex-array/src/array/constant/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ impl BinaryBooleanFn<ConstantArray> for ConstantEncoding {
lhs: &ConstantArray,
rhs: &ArrayData,
op: BinaryOperator,
) -> VortexResult<ArrayData> {
) -> VortexResult<Option<ArrayData>> {
// We only implement this for constant <-> constant arrays, otherwise we allow fall back
// to the Arrow implementation.
if !rhs.is_constant() {
return Ok(None);
}

let length = lhs.len();
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
let lhs = lhs.scalar().as_bool().value();
Expand All @@ -35,7 +41,7 @@ impl BinaryBooleanFn<ConstantArray> for ConstantEncoding {
.map(|b| Scalar::bool(b, nullable.into()))
.unwrap_or_else(|| Scalar::null(DType::Bool(nullable.into())));

Ok(ConstantArray::new(scalar, length).into_array())
Ok(Some(ConstantArray::new(scalar, length).into_array()))
}
}

Expand Down
10 changes: 2 additions & 8 deletions vortex-array/src/array/constant/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@ use crate::compute::{
use crate::{ArrayData, IntoArrayData};

impl ComputeVTable for ConstantEncoding {
fn binary_boolean_fn(
&self,
lhs: &ArrayData,
rhs: &ArrayData,
) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
// We only need to deal with this if both sides are constant, otherwise other arrays
// will have handled the RHS being constant.
(lhs.is_constant() && rhs.is_constant()).then_some(self)
fn binary_boolean_fn(&self) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Expand Down
26 changes: 20 additions & 6 deletions vortex-array/src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub trait BinaryBooleanFn<Array> {
array: &Array,
other: &ArrayData,
op: BinaryOperator,
) -> VortexResult<ArrayData>;
) -> VortexResult<Option<ArrayData>>;
}

impl<E: Encoding> BinaryBooleanFn<ArrayData> for E
Expand All @@ -38,7 +38,7 @@ where
lhs: &ArrayData,
rhs: &ArrayData,
op: BinaryOperator,
) -> VortexResult<ArrayData> {
) -> VortexResult<Option<ArrayData>> {
let array_ref = <&E::Array>::try_from(lhs)?;
let encoding = lhs
.encoding()
Expand Down Expand Up @@ -92,9 +92,18 @@ fn binary_boolean(lhs: &ArrayData, rhs: &ArrayData, op: BinaryOperator) -> Vorte
return binary_boolean(rhs, lhs, op);
}

// If the RHS is constant and the LHS is Arrow, we can't do any better than arrow_compare.
if lhs.is_arrow() && rhs.is_constant() {
return arrow_boolean(lhs.clone(), rhs.clone(), op);
}

// Check if either LHS or RHS supports the operation directly.
if let Some(f) = lhs.encoding().binary_boolean_fn(lhs, rhs) {
return f.binary_boolean(lhs, rhs, op);
if let Some(result) = lhs
.encoding()
.binary_boolean_fn()
.and_then(|f| f.binary_boolean(lhs, rhs, op).transpose())
{
return result;
} else {
log::debug!(
"No boolean implementation found for LHS {}, RHS {}, and operator {:?}",
Expand All @@ -103,8 +112,13 @@ fn binary_boolean(lhs: &ArrayData, rhs: &ArrayData, op: BinaryOperator) -> Vorte
op,
);
}
if let Some(f) = rhs.encoding().binary_boolean_fn(rhs, lhs) {
return f.binary_boolean(rhs, lhs, op);

if let Some(result) = rhs
.encoding()
.binary_boolean_fn()
.and_then(|f| f.binary_boolean(rhs, lhs, op).transpose())
{
return result;
} else {
log::debug!(
"No boolean implementation found for LHS {}, RHS {}, and operator {:?}",
Expand Down
2 changes: 0 additions & 2 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ pub fn compare(
if left.len() != right.len() {
vortex_bail!("Compare operations only support arrays of the same length");
}

// TODO(adamg): This is a placeholder until we figure out type coercion and casting
if !left.dtype().eq_ignore_nullability(right.dtype()) {
vortex_bail!("Compare operations only support arrays of the same type");
}
Expand Down
6 changes: 1 addition & 5 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ pub trait ComputeVTable {
/// Implementation of binary boolean logic operations.
///
/// See: [BinaryBooleanFn].
fn binary_boolean_fn(
&self,
_lhs: &ArrayData,
_rhs: &ArrayData,
) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
fn binary_boolean_fn(&self) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
None
}

Expand Down

0 comments on commit dbeb724

Please sign in to comment.