From 051627f1a0d14b1ae0c9780faa91aad5d52946ab Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 10 Jul 2024 22:17:55 +0100 Subject: [PATCH 1/2] Chunked take --- vortex-array/src/array/chunked/canonical.rs | 15 ++-- .../src/array/chunked/compute/take.rs | 64 ++++++++++++++- .../src/array/primitive/compute/cast.rs | 12 +-- .../src/array/primitive/compute/slice.rs | 15 ++-- .../primitive/compute/subtract_scalar.rs | 79 +------------------ vortex-array/src/array/primitive/mod.rs | 48 +++++------ vortex-array/src/arrow/array.rs | 35 +++----- vortex-buffer/src/lib.rs | 7 ++ 8 files changed, 131 insertions(+), 144 deletions(-) diff --git a/vortex-array/src/array/chunked/canonical.rs b/vortex-array/src/array/chunked/canonical.rs index 2d88d7fdd9..c42f361ee5 100644 --- a/vortex-array/src/array/chunked/canonical.rs +++ b/vortex-array/src/array/chunked/canonical.rs @@ -1,6 +1,6 @@ -use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer}; +use arrow_buffer::{BooleanBuffer, Buffer, MutableBuffer}; use itertools::Itertools; -use vortex_dtype::{match_each_native_ptype, DType, Nullability, PType, StructDType}; +use vortex_dtype::{DType, Nullability, PType, StructDType}; use vortex_error::{vortex_bail, ErrString, VortexResult}; use crate::accessor::ArrayAccessor; @@ -152,12 +152,11 @@ fn pack_primitives( buffer.extend_from_slice(chunk.buffer()); } - match_each_native_ptype!(ptype, |$T| { - Ok(PrimitiveArray::try_new( - ScalarBuffer::<$T>::from(buffer), - validity, - )?) - }) + Ok(PrimitiveArray::new( + Buffer::from(buffer).into(), + ptype, + validity, + )) } /// Builds a new [VarBinArray] by repacking the values from the chunks into a single diff --git a/vortex-array/src/array/chunked/compute/take.rs b/vortex-array/src/array/chunked/compute/take.rs index 9afd96649a..e0e7533df9 100644 --- a/vortex-array/src/array/chunked/compute/take.rs +++ b/vortex-array/src/array/chunked/compute/take.rs @@ -1,19 +1,35 @@ +use itertools::Itertools; use vortex_dtype::PType; use vortex_error::VortexResult; use crate::array::chunked::ChunkedArray; use crate::array::primitive::PrimitiveArray; +use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; +use crate::compute::slice::slice; use crate::compute::take::{take, TakeFn}; use crate::compute::unary::cast::try_cast; +use crate::compute::unary::scalar_at::scalar_at; +use crate::compute::unary::scalar_subtract::subtract_scalar; +use crate::stats::ArrayStatistics; use crate::ArrayDType; use crate::{Array, IntoArray, ToArray}; impl TakeFn for ChunkedArray { fn take(&self, indices: &Array) -> VortexResult { - if self.len() == indices.len() { - return Ok(self.to_array()); + // Fast path for strict sorted indices. + if indices + .statistics() + .compute_is_strict_sorted() + .unwrap_or(false) + { + if self.len() == indices.len() { + return Ok(self.to_array()); + } + + return take_strict_sorted(self, indices); } + // FIXME(ngates): this is wrong, need to canonicalise let indices = PrimitiveArray::try_from(try_cast(indices, PType::U64.into())?)?; // While the chunk idx remains the same, accumulate a list of chunk indices. @@ -51,6 +67,50 @@ impl TakeFn for ChunkedArray { } } +/// When the indices are non-null and strict-sorted, we can do better +fn take_strict_sorted(chunked: &ChunkedArray, indices: &Array) -> VortexResult { + let mut indices_by_chunk = vec![None; chunked.nchunks()]; + + // Track our position in the indices array + let mut pos = 0; + while pos < indices.len() { + // Locate the chunk index for the current index + let idx = usize::try_from(&scalar_at(indices, pos)?).unwrap(); + let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx); + + // Find the end of this chunk, and locate that position in the indices array. + let chunk_begin = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx)?).unwrap(); + let chunk_end = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx + 1)?).unwrap(); + let chunk_end_pos = search_sorted(indices, chunk_end, SearchSortedSide::Left) + .unwrap() + .to_index(); + + // Now we can say the slice of indices belonging to this chunk is [pos, chunk_end_pos) + let chunk_indices = slice(indices, pos, chunk_end_pos)?; + + // Adjust the indices so they're relative to the chunk + let chunk_indices = subtract_scalar(&chunk_indices, &chunk_begin.into())?; + indices_by_chunk[chunk_idx] = Some(chunk_indices); + + pos = chunk_end_pos; + } + + // Now we can take the chunks + let chunks = indices_by_chunk + .iter() + .enumerate() + .filter_map(|(chunk_idx, indices)| indices.as_ref().map(|i| (chunk_idx, i))) + .map(|(chunk_idx, chunk_indices)| { + take( + &chunked.chunk(chunk_idx).expect("chunk not found"), + chunk_indices, + ) + }) + .try_collect()?; + + Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array()) +} + #[cfg(test)] mod test { use crate::array::chunked::ChunkedArray; diff --git a/vortex-array/src/array/primitive/compute/cast.rs b/vortex-array/src/array/primitive/compute/cast.rs index e1339571c1..906f295b64 100644 --- a/vortex-array/src/array/primitive/compute/cast.rs +++ b/vortex-array/src/array/primitive/compute/cast.rs @@ -14,12 +14,12 @@ impl CastFn for PrimitiveArray { // Short-cut if we can just change the nullability if self.ptype() == ptype && !self.dtype().is_nullable() && dtype.is_nullable() { - match_each_native_ptype!(self.ptype(), |$T| { - return Ok( - PrimitiveArray::try_new(self.scalar_buffer::<$T>(), Validity::AllValid)? - .into_array(), - ); - }) + return Ok(PrimitiveArray::new( + self.buffer().clone(), + self.ptype(), + Validity::AllValid, + ) + .into_array()); } // FIXME(ngates): #260 - check validity and nullability diff --git a/vortex-array/src/array/primitive/compute/slice.rs b/vortex-array/src/array/primitive/compute/slice.rs index 567b699d25..568d2b51af 100644 --- a/vortex-array/src/array/primitive/compute/slice.rs +++ b/vortex-array/src/array/primitive/compute/slice.rs @@ -1,4 +1,3 @@ -use vortex_dtype::match_each_native_ptype; use vortex_error::VortexResult; use crate::array::primitive::PrimitiveArray; @@ -8,12 +7,12 @@ use crate::IntoArray; impl SliceFn for PrimitiveArray { fn slice(&self, start: usize, stop: usize) -> VortexResult { - match_each_native_ptype!(self.ptype(), |$T| { - Ok(PrimitiveArray::try_new( - self.scalar_buffer::<$T>().slice(start, stop - start), - self.validity().slice(start, stop)?, - )? - .into_array()) - }) + assert!(start <= stop, "start must be <= stop"); + let byte_width = self.ptype().byte_width(); + let buffer = self.buffer().slice(start * byte_width..stop * byte_width); + Ok( + PrimitiveArray::new(buffer, self.ptype(), self.validity().slice(start, stop)?) + .into_array(), + ) } } diff --git a/vortex-array/src/array/primitive/compute/subtract_scalar.rs b/vortex-array/src/array/primitive/compute/subtract_scalar.rs index bfbce0e86d..871353699d 100644 --- a/vortex-array/src/array/primitive/compute/subtract_scalar.rs +++ b/vortex-array/src/array/primitive/compute/subtract_scalar.rs @@ -1,6 +1,5 @@ use itertools::Itertools; -use num_traits::ops::overflowing::OverflowingSub; -use num_traits::SaturatingSub; +use num_traits::WrappingSub; use vortex_dtype::{match_each_float_ptype, match_each_integer_ptype, NativePType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_scalar::PrimitiveScalar; @@ -9,7 +8,6 @@ use vortex_scalar::Scalar; use crate::array::constant::ConstantArray; use crate::array::primitive::PrimitiveArray; use crate::compute::unary::scalar_subtract::SubtractScalarFn; -use crate::stats::{ArrayStatistics, Stat}; use crate::validity::ArrayValidity; use crate::{Array, ArrayDType, IntoArray}; @@ -50,7 +48,7 @@ impl SubtractScalarFn for PrimitiveArray { fn subtract_scalar_integer< 'a, - T: NativePType + OverflowingSub + SaturatingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>, + T: NativePType + WrappingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>, >( subtract_from: &PrimitiveArray, to_subtract: T, @@ -60,31 +58,12 @@ fn subtract_scalar_integer< return Ok(subtract_from.clone()); } - if let Some(min) = subtract_from.statistics().compute_as_cast::(Stat::Min) { - if let (_, true) = min.overflowing_sub(&to_subtract) { - vortex_bail!( - "Integer subtraction over/underflow: {}, {}", - min, - to_subtract - ) - } - } - if let Some(max) = subtract_from.statistics().compute_as_cast::(Stat::Max) { - if let (_, true) = max.overflowing_sub(&to_subtract) { - vortex_bail!( - "Integer subtraction over/underflow: {}, {}", - max, - to_subtract - ) - } - } - let contains_nulls = !subtract_from.logical_validity().all_valid(); let subtraction_result = if contains_nulls { let sub_vec = subtract_from .maybe_null_slice() .iter() - .map(|&v: &T| v.saturating_sub(&to_subtract)) + .map(|&v: &T| v.wrapping_sub(&to_subtract)) .collect_vec(); PrimitiveArray::from_vec(sub_vec, subtract_from.validity()) } else { @@ -102,7 +81,6 @@ fn subtract_scalar_integer< #[cfg(test)] mod test { use itertools::Itertools; - use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; use crate::compute::unary::scalar_subtract::subtract_scalar; @@ -148,7 +126,7 @@ mod test { .unwrap(); let results = flattened.maybe_null_slice::().to_vec(); - assert_eq!(results, &[0u16, 1, 0, 2]); + assert_eq!(results, &[0u16, 1, 65535, 2]); let valid_indices = flattened .validity() .to_logical(flattened.len()) @@ -175,55 +153,6 @@ mod test { assert_eq!(results, &[2.0f64, 3.0, 4.0]); } - #[test] - fn test_scalar_subtract_unsigned_underflow() { - let values = vec![u8::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1u8.into()).expect_err("should fail with underflow"); - let values = vec![u16::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1u16.into()).expect_err("should fail with underflow"); - let values = vec![u32::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1u32.into()).expect_err("should fail with underflow"); - let values = vec![u64::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1u64.into()).expect_err("should fail with underflow"); - } - - #[test] - fn test_scalar_subtract_signed_overflow() { - let values = vec![i8::MAX, 2, 3].into_array(); - let to_subtract: Scalar = (-1i8).into(); - let _results = - subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow"); - let values = vec![i16::MAX, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow"); - let values = vec![i32::MAX, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow"); - let values = vec![i64::MAX, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow"); - } - - #[test] - fn test_scalar_subtract_signed_underflow() { - let values = vec![i8::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1i8.into()).expect_err("should fail with underflow"); - let values = vec![i16::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1i16.into()).expect_err("should fail with underflow"); - let values = vec![i32::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1i32.into()).expect_err("should fail with underflow"); - let values = vec![i64::MIN, 2, 3].into_array(); - let _results = - subtract_scalar(&values, &1i64.into()).expect_err("should fail with underflow"); - } - #[test] fn test_scalar_subtract_float_underflow_is_ok() { let values = vec![f32::MIN, 2.0, 3.0].into_array(); diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index ca03f19f69..01117a18de 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -1,4 +1,4 @@ -use arrow_buffer::{ArrowNativeType, ScalarBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer as ArrowBuffer, MutableBuffer, ScalarBuffer}; use itertools::Itertools; use num_traits::AsPrimitive; use serde::{Deserialize, Serialize}; @@ -23,30 +23,38 @@ pub struct PrimitiveMetadata { } impl PrimitiveArray { - // TODO(ngates): remove the Arrow types from this API. - pub fn try_new( - buffer: ScalarBuffer, - validity: Validity, - ) -> VortexResult { - Ok(Self { + pub fn new(buffer: Buffer, ptype: PType, validity: Validity) -> Self { + let length = match_each_native_ptype!(ptype, |$P| { + let (prefix, values, suffix) = unsafe { buffer.align_to::<$P>() }; + assert!( + prefix.is_empty() && suffix.is_empty(), + "buffer is not aligned" + ); + values.len() + }); + + Self { typed: TypedArray::try_from_parts( - DType::from(T::PTYPE).with_nullability(validity.nullability()), - buffer.len(), + DType::from(ptype).with_nullability(validity.nullability()), + length, PrimitiveMetadata { - validity: validity.to_metadata(buffer.len())?, + validity: validity.to_metadata(length).expect("invalid validity"), }, - Some(Buffer::from(buffer.into_inner())), + Some(buffer), validity.into_array().into_iter().collect_vec().into(), StatsSet::new(), - )?, - }) + ) + .expect("should be valid"), + } } pub fn from_vec(values: Vec, validity: Validity) -> Self { match_each_native_ptype!(T::PTYPE, |$P| { - Self::try_new(ScalarBuffer::<$P>::from( - unsafe { std::mem::transmute::, Vec<$P>>(values) } - ), validity).unwrap() + PrimitiveArray::new( + ArrowBuffer::from(MutableBuffer::from(unsafe { std::mem::transmute::, Vec<$P>>(values) })).into(), + T::PTYPE, + validity, + ) }) } @@ -131,13 +139,7 @@ impl PrimitiveArray { "can't reinterpret cast between integers of two different widths" ); - match_each_native_ptype!(ptype, |$P| { - PrimitiveArray::try_new( - ScalarBuffer::<$P>::new(self.buffer().clone().into(), 0, self.len()), - self.validity(), - ) - .unwrap() - }) + PrimitiveArray::new(self.buffer().clone(), ptype, self.validity()) } pub fn patch, T: NativePType + ArrowNativeType>( diff --git a/vortex-array/src/arrow/array.rs b/vortex-array/src/arrow/array.rs index 47bb9f5f41..a1eff32aa4 100644 --- a/vortex-array/src/arrow/array.rs +++ b/vortex-array/src/arrow/array.rs @@ -20,8 +20,8 @@ use arrow_buffer::buffer::{NullBuffer, OffsetBuffer}; use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer}; use arrow_schema::{DataType, TimeUnit}; use itertools::Itertools; -use vortex_dtype::DType; use vortex_dtype::NativePType; +use vortex_dtype::{DType, PType}; use crate::array::bool::BoolArray; use crate::array::datetime::LocalDateTimeArray; @@ -37,13 +37,7 @@ use crate::{ArrayData, IntoArray, IntoArrayData}; impl IntoArrayData for Buffer { fn into_array_data(self) -> ArrayData { - let length = self.len(); - PrimitiveArray::try_new( - ScalarBuffer::::new(self, 0, length), - Validity::NonNullable, - ) - .unwrap() - .into_array_data() + PrimitiveArray::new(self.into(), PType::U8, Validity::NonNullable).into_array_data() } } @@ -57,24 +51,18 @@ impl IntoArrayData for NullBuffer { impl IntoArrayData for ScalarBuffer { fn into_array_data(self) -> ArrayData { - let length = self.len(); - PrimitiveArray::try_new( - Self::new(self.into_inner(), 0, length), - Validity::NonNullable, - ) - .unwrap() - .into_array_data() + PrimitiveArray::new(self.into_inner().into(), T::PTYPE, Validity::NonNullable) + .into_array_data() } } impl IntoArrayData for OffsetBuffer { fn into_array_data(self) -> ArrayData { - let length = self.len(); - let array = PrimitiveArray::try_new( - ScalarBuffer::::new(self.into_inner().into_inner(), 0, length), + let array = PrimitiveArray::new( + self.into_inner().into_inner().into(), + O::PTYPE, Validity::NonNullable, ) - .unwrap() .into_array_data(); array.set(Stat::IsSorted, true.into()); array.set(Stat::IsStrictSorted, true.into()); @@ -87,9 +75,12 @@ where ::Native: NativePType, { fn from_arrow(value: &ArrowPrimitiveArray, nullable: bool) -> Self { - let arr = PrimitiveArray::try_new(value.values().clone(), nulls(value.nulls(), nullable)) - .unwrap() - .into_array_data(); + let arr = PrimitiveArray::new( + value.values().clone().into_inner().into(), + T::Native::PTYPE, + nulls(value.nulls(), nullable), + ) + .into_array_data(); if T::DATA_TYPE.is_numeric() { return arr; diff --git a/vortex-buffer/src/lib.rs b/vortex-buffer/src/lib.rs index 7dcddd4bdf..34ff9ef258 100644 --- a/vortex-buffer/src/lib.rs +++ b/vortex-buffer/src/lib.rs @@ -56,6 +56,13 @@ impl Buffer { Self::Bytes(_) => Err(self), } } + + pub fn from_vec(values: Vec) -> Self + where + T: ArrowNativeType, + { + Self::Arrow(ArrowBuffer::from_vec(values)) + } } impl Deref for Buffer { From 012f8bf9552cb175e7f6e343ff101cb0ee937af6 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 10 Jul 2024 22:51:29 +0100 Subject: [PATCH 2/2] merge --- vortex-buffer/src/lib.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vortex-buffer/src/lib.rs b/vortex-buffer/src/lib.rs index 34ff9ef258..7dcddd4bdf 100644 --- a/vortex-buffer/src/lib.rs +++ b/vortex-buffer/src/lib.rs @@ -56,13 +56,6 @@ impl Buffer { Self::Bytes(_) => Err(self), } } - - pub fn from_vec(values: Vec) -> Self - where - T: ArrowNativeType, - { - Self::Arrow(ArrowBuffer::from_vec(values)) - } } impl Deref for Buffer {