Skip to content

Commit

Permalink
Chunked take (#447)
Browse files Browse the repository at this point in the history
* Improves the performance of ChunkedArray::take for strict-sorted
indices.
* Makes subtraction wrapping instead of doing an expensive statistics
test
* Fixes PrimitiveArray slice to be zero-copy
  • Loading branch information
gatesn authored Jul 10, 2024
1 parent 9419059 commit dba7012
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 144 deletions.
15 changes: 7 additions & 8 deletions vortex-array/src/array/chunked/canonical.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
use arrow_buffer::{BooleanBuffer, Buffer, MutableBuffer};
use itertools::Itertools;
use vortex_dtype::{match_each_native_ptype, DType, Nullability, PType, StructDType};
use vortex_dtype::{DType, Nullability, PType, StructDType};
use vortex_error::{vortex_bail, ErrString, VortexResult};

use crate::accessor::ArrayAccessor;
Expand Down Expand Up @@ -152,12 +152,11 @@ fn pack_primitives(
buffer.extend_from_slice(chunk.buffer());
}

match_each_native_ptype!(ptype, |$T| {
Ok(PrimitiveArray::try_new(
ScalarBuffer::<$T>::from(buffer),
validity,
)?)
})
Ok(PrimitiveArray::new(
Buffer::from(buffer).into(),
ptype,
validity,
))
}

/// Builds a new [VarBinArray] by repacking the values from the chunks into a single
Expand Down
64 changes: 62 additions & 2 deletions vortex-array/src/array/chunked/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
use itertools::Itertools;
use vortex_dtype::PType;
use vortex_error::VortexResult;

use crate::array::chunked::ChunkedArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::slice::slice;
use crate::compute::take::{take, TakeFn};
use crate::compute::unary::cast::try_cast;
use crate::compute::unary::scalar_at::scalar_at;
use crate::compute::unary::scalar_subtract::subtract_scalar;
use crate::stats::ArrayStatistics;
use crate::ArrayDType;
use crate::{Array, IntoArray, ToArray};

impl TakeFn for ChunkedArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
if self.len() == indices.len() {
return Ok(self.to_array());
// Fast path for strict sorted indices.
if indices
.statistics()
.compute_is_strict_sorted()
.unwrap_or(false)
{
if self.len() == indices.len() {
return Ok(self.to_array());
}

return take_strict_sorted(self, indices);
}

// FIXME(ngates): this is wrong, need to canonicalise
let indices = PrimitiveArray::try_from(try_cast(indices, PType::U64.into())?)?;

// While the chunk idx remains the same, accumulate a list of chunk indices.
Expand Down Expand Up @@ -51,6 +67,50 @@ impl TakeFn for ChunkedArray {
}
}

/// When the indices are non-null and strict-sorted, we can do better
fn take_strict_sorted(chunked: &ChunkedArray, indices: &Array) -> VortexResult<Array> {
let mut indices_by_chunk = vec![None; chunked.nchunks()];

// Track our position in the indices array
let mut pos = 0;
while pos < indices.len() {
// Locate the chunk index for the current index
let idx = usize::try_from(&scalar_at(indices, pos)?).unwrap();
let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx);

// Find the end of this chunk, and locate that position in the indices array.
let chunk_begin = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx)?).unwrap();
let chunk_end = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx + 1)?).unwrap();
let chunk_end_pos = search_sorted(indices, chunk_end, SearchSortedSide::Left)
.unwrap()
.to_index();

// Now we can say the slice of indices belonging to this chunk is [pos, chunk_end_pos)
let chunk_indices = slice(indices, pos, chunk_end_pos)?;

// Adjust the indices so they're relative to the chunk
let chunk_indices = subtract_scalar(&chunk_indices, &chunk_begin.into())?;
indices_by_chunk[chunk_idx] = Some(chunk_indices);

pos = chunk_end_pos;
}

// Now we can take the chunks
let chunks = indices_by_chunk
.iter()
.enumerate()
.filter_map(|(chunk_idx, indices)| indices.as_ref().map(|i| (chunk_idx, i)))
.map(|(chunk_idx, chunk_indices)| {
take(
&chunked.chunk(chunk_idx).expect("chunk not found"),
chunk_indices,
)
})
.try_collect()?;

Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array())
}

#[cfg(test)]
mod test {
use crate::array::chunked::ChunkedArray;
Expand Down
12 changes: 6 additions & 6 deletions vortex-array/src/array/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ impl CastFn for PrimitiveArray {

// Short-cut if we can just change the nullability
if self.ptype() == ptype && !self.dtype().is_nullable() && dtype.is_nullable() {
match_each_native_ptype!(self.ptype(), |$T| {
return Ok(
PrimitiveArray::try_new(self.scalar_buffer::<$T>(), Validity::AllValid)?
.into_array(),
);
})
return Ok(PrimitiveArray::new(
self.buffer().clone(),
self.ptype(),
Validity::AllValid,
)
.into_array());
}

// FIXME(ngates): #260 - check validity and nullability
Expand Down
15 changes: 7 additions & 8 deletions vortex-array/src/array/primitive/compute/slice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use vortex_dtype::match_each_native_ptype;
use vortex_error::VortexResult;

use crate::array::primitive::PrimitiveArray;
Expand All @@ -8,12 +7,12 @@ use crate::IntoArray;

impl SliceFn for PrimitiveArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
match_each_native_ptype!(self.ptype(), |$T| {
Ok(PrimitiveArray::try_new(
self.scalar_buffer::<$T>().slice(start, stop - start),
self.validity().slice(start, stop)?,
)?
.into_array())
})
assert!(start <= stop, "start must be <= stop");
let byte_width = self.ptype().byte_width();
let buffer = self.buffer().slice(start * byte_width..stop * byte_width);
Ok(
PrimitiveArray::new(buffer, self.ptype(), self.validity().slice(start, stop)?)
.into_array(),
)
}
}
79 changes: 4 additions & 75 deletions vortex-array/src/array/primitive/compute/subtract_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use itertools::Itertools;
use num_traits::ops::overflowing::OverflowingSub;
use num_traits::SaturatingSub;
use num_traits::WrappingSub;
use vortex_dtype::{match_each_float_ptype, match_each_integer_ptype, NativePType};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_scalar::PrimitiveScalar;
Expand All @@ -9,7 +8,6 @@ use vortex_scalar::Scalar;
use crate::array::constant::ConstantArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::unary::scalar_subtract::SubtractScalarFn;
use crate::stats::{ArrayStatistics, Stat};
use crate::validity::ArrayValidity;
use crate::{Array, ArrayDType, IntoArray};

Expand Down Expand Up @@ -50,7 +48,7 @@ impl SubtractScalarFn for PrimitiveArray {

fn subtract_scalar_integer<
'a,
T: NativePType + OverflowingSub + SaturatingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>,
T: NativePType + WrappingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>,
>(
subtract_from: &PrimitiveArray,
to_subtract: T,
Expand All @@ -60,31 +58,12 @@ fn subtract_scalar_integer<
return Ok(subtract_from.clone());
}

if let Some(min) = subtract_from.statistics().compute_as_cast::<T>(Stat::Min) {
if let (_, true) = min.overflowing_sub(&to_subtract) {
vortex_bail!(
"Integer subtraction over/underflow: {}, {}",
min,
to_subtract
)
}
}
if let Some(max) = subtract_from.statistics().compute_as_cast::<T>(Stat::Max) {
if let (_, true) = max.overflowing_sub(&to_subtract) {
vortex_bail!(
"Integer subtraction over/underflow: {}, {}",
max,
to_subtract
)
}
}

let contains_nulls = !subtract_from.logical_validity().all_valid();
let subtraction_result = if contains_nulls {
let sub_vec = subtract_from
.maybe_null_slice()
.iter()
.map(|&v: &T| v.saturating_sub(&to_subtract))
.map(|&v: &T| v.wrapping_sub(&to_subtract))
.collect_vec();
PrimitiveArray::from_vec(sub_vec, subtract_from.validity())
} else {
Expand All @@ -102,7 +81,6 @@ fn subtract_scalar_integer<
#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
use crate::compute::unary::scalar_subtract::subtract_scalar;
Expand Down Expand Up @@ -148,7 +126,7 @@ mod test {
.unwrap();

let results = flattened.maybe_null_slice::<u16>().to_vec();
assert_eq!(results, &[0u16, 1, 0, 2]);
assert_eq!(results, &[0u16, 1, 65535, 2]);
let valid_indices = flattened
.validity()
.to_logical(flattened.len())
Expand All @@ -175,55 +153,6 @@ mod test {
assert_eq!(results, &[2.0f64, 3.0, 4.0]);
}

#[test]
fn test_scalar_subtract_unsigned_underflow() {
let values = vec![u8::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u8.into()).expect_err("should fail with underflow");
let values = vec![u16::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u16.into()).expect_err("should fail with underflow");
let values = vec![u32::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u32.into()).expect_err("should fail with underflow");
let values = vec![u64::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u64.into()).expect_err("should fail with underflow");
}

#[test]
fn test_scalar_subtract_signed_overflow() {
let values = vec![i8::MAX, 2, 3].into_array();
let to_subtract: Scalar = (-1i8).into();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
let values = vec![i16::MAX, 2, 3].into_array();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
let values = vec![i32::MAX, 2, 3].into_array();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
let values = vec![i64::MAX, 2, 3].into_array();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
}

#[test]
fn test_scalar_subtract_signed_underflow() {
let values = vec![i8::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i8.into()).expect_err("should fail with underflow");
let values = vec![i16::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i16.into()).expect_err("should fail with underflow");
let values = vec![i32::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i32.into()).expect_err("should fail with underflow");
let values = vec![i64::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i64.into()).expect_err("should fail with underflow");
}

#[test]
fn test_scalar_subtract_float_underflow_is_ok() {
let values = vec![f32::MIN, 2.0, 3.0].into_array();
Expand Down
48 changes: 25 additions & 23 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow_buffer::{ArrowNativeType, ScalarBuffer};
use arrow_buffer::{ArrowNativeType, Buffer as ArrowBuffer, MutableBuffer, ScalarBuffer};
use itertools::Itertools;
use num_traits::AsPrimitive;
use serde::{Deserialize, Serialize};
Expand All @@ -23,30 +23,38 @@ pub struct PrimitiveMetadata {
}

impl PrimitiveArray {
// TODO(ngates): remove the Arrow types from this API.
pub fn try_new<T: NativePType + ArrowNativeType>(
buffer: ScalarBuffer<T>,
validity: Validity,
) -> VortexResult<Self> {
Ok(Self {
pub fn new(buffer: Buffer, ptype: PType, validity: Validity) -> Self {
let length = match_each_native_ptype!(ptype, |$P| {
let (prefix, values, suffix) = unsafe { buffer.align_to::<$P>() };
assert!(
prefix.is_empty() && suffix.is_empty(),
"buffer is not aligned"
);
values.len()
});

Self {
typed: TypedArray::try_from_parts(
DType::from(T::PTYPE).with_nullability(validity.nullability()),
buffer.len(),
DType::from(ptype).with_nullability(validity.nullability()),
length,
PrimitiveMetadata {
validity: validity.to_metadata(buffer.len())?,
validity: validity.to_metadata(length).expect("invalid validity"),
},
Some(Buffer::from(buffer.into_inner())),
Some(buffer),
validity.into_array().into_iter().collect_vec().into(),
StatsSet::new(),
)?,
})
)
.expect("should be valid"),
}
}

pub fn from_vec<T: NativePType>(values: Vec<T>, validity: Validity) -> Self {
match_each_native_ptype!(T::PTYPE, |$P| {
Self::try_new(ScalarBuffer::<$P>::from(
unsafe { std::mem::transmute::<Vec<T>, Vec<$P>>(values) }
), validity).unwrap()
PrimitiveArray::new(
ArrowBuffer::from(MutableBuffer::from(unsafe { std::mem::transmute::<Vec<T>, Vec<$P>>(values) })).into(),
T::PTYPE,
validity,
)
})
}

Expand Down Expand Up @@ -131,13 +139,7 @@ impl PrimitiveArray {
"can't reinterpret cast between integers of two different widths"
);

match_each_native_ptype!(ptype, |$P| {
PrimitiveArray::try_new(
ScalarBuffer::<$P>::new(self.buffer().clone().into(), 0, self.len()),
self.validity(),
)
.unwrap()
})
PrimitiveArray::new(self.buffer().clone(), ptype, self.validity())
}

pub fn patch<P: AsPrimitive<usize>, T: NativePType + ArrowNativeType>(
Expand Down
Loading

0 comments on commit dba7012

Please sign in to comment.