Skip to content

Commit

Permalink
get_posting_list now returns &[u8] and not &[u64]
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Jan 2, 2025
1 parent 9ff16eb commit 091357d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
2 changes: 1 addition & 1 deletion rs/demo/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "demo"
version = "0.1.0"
edition = "2024"
edition = "2021"

[dependencies]
tonic.workspace = true
Expand Down
25 changes: 17 additions & 8 deletions rs/index/src/ivf/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use quantization::quantization::Quantizer;
use quantization::typing::VectorOps;
use utils::distance::l2::L2DistanceCalculator;
use utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD;
use utils::mem::transmute_u8_to_slice;
use utils::DistanceCalculator;

use crate::index::Searchable;
Expand Down Expand Up @@ -71,18 +72,18 @@ impl<Q: Quantizer> Ivf<Q> {
query: &[f32],
context: &mut SearchContext,
) -> Vec<IdWithScore> {
if let Ok(list) = self.index_storage.get_posting_list(centroid) {
if let Ok(byte_slice) = self.index_storage.get_posting_list(centroid) {
let quantized_query = Q::QuantizedT::process_vector(query, &self.quantizer);
let mut results: Vec<IdWithScore> = Vec::new();
for &idx in list {
match self.vector_storage.get(idx as usize, context) {
for idx in transmute_u8_to_slice::<u64>(byte_slice).iter() {
match self.vector_storage.get(*idx as usize, context) {
Some(vector) => {
let distance =
self.quantizer
.distance(&quantized_query, vector, StreamingSIMD);
results.push(IdWithScore {
score: distance,
id: idx,
id: *idx,
});
}
None => {}
Expand Down Expand Up @@ -339,10 +340,18 @@ mod tests {
let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer);

assert_eq!(ivf.num_clusters, num_clusters);
let cluster_0 = ivf.index_storage.get_posting_list(0);
let cluster_1 = ivf.index_storage.get_posting_list(1);
assert!(cluster_0.map_or(false, |list| list.contains(&0)));
assert!(cluster_1.map_or(false, |list| list.contains(&2)));
let cluster_0 = transmute_u8_to_slice::<u64>(
ivf.index_storage
.get_posting_list(0)
.expect("Failed to get posting list"),
);
let cluster_1 = transmute_u8_to_slice::<u64>(
ivf.index_storage
.get_posting_list(1)
.expect("Failed to get posting list"),
);
assert!(cluster_0.contains(&0));
assert!(cluster_1.contains(&2));
}

#[test]
Expand Down
17 changes: 10 additions & 7 deletions rs/index/src/ivf/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ mod tests {

use quantization::noq::noq::{NoQuantizer, NoQuantizerWriter};
use tempdir::TempDir;
use utils::mem::transmute_u8_to_slice;
use utils::test_utils::generate_random_vector;

use super::*;
Expand Down Expand Up @@ -177,10 +178,12 @@ mod tests {
.posting_lists_mut()
.get(i as u32)
.expect("Failed to read vector from FileBackedAppendablePostingListStorage");
let read_vector = index
.index_storage
.get_posting_list(i)
.expect("Failed to read vector from FixedIndexFile");
let read_vector = transmute_u8_to_slice::<u64>(
index
.index_storage
.get_posting_list(i)
.expect("Failed to read vector from FixedIndexFile"),
);
for (val_ref, val_read) in ref_vector.iter().zip(read_vector.iter()) {
assert_eq!(val_ref, *val_read);
}
Expand Down Expand Up @@ -244,9 +247,9 @@ mod tests {

for i in 0..num_centroids {
// Assert that posting lists size is less than or equal to max_posting_list_size
let posting_list = index.index_storage.get_posting_list(i);
assert!(posting_list.is_ok());
let posting_list = posting_list.unwrap();
let posting_list_byte_arr = index.index_storage.get_posting_list(i);
assert!(posting_list_byte_arr.is_ok());
let posting_list = transmute_u8_to_slice::<u64>(posting_list_byte_arr.unwrap());

// It's possible that the posting list size is more than max_posting_list_size,
// but it should be less than 2x.
Expand Down
25 changes: 14 additions & 11 deletions rs/index/src/posting_list/combined_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl FixedIndexFile {
Ok(transmute_u8_to_slice::<f32>(slice))
}

pub fn get_posting_list(&self, index: usize) -> Result<&[u64]> {
pub fn get_posting_list(&self, index: usize) -> Result<&[u8]> {
if index >= self.header.num_clusters as usize {
return Err(anyhow!("Index out of bound"));
}
Expand All @@ -150,8 +150,7 @@ impl FixedIndexFile {
..metadata_offset + PL_METADATA_LEN * size_of::<u64>()];
let pl_offset = u64::from_le_bytes(slice.try_into()?) as usize + posting_list_start_offset;

let slice = &self.mmap[pl_offset..pl_offset + pl_len * size_of::<u64>()];
Ok(transmute_u8_to_slice::<u64>(slice))
Ok(&self.mmap[pl_offset..pl_offset + pl_len * size_of::<u64>()])
}

pub fn header(&self) -> &Header {
Expand Down Expand Up @@ -281,16 +280,20 @@ mod tests {
assert!(combined_file.get_centroid(2).is_err());

assert_eq!(
combined_file
.get_posting_list(0)
.expect("Failed to read posting_list"),
&posting_lists[0]
transmute_u8_to_slice::<u64>(
combined_file
.get_posting_list(0)
.expect("Failed to read posting_list")
),
posting_lists[0]
);
assert_eq!(
combined_file
.get_posting_list(1)
.expect("Failed to read posting_list"),
&posting_lists[1]
transmute_u8_to_slice::<u64>(
combined_file
.get_posting_list(1)
.expect("Failed to read posting_list")
),
posting_lists[1]
);
assert!(combined_file.get_posting_list(2).is_err());
}
Expand Down

0 comments on commit 091357d

Please sign in to comment.