From f983056e3cfe3ea3717026f24aa9cfb54883c2cf Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Mon, 4 Mar 2024 14:01:56 +0000 Subject: [PATCH 1/3] merge --- vortex/src/array/bool/compute.rs | 36 +++++++++++++++++++++++++++ vortex/src/array/bool/mod.rs | 15 +++--------- vortex/src/array/chunked/compute.rs | 34 ++++++++++++++++++++++++++ vortex/src/array/chunked/mod.rs | 11 +++------ vortex/src/array/mod.rs | 9 ------- vortex/src/compute/mod.rs | 7 ++++++ vortex/src/compute/scalar_at.rs | 38 +++++++++++++++++++++++++++++ 7 files changed, 122 insertions(+), 28 deletions(-) create mode 100644 vortex/src/array/bool/compute.rs create mode 100644 vortex/src/array/chunked/compute.rs create mode 100644 vortex/src/compute/scalar_at.rs diff --git a/vortex/src/array/bool/compute.rs b/vortex/src/array/bool/compute.rs new file mode 100644 index 0000000000..371edd78f6 --- /dev/null +++ b/vortex/src/array/bool/compute.rs @@ -0,0 +1,36 @@ +// (c) Copyright 2024 Fulcrum Technologies, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::array::bool::BoolArray; +use crate::array::Array; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for BoolArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for BoolArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + Ok(self.buffer.value(index).into()) + } else { + Ok(NullableScalar::none(self.dtype().clone()).boxed()) + } + } +} diff --git a/vortex/src/array/bool/mod.rs b/vortex/src/array/bool/mod.rs index c59c072a9a..c595626735 100644 --- a/vortex/src/array/bool/mod.rs +++ b/vortex/src/array/bool/mod.rs @@ -30,11 +30,12 @@ use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stat, Stats, StatsSet}; use super::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, - Encoding, EncodingId, EncodingRef, ENCODINGS, + check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, + EncodingId, EncodingRef, ENCODINGS, }; mod compress; +mod compute; mod serde; mod stats; @@ -119,15 +120,7 @@ impl Array for BoolArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - if self.is_valid(index) { - Ok(self.buffer.value(index).into()) - } else { - Ok(NullableScalar::none(self.dtype().clone()).boxed()) - } - } + fn scalar_at(&self, index: usize) -> VortexResult> {} fn iter_arrow(&self) -> Box { Box::new(iter::once(Arc::new(BooleanArray::new( diff --git a/vortex/src/array/chunked/compute.rs b/vortex/src/array/chunked/compute.rs new file mode 100644 index 0000000000..8eda56cc6e --- /dev/null +++ b/vortex/src/array/chunked/compute.rs @@ -0,0 +1,34 @@ +// (c) Copyright 2024 Fulcrum Technologies, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::array::bool::BoolArray; +use crate::array::chunked::ChunkedArray; +use crate::array::Array; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for ChunkedArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for ChunkedArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let (chunk_index, chunk_offset) = self.find_physical_location(index); + self.chunks[chunk_index].scalar_at(chunk_offset) + } +} diff --git a/vortex/src/array/chunked/mod.rs b/vortex/src/array/chunked/mod.rs index abd39f9b46..07900b33a1 100644 --- a/vortex/src/array/chunked/mod.rs +++ b/vortex/src/array/chunked/mod.rs @@ -21,8 +21,8 @@ use itertools::Itertools; use linkme::distributed_slice; use crate::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, ENCODINGS, + check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, + ENCODINGS, }; use crate::compress::EncodingCompression; use crate::dtype::DType; @@ -33,6 +33,7 @@ use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -140,12 +141,6 @@ impl Array for ChunkedArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - let (chunk_index, chunk_offset) = self.find_physical_location(index); - self.chunks[chunk_index].scalar_at(chunk_offset) - } - fn iter_arrow(&self) -> Box { Box::new(ChunkedArrowIterator::new(self)) } diff --git a/vortex/src/array/mod.rs b/vortex/src/array/mod.rs index 529677c458..b5a2d52ed4 100644 --- a/vortex/src/array/mod.rs +++ b/vortex/src/array/mod.rs @@ -74,8 +74,6 @@ pub trait Array: ArrayDisplay + Debug + Send + Sync + dyn_clone::DynClone + 'sta fn dtype(&self) -> &DType; /// Get statistics for the array fn stats(&self) -> Stats; - /// Get scalar value at given index - fn scalar_at(&self, index: usize) -> VortexResult>; /// Produce arrow batches from the encoding fn iter_arrow(&self) -> Box; /// Limit array to start..stop range @@ -104,13 +102,6 @@ pub fn check_slice_bounds(array: &dyn Array, start: usize, stop: usize) -> Vorte Ok(()) } -pub fn check_index_bounds(array: &dyn Array, index: usize) -> VortexResult<()> { - if index >= array.len() { - return Err(VortexError::OutOfBounds(index, 0, array.len())); - } - Ok(()) -} - pub fn check_validity_buffer(validity: Option<&ArrayRef>) -> VortexResult<()> { // TODO(ngates): take a length parameter and check that the length of the validity buffer matches if validity diff --git a/vortex/src/compute/mod.rs b/vortex/src/compute/mod.rs index 6eea3060c6..eef77ed228 100644 --- a/vortex/src/compute/mod.rs +++ b/vortex/src/compute/mod.rs @@ -12,16 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::compute::scalar_at::ScalarAtFn; +use polars_arrow::scalar::Scalar; use take::TakeFn; pub mod add; pub mod as_contiguous; pub mod cast; pub mod repeat; +pub mod scalar_at; pub mod search_sorted; pub mod take; pub trait ArrayCompute { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + None + } + fn take(&self) -> Option<&dyn TakeFn> { None } diff --git a/vortex/src/compute/scalar_at.rs b/vortex/src/compute/scalar_at.rs new file mode 100644 index 0000000000..347cab77cd --- /dev/null +++ b/vortex/src/compute/scalar_at.rs @@ -0,0 +1,38 @@ +// (c) Copyright 2024 Fulcrum Technologies, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::array::Array; +use crate::error::{VortexError, VortexResult}; +use crate::scalar::Scalar; + +pub trait ScalarAtFn { + fn scalar_at(&self, index: usize) -> VortexResult>; +} + +pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult> { + if index >= array.len() { + return Err(VortexError::OutOfBounds(index, 0, array.len())); + } + + array + .compute() + .and_then(|c| c.scalar_at()) + .map(|t| t.scalar_at(index)) + .unwrap_or_else(|| { + // TODO(ngates): default implementation of decode and then try again + Err(VortexError::ComputeError( + format!("scalar_at not implemented for {}", &array.encoding().id()).into(), + )) + }) +} From bc2c5dc960b20a29631f805cd6392f8f7caf0771 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Mon, 4 Mar 2024 14:25:43 +0000 Subject: [PATCH 2/3] merge --- vortex-alp/Cargo.toml | 3 ++ vortex-alp/src/alp.rs | 35 +------------------ vortex-alp/src/compute.rs | 47 ++++++++++++++++++++++++++ vortex-alp/src/lib.rs | 1 + vortex-dict/Cargo.toml | 3 ++ vortex-dict/src/compress.rs | 5 +-- vortex-dict/src/compute.rs | 18 ++++++++++ vortex-dict/src/dict.rs | 11 +----- vortex-dict/src/lib.rs | 1 + vortex-fastlanes/src/bitpacking/mod.rs | 23 +++---------- vortex-fastlanes/src/for/mod.rs | 7 ++-- vortex-ffor/src/ffor.rs | 46 ++++--------------------- vortex-ree/Cargo.toml | 3 ++ vortex-ree/src/compute.rs | 17 ++++++++++ vortex-ree/src/lib.rs | 1 + vortex-ree/src/ree.rs | 14 +++----- vortex-roaring/Cargo.toml | 6 ++-- vortex-roaring/src/boolean/compute.rs | 21 ++++++++++++ vortex-roaring/src/boolean/mod.rs | 16 ++------- vortex-roaring/src/integer/compute.rs | 27 +++++++++++++++ vortex-roaring/src/integer/mod.rs | 20 ++--------- vortex-zigzag/Cargo.toml | 6 ++-- vortex-zigzag/src/compute.rs | 38 +++++++++++++++++++++ vortex-zigzag/src/lib.rs | 1 + vortex-zigzag/src/zigzag.rs | 34 ++----------------- vortex/src/array/bool/mod.rs | 8 ++--- vortex/src/array/chunked/compute.rs | 8 ++--- vortex/src/array/chunked/mod.rs | 1 - vortex/src/array/constant/compress.rs | 3 +- vortex/src/array/constant/compute.rs | 20 +++++++++++ vortex/src/array/constant/mod.rs | 10 ++---- vortex/src/array/constant/take.rs | 10 ------ vortex/src/array/mod.rs | 11 +++--- vortex/src/array/primitive/compute.rs | 30 ++++++++++++++++ vortex/src/array/primitive/mod.rs | 28 ++++----------- vortex/src/array/sparse/compute.rs | 37 ++++++++++++++++++++ vortex/src/array/sparse/mod.rs | 31 ++--------------- vortex/src/array/struct_/compute.rs | 25 ++++++++++++++ vortex/src/array/struct_/mod.rs | 13 +------ vortex/src/array/typed/compute.rs | 19 +++++++++++ vortex/src/array/typed/mod.rs | 7 +--- vortex/src/array/varbin/compute.rs | 28 +++++++++++++++ vortex/src/array/varbin/mod.rs | 31 +++++------------ vortex/src/array/varbinview/compute.rs | 28 +++++++++++++++ vortex/src/array/varbinview/mod.rs | 21 +++--------- vortex/src/compute/mod.rs | 1 - vortex/src/compute/scalar_at.rs | 3 +- vortex/src/compute/take.rs | 16 ++++----- vortex/src/lib.rs | 2 ++ 49 files changed, 451 insertions(+), 344 deletions(-) create mode 100644 vortex-alp/src/compute.rs create mode 100644 vortex-dict/src/compute.rs create mode 100644 vortex-ree/src/compute.rs create mode 100644 vortex-roaring/src/boolean/compute.rs create mode 100644 vortex-roaring/src/integer/compute.rs create mode 100644 vortex-zigzag/src/compute.rs delete mode 100644 vortex/src/array/constant/take.rs create mode 100644 vortex/src/array/primitive/compute.rs create mode 100644 vortex/src/array/sparse/compute.rs create mode 100644 vortex/src/array/struct_/compute.rs create mode 100644 vortex/src/array/typed/compute.rs create mode 100644 vortex/src/array/varbin/compute.rs create mode 100644 vortex/src/array/varbinview/compute.rs diff --git a/vortex-alp/Cargo.toml b/vortex-alp/Cargo.toml index 776c8c31ce..cb1a1cdd1d 100644 --- a/vortex-alp/Cargo.toml +++ b/vortex-alp/Cargo.toml @@ -18,3 +18,6 @@ linkme = "0.3.22" itertools = "0.12.1" codecz = { version = "0.1.0", path = "../codecz" } log = { version = "0.4.20", features = [] } + +[lints] +workspace = true diff --git a/vortex-alp/src/alp.rs b/vortex-alp/src/alp.rs index 43f4b0d747..65d37418aa 100644 --- a/vortex-alp/src/alp.rs +++ b/vortex-alp/src/alp.rs @@ -1,14 +1,12 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use codecz::alp; pub use codecz::alp::ALPExponents; use vortex::array::{Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef}; use vortex::compress::EncodingCompression; -use vortex::dtype::{DType, FloatWidth, IntWidth, Signedness}; +use vortex::dtype::{DType, IntWidth, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{NullableScalar, Scalar}; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -106,37 +104,6 @@ impl Array for ALPArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if let Some(patch) = self - .patches() - .and_then(|p| p.scalar_at(index).ok()) - .and_then(|p| p.into_nonnull()) - { - return Ok(patch); - } - - let Some(encoded_val) = self.encoded.scalar_at(index)?.into_nonnull() else { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - }; - match self.dtype { - DType::Float(FloatWidth::_32, _) => { - let encoded_val: i32 = encoded_val.try_into().unwrap(); - Ok(alp::decode_single::(encoded_val, self.exponents) - .unwrap() - .into()) - } - - DType::Float(FloatWidth::_64, _) => { - let encoded_val: i64 = encoded_val.try_into().unwrap(); - Ok(alp::decode_single::(encoded_val, self.exponents) - .unwrap() - .into()) - } - - _ => unreachable!(), - } - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-alp/src/compute.rs b/vortex-alp/src/compute.rs new file mode 100644 index 0000000000..4ae36d02f8 --- /dev/null +++ b/vortex-alp/src/compute.rs @@ -0,0 +1,47 @@ +use crate::ALPArray; +use codecz::alp; +use vortex::array::Array; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::dtype::{DType, FloatWidth}; +use vortex::error::VortexResult; +use vortex::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for ALPArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for ALPArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if let Some(patch) = self + .patches() + .and_then(|p| scalar_at(p, index).ok()) + .and_then(|p| p.into_nonnull()) + { + return Ok(patch); + } + + let Some(encoded_val) = scalar_at(self.encoded(), index)?.into_nonnull() else { + return Ok(NullableScalar::none(self.dtype().clone()).boxed()); + }; + match self.dtype() { + DType::Float(FloatWidth::_32, _) => { + let encoded_val: i32 = encoded_val.try_into().unwrap(); + Ok(alp::decode_single::(encoded_val, self.exponents()) + .unwrap() + .into()) + } + + DType::Float(FloatWidth::_64, _) => { + let encoded_val: i64 = encoded_val.try_into().unwrap(); + Ok(alp::decode_single::(encoded_val, self.exponents()) + .unwrap() + .into()) + } + + _ => unreachable!(), + } + } +} diff --git a/vortex-alp/src/lib.rs b/vortex-alp/src/lib.rs index 6f0fe3ae42..8ff1d432cc 100644 --- a/vortex-alp/src/lib.rs +++ b/vortex-alp/src/lib.rs @@ -4,6 +4,7 @@ use vortex::array::{EncodingRef, ENCODINGS}; mod alp; mod compress; +mod compute; mod downcast; mod serde; mod stats; diff --git a/vortex-dict/Cargo.toml b/vortex-dict/Cargo.toml index 5f9e19ab6d..0ea173dfe0 100644 --- a/vortex-dict/Cargo.toml +++ b/vortex-dict/Cargo.toml @@ -19,3 +19,6 @@ hashbrown = "0.14.3" linkme = "0.3.22" log = "0.4.20" num-traits = "0.2.17" + +[lints] +workspace = true diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index a3f6dd4ea4..f1ef44f8c2 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -11,6 +11,7 @@ use vortex::array::primitive::PrimitiveArray; use vortex::array::varbin::VarBinArray; use vortex::array::{Array, ArrayKind, ArrayRef, CloneOptionalArray}; use vortex::compress::{CompressConfig, CompressCtx, Compressor, EncodingCompression}; +use vortex::compute::scalar_at::scalar_at; use vortex::dtype::DType; use vortex::match_each_native_ptype; use vortex::ptype::NativePType; @@ -207,8 +208,8 @@ fn bytes_at_primitive<'a, T: NativePType + AsPrimitive>( } fn bytes_at<'a>(offsets: &'a dyn Array, bytes: &'a [u8], idx: usize) -> &'a [u8] { - let start: usize = offsets.scalar_at(idx).unwrap().try_into().unwrap(); - let stop: usize = offsets.scalar_at(idx + 1).unwrap().try_into().unwrap(); + let start: usize = scalar_at(offsets, idx).unwrap().try_into().unwrap(); + let stop: usize = scalar_at(offsets, idx + 1).unwrap().try_into().unwrap(); &bytes[start..stop] } diff --git a/vortex-dict/src/compute.rs b/vortex-dict/src/compute.rs new file mode 100644 index 0000000000..b8292458b8 --- /dev/null +++ b/vortex-dict/src/compute.rs @@ -0,0 +1,18 @@ +use crate::DictArray; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +impl ArrayCompute for DictArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for DictArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let dict_index: usize = scalar_at(self.codes(), index)?.try_into()?; + scalar_at(self.dict(), dict_index) + } +} diff --git a/vortex-dict/src/dict.rs b/vortex-dict/src/dict.rs index bb332cd453..8cb0cbcd54 100644 --- a/vortex-dict/src/dict.rs +++ b/vortex-dict/src/dict.rs @@ -1,14 +1,11 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use vortex::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, -}; +use vortex::array::{check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId}; use vortex::compress::EncodingCompression; use vortex::dtype::{DType, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -75,12 +72,6 @@ impl Array for DictArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - let dict_index: usize = self.codes().scalar_at(index)?.try_into()?; - self.dict().scalar_at(dict_index) - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-dict/src/lib.rs b/vortex-dict/src/lib.rs index afef63b7da..7342c1d293 100644 --- a/vortex-dict/src/lib.rs +++ b/vortex-dict/src/lib.rs @@ -5,6 +5,7 @@ pub use compress::*; pub use dict::*; mod compress; +mod compute; mod dict; mod downcast; mod serde; diff --git a/vortex-fastlanes/src/bitpacking/mod.rs b/vortex-fastlanes/src/bitpacking/mod.rs index 20b6b86c77..4f49307ec6 100644 --- a/vortex-fastlanes/src/bitpacking/mod.rs +++ b/vortex-fastlanes/src/bitpacking/mod.rs @@ -6,10 +6,11 @@ use vortex::array::{ check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use vortex::compress::EncodingCompression; +use vortex::compute::scalar_at::scalar_at; +use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::VortexResult; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{NullableScalar, Scalar}; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stat, Stats, StatsCompute, StatsSet}; @@ -74,7 +75,7 @@ impl BitPackedArray { pub fn is_valid(&self, index: usize) -> bool { self.validity() - .map(|v| v.scalar_at(index).and_then(|v| v.try_into()).unwrap()) + .map(|v| scalar_at(v, index).and_then(|v| v.try_into()).unwrap()) .unwrap_or(true) } } @@ -115,22 +116,6 @@ impl Array for BitPackedArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if !self.is_valid(index) { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - } - - if let Some(patch) = self - .patches() - .and_then(|p| p.scalar_at(index).ok()) - .and_then(|p| p.into_nonnull()) - { - return Ok(patch); - } - - todo!("Decode single element from BitPacked array"); - } - fn iter_arrow(&self) -> Box { todo!() } @@ -156,6 +141,8 @@ impl Array for BitPackedArray { } } +impl ArrayCompute for BitPackedArray {} + impl<'arr> AsRef<(dyn Array + 'arr)> for BitPackedArray { fn as_ref(&self) -> &(dyn Array + 'arr) { self diff --git a/vortex-fastlanes/src/for/mod.rs b/vortex-fastlanes/src/for/mod.rs index cd7b8fba9a..c0a3ec6e20 100644 --- a/vortex-fastlanes/src/for/mod.rs +++ b/vortex-fastlanes/src/for/mod.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef}; use vortex::compress::EncodingCompression; +use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::VortexResult; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -78,10 +79,6 @@ impl Array for FoRArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, _index: usize) -> VortexResult> { - todo!() - } - fn iter_arrow(&self) -> Box { todo!() } @@ -110,6 +107,8 @@ impl Array for FoRArray { } } +impl ArrayCompute for FoRArray {} + impl<'arr> AsRef<(dyn Array + 'arr)> for FoRArray { fn as_ref(&self) -> &(dyn Array + 'arr) { self diff --git a/vortex-ffor/src/ffor.rs b/vortex-ffor/src/ffor.rs index fd187d9fa0..5fd1325bd8 100644 --- a/vortex-ffor/src/ffor.rs +++ b/vortex-ffor/src/ffor.rs @@ -1,17 +1,17 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use vortex::array::downcast::DowncastArrayBuiltin; use vortex::array::{ check_validity_buffer, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use vortex::compress::EncodingCompression; +use vortex::compute::scalar_at::scalar_at; +use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::match_each_integer_ptype; -use vortex::scalar::{NullableScalar, Scalar}; +use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -104,7 +104,7 @@ impl FFORArray { pub fn is_valid(&self, index: usize) -> bool { self.validity() - .map(|v| v.scalar_at(index).and_then(|v| v.try_into()).unwrap()) + .map(|v| scalar_at(v, index).and_then(|v| v.try_into()).unwrap()) .unwrap_or(true) } } @@ -145,42 +145,6 @@ impl Array for FFORArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if !self.is_valid(index) { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - } - - if let Some(patch) = self - .patches() - .and_then(|p| p.scalar_at(index).ok()) - .and_then(|p| p.into_nonnull()) - { - return Ok(patch); - } - - let Some(parray) = self.encoded().maybe_primitive() else { - return Err(VortexError::InvalidEncoding( - self.encoded().encoding().id().clone(), - )); - }; - - if let Ok(ptype) = self.dtype().try_into() { - match_each_integer_ptype!(ptype, |$T| { - return Ok(codecz::ffor::decode_single::<$T>( - parray.buffer().as_slice(), - self.len, - self.num_bits, - self.min_val().try_into().unwrap(), - index, - ) - .unwrap() - .into()); - }) - } else { - return Err(VortexError::InvalidDType(self.dtype().clone())); - } - } - fn iter_arrow(&self) -> Box { todo!() } @@ -204,6 +168,8 @@ impl Array for FFORArray { } } +impl ArrayCompute for FFORArray {} + impl<'arr> AsRef<(dyn Array + 'arr)> for FFORArray { fn as_ref(&self) -> &(dyn Array + 'arr) { self diff --git a/vortex-ree/Cargo.toml b/vortex-ree/Cargo.toml index 22bf5a8dad..92f4c5b515 100644 --- a/vortex-ree/Cargo.toml +++ b/vortex-ree/Cargo.toml @@ -19,3 +19,6 @@ linkme = "0.3.22" half = "2.3.1" num-traits = "0.2.17" itertools = "0.10.5" + +[lints] +workspace = true diff --git a/vortex-ree/src/compute.rs b/vortex-ree/src/compute.rs new file mode 100644 index 0000000000..8fb64b37dd --- /dev/null +++ b/vortex-ree/src/compute.rs @@ -0,0 +1,17 @@ +use crate::REEArray; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +impl ArrayCompute for REEArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for REEArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + scalar_at(self.values(), self.find_physical_index(index)?) + } +} diff --git a/vortex-ree/src/lib.rs b/vortex-ree/src/lib.rs index 7fecbf9c23..c6e878a3e1 100644 --- a/vortex-ree/src/lib.rs +++ b/vortex-ree/src/lib.rs @@ -4,6 +4,7 @@ use vortex::array::{EncodingRef, ENCODINGS}; pub use ree::*; mod compress; +mod compute; mod downcast; mod ree; mod serde; diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 6457b5870a..06ed424925 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -10,18 +10,18 @@ use num_traits::AsPrimitive; use codecz::ree::SupportsREE; use vortex::array::primitive::PrimitiveArray; use vortex::array::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayKind, ArrayRef, - ArrowIterator, CloneOptionalArray, Encoding, EncodingId, EncodingRef, + check_slice_bounds, check_validity_buffer, Array, ArrayKind, ArrayRef, ArrowIterator, + CloneOptionalArray, Encoding, EncodingId, EncodingRef, }; use vortex::arrow::match_arrow_numeric_type; use vortex::compress::EncodingCompression; use vortex::compute; +use vortex::compute::scalar_at::scalar_at; use vortex::compute::search_sorted::SearchSortedSide; use vortex::dtype::{DType, Nullability, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; use vortex::ptype::NativePType; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stat, Stats, StatsSet}; @@ -158,11 +158,6 @@ impl Array for REEArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - self.values.scalar_at(self.find_physical_index(index)?) - } - fn iter_arrow(&self) -> Box { // TODO(robert): Plumb offset rewriting to zig to fuse with REE decompression let ends: Vec = self @@ -308,8 +303,7 @@ fn run_ends_logical_length>(ends: &T) -> usize { if ends.as_ref().is_empty() { 0 } else { - ends.as_ref() - .scalar_at(ends.as_ref().len() - 1) + scalar_at(ends.as_ref(), ends.as_ref().len() - 1) .and_then(|end| end.try_into()) .unwrap_or_else(|_| panic!("Couldn't convert ends to usize")) } diff --git a/vortex-roaring/Cargo.toml b/vortex-roaring/Cargo.toml index 57dcbe86e1..5dacd630f7 100644 --- a/vortex-roaring/Cargo.toml +++ b/vortex-roaring/Cargo.toml @@ -11,12 +11,12 @@ include = { workspace = true } edition = { workspace = true } rust-version = { workspace = true } -[lints] -workspace = true - [dependencies] vortex = { "path" = "../vortex" } linkme = "0.3.22" croaring = "1.0.1" num-traits = "0.2.17" log = "0.4.20" + +[lints] +workspace = true diff --git a/vortex-roaring/src/boolean/compute.rs b/vortex-roaring/src/boolean/compute.rs new file mode 100644 index 0000000000..411245c962 --- /dev/null +++ b/vortex-roaring/src/boolean/compute.rs @@ -0,0 +1,21 @@ +use crate::RoaringBoolArray; +use vortex::compute::scalar_at::ScalarAtFn; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +impl ArrayCompute for RoaringBoolArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for RoaringBoolArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.bitmap.contains(index as u32) { + Ok(true.into()) + } else { + Ok(false.into()) + } + } +} diff --git a/vortex-roaring/src/boolean/mod.rs b/vortex-roaring/src/boolean/mod.rs index 373b2db8bf..9131e5c4e7 100644 --- a/vortex-roaring/src/boolean/mod.rs +++ b/vortex-roaring/src/boolean/mod.rs @@ -5,19 +5,19 @@ use croaring::{Bitmap, Native}; use compress::roaring_encode; use vortex::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, - EncodingId, EncodingRef, + check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, + EncodingRef, }; use vortex::compress::EncodingCompression; use vortex::dtype::DType; use vortex::dtype::Nullability::NonNullable; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -84,16 +84,6 @@ impl Array for RoaringBoolArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - if self.bitmap.contains(index as u32) { - Ok(true.into()) - } else { - Ok(false.into()) - } - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-roaring/src/integer/compute.rs b/vortex-roaring/src/integer/compute.rs new file mode 100644 index 0000000000..9998a14958 --- /dev/null +++ b/vortex-roaring/src/integer/compute.rs @@ -0,0 +1,27 @@ +use crate::RoaringIntArray; +use vortex::compute::scalar_at::ScalarAtFn; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::ptype::PType; +use vortex::scalar::Scalar; + +impl ArrayCompute for RoaringIntArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for RoaringIntArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + // Unwrap since we know the index is valid + let bitmap_value = self.bitmap.select(index as u32).unwrap(); + let scalar: Box = match self.ptype { + PType::U8 => (bitmap_value as u8).into(), + PType::U16 => (bitmap_value as u16).into(), + PType::U32 => bitmap_value.into(), + PType::U64 => (bitmap_value as u64).into(), + _ => unreachable!("RoaringIntArray constructor should have disallowed this type"), + }; + Ok(scalar) + } +} diff --git a/vortex-roaring/src/integer/mod.rs b/vortex-roaring/src/integer/mod.rs index e24469c078..7c650fb794 100644 --- a/vortex-roaring/src/integer/mod.rs +++ b/vortex-roaring/src/integer/mod.rs @@ -5,19 +5,19 @@ use croaring::{Bitmap, Native}; use compress::roaring_encode; use vortex::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, - EncodingId, EncodingRef, + check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, + EncodingRef, }; use vortex::compress::EncodingCompression; use vortex::dtype::DType; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; use vortex::ptype::PType; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -96,20 +96,6 @@ impl Array for RoaringIntArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - // Unwrap since we know the index is valid - let bitmap_value = self.bitmap.select(index as u32).unwrap(); - let scalar: Box = match self.ptype { - PType::U8 => (bitmap_value as u8).into(), - PType::U16 => (bitmap_value as u16).into(), - PType::U32 => bitmap_value.into(), - PType::U64 => (bitmap_value as u64).into(), - _ => unreachable!("RoaringIntArray constructor should have disallowed this type"), - }; - Ok(scalar) - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-zigzag/Cargo.toml b/vortex-zigzag/Cargo.toml index 2f1280f925..c8a608496f 100644 --- a/vortex-zigzag/Cargo.toml +++ b/vortex-zigzag/Cargo.toml @@ -11,11 +11,11 @@ include = { workspace = true } edition = { workspace = true } rust-version = { workspace = true } -[lints] -workspace = true - [dependencies] vortex = { "path" = "../vortex" } linkme = "0.3.22" vortex-alloc = { version = "0.1.0", path = "../vortex-alloc" } zigzag = "0.1.0" + +[lints] +workspace = true diff --git a/vortex-zigzag/src/compute.rs b/vortex-zigzag/src/compute.rs new file mode 100644 index 0000000000..17b70555e3 --- /dev/null +++ b/vortex-zigzag/src/compute.rs @@ -0,0 +1,38 @@ +use crate::ZigZagArray; +use vortex::array::Array; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::dtype::{DType, IntWidth, Signedness}; +use vortex::error::{VortexError, VortexResult}; +use vortex::scalar::{NullableScalar, Scalar}; +use zigzag::ZigZag; + +impl ArrayCompute for ZigZagArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for ZigZagArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let scalar = scalar_at(self.encoded(), index)?; + let Some(scalar) = scalar.as_nonnull() else { + return Ok(NullableScalar::none(self.dtype().clone()).boxed()); + }; + match self.dtype() { + DType::Int(IntWidth::_8, Signedness::Signed, _) => { + Ok(i8::decode(scalar.try_into()?).into()) + } + DType::Int(IntWidth::_16, Signedness::Signed, _) => { + Ok(i16::decode(scalar.try_into()?).into()) + } + DType::Int(IntWidth::_32, Signedness::Signed, _) => { + Ok(i32::decode(scalar.try_into()?).into()) + } + DType::Int(IntWidth::_64, Signedness::Signed, _) => { + Ok(i64::decode(scalar.try_into()?).into()) + } + _ => Err(VortexError::InvalidDType(self.dtype().clone())), + } + } +} diff --git a/vortex-zigzag/src/lib.rs b/vortex-zigzag/src/lib.rs index 55da8b3bcd..904c0baff0 100644 --- a/vortex-zigzag/src/lib.rs +++ b/vortex-zigzag/src/lib.rs @@ -4,6 +4,7 @@ use vortex::array::{EncodingRef, ENCODINGS}; pub use zigzag::*; mod compress; +mod compute; mod downcast; mod serde; mod stats; diff --git a/vortex-zigzag/src/zigzag.rs b/vortex-zigzag/src/zigzag.rs index a9888d2608..08ebace534 100644 --- a/vortex-zigzag/src/zigzag.rs +++ b/vortex-zigzag/src/zigzag.rs @@ -1,17 +1,11 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use zigzag::ZigZag; - -use vortex::array::{ - check_index_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, -}; +use vortex::array::{Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef}; use vortex::compress::EncodingCompression; -use vortex::dtype::{DType, IntWidth, Signedness}; +use vortex::dtype::{DType, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{NullableScalar, Scalar}; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -91,30 +85,6 @@ impl Array for ZigZagArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - let scalar = self.encoded().scalar_at(index)?; - let Some(scalar) = scalar.as_nonnull() else { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - }; - match self.dtype() { - DType::Int(IntWidth::_8, Signedness::Signed, _) => { - Ok(i8::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_16, Signedness::Signed, _) => { - Ok(i16::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_32, Signedness::Signed, _) => { - Ok(i32::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_64, Signedness::Signed, _) => { - Ok(i64::decode(scalar.try_into()?).into()) - } - _ => Err(VortexError::InvalidDType(self.dtype().clone())), - } - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex/src/array/bool/mod.rs b/vortex/src/array/bool/mod.rs index b74340650d..bc06a4274b 100644 --- a/vortex/src/array/bool/mod.rs +++ b/vortex/src/array/bool/mod.rs @@ -8,10 +8,10 @@ use linkme::distributed_slice; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, Nullability}; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stat, Stats, StatsSet}; @@ -50,8 +50,8 @@ impl BoolArray { fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -106,8 +106,6 @@ impl Array for BoolArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> {} - fn iter_arrow(&self) -> Box { Box::new(iter::once(Arc::new(BooleanArray::new( self.buffer.clone(), diff --git a/vortex/src/array/chunked/compute.rs b/vortex/src/array/chunked/compute.rs index 0ebee7b0d3..e1cd17c740 100644 --- a/vortex/src/array/chunked/compute.rs +++ b/vortex/src/array/chunked/compute.rs @@ -1,10 +1,8 @@ -use crate::array::bool::BoolArray; use crate::array::chunked::ChunkedArray; -use crate::array::Array; -use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::{NullableScalar, Scalar}; +use crate::scalar::Scalar; impl ArrayCompute for ChunkedArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -15,6 +13,6 @@ impl ArrayCompute for ChunkedArray { impl ScalarAtFn for ChunkedArray { fn scalar_at(&self, index: usize) -> VortexResult> { let (chunk_index, chunk_offset) = self.find_physical_location(index); - self.chunks[chunk_index].scalar_at(chunk_offset) + scalar_at(self.chunks[chunk_index].as_ref(), chunk_offset) } } diff --git a/vortex/src/array/chunked/mod.rs b/vortex/src/array/chunked/mod.rs index a13093adbf..6618b5a206 100644 --- a/vortex/src/array/chunked/mod.rs +++ b/vortex/src/array/chunked/mod.rs @@ -14,7 +14,6 @@ use crate::compress::EncodingCompression; use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; diff --git a/vortex/src/array/constant/compress.rs b/vortex/src/array/constant/compress.rs index 0826d5b98c..2e3687e71c 100644 --- a/vortex/src/array/constant/compress.rs +++ b/vortex/src/array/constant/compress.rs @@ -1,6 +1,7 @@ use crate::array::constant::{ConstantArray, ConstantEncoding}; use crate::array::{Array, ArrayRef}; use crate::compress::{CompressConfig, CompressCtx, Compressor, EncodingCompression}; +use crate::compute::scalar_at::scalar_at; use crate::stats::Stat; impl EncodingCompression for ConstantEncoding { @@ -22,5 +23,5 @@ fn constant_compressor( _like: Option<&dyn Array>, _ctx: CompressCtx, ) -> ArrayRef { - ConstantArray::new(array.scalar_at(0).unwrap(), array.len()).boxed() + ConstantArray::new(scalar_at(array, 0).unwrap(), array.len()).boxed() } diff --git a/vortex/src/array/constant/compute.rs b/vortex/src/array/constant/compute.rs index e46daaa542..300d1854f5 100644 --- a/vortex/src/array/constant/compute.rs +++ b/vortex/src/array/constant/compute.rs @@ -1,9 +1,29 @@ use crate::array::constant::ConstantArray; +use crate::array::{Array, ArrayRef}; +use crate::compute::scalar_at::ScalarAtFn; use crate::compute::take::TakeFn; use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::Scalar; impl ArrayCompute for ConstantArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } + fn take(&self) -> Option<&dyn TakeFn> { Some(self) } } + +impl ScalarAtFn for ConstantArray { + fn scalar_at(&self, _index: usize) -> VortexResult> { + Ok(dyn_clone::clone_box(self.scalar())) + } +} + +impl TakeFn for ConstantArray { + fn take(&self, indices: &dyn Array) -> VortexResult { + Ok(ConstantArray::new(dyn_clone::clone_box(self.scalar()), indices.len()).boxed()) + } +} diff --git a/vortex/src/array/constant/mod.rs b/vortex/src/array/constant/mod.rs index 7a21343685..e8673aa581 100644 --- a/vortex/src/array/constant/mod.rs +++ b/vortex/src/array/constant/mod.rs @@ -5,8 +5,8 @@ use arrow::array::Datum; use linkme::distributed_slice; use crate::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, ENCODINGS, + check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, + ENCODINGS, }; use crate::arrow::compute::repeat; use crate::compress::EncodingCompression; @@ -21,7 +21,6 @@ mod compress; mod compute; mod serde; mod stats; -mod take; #[derive(Debug, Clone)] pub struct ConstantArray { @@ -80,11 +79,6 @@ impl Array for ConstantArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - Ok(self.scalar.clone()) - } - fn iter_arrow(&self) -> Box { let arrow_scalar: Box = self.scalar.as_ref().into(); Box::new(std::iter::once(repeat(arrow_scalar.as_ref(), self.length))) diff --git a/vortex/src/array/constant/take.rs b/vortex/src/array/constant/take.rs deleted file mode 100644 index 0760e501fd..0000000000 --- a/vortex/src/array/constant/take.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::array::constant::ConstantArray; -use crate::array::{Array, ArrayRef}; -use crate::compute::take::TakeFn; -use crate::error::VortexResult; - -impl TakeFn for ConstantArray { - fn take(&self, indices: &dyn Array) -> VortexResult { - Ok(ConstantArray::new(dyn_clone::clone_box(self.scalar()), indices.len()).boxed()) - } -} diff --git a/vortex/src/array/mod.rs b/vortex/src/array/mod.rs index e42e503af6..072afc4b50 100644 --- a/vortex/src/array/mod.rs +++ b/vortex/src/array/mod.rs @@ -19,7 +19,6 @@ use crate::compute::ArrayCompute; use crate::dtype::{DType, Nullability}; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::Stats; @@ -45,7 +44,9 @@ pub type ArrayRef = Box; /// /// This differs from Apache Arrow where logical and physical are combined in /// the data type, e.g. LargeString, RunEndEncoded. -pub trait Array: ArrayDisplay + Debug + Send + Sync + dyn_clone::DynClone + 'static { +pub trait Array: + ArrayDisplay + ArrayCompute + Debug + Send + Sync + dyn_clone::DynClone + 'static +{ /// Converts itself to a reference of [`Any`], which enables downcasting to concrete types. fn as_any(&self) -> &dyn Any; /// Move an owned array to `ArrayRef` @@ -69,10 +70,6 @@ pub trait Array: ArrayDisplay + Debug + Send + Sync + dyn_clone::DynClone + 'sta /// Approximate size in bytes of the array. Only takes into account variable size portion of the array fn nbytes(&self) -> usize; - fn compute(&self) -> Option<&dyn ArrayCompute> { - None - } - fn serde(&self) -> &dyn ArraySerde; } @@ -98,7 +95,7 @@ pub fn check_slice_bounds(array: &dyn Array, start: usize, stop: usize) -> Vorte Ok(()) } -pub fn check_validity_buffer(validity: Option<&ArrayRef>) -> VortexResult<()> { +pub fn check_validity_buffer(validity: Option<&dyn Array>) -> VortexResult<()> { // TODO(ngates): take a length parameter and check that the length of the validity buffer matches if validity .map(|v| !matches!(v.dtype(), DType::Bool(Nullability::NonNullable))) diff --git a/vortex/src/array/primitive/compute.rs b/vortex/src/array/primitive/compute.rs new file mode 100644 index 0000000000..b11e959fab --- /dev/null +++ b/vortex/src/array/primitive/compute.rs @@ -0,0 +1,30 @@ +use crate::array::primitive::PrimitiveArray; +use crate::array::Array; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::match_each_native_ptype; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for PrimitiveArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for PrimitiveArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + Ok( + match_each_native_ptype!(self.ptype, |$T| self.buffer.typed_data::<$T>() + .get(index) + .unwrap() + .clone() + .into() + ), + ) + } else { + Ok(NullableScalar::none(self.dtype().clone()).boxed()) + } + } +} diff --git a/vortex/src/array/primitive/mod.rs b/vortex/src/array/primitive/mod.rs index 2cff46a45c..845b21e52c 100644 --- a/vortex/src/array/primitive/mod.rs +++ b/vortex/src/array/primitive/mod.rs @@ -15,20 +15,21 @@ use log::debug; use crate::array::bool::BoolArray; use crate::array::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, - Encoding, EncodingId, EncodingRef, ENCODINGS, + check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, + EncodingId, EncodingRef, ENCODINGS, }; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::DType; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; use crate::ptype::{match_each_native_ptype, NativePType, PType}; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -109,8 +110,8 @@ impl PrimitiveArray { pub fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -166,23 +167,6 @@ impl Array for PrimitiveArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - if self.is_valid(index) { - Ok( - match_each_native_ptype!(self.ptype, |$T| self.buffer.typed_data::<$T>() - .get(index) - .unwrap() - .clone() - .into() - ), - ) - } else { - Ok(NullableScalar::none(self.dtype().clone()).boxed()) - } - } - fn iter_arrow(&self) -> Box { Box::new(iter::once(make_array( ArrayData::builder(self.dtype().into()) diff --git a/vortex/src/array/sparse/compute.rs b/vortex/src/array/sparse/compute.rs new file mode 100644 index 0000000000..02b5a9071e --- /dev/null +++ b/vortex/src/array/sparse/compute.rs @@ -0,0 +1,37 @@ +use crate::array::sparse::SparseArray; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide}; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for SparseArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for SparseArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + // Check whether `true_patch_index` exists in the patch index array + // First, get the index of the patch index array that is the first index + // greater than or equal to the true index + let true_patch_index = index + self.indices_offset; + search_sorted_usize(self.indices(), true_patch_index, SearchSortedSide::Left).and_then( + |idx| { + // If the value at this index is equal to the true index, then it exists in the patch index array + // and we should return the value at the corresponding index in the patch values array + scalar_at(self.indices(), idx) + .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) + .and_then(usize::try_from) + .and_then(|patch_index| { + if patch_index == true_patch_index { + scalar_at(self.values(), idx) + } else { + Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) + } + }) + }, + ) + } +} diff --git a/vortex/src/array/sparse/mod.rs b/vortex/src/array/sparse/mod.rs index 86a2e623c0..60f1dbe1ae 100644 --- a/vortex/src/array/sparse/mod.rs +++ b/vortex/src/array/sparse/mod.rs @@ -11,8 +11,7 @@ use linkme::distributed_slice; use crate::array::ENCODINGS; use crate::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, + check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use crate::compress::EncodingCompression; use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide}; @@ -20,11 +19,11 @@ use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; use crate::match_arrow_numeric_type; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -118,32 +117,6 @@ impl Array for SparseArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - // Check whether `true_patch_index` exists in the patch index array - // First, get the index of the patch index array that is the first index - // greater than or equal to the true index - let true_patch_index = index + self.indices_offset; - search_sorted_usize(self.indices(), true_patch_index, SearchSortedSide::Left).and_then( - |idx| { - // If the value at this index is equal to the true index, then it exists in the patch index array - // and we should return the value at the corresponding index in the patch values array - self.indices() - .scalar_at(idx) - .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) - .and_then(usize::try_from) - .and_then(|patch_index| { - if patch_index == true_patch_index { - self.values().scalar_at(idx) - } else { - Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) - } - }) - }, - ) - } - fn iter_arrow(&self) -> Box { // Resolve our indices into a vector of usize applying the offset let mut indices = Vec::with_capacity(self.len()); diff --git a/vortex/src/array/struct_/compute.rs b/vortex/src/array/struct_/compute.rs new file mode 100644 index 0000000000..cbf5adc70d --- /dev/null +++ b/vortex/src/array/struct_/compute.rs @@ -0,0 +1,25 @@ +use crate::array::struct_::StructArray; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::{Scalar, StructScalar}; +use itertools::Itertools; + +impl ArrayCompute for StructArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for StructArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + Ok(StructScalar::new( + self.dtype.clone(), + self.fields + .iter() + .map(|field| scalar_at(field.as_ref(), index)) + .try_collect()?, + ) + .boxed()) + } +} diff --git a/vortex/src/array/struct_/mod.rs b/vortex/src/array/struct_/mod.rs index 698d39a43d..8ea1b16fd6 100644 --- a/vortex/src/array/struct_/mod.rs +++ b/vortex/src/array/struct_/mod.rs @@ -12,7 +12,6 @@ use crate::compress::EncodingCompression; use crate::dtype::{DType, FieldNames}; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{Scalar, StructScalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; @@ -22,6 +21,7 @@ use super::{ }; mod compress; +mod compute; mod serde; mod stats; @@ -112,17 +112,6 @@ impl Array for StructArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - Ok(StructScalar::new( - self.dtype.clone(), - self.fields - .iter() - .map(|field| field.scalar_at(index)) - .try_collect()?, - ) - .boxed()) - } - fn iter_arrow(&self) -> Box { let fields = self.arrow_fields(); Box::new( diff --git a/vortex/src/array/typed/compute.rs b/vortex/src/array/typed/compute.rs new file mode 100644 index 0000000000..331a8ccdfb --- /dev/null +++ b/vortex/src/array/typed/compute.rs @@ -0,0 +1,19 @@ +use crate::array::typed::TypedArray; +use crate::array::Array; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::Scalar; + +impl ArrayCompute for TypedArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for TypedArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let underlying = scalar_at(self.array.as_ref(), index)?; + underlying.as_ref().cast(self.dtype()) + } +} diff --git a/vortex/src/array/typed/mod.rs b/vortex/src/array/typed/mod.rs index 945cc2957d..40ad18649f 100644 --- a/vortex/src/array/typed/mod.rs +++ b/vortex/src/array/typed/mod.rs @@ -9,11 +9,11 @@ use crate::compress::EncodingCompression; use crate::dtype::DType; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -85,11 +85,6 @@ impl Array for TypedArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - let underlying = self.array.scalar_at(index)?; - underlying.as_ref().cast(self.dtype()) - } - // TODO(robert): Have cast happen in enc space and not in arrow space fn iter_arrow(&self) -> Box { let datatype: DataType = self.dtype().into(); diff --git a/vortex/src/array/varbin/compute.rs b/vortex/src/array/varbin/compute.rs new file mode 100644 index 0000000000..5823e3c819 --- /dev/null +++ b/vortex/src/array/varbin/compute.rs @@ -0,0 +1,28 @@ +use crate::array::varbin::VarBinArray; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::dtype::DType; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for VarBinArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for VarBinArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + self.bytes_at(index).map(|bytes| { + if matches!(self.dtype, DType::Utf8(_)) { + unsafe { String::from_utf8_unchecked(bytes) }.into() + } else { + bytes.into() + } + }) + } else { + Ok(NullableScalar::none(self.dtype.clone()).boxed()) + } + } +} diff --git a/vortex/src/array/varbin/mod.rs b/vortex/src/array/varbin/mod.rs index beab3d5dab..45a112b23a 100644 --- a/vortex/src/array/varbin/mod.rs +++ b/vortex/src/array/varbin/mod.rs @@ -12,21 +12,22 @@ use crate::array::bool::BoolArray; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::primitive::PrimitiveArray; use crate::array::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, - Encoding, EncodingId, EncodingRef, ENCODINGS, + check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, + EncodingId, EncodingRef, ENCODINGS, }; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, IntWidth, Nullability, Signedness}; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; use crate::match_each_native_ptype; use crate::ptype::NativePType; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -92,8 +93,8 @@ impl VarBinArray { fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -179,7 +180,7 @@ impl VarBinArray { } pub fn bytes_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; + // check_index_bounds(self, index)?; let (start, end): (usize, usize) = if let Some(p) = self.offsets.maybe_primitive() { match_each_native_ptype!(p.ptype(), |$P| { @@ -188,8 +189,8 @@ impl VarBinArray { }) } else { ( - self.offsets().scalar_at(index)?.try_into()?, - self.offsets().scalar_at(index + 1)?.try_into()?, + scalar_at(self.offsets(), index)?.try_into()?, + scalar_at(self.offsets(), index + 1)?.try_into()?, ) }; let sliced = self.bytes().slice(start, end)?; @@ -234,20 +235,6 @@ impl Array for VarBinArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if self.is_valid(index) { - self.bytes_at(index).map(|bytes| { - if matches!(self.dtype, DType::Utf8(_)) { - unsafe { String::from_utf8_unchecked(bytes) }.into() - } else { - bytes.into() - } - }) - } else { - Ok(NullableScalar::none(self.dtype.clone()).boxed()) - } - } - fn iter_arrow(&self) -> Box { let offsets_data = self.offsets.iter_arrow().combine_chunks().into_data(); let bytes_data = self.bytes.iter_arrow().combine_chunks().into_data(); diff --git a/vortex/src/array/varbinview/compute.rs b/vortex/src/array/varbinview/compute.rs new file mode 100644 index 0000000000..ca169f2636 --- /dev/null +++ b/vortex/src/array/varbinview/compute.rs @@ -0,0 +1,28 @@ +use crate::array::varbinview::VarBinViewArray; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::dtype::DType; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for VarBinViewArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for VarBinViewArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + self.bytes_at(index).map(|bytes| { + if matches!(self.dtype, DType::Utf8(_)) { + unsafe { String::from_utf8_unchecked(bytes) }.into() + } else { + bytes.into() + } + }) + } else { + Ok(NullableScalar::none(self.dtype.clone()).boxed()) + } + } +} diff --git a/vortex/src/array/varbinview/mod.rs b/vortex/src/array/varbinview/mod.rs index bfa7a4ac95..ee8838379f 100644 --- a/vortex/src/array/varbinview/mod.rs +++ b/vortex/src/array/varbinview/mod.rs @@ -1,4 +1,5 @@ mod compress; +mod compute; mod serde; use std::any::Any; @@ -17,10 +18,10 @@ use crate::array::{ }; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, IntWidth, Nullability, Signedness}; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; @@ -145,8 +146,8 @@ impl VarBinViewArray { fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -246,20 +247,6 @@ impl Array for VarBinViewArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if self.is_valid(index) { - self.bytes_at(index).map(|bytes| { - if matches!(self.dtype, DType::Utf8(_)) { - unsafe { String::from_utf8_unchecked(bytes) }.into() - } else { - bytes.into() - } - }) - } else { - Ok(NullableScalar::none(self.dtype.clone()).boxed()) - } - } - fn iter_arrow(&self) -> Box { let data_arr: ArrowArrayRef = if matches!(self.dtype, DType::Utf8(_)) { let mut data_buf = StringBuilder::with_capacity(self.len(), self.plain_size()); diff --git a/vortex/src/compute/mod.rs b/vortex/src/compute/mod.rs index ed4376bff1..f9273d7829 100644 --- a/vortex/src/compute/mod.rs +++ b/vortex/src/compute/mod.rs @@ -1,5 +1,4 @@ use crate::compute::scalar_at::ScalarAtFn; -use polars_arrow::scalar::Scalar; use take::TakeFn; pub mod add; diff --git a/vortex/src/compute/scalar_at.rs b/vortex/src/compute/scalar_at.rs index eae1ce3634..cf37161247 100644 --- a/vortex/src/compute/scalar_at.rs +++ b/vortex/src/compute/scalar_at.rs @@ -12,8 +12,7 @@ pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult VortexResult { - array - .compute() - .and_then(|c| c.take()) - .map(|t| t.take(indices)) - .unwrap_or_else(|| { - // TODO(ngates): default implementation of decode and then try again - Err(VortexError::ComputeError( - format!("take not implemented for {}", &array.encoding().id()).into(), - )) - }) + array.take().map(|t| t.take(indices)).unwrap_or_else(|| { + // TODO(ngates): default implementation of decode and then try again + Err(VortexError::ComputeError( + format!("take not implemented for {}", &array.encoding().id()).into(), + )) + }) } diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index fad5ab605f..0fbd8789b5 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(iterator_try_collect)] + pub mod array; pub mod arrow; pub mod scalar; From 1a7efaf7c8464f8a6525f904f74752c76f4db554 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Mon, 4 Mar 2024 14:35:36 +0000 Subject: [PATCH 3/3] merge --- vortex-dict/src/compress.rs | 6 +++--- vortex-ree/src/ree.rs | 9 +++++---- vortex-roaring/src/boolean/mod.rs | 11 ++++++----- vortex-roaring/src/integer/mod.rs | 7 ++++--- vortex/src/array/bool/mod.rs | 6 +++--- vortex/src/array/primitive/compress.rs | 3 ++- vortex/src/array/primitive/mod.rs | 12 ++++++------ vortex/src/array/sparse/mod.rs | 22 +++++++++++++--------- vortex/src/array/typed/mod.rs | 5 +++-- vortex/src/array/varbin/mod.rs | 9 +++++---- vortex/src/array/varbinview/mod.rs | 6 +++--- vortex/src/compute/add.rs | 2 +- 12 files changed, 54 insertions(+), 44 deletions(-) diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index f1ef44f8c2..227ecd559a 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -269,7 +269,7 @@ where mod test { use vortex::array::primitive::PrimitiveArray; use vortex::array::varbin::VarBinArray; - use vortex::array::Array; + use vortex::compute::scalar_at::scalar_at; use crate::compress::{dict_encode_typed_primitive, dict_encode_varbin}; @@ -298,8 +298,8 @@ mod test { assert!(!codes.is_valid(2)); assert!(!codes.is_valid(5)); assert!(!codes.is_valid(7)); - assert_eq!(values.scalar_at(0), Ok(1.into())); - assert_eq!(values.scalar_at(2), Ok(3.into())); + assert_eq!(scalar_at(values.as_ref(), 0), Ok(1.into())); + assert_eq!(scalar_at(values.as_ref(), 2), Ok(3.into())); } #[test] diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 06ed424925..ece19ab5f4 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -315,6 +315,7 @@ mod test { use arrow::array::types::Int32Type; use itertools::Itertools; use vortex::array::Array; + use vortex::compute::scalar_at::scalar_at; use crate::REEArray; use vortex::dtype::{DType, IntWidth, Nullability, Signedness}; @@ -331,10 +332,10 @@ mod test { // 0, 1 => 1 // 2, 3, 4 => 2 // 5, 6, 7, 8, 9 => 3 - assert_eq!(arr.scalar_at(0).unwrap().try_into(), Ok(1)); - assert_eq!(arr.scalar_at(2).unwrap().try_into(), Ok(2)); - assert_eq!(arr.scalar_at(5).unwrap().try_into(), Ok(3)); - assert_eq!(arr.scalar_at(9).unwrap().try_into(), Ok(3)); + assert_eq!(scalar_at(arr.as_ref(), 0).unwrap().try_into(), Ok(1)); + assert_eq!(scalar_at(arr.as_ref(), 2).unwrap().try_into(), Ok(2)); + assert_eq!(scalar_at(arr.as_ref(), 5).unwrap().try_into(), Ok(3)); + assert_eq!(scalar_at(arr.as_ref(), 9).unwrap().try_into(), Ok(3)); } #[test] diff --git a/vortex-roaring/src/boolean/mod.rs b/vortex-roaring/src/boolean/mod.rs index 9131e5c4e7..da547ba0d5 100644 --- a/vortex-roaring/src/boolean/mod.rs +++ b/vortex-roaring/src/boolean/mod.rs @@ -155,6 +155,7 @@ impl Encoding for RoaringBoolEncoding { mod test { use vortex::array::bool::BoolArray; use vortex::array::Array; + use vortex::compute::scalar_at::scalar_at; use vortex::error::VortexResult; use vortex::scalar::Scalar; @@ -172,17 +173,17 @@ mod test { } #[test] - pub fn scalar_at() -> VortexResult<()> { + pub fn test_scalar_at() -> VortexResult<()> { let bool: &dyn Array = &BoolArray::from(vec![true, false, true, true]); let array = RoaringBoolArray::encode(bool)?; let truthy: Box = true.into(); let falsy: Box = false.into(); - assert_eq!(array.scalar_at(0)?, truthy); - assert_eq!(array.scalar_at(1)?, falsy); - assert_eq!(array.scalar_at(2)?, truthy); - assert_eq!(array.scalar_at(3)?, truthy); + assert_eq!(scalar_at(array.as_ref(), 0)?, truthy); + assert_eq!(scalar_at(array.as_ref(), 1)?, falsy); + assert_eq!(scalar_at(array.as_ref(), 2)?, truthy); + assert_eq!(scalar_at(array.as_ref(), 3)?, truthy); Ok(()) } diff --git a/vortex-roaring/src/integer/mod.rs b/vortex-roaring/src/integer/mod.rs index 7c650fb794..a2efcc4d27 100644 --- a/vortex-roaring/src/integer/mod.rs +++ b/vortex-roaring/src/integer/mod.rs @@ -157,17 +157,18 @@ impl Encoding for RoaringIntEncoding { mod test { use vortex::array::primitive::PrimitiveArray; use vortex::array::Array; + use vortex::compute::scalar_at::scalar_at; use vortex::error::VortexResult; use crate::RoaringIntArray; #[test] - pub fn scalar_at() -> VortexResult<()> { + pub fn test_scalar_at() -> VortexResult<()> { let ints: &dyn Array = &PrimitiveArray::from_vec::(vec![2, 12, 22, 32]); let array = RoaringIntArray::encode(ints)?; - assert_eq!(array.scalar_at(0), Ok(2u32.into())); - assert_eq!(array.scalar_at(1), Ok(12u32.into())); + assert_eq!(scalar_at(array.as_ref(), 0), Ok(2u32.into())); + assert_eq!(scalar_at(array.as_ref(), 1), Ok(12u32.into())); Ok(()) } diff --git a/vortex/src/array/bool/mod.rs b/vortex/src/array/bool/mod.rs index bc06a4274b..a0d3820f39 100644 --- a/vortex/src/array/bool/mod.rs +++ b/vortex/src/array/bool/mod.rs @@ -234,9 +234,9 @@ mod test { .slice(1, 4) .unwrap(); assert_eq!(arr.len(), 3); - assert_eq!(arr.scalar_at(0).unwrap().try_into(), Ok(true)); - assert_eq!(arr.scalar_at(1).unwrap().try_into(), Ok(false)); - assert_eq!(arr.scalar_at(2).unwrap().try_into(), Ok(false)); + assert_eq!(scalar_at(arr.as_ref(), 0).unwrap().try_into(), Ok(true)); + assert_eq!(scalar_at(arr.as_ref(), 1).unwrap().try_into(), Ok(false)); + assert_eq!(scalar_at(arr.as_ref(), 2).unwrap().try_into(), Ok(false)); } #[test] diff --git a/vortex/src/array/primitive/compress.rs b/vortex/src/array/primitive/compress.rs index 1d5aaf0fd8..1c0cf0c5c8 100644 --- a/vortex/src/array/primitive/compress.rs +++ b/vortex/src/array/primitive/compress.rs @@ -32,12 +32,13 @@ mod test { use crate::array::primitive::PrimitiveArray; use crate::array::Encoding; use crate::compress::CompressCtx; + use crate::compute::scalar_at::scalar_at; #[test] pub fn compress_constant() { let arr = PrimitiveArray::from_vec(vec![1, 1, 1, 1]); let res = CompressCtx::default().compress(arr.as_ref(), None); assert_eq!(res.encoding().id(), ConstantEncoding.id()); - assert_eq!(res.scalar_at(3).unwrap().try_into(), Ok(1)); + assert_eq!(scalar_at(res.as_ref(), 3).unwrap().try_into(), Ok(1)); } } diff --git a/vortex/src/array/primitive/mod.rs b/vortex/src/array/primitive/mod.rs index 845b21e52c..17f19470b2 100644 --- a/vortex/src/array/primitive/mod.rs +++ b/vortex/src/array/primitive/mod.rs @@ -319,9 +319,9 @@ mod test { ); // Ensure we can fetch the scalar at the given index. - assert_eq!(arr.scalar_at(0).unwrap().try_into(), Ok(1)); - assert_eq!(arr.scalar_at(1).unwrap().try_into(), Ok(2)); - assert_eq!(arr.scalar_at(2).unwrap().try_into(), Ok(3)); + assert_eq!(scalar_at(arr.as_ref(), 0).unwrap().try_into(), Ok(1)); + assert_eq!(scalar_at(arr.as_ref(), 1).unwrap().try_into(), Ok(2)); + assert_eq!(scalar_at(arr.as_ref(), 2).unwrap().try_into(), Ok(3)); } #[test] @@ -330,8 +330,8 @@ mod test { .slice(1, 4) .unwrap(); assert_eq!(arr.len(), 3); - assert_eq!(arr.scalar_at(0).unwrap().try_into(), Ok(2)); - assert_eq!(arr.scalar_at(1).unwrap().try_into(), Ok(3)); - assert_eq!(arr.scalar_at(2).unwrap().try_into(), Ok(4)); + assert_eq!(scalar_at(arr.as_ref(), 0).unwrap().try_into(), Ok(2)); + assert_eq!(scalar_at(arr.as_ref(), 1).unwrap().try_into(), Ok(3)); + assert_eq!(scalar_at(arr.as_ref(), 2).unwrap().try_into(), Ok(4)); } } diff --git a/vortex/src/array/sparse/mod.rs b/vortex/src/array/sparse/mod.rs index 60f1dbe1ae..4f6c8c0e64 100644 --- a/vortex/src/array/sparse/mod.rs +++ b/vortex/src/array/sparse/mod.rs @@ -231,6 +231,7 @@ mod test { use crate::array::sparse::SparseArray; use crate::array::Array; + use crate::compute::scalar_at::scalar_at; use crate::error::VortexError; fn sparse_array() -> SparseArray { @@ -291,13 +292,13 @@ mod test { } #[test] - pub fn scalar_at() { + pub fn test_scalar_at() { assert_eq!( - usize::try_from(sparse_array().scalar_at(2).unwrap()).unwrap(), + usize::try_from(scalar_at(sparse_array().as_ref(), 2).unwrap()).unwrap(), 100 ); assert_eq!( - sparse_array().scalar_at(10).err().unwrap(), + scalar_at(sparse_array().as_ref(), 10).err().unwrap(), VortexError::OutOfBounds(10, 0, 10) ); } @@ -305,9 +306,12 @@ mod test { #[test] pub fn scalar_at_sliced() { let sliced = sparse_array().slice(2, 7).unwrap(); - assert_eq!(usize::try_from(sliced.scalar_at(0).unwrap()).unwrap(), 100); assert_eq!( - sliced.scalar_at(5).err().unwrap(), + usize::try_from(scalar_at(sliced.as_ref(), 0).unwrap()).unwrap(), + 100 + ); + assert_eq!( + scalar_at(sliced.as_ref(), 5).err().unwrap(), VortexError::OutOfBounds(5, 0, 5) ); } @@ -316,21 +320,21 @@ mod test { pub fn scalar_at_sliced_twice() { let sliced_once = sparse_array().slice(1, 8).unwrap(); assert_eq!( - usize::try_from(sliced_once.scalar_at(1).unwrap()).unwrap(), + usize::try_from(scalar_at(sliced_once.as_ref(), 1).unwrap()).unwrap(), 100 ); assert_eq!( - sliced_once.scalar_at(7).err().unwrap(), + scalar_at(sliced_once.as_ref(), 7).err().unwrap(), VortexError::OutOfBounds(7, 0, 7) ); let sliced_twice = sliced_once.slice(1, 6).unwrap(); assert_eq!( - usize::try_from(sliced_twice.scalar_at(3).unwrap()).unwrap(), + usize::try_from(scalar_at(sliced_twice.as_ref(), 3).unwrap()).unwrap(), 200 ); assert_eq!( - sliced_twice.scalar_at(5).err().unwrap(), + scalar_at(sliced_twice.as_ref(), 5).err().unwrap(), VortexError::OutOfBounds(5, 0, 5) ); } diff --git a/vortex/src/array/typed/mod.rs b/vortex/src/array/typed/mod.rs index 40ad18649f..30ce3041fb 100644 --- a/vortex/src/array/typed/mod.rs +++ b/vortex/src/array/typed/mod.rs @@ -161,6 +161,7 @@ mod test { use crate::array::typed::TypedArray; use crate::array::Array; + use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, Nullability, TimeUnit}; use crate::scalar::{LocalTimeScalar, PScalar, Scalar}; @@ -171,11 +172,11 @@ mod test { DType::LocalTime(TimeUnit::Us, Nullability::NonNullable), ); assert_eq!( - arr.scalar_at(0).unwrap().as_ref(), + scalar_at(arr.as_ref(), 0).unwrap().as_ref(), &LocalTimeScalar::new(PScalar::U64(64_799_000_000), TimeUnit::Us) as &dyn Scalar ); assert_eq!( - arr.scalar_at(1).unwrap().as_ref(), + scalar_at(arr.as_ref(), 1).unwrap().as_ref(), &LocalTimeScalar::new(PScalar::U64(43_000_000_000), TimeUnit::Us) as &dyn Scalar ); } diff --git a/vortex/src/array/varbin/mod.rs b/vortex/src/array/varbin/mod.rs index 45a112b23a..c6dd68100f 100644 --- a/vortex/src/array/varbin/mod.rs +++ b/vortex/src/array/varbin/mod.rs @@ -383,6 +383,7 @@ mod test { use crate::array::primitive::PrimitiveArray; use crate::array::varbin::VarBinArray; use crate::arrow::CombineChunks; + use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, Nullability}; fn binary_array() -> VarBinArray { @@ -402,12 +403,12 @@ mod test { } #[test] - pub fn scalar_at() { + pub fn test_scalar_at() { let binary_arr = binary_array(); assert_eq!(binary_arr.len(), 2); - assert_eq!(binary_arr.scalar_at(0), Ok("hello world".into())); + assert_eq!(scalar_at(binary_arr.as_ref(), 0), Ok("hello world".into())); assert_eq!( - binary_arr.scalar_at(1), + scalar_at(binary_arr.as_ref(), 1), Ok("hello world this is a long string".into()) ) } @@ -416,7 +417,7 @@ mod test { pub fn slice() { let binary_arr = binary_array().slice(1, 2).unwrap(); assert_eq!( - binary_arr.scalar_at(0), + scalar_at(binary_arr.as_ref(), 0), Ok("hello world this is a long string".into()) ); } diff --git a/vortex/src/array/varbinview/mod.rs b/vortex/src/array/varbinview/mod.rs index ee8838379f..fc90e88825 100644 --- a/vortex/src/array/varbinview/mod.rs +++ b/vortex/src/array/varbinview/mod.rs @@ -392,9 +392,9 @@ mod test { pub fn varbin_view() { let binary_arr = binary_array(); assert_eq!(binary_arr.len(), 2); - assert_eq!(binary_arr.scalar_at(0), Ok("hello world".into())); + assert_eq!(scalar_at(binary_arr.as_ref(), 0), Ok("hello world".into())); assert_eq!( - binary_arr.scalar_at(1), + scalar_at(binary_arr.as_ref(), 1), Ok("hello world this is a long string".into()) ) } @@ -403,7 +403,7 @@ mod test { pub fn slice() { let binary_arr = binary_array().slice(1, 2).unwrap(); assert_eq!( - binary_arr.scalar_at(0), + scalar_at(binary_arr.as_ref(), 0), Ok("hello world this is a long string".into()) ); } diff --git a/vortex/src/compute/add.rs b/vortex/src/compute/add.rs index 6ef1002204..bbf9ec63e8 100644 --- a/vortex/src/compute/add.rs +++ b/vortex/src/compute/add.rs @@ -45,6 +45,6 @@ mod test { let rhs = ConstantArray::new(47.into(), 100); let result = add(&lhs, &rhs).unwrap(); assert_eq!(result.len(), 100); - // assert_eq!(result.scalar_at(0), 94); + // assert_eq!(scalar_at(result, 0), 94); } }