Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch Compute Function #50

Merged
merged 17 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/fastlanez
69 changes: 69 additions & 0 deletions vortex/src/array/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use crate::array::primitive::PrimitiveArray;
use crate::array::CloneOptionalArray;
use crate::compute::cast::CastPrimitiveFn;
use crate::error::{VortexError, VortexResult};
use crate::match_each_native_ptype;
use crate::ptype::{NativePType, PType};

impl CastPrimitiveFn for PrimitiveArray {
fn cast_primitive(&self, ptype: &PType) -> VortexResult<PrimitiveArray> {
if self.ptype() == ptype {
Ok(self.clone())
} else {
match_each_native_ptype!(ptype, |$T| {
Ok(PrimitiveArray::from_nullable(
cast::<$T>(self)?,
self.validity().clone_optional(),
))
})
}
}
}

fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Vec<T>> {
match_each_native_ptype!(array.ptype(), |$E| {
array
.typed_data::<$E>()
.iter()
// TODO(ngates): allow configurable checked/unchecked casting
.map(|v| {
T::from(*v).ok_or_else(|| {
VortexError::ComputeError(format!("Failed to cast {} to {:?}", v, T::PTYPE).into())
})
})
.collect()
})
}

#[cfg(test)]
mod test {
use crate::array::primitive::PrimitiveArray;
use crate::compute;
use crate::error::VortexError;
use crate::ptype::PType;

#[test]
fn cast_u32_u8() {
let arr = PrimitiveArray::from_vec(vec![0u32, 10, 200]);
let u8arr = compute::cast::cast_primitive(&arr, &PType::U8).unwrap();
assert_eq!(u8arr.typed_data::<u8>(), vec![0u8, 10, 200]);
}

#[test]
fn cast_u32_f32() {
let arr = PrimitiveArray::from_vec(vec![0u32, 10, 200]);
let u8arr = compute::cast::cast_primitive(&arr, &PType::F32).unwrap();
assert_eq!(u8arr.typed_data::<f32>(), vec![0.0f32, 10., 200.]);
}

#[test]
fn cast_i32_u32() {
let arr = PrimitiveArray::from_vec(vec![-1i32]);
assert_eq!(
compute::cast::cast_primitive(&arr, &PType::U32)
.err()
.unwrap(),
VortexError::ComputeError("Failed to cast -1 to U32".into(),)
)
}
}
23 changes: 23 additions & 0 deletions vortex/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use crate::array::primitive::PrimitiveArray;
use crate::compute::cast::CastPrimitiveFn;
use crate::compute::patch::PatchFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::ArrayCompute;

mod cast;
mod patch;
mod scalar_at;

impl ArrayCompute for PrimitiveArray {
fn cast_primitive(&self) -> Option<&dyn CastPrimitiveFn> {
Some(self)
}

fn patch(&self) -> Option<&dyn PatchFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
}
39 changes: 39 additions & 0 deletions vortex/src/array/primitive/compute/patch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use itertools::Itertools;

use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::{SparseArray, SparseEncoding};
use crate::array::{Array, ArrayRef, CloneOptionalArray};
use crate::compute::patch::PatchFn;
use crate::error::{VortexError, VortexResult};
use crate::{compute, match_each_native_ptype};

impl PatchFn for PrimitiveArray {
fn patch(&self, patch: &dyn Array) -> VortexResult<ArrayRef> {
match patch.encoding().id() {
&SparseEncoding::ID => patch_with_sparse(self, patch.as_sparse()),
// TODO(ngates): support a default implementation based on iter_arrow?
_ => Err(VortexError::MissingKernel(
"patch",
self.encoding().id(),
vec![patch.encoding().id()],
)),
}
}
}

fn patch_with_sparse(array: &PrimitiveArray, patch: &SparseArray) -> VortexResult<ArrayRef> {
let patch_indices = patch.resolved_indices();
match_each_native_ptype!(array.ptype(), |$T| {
let mut values = Vec::from(array.typed_data::<$T>());
let patch_values = compute::cast::cast_primitive(patch.values(), array.ptype())?;
for (idx, value) in patch_indices.iter().zip_eq(patch_values.typed_data::<$T>().iter()) {
values[*idx] = *value;
}
Ok(PrimitiveArray::from_nullable(
values,
// TODO(ngates): if patch values has null, we need to patch into the validity buffer
array.validity().clone_optional(),
).boxed())
})
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
use crate::array::primitive::PrimitiveArray;
use crate::array::Array;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
use crate::match_each_native_ptype;
use crate::scalar::{NullableScalar, Scalar};

impl ArrayCompute for PrimitiveArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
}

impl ScalarAtFn for PrimitiveArray {
fn scalar_at(&self, index: usize) -> VortexResult<Box<dyn Scalar>> {
if self.is_valid(index) {
Expand Down
13 changes: 12 additions & 1 deletion vortex/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::sync::{Arc, RwLock};
use allocator_api2::alloc::Allocator;
use arrow::alloc::ALIGNMENT as ARROW_ALIGNMENT;
use arrow::array::{make_array, ArrayData, AsArray};
use arrow::buffer::{Buffer, NullBuffer};
use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer};
use linkme::distributed_slice;
use log::debug;

Expand Down Expand Up @@ -129,6 +129,17 @@ impl PrimitiveArray {
pub fn validity(&self) -> Option<&dyn Array> {
self.validity.as_deref()
}

pub fn scalar_buffer<T: NativePType>(&self) -> ScalarBuffer<T> {
ScalarBuffer::from(self.buffer().clone())
}

pub fn typed_data<T: NativePType>(&self) -> &[T] {
if self.ptype() != &T::PTYPE {
panic!("Invalid PType")
}
self.buffer().typed_data()
}
}

impl Array for PrimitiveArray {
Expand Down
33 changes: 20 additions & 13 deletions vortex/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::iter;
use std::sync::{Arc, RwLock};

use arrow::array::AsArray;
use arrow::array::BooleanBufferBuilder;
use arrow::array::{ArrayRef as ArrowArrayRef, PrimitiveArray as ArrowPrimitiveArray};
use arrow::array::{
ArrayRef as ArrowArrayRef, BooleanBufferBuilder, PrimitiveArray as ArrowPrimitiveArray,
};
use arrow::buffer::{NullBuffer, ScalarBuffer};
use arrow::datatypes::UInt64Type;
use linkme::distributed_slice;
Expand Down Expand Up @@ -79,6 +80,22 @@ impl SparseArray {
pub fn indices(&self) -> &dyn Array {
self.indices.as_ref()
}

/// Return indices as a vector of usize with the indices_offset applied.
pub fn resolved_indices(&self) -> Vec<usize> {
let mut indices = Vec::with_capacity(self.len());
self.indices().iter_arrow().for_each(|c| {
indices.extend(
arrow::compute::cast(c.as_ref(), &arrow::datatypes::DataType::UInt64)
.unwrap()
.as_primitive::<UInt64Type>()
.values()
.into_iter()
.map(|v| (*v as usize) - self.indices_offset),
)
});
indices
}
}

impl Array for SparseArray {
Expand Down Expand Up @@ -119,16 +136,7 @@ impl Array for SparseArray {

fn iter_arrow(&self) -> Box<ArrowIterator> {
// Resolve our indices into a vector of usize applying the offset
let mut indices = Vec::with_capacity(self.len());
self.indices().iter_arrow().for_each(|c| {
indices.extend(
c.as_primitive::<UInt64Type>()
.values()
.into_iter()
.map(|v| (*v as usize) - self.indices_offset),
)
});

let indices = self.resolved_indices();
let array: ArrowArrayRef = match_arrow_numeric_type!(self.values().dtype(), |$E| {
let mut validity = BooleanBufferBuilder::new(self.len());
validity.append_n(self.len(), false);
Expand All @@ -147,7 +155,6 @@ impl Array for SparseArray {
Some(NullBuffer::from(validity.finish())),
))
});

Box::new(iter::once(array))
}

Expand Down
2 changes: 1 addition & 1 deletion vortex/src/compute/as_contiguous.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use arrow::buffer::BooleanBuffer;
use itertools::Itertools;
use vortex_alloc::{AlignedVec, ALIGNED_ALLOCATOR};

use crate::array::bool::{BoolArray, BoolEncoding};
use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::{PrimitiveArray, PrimitiveEncoding};
use crate::array::{Array, ArrayRef, CloneOptionalArray};
use crate::error::{VortexError, VortexResult};
use crate::ptype::{match_each_native_ptype, NativePType};
use vortex_alloc::{AlignedVec, ALIGNED_ALLOCATOR};

pub fn as_contiguous(arrays: Vec<ArrayRef>) -> VortexResult<ArrayRef> {
if arrays.is_empty() {
Expand Down
23 changes: 19 additions & 4 deletions vortex/src/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
use crate::dtype::DType;
use crate::scalar::Scalar;
use crate::array::primitive::PrimitiveArray;
use crate::array::Array;
use crate::error::{VortexError, VortexResult};
use crate::ptype::PType;

pub fn cast_scalar(_value: &dyn Scalar, _dtype: &DType) -> Box<dyn Scalar> {
todo!()
pub trait CastPrimitiveFn {
fn cast_primitive(&self, ptype: &PType) -> VortexResult<PrimitiveArray>;
}

pub fn cast_primitive(array: &dyn Array, ptype: &PType) -> VortexResult<PrimitiveArray> {
PType::try_from(array.dtype()).map_err(|_| VortexError::InvalidDType(array.dtype().clone()))?;
array
.cast_primitive()
.map(|t| t.cast_primitive(ptype))
.unwrap_or_else(|| {
Err(VortexError::NotImplemented(
"cast_primitive",
array.encoding().id(),
))
})
}
14 changes: 12 additions & 2 deletions vortex/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
use crate::compute::scalar_at::ScalarAtFn;
use cast::CastPrimitiveFn;
use patch::PatchFn;
use scalar_at::ScalarAtFn;
use take::TakeFn;

pub mod add;
pub mod as_contiguous;
pub mod cast;
pub mod patch;
pub mod repeat;
pub mod scalar_at;
pub mod search_sorted;
pub mod take;

pub trait ArrayCompute {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
fn cast_primitive(&self) -> Option<&dyn CastPrimitiveFn> {
None
}

fn patch(&self) -> Option<&dyn PatchFn> {
None
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
None
}
fn take(&self) -> Option<&dyn TakeFn> {
None
}
Expand Down
22 changes: 22 additions & 0 deletions vortex/src/compute/patch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::array::{Array, ArrayRef};
use crate::error::{VortexError, VortexResult};

pub trait PatchFn {
fn patch(&self, patch: &dyn Array) -> VortexResult<ArrayRef>;
}

/// Returns a new array where the non-null values from the patch array are replaced in the original.
pub fn patch(array: &dyn Array, patch: &dyn Array) -> VortexResult<ArrayRef> {
if array.len() != patch.len() {
return Err(VortexError::InvalidArgument(
"patch array must have the same length as the original array".into(),
));
}

// TODO(ngates): check the dtype matches

array
.patch()
.map(|t| t.patch(patch))
.unwrap_or_else(|| Err(VortexError::NotImplemented("take", array.encoding().id())))
}
6 changes: 3 additions & 3 deletions vortex/src/compute/scalar_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult<Box<dyn Scalar
.scalar_at()
.map(|t| t.scalar_at(index))
.unwrap_or_else(|| {
// TODO(ngates): default implementation of decode and then try again
Err(VortexError::ComputeError(
format!("scalar_at not implemented for {}", &array.encoding().id()).into(),
Err(VortexError::NotImplemented(
"scalar_at",
array.encoding().id(),
))
})
}
10 changes: 4 additions & 6 deletions vortex/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ pub trait TakeFn {
}

pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
array.take().map(|t| t.take(indices)).unwrap_or_else(|| {
// TODO(ngates): default implementation of decode and then try again
Err(VortexError::ComputeError(
format!("take not implemented for {}", &array.encoding().id()).into(),
))
})
array
.take()
.map(|t| t.take(indices))
.unwrap_or_else(|| Err(VortexError::NotImplemented("take", array.encoding().id())))
}
8 changes: 8 additions & 0 deletions vortex/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ pub enum VortexError {
LengthMismatch,
#[error("{0}")]
ComputeError(ErrString),
#[error("{0}")]
InvalidArgument(ErrString),
// Used when a function is not implemented for a given array type.
#[error("function {0} not implemented for {1}")]
NotImplemented(&'static str, &'static EncodingId),
// Used when a function is implemented for an array type, but the RHS is not supported.
#[error("missing kernel {0} for {1} and {2:?}")]
MissingKernel(&'static str, &'static EncodingId, Vec<&'static EncodingId>),
#[error("invalid data type: {0}")]
InvalidDType(DType),
#[error("invalid physical type: {0:?}")]
Expand Down
Loading