Skip to content

Commit

Permalink
Merge branch 'develop' into jc/scalar-subtract
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcasale committed May 1, 2024
2 parents be37995 + eabb8e6 commit d2e0a08
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 38 deletions.
6 changes: 6 additions & 0 deletions vortex-array/src/array/chunked/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use crate::array::chunked::ChunkedArray;
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::scalar_subtract::SubtractScalarFn;
use crate::compute::slice::SliceFn;
use crate::compute::take::TakeFn;
use crate::compute::ArrayCompute;
use crate::{Array, OwnedArray, ToStatic};

mod slice;
mod take;

impl ArrayCompute for ChunkedArray<'_> {
Expand All @@ -20,6 +22,10 @@ impl ArrayCompute for ChunkedArray<'_> {
Some(self)
}

fn slice(&self) -> Option<&dyn SliceFn> {
Some(self)
}

fn take(&self) -> Option<&dyn TakeFn> {
Some(self)
}
Expand Down
40 changes: 40 additions & 0 deletions vortex-array/src/array/chunked/compute/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use vortex_error::VortexResult;

use crate::array::chunked::ChunkedArray;
use crate::compute::slice::{slice, SliceFn};
use crate::{ArrayDType, IntoArray, OwnedArray};

impl SliceFn for ChunkedArray<'_> {
fn slice(&self, start: usize, stop: usize) -> VortexResult<OwnedArray> {
let (offset_chunk, offset_in_first_chunk) = self.find_chunk_idx(start);
let (length_chunk, length_in_last_chunk) = self.find_chunk_idx(stop);

if length_chunk == offset_chunk {
if let Some(chunk) = self.chunk(offset_chunk) {
return ChunkedArray::try_new(
vec![slice(&chunk, offset_in_first_chunk, length_in_last_chunk)?],
self.dtype().clone(),
)
.map(|a| a.into_array());
}
}

let mut chunks = (offset_chunk..length_chunk + 1)
.map(|i| {
self.chunk(i)
.expect("find_chunk_idx returned an incorrect index")
})
.collect::<Vec<_>>();
if let Some(c) = chunks.first_mut() {
*c = slice(c, offset_in_first_chunk, c.len())?;
}

if length_in_last_chunk == 0 {
chunks.pop();
} else if let Some(c) = chunks.last_mut() {
*c = slice(c, 0, length_in_last_chunk)?;
}

ChunkedArray::try_new(chunks, self.dtype().clone()).map(|a| a.into_array())
}
}
82 changes: 44 additions & 38 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ mod test {

use crate::array::chunked::{ChunkedArray, OwnedChunkedArray};
use crate::compute::scalar_subtract::subtract_scalar;
use crate::compute::slice::slice;
use crate::{Array, IntoArray, ToArray};

#[allow(dead_code)]
fn chunked_array() -> OwnedChunkedArray {
ChunkedArray::try_new(
vec![
Expand All @@ -175,7 +175,6 @@ mod test {
.unwrap()
}

#[allow(dead_code)]
fn assert_equal_slices<T: NativePType>(arr: Array, slice: &[T]) {
let mut values = Vec::with_capacity(arr.len());
ChunkedArray::try_from(arr)
Expand All @@ -187,58 +186,65 @@ mod test {
}

#[test]
fn test_scalar_subtract() {
let chunk1 = vec![1.0f64, 2.0, 3.0].into_array();
let chunk2 = vec![4.0f64, 5.0, 6.0].into_array();
let to_subtract = -1f64;
pub fn slice_middle() {
assert_equal_slices(slice(chunked_array().array(), 2, 5).unwrap(), &[3u64, 4, 5])
}

#[test]
pub fn slice_begin() {
assert_equal_slices(slice(chunked_array().array(), 1, 3).unwrap(), &[2u64, 3]);
}

#[test]
pub fn slice_aligned() {
assert_equal_slices(slice(chunked_array().array(), 3, 6).unwrap(), &[4u64, 5, 6]);
}

#[test]
pub fn slice_many_aligned() {
assert_equal_slices(
slice(chunked_array().array(), 0, 6).unwrap(),
&[1u64, 2, 3, 4, 5, 6],
);
}

let chunked = ChunkedArray::from_iter(vec![chunk1, chunk2]);
#[test]
pub fn slice_end() {
assert_equal_slices(slice(chunked_array().array(), 7, 8).unwrap(), &[8u64]);
}

#[test]
fn test_scalar_subtract() {
let chunked = chunked_array();
let to_subtract = 1u64;
let array = subtract_scalar(&chunked.to_array(), &to_subtract.into()).unwrap();

let chunked = ChunkedArray::try_from(array).unwrap();
let mut chunks_out = chunked.chunks();

let results = chunks_out
.next()
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<u64>()
.to_vec();
assert_eq!(results, &[0u64, 1, 2]);
let results = chunks_out
.next()
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<f64>()
.typed_data::<u64>()
.to_vec();
assert_eq!(results, &[2.0f64, 3.0, 4.0]);
assert_eq!(results, &[3u64, 4, 5]);
let results = chunks_out
.next()
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<f64>()
.typed_data::<u64>()
.to_vec();
assert_eq!(results, &[5.0f64, 6.0, 7.0]);
}

// FIXME(ngates): bring back when slicing is a compute function.
// #[test]
// pub fn slice_middle() {
// assert_equal_slices(chunked_array().slice(2, 5).unwrap(), &[3u64, 4, 5])
// }
//
// #[test]
// pub fn slice_begin() {
// assert_equal_slices(chunked_array().slice(1, 3).unwrap(), &[2u64, 3]);
// }
//
// #[test]
// pub fn slice_aligned() {
// assert_equal_slices(chunked_array().slice(3, 6).unwrap(), &[4u64, 5, 6]);
// }
//
// #[test]
// pub fn slice_many_aligned() {
// assert_equal_slices(chunked_array().slice(0, 6).unwrap(), &[1u64, 2, 3, 4, 5, 6]);
// }
//
// #[test]
// pub fn slice_end() {
// assert_equal_slices(chunked_array().slice(7, 8).unwrap(), &[8u64]);
// }
assert_eq!(results, &[6u64, 7, 8]);
}
}

0 comments on commit d2e0a08

Please sign in to comment.