diff --git a/vortex-array/src/compute/flatten.rs b/vortex-array/src/compute/flatten.rs index 525ad4636f..6578e7dc56 100644 --- a/vortex-array/src/compute/flatten.rs +++ b/vortex-array/src/compute/flatten.rs @@ -49,6 +49,16 @@ pub fn flatten(array: &dyn Array) -> VortexResult { }) } +pub fn flatten_varbin(array: &dyn Array) -> VortexResult { + if let FlattenedArray::VarBin(vb) = flatten(array)? { + Ok(vb) + } else { + Err(VortexError::InvalidArgument( + format!("Cannot flatten array {} into varbin", array).into(), + )) + } +} + pub fn flatten_bool(array: &dyn Array) -> VortexResult { if let FlattenedArray::Bool(b) = flatten(array)? { Ok(b) diff --git a/vortex-dict/src/compute.rs b/vortex-dict/src/compute.rs index 96ae1c3de4..7bb4060f4f 100644 --- a/vortex-dict/src/compute.rs +++ b/vortex-dict/src/compute.rs @@ -1,11 +1,21 @@ +use std::sync::Arc; + +use vortex::array::primitive::PrimitiveArray; +use vortex::array::varbin::VarBinArray; +use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray}; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; +use vortex::compute::take::take; use vortex::compute::ArrayCompute; use vortex::scalar::Scalar; -use vortex_error::VortexResult; +use vortex_error::{VortexError, VortexResult}; use crate::DictArray; impl ArrayCompute for DictArray { + fn flatten(&self) -> Option<&dyn FlattenFn> { + Some(self) + } + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -17,3 +27,69 @@ impl ScalarAtFn for DictArray { scalar_at(self.values(), dict_index) } } + +impl FlattenFn for DictArray { + fn flatten(&self) -> VortexResult { + let codes = flatten_primitive(self.codes())?; + let values = flatten(self.values())?; + + match values { + FlattenedArray::Primitive(v) => take(&v, &codes).map(|r| { + FlattenedArray::Primitive( + Arc::try_unwrap(r.into_any().downcast::().unwrap()) + .expect("Expected take on PrimitiveArray array to produce new array"), + ) + }), + FlattenedArray::VarBin(vb) => take(&vb, &codes).map(|r| { + FlattenedArray::VarBin( + Arc::try_unwrap(r.into_any().downcast::().unwrap()) + .expect("Expected take on VarBin array to produce new array"), + ) + }), + _ => Err(VortexError::InvalidArgument( + "Only VarBin and Primitive values array are supported".into(), + )), + } + } +} + +#[cfg(test)] +mod test { + use vortex::array::downcast::DowncastArrayBuiltin; + use vortex::array::primitive::PrimitiveArray; + use vortex::array::varbin::VarBinArray; + use vortex::array::Array; + use vortex::compute::flatten::{flatten_primitive, flatten_varbin}; + use vortex_schema::{DType, Nullability}; + + use crate::{dict_encode_typed_primitive, dict_encode_varbin, DictArray}; + + #[test] + fn flatten_nullable_primitive() { + let reference = + PrimitiveArray::from_iter(vec![Some(42), Some(-9), None, Some(42), None, Some(-9)]); + let (codes, values) = dict_encode_typed_primitive::(&reference); + let dict = DictArray::new(codes.to_array(), values.to_array()); + let flattened_dict = flatten_primitive(&dict).unwrap(); + assert_eq!(flattened_dict.buffer(), reference.buffer()); + } + + #[test] + fn flatten_nullable_varbin() { + let reference = VarBinArray::from_iter( + vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")], + DType::Utf8(Nullability::Nullable), + ); + let (codes, values) = dict_encode_varbin(&reference); + let dict = DictArray::new(codes.to_array(), values.to_array()); + let flattened_dict = flatten_varbin(&dict).unwrap(); + assert_eq!( + flattened_dict.offsets().as_primitive().buffer(), + reference.offsets().as_primitive().buffer() + ); + assert_eq!( + flattened_dict.bytes().as_primitive().buffer(), + reference.bytes().as_primitive().buffer() + ); + } +}