Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debug assertions to ComputeFn results #1716

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion vortex-array/src/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use arrow_array::ArrayRef;
use vortex_dtype::DType;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_scalar::{BinaryNumericOperator, Scalar};

Expand Down Expand Up @@ -117,13 +117,43 @@ pub fn binary_numeric(
// Check if LHS supports the operation directly.
if let Some(fun) = lhs.encoding().binary_numeric_fn() {
if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
debug_assert_eq!(
result.len(),
lhs.len(),
"Numeric operation length mismatch {}",
lhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Primitive(
PType::try_from(lhs.dtype())?,
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
),
"Numeric operation dtype mismatch {}",
lhs.encoding().id()
);
return Ok(result);
}
}

// Check if RHS supports the operation directly.
if let Some(fun) = rhs.encoding().binary_numeric_fn() {
if let Some(result) = fun.binary_numeric(rhs, lhs, op)? {
debug_assert_eq!(
result.len(),
lhs.len(),
"Numeric operation length mismatch {}",
rhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Primitive(
PType::try_from(lhs.dtype())?,
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
),
"Numeric operation dtype mismatch {}",
rhs.encoding().id()
);
return Ok(result);
}
}
Expand Down
31 changes: 29 additions & 2 deletions vortex-array/src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::ArrayRef;
use vortex_dtype::DType;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};

use crate::arrow::FromArrowArray;
Expand Down Expand Up @@ -106,16 +107,42 @@ pub fn binary_boolean(
.encoding()
.binary_boolean_fn()
.and_then(|f| f.binary_boolean(lhs, rhs, op).transpose())
.transpose()?
{
return result;
debug_assert_eq!(
result.len(),
lhs.len(),
"Boolean operation length mismatch {}",
lhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
"Boolean operation dtype mismatch {}",
lhs.encoding().id()
);
return Ok(result);
}

if let Some(result) = rhs
.encoding()
.binary_boolean_fn()
.and_then(|f| f.binary_boolean(rhs, lhs, op).transpose())
.transpose()?
{
return result;
debug_assert_eq!(
result.len(),
lhs.len(),
"Boolean operation length mismatch {}",
rhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
"Boolean operation dtype mismatch {}",
rhs.encoding().id()
);
return Ok(result);
}

log::debug!(
Expand Down
19 changes: 19 additions & 0 deletions vortex-array/src/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ pub fn try_cast(array: impl AsRef<ArrayData>, dtype: &DType) -> VortexResult<Arr
return Ok(array.clone());
}

let casted = try_cast_impl(array, dtype)?;

debug_assert_eq!(
casted.len(),
array.len(),
"Cast length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
casted.dtype(),
dtype,
"Cast dtype mismatch {}",
array.encoding().id()
);

Ok(casted)
}

fn try_cast_impl(array: &ArrayData, dtype: &DType) -> VortexResult<ArrayData> {
// TODO(ngates): check for null_count if dtype is non-nullable
if let Some(f) = array.encoding().cast_fn() {
return f.cast(array, dtype);
Expand Down
30 changes: 28 additions & 2 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,42 @@ pub fn compare(
.encoding()
.compare_fn()
.and_then(|f| f.compare(left, right, operator).transpose())
.transpose()?
{
return result;
debug_assert_eq!(
result.len(),
left.len(),
"Compare length mismatch {}",
left.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into()),
"Compare dtype mismatch {}",
left.encoding().id()
);
return Ok(result);
}

if let Some(result) = right
.encoding()
.compare_fn()
.and_then(|f| f.compare(right, left, operator.swap()).transpose())
.transpose()?
{
return result;
debug_assert_eq!(
result.len(),
left.len(),
"Compare length mismatch {}",
right.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into()),
"Compare dtype mismatch {}",
right.encoding().id()
);
return Ok(result);
}

// Only log missing compare implementation if there's possibly better one than arrow,
Expand Down
20 changes: 18 additions & 2 deletions vortex-array/src/compute/fill_forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ pub fn fill_forward(array: impl AsRef<ArrayData>) -> VortexResult<ArrayData> {
if !array.dtype().is_nullable() {
return Ok(array.clone());
}
array

let filled = array
.encoding()
.fill_forward_fn()
.map(|f| f.fill_forward(array))
Expand All @@ -42,5 +43,20 @@ pub fn fill_forward(array: impl AsRef<ArrayData>) -> VortexResult<ArrayData> {
NotImplemented: "fill_forward",
array.encoding().id()
))
})
})?;

debug_assert_eq!(
filled.len(),
array.len(),
"FillForward length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
filled.dtype(),
array.dtype(),
"FillForward dtype mismatch {}",
array.encoding().id()
);

Ok(filled)
}
20 changes: 20 additions & 0 deletions vortex-array/src/compute/fill_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ pub fn fill_null(array: impl AsRef<ArrayData>, fill_value: Scalar) -> VortexResu
vortex_bail!(MismatchedTypes: array.dtype(), fill_value.dtype())
}

let fill_value_nullability = fill_value.dtype().nullability();
let filled = fill_null_impl(array, fill_value)?;

debug_assert_eq!(
filled.len(),
array.len(),
"FillNull length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
filled.dtype(),
&array.dtype().with_nullability(fill_value_nullability),
"FillNull dtype mismatch {}",
array.encoding().id()
);

Ok(filled)
}

fn fill_null_impl(array: &ArrayData, fill_value: Scalar) -> VortexResult<ArrayData> {
if let Some(fill_null_fn) = array.encoding().fill_null_fn() {
return fill_null_fn.fill_null(array, fill_value);
}
Expand Down
77 changes: 40 additions & 37 deletions vortex-array/src/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,55 +60,58 @@ pub fn filter(array: &ArrayData, mask: FilterMask) -> VortexResult<ArrayData> {
);
}

let true_count = mask.true_count();

// Fast-path for empty mask.
if mask.true_count() == 0 {
if true_count == 0 {
return Ok(Canonical::empty(array.dtype())?.into());
}

// Fast-path for full mask
if mask.true_count() == mask.len() {
if true_count == mask.len() {
return Ok(array.clone());
}

let filtered = filter_impl(array, mask)?;

debug_assert_eq!(
filtered.len(),
true_count,
"Filter length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
filtered.dtype(),
array.dtype(),
"Filter dtype mismatch {}",
array.encoding().id()
);

Ok(filtered)
}

fn filter_impl(array: &ArrayData, mask: FilterMask) -> VortexResult<ArrayData> {
if let Some(filter_fn) = array.encoding().filter_fn() {
let true_count = mask.true_count();
let result = filter_fn.filter(array, mask)?;
if array.dtype() != result.dtype() {
vortex_bail!(
"FilterFn {} changed array dtype from {} to {}",
array.encoding().id(),
array.dtype(),
result.dtype()
);
}
if true_count != result.len() {
vortex_bail!(
"FilterFn {} returned incorrect length: expected {}, got {}",
array.encoding().id(),
true_count,
result.len()
);
}
Ok(result)
} else {
// We can use scalar_at if the mask has length 1.
if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() {
let idx = mask.indices()?[0];
return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array());
}
return filter_fn.filter(array, mask);
}

// Fallback: implement using Arrow kernels.
log::debug!(
"No filter implementation found for {}",
array.encoding().id(),
);
// We can use scalar_at if the mask has length 1.
if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() {
let idx = mask.indices()?[0];
return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array());
}

let array_ref = array.clone().into_arrow()?;
let mask_array = BooleanArray::new(mask.to_boolean_buffer()?, None);
let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?;
// Fallback: implement using Arrow kernels.
log::debug!(
"No filter implementation found for {}",
array.encoding().id(),
);

Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable()))
}
let array_ref = array.clone().into_arrow()?;
let mask_array = BooleanArray::new(mask.to_boolean_buffer()?, None);
let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?;

Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable()))
}

/// Represents the mask argument to a filter function.
Expand Down
17 changes: 16 additions & 1 deletion vortex-array/src/compute/invert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,22 @@ pub fn invert(array: &ArrayData) -> VortexResult<ArrayData> {
}

if let Some(f) = array.encoding().invert_fn() {
return f.invert(array);
let inverted = f.invert(array)?;

debug_assert_eq!(
inverted.len(),
array.len(),
"Invert length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
inverted.dtype(),
array.dtype(),
"Invert dtype mismatch {}",
array.encoding().id()
);

return Ok(inverted);
}

// Otherwise, we canonicalize into a boolean array and invert.
Expand Down
17 changes: 16 additions & 1 deletion vortex-array/src/compute/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,22 @@ pub fn like(
}

if let Some(f) = array.encoding().like_fn() {
return f.like(array, pattern, options);
let result = f.like(array, pattern, options)?;

debug_assert_eq!(
result.len(),
array.len(),
"Like length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((array.dtype().is_nullable() || pattern.dtype().is_nullable()).into()),
"Like dtype mismatch {}",
array.encoding().id()
);

return Ok(result);
}

// Otherwise, we canonicalize into a UTF8 array.
Expand Down
Loading
Loading