Skip to content

Commit

Permalink
REE flattening pushes slicing to decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 committed Mar 26, 2024
1 parent 303a168 commit ca1b2bc
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
49 changes: 37 additions & 12 deletions vortex-ree/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::cmp::min;

use itertools::Itertools;
use num_traits::{AsPrimitive, FromPrimitive};

use vortex::array::downcast::DowncastArrayBuiltin;
use vortex::array::primitive::{PrimitiveArray, PrimitiveEncoding};
use vortex::array::{Array, ArrayRef, Encoding};
use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression};
use vortex::compute::cast::cast;
use vortex::compute::flatten::flatten_primitive;
use vortex::match_each_integer_ptype;
use vortex::ptype::{match_each_native_ptype, NativePType};
use vortex::stats::Stat;
use vortex::validity::{ArrayValidity, Validity};
Expand Down Expand Up @@ -115,20 +117,41 @@ pub fn ree_decode(
ends: &PrimitiveArray,
values: &PrimitiveArray,
validity: Option<Validity>,
offset: usize,
length: usize,
) -> VortexResult<PrimitiveArray> {
// TODO(ngates): switch over ends without necessarily casting
match_each_native_ptype!(values.ptype(), |$P| {
Ok(PrimitiveArray::from_nullable(ree_decode_primitive(
flatten_primitive(cast(ends, PType::U64.into())?.as_ref())?.typed_data(),
values.typed_data::<$P>(),
), validity))
match_each_integer_ptype!(ends.ptype(), |$E| {
Ok(PrimitiveArray::from_nullable(ree_decode_primitive(
ends.typed_data::<$E>(),
values.typed_data::<$P>(),
offset,
length,
), validity))
})
})
}

pub fn ree_decode_primitive<T: NativePType>(run_ends: &[u64], values: &[T]) -> Vec<T> {
let mut decoded = Vec::with_capacity(run_ends.last().map(|x| *x as usize).unwrap_or(0_usize));
for (&end, &value) in run_ends.iter().zip_eq(values) {
decoded.extend(std::iter::repeat(value).take(end as usize - decoded.len()));
pub fn ree_decode_primitive<
E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord,
T: NativePType,
>(
run_ends: &[E],
values: &[T],
offset: usize,
length: usize,
) -> Vec<T> {
let offset_e = <E as FromPrimitive>::from_usize(offset).unwrap();
let length_e = <E as FromPrimitive>::from_usize(length).unwrap();
let trimmed_ends = run_ends
.iter()
.map(|v| *v - offset_e)
.map(|v| min(v, length_e))
.take_while(|v| *v <= length_e);

let mut decoded = Vec::with_capacity(length);
for (end, &value) in trimmed_ends.zip_eq(values) {
decoded.extend(std::iter::repeat(value).take(end.as_() - decoded.len()));
}
decoded
}
Expand Down Expand Up @@ -156,7 +179,7 @@ mod test {
fn decode() {
let ends = PrimitiveArray::from(vec![2, 5, 10]);
let values = PrimitiveArray::from(vec![1i32, 2, 3]);
let decoded = ree_decode(&ends, &values, None).unwrap();
let decoded = ree_decode(&ends, &values, None, 0, 0).unwrap();

assert_eq!(
decoded.typed_data::<i32>(),
Expand All @@ -183,6 +206,8 @@ mod test {
arr.ends().as_primitive(),
arr.values().as_primitive(),
arr.validity(),
0,
0,
)
.unwrap();

Expand Down
24 changes: 9 additions & 15 deletions vortex-ree/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use std::cmp::min;
use vortex::array::primitive::PrimitiveArray;
use vortex::array::Array;
use vortex::compute::cast::cast;
use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray};
use vortex::compute::flatten::{flatten, FlattenFn, FlattenedArray};
use vortex::compute::scalar_at::{scalar_at, ScalarAtFn};
use vortex::compute::ArrayCompute;
use vortex::ptype::PType;
use vortex::scalar::Scalar;
use vortex::validity::ArrayValidity;
use vortex_error::{VortexError, VortexResult};
Expand All @@ -25,19 +21,17 @@ impl ArrayCompute for REEArray {

impl FlattenFn for REEArray {
fn flatten(&self) -> VortexResult<FlattenedArray> {
let ends: PrimitiveArray =
flatten_primitive(cast(self.ends(), PType::U64.into())?.as_ref())?
.typed_data::<u64>()
.iter()
.map(|v| v - self.offset() as u64)
.map(|v| min(v, self.len() as u64))
.take_while(|v| *v <= (self.len() as u64))
.collect::<Vec<u64>>()
.into();
let ends = flatten(self.ends())?;
let FlattenedArray::Primitive(pends) = ends else {
return Err(VortexError::InvalidArgument(
"REE Ends array didn't flatten to primitive".into(),
));
};

let values = flatten(self.values())?;
if let FlattenedArray::Primitive(pvalues) = values {
ree_decode(&ends, &pvalues, self.validity()).map(FlattenedArray::Primitive)
ree_decode(&pends, &pvalues, self.validity(), self.offset(), self.len())
.map(FlattenedArray::Primitive)
} else {
Err(VortexError::InvalidArgument(
"Cannot yet flatten non-primitive REE array".into(),
Expand Down

0 comments on commit ca1b2bc

Please sign in to comment.