From 79a063072dd5a955525918503cb67b4f66bd4b8b Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 5 Apr 2024 17:34:26 +0100 Subject: [PATCH] SparseArray TakeFn returns results in the requested order (#212) --- vortex-array/src/array/sparse/compute.rs | 83 +++++++++++++++++------- 1 file changed, 61 insertions(+), 22 deletions(-) diff --git a/vortex-array/src/array/sparse/compute.rs b/vortex-array/src/array/sparse/compute.rs index d484d8af7d..cfdea6affd 100644 --- a/vortex-array/src/array/sparse/compute.rs +++ b/vortex-array/src/array/sparse/compute.rs @@ -129,17 +129,16 @@ impl TakeFn for SparseArray { fn take(&self, indices: &dyn Array) -> VortexResult { let flat_indices = flatten_primitive(indices)?; // if we are taking a lot of values we should build a hashmap - let exact_taken_indices = if indices.len() > 512 { + let (positions, physical_take_indices) = if indices.len() > 512 { take_map(self, flat_indices)? } else { take_search_sorted(self, flat_indices)? }; - let taken_values = take(self.values(), &exact_taken_indices)?; + let taken_values = take(self.values(), &physical_take_indices)?; Ok(SparseArray::new( - PrimitiveArray::from((0u64..exact_taken_indices.len() as u64).collect::>()) - .into_array(), + positions.into_array(), taken_values, indices.len(), self.fill_value().clone(), @@ -148,27 +147,34 @@ impl TakeFn for SparseArray { } } -fn take_map(array: &SparseArray, indices: PrimitiveArray) -> VortexResult { +fn take_map( + array: &SparseArray, + indices: PrimitiveArray, +) -> VortexResult<(PrimitiveArray, PrimitiveArray)> { let indices_map: HashMap = array .resolved_indices() .iter() .enumerate() .map(|(i, r)| (*r as u64, i as u64)) .collect(); - let patch_indices: Vec = match_each_integer_ptype!(indices.ptype(), |$P| { + let (positions, patch_indices): (Vec, Vec) = match_each_integer_ptype!(indices.ptype(), |$P| { indices.typed_data::<$P>() .iter() - .map(|i| *i as u64) - .filter_map(|pi| indices_map.get(&pi).copied()) - .collect::>() + .map(|pi| *pi as u64) + .enumerate() + .filter_map(|(i, pi)| indices_map.get(&pi).copied().map(|phy_idx| (i as u64, phy_idx))) + .unzip() }); - Ok(PrimitiveArray::from(patch_indices)) + Ok(( + PrimitiveArray::from(positions), + PrimitiveArray::from(patch_indices), + )) } fn take_search_sorted( array: &SparseArray, indices: PrimitiveArray, -) -> VortexResult { +) -> VortexResult<(PrimitiveArray, PrimitiveArray)> { let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| { indices.typed_data::<$P>() .iter() @@ -184,17 +190,22 @@ fn take_search_sorted( .collect::>>()?, ); let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?; - match_each_integer_ptype!(taken_indices.ptype(), |$P| { - Ok(PrimitiveArray::from(taken_indices - .typed_data::<$P>() - .iter() - .copied() - .zip_eq(adjusted_indices) - .zip_eq(physical_indices.typed_data::()) - .filter(|((taken_idx, orig_idx), _)| *taken_idx as usize == *orig_idx) - .map(|(_, physical_idx)| *physical_idx) - .collect::>())) - }) + let (positions, patch_indices): (Vec, Vec) = match_each_integer_ptype!(taken_indices.ptype(), |$P| { + taken_indices + .typed_data::<$P>() + .iter() + .copied() + .enumerate() + .zip_eq(adjusted_indices) + .zip_eq(physical_indices.typed_data::()) + .filter(|(((_, taken_idx), orig_idx), _)| *taken_idx as usize == *orig_idx) + .map(|(((i, _), _), physical_idx)| (i as u64, *physical_idx)) + .unzip() + }); + Ok(( + PrimitiveArray::from(positions), + PrimitiveArray::from(patch_indices), + )) } #[cfg(test)] @@ -261,4 +272,32 @@ mod test { [] ); } + + #[test] + fn ordered_take() { + let sparse = SparseArray::new( + PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(), + PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(), + 100, + Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), + ); + let taken = take(&sparse, &PrimitiveArray::from(vec![69, 37])).unwrap(); + assert_eq!( + taken + .as_sparse() + .indices() + .as_primitive() + .typed_data::(), + [1] + ); + assert_eq!( + taken + .as_sparse() + .values() + .as_primitive() + .typed_data::(), + [0.47f64] + ); + assert_eq!(taken.len(), 2); + } }