diff --git a/vortex-ree/src/compress.rs b/vortex-ree/src/compress.rs index 9dc9d2fd18..21d143eae2 100644 --- a/vortex-ree/src/compress.rs +++ b/vortex-ree/src/compress.rs @@ -1,11 +1,13 @@ +use std::cmp::min; + use itertools::Itertools; +use num_traits::{AsPrimitive, FromPrimitive}; use vortex::array::downcast::DowncastArrayBuiltin; use vortex::array::primitive::{PrimitiveArray, PrimitiveEncoding}; use vortex::array::{Array, ArrayRef, Encoding}; use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression}; -use vortex::compute::cast::cast; -use vortex::compute::flatten::flatten_primitive; +use vortex::match_each_integer_ptype; use vortex::ptype::{match_each_native_ptype, NativePType}; use vortex::stats::Stat; use vortex::validity::{ArrayValidity, Validity}; @@ -115,20 +117,41 @@ pub fn ree_decode( ends: &PrimitiveArray, values: &PrimitiveArray, validity: Option, + offset: usize, + length: usize, ) -> VortexResult { - // TODO(ngates): switch over ends without necessarily casting match_each_native_ptype!(values.ptype(), |$P| { - Ok(PrimitiveArray::from_nullable(ree_decode_primitive( - flatten_primitive(cast(ends, PType::U64.into())?.as_ref())?.typed_data(), - values.typed_data::<$P>(), - ), validity)) + match_each_integer_ptype!(ends.ptype(), |$E| { + Ok(PrimitiveArray::from_nullable(ree_decode_primitive( + ends.typed_data::<$E>(), + values.typed_data::<$P>(), + offset, + length, + ), validity)) + }) }) } -pub fn ree_decode_primitive(run_ends: &[u64], values: &[T]) -> Vec { - let mut decoded = Vec::with_capacity(run_ends.last().map(|x| *x as usize).unwrap_or(0_usize)); - for (&end, &value) in run_ends.iter().zip_eq(values) { - decoded.extend(std::iter::repeat(value).take(end as usize - decoded.len())); +pub fn ree_decode_primitive< + E: NativePType + AsPrimitive + FromPrimitive + Ord, + T: NativePType, +>( + run_ends: &[E], + values: &[T], + offset: usize, + length: usize, +) -> Vec { + let offset_e = ::from_usize(offset).unwrap(); + let length_e = ::from_usize(length).unwrap(); + let trimmed_ends = run_ends + .iter() + .map(|v| *v - offset_e) + .map(|v| min(v, length_e)) + .take_while(|v| *v <= length_e); + + let mut decoded = Vec::with_capacity(length); + for (end, &value) in trimmed_ends.zip_eq(values) { + decoded.extend(std::iter::repeat(value).take(end.as_() - decoded.len())); } decoded } @@ -156,7 +179,7 @@ mod test { fn decode() { let ends = PrimitiveArray::from(vec![2, 5, 10]); let values = PrimitiveArray::from(vec![1i32, 2, 3]); - let decoded = ree_decode(&ends, &values, None).unwrap(); + let decoded = ree_decode(&ends, &values, None, 0, 0).unwrap(); assert_eq!( decoded.typed_data::(), @@ -183,6 +206,8 @@ mod test { arr.ends().as_primitive(), arr.values().as_primitive(), arr.validity(), + 0, + 0, ) .unwrap(); diff --git a/vortex-ree/src/compute.rs b/vortex-ree/src/compute.rs index 0a7d11b76d..f5890a3883 100644 --- a/vortex-ree/src/compute.rs +++ b/vortex-ree/src/compute.rs @@ -1,11 +1,7 @@ -use std::cmp::min; -use vortex::array::primitive::PrimitiveArray; use vortex::array::Array; -use vortex::compute::cast::cast; -use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray}; +use vortex::compute::flatten::{flatten, FlattenFn, FlattenedArray}; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::ArrayCompute; -use vortex::ptype::PType; use vortex::scalar::Scalar; use vortex::validity::ArrayValidity; use vortex_error::{VortexError, VortexResult}; @@ -25,19 +21,17 @@ impl ArrayCompute for REEArray { impl FlattenFn for REEArray { fn flatten(&self) -> VortexResult { - let ends: PrimitiveArray = - flatten_primitive(cast(self.ends(), PType::U64.into())?.as_ref())? - .typed_data::() - .iter() - .map(|v| v - self.offset() as u64) - .map(|v| min(v, self.len() as u64)) - .take_while(|v| *v <= (self.len() as u64)) - .collect::>() - .into(); + let ends = flatten(self.ends())?; + let FlattenedArray::Primitive(pends) = ends else { + return Err(VortexError::InvalidArgument( + "REE Ends array didn't flatten to primitive".into(), + )); + }; let values = flatten(self.values())?; if let FlattenedArray::Primitive(pvalues) = values { - ree_decode(&ends, &pvalues, self.validity()).map(FlattenedArray::Primitive) + ree_decode(&pends, &pvalues, self.validity(), self.offset(), self.len()) + .map(FlattenedArray::Primitive) } else { Err(VortexError::InvalidArgument( "Cannot yet flatten non-primitive REE array".into(),