Skip to content

Commit

Permalink
Optimize bitpacked take (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
lwwmanning authored Apr 5, 2024
1 parent 0496b54 commit 1bec4c2
Show file tree
Hide file tree
Showing 13 changed files with 459 additions and 127 deletions.
135 changes: 100 additions & 35 deletions vortex-array/src/array/sparse/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ impl TakeFn for SparseArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
let flat_indices = flatten_primitive(indices)?;
// if we are taking a lot of values we should build a hashmap
let (positions, physical_take_indices) = if indices.len() > 512 {
take_map(self, flat_indices)?
let (positions, physical_take_indices) = if indices.len() > 128 {
take_map(self, &flat_indices)?
} else {
take_search_sorted(self, flat_indices)?
take_search_sorted(self, &flat_indices)?
};

let taken_values = take(self.values(), &physical_take_indices)?;
Expand All @@ -149,7 +149,7 @@ impl TakeFn for SparseArray {

fn take_map(
array: &SparseArray,
indices: PrimitiveArray,
indices: &PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
let indices_map: HashMap<u64, u64> = array
.resolved_indices()
Expand All @@ -162,7 +162,7 @@ fn take_map(
.iter()
.map(|pi| *pi as u64)
.enumerate()
.filter_map(|(i, pi)| indices_map.get(&pi).copied().map(|phy_idx| (i as u64, phy_idx)))
.filter_map(|(i, pi)| indices_map.get(&pi).map(|phy_idx| (i as u64, phy_idx)))
.unzip()
});
Ok((
Expand All @@ -173,8 +173,9 @@ fn take_map(

fn take_search_sorted(
array: &SparseArray,
indices: PrimitiveArray,
indices: &PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
// adjust the input indices (to take) by the internal index offset of the array
let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
Expand All @@ -183,25 +184,42 @@ fn take_search_sorted(
});

// TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work
let physical_indices = PrimitiveArray::from(
adjusted_indices
.iter()
.map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64))
.collect::<VortexResult<Vec<_>>>()?,
);
// search_sorted for the adjusted indices (need to validate that they are an exact match still)
let physical_indices = adjusted_indices
.iter()
.map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64))
.collect::<VortexResult<Vec<_>>>()?;

// filter out indices that are out of bounds, which will cause the take to fail
let (adjusted_indices, physical_indices): (Vec<usize>, Vec<u64>) = adjusted_indices
.iter()
.zip_eq(physical_indices)
.filter(|(_, phys_idx)| *phys_idx < array.indices().len() as u64)
.unzip();

let physical_indices = PrimitiveArray::from(physical_indices);
let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?;
let (positions, patch_indices): (Vec<u64>, Vec<u64>) = match_each_integer_ptype!(taken_indices.ptype(), |$P| {
let exact_matches: Vec<bool> = match_each_integer_ptype!(taken_indices.ptype(), |$P| {
taken_indices
.typed_data::<$P>()
.iter()
.copied()
.enumerate()
.zip_eq(adjusted_indices)
.zip_eq(physical_indices.typed_data::<u64>())
.filter(|(((_, taken_idx), orig_idx), _)| *taken_idx as usize == *orig_idx)
.map(|(((i, _), _), physical_idx)| (i as u64, *physical_idx))
.unzip()
.map(|(taken_idx, adj_idx)| *taken_idx as usize == adj_idx)
.collect()
});
let (positions, patch_indices): (Vec<u64>, Vec<u64>) = physical_indices
.typed_data::<u64>()
.iter()
.enumerate()
.filter_map(|(i, phy_idx)| {
// search_sorted != binary search, so we need to filter out indices that weren't found
if exact_matches[i] {
Some((i as u64, *phy_idx))
} else {
None
}
})
.unzip();
Ok((
PrimitiveArray::from(positions),
PrimitiveArray::from(patch_indices),
Expand All @@ -210,23 +228,30 @@ fn take_search_sorted(

#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_schema::{DType, FloatWidth, Nullability};

use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::compute::take_map;
use crate::array::sparse::SparseArray;
use crate::array::Array;
use crate::compute::as_contiguous::as_contiguous;
use crate::compute::take::take;
use crate::scalar::Scalar;

#[test]
fn sparse_take() {
let sparse = SparseArray::new(
fn sparse_array() -> SparseArray {
SparseArray::new(
PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
);
)
}

#[test]
fn sparse_take() {
let sparse = sparse_array();
let taken = take(&sparse, &PrimitiveArray::from(vec![0, 47, 47, 0, 99])).unwrap();
assert_eq!(
taken
Expand All @@ -248,12 +273,7 @@ mod test {

#[test]
fn nonexistent_take() {
let sparse = SparseArray::new(
PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
);
let sparse = sparse_array();
let taken = take(&sparse, &PrimitiveArray::from(vec![69])).unwrap();
assert_eq!(
taken
Expand All @@ -275,12 +295,7 @@ mod test {

#[test]
fn ordered_take() {
let sparse = SparseArray::new(
PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
);
let sparse = sparse_array();
let taken = take(&sparse, &PrimitiveArray::from(vec![69, 37])).unwrap();
assert_eq!(
taken
Expand All @@ -300,4 +315,54 @@ mod test {
);
assert_eq!(taken.len(), 2);
}

#[test]
fn take_slices_and_reassemble() {
let sparse = sparse_array();
let indices: PrimitiveArray = (0u64..10).collect_vec().into();
let slices = (0..10)
.map(|i| sparse.slice(i * 10, (i + 1) * 10).unwrap())
.collect_vec();

let taken = slices
.iter()
.map(|s| take(s, &indices).unwrap())
.collect_vec();
for i in [1, 2, 5, 6, 7, 8] {
assert_eq!(taken[i].as_sparse().indices().len(), 0);
}
for i in [0, 3, 4, 9] {
assert_eq!(taken[i].as_sparse().indices().len(), 1);
}

let contiguous = as_contiguous(&taken).unwrap();
assert_eq!(
contiguous
.as_sparse()
.indices()
.as_primitive()
.typed_data::<u64>(),
[0u64, 7, 7, 9] // relative offsets
);
assert_eq!(
contiguous
.as_sparse()
.values()
.as_primitive()
.typed_data::<f64>(),
sparse.values().as_primitive().typed_data()
);
}

#[test]
fn test_take_map() {
let sparse = sparse_array();
let indices = PrimitiveArray::from((0u64..100).collect_vec());
let (positions, patch_indices) = take_map(&sparse, &indices).unwrap();
assert_eq!(
positions.typed_data::<u64>(),
sparse.indices().as_primitive().typed_data()
);
assert_eq!(patch_indices.typed_data::<u64>(), [0u64, 1, 2, 3]);
}
}
3 changes: 3 additions & 0 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ impl SparseArray {

// TODO(ngates): replace this with a binary search that tells us if we get an exact match.
let idx = search_sorted(self.indices(), true_index, SearchSortedSide::Left)?;
if idx >= self.indices().len() {
return Ok(None);
}

// If the value at this index is equal to the true index, then it exists in the
// indices array.
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl CompressCtx {
ValidityView::Valid(_) | ValidityView::Invalid(_) => {
Ok(Some(validity.to_owned_view()))
}
ValidityView::Array(a) => Ok(Some(Validity::array(self.compress(a, None)?))),
ValidityView::Array(a) => Ok(Some(Validity::array(self.compress(a, None)?)?)),
}
} else {
Ok(None)
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ impl<'a> ReadCtx<'a> {
[1u8] => Ok(Some(Validity::Invalid(self.read_usize()?))),
[2u8] => Ok(Some(Validity::array(
self.with_schema(&Validity::DTYPE).read()?,
))),
)?)),
_ => panic!("Invalid validity tag"),
}
} else {
Expand Down
6 changes: 3 additions & 3 deletions vortex-array/src/validity/owned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ pub enum Validity {
impl Validity {
pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);

pub fn array(array: ArrayRef) -> Self {
pub fn array(array: ArrayRef) -> VortexResult<Self> {
if !matches!(array.dtype(), &Validity::DTYPE) {
panic!("Validity array must be of type bool");
vortex_bail!("Validity array must be of type bool");
}
Self::Array(array)
Ok(Self::Array(array))
}

pub fn try_from_logical(
Expand Down
4 changes: 2 additions & 2 deletions vortex-dict/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub fn dict_encode_typed_primitive<T: NativePType>(
validity.push(false);
validity.extend(vec![true; values.len() - 1]);

Some(Validity::array(BoolArray::from(validity).into_array()))
Some(Validity::array(BoolArray::from(validity).into_array()).unwrap())
} else {
None
};
Expand Down Expand Up @@ -222,7 +222,7 @@ where
validity.push(false);
validity.extend(vec![true; offsets.len() - 2]);

Some(Validity::array(BoolArray::from(validity).into_array()))
Some(Validity::array(BoolArray::from(validity).into_array()).unwrap())
} else {
None
};
Expand Down
4 changes: 4 additions & 0 deletions vortex-fastlanes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ simplelog = { workspace = true }
[[bench]]
name = "bitpacking"
harness = false

[[bench]]
name = "bitpacking_take"
harness = false
3 changes: 3 additions & 0 deletions vortex-fastlanes/benches/bitpacking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ fn pack_unpack(c: &mut Criterion) {
});

let packed = bitpack_primitive(&values, bits);
let unpacked = unpack_primitive::<u32>(&packed, bits, 0, values.len());
assert_eq!(unpacked, values);

c.bench_function("unpack_1M", |b| {
b.iter(|| black_box(unpack_primitive::<u32>(&packed, bits, 0, values.len())));
});
Expand Down
Loading

0 comments on commit 1bec4c2

Please sign in to comment.