From 65684ab43cdbcc67c0a62e128c9b947b772064ed Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 18 Dec 2024 18:16:10 +0000 Subject: [PATCH] Some fixes --- encodings/fsst/src/compute/compare.rs | 66 ++++++++++++++------------- pyvortex/src/expr.rs | 2 +- vortex-array/src/compute/compare.rs | 3 +- vortex-array/src/compute/scalar_at.rs | 2 +- 4 files changed, 39 insertions(+), 34 deletions(-) diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 2375edfcf9..dbdcff4324 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -1,10 +1,11 @@ use fsst::Symbol; use vortex_array::array::ConstantArray; use vortex_array::compute::{compare, CompareFn, Operator}; -use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayVariant, ToArrayData}; +use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::VortexResult; +use vortex_dtype::{DType, Nullability}; +use vortex_error::{VortexExpect, VortexResult}; +use vortex_scalar::Scalar; use crate::{FSSTArray, FSSTEncoding}; @@ -16,10 +17,16 @@ impl CompareFn for FSSTEncoding { operator: Operator, ) -> VortexResult> { match (rhs.as_constant(), operator) { - // TODO(ngates): implement short-circuit comparisons for other operators. - (Some(constant_array), Operator::Eq | Operator::NotEq) => compare_fsst_constant( + (Some(constant), _) if constant.is_null() => { + // All comparisons to null must return null + Ok(Some( + ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), lhs.len()) + .into_array(), + )) + } + (Some(constant), Operator::Eq | Operator::NotEq) => compare_fsst_constant( lhs, - &ConstantArray::new(constant_array, lhs.len()), + &ConstantArray::new(constant, lhs.len()), operator == Operator::Eq, ) .map(Some), @@ -49,34 +56,31 @@ fn compare_fsst_constant( let compressor = compressor.build(); let encoded_scalar = match left.dtype() { - DType::Utf8(_) => right - .scalar() - .as_utf8() - .value() - .map(|scalar| Buffer::from(compressor.compress(scalar.as_bytes()))), - DType::Binary(_) => right - .scalar() - .as_binary() - .value() - .map(|scalar| Buffer::from(compressor.compress(scalar.as_slice()))), + DType::Utf8(_) => { + let value = right + .scalar() + .as_utf8() + .value() + .vortex_expect("Expected non-null scalar"); + Buffer::from(compressor.compress(value.as_bytes())) + } + DType::Binary(_) => { + let value = right + .scalar() + .as_binary() + .value() + .vortex_expect("Expected non-null scalar"); + Buffer::from(compressor.compress(value.as_slice())) + } _ => unreachable!("FSSTArray can only have string or binary data type"), }; - match encoded_scalar { - None => { - // Eq and NotEq on null values yield nulls, per the Arrow behavior. - Ok(right.to_array()) - } - Some(encoded_scalar) => { - let rhs = ConstantArray::new(encoded_scalar, left.len()); - - compare( - left.codes(), - rhs, - if equal { Operator::Eq } else { Operator::NotEq }, - ) - } - } + let rhs = ConstantArray::new(encoded_scalar, left.len()); + compare( + left.codes(), + rhs, + if equal { Operator::Eq } else { Operator::NotEq }, + ) } #[cfg(test)] diff --git a/pyvortex/src/expr.rs b/pyvortex/src/expr.rs index 0f9b9972b9..2e1456d048 100644 --- a/pyvortex/src/expr.rs +++ b/pyvortex/src/expr.rs @@ -304,7 +304,7 @@ pub fn scalar_helper(dtype: DType, value: &Bound<'_, PyAny>) -> PyResult .iter() .map(|element| scalar_helper(element_type.as_ref().clone(), element)) .collect::>>()?; - Ok(Scalar::list(element_type, values)) + Ok(Scalar::list(element_type, values, Nullability::Nullable)) } DType::Extension(..) => todo!(), } diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 94aa673d13..affacec13f 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -185,6 +185,7 @@ pub(crate) fn arrow_compare( rhs: &ArrayData, operator: Operator, ) -> VortexResult { + let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); let lhs = Datum::try_from(lhs.clone())?; let rhs = Datum::try_from(rhs.clone())?; @@ -197,7 +198,7 @@ pub(crate) fn arrow_compare( Operator::Lte => cmp::lt_eq(&lhs, &rhs)?, }; - Ok(ArrayData::from_arrow(&array, true)) + Ok(ArrayData::from_arrow(&array, nullable)) } pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { diff --git a/vortex-array/src/compute/scalar_at.rs b/vortex-array/src/compute/scalar_at.rs index e9e3aae58b..428bf6df1f 100644 --- a/vortex-array/src/compute/scalar_at.rs +++ b/vortex-array/src/compute/scalar_at.rs @@ -45,8 +45,8 @@ pub fn scalar_at(array: impl AsRef, index: usize) -> VortexResult