From 2f62fe0d66fb9cbbec5334e6568e548cd12b4474 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Thu, 5 Dec 2024 02:33:16 +0100 Subject: [PATCH] Adding a vector to multiple posting lists (#167) --- rs/index/src/ivf/builder.rs | 114 +++++++++++++++++++++++++--- rs/index/src/ivf/reader.rs | 2 + rs/index/src/ivf/writer.rs | 1 + rs/index/src/posting_list/mod.rs | 2 +- rs/index_writer/src/config.rs | 1 + rs/index_writer/src/index_writer.rs | 6 +- 6 files changed, 113 insertions(+), 13 deletions(-) diff --git a/rs/index/src/ivf/builder.rs b/rs/index/src/ivf/builder.rs index 7738fc5a..6a2501fc 100644 --- a/rs/index/src/ivf/builder.rs +++ b/rs/index/src/ivf/builder.rs @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use std::collections::BinaryHeap; use std::fs::{create_dir, create_dir_all}; @@ -18,8 +19,10 @@ pub struct IvfBuilderConfig { pub num_clusters: usize, pub num_data_points: usize, pub max_clusters_per_vector: usize, + // Threshold to add a vector to more than one cluster + pub distance_threshold: f32, - // Parameters for storages. + // Parameters for storages pub base_directory: String, pub memory_size: usize, pub file_size: usize, @@ -38,6 +41,12 @@ pub struct IvfBuilder { doc_id_mapping: Vec, } +// TODO(tyb): maybe merge with HNSW's one +pub struct PointAndDistance { + pub point_id: usize, + pub distance: f32, +} + #[derive(Debug)] struct PostingListInfo { centroid: Vec, @@ -176,8 +185,8 @@ impl IvfBuilder { vector: &[f32], centroids: &dyn VectorStorage, num_probes: usize, - ) -> Result> { - let mut distances: Vec<(usize, f32)> = Vec::new(); + ) -> Result> { + let mut distances: Vec = Vec::new(); let num_centroids = centroids.len(); for i in 0..num_centroids { let centroid = centroids.get(i as u32)?; @@ -185,11 +194,14 @@ impl IvfBuilder { if dist.is_nan() { println!("NAN found"); } - distances.push((i, dist)); + distances.push(PointAndDistance { + point_id: i, + distance: dist, + }); } - distances.select_nth_unstable_by(num_probes - 1, |a, b| a.1.total_cmp(&b.1)); + distances.select_nth_unstable_by(num_probes - 1, |a, b| a.distance.total_cmp(&b.distance)); distances.truncate(num_probes); - Ok(distances.into_iter().map(|(idx, _)| idx).collect()) + Ok(distances) } pub fn build_posting_lists(&mut self) -> Result<()> { @@ -197,14 +209,24 @@ impl IvfBuilder { // Assign vectors to nearest centroids for i in 0..self.vectors.len() { let vector = self.vectors.get(i as u32)?; - let nearest_centroid = Self::find_nearest_centroids( + let nearest_centroids = Self::find_nearest_centroids( &vector, self.centroids.as_ref(), self.config.max_clusters_per_vector, )?; - - for centroid_id in nearest_centroid { - posting_lists[centroid_id].push(i as u64); + // Find the nearest distance, ensuring that NaN values are treated as greater than any + // other value + let nearest_distance = nearest_centroids + .iter() + .map(|pad| pad.distance) + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Greater)) + .expect("nearest_distance should not be None"); + for point_and_distance in nearest_centroids.iter() { + if (point_and_distance.distance - nearest_distance).abs() + <= nearest_distance * self.config.distance_threshold + { + posting_lists[point_and_distance.point_id].push(i as u64); + } } } @@ -404,9 +426,78 @@ mod tests { } #[test] - fn test_ivf_builder() { + fn test_build_posting_lists() { env_logger::init(); + let temp_dir = tempdir::TempDir::new("buil_posting_lists_test") + .expect("Failed to create temporary directory"); + let base_directory = temp_dir + .path() + .to_str() + .expect("Failed to convert temporary directory path to string") + .to_string(); + let num_clusters = 2; + let num_vectors = 6; + let num_features = 1; + let file_size = 4096; + let balance_factor = 0.0; + let max_posting_list_size = usize::MAX; + let mut builder = IvfBuilder::new(IvfBuilderConfig { + max_iteration: 1000, + batch_size: 4, + num_clusters, + num_data_points: num_vectors, + max_clusters_per_vector: 2, + distance_threshold: 0.1, + base_directory, + memory_size: 1024, + file_size, + num_features, + tolerance: balance_factor, + max_posting_list_size, + }) + .expect("Failed to create builder"); + // Generate 1000 vectors of f32, dimension 4 + for i in 0..num_vectors { + builder + .add_vector(i as u64, vec![(i + 1) as f32]) + .expect("Vector should be added"); + } + + let _ = builder.add_centroid(&[2.5]); + let _ = builder.add_centroid(&[5.5]); + + let result = builder.build_posting_lists(); + assert!(result.is_ok()); + + assert_eq!( + builder + .posting_lists + .get(0) + .expect("Failed to get posting list") + .slices + .iter() + .flat_map(|slice| slice.chunks_exact(8)) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) + .collect::>(), + vec![0, 1, 2, 3] + ); + assert_eq!( + builder + .posting_lists + .get(1) + .expect("Failed to get posting list") + .slices + .iter() + .flat_map(|slice| slice.chunks_exact(8)) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) + .collect::>(), + vec![3, 4, 5] + ); + } + + #[test] + fn test_ivf_builder() { let temp_dir = tempdir::TempDir::new("ivf_builder_test") .expect("Failed to create temporary directory"); let base_directory = temp_dir @@ -426,6 +517,7 @@ mod tests { num_clusters, num_data_points: num_vectors, max_clusters_per_vector: 1, + distance_threshold: 0.1, base_directory, memory_size: 1024, file_size, diff --git a/rs/index/src/ivf/reader.rs b/rs/index/src/ivf/reader.rs index f3555eab..347b58e7 100644 --- a/rs/index/src/ivf/reader.rs +++ b/rs/index/src/ivf/reader.rs @@ -61,6 +61,7 @@ mod tests { num_clusters, num_data_points: num_vectors, max_clusters_per_vector: 1, + distance_threshold: 0.1, base_directory: base_directory.clone(), memory_size: 1024, file_size, @@ -182,6 +183,7 @@ mod tests { num_clusters, num_data_points: num_vectors, max_clusters_per_vector: 1, + distance_threshold: 0.1, base_directory: base_directory.clone(), memory_size: 1024, file_size, diff --git a/rs/index/src/ivf/writer.rs b/rs/index/src/ivf/writer.rs index bb3dc4da..72aeb279 100644 --- a/rs/index/src/ivf/writer.rs +++ b/rs/index/src/ivf/writer.rs @@ -353,6 +353,7 @@ mod tests { num_clusters, num_data_points: num_vectors, max_clusters_per_vector: 1, + distance_threshold: 0.1, base_directory: base_directory.clone(), memory_size: 1024, file_size, diff --git a/rs/index/src/posting_list/mod.rs b/rs/index/src/posting_list/mod.rs index d9c2caf9..e7679ec8 100644 --- a/rs/index/src/posting_list/mod.rs +++ b/rs/index/src/posting_list/mod.rs @@ -16,7 +16,7 @@ pub struct PostingListStorageConfig { } pub struct PostingList<'a> { - slices: Vec<&'a [u8]>, + pub slices: Vec<&'a [u8]>, } pub struct PostingListIterator<'a> { diff --git a/rs/index_writer/src/config.rs b/rs/index_writer/src/config.rs index b3dd6c97..7e4c33c6 100644 --- a/rs/index_writer/src/config.rs +++ b/rs/index_writer/src/config.rs @@ -54,6 +54,7 @@ pub struct IvfConfig { pub num_clusters: usize, pub num_data_points: usize, pub max_clusters_per_vector: usize, + pub distance_threshold: f32, // KMeans training parameters pub max_iteration: usize, diff --git a/rs/index_writer/src/index_writer.rs b/rs/index_writer/src/index_writer.rs index cabec492..dbcef441 100644 --- a/rs/index_writer/src/index_writer.rs +++ b/rs/index_writer/src/index_writer.rs @@ -10,7 +10,7 @@ use quantization::pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderC use rand::seq::SliceRandom; use crate::config::{ - HnswConfigWithBase, SpannConfigWithBase, IndexWriterConfig, IvfConfigWithBase, QuantizerType, + HnswConfigWithBase, IndexWriterConfig, IvfConfigWithBase, QuantizerType, SpannConfigWithBase, }; use crate::input::Input; @@ -131,6 +131,7 @@ impl IndexWriter { num_clusters: index_builder_config.ivf_config.num_clusters, num_data_points: index_builder_config.ivf_config.num_data_points, max_clusters_per_vector: index_builder_config.ivf_config.max_clusters_per_vector, + distance_threshold: index_builder_config.ivf_config.distance_threshold, base_directory: path.to_string(), memory_size: index_builder_config.base_config.max_memory_size, file_size: index_builder_config.base_config.file_size, @@ -191,6 +192,7 @@ impl IndexWriter { num_clusters: ivf_config.num_clusters, num_data_points: ivf_config.num_data_points, max_clusters_per_vector: ivf_config.max_clusters_per_vector, + distance_threshold: ivf_config.distance_threshold, base_directory: index_writer_config.base_config.output_path.clone(), memory_size: index_writer_config.base_config.max_memory_size, file_size: index_writer_config.base_config.file_size, @@ -453,6 +455,7 @@ mod tests { num_clusters: 2, num_data_points: 100, max_clusters_per_vector: 1, + distance_threshold: 0.1, max_iteration: 10, batch_size: 10, @@ -527,6 +530,7 @@ mod tests { num_clusters: 2, num_data_points: 100, max_clusters_per_vector: 1, + distance_threshold: 0.1, max_iteration: 10, batch_size: 10,