Skip to content

Commit

Permalink
Add TakeFn for SparseArray (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Apr 5, 2024
1 parent 27580f2 commit 8b8fded
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 12 deletions.
2 changes: 1 addition & 1 deletion vortex-array/src/array/sparse/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl EncodingCompression for SparseEncoding {
ctx.named("values")
.compress(sparse_array.values(), sparse_like.map(|sa| sa.values()))?,
sparse_array.len(),
sparse_array.fill_value.clone(),
sparse_array.fill_value().clone(),
)
.into_array())
}
Expand Down
148 changes: 147 additions & 1 deletion vortex-array/src/array/sparse/compute.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use arrow_buffer::BooleanBufferBuilder;
use itertools::Itertools;
use vortex_error::{vortex_bail, VortexResult};
Expand All @@ -9,10 +11,12 @@ use crate::array::{Array, ArrayRef};
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::take::{take, TakeFn};
use crate::compute::ArrayCompute;
use crate::match_each_native_ptype;
use crate::ptype::NativePType;
use crate::scalar::Scalar;
use crate::{match_each_integer_ptype, match_each_native_ptype};

impl ArrayCompute for SparseArray {
fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> {
Expand All @@ -26,6 +30,10 @@ impl ArrayCompute for SparseArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}

fn take(&self) -> Option<&dyn TakeFn> {
Some(self)
}
}

impl AsContiguousFn for SparseArray {
Expand Down Expand Up @@ -116,3 +124,141 @@ impl ScalarAtFn for SparseArray {
}
}
}

impl TakeFn for SparseArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
let flat_indices = flatten_primitive(indices)?;
// if we are taking a lot of values we should build a hashmap
let exact_taken_indices = if indices.len() > 512 {
take_map(self, flat_indices)?
} else {
take_search_sorted(self, flat_indices)?
};

let taken_values = take(self.values(), &exact_taken_indices)?;

Ok(SparseArray::new(
PrimitiveArray::from((0u64..exact_taken_indices.len() as u64).collect::<Vec<_>>())
.into_array(),
taken_values,
indices.len(),
self.fill_value().clone(),
)
.into_array())
}
}

fn take_map(array: &SparseArray, indices: PrimitiveArray) -> VortexResult<PrimitiveArray> {
let indices_map: HashMap<u64, u64> = array
.resolved_indices()
.iter()
.enumerate()
.map(|(i, r)| (*r as u64, i as u64))
.collect();
let patch_indices: Vec<u64> = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
.map(|i| *i as u64)
.filter_map(|pi| indices_map.get(&pi).copied())
.collect::<Vec<_>>()
});
Ok(PrimitiveArray::from(patch_indices))
}

fn take_search_sorted(
array: &SparseArray,
indices: PrimitiveArray,
) -> VortexResult<PrimitiveArray> {
let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
.map(|i| *i as usize + array.indices_offset())
.collect::<Vec<_>>()
});

// TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work
let physical_indices = PrimitiveArray::from(
adjusted_indices
.iter()
.map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64))
.collect::<VortexResult<Vec<_>>>()?,
);
let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?;
match_each_integer_ptype!(taken_indices.ptype(), |$P| {
Ok(PrimitiveArray::from(taken_indices
.typed_data::<$P>()
.iter()
.copied()
.zip_eq(adjusted_indices)
.zip_eq(physical_indices.typed_data::<u64>())
.filter(|((taken_idx, orig_idx), _)| *taken_idx as usize == *orig_idx)
.map(|(_, physical_idx)| *physical_idx)
.collect::<Vec<_>>()))
})
}

#[cfg(test)]
mod test {
use vortex_schema::{DType, FloatWidth, Nullability};

use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::SparseArray;
use crate::array::Array;
use crate::compute::take::take;
use crate::scalar::Scalar;

#[test]
fn sparse_take() {
let sparse = SparseArray::new(
PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
);
let taken = take(&sparse, &PrimitiveArray::from(vec![0, 47, 47, 0, 99])).unwrap();
assert_eq!(
taken
.as_sparse()
.indices()
.as_primitive()
.typed_data::<u64>(),
[0, 1, 2, 3, 4]
);
assert_eq!(
taken
.as_sparse()
.values()
.as_primitive()
.typed_data::<f64>(),
[1.23f64, 9.99, 9.99, 1.23, 3.5]
);
}

#[test]
fn nonexistent_take() {
let sparse = SparseArray::new(
PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
);
let taken = take(&sparse, &PrimitiveArray::from(vec![69])).unwrap();
assert_eq!(
taken
.as_sparse()
.indices()
.as_primitive()
.typed_data::<u64>(),
[]
);
assert_eq!(
taken
.as_sparse()
.values()
.as_primitive()
.typed_data::<f64>(),
[]
);
}
}
19 changes: 9 additions & 10 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
use std::sync::{Arc, RwLock};

use itertools::Itertools;
use linkme::distributed_slice;
use vortex_error::{vortex_bail, VortexResult};
use vortex_schema::DType;

use crate::array::constant::ConstantArray;
use crate::array::{check_slice_bounds, Array, ArrayRef};
use crate::compress::EncodingCompression;
use crate::compute::cast::cast;
use crate::compute::flatten::flatten_primitive;
use crate::compute::scalar_at::scalar_at;
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::ArrayCompute;
use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS};
use crate::formatter::{ArrayDisplay, ArrayFormatter};
use crate::ptype::PType;
use crate::scalar::Scalar;
use crate::serde::{ArraySerde, EncodingSerde};
use crate::stats::{Stats, StatsCompute, StatsSet};
use crate::validity::ArrayValidity;
use crate::validity::Validity;
use crate::{impl_array, ArrayWalker};
use crate::{impl_array, match_each_integer_ptype, ArrayWalker};

mod compress;
mod compute;
Expand Down Expand Up @@ -112,12 +109,14 @@ impl SparseArray {

/// Return indices as a vector of usize with the indices_offset applied.
pub fn resolved_indices(&self) -> Vec<usize> {
flatten_primitive(cast(self.indices(), PType::U64.into()).unwrap().as_ref())
.unwrap()
.typed_data::<u64>()
.iter()
.map(|v| (*v as usize) - self.indices_offset)
.collect_vec()
let flat_indices = flatten_primitive(self.indices()).unwrap();
match_each_integer_ptype!(flat_indices.ptype(), |$P| {
flat_indices
.typed_data::<$P>()
.iter()
.map(|v| (*v as usize) - self.indices_offset)
.collect::<Vec<_>>()
})
}
}

Expand Down

0 comments on commit 8b8fded

Please sign in to comment.