Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked take #447

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions vortex-array/src/array/chunked/canonical.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
use arrow_buffer::{BooleanBuffer, Buffer, MutableBuffer};
use itertools::Itertools;
use vortex_dtype::{match_each_native_ptype, DType, Nullability, PType, StructDType};
use vortex_dtype::{DType, Nullability, PType, StructDType};
use vortex_error::{vortex_bail, ErrString, VortexResult};

use crate::accessor::ArrayAccessor;
Expand Down Expand Up @@ -152,12 +152,11 @@ fn pack_primitives(
buffer.extend_from_slice(chunk.buffer());
}

match_each_native_ptype!(ptype, |$T| {
Ok(PrimitiveArray::try_new(
ScalarBuffer::<$T>::from(buffer),
validity,
)?)
})
Ok(PrimitiveArray::new(
Buffer::from(buffer).into(),
ptype,
validity,
))
}

/// Builds a new [VarBinArray] by repacking the values from the chunks into a single
Expand Down
64 changes: 62 additions & 2 deletions vortex-array/src/array/chunked/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
use itertools::Itertools;
use vortex_dtype::PType;
use vortex_error::VortexResult;

use crate::array::chunked::ChunkedArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::slice::slice;
use crate::compute::take::{take, TakeFn};
use crate::compute::unary::cast::try_cast;
use crate::compute::unary::scalar_at::scalar_at;
use crate::compute::unary::scalar_subtract::subtract_scalar;
use crate::stats::ArrayStatistics;
use crate::ArrayDType;
use crate::{Array, IntoArray, ToArray};

impl TakeFn for ChunkedArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
if self.len() == indices.len() {
return Ok(self.to_array());
// Fast path for strict sorted indices.
if indices
.statistics()
.compute_is_strict_sorted()
.unwrap_or(false)
{
if self.len() == indices.len() {
return Ok(self.to_array());
}

return take_strict_sorted(self, indices);
}

// FIXME(ngates): this is wrong, need to canonicalise
let indices = PrimitiveArray::try_from(try_cast(indices, PType::U64.into())?)?;

// While the chunk idx remains the same, accumulate a list of chunk indices.
Expand Down Expand Up @@ -51,6 +67,50 @@ impl TakeFn for ChunkedArray {
}
}

/// When the indices are non-null and strict-sorted, we can do better
fn take_strict_sorted(chunked: &ChunkedArray, indices: &Array) -> VortexResult<Array> {
let mut indices_by_chunk = vec![None; chunked.nchunks()];

// Track our position in the indices array
let mut pos = 0;
while pos < indices.len() {
// Locate the chunk index for the current index
let idx = usize::try_from(&scalar_at(indices, pos)?).unwrap();
let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx);

// Find the end of this chunk, and locate that position in the indices array.
let chunk_begin = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx)?).unwrap();
let chunk_end = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx + 1)?).unwrap();
let chunk_end_pos = search_sorted(indices, chunk_end, SearchSortedSide::Left)
.unwrap()
.to_index();

// Now we can say the slice of indices belonging to this chunk is [pos, chunk_end_pos)
let chunk_indices = slice(indices, pos, chunk_end_pos)?;

// Adjust the indices so they're relative to the chunk
let chunk_indices = subtract_scalar(&chunk_indices, &chunk_begin.into())?;
indices_by_chunk[chunk_idx] = Some(chunk_indices);

pos = chunk_end_pos;
}

// Now we can take the chunks
let chunks = indices_by_chunk
.iter()
.enumerate()
.filter_map(|(chunk_idx, indices)| indices.as_ref().map(|i| (chunk_idx, i)))
.map(|(chunk_idx, chunk_indices)| {
take(
&chunked.chunk(chunk_idx).expect("chunk not found"),
chunk_indices,
)
})
.try_collect()?;

Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array())
}

#[cfg(test)]
mod test {
use crate::array::chunked::ChunkedArray;
Expand Down
12 changes: 6 additions & 6 deletions vortex-array/src/array/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ impl CastFn for PrimitiveArray {

// Short-cut if we can just change the nullability
if self.ptype() == ptype && !self.dtype().is_nullable() && dtype.is_nullable() {
match_each_native_ptype!(self.ptype(), |$T| {
return Ok(
PrimitiveArray::try_new(self.scalar_buffer::<$T>(), Validity::AllValid)?
.into_array(),
);
})
return Ok(PrimitiveArray::new(
self.buffer().clone(),
self.ptype(),
Validity::AllValid,
)
.into_array());
}

// FIXME(ngates): #260 - check validity and nullability
Expand Down
15 changes: 7 additions & 8 deletions vortex-array/src/array/primitive/compute/slice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use vortex_dtype::match_each_native_ptype;
use vortex_error::VortexResult;

use crate::array::primitive::PrimitiveArray;
Expand All @@ -8,12 +7,12 @@ use crate::IntoArray;

impl SliceFn for PrimitiveArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
match_each_native_ptype!(self.ptype(), |$T| {
Ok(PrimitiveArray::try_new(
self.scalar_buffer::<$T>().slice(start, stop - start),
self.validity().slice(start, stop)?,
)?
.into_array())
})
assert!(start <= stop, "start must be <= stop");
let byte_width = self.ptype().byte_width();
let buffer = self.buffer().slice(start * byte_width..stop * byte_width);
Ok(
PrimitiveArray::new(buffer, self.ptype(), self.validity().slice(start, stop)?)
.into_array(),
)
}
}
79 changes: 4 additions & 75 deletions vortex-array/src/array/primitive/compute/subtract_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use itertools::Itertools;
use num_traits::ops::overflowing::OverflowingSub;
use num_traits::SaturatingSub;
use num_traits::WrappingSub;
use vortex_dtype::{match_each_float_ptype, match_each_integer_ptype, NativePType};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_scalar::PrimitiveScalar;
Expand All @@ -9,7 +8,6 @@ use vortex_scalar::Scalar;
use crate::array::constant::ConstantArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::unary::scalar_subtract::SubtractScalarFn;
use crate::stats::{ArrayStatistics, Stat};
use crate::validity::ArrayValidity;
use crate::{Array, ArrayDType, IntoArray};

Expand Down Expand Up @@ -50,7 +48,7 @@ impl SubtractScalarFn for PrimitiveArray {

fn subtract_scalar_integer<
'a,
T: NativePType + OverflowingSub + SaturatingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>,
T: NativePType + WrappingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>,
>(
subtract_from: &PrimitiveArray,
to_subtract: T,
Expand All @@ -60,31 +58,12 @@ fn subtract_scalar_integer<
return Ok(subtract_from.clone());
}

if let Some(min) = subtract_from.statistics().compute_as_cast::<T>(Stat::Min) {
if let (_, true) = min.overflowing_sub(&to_subtract) {
vortex_bail!(
"Integer subtraction over/underflow: {}, {}",
min,
to_subtract
)
}
}
if let Some(max) = subtract_from.statistics().compute_as_cast::<T>(Stat::Max) {
if let (_, true) = max.overflowing_sub(&to_subtract) {
vortex_bail!(
"Integer subtraction over/underflow: {}, {}",
max,
to_subtract
)
}
}

let contains_nulls = !subtract_from.logical_validity().all_valid();
let subtraction_result = if contains_nulls {
let sub_vec = subtract_from
.maybe_null_slice()
.iter()
.map(|&v: &T| v.saturating_sub(&to_subtract))
.map(|&v: &T| v.wrapping_sub(&to_subtract))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we changing the behavior of this, and removing the assertions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertions are expensive since they require computing stats.
We could either just use checked subtraction and allow this to panic, or we pick between saturating and wrapping.

Tbh, I've never seen unchecked subtraction do anything other than wrap, so it seemed like the better choice. e.g. https://arrow.apache.org/docs/python/generated/pyarrow.compute.subtract.html

.collect_vec();
PrimitiveArray::from_vec(sub_vec, subtract_from.validity())
} else {
Expand All @@ -102,7 +81,6 @@ fn subtract_scalar_integer<
#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
use crate::compute::unary::scalar_subtract::subtract_scalar;
Expand Down Expand Up @@ -148,7 +126,7 @@ mod test {
.unwrap();

let results = flattened.maybe_null_slice::<u16>().to_vec();
assert_eq!(results, &[0u16, 1, 0, 2]);
assert_eq!(results, &[0u16, 1, 65535, 2]);
let valid_indices = flattened
.validity()
.to_logical(flattened.len())
Expand All @@ -175,55 +153,6 @@ mod test {
assert_eq!(results, &[2.0f64, 3.0, 4.0]);
}

#[test]
fn test_scalar_subtract_unsigned_underflow() {
let values = vec![u8::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u8.into()).expect_err("should fail with underflow");
let values = vec![u16::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u16.into()).expect_err("should fail with underflow");
let values = vec![u32::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u32.into()).expect_err("should fail with underflow");
let values = vec![u64::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1u64.into()).expect_err("should fail with underflow");
}

#[test]
fn test_scalar_subtract_signed_overflow() {
let values = vec![i8::MAX, 2, 3].into_array();
let to_subtract: Scalar = (-1i8).into();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
let values = vec![i16::MAX, 2, 3].into_array();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
let values = vec![i32::MAX, 2, 3].into_array();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
let values = vec![i64::MAX, 2, 3].into_array();
let _results =
subtract_scalar(&values, &to_subtract).expect_err("should fail with overflow");
}

#[test]
fn test_scalar_subtract_signed_underflow() {
let values = vec![i8::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i8.into()).expect_err("should fail with underflow");
let values = vec![i16::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i16.into()).expect_err("should fail with underflow");
let values = vec![i32::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i32.into()).expect_err("should fail with underflow");
let values = vec![i64::MIN, 2, 3].into_array();
let _results =
subtract_scalar(&values, &1i64.into()).expect_err("should fail with underflow");
}

#[test]
fn test_scalar_subtract_float_underflow_is_ok() {
let values = vec![f32::MIN, 2.0, 3.0].into_array();
Expand Down
48 changes: 25 additions & 23 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow_buffer::{ArrowNativeType, ScalarBuffer};
use arrow_buffer::{ArrowNativeType, Buffer as ArrowBuffer, MutableBuffer, ScalarBuffer};
use itertools::Itertools;
use num_traits::AsPrimitive;
use serde::{Deserialize, Serialize};
Expand All @@ -23,30 +23,38 @@ pub struct PrimitiveMetadata {
}

impl PrimitiveArray {
// TODO(ngates): remove the Arrow types from this API.
pub fn try_new<T: NativePType + ArrowNativeType>(
buffer: ScalarBuffer<T>,
validity: Validity,
) -> VortexResult<Self> {
Ok(Self {
pub fn new(buffer: Buffer, ptype: PType, validity: Validity) -> Self {
let length = match_each_native_ptype!(ptype, |$P| {
let (prefix, values, suffix) = unsafe { buffer.align_to::<$P>() };
assert!(
prefix.is_empty() && suffix.is_empty(),
"buffer is not aligned"
);
values.len()
});

Self {
typed: TypedArray::try_from_parts(
DType::from(T::PTYPE).with_nullability(validity.nullability()),
buffer.len(),
DType::from(ptype).with_nullability(validity.nullability()),
length,
PrimitiveMetadata {
validity: validity.to_metadata(buffer.len())?,
validity: validity.to_metadata(length).expect("invalid validity"),
},
Some(Buffer::from(buffer.into_inner())),
Some(buffer),
validity.into_array().into_iter().collect_vec().into(),
StatsSet::new(),
)?,
})
)
.expect("should be valid"),
}
}

pub fn from_vec<T: NativePType>(values: Vec<T>, validity: Validity) -> Self {
match_each_native_ptype!(T::PTYPE, |$P| {
Self::try_new(ScalarBuffer::<$P>::from(
unsafe { std::mem::transmute::<Vec<T>, Vec<$P>>(values) }
), validity).unwrap()
PrimitiveArray::new(
ArrowBuffer::from(MutableBuffer::from(unsafe { std::mem::transmute::<Vec<T>, Vec<$P>>(values) })).into(),
T::PTYPE,
validity,
)
})
}

Expand Down Expand Up @@ -131,13 +139,7 @@ impl PrimitiveArray {
"can't reinterpret cast between integers of two different widths"
);

match_each_native_ptype!(ptype, |$P| {
PrimitiveArray::try_new(
ScalarBuffer::<$P>::new(self.buffer().clone().into(), 0, self.len()),
self.validity(),
)
.unwrap()
})
PrimitiveArray::new(self.buffer().clone(), ptype, self.validity())
}

pub fn patch<P: AsPrimitive<usize>, T: NativePType + ArrowNativeType>(
Expand Down
Loading
Loading