Skip to content

Commit

Permalink
Implement CastFn for chunkedarray (#497)
Browse files Browse the repository at this point in the history
try_cast on chunked array will now cast the chunks
  • Loading branch information
a10y authored Jul 22, 2024
1 parent 4660d26 commit 1f31308
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
63 changes: 62 additions & 1 deletion vortex-array/src/array/chunked/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::array::chunked::ChunkedArray;
use crate::compute::unary::{scalar_at, ScalarAtFn, SubtractScalarFn};
use crate::compute::unary::{scalar_at, try_cast, CastFn, ScalarAtFn, SubtractScalarFn};
use crate::compute::{ArrayCompute, SliceFn, TakeFn};
use crate::{Array, IntoArray};

mod slice;
mod take;

impl ArrayCompute for ChunkedArray {
fn cast(&self) -> Option<&dyn CastFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
Expand All @@ -32,3 +38,58 @@ impl ScalarAtFn for ChunkedArray {
scalar_at(&self.chunk(chunk_index).unwrap(), chunk_offset)
}
}

impl CastFn for ChunkedArray {
fn cast(&self, dtype: &DType) -> VortexResult<Array> {
let mut cast_chunks = Vec::new();
for chunk in self.chunks() {
cast_chunks.push(try_cast(&chunk, dtype)?);
}

Ok(ChunkedArray::try_new(cast_chunks, dtype.clone())?.into_array())
}
}

#[cfg(test)]
mod test {
use vortex_dtype::{DType, Nullability, PType};

use crate::array::chunked::ChunkedArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::unary::try_cast;
use crate::validity::Validity;
use crate::{IntoArray, IntoArrayVariant};

#[test]
fn test_cast_chunked() {
let arr0 = PrimitiveArray::from_vec(vec![0u32, 1], Validity::NonNullable).into_array();
let arr1 = PrimitiveArray::from_vec(vec![2u32, 3], Validity::NonNullable).into_array();

let chunked = ChunkedArray::try_new(
vec![arr0, arr1],
DType::Primitive(PType::U32, Nullability::NonNullable),
)
.unwrap()
.into_array();

// Two levels of chunking, just to be fancy.
let root = ChunkedArray::try_new(
vec![chunked],
DType::Primitive(PType::U32, Nullability::NonNullable),
)
.unwrap()
.into_array();

assert_eq!(
try_cast(
&root,
&DType::Primitive(PType::U64, Nullability::NonNullable)
)
.unwrap()
.into_primitive()
.unwrap()
.into_maybe_null_slice::<u64>(),
vec![0u64, 1, 2, 3],
);
}
}
25 changes: 11 additions & 14 deletions vortex-array/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,10 @@ fn varbin_to_arrow(varbin_array: VarBinArray) -> ArrayRef {
fn temporal_to_arrow(temporal_array: TemporalArray) -> ArrayRef {
macro_rules! extract_temporal_values {
($values:expr, $prim:ty) => {{
let temporal_values = try_cast(
&temporal_array.temporal_values(),
<$prim as NativePType>::PTYPE.into(),
)
.expect("values must cast to primitive type")
.into_primitive()
.expect("must be primitive array");
let temporal_values = try_cast($values, <$prim as NativePType>::PTYPE.into())
.expect("values must cast to primitive type")
.into_primitive()
.expect("must be primitive array");
let len = temporal_values.len();
let nulls = temporal_values
.logical_validity()
Expand All @@ -312,41 +309,41 @@ fn temporal_to_arrow(temporal_array: TemporalArray) -> ArrayRef {
TemporalMetadata::Date(time_unit) => match time_unit {
TimeUnit::D => {
let (scalars, nulls) =
extract_temporal_values!(temporal_array.temporal_values(), i32);
extract_temporal_values!(&temporal_array.temporal_values(), i32);
Arc::new(Date32Array::new(scalars, nulls))
}
TimeUnit::Ms => {
let (scalars, nulls) =
extract_temporal_values!(temporal_array.temporal_values(), i64);
extract_temporal_values!(&temporal_array.temporal_values(), i64);
Arc::new(Date64Array::new(scalars, nulls))
}
_ => panic!("invalid time_unit {time_unit} for vortex.date"),
},
TemporalMetadata::Time(time_unit) => match time_unit {
TimeUnit::S => {
let (scalars, nulls) =
extract_temporal_values!(temporal_array.temporal_values(), i32);
extract_temporal_values!(&temporal_array.temporal_values(), i32);
Arc::new(Time32SecondArray::new(scalars, nulls))
}
TimeUnit::Ms => {
let (scalars, nulls) =
extract_temporal_values!(temporal_array.temporal_values(), i32);
extract_temporal_values!(&temporal_array.temporal_values(), i32);
Arc::new(Time32MillisecondArray::new(scalars, nulls))
}
TimeUnit::Us => {
let (scalars, nulls) =
extract_temporal_values!(temporal_array.temporal_values(), i64);
extract_temporal_values!(&temporal_array.temporal_values(), i64);
Arc::new(Time64MicrosecondArray::new(scalars, nulls))
}
TimeUnit::Ns => {
let (scalars, nulls) =
extract_temporal_values!(temporal_array.temporal_values(), i64);
extract_temporal_values!(&temporal_array.temporal_values(), i64);
Arc::new(Time64NanosecondArray::new(scalars, nulls))
}
_ => panic!("invalid TimeUnit for Time32 array {time_unit}"),
},
TemporalMetadata::Timestamp(time_unit, _) => {
let (scalars, nulls) = extract_temporal_values!(temporal_array.temporal_values(), i64);
let (scalars, nulls) = extract_temporal_values!(&temporal_array.temporal_values(), i64);
match time_unit {
TimeUnit::Ns => Arc::new(TimestampNanosecondArray::new(scalars, nulls)),
TimeUnit::Us => Arc::new(TimestampMicrosecondArray::new(scalars, nulls)),
Expand Down

0 comments on commit 1f31308

Please sign in to comment.