From 826737dc96c74294d5651311c72d8be7f8a4b701 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Wed, 18 Dec 2024 18:06:00 +0000 Subject: [PATCH 1/3] ignore dev-targeted crates in coverage report (#1715) --- .github/coverage.yml | 9 +++++++++ .github/workflows/ci.yml | 3 +++ vortex-serde/src/layouts/write/mod.rs | 7 ------- 3 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 .github/coverage.yml delete mode 100644 vortex-serde/src/layouts/write/mod.rs diff --git a/.github/coverage.yml b/.github/coverage.yml new file mode 100644 index 0000000000..2d55d9e9a1 --- /dev/null +++ b/.github/coverage.yml @@ -0,0 +1,9 @@ +ignore-non-existing: true + +excl-line: "unreachable!" +ignore: + - "bench-vortex/*" + - "fuzz/*" + - "home/*" + - "/*" + - "../*" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9104e88caa..320aa97514 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -157,12 +157,15 @@ jobs: RUSTDOCFLAGS: '-Zprofile' - uses: rraval/actions-rs-grcov@e96292badb0d33512d16654efb0ee3032a9a3cff id: grcov + with: + config: ".github/coverage.yml" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Coveralls uses: coverallsapp/github-action@v2 with: file: "${{ steps.grcov.outputs.report }}" + license-check-and-audit-check: name: License Check and Audit Check diff --git a/vortex-serde/src/layouts/write/mod.rs b/vortex-serde/src/layouts/write/mod.rs deleted file mode 100644 index 5db6f3505d..0000000000 --- a/vortex-serde/src/layouts/write/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub use layouts::LayoutSpec; -pub use writer::LayoutWriter; - -mod footer; -mod layouts; -mod metadata_accumulators; -mod writer; From 6e7f731a831a590aa1183b6af575dc43f51dc3c1 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 18 Dec 2024 18:56:16 +0000 Subject: [PATCH 2/3] Add debug assertions to ComputeFn results (#1716) --- encodings/fsst/src/compute/compare.rs | 66 ++++++++-------- pyvortex/src/expr.rs | 2 +- .../src/array/chunked/compute/compare.rs | 10 ++- vortex-array/src/array/list/compute/mod.rs | 6 +- vortex-array/src/array/list/mod.rs | 20 ++++- .../src/array/varbin/compute/compare.rs | 6 +- vortex-array/src/builders/list.rs | 16 +++- vortex-array/src/compute/binary_numeric.rs | 32 +++++++- vortex-array/src/compute/boolean.rs | 31 +++++++- vortex-array/src/compute/cast.rs | 19 +++++ vortex-array/src/compute/compare.rs | 38 ++++++++- vortex-array/src/compute/fill_forward.rs | 20 ++++- vortex-array/src/compute/fill_null.rs | 20 +++++ vortex-array/src/compute/filter.rs | 77 ++++++++++--------- vortex-array/src/compute/invert.rs | 17 +++- vortex-array/src/compute/like.rs | 17 +++- vortex-array/src/compute/scalar_at.rs | 13 +++- vortex-array/src/compute/slice.rs | 21 ++++- vortex-array/src/compute/take.rs | 32 ++++++-- vortex-array/src/data/viewed.rs | 2 +- vortex-scalar/src/list.rs | 15 ++-- 21 files changed, 370 insertions(+), 110 deletions(-) diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 2375edfcf9..dbdcff4324 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -1,10 +1,11 @@ use fsst::Symbol; use vortex_array::array::ConstantArray; use vortex_array::compute::{compare, CompareFn, Operator}; -use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayVariant, ToArrayData}; +use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::VortexResult; +use vortex_dtype::{DType, Nullability}; +use vortex_error::{VortexExpect, VortexResult}; +use vortex_scalar::Scalar; use crate::{FSSTArray, FSSTEncoding}; @@ -16,10 +17,16 @@ impl CompareFn for FSSTEncoding { operator: Operator, ) -> VortexResult> { match (rhs.as_constant(), operator) { - // TODO(ngates): implement short-circuit comparisons for other operators. - (Some(constant_array), Operator::Eq | Operator::NotEq) => compare_fsst_constant( + (Some(constant), _) if constant.is_null() => { + // All comparisons to null must return null + Ok(Some( + ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), lhs.len()) + .into_array(), + )) + } + (Some(constant), Operator::Eq | Operator::NotEq) => compare_fsst_constant( lhs, - &ConstantArray::new(constant_array, lhs.len()), + &ConstantArray::new(constant, lhs.len()), operator == Operator::Eq, ) .map(Some), @@ -49,34 +56,31 @@ fn compare_fsst_constant( let compressor = compressor.build(); let encoded_scalar = match left.dtype() { - DType::Utf8(_) => right - .scalar() - .as_utf8() - .value() - .map(|scalar| Buffer::from(compressor.compress(scalar.as_bytes()))), - DType::Binary(_) => right - .scalar() - .as_binary() - .value() - .map(|scalar| Buffer::from(compressor.compress(scalar.as_slice()))), + DType::Utf8(_) => { + let value = right + .scalar() + .as_utf8() + .value() + .vortex_expect("Expected non-null scalar"); + Buffer::from(compressor.compress(value.as_bytes())) + } + DType::Binary(_) => { + let value = right + .scalar() + .as_binary() + .value() + .vortex_expect("Expected non-null scalar"); + Buffer::from(compressor.compress(value.as_slice())) + } _ => unreachable!("FSSTArray can only have string or binary data type"), }; - match encoded_scalar { - None => { - // Eq and NotEq on null values yield nulls, per the Arrow behavior. - Ok(right.to_array()) - } - Some(encoded_scalar) => { - let rhs = ConstantArray::new(encoded_scalar, left.len()); - - compare( - left.codes(), - rhs, - if equal { Operator::Eq } else { Operator::NotEq }, - ) - } - } + let rhs = ConstantArray::new(encoded_scalar, left.len()); + compare( + left.codes(), + rhs, + if equal { Operator::Eq } else { Operator::NotEq }, + ) } #[cfg(test)] diff --git a/pyvortex/src/expr.rs b/pyvortex/src/expr.rs index 0f9b9972b9..2e1456d048 100644 --- a/pyvortex/src/expr.rs +++ b/pyvortex/src/expr.rs @@ -304,7 +304,7 @@ pub fn scalar_helper(dtype: DType, value: &Bound<'_, PyAny>) -> PyResult .iter() .map(|element| scalar_helper(element_type.as_ref().clone(), element)) .collect::>>()?; - Ok(Scalar::list(element_type, values)) + Ok(Scalar::list(element_type, values, Nullability::Nullable)) } DType::Extension(..) => todo!(), } diff --git a/vortex-array/src/array/chunked/compute/compare.rs b/vortex-array/src/array/chunked/compute/compare.rs index 5dcb31e7bb..7839d749c7 100644 --- a/vortex-array/src/array/chunked/compute/compare.rs +++ b/vortex-array/src/array/chunked/compute/compare.rs @@ -1,9 +1,9 @@ -use vortex_dtype::{DType, Nullability}; +use vortex_dtype::DType; use vortex_error::VortexResult; use crate::array::{ChunkedArray, ChunkedEncoding}; use crate::compute::{compare, slice, CompareFn, Operator}; -use crate::{ArrayData, IntoArrayData}; +use crate::{ArrayDType, ArrayData, IntoArrayData}; impl CompareFn for ChunkedEncoding { fn compare( @@ -24,7 +24,11 @@ impl CompareFn for ChunkedEncoding { } Ok(Some( - ChunkedArray::try_new(compare_chunks, DType::Bool(Nullability::Nullable))?.into_array(), + ChunkedArray::try_new( + compare_chunks, + DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()), + )? + .into_array(), )) } } diff --git a/vortex-array/src/array/list/compute/mod.rs b/vortex-array/src/array/list/compute/mod.rs index f5e6f8b4e3..317f574c75 100644 --- a/vortex-array/src/array/list/compute/mod.rs +++ b/vortex-array/src/array/list/compute/mod.rs @@ -23,7 +23,11 @@ impl ScalarAtFn for ListEncoding { let elem = array.elements_at(index)?; let scalars: Vec = (0..elem.len()).map(|i| scalar_at(&elem, i)).try_collect()?; - Ok(Scalar::list(Arc::new(elem.dtype().clone()), scalars)) + Ok(Scalar::list( + Arc::new(elem.dtype().clone()), + scalars, + array.dtype().nullability(), + )) } } diff --git a/vortex-array/src/array/list/mod.rs b/vortex-array/src/array/list/mod.rs index 45f25f76e8..cbea5d15fa 100644 --- a/vortex-array/src/array/list/mod.rs +++ b/vortex-array/src/array/list/mod.rs @@ -197,7 +197,7 @@ impl ValidityVTable for ListEncoding { mod test { use std::sync::Arc; - use vortex_dtype::PType; + use vortex_dtype::{Nullability, PType}; use vortex_scalar::Scalar; use crate::array::list::ListArray; @@ -228,15 +228,27 @@ mod test { ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap(); assert_eq!( - Scalar::list(Arc::new(PType::I32.into()), vec![1.into(), 2.into()]), + Scalar::list( + Arc::new(PType::I32.into()), + vec![1.into(), 2.into()], + Nullability::Nullable + ), scalar_at(&list, 0).unwrap() ); assert_eq!( - Scalar::list(Arc::new(PType::I32.into()), vec![3.into(), 4.into()]), + Scalar::list( + Arc::new(PType::I32.into()), + vec![3.into(), 4.into()], + Nullability::Nullable + ), scalar_at(&list, 1).unwrap() ); assert_eq!( - Scalar::list(Arc::new(PType::I32.into()), vec![5.into()]), + Scalar::list( + Arc::new(PType::I32.into()), + vec![5.into()], + Nullability::Nullable + ), scalar_at(&list, 2).unwrap() ); } diff --git a/vortex-array/src/array/varbin/compute/compare.rs b/vortex-array/src/array/varbin/compute/compare.rs index 7f06dd1dbc..d64395ae1d 100644 --- a/vortex-array/src/array/varbin/compute/compare.rs +++ b/vortex-array/src/array/varbin/compute/compare.rs @@ -6,7 +6,7 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::{VarBinArray, VarBinEncoding}; use crate::arrow::{Datum, FromArrowArray}; use crate::compute::{CompareFn, Operator}; -use crate::{ArrayData, IntoArrayData}; +use crate::{ArrayDType, ArrayData, IntoArrayData}; // This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical impl CompareFn for VarBinEncoding { @@ -17,6 +17,8 @@ impl CompareFn for VarBinEncoding { operator: Operator, ) -> VortexResult> { if let Some(rhs_const) = rhs.as_constant() { + let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable(); + let lhs = Datum::try_from(lhs.clone().into_array())?; // TODO(robert): Handle LargeString/Binary arrays @@ -46,7 +48,7 @@ impl CompareFn for VarBinEncoding { Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs)?, }; - Ok(Some(ArrayData::from_arrow(&array, true))) + Ok(Some(ArrayData::from_arrow(&array, nullable))) } else { Ok(None) } diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index 91fcb8d0e6..ae87bd00ad 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -156,17 +156,27 @@ mod tests { builder .append_value( - Scalar::list(dtype.clone(), vec![1i32.into(), 2i32.into(), 3i32.into()]).as_list(), + Scalar::list( + dtype.clone(), + vec![1i32.into(), 2i32.into(), 3i32.into()], + Nullability::NonNullable, + ) + .as_list(), ) .unwrap(); builder - .append_value(Scalar::empty(dtype.clone()).as_list()) + .append_value(Scalar::list_empty(dtype.clone(), Nullability::NonNullable).as_list()) .unwrap(); builder .append_value( - Scalar::list(dtype, vec![4i32.into(), 5i32.into(), 6i32.into()]).as_list(), + Scalar::list( + dtype, + vec![4i32.into(), 5i32.into(), 6i32.into()], + Nullability::NonNullable, + ) + .as_list(), ) .unwrap(); diff --git a/vortex-array/src/compute/binary_numeric.rs b/vortex-array/src/compute/binary_numeric.rs index 220608166d..e0a1889d06 100644 --- a/vortex-array/src/compute/binary_numeric.rs +++ b/vortex-array/src/compute/binary_numeric.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use arrow_array::ArrayRef; -use vortex_dtype::DType; +use vortex_dtype::{DType, PType}; use vortex_error::{vortex_bail, VortexError, VortexResult}; use vortex_scalar::{BinaryNumericOperator, Scalar}; @@ -117,6 +117,21 @@ pub fn binary_numeric( // Check if LHS supports the operation directly. if let Some(fun) = lhs.encoding().binary_numeric_fn() { if let Some(result) = fun.binary_numeric(lhs, rhs, op)? { + debug_assert_eq!( + result.len(), + lhs.len(), + "Numeric operation length mismatch {}", + lhs.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Primitive( + PType::try_from(lhs.dtype())?, + (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into() + ), + "Numeric operation dtype mismatch {}", + lhs.encoding().id() + ); return Ok(result); } } @@ -124,6 +139,21 @@ pub fn binary_numeric( // Check if RHS supports the operation directly. if let Some(fun) = rhs.encoding().binary_numeric_fn() { if let Some(result) = fun.binary_numeric(rhs, lhs, op)? { + debug_assert_eq!( + result.len(), + lhs.len(), + "Numeric operation length mismatch {}", + rhs.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Primitive( + PType::try_from(lhs.dtype())?, + (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into() + ), + "Numeric operation dtype mismatch {}", + rhs.encoding().id() + ); return Ok(result); } } diff --git a/vortex-array/src/compute/boolean.rs b/vortex-array/src/compute/boolean.rs index ab548fdc65..2bab7d22f6 100644 --- a/vortex-array/src/compute/boolean.rs +++ b/vortex-array/src/compute/boolean.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use arrow_array::cast::AsArray; use arrow_array::ArrayRef; +use vortex_dtype::DType; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use crate::arrow::FromArrowArray; @@ -106,16 +107,42 @@ pub fn binary_boolean( .encoding() .binary_boolean_fn() .and_then(|f| f.binary_boolean(lhs, rhs, op).transpose()) + .transpose()? { - return result; + debug_assert_eq!( + result.len(), + lhs.len(), + "Boolean operation length mismatch {}", + lhs.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()), + "Boolean operation dtype mismatch {}", + lhs.encoding().id() + ); + return Ok(result); } if let Some(result) = rhs .encoding() .binary_boolean_fn() .and_then(|f| f.binary_boolean(rhs, lhs, op).transpose()) + .transpose()? { - return result; + debug_assert_eq!( + result.len(), + lhs.len(), + "Boolean operation length mismatch {}", + rhs.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()), + "Boolean operation dtype mismatch {}", + rhs.encoding().id() + ); + return Ok(result); } log::debug!( diff --git a/vortex-array/src/compute/cast.rs b/vortex-array/src/compute/cast.rs index 7dedabcff9..e1c5dedb9d 100644 --- a/vortex-array/src/compute/cast.rs +++ b/vortex-array/src/compute/cast.rs @@ -33,6 +33,25 @@ pub fn try_cast(array: impl AsRef, dtype: &DType) -> VortexResult VortexResult { // TODO(ngates): check for null_count if dtype is non-nullable if let Some(f) = array.encoding().cast_fn() { return f.cast(array, dtype); diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index a19ddc6933..affacec13f 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -126,16 +126,42 @@ pub fn compare( .encoding() .compare_fn() .and_then(|f| f.compare(left, right, operator).transpose()) + .transpose()? { - return result; + debug_assert_eq!( + result.len(), + left.len(), + "Compare length mismatch {}", + left.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into()), + "Compare dtype mismatch {}", + left.encoding().id() + ); + return Ok(result); } if let Some(result) = right .encoding() .compare_fn() .and_then(|f| f.compare(right, left, operator.swap()).transpose()) + .transpose()? { - return result; + debug_assert_eq!( + result.len(), + left.len(), + "Compare length mismatch {}", + right.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into()), + "Compare dtype mismatch {}", + right.encoding().id() + ); + return Ok(result); } // Only log missing compare implementation if there's possibly better one than arrow, @@ -159,6 +185,7 @@ pub(crate) fn arrow_compare( rhs: &ArrayData, operator: Operator, ) -> VortexResult { + let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); let lhs = Datum::try_from(lhs.clone())?; let rhs = Datum::try_from(rhs.clone())?; @@ -171,7 +198,7 @@ pub(crate) fn arrow_compare( Operator::Lte => cmp::lt_eq(&lhs, &rhs)?, }; - Ok(ArrayData::from_arrow(&array, true)) + Ok(ArrayData::from_arrow(&array, nullable)) } pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { @@ -187,7 +214,10 @@ pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { Operator::Lte => lhs <= rhs, }; - Scalar::bool(b, Nullability::Nullable) + Scalar::bool( + b, + (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(), + ) } } diff --git a/vortex-array/src/compute/fill_forward.rs b/vortex-array/src/compute/fill_forward.rs index 2017d964a5..2c35769e44 100644 --- a/vortex-array/src/compute/fill_forward.rs +++ b/vortex-array/src/compute/fill_forward.rs @@ -33,7 +33,8 @@ pub fn fill_forward(array: impl AsRef) -> VortexResult { if !array.dtype().is_nullable() { return Ok(array.clone()); } - array + + let filled = array .encoding() .fill_forward_fn() .map(|f| f.fill_forward(array)) @@ -42,5 +43,20 @@ pub fn fill_forward(array: impl AsRef) -> VortexResult { NotImplemented: "fill_forward", array.encoding().id() )) - }) + })?; + + debug_assert_eq!( + filled.len(), + array.len(), + "FillForward length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + filled.dtype(), + array.dtype(), + "FillForward dtype mismatch {}", + array.encoding().id() + ); + + Ok(filled) } diff --git a/vortex-array/src/compute/fill_null.rs b/vortex-array/src/compute/fill_null.rs index 1a2ddf3ff6..ce6f541530 100644 --- a/vortex-array/src/compute/fill_null.rs +++ b/vortex-array/src/compute/fill_null.rs @@ -41,6 +41,26 @@ pub fn fill_null(array: impl AsRef, fill_value: Scalar) -> VortexResu vortex_bail!(MismatchedTypes: array.dtype(), fill_value.dtype()) } + let fill_value_nullability = fill_value.dtype().nullability(); + let filled = fill_null_impl(array, fill_value)?; + + debug_assert_eq!( + filled.len(), + array.len(), + "FillNull length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + filled.dtype(), + &array.dtype().with_nullability(fill_value_nullability), + "FillNull dtype mismatch {}", + array.encoding().id() + ); + + Ok(filled) +} + +fn fill_null_impl(array: &ArrayData, fill_value: Scalar) -> VortexResult { if let Some(fill_null_fn) = array.encoding().fill_null_fn() { return fill_null_fn.fill_null(array, fill_value); } diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index eab61522fc..030f8195b5 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -60,55 +60,58 @@ pub fn filter(array: &ArrayData, mask: FilterMask) -> VortexResult { ); } + let true_count = mask.true_count(); + // Fast-path for empty mask. - if mask.true_count() == 0 { + if true_count == 0 { return Ok(Canonical::empty(array.dtype())?.into()); } // Fast-path for full mask - if mask.true_count() == mask.len() { + if true_count == mask.len() { return Ok(array.clone()); } + let filtered = filter_impl(array, mask)?; + + debug_assert_eq!( + filtered.len(), + true_count, + "Filter length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + filtered.dtype(), + array.dtype(), + "Filter dtype mismatch {}", + array.encoding().id() + ); + + Ok(filtered) +} + +fn filter_impl(array: &ArrayData, mask: FilterMask) -> VortexResult { if let Some(filter_fn) = array.encoding().filter_fn() { - let true_count = mask.true_count(); - let result = filter_fn.filter(array, mask)?; - if array.dtype() != result.dtype() { - vortex_bail!( - "FilterFn {} changed array dtype from {} to {}", - array.encoding().id(), - array.dtype(), - result.dtype() - ); - } - if true_count != result.len() { - vortex_bail!( - "FilterFn {} returned incorrect length: expected {}, got {}", - array.encoding().id(), - true_count, - result.len() - ); - } - Ok(result) - } else { - // We can use scalar_at if the mask has length 1. - if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() { - let idx = mask.indices()?[0]; - return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array()); - } + return filter_fn.filter(array, mask); + } - // Fallback: implement using Arrow kernels. - log::debug!( - "No filter implementation found for {}", - array.encoding().id(), - ); + // We can use scalar_at if the mask has length 1. + if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() { + let idx = mask.indices()?[0]; + return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array()); + } - let array_ref = array.clone().into_arrow()?; - let mask_array = BooleanArray::new(mask.to_boolean_buffer()?, None); - let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?; + // Fallback: implement using Arrow kernels. + log::debug!( + "No filter implementation found for {}", + array.encoding().id(), + ); - Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable())) - } + let array_ref = array.clone().into_arrow()?; + let mask_array = BooleanArray::new(mask.to_boolean_buffer()?, None); + let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?; + + Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable())) } /// Represents the mask argument to a filter function. diff --git a/vortex-array/src/compute/invert.rs b/vortex-array/src/compute/invert.rs index 0a25c3a373..f1f965d62a 100644 --- a/vortex-array/src/compute/invert.rs +++ b/vortex-array/src/compute/invert.rs @@ -32,7 +32,22 @@ pub fn invert(array: &ArrayData) -> VortexResult { } if let Some(f) = array.encoding().invert_fn() { - return f.invert(array); + let inverted = f.invert(array)?; + + debug_assert_eq!( + inverted.len(), + array.len(), + "Invert length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + inverted.dtype(), + array.dtype(), + "Invert dtype mismatch {}", + array.encoding().id() + ); + + return Ok(inverted); } // Otherwise, we canonicalize into a boolean array and invert. diff --git a/vortex-array/src/compute/like.rs b/vortex-array/src/compute/like.rs index d2c1ad28e6..e3a19e8436 100644 --- a/vortex-array/src/compute/like.rs +++ b/vortex-array/src/compute/like.rs @@ -60,7 +60,22 @@ pub fn like( } if let Some(f) = array.encoding().like_fn() { - return f.like(array, pattern, options); + let result = f.like(array, pattern, options)?; + + debug_assert_eq!( + result.len(), + array.len(), + "Like length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + result.dtype(), + &DType::Bool((array.dtype().is_nullable() || pattern.dtype().is_nullable()).into()), + "Like dtype mismatch {}", + array.encoding().id() + ); + + return Ok(result); } // Otherwise, we canonicalize into a UTF8 array. diff --git a/vortex-array/src/compute/scalar_at.rs b/vortex-array/src/compute/scalar_at.rs index 14b78e23f1..428bf6df1f 100644 --- a/vortex-array/src/compute/scalar_at.rs +++ b/vortex-array/src/compute/scalar_at.rs @@ -38,9 +38,18 @@ pub fn scalar_at(array: impl AsRef, index: usize) -> VortexResult { @@ -40,7 +40,7 @@ pub fn slice(array: impl AsRef, start: usize, stop: usize) -> VortexR let array = array.as_ref(); check_slice_bounds(array, start, stop)?; - array + let sliced = array .encoding() .slice_fn() .map(|f| f.slice(array, start, stop)) @@ -49,7 +49,22 @@ pub fn slice(array: impl AsRef, start: usize, stop: usize) -> VortexR NotImplemented: "slice", array.encoding().id() )) - }) + })?; + + debug_assert_eq!( + sliced.len(), + stop - start, + "Slice length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + sliced.dtype(), + array.dtype(), + "Slice dtype mismatch {}", + array.encoding().id() + ); + + Ok(sliced) } fn check_slice_bounds(array: &ArrayData, start: usize, stop: usize) -> VortexResult<()> { diff --git a/vortex-array/src/compute/take.rs b/vortex-array/src/compute/take.rs index 20ea46f556..a7e1a42a7f 100644 --- a/vortex-array/src/compute/take.rs +++ b/vortex-array/src/compute/take.rs @@ -46,6 +46,11 @@ pub fn take( array: impl AsRef, indices: impl AsRef, ) -> VortexResult { + // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to + // the filter function since they're typically optimised for this case. + // TODO(ngates): if indices min is quite high, we could slice self and offset the indices + // such that canonicalize does less work. + let array = array.as_ref(); let indices = indices.as_ref(); @@ -56,18 +61,35 @@ pub fn take( ); } - // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to - // the filter function since they're typically optimised for this case. - // If the indices are all within bounds, we can skip bounds checking. let checked_indices = indices .statistics() .get_as::(Stat::Max) .is_some_and(|max| max < array.len()); - // TODO(ngates): if indices min is quite high, we could slice self and offset the indices - // such that canonicalize does less work. + let taken = take_impl(array, indices, checked_indices)?; + + debug_assert_eq!( + taken.len(), + indices.len(), + "Take length mismatch {}", + array.encoding().id() + ); + debug_assert_eq!( + array.dtype(), + taken.dtype(), + "Take dtype mismatch {}", + array.encoding().id() + ); + Ok(taken) +} + +fn take_impl( + array: &ArrayData, + indices: &ArrayData, + checked_indices: bool, +) -> VortexResult { // If TakeFn defined for the encoding, delegate to TakeFn. // If we know from stats that indices are all valid, we can avoid all bounds checks. if let Some(take_fn) = array.encoding().take_fn() { diff --git a/vortex-array/src/data/viewed.rs b/vortex-array/src/data/viewed.rs index 32f53ffa58..2fa25fdf6e 100644 --- a/vortex-array/src/data/viewed.rs +++ b/vortex-array/src/data/viewed.rs @@ -156,7 +156,7 @@ impl Statistics for ViewedArrayData { .stats()? .bit_width_freq() .map(|v| v.iter().map(Scalar::from).collect_vec()) - .map(|v| Scalar::list(element_dtype, v)) + .map(|v| Scalar::list(element_dtype, v, Nullability::NonNullable)) } Stat::TrailingZeroFreq => self .flatbuffer() diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index bdc1fd0212..9d493e4ea0 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,8 +1,7 @@ use std::ops::Deref; use std::sync::Arc; -use vortex_dtype::DType; -use vortex_dtype::Nullability::{NonNullable, Nullable}; +use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; use crate::value::{InnerScalarValue, ScalarValue}; @@ -72,7 +71,11 @@ impl<'a> ListScalar<'a> { } impl Scalar { - pub fn list(element_dtype: Arc, children: Vec) -> Self { + pub fn list( + element_dtype: Arc, + children: Vec, + nullability: Nullability, + ) -> Self { for child in &children { if child.dtype() != &*element_dtype { vortex_panic!( @@ -83,16 +86,16 @@ impl Scalar { } } Self { - dtype: DType::List(element_dtype, NonNullable), + dtype: DType::List(element_dtype, nullability), value: ScalarValue(InnerScalarValue::List( children.into_iter().map(|x| x.value).collect::>(), )), } } - pub fn empty(element_dtype: Arc) -> Self { + pub fn list_empty(element_dtype: Arc, nullability: Nullability) -> Self { Self { - dtype: DType::List(element_dtype, Nullable), + dtype: DType::List(element_dtype, nullability), value: ScalarValue(InnerScalarValue::Null), } } From d6a24756ef47efd0947a1f89bac245ee4db67743 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 18 Dec 2024 19:05:13 +0000 Subject: [PATCH 3/3] chore: move IoBuf to vortex-io (#1714) --- bench-vortex/src/taxi_data.rs | 3 +-- vortex-buffer/src/lib.rs | 1 - {vortex-buffer => vortex-io}/src/io_buf.rs | 3 +-- vortex-io/src/lib.rs | 2 ++ vortex-io/src/object_store.rs | 3 +-- vortex-io/src/tokio.rs | 3 +-- vortex-io/src/write.rs | 2 +- 7 files changed, 7 insertions(+), 10 deletions(-) rename {vortex-buffer => vortex-io}/src/io_buf.rs (99%) diff --git a/bench-vortex/src/taxi_data.rs b/bench-vortex/src/taxi_data.rs index c612d5ae69..4dad5929b8 100644 --- a/bench-vortex/src/taxi_data.rs +++ b/bench-vortex/src/taxi_data.rs @@ -3,9 +3,8 @@ use std::io::Write; use std::path::PathBuf; use futures::executor::block_on; -use vortex::buffer::io_buf::IoBuf; use vortex::error::VortexError; -use vortex::io::VortexWrite; +use vortex::io::{IoBuf, VortexWrite}; use crate::data_downloads::{data_vortex_uncompressed, download_data}; use crate::reader::rewrite_parquet_as_vortex; diff --git a/vortex-buffer/src/lib.rs b/vortex-buffer/src/lib.rs index ba42277a24..e22fa7d5b8 100644 --- a/vortex-buffer/src/lib.rs +++ b/vortex-buffer/src/lib.rs @@ -16,7 +16,6 @@ use arrow_buffer::{ArrowNativeType, Buffer as ArrowBuffer, MutableBuffer as Arro pub use string::*; mod flexbuffers; -pub mod io_buf; mod string; /// Buffer is an owned, cheaply cloneable byte array. diff --git a/vortex-buffer/src/io_buf.rs b/vortex-io/src/io_buf.rs similarity index 99% rename from vortex-buffer/src/io_buf.rs rename to vortex-io/src/io_buf.rs index ec55a65f44..71c4334936 100644 --- a/vortex-buffer/src/io_buf.rs +++ b/vortex-io/src/io_buf.rs @@ -3,8 +3,7 @@ use std::ops::Range; use bytes::Bytes; - -use crate::Buffer; +use vortex_buffer::Buffer; /// Trait for types that can provide a readonly byte buffer interface to I/O frameworks. /// diff --git a/vortex-io/src/lib.rs b/vortex-io/src/lib.rs index f7168b4898..4f6d37d90b 100644 --- a/vortex-io/src/lib.rs +++ b/vortex-io/src/lib.rs @@ -9,6 +9,7 @@ pub use buf::*; pub use dispatcher::*; +pub use io_buf::*; pub use limit::*; #[cfg(feature = "object_store")] pub use object_store::*; @@ -23,6 +24,7 @@ mod buf; #[cfg(feature = "compio")] mod compio; mod dispatcher; +mod io_buf; mod limit; #[cfg(feature = "object_store")] mod object_store; diff --git a/vortex-io/src/object_store.rs b/vortex-io/src/object_store.rs index 532cf42fad..dbf5dacc8f 100644 --- a/vortex-io/src/object_store.rs +++ b/vortex-io/src/object_store.rs @@ -8,12 +8,11 @@ use bytes::Bytes; use futures_util::StreamExt; use object_store::path::Path; use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore, WriteMultipart}; -use vortex_buffer::io_buf::IoBuf; use vortex_buffer::Buffer; use vortex_error::{VortexExpect, VortexResult, VortexUnwrap}; use crate::aligned::AlignedBytesMut; -use crate::{VortexBufReader, VortexReadAt, VortexWrite, ALIGNMENT}; +use crate::{IoBuf, VortexBufReader, VortexReadAt, VortexWrite, ALIGNMENT}; pub trait ObjectStoreExt { fn vortex_read( diff --git a/vortex-io/src/tokio.rs b/vortex-io/src/tokio.rs index 0ed29e0fa2..9fc50a9d46 100644 --- a/vortex-io/src/tokio.rs +++ b/vortex-io/src/tokio.rs @@ -8,11 +8,10 @@ use std::sync::Arc; use bytes::Bytes; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use vortex_buffer::io_buf::IoBuf; use vortex_error::VortexUnwrap; use crate::aligned::AlignedBytesMut; -use crate::{VortexReadAt, VortexWrite, ALIGNMENT}; +use crate::{IoBuf, VortexReadAt, VortexWrite, ALIGNMENT}; pub struct TokioAdapter(pub IO); diff --git a/vortex-io/src/write.rs b/vortex-io/src/write.rs index 12220c9ff1..0e32b0b07c 100644 --- a/vortex-io/src/write.rs +++ b/vortex-io/src/write.rs @@ -1,7 +1,7 @@ use std::future::{ready, Future}; use std::io::{self, Cursor, Write}; -use vortex_buffer::io_buf::IoBuf; +use crate::IoBuf; pub trait VortexWrite { fn write_all(&mut self, buffer: B) -> impl Future>;