From 55220c0c77f28c4a11ad0837da3fa88a40e0660d Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 5 Sep 2024 17:52:53 +0100 Subject: [PATCH] Unary and Binary functions trait (#726) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introducing specialized traits for unary (like map) and binary (like &) operations on arrays. This will allow us to have specialized and ergonomic implementations for cases when we know the encoding of an array (like in `ArrayCompute` functions) ## Benchmarks As of 08bbd8db running on my M3 Max macbook ``` arrow_unary_add time: [228.22 µs 233.98 µs 239.93 µs] vortex_unary_add time: [1.5322 µs 1.5469 µs 1.5604 µs] arrow_binary_add time: [237.74 µs 240.93 µs 243.98 µs] vortex_binary_add time: [93.821 µs 94.796 µs 95.918 µs] ``` --- vortex-array/Cargo.toml | 4 + vortex-array/benches/fn.rs | 67 +++++++++ vortex-array/src/array/primitive/mod.rs | 178 +++++++++++++++++++++++- vortex-array/src/elementwise.rs | 91 ++++++++++++ vortex-array/src/iter/mod.rs | 18 ++- vortex-array/src/lib.rs | 1 + vortex-array/src/validity.rs | 24 +++- vortex-array/src/variants.rs | 32 +++++ 8 files changed, 410 insertions(+), 5 deletions(-) create mode 100644 vortex-array/benches/fn.rs create mode 100644 vortex-array/src/elementwise.rs diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 8de78bb24b..037b975524 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -89,3 +89,7 @@ harness = false [[bench]] name = "iter" harness = false + +[[bench]] +name = "fn" +harness = false diff --git a/vortex-array/benches/fn.rs b/vortex-array/benches/fn.rs new file mode 100644 index 0000000000..d5bcc07091 --- /dev/null +++ b/vortex-array/benches/fn.rs @@ -0,0 +1,67 @@ +use arrow_array::types::UInt32Type; +use arrow_array::UInt32Array; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use vortex::array::PrimitiveArray; +use vortex::elementwise::{BinaryFn, UnaryFn}; +use vortex::validity::Validity; +use vortex::IntoArray; + +fn vortex_unary_add(c: &mut Criterion) { + let data = PrimitiveArray::from_vec((0_u32..1_000_000).collect::>(), Validity::AllValid); + c.bench_function("vortex_unary_add", |b| { + b.iter_batched( + || (data.clone()), + |data| data.unary(|v: u32| v + 1).unwrap(), + BatchSize::SmallInput, + ) + }); +} + +fn arrow_unary_add(c: &mut Criterion) { + let data = UInt32Array::from_iter_values(0_u32..1_000_000); + c.bench_function("arrow_unary_add", |b| { + b.iter_batched( + || data.clone(), + |data: arrow_array::PrimitiveArray| data.unary::<_, UInt32Type>(|v| v + 1), + BatchSize::SmallInput, + ) + }); +} + +fn vortex_binary_add(c: &mut Criterion) { + let lhs = PrimitiveArray::from_vec((0_u32..1_000_000).collect::>(), Validity::AllValid); + let rhs = PrimitiveArray::from_vec((0_u32..1_000_000).collect::>(), Validity::AllValid) + .into_array(); + c.bench_function("vortex_binary_add", |b| { + b.iter_batched( + || (lhs.clone(), rhs.clone()), + |(lhs, rhs)| lhs.binary(rhs, |l: u32, r: u32| l + r), + BatchSize::SmallInput, + ) + }); +} + +fn arrow_binary_add(c: &mut Criterion) { + let lhs = UInt32Array::from_iter_values(0_u32..1_000_000); + let rhs = UInt32Array::from_iter_values(0_u32..1_000_000); + c.bench_function("arrow_binary_add", |b| { + b.iter_batched( + || (lhs.clone(), rhs.clone()), + |(lhs, rhs)| { + arrow_arith::arity::binary::<_, _, _, UInt32Type>(&lhs, &rhs, |a, b| a + b).unwrap() + }, + BatchSize::SmallInput, + ) + }); +} + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = + arrow_unary_add, + vortex_unary_add, + arrow_binary_add, + vortex_binary_add, +); +criterion_main!(benches); diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index b801472c23..947064df08 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -1,3 +1,4 @@ +use std::mem::{transmute, MaybeUninit}; use std::ptr; use std::sync::Arc; @@ -10,7 +11,8 @@ use vortex_buffer::Buffer; use vortex_dtype::{match_each_native_ptype, DType, NativePType, PType}; use vortex_error::{vortex_bail, VortexResult}; -use crate::iter::{Accessor, AccessorRef}; +use crate::elementwise::{dyn_cast_array_iter, BinaryFn, UnaryFn}; +use crate::iter::{Accessor, AccessorRef, Batch, ITER_BATCH_SIZE}; use crate::stats::StatsSet; use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; use crate::variants::{ArrayVariants, PrimitiveArrayTrait}; @@ -314,8 +316,141 @@ impl Array { } } +// This is an arbitrary value, tried a few seems like this is a better value than smaller ones, +// I assume there's some hardware dependency here but this seems to be good enough +const CHUNK_SIZE: usize = 1024; + +impl UnaryFn for PrimitiveArray { + fn unary O>( + &self, + unary_fn: F, + ) -> VortexResult { + let data = self.maybe_null_slice::(); + let mut output: Vec> = Vec::with_capacity(data.len()); + // Safety: we are going to apply the fn to every element and store it so the full length will be utilized + unsafe { output.set_len(data.len()) }; + + let chunks = data.chunks_exact(CHUNK_SIZE); + + // We start with the reminder because of ownership + let reminder_start_idx = data.len() - (data.len() % CHUNK_SIZE); + for (index, item) in chunks.remainder().iter().enumerate() { + // Safety: This access is bound by the same range as the output's capacity and length, so its within the Vec's allocated memory + unsafe { + *output.get_unchecked_mut(reminder_start_idx + index) = + MaybeUninit::new(unary_fn(*item)); + } + } + + let mut offset = 0; + + for chunk in chunks { + // We know the size of the chunk, and we know output is the same length as the input array + let chunk: [I; CHUNK_SIZE] = chunk.try_into()?; + let mut output_slice: [_; CHUNK_SIZE] = + output[offset..offset + CHUNK_SIZE].try_into()?; + + for idx in 0..CHUNK_SIZE { + output_slice[idx] = MaybeUninit::new(unary_fn(chunk[idx])); + } + + offset += CHUNK_SIZE; + } + + // Safety: `MaybeUninit` is a transparent struct and we know the actual length of the vec. + let output = unsafe { transmute::>, Vec>(output) }; + + Ok(PrimitiveArray::from_vec(output, self.validity()).into_array()) + } +} + +impl BinaryFn for PrimitiveArray { + fn binary O>( + &self, + rhs: Array, + binary_fn: F, + ) -> VortexResult { + if self.len() != rhs.len() { + vortex_bail!(InvalidArgument: "Both arguments to `binary` should be of the same length"); + } + if !self.dtype().eq_ignore_nullability(rhs.dtype()) { + vortex_bail!(MismatchedTypes: self.dtype(), rhs.dtype()); + } + + if PType::try_from(self.dtype())? != I::PTYPE { + vortex_bail!(MismatchedTypes: self.dtype(), I::PTYPE); + } + + let lhs = self.maybe_null_slice::(); + + let mut output: Vec> = Vec::with_capacity(self.len()); + // Safety: we are going to apply the fn to every element and store it so the full length will be utilized + unsafe { output.set_len(self.len()) }; + + let validity = self + .validity() + .and(rhs.with_dyn(|a| a.logical_validity().into_validity()))?; + + let mut idx_offset = 0; + let rhs_iter = dyn_cast_array_iter::(&rhs); + + for batch in rhs_iter { + let batch_len = batch.len(); + process_batch( + &lhs[idx_offset..idx_offset + batch_len], + batch, + &binary_fn, + idx_offset, + output.as_mut_slice(), + ); + idx_offset += batch_len; + } + + // Safety: `MaybeUninit` is a transparent struct and we know the actual length of the vec. + let output = unsafe { transmute::>, Vec>(output) }; + + Ok(PrimitiveArray::from_vec(output, validity).into_array()) + } +} + +fn process_batch O>( + lhs: &[I], + batch: Batch, + f: F, + idx_offset: usize, + output: &mut [MaybeUninit], +) { + assert_eq!(batch.len(), lhs.len()); + + if batch.len() == ITER_BATCH_SIZE { + let lhs: [I; ITER_BATCH_SIZE] = lhs.try_into().unwrap(); + let rhs: [U; ITER_BATCH_SIZE] = batch.data().try_into().unwrap(); + // We know output is of the same length and lhs/rhs + let mut output_slice: [_; ITER_BATCH_SIZE] = output + [idx_offset..idx_offset + ITER_BATCH_SIZE] + .try_into() + .unwrap(); + + for idx in 0..ITER_BATCH_SIZE { + unsafe { + *output_slice.get_unchecked_mut(idx) = MaybeUninit::new(f(lhs[idx], rhs[idx])); + } + } + } else { + for (idx, rhs_item) in batch.data().iter().enumerate() { + // Safety: output is the same length as the original array, so we know these are still valid indexes + unsafe { + *output.get_unchecked_mut(idx + idx_offset) = + MaybeUninit::new(f(lhs[idx], *rhs_item)); + } + } + } +} + #[cfg(test)] mod tests { + use vortex_scalar::Scalar; + use super::*; #[test] @@ -347,4 +482,45 @@ mod tests { assert_eq!(idx as u32, v.unwrap()); } } + + #[test] + fn binary_fn_example() { + let input = PrimitiveArray::from_vec(vec![2u32, 2, 2, 2], Validity::AllValid); + + let scalar = Scalar::from(2u32); + + let o = input + .unary(move |v: u32| { + let scalar_v = u32::try_from(&scalar).unwrap(); + if v == scalar_v { + 1_u8 + } else { + 0_u8 + } + }) + .unwrap(); + + let output_iter = o + .with_dyn(|a| a.as_primitive_array_unchecked().u8_iter()) + .unwrap() + .flatten(); + + for v in output_iter { + assert_eq!(v.unwrap(), 1); + } + } + + #[test] + fn unary_fn_example() { + let input = PrimitiveArray::from_vec(vec![2u32, 2, 2, 2], Validity::AllValid); + let output = input.unary(|u: u32| u + 1).unwrap(); + + for o in output + .with_dyn(|a| a.as_primitive_array_unchecked().u32_iter()) + .unwrap() + .flatten() + { + assert_eq!(o.unwrap(), 3); + } + } } diff --git a/vortex-array/src/elementwise.rs b/vortex-array/src/elementwise.rs new file mode 100644 index 0000000000..c6f8eabb9e --- /dev/null +++ b/vortex-array/src/elementwise.rs @@ -0,0 +1,91 @@ +use vortex_dtype::{NativePType, PType}; +use vortex_error::VortexResult; + +use crate::iter::Batch; +use crate::{Array, ArrayDType}; + +pub trait BinaryFn { + fn binary O>( + &self, + rhs: Array, + binary_fn: F, + ) -> VortexResult; +} + +pub trait UnaryFn { + fn unary O>( + &self, + unary_fn: F, + ) -> VortexResult; +} + +pub fn dyn_cast_array_iter(array: &Array) -> Box>> { + match PType::try_from(array.dtype()).unwrap() { + PType::U8 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().u8_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::U16 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().u16_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::U32 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().u32_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::U64 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().u64_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::I8 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().i8_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::I16 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().i16_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::I32 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().i32_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::I64 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().i64_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::F16 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().u64_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::F32 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().f32_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + PType::F64 => Box::new( + array + .with_dyn(|a| a.as_primitive_array_unchecked().f64_iter()) + .unwrap() + .map(|b| b.as_::()), + ), + } +} diff --git a/vortex-array/src/iter/mod.rs b/vortex-array/src/iter/mod.rs index 42b8b677de..56077944a0 100644 --- a/vortex-array/src/iter/mod.rs +++ b/vortex-array/src/iter/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; pub use adapter::*; pub use ext::*; -use vortex_dtype::DType; +use vortex_dtype::{DType, NativePType}; use vortex_error::VortexResult; use crate::validity::Validity; @@ -11,7 +11,7 @@ use crate::Array; mod adapter; mod ext; -pub const BATCH_SIZE: usize = 1024; +pub const ITER_BATCH_SIZE: usize = 1024; /// A stream of array chunks along with a DType. /// Analogous to Arrow's RecordBatchReader. @@ -24,7 +24,7 @@ pub type AccessorRef = Arc>; /// Define the basic behavior required for batched iterators pub trait Accessor: Send + Sync { fn batch_size(&self, start_idx: usize) -> usize { - usize::min(BATCH_SIZE, self.array_len() - start_idx) + usize::min(ITER_BATCH_SIZE, self.array_len() - start_idx) } fn array_len(&self) -> usize; fn is_valid(&self, index: usize) -> bool; @@ -124,6 +124,18 @@ impl Batch { pub unsafe fn get_unchecked(&self, index: usize) -> &T { unsafe { self.data.get_unchecked(index) } } + + pub fn as_(self) -> Batch { + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + Batch { + data: unsafe { std::mem::transmute::, Vec>(self.data) }, + validity: self.validity, + } + } + + pub fn data(&self) -> &[T] { + self.data.as_ref() + } } pub struct FlattenedBatch { diff --git a/vortex-array/src/lib.rs b/vortex-array/src/lib.rs index bf4ae5ba61..049747aba0 100644 --- a/vortex-array/src/lib.rs +++ b/vortex-array/src/lib.rs @@ -41,6 +41,7 @@ pub mod compress; pub mod compute; mod context; mod data; +pub mod elementwise; pub mod encoding; mod implementation; pub mod iter; diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index b333869d0c..f3991ca1c2 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -5,7 +5,7 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::BoolArray; use crate::compute::unary::scalar_at_unchecked; -use crate::compute::{filter, slice, take}; +use crate::compute::{and, filter, slice, take}; use crate::stats::ArrayStatistics; use crate::{Array, IntoArray, IntoArrayVariant}; @@ -140,6 +140,28 @@ impl Validity { } } } + + /// Logically & two Validity values of the same length + pub fn and(self, rhs: Validity) -> VortexResult { + let validity = match (self, rhs) { + // Any `AllInvalid` makes the output all invalid values + (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid, + // All truthy values on one side, which makes no effect on an `Array` variant + (Validity::Array(a), Validity::AllValid) + | (Validity::Array(a), Validity::NonNullable) + | (Validity::NonNullable, Validity::Array(a)) + | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a.clone()), + (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable, + // Both sides are all valid + (Validity::NonNullable, Validity::AllValid) + | (Validity::AllValid, Validity::NonNullable) + | (Validity::AllValid, Validity::AllValid) => Validity::AllValid, + // Here we actually have to do some work + (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(and(&lhs, &rhs)?), + }; + + Ok(validity) + } } impl PartialEq for Validity { diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 71915ffc76..6a4da9b5a0 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -134,6 +134,10 @@ pub trait PrimitiveArrayTrait: ArrayTrait { None } + fn f16_accessor(&self) -> Option> { + None + } + fn f32_accessor(&self) -> Option> { None } @@ -142,6 +146,14 @@ pub trait PrimitiveArrayTrait: ArrayTrait { None } + fn u8_iter(&self) -> Option> { + self.u8_accessor().map(VectorizedArrayIter::new) + } + + fn u16_iter(&self) -> Option> { + self.u16_accessor().map(VectorizedArrayIter::new) + } + fn u32_iter(&self) -> Option> { self.u32_accessor().map(VectorizedArrayIter::new) } @@ -150,6 +162,26 @@ pub trait PrimitiveArrayTrait: ArrayTrait { self.u64_accessor().map(VectorizedArrayIter::new) } + fn i8_iter(&self) -> Option> { + self.i8_accessor().map(VectorizedArrayIter::new) + } + + fn i16_iter(&self) -> Option> { + self.i16_accessor().map(VectorizedArrayIter::new) + } + + fn i32_iter(&self) -> Option> { + self.i32_accessor().map(VectorizedArrayIter::new) + } + + fn i64_iter(&self) -> Option> { + self.i64_accessor().map(VectorizedArrayIter::new) + } + + fn f16_iter(&self) -> Option> { + self.f16_accessor().map(VectorizedArrayIter::new) + } + fn f32_iter(&self) -> Option> { self.f32_accessor().map(VectorizedArrayIter::new) }