From 9df51bc9ee15180b579a2e4dc26f4c0d84d0e665 Mon Sep 17 00:00:00 2001 From: BuildKite Date: Tue, 31 Dec 2024 22:00:34 +0100 Subject: [PATCH] get_posting_list now returns iterator and not &[u64] --- rs/demo/Cargo.toml | 2 +- rs/index/src/ivf/index.rs | 25 +++++++++++++++------- rs/index/src/ivf/reader.rs | 17 +++++++++------ rs/index/src/posting_list/combined_file.rs | 25 ++++++++++++---------- 4 files changed, 42 insertions(+), 27 deletions(-) diff --git a/rs/demo/Cargo.toml b/rs/demo/Cargo.toml index 1423047e..ebbc7840 100644 --- a/rs/demo/Cargo.toml +++ b/rs/demo/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "demo" version = "0.1.0" -edition = "2024" +edition = "2021" [dependencies] tonic.workspace = true diff --git a/rs/index/src/ivf/index.rs b/rs/index/src/ivf/index.rs index 67d0a9cf..2edf9ba7 100644 --- a/rs/index/src/ivf/index.rs +++ b/rs/index/src/ivf/index.rs @@ -5,6 +5,7 @@ use quantization::quantization::Quantizer; use quantization::typing::VectorOps; use utils::distance::l2::L2DistanceCalculator; use utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD; +use utils::mem::transmute_u8_to_slice; use utils::DistanceCalculator; use crate::index::Searchable; @@ -71,18 +72,18 @@ impl Ivf { query: &[f32], context: &mut SearchContext, ) -> Vec { - if let Ok(list) = self.index_storage.get_posting_list(centroid) { + if let Ok(byte_slice) = self.index_storage.get_posting_list(centroid) { let quantized_query = Q::QuantizedT::process_vector(query, &self.quantizer); let mut results: Vec = Vec::new(); - for &idx in list { - match self.vector_storage.get(idx as usize, context) { + for idx in transmute_u8_to_slice::(byte_slice).iter() { + match self.vector_storage.get(*idx as usize, context) { Some(vector) => { let distance = self.quantizer .distance(&quantized_query, vector, StreamingSIMD); results.push(IdWithScore { score: distance, - id: idx, + id: *idx, }); } None => {} @@ -339,10 +340,18 @@ mod tests { let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer); assert_eq!(ivf.num_clusters, num_clusters); - let cluster_0 = ivf.index_storage.get_posting_list(0); - let cluster_1 = ivf.index_storage.get_posting_list(1); - assert!(cluster_0.map_or(false, |list| list.contains(&0))); - assert!(cluster_1.map_or(false, |list| list.contains(&2))); + let cluster_0 = transmute_u8_to_slice::( + ivf.index_storage + .get_posting_list(0) + .expect("Failed to get posting list"), + ); + let cluster_1 = transmute_u8_to_slice::( + ivf.index_storage + .get_posting_list(1) + .expect("Failed to get posting list"), + ); + assert!(cluster_0.contains(&0)); + assert!(cluster_1.contains(&2)); } #[test] diff --git a/rs/index/src/ivf/reader.rs b/rs/index/src/ivf/reader.rs index c8cc8b7a..2045386c 100644 --- a/rs/index/src/ivf/reader.rs +++ b/rs/index/src/ivf/reader.rs @@ -44,6 +44,7 @@ mod tests { use quantization::noq::noq::{NoQuantizer, NoQuantizerWriter}; use tempdir::TempDir; + use utils::mem::transmute_u8_to_slice; use utils::test_utils::generate_random_vector; use super::*; @@ -177,10 +178,12 @@ mod tests { .posting_lists_mut() .get(i as u32) .expect("Failed to read vector from FileBackedAppendablePostingListStorage"); - let read_vector = index - .index_storage - .get_posting_list(i) - .expect("Failed to read vector from FixedIndexFile"); + let read_vector = transmute_u8_to_slice::( + index + .index_storage + .get_posting_list(i) + .expect("Failed to read vector from FixedIndexFile"), + ); for (val_ref, val_read) in ref_vector.iter().zip(read_vector.iter()) { assert_eq!(val_ref, *val_read); } @@ -244,9 +247,9 @@ mod tests { for i in 0..num_centroids { // Assert that posting lists size is less than or equal to max_posting_list_size - let posting_list = index.index_storage.get_posting_list(i); - assert!(posting_list.is_ok()); - let posting_list = posting_list.unwrap(); + let posting_list_byte_arr = index.index_storage.get_posting_list(i); + assert!(posting_list_byte_arr.is_ok()); + let posting_list = transmute_u8_to_slice::(posting_list_byte_arr.unwrap()); // It's possible that the posting list size is more than max_posting_list_size, // but it should be less than 2x. diff --git a/rs/index/src/posting_list/combined_file.rs b/rs/index/src/posting_list/combined_file.rs index 940c7b96..a0f036ac 100644 --- a/rs/index/src/posting_list/combined_file.rs +++ b/rs/index/src/posting_list/combined_file.rs @@ -132,7 +132,7 @@ impl FixedIndexFile { Ok(transmute_u8_to_slice::(slice)) } - pub fn get_posting_list(&self, index: usize) -> Result<&[u64]> { + pub fn get_posting_list(&self, index: usize) -> Result<&[u8]> { if index >= self.header.num_clusters as usize { return Err(anyhow!("Index out of bound")); } @@ -150,8 +150,7 @@ impl FixedIndexFile { ..metadata_offset + PL_METADATA_LEN * size_of::()]; let pl_offset = u64::from_le_bytes(slice.try_into()?) as usize + posting_list_start_offset; - let slice = &self.mmap[pl_offset..pl_offset + pl_len * size_of::()]; - Ok(transmute_u8_to_slice::(slice)) + Ok(&self.mmap[pl_offset..pl_offset + pl_len * size_of::()]) } pub fn header(&self) -> &Header { @@ -281,16 +280,20 @@ mod tests { assert!(combined_file.get_centroid(2).is_err()); assert_eq!( - combined_file - .get_posting_list(0) - .expect("Failed to read posting_list"), - &posting_lists[0] + transmute_u8_to_slice::( + combined_file + .get_posting_list(0) + .expect("Failed to read posting_list") + ), + posting_lists[0] ); assert_eq!( - combined_file - .get_posting_list(1) - .expect("Failed to read posting_list"), - &posting_lists[1] + transmute_u8_to_slice::( + combined_file + .get_posting_list(1) + .expect("Failed to read posting_list") + ), + posting_lists[1] ); assert!(combined_file.get_posting_list(2).is_err()); }