Skip to content

Commit

Permalink
Unary and Binary functions trait (#726)
Browse files Browse the repository at this point in the history
Introducing specialized traits for unary (like map) and binary (like &)
operations on arrays. This will allow us to have specialized and
ergonomic implementations for cases when we know the encoding of an
array (like in `ArrayCompute` functions)

## Benchmarks
As of 08bbd8d running on my M3 Max macbook

```
arrow_unary_add         time:   [228.22 µs 233.98 µs 239.93 µs]

vortex_unary_add        time:   [1.5322 µs 1.5469 µs 1.5604 µs]

arrow_binary_add        time:   [237.74 µs 240.93 µs 243.98 µs]

vortex_binary_add       time:   [93.821 µs 94.796 µs 95.918 µs]
```
  • Loading branch information
AdamGS authored Sep 5, 2024
1 parent d91eac7 commit 55220c0
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 5 deletions.
4 changes: 4 additions & 0 deletions vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ harness = false
[[bench]]
name = "iter"
harness = false

[[bench]]
name = "fn"
harness = false
67 changes: 67 additions & 0 deletions vortex-array/benches/fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use arrow_array::types::UInt32Type;
use arrow_array::UInt32Array;
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use vortex::array::PrimitiveArray;
use vortex::elementwise::{BinaryFn, UnaryFn};
use vortex::validity::Validity;
use vortex::IntoArray;

fn vortex_unary_add(c: &mut Criterion) {
let data = PrimitiveArray::from_vec((0_u32..1_000_000).collect::<Vec<_>>(), Validity::AllValid);
c.bench_function("vortex_unary_add", |b| {
b.iter_batched(
|| (data.clone()),
|data| data.unary(|v: u32| v + 1).unwrap(),
BatchSize::SmallInput,
)
});
}

fn arrow_unary_add(c: &mut Criterion) {
let data = UInt32Array::from_iter_values(0_u32..1_000_000);
c.bench_function("arrow_unary_add", |b| {
b.iter_batched(
|| data.clone(),
|data: arrow_array::PrimitiveArray<UInt32Type>| data.unary::<_, UInt32Type>(|v| v + 1),
BatchSize::SmallInput,
)
});
}

fn vortex_binary_add(c: &mut Criterion) {
let lhs = PrimitiveArray::from_vec((0_u32..1_000_000).collect::<Vec<_>>(), Validity::AllValid);
let rhs = PrimitiveArray::from_vec((0_u32..1_000_000).collect::<Vec<_>>(), Validity::AllValid)
.into_array();
c.bench_function("vortex_binary_add", |b| {
b.iter_batched(
|| (lhs.clone(), rhs.clone()),
|(lhs, rhs)| lhs.binary(rhs, |l: u32, r: u32| l + r),
BatchSize::SmallInput,
)
});
}

fn arrow_binary_add(c: &mut Criterion) {
let lhs = UInt32Array::from_iter_values(0_u32..1_000_000);
let rhs = UInt32Array::from_iter_values(0_u32..1_000_000);
c.bench_function("arrow_binary_add", |b| {
b.iter_batched(
|| (lhs.clone(), rhs.clone()),
|(lhs, rhs)| {
arrow_arith::arity::binary::<_, _, _, UInt32Type>(&lhs, &rhs, |a, b| a + b).unwrap()
},
BatchSize::SmallInput,
)
});
}

criterion_group!(
name = benches;
config = Criterion::default();
targets =
arrow_unary_add,
vortex_unary_add,
arrow_binary_add,
vortex_binary_add,
);
criterion_main!(benches);
178 changes: 177 additions & 1 deletion vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::mem::{transmute, MaybeUninit};
use std::ptr;
use std::sync::Arc;

Expand All @@ -10,7 +11,8 @@ use vortex_buffer::Buffer;
use vortex_dtype::{match_each_native_ptype, DType, NativePType, PType};
use vortex_error::{vortex_bail, VortexResult};

use crate::iter::{Accessor, AccessorRef};
use crate::elementwise::{dyn_cast_array_iter, BinaryFn, UnaryFn};
use crate::iter::{Accessor, AccessorRef, Batch, ITER_BATCH_SIZE};
use crate::stats::StatsSet;
use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata};
use crate::variants::{ArrayVariants, PrimitiveArrayTrait};
Expand Down Expand Up @@ -314,8 +316,141 @@ impl Array {
}
}

// This is an arbitrary value, tried a few seems like this is a better value than smaller ones,
// I assume there's some hardware dependency here but this seems to be good enough
const CHUNK_SIZE: usize = 1024;

impl UnaryFn for PrimitiveArray {
fn unary<I: NativePType, O: NativePType, F: Fn(I) -> O>(
&self,
unary_fn: F,
) -> VortexResult<Array> {
let data = self.maybe_null_slice::<I>();
let mut output: Vec<MaybeUninit<O>> = Vec::with_capacity(data.len());
// Safety: we are going to apply the fn to every element and store it so the full length will be utilized
unsafe { output.set_len(data.len()) };

let chunks = data.chunks_exact(CHUNK_SIZE);

// We start with the reminder because of ownership
let reminder_start_idx = data.len() - (data.len() % CHUNK_SIZE);
for (index, item) in chunks.remainder().iter().enumerate() {
// Safety: This access is bound by the same range as the output's capacity and length, so its within the Vec's allocated memory
unsafe {
*output.get_unchecked_mut(reminder_start_idx + index) =
MaybeUninit::new(unary_fn(*item));
}
}

let mut offset = 0;

for chunk in chunks {
// We know the size of the chunk, and we know output is the same length as the input array
let chunk: [I; CHUNK_SIZE] = chunk.try_into()?;
let mut output_slice: [_; CHUNK_SIZE] =
output[offset..offset + CHUNK_SIZE].try_into()?;

for idx in 0..CHUNK_SIZE {
output_slice[idx] = MaybeUninit::new(unary_fn(chunk[idx]));
}

offset += CHUNK_SIZE;
}

// Safety: `MaybeUninit` is a transparent struct and we know the actual length of the vec.
let output = unsafe { transmute::<Vec<MaybeUninit<O>>, Vec<O>>(output) };

Ok(PrimitiveArray::from_vec(output, self.validity()).into_array())
}
}

impl BinaryFn for PrimitiveArray {
fn binary<I: NativePType, U: NativePType, O: NativePType, F: Fn(I, U) -> O>(
&self,
rhs: Array,
binary_fn: F,
) -> VortexResult<Array> {
if self.len() != rhs.len() {
vortex_bail!(InvalidArgument: "Both arguments to `binary` should be of the same length");
}
if !self.dtype().eq_ignore_nullability(rhs.dtype()) {
vortex_bail!(MismatchedTypes: self.dtype(), rhs.dtype());
}

if PType::try_from(self.dtype())? != I::PTYPE {
vortex_bail!(MismatchedTypes: self.dtype(), I::PTYPE);
}

let lhs = self.maybe_null_slice::<I>();

let mut output: Vec<MaybeUninit<O>> = Vec::with_capacity(self.len());
// Safety: we are going to apply the fn to every element and store it so the full length will be utilized
unsafe { output.set_len(self.len()) };

let validity = self
.validity()
.and(rhs.with_dyn(|a| a.logical_validity().into_validity()))?;

let mut idx_offset = 0;
let rhs_iter = dyn_cast_array_iter::<U>(&rhs);

for batch in rhs_iter {
let batch_len = batch.len();
process_batch(
&lhs[idx_offset..idx_offset + batch_len],
batch,
&binary_fn,
idx_offset,
output.as_mut_slice(),
);
idx_offset += batch_len;
}

// Safety: `MaybeUninit` is a transparent struct and we know the actual length of the vec.
let output = unsafe { transmute::<Vec<MaybeUninit<O>>, Vec<O>>(output) };

Ok(PrimitiveArray::from_vec(output, validity).into_array())
}
}

fn process_batch<I: NativePType, U: NativePType, O: NativePType, F: Fn(I, U) -> O>(
lhs: &[I],
batch: Batch<U>,
f: F,
idx_offset: usize,
output: &mut [MaybeUninit<O>],
) {
assert_eq!(batch.len(), lhs.len());

if batch.len() == ITER_BATCH_SIZE {
let lhs: [I; ITER_BATCH_SIZE] = lhs.try_into().unwrap();
let rhs: [U; ITER_BATCH_SIZE] = batch.data().try_into().unwrap();
// We know output is of the same length and lhs/rhs
let mut output_slice: [_; ITER_BATCH_SIZE] = output
[idx_offset..idx_offset + ITER_BATCH_SIZE]
.try_into()
.unwrap();

for idx in 0..ITER_BATCH_SIZE {
unsafe {
*output_slice.get_unchecked_mut(idx) = MaybeUninit::new(f(lhs[idx], rhs[idx]));
}
}
} else {
for (idx, rhs_item) in batch.data().iter().enumerate() {
// Safety: output is the same length as the original array, so we know these are still valid indexes
unsafe {
*output.get_unchecked_mut(idx + idx_offset) =
MaybeUninit::new(f(lhs[idx], *rhs_item));
}
}
}
}

#[cfg(test)]
mod tests {
use vortex_scalar::Scalar;

use super::*;

#[test]
Expand Down Expand Up @@ -347,4 +482,45 @@ mod tests {
assert_eq!(idx as u32, v.unwrap());
}
}

#[test]
fn binary_fn_example() {
let input = PrimitiveArray::from_vec(vec![2u32, 2, 2, 2], Validity::AllValid);

let scalar = Scalar::from(2u32);

let o = input
.unary(move |v: u32| {
let scalar_v = u32::try_from(&scalar).unwrap();
if v == scalar_v {
1_u8
} else {
0_u8
}
})
.unwrap();

let output_iter = o
.with_dyn(|a| a.as_primitive_array_unchecked().u8_iter())
.unwrap()
.flatten();

for v in output_iter {
assert_eq!(v.unwrap(), 1);
}
}

#[test]
fn unary_fn_example() {
let input = PrimitiveArray::from_vec(vec![2u32, 2, 2, 2], Validity::AllValid);
let output = input.unary(|u: u32| u + 1).unwrap();

for o in output
.with_dyn(|a| a.as_primitive_array_unchecked().u32_iter())
.unwrap()
.flatten()
{
assert_eq!(o.unwrap(), 3);
}
}
}
91 changes: 91 additions & 0 deletions vortex-array/src/elementwise.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use vortex_dtype::{NativePType, PType};
use vortex_error::VortexResult;

use crate::iter::Batch;
use crate::{Array, ArrayDType};

pub trait BinaryFn {
fn binary<I: NativePType, U: NativePType, O: NativePType, F: Fn(I, U) -> O>(
&self,
rhs: Array,
binary_fn: F,
) -> VortexResult<Array>;
}

pub trait UnaryFn {
fn unary<I: NativePType, O: NativePType, F: Fn(I) -> O>(
&self,
unary_fn: F,
) -> VortexResult<Array>;
}

pub fn dyn_cast_array_iter<N: NativePType>(array: &Array) -> Box<dyn Iterator<Item = Batch<N>>> {
match PType::try_from(array.dtype()).unwrap() {
PType::U8 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().u8_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::U16 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().u16_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::U32 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().u32_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::U64 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().u64_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::I8 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().i8_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::I16 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().i16_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::I32 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().i32_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::I64 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().i64_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::F16 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().u64_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::F32 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().f32_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
PType::F64 => Box::new(
array
.with_dyn(|a| a.as_primitive_array_unchecked().f64_iter())
.unwrap()
.map(|b| b.as_::<N>()),
),
}
}
Loading

0 comments on commit 55220c0

Please sign in to comment.