Skip to content

Commit

Permalink
Some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn committed Dec 18, 2024
1 parent 1340d8d commit 65684ab
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
66 changes: 35 additions & 31 deletions encodings/fsst/src/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use fsst::Symbol;
use vortex_array::array::ConstantArray;
use vortex_array::compute::{compare, CompareFn, Operator};
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayVariant, ToArrayData};
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};
use vortex_buffer::Buffer;
use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_dtype::{DType, Nullability};
use vortex_error::{VortexExpect, VortexResult};
use vortex_scalar::Scalar;

use crate::{FSSTArray, FSSTEncoding};

Expand All @@ -16,10 +17,16 @@ impl CompareFn<FSSTArray> for FSSTEncoding {
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
match (rhs.as_constant(), operator) {
// TODO(ngates): implement short-circuit comparisons for other operators.
(Some(constant_array), Operator::Eq | Operator::NotEq) => compare_fsst_constant(
(Some(constant), _) if constant.is_null() => {
// All comparisons to null must return null
Ok(Some(
ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), lhs.len())
.into_array(),
))
}
(Some(constant), Operator::Eq | Operator::NotEq) => compare_fsst_constant(
lhs,
&ConstantArray::new(constant_array, lhs.len()),
&ConstantArray::new(constant, lhs.len()),
operator == Operator::Eq,
)
.map(Some),
Expand Down Expand Up @@ -49,34 +56,31 @@ fn compare_fsst_constant(
let compressor = compressor.build();

let encoded_scalar = match left.dtype() {
DType::Utf8(_) => right
.scalar()
.as_utf8()
.value()
.map(|scalar| Buffer::from(compressor.compress(scalar.as_bytes()))),
DType::Binary(_) => right
.scalar()
.as_binary()
.value()
.map(|scalar| Buffer::from(compressor.compress(scalar.as_slice()))),
DType::Utf8(_) => {
let value = right
.scalar()
.as_utf8()
.value()
.vortex_expect("Expected non-null scalar");
Buffer::from(compressor.compress(value.as_bytes()))
}
DType::Binary(_) => {
let value = right
.scalar()
.as_binary()
.value()
.vortex_expect("Expected non-null scalar");
Buffer::from(compressor.compress(value.as_slice()))
}
_ => unreachable!("FSSTArray can only have string or binary data type"),
};

match encoded_scalar {
None => {
// Eq and NotEq on null values yield nulls, per the Arrow behavior.
Ok(right.to_array())
}
Some(encoded_scalar) => {
let rhs = ConstantArray::new(encoded_scalar, left.len());

compare(
left.codes(),
rhs,
if equal { Operator::Eq } else { Operator::NotEq },
)
}
}
let rhs = ConstantArray::new(encoded_scalar, left.len());
compare(
left.codes(),
rhs,
if equal { Operator::Eq } else { Operator::NotEq },
)
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion pyvortex/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ pub fn scalar_helper(dtype: DType, value: &Bound<'_, PyAny>) -> PyResult<Scalar>
.iter()
.map(|element| scalar_helper(element_type.as_ref().clone(), element))
.collect::<PyResult<Vec<_>>>()?;
Ok(Scalar::list(element_type, values))
Ok(Scalar::list(element_type, values, Nullability::Nullable))
}
DType::Extension(..) => todo!(),
}
Expand Down
3 changes: 2 additions & 1 deletion vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ pub(crate) fn arrow_compare(
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<ArrayData> {
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
let lhs = Datum::try_from(lhs.clone())?;
let rhs = Datum::try_from(rhs.clone())?;

Expand All @@ -197,7 +198,7 @@ pub(crate) fn arrow_compare(
Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
};

Ok(ArrayData::from_arrow(&array, true))
Ok(ArrayData::from_arrow(&array, nullable))
}

pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/compute/scalar_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ pub fn scalar_at(array: impl AsRef<ArrayData>, index: usize) -> VortexResult<Sca
.unwrap_or_else(|| Err(vortex_err!(NotImplemented: "scalar_at", array.encoding().id())))?;

debug_assert_eq!(
array.dtype(),
scalar.dtype(),
array.dtype(),
"ScalarAt dtype mismatch {}",
array.encoding().id()
);
Expand Down

0 comments on commit 65684ab

Please sign in to comment.