From bc2c5dc960b20a29631f805cd6392f8f7caf0771 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Mon, 4 Mar 2024 14:25:43 +0000 Subject: [PATCH] merge --- vortex-alp/Cargo.toml | 3 ++ vortex-alp/src/alp.rs | 35 +------------------ vortex-alp/src/compute.rs | 47 ++++++++++++++++++++++++++ vortex-alp/src/lib.rs | 1 + vortex-dict/Cargo.toml | 3 ++ vortex-dict/src/compress.rs | 5 +-- vortex-dict/src/compute.rs | 18 ++++++++++ vortex-dict/src/dict.rs | 11 +----- vortex-dict/src/lib.rs | 1 + vortex-fastlanes/src/bitpacking/mod.rs | 23 +++---------- vortex-fastlanes/src/for/mod.rs | 7 ++-- vortex-ffor/src/ffor.rs | 46 ++++--------------------- vortex-ree/Cargo.toml | 3 ++ vortex-ree/src/compute.rs | 17 ++++++++++ vortex-ree/src/lib.rs | 1 + vortex-ree/src/ree.rs | 14 +++----- vortex-roaring/Cargo.toml | 6 ++-- vortex-roaring/src/boolean/compute.rs | 21 ++++++++++++ vortex-roaring/src/boolean/mod.rs | 16 ++------- vortex-roaring/src/integer/compute.rs | 27 +++++++++++++++ vortex-roaring/src/integer/mod.rs | 20 ++--------- vortex-zigzag/Cargo.toml | 6 ++-- vortex-zigzag/src/compute.rs | 38 +++++++++++++++++++++ vortex-zigzag/src/lib.rs | 1 + vortex-zigzag/src/zigzag.rs | 34 ++----------------- vortex/src/array/bool/mod.rs | 8 ++--- vortex/src/array/chunked/compute.rs | 8 ++--- vortex/src/array/chunked/mod.rs | 1 - vortex/src/array/constant/compress.rs | 3 +- vortex/src/array/constant/compute.rs | 20 +++++++++++ vortex/src/array/constant/mod.rs | 10 ++---- vortex/src/array/constant/take.rs | 10 ------ vortex/src/array/mod.rs | 11 +++--- vortex/src/array/primitive/compute.rs | 30 ++++++++++++++++ vortex/src/array/primitive/mod.rs | 28 ++++----------- vortex/src/array/sparse/compute.rs | 37 ++++++++++++++++++++ vortex/src/array/sparse/mod.rs | 31 ++--------------- vortex/src/array/struct_/compute.rs | 25 ++++++++++++++ vortex/src/array/struct_/mod.rs | 13 +------ vortex/src/array/typed/compute.rs | 19 +++++++++++ vortex/src/array/typed/mod.rs | 7 +--- vortex/src/array/varbin/compute.rs | 28 +++++++++++++++ vortex/src/array/varbin/mod.rs | 31 +++++------------ vortex/src/array/varbinview/compute.rs | 28 +++++++++++++++ vortex/src/array/varbinview/mod.rs | 21 +++--------- vortex/src/compute/mod.rs | 1 - vortex/src/compute/scalar_at.rs | 3 +- vortex/src/compute/take.rs | 16 ++++----- vortex/src/lib.rs | 2 ++ 49 files changed, 451 insertions(+), 344 deletions(-) create mode 100644 vortex-alp/src/compute.rs create mode 100644 vortex-dict/src/compute.rs create mode 100644 vortex-ree/src/compute.rs create mode 100644 vortex-roaring/src/boolean/compute.rs create mode 100644 vortex-roaring/src/integer/compute.rs create mode 100644 vortex-zigzag/src/compute.rs delete mode 100644 vortex/src/array/constant/take.rs create mode 100644 vortex/src/array/primitive/compute.rs create mode 100644 vortex/src/array/sparse/compute.rs create mode 100644 vortex/src/array/struct_/compute.rs create mode 100644 vortex/src/array/typed/compute.rs create mode 100644 vortex/src/array/varbin/compute.rs create mode 100644 vortex/src/array/varbinview/compute.rs diff --git a/vortex-alp/Cargo.toml b/vortex-alp/Cargo.toml index 776c8c31ce..cb1a1cdd1d 100644 --- a/vortex-alp/Cargo.toml +++ b/vortex-alp/Cargo.toml @@ -18,3 +18,6 @@ linkme = "0.3.22" itertools = "0.12.1" codecz = { version = "0.1.0", path = "../codecz" } log = { version = "0.4.20", features = [] } + +[lints] +workspace = true diff --git a/vortex-alp/src/alp.rs b/vortex-alp/src/alp.rs index 43f4b0d747..65d37418aa 100644 --- a/vortex-alp/src/alp.rs +++ b/vortex-alp/src/alp.rs @@ -1,14 +1,12 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use codecz::alp; pub use codecz::alp::ALPExponents; use vortex::array::{Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef}; use vortex::compress::EncodingCompression; -use vortex::dtype::{DType, FloatWidth, IntWidth, Signedness}; +use vortex::dtype::{DType, IntWidth, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{NullableScalar, Scalar}; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -106,37 +104,6 @@ impl Array for ALPArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if let Some(patch) = self - .patches() - .and_then(|p| p.scalar_at(index).ok()) - .and_then(|p| p.into_nonnull()) - { - return Ok(patch); - } - - let Some(encoded_val) = self.encoded.scalar_at(index)?.into_nonnull() else { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - }; - match self.dtype { - DType::Float(FloatWidth::_32, _) => { - let encoded_val: i32 = encoded_val.try_into().unwrap(); - Ok(alp::decode_single::(encoded_val, self.exponents) - .unwrap() - .into()) - } - - DType::Float(FloatWidth::_64, _) => { - let encoded_val: i64 = encoded_val.try_into().unwrap(); - Ok(alp::decode_single::(encoded_val, self.exponents) - .unwrap() - .into()) - } - - _ => unreachable!(), - } - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-alp/src/compute.rs b/vortex-alp/src/compute.rs new file mode 100644 index 0000000000..4ae36d02f8 --- /dev/null +++ b/vortex-alp/src/compute.rs @@ -0,0 +1,47 @@ +use crate::ALPArray; +use codecz::alp; +use vortex::array::Array; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::dtype::{DType, FloatWidth}; +use vortex::error::VortexResult; +use vortex::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for ALPArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for ALPArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if let Some(patch) = self + .patches() + .and_then(|p| scalar_at(p, index).ok()) + .and_then(|p| p.into_nonnull()) + { + return Ok(patch); + } + + let Some(encoded_val) = scalar_at(self.encoded(), index)?.into_nonnull() else { + return Ok(NullableScalar::none(self.dtype().clone()).boxed()); + }; + match self.dtype() { + DType::Float(FloatWidth::_32, _) => { + let encoded_val: i32 = encoded_val.try_into().unwrap(); + Ok(alp::decode_single::(encoded_val, self.exponents()) + .unwrap() + .into()) + } + + DType::Float(FloatWidth::_64, _) => { + let encoded_val: i64 = encoded_val.try_into().unwrap(); + Ok(alp::decode_single::(encoded_val, self.exponents()) + .unwrap() + .into()) + } + + _ => unreachable!(), + } + } +} diff --git a/vortex-alp/src/lib.rs b/vortex-alp/src/lib.rs index 6f0fe3ae42..8ff1d432cc 100644 --- a/vortex-alp/src/lib.rs +++ b/vortex-alp/src/lib.rs @@ -4,6 +4,7 @@ use vortex::array::{EncodingRef, ENCODINGS}; mod alp; mod compress; +mod compute; mod downcast; mod serde; mod stats; diff --git a/vortex-dict/Cargo.toml b/vortex-dict/Cargo.toml index 5f9e19ab6d..0ea173dfe0 100644 --- a/vortex-dict/Cargo.toml +++ b/vortex-dict/Cargo.toml @@ -19,3 +19,6 @@ hashbrown = "0.14.3" linkme = "0.3.22" log = "0.4.20" num-traits = "0.2.17" + +[lints] +workspace = true diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index a3f6dd4ea4..f1ef44f8c2 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -11,6 +11,7 @@ use vortex::array::primitive::PrimitiveArray; use vortex::array::varbin::VarBinArray; use vortex::array::{Array, ArrayKind, ArrayRef, CloneOptionalArray}; use vortex::compress::{CompressConfig, CompressCtx, Compressor, EncodingCompression}; +use vortex::compute::scalar_at::scalar_at; use vortex::dtype::DType; use vortex::match_each_native_ptype; use vortex::ptype::NativePType; @@ -207,8 +208,8 @@ fn bytes_at_primitive<'a, T: NativePType + AsPrimitive>( } fn bytes_at<'a>(offsets: &'a dyn Array, bytes: &'a [u8], idx: usize) -> &'a [u8] { - let start: usize = offsets.scalar_at(idx).unwrap().try_into().unwrap(); - let stop: usize = offsets.scalar_at(idx + 1).unwrap().try_into().unwrap(); + let start: usize = scalar_at(offsets, idx).unwrap().try_into().unwrap(); + let stop: usize = scalar_at(offsets, idx + 1).unwrap().try_into().unwrap(); &bytes[start..stop] } diff --git a/vortex-dict/src/compute.rs b/vortex-dict/src/compute.rs new file mode 100644 index 0000000000..b8292458b8 --- /dev/null +++ b/vortex-dict/src/compute.rs @@ -0,0 +1,18 @@ +use crate::DictArray; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +impl ArrayCompute for DictArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for DictArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let dict_index: usize = scalar_at(self.codes(), index)?.try_into()?; + scalar_at(self.dict(), dict_index) + } +} diff --git a/vortex-dict/src/dict.rs b/vortex-dict/src/dict.rs index bb332cd453..8cb0cbcd54 100644 --- a/vortex-dict/src/dict.rs +++ b/vortex-dict/src/dict.rs @@ -1,14 +1,11 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use vortex::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, -}; +use vortex::array::{check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId}; use vortex::compress::EncodingCompression; use vortex::dtype::{DType, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -75,12 +72,6 @@ impl Array for DictArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - let dict_index: usize = self.codes().scalar_at(index)?.try_into()?; - self.dict().scalar_at(dict_index) - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-dict/src/lib.rs b/vortex-dict/src/lib.rs index afef63b7da..7342c1d293 100644 --- a/vortex-dict/src/lib.rs +++ b/vortex-dict/src/lib.rs @@ -5,6 +5,7 @@ pub use compress::*; pub use dict::*; mod compress; +mod compute; mod dict; mod downcast; mod serde; diff --git a/vortex-fastlanes/src/bitpacking/mod.rs b/vortex-fastlanes/src/bitpacking/mod.rs index 20b6b86c77..4f49307ec6 100644 --- a/vortex-fastlanes/src/bitpacking/mod.rs +++ b/vortex-fastlanes/src/bitpacking/mod.rs @@ -6,10 +6,11 @@ use vortex::array::{ check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use vortex::compress::EncodingCompression; +use vortex::compute::scalar_at::scalar_at; +use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::VortexResult; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{NullableScalar, Scalar}; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stat, Stats, StatsCompute, StatsSet}; @@ -74,7 +75,7 @@ impl BitPackedArray { pub fn is_valid(&self, index: usize) -> bool { self.validity() - .map(|v| v.scalar_at(index).and_then(|v| v.try_into()).unwrap()) + .map(|v| scalar_at(v, index).and_then(|v| v.try_into()).unwrap()) .unwrap_or(true) } } @@ -115,22 +116,6 @@ impl Array for BitPackedArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if !self.is_valid(index) { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - } - - if let Some(patch) = self - .patches() - .and_then(|p| p.scalar_at(index).ok()) - .and_then(|p| p.into_nonnull()) - { - return Ok(patch); - } - - todo!("Decode single element from BitPacked array"); - } - fn iter_arrow(&self) -> Box { todo!() } @@ -156,6 +141,8 @@ impl Array for BitPackedArray { } } +impl ArrayCompute for BitPackedArray {} + impl<'arr> AsRef<(dyn Array + 'arr)> for BitPackedArray { fn as_ref(&self) -> &(dyn Array + 'arr) { self diff --git a/vortex-fastlanes/src/for/mod.rs b/vortex-fastlanes/src/for/mod.rs index cd7b8fba9a..c0a3ec6e20 100644 --- a/vortex-fastlanes/src/for/mod.rs +++ b/vortex-fastlanes/src/for/mod.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef}; use vortex::compress::EncodingCompression; +use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::VortexResult; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -78,10 +79,6 @@ impl Array for FoRArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, _index: usize) -> VortexResult> { - todo!() - } - fn iter_arrow(&self) -> Box { todo!() } @@ -110,6 +107,8 @@ impl Array for FoRArray { } } +impl ArrayCompute for FoRArray {} + impl<'arr> AsRef<(dyn Array + 'arr)> for FoRArray { fn as_ref(&self) -> &(dyn Array + 'arr) { self diff --git a/vortex-ffor/src/ffor.rs b/vortex-ffor/src/ffor.rs index fd187d9fa0..5fd1325bd8 100644 --- a/vortex-ffor/src/ffor.rs +++ b/vortex-ffor/src/ffor.rs @@ -1,17 +1,17 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use vortex::array::downcast::DowncastArrayBuiltin; use vortex::array::{ check_validity_buffer, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use vortex::compress::EncodingCompression; +use vortex::compute::scalar_at::scalar_at; +use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::match_each_integer_ptype; -use vortex::scalar::{NullableScalar, Scalar}; +use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -104,7 +104,7 @@ impl FFORArray { pub fn is_valid(&self, index: usize) -> bool { self.validity() - .map(|v| v.scalar_at(index).and_then(|v| v.try_into()).unwrap()) + .map(|v| scalar_at(v, index).and_then(|v| v.try_into()).unwrap()) .unwrap_or(true) } } @@ -145,42 +145,6 @@ impl Array for FFORArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if !self.is_valid(index) { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - } - - if let Some(patch) = self - .patches() - .and_then(|p| p.scalar_at(index).ok()) - .and_then(|p| p.into_nonnull()) - { - return Ok(patch); - } - - let Some(parray) = self.encoded().maybe_primitive() else { - return Err(VortexError::InvalidEncoding( - self.encoded().encoding().id().clone(), - )); - }; - - if let Ok(ptype) = self.dtype().try_into() { - match_each_integer_ptype!(ptype, |$T| { - return Ok(codecz::ffor::decode_single::<$T>( - parray.buffer().as_slice(), - self.len, - self.num_bits, - self.min_val().try_into().unwrap(), - index, - ) - .unwrap() - .into()); - }) - } else { - return Err(VortexError::InvalidDType(self.dtype().clone())); - } - } - fn iter_arrow(&self) -> Box { todo!() } @@ -204,6 +168,8 @@ impl Array for FFORArray { } } +impl ArrayCompute for FFORArray {} + impl<'arr> AsRef<(dyn Array + 'arr)> for FFORArray { fn as_ref(&self) -> &(dyn Array + 'arr) { self diff --git a/vortex-ree/Cargo.toml b/vortex-ree/Cargo.toml index 22bf5a8dad..92f4c5b515 100644 --- a/vortex-ree/Cargo.toml +++ b/vortex-ree/Cargo.toml @@ -19,3 +19,6 @@ linkme = "0.3.22" half = "2.3.1" num-traits = "0.2.17" itertools = "0.10.5" + +[lints] +workspace = true diff --git a/vortex-ree/src/compute.rs b/vortex-ree/src/compute.rs new file mode 100644 index 0000000000..8fb64b37dd --- /dev/null +++ b/vortex-ree/src/compute.rs @@ -0,0 +1,17 @@ +use crate::REEArray; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +impl ArrayCompute for REEArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for REEArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + scalar_at(self.values(), self.find_physical_index(index)?) + } +} diff --git a/vortex-ree/src/lib.rs b/vortex-ree/src/lib.rs index 7fecbf9c23..c6e878a3e1 100644 --- a/vortex-ree/src/lib.rs +++ b/vortex-ree/src/lib.rs @@ -4,6 +4,7 @@ use vortex::array::{EncodingRef, ENCODINGS}; pub use ree::*; mod compress; +mod compute; mod downcast; mod ree; mod serde; diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 6457b5870a..06ed424925 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -10,18 +10,18 @@ use num_traits::AsPrimitive; use codecz::ree::SupportsREE; use vortex::array::primitive::PrimitiveArray; use vortex::array::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayKind, ArrayRef, - ArrowIterator, CloneOptionalArray, Encoding, EncodingId, EncodingRef, + check_slice_bounds, check_validity_buffer, Array, ArrayKind, ArrayRef, ArrowIterator, + CloneOptionalArray, Encoding, EncodingId, EncodingRef, }; use vortex::arrow::match_arrow_numeric_type; use vortex::compress::EncodingCompression; use vortex::compute; +use vortex::compute::scalar_at::scalar_at; use vortex::compute::search_sorted::SearchSortedSide; use vortex::dtype::{DType, Nullability, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; use vortex::ptype::NativePType; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stat, Stats, StatsSet}; @@ -158,11 +158,6 @@ impl Array for REEArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - self.values.scalar_at(self.find_physical_index(index)?) - } - fn iter_arrow(&self) -> Box { // TODO(robert): Plumb offset rewriting to zig to fuse with REE decompression let ends: Vec = self @@ -308,8 +303,7 @@ fn run_ends_logical_length>(ends: &T) -> usize { if ends.as_ref().is_empty() { 0 } else { - ends.as_ref() - .scalar_at(ends.as_ref().len() - 1) + scalar_at(ends.as_ref(), ends.as_ref().len() - 1) .and_then(|end| end.try_into()) .unwrap_or_else(|_| panic!("Couldn't convert ends to usize")) } diff --git a/vortex-roaring/Cargo.toml b/vortex-roaring/Cargo.toml index 57dcbe86e1..5dacd630f7 100644 --- a/vortex-roaring/Cargo.toml +++ b/vortex-roaring/Cargo.toml @@ -11,12 +11,12 @@ include = { workspace = true } edition = { workspace = true } rust-version = { workspace = true } -[lints] -workspace = true - [dependencies] vortex = { "path" = "../vortex" } linkme = "0.3.22" croaring = "1.0.1" num-traits = "0.2.17" log = "0.4.20" + +[lints] +workspace = true diff --git a/vortex-roaring/src/boolean/compute.rs b/vortex-roaring/src/boolean/compute.rs new file mode 100644 index 0000000000..411245c962 --- /dev/null +++ b/vortex-roaring/src/boolean/compute.rs @@ -0,0 +1,21 @@ +use crate::RoaringBoolArray; +use vortex::compute::scalar_at::ScalarAtFn; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +impl ArrayCompute for RoaringBoolArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for RoaringBoolArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.bitmap.contains(index as u32) { + Ok(true.into()) + } else { + Ok(false.into()) + } + } +} diff --git a/vortex-roaring/src/boolean/mod.rs b/vortex-roaring/src/boolean/mod.rs index 373b2db8bf..9131e5c4e7 100644 --- a/vortex-roaring/src/boolean/mod.rs +++ b/vortex-roaring/src/boolean/mod.rs @@ -5,19 +5,19 @@ use croaring::{Bitmap, Native}; use compress::roaring_encode; use vortex::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, - EncodingId, EncodingRef, + check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, + EncodingRef, }; use vortex::compress::EncodingCompression; use vortex::dtype::DType; use vortex::dtype::Nullability::NonNullable; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -84,16 +84,6 @@ impl Array for RoaringBoolArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - if self.bitmap.contains(index as u32) { - Ok(true.into()) - } else { - Ok(false.into()) - } - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-roaring/src/integer/compute.rs b/vortex-roaring/src/integer/compute.rs new file mode 100644 index 0000000000..9998a14958 --- /dev/null +++ b/vortex-roaring/src/integer/compute.rs @@ -0,0 +1,27 @@ +use crate::RoaringIntArray; +use vortex::compute::scalar_at::ScalarAtFn; +use vortex::compute::ArrayCompute; +use vortex::error::VortexResult; +use vortex::ptype::PType; +use vortex::scalar::Scalar; + +impl ArrayCompute for RoaringIntArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for RoaringIntArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + // Unwrap since we know the index is valid + let bitmap_value = self.bitmap.select(index as u32).unwrap(); + let scalar: Box = match self.ptype { + PType::U8 => (bitmap_value as u8).into(), + PType::U16 => (bitmap_value as u16).into(), + PType::U32 => bitmap_value.into(), + PType::U64 => (bitmap_value as u64).into(), + _ => unreachable!("RoaringIntArray constructor should have disallowed this type"), + }; + Ok(scalar) + } +} diff --git a/vortex-roaring/src/integer/mod.rs b/vortex-roaring/src/integer/mod.rs index e24469c078..7c650fb794 100644 --- a/vortex-roaring/src/integer/mod.rs +++ b/vortex-roaring/src/integer/mod.rs @@ -5,19 +5,19 @@ use croaring::{Bitmap, Native}; use compress::roaring_encode; use vortex::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, - EncodingId, EncodingRef, + check_slice_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, + EncodingRef, }; use vortex::compress::EncodingCompression; use vortex::dtype::DType; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; use vortex::ptype::PType; -use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -96,20 +96,6 @@ impl Array for RoaringIntArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - // Unwrap since we know the index is valid - let bitmap_value = self.bitmap.select(index as u32).unwrap(); - let scalar: Box = match self.ptype { - PType::U8 => (bitmap_value as u8).into(), - PType::U16 => (bitmap_value as u16).into(), - PType::U32 => bitmap_value.into(), - PType::U64 => (bitmap_value as u64).into(), - _ => unreachable!("RoaringIntArray constructor should have disallowed this type"), - }; - Ok(scalar) - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex-zigzag/Cargo.toml b/vortex-zigzag/Cargo.toml index 2f1280f925..c8a608496f 100644 --- a/vortex-zigzag/Cargo.toml +++ b/vortex-zigzag/Cargo.toml @@ -11,11 +11,11 @@ include = { workspace = true } edition = { workspace = true } rust-version = { workspace = true } -[lints] -workspace = true - [dependencies] vortex = { "path" = "../vortex" } linkme = "0.3.22" vortex-alloc = { version = "0.1.0", path = "../vortex-alloc" } zigzag = "0.1.0" + +[lints] +workspace = true diff --git a/vortex-zigzag/src/compute.rs b/vortex-zigzag/src/compute.rs new file mode 100644 index 0000000000..17b70555e3 --- /dev/null +++ b/vortex-zigzag/src/compute.rs @@ -0,0 +1,38 @@ +use crate::ZigZagArray; +use vortex::array::Array; +use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::ArrayCompute; +use vortex::dtype::{DType, IntWidth, Signedness}; +use vortex::error::{VortexError, VortexResult}; +use vortex::scalar::{NullableScalar, Scalar}; +use zigzag::ZigZag; + +impl ArrayCompute for ZigZagArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for ZigZagArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let scalar = scalar_at(self.encoded(), index)?; + let Some(scalar) = scalar.as_nonnull() else { + return Ok(NullableScalar::none(self.dtype().clone()).boxed()); + }; + match self.dtype() { + DType::Int(IntWidth::_8, Signedness::Signed, _) => { + Ok(i8::decode(scalar.try_into()?).into()) + } + DType::Int(IntWidth::_16, Signedness::Signed, _) => { + Ok(i16::decode(scalar.try_into()?).into()) + } + DType::Int(IntWidth::_32, Signedness::Signed, _) => { + Ok(i32::decode(scalar.try_into()?).into()) + } + DType::Int(IntWidth::_64, Signedness::Signed, _) => { + Ok(i64::decode(scalar.try_into()?).into()) + } + _ => Err(VortexError::InvalidDType(self.dtype().clone())), + } + } +} diff --git a/vortex-zigzag/src/lib.rs b/vortex-zigzag/src/lib.rs index 55da8b3bcd..904c0baff0 100644 --- a/vortex-zigzag/src/lib.rs +++ b/vortex-zigzag/src/lib.rs @@ -4,6 +4,7 @@ use vortex::array::{EncodingRef, ENCODINGS}; pub use zigzag::*; mod compress; +mod compute; mod downcast; mod serde; mod stats; diff --git a/vortex-zigzag/src/zigzag.rs b/vortex-zigzag/src/zigzag.rs index a9888d2608..08ebace534 100644 --- a/vortex-zigzag/src/zigzag.rs +++ b/vortex-zigzag/src/zigzag.rs @@ -1,17 +1,11 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use zigzag::ZigZag; - -use vortex::array::{ - check_index_bounds, Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, -}; +use vortex::array::{Array, ArrayKind, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef}; use vortex::compress::EncodingCompression; -use vortex::dtype::{DType, IntWidth, Signedness}; +use vortex::dtype::{DType, Signedness}; use vortex::error::{VortexError, VortexResult}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{NullableScalar, Scalar}; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stats, StatsSet}; @@ -91,30 +85,6 @@ impl Array for ZigZagArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - let scalar = self.encoded().scalar_at(index)?; - let Some(scalar) = scalar.as_nonnull() else { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - }; - match self.dtype() { - DType::Int(IntWidth::_8, Signedness::Signed, _) => { - Ok(i8::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_16, Signedness::Signed, _) => { - Ok(i16::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_32, Signedness::Signed, _) => { - Ok(i32::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_64, Signedness::Signed, _) => { - Ok(i64::decode(scalar.try_into()?).into()) - } - _ => Err(VortexError::InvalidDType(self.dtype().clone())), - } - } - fn iter_arrow(&self) -> Box { todo!() } diff --git a/vortex/src/array/bool/mod.rs b/vortex/src/array/bool/mod.rs index b74340650d..bc06a4274b 100644 --- a/vortex/src/array/bool/mod.rs +++ b/vortex/src/array/bool/mod.rs @@ -8,10 +8,10 @@ use linkme::distributed_slice; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, Nullability}; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stat, Stats, StatsSet}; @@ -50,8 +50,8 @@ impl BoolArray { fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -106,8 +106,6 @@ impl Array for BoolArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> {} - fn iter_arrow(&self) -> Box { Box::new(iter::once(Arc::new(BooleanArray::new( self.buffer.clone(), diff --git a/vortex/src/array/chunked/compute.rs b/vortex/src/array/chunked/compute.rs index 0ebee7b0d3..e1cd17c740 100644 --- a/vortex/src/array/chunked/compute.rs +++ b/vortex/src/array/chunked/compute.rs @@ -1,10 +1,8 @@ -use crate::array::bool::BoolArray; use crate::array::chunked::ChunkedArray; -use crate::array::Array; -use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::{NullableScalar, Scalar}; +use crate::scalar::Scalar; impl ArrayCompute for ChunkedArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -15,6 +13,6 @@ impl ArrayCompute for ChunkedArray { impl ScalarAtFn for ChunkedArray { fn scalar_at(&self, index: usize) -> VortexResult> { let (chunk_index, chunk_offset) = self.find_physical_location(index); - self.chunks[chunk_index].scalar_at(chunk_offset) + scalar_at(self.chunks[chunk_index].as_ref(), chunk_offset) } } diff --git a/vortex/src/array/chunked/mod.rs b/vortex/src/array/chunked/mod.rs index a13093adbf..6618b5a206 100644 --- a/vortex/src/array/chunked/mod.rs +++ b/vortex/src/array/chunked/mod.rs @@ -14,7 +14,6 @@ use crate::compress::EncodingCompression; use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; diff --git a/vortex/src/array/constant/compress.rs b/vortex/src/array/constant/compress.rs index 0826d5b98c..2e3687e71c 100644 --- a/vortex/src/array/constant/compress.rs +++ b/vortex/src/array/constant/compress.rs @@ -1,6 +1,7 @@ use crate::array::constant::{ConstantArray, ConstantEncoding}; use crate::array::{Array, ArrayRef}; use crate::compress::{CompressConfig, CompressCtx, Compressor, EncodingCompression}; +use crate::compute::scalar_at::scalar_at; use crate::stats::Stat; impl EncodingCompression for ConstantEncoding { @@ -22,5 +23,5 @@ fn constant_compressor( _like: Option<&dyn Array>, _ctx: CompressCtx, ) -> ArrayRef { - ConstantArray::new(array.scalar_at(0).unwrap(), array.len()).boxed() + ConstantArray::new(scalar_at(array, 0).unwrap(), array.len()).boxed() } diff --git a/vortex/src/array/constant/compute.rs b/vortex/src/array/constant/compute.rs index e46daaa542..300d1854f5 100644 --- a/vortex/src/array/constant/compute.rs +++ b/vortex/src/array/constant/compute.rs @@ -1,9 +1,29 @@ use crate::array::constant::ConstantArray; +use crate::array::{Array, ArrayRef}; +use crate::compute::scalar_at::ScalarAtFn; use crate::compute::take::TakeFn; use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::Scalar; impl ArrayCompute for ConstantArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } + fn take(&self) -> Option<&dyn TakeFn> { Some(self) } } + +impl ScalarAtFn for ConstantArray { + fn scalar_at(&self, _index: usize) -> VortexResult> { + Ok(dyn_clone::clone_box(self.scalar())) + } +} + +impl TakeFn for ConstantArray { + fn take(&self, indices: &dyn Array) -> VortexResult { + Ok(ConstantArray::new(dyn_clone::clone_box(self.scalar()), indices.len()).boxed()) + } +} diff --git a/vortex/src/array/constant/mod.rs b/vortex/src/array/constant/mod.rs index 7a21343685..e8673aa581 100644 --- a/vortex/src/array/constant/mod.rs +++ b/vortex/src/array/constant/mod.rs @@ -5,8 +5,8 @@ use arrow::array::Datum; use linkme::distributed_slice; use crate::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, ENCODINGS, + check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, + ENCODINGS, }; use crate::arrow::compute::repeat; use crate::compress::EncodingCompression; @@ -21,7 +21,6 @@ mod compress; mod compute; mod serde; mod stats; -mod take; #[derive(Debug, Clone)] pub struct ConstantArray { @@ -80,11 +79,6 @@ impl Array for ConstantArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - Ok(self.scalar.clone()) - } - fn iter_arrow(&self) -> Box { let arrow_scalar: Box = self.scalar.as_ref().into(); Box::new(std::iter::once(repeat(arrow_scalar.as_ref(), self.length))) diff --git a/vortex/src/array/constant/take.rs b/vortex/src/array/constant/take.rs deleted file mode 100644 index 0760e501fd..0000000000 --- a/vortex/src/array/constant/take.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::array::constant::ConstantArray; -use crate::array::{Array, ArrayRef}; -use crate::compute::take::TakeFn; -use crate::error::VortexResult; - -impl TakeFn for ConstantArray { - fn take(&self, indices: &dyn Array) -> VortexResult { - Ok(ConstantArray::new(dyn_clone::clone_box(self.scalar()), indices.len()).boxed()) - } -} diff --git a/vortex/src/array/mod.rs b/vortex/src/array/mod.rs index e42e503af6..072afc4b50 100644 --- a/vortex/src/array/mod.rs +++ b/vortex/src/array/mod.rs @@ -19,7 +19,6 @@ use crate::compute::ArrayCompute; use crate::dtype::{DType, Nullability}; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::Stats; @@ -45,7 +44,9 @@ pub type ArrayRef = Box; /// /// This differs from Apache Arrow where logical and physical are combined in /// the data type, e.g. LargeString, RunEndEncoded. -pub trait Array: ArrayDisplay + Debug + Send + Sync + dyn_clone::DynClone + 'static { +pub trait Array: + ArrayDisplay + ArrayCompute + Debug + Send + Sync + dyn_clone::DynClone + 'static +{ /// Converts itself to a reference of [`Any`], which enables downcasting to concrete types. fn as_any(&self) -> &dyn Any; /// Move an owned array to `ArrayRef` @@ -69,10 +70,6 @@ pub trait Array: ArrayDisplay + Debug + Send + Sync + dyn_clone::DynClone + 'sta /// Approximate size in bytes of the array. Only takes into account variable size portion of the array fn nbytes(&self) -> usize; - fn compute(&self) -> Option<&dyn ArrayCompute> { - None - } - fn serde(&self) -> &dyn ArraySerde; } @@ -98,7 +95,7 @@ pub fn check_slice_bounds(array: &dyn Array, start: usize, stop: usize) -> Vorte Ok(()) } -pub fn check_validity_buffer(validity: Option<&ArrayRef>) -> VortexResult<()> { +pub fn check_validity_buffer(validity: Option<&dyn Array>) -> VortexResult<()> { // TODO(ngates): take a length parameter and check that the length of the validity buffer matches if validity .map(|v| !matches!(v.dtype(), DType::Bool(Nullability::NonNullable))) diff --git a/vortex/src/array/primitive/compute.rs b/vortex/src/array/primitive/compute.rs new file mode 100644 index 0000000000..b11e959fab --- /dev/null +++ b/vortex/src/array/primitive/compute.rs @@ -0,0 +1,30 @@ +use crate::array::primitive::PrimitiveArray; +use crate::array::Array; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::match_each_native_ptype; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for PrimitiveArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for PrimitiveArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + Ok( + match_each_native_ptype!(self.ptype, |$T| self.buffer.typed_data::<$T>() + .get(index) + .unwrap() + .clone() + .into() + ), + ) + } else { + Ok(NullableScalar::none(self.dtype().clone()).boxed()) + } + } +} diff --git a/vortex/src/array/primitive/mod.rs b/vortex/src/array/primitive/mod.rs index 2cff46a45c..845b21e52c 100644 --- a/vortex/src/array/primitive/mod.rs +++ b/vortex/src/array/primitive/mod.rs @@ -15,20 +15,21 @@ use log::debug; use crate::array::bool::BoolArray; use crate::array::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, - Encoding, EncodingId, EncodingRef, ENCODINGS, + check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, + EncodingId, EncodingRef, ENCODINGS, }; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::DType; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; use crate::ptype::{match_each_native_ptype, NativePType, PType}; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -109,8 +110,8 @@ impl PrimitiveArray { pub fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -166,23 +167,6 @@ impl Array for PrimitiveArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - if self.is_valid(index) { - Ok( - match_each_native_ptype!(self.ptype, |$T| self.buffer.typed_data::<$T>() - .get(index) - .unwrap() - .clone() - .into() - ), - ) - } else { - Ok(NullableScalar::none(self.dtype().clone()).boxed()) - } - } - fn iter_arrow(&self) -> Box { Box::new(iter::once(make_array( ArrayData::builder(self.dtype().into()) diff --git a/vortex/src/array/sparse/compute.rs b/vortex/src/array/sparse/compute.rs new file mode 100644 index 0000000000..02b5a9071e --- /dev/null +++ b/vortex/src/array/sparse/compute.rs @@ -0,0 +1,37 @@ +use crate::array::sparse::SparseArray; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide}; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for SparseArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for SparseArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + // Check whether `true_patch_index` exists in the patch index array + // First, get the index of the patch index array that is the first index + // greater than or equal to the true index + let true_patch_index = index + self.indices_offset; + search_sorted_usize(self.indices(), true_patch_index, SearchSortedSide::Left).and_then( + |idx| { + // If the value at this index is equal to the true index, then it exists in the patch index array + // and we should return the value at the corresponding index in the patch values array + scalar_at(self.indices(), idx) + .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) + .and_then(usize::try_from) + .and_then(|patch_index| { + if patch_index == true_patch_index { + scalar_at(self.values(), idx) + } else { + Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) + } + }) + }, + ) + } +} diff --git a/vortex/src/array/sparse/mod.rs b/vortex/src/array/sparse/mod.rs index 86a2e623c0..60f1dbe1ae 100644 --- a/vortex/src/array/sparse/mod.rs +++ b/vortex/src/array/sparse/mod.rs @@ -11,8 +11,7 @@ use linkme::distributed_slice; use crate::array::ENCODINGS; use crate::array::{ - check_index_bounds, check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, - EncodingRef, + check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use crate::compress::EncodingCompression; use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide}; @@ -20,11 +19,11 @@ use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; use crate::match_arrow_numeric_type; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -118,32 +117,6 @@ impl Array for SparseArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; - - // Check whether `true_patch_index` exists in the patch index array - // First, get the index of the patch index array that is the first index - // greater than or equal to the true index - let true_patch_index = index + self.indices_offset; - search_sorted_usize(self.indices(), true_patch_index, SearchSortedSide::Left).and_then( - |idx| { - // If the value at this index is equal to the true index, then it exists in the patch index array - // and we should return the value at the corresponding index in the patch values array - self.indices() - .scalar_at(idx) - .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) - .and_then(usize::try_from) - .and_then(|patch_index| { - if patch_index == true_patch_index { - self.values().scalar_at(idx) - } else { - Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) - } - }) - }, - ) - } - fn iter_arrow(&self) -> Box { // Resolve our indices into a vector of usize applying the offset let mut indices = Vec::with_capacity(self.len()); diff --git a/vortex/src/array/struct_/compute.rs b/vortex/src/array/struct_/compute.rs new file mode 100644 index 0000000000..cbf5adc70d --- /dev/null +++ b/vortex/src/array/struct_/compute.rs @@ -0,0 +1,25 @@ +use crate::array::struct_::StructArray; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::{Scalar, StructScalar}; +use itertools::Itertools; + +impl ArrayCompute for StructArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for StructArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + Ok(StructScalar::new( + self.dtype.clone(), + self.fields + .iter() + .map(|field| scalar_at(field.as_ref(), index)) + .try_collect()?, + ) + .boxed()) + } +} diff --git a/vortex/src/array/struct_/mod.rs b/vortex/src/array/struct_/mod.rs index 698d39a43d..8ea1b16fd6 100644 --- a/vortex/src/array/struct_/mod.rs +++ b/vortex/src/array/struct_/mod.rs @@ -12,7 +12,6 @@ use crate::compress::EncodingCompression; use crate::dtype::{DType, FieldNames}; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{Scalar, StructScalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; @@ -22,6 +21,7 @@ use super::{ }; mod compress; +mod compute; mod serde; mod stats; @@ -112,17 +112,6 @@ impl Array for StructArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - Ok(StructScalar::new( - self.dtype.clone(), - self.fields - .iter() - .map(|field| field.scalar_at(index)) - .try_collect()?, - ) - .boxed()) - } - fn iter_arrow(&self) -> Box { let fields = self.arrow_fields(); Box::new( diff --git a/vortex/src/array/typed/compute.rs b/vortex/src/array/typed/compute.rs new file mode 100644 index 0000000000..331a8ccdfb --- /dev/null +++ b/vortex/src/array/typed/compute.rs @@ -0,0 +1,19 @@ +use crate::array::typed::TypedArray; +use crate::array::Array; +use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::ArrayCompute; +use crate::error::VortexResult; +use crate::scalar::Scalar; + +impl ArrayCompute for TypedArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for TypedArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + let underlying = scalar_at(self.array.as_ref(), index)?; + underlying.as_ref().cast(self.dtype()) + } +} diff --git a/vortex/src/array/typed/mod.rs b/vortex/src/array/typed/mod.rs index 945cc2957d..40ad18649f 100644 --- a/vortex/src/array/typed/mod.rs +++ b/vortex/src/array/typed/mod.rs @@ -9,11 +9,11 @@ use crate::compress::EncodingCompression; use crate::dtype::DType; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::Scalar; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -85,11 +85,6 @@ impl Array for TypedArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - let underlying = self.array.scalar_at(index)?; - underlying.as_ref().cast(self.dtype()) - } - // TODO(robert): Have cast happen in enc space and not in arrow space fn iter_arrow(&self) -> Box { let datatype: DataType = self.dtype().into(); diff --git a/vortex/src/array/varbin/compute.rs b/vortex/src/array/varbin/compute.rs new file mode 100644 index 0000000000..5823e3c819 --- /dev/null +++ b/vortex/src/array/varbin/compute.rs @@ -0,0 +1,28 @@ +use crate::array::varbin::VarBinArray; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::dtype::DType; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for VarBinArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for VarBinArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + self.bytes_at(index).map(|bytes| { + if matches!(self.dtype, DType::Utf8(_)) { + unsafe { String::from_utf8_unchecked(bytes) }.into() + } else { + bytes.into() + } + }) + } else { + Ok(NullableScalar::none(self.dtype.clone()).boxed()) + } + } +} diff --git a/vortex/src/array/varbin/mod.rs b/vortex/src/array/varbin/mod.rs index beab3d5dab..45a112b23a 100644 --- a/vortex/src/array/varbin/mod.rs +++ b/vortex/src/array/varbin/mod.rs @@ -12,21 +12,22 @@ use crate::array::bool::BoolArray; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::primitive::PrimitiveArray; use crate::array::{ - check_index_bounds, check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, - Encoding, EncodingId, EncodingRef, ENCODINGS, + check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, + EncodingId, EncodingRef, ENCODINGS, }; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, IntWidth, Nullability, Signedness}; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; use crate::match_each_native_ptype; use crate::ptype::NativePType; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; mod compress; +mod compute; mod serde; mod stats; @@ -92,8 +93,8 @@ impl VarBinArray { fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -179,7 +180,7 @@ impl VarBinArray { } pub fn bytes_at(&self, index: usize) -> VortexResult> { - check_index_bounds(self, index)?; + // check_index_bounds(self, index)?; let (start, end): (usize, usize) = if let Some(p) = self.offsets.maybe_primitive() { match_each_native_ptype!(p.ptype(), |$P| { @@ -188,8 +189,8 @@ impl VarBinArray { }) } else { ( - self.offsets().scalar_at(index)?.try_into()?, - self.offsets().scalar_at(index + 1)?.try_into()?, + scalar_at(self.offsets(), index)?.try_into()?, + scalar_at(self.offsets(), index + 1)?.try_into()?, ) }; let sliced = self.bytes().slice(start, end)?; @@ -234,20 +235,6 @@ impl Array for VarBinArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if self.is_valid(index) { - self.bytes_at(index).map(|bytes| { - if matches!(self.dtype, DType::Utf8(_)) { - unsafe { String::from_utf8_unchecked(bytes) }.into() - } else { - bytes.into() - } - }) - } else { - Ok(NullableScalar::none(self.dtype.clone()).boxed()) - } - } - fn iter_arrow(&self) -> Box { let offsets_data = self.offsets.iter_arrow().combine_chunks().into_data(); let bytes_data = self.bytes.iter_arrow().combine_chunks().into_data(); diff --git a/vortex/src/array/varbinview/compute.rs b/vortex/src/array/varbinview/compute.rs new file mode 100644 index 0000000000..ca169f2636 --- /dev/null +++ b/vortex/src/array/varbinview/compute.rs @@ -0,0 +1,28 @@ +use crate::array::varbinview::VarBinViewArray; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; +use crate::dtype::DType; +use crate::error::VortexResult; +use crate::scalar::{NullableScalar, Scalar}; + +impl ArrayCompute for VarBinViewArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} + +impl ScalarAtFn for VarBinViewArray { + fn scalar_at(&self, index: usize) -> VortexResult> { + if self.is_valid(index) { + self.bytes_at(index).map(|bytes| { + if matches!(self.dtype, DType::Utf8(_)) { + unsafe { String::from_utf8_unchecked(bytes) }.into() + } else { + bytes.into() + } + }) + } else { + Ok(NullableScalar::none(self.dtype.clone()).boxed()) + } + } +} diff --git a/vortex/src/array/varbinview/mod.rs b/vortex/src/array/varbinview/mod.rs index bfa7a4ac95..ee8838379f 100644 --- a/vortex/src/array/varbinview/mod.rs +++ b/vortex/src/array/varbinview/mod.rs @@ -1,4 +1,5 @@ mod compress; +mod compute; mod serde; use std::any::Any; @@ -17,10 +18,10 @@ use crate::array::{ }; use crate::arrow::CombineChunks; use crate::compress::EncodingCompression; +use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, IntWidth, Nullability, Signedness}; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{NullableScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; @@ -145,8 +146,8 @@ impl VarBinViewArray { fn is_valid(&self, index: usize) -> bool { self.validity - .as_ref() - .map(|v| v.scalar_at(index).unwrap().try_into().unwrap()) + .as_deref() + .map(|v| scalar_at(v, index).unwrap().try_into().unwrap()) .unwrap_or(true) } @@ -246,20 +247,6 @@ impl Array for VarBinViewArray { Stats::new(&self.stats, self) } - fn scalar_at(&self, index: usize) -> VortexResult> { - if self.is_valid(index) { - self.bytes_at(index).map(|bytes| { - if matches!(self.dtype, DType::Utf8(_)) { - unsafe { String::from_utf8_unchecked(bytes) }.into() - } else { - bytes.into() - } - }) - } else { - Ok(NullableScalar::none(self.dtype.clone()).boxed()) - } - } - fn iter_arrow(&self) -> Box { let data_arr: ArrowArrayRef = if matches!(self.dtype, DType::Utf8(_)) { let mut data_buf = StringBuilder::with_capacity(self.len(), self.plain_size()); diff --git a/vortex/src/compute/mod.rs b/vortex/src/compute/mod.rs index ed4376bff1..f9273d7829 100644 --- a/vortex/src/compute/mod.rs +++ b/vortex/src/compute/mod.rs @@ -1,5 +1,4 @@ use crate::compute::scalar_at::ScalarAtFn; -use polars_arrow::scalar::Scalar; use take::TakeFn; pub mod add; diff --git a/vortex/src/compute/scalar_at.rs b/vortex/src/compute/scalar_at.rs index eae1ce3634..cf37161247 100644 --- a/vortex/src/compute/scalar_at.rs +++ b/vortex/src/compute/scalar_at.rs @@ -12,8 +12,7 @@ pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult VortexResult { - array - .compute() - .and_then(|c| c.take()) - .map(|t| t.take(indices)) - .unwrap_or_else(|| { - // TODO(ngates): default implementation of decode and then try again - Err(VortexError::ComputeError( - format!("take not implemented for {}", &array.encoding().id()).into(), - )) - }) + array.take().map(|t| t.take(indices)).unwrap_or_else(|| { + // TODO(ngates): default implementation of decode and then try again + Err(VortexError::ComputeError( + format!("take not implemented for {}", &array.encoding().id()).into(), + )) + }) } diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index fad5ab605f..0fbd8789b5 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(iterator_try_collect)] + pub mod array; pub mod arrow; pub mod scalar;