From 39b827d5e1bf1f19c10ba51a21babbaf830ea290 Mon Sep 17 00:00:00 2001 From: Hieu Pham Date: Tue, 31 Dec 2024 19:30:26 -0800 Subject: [PATCH] Optimizations to KMeans implementation (#252) --- Cargo.lock | 8 + Cargo.toml | 1 + py/demo_search.py | 35 +- rs/demo/src/main.rs | 4 +- rs/demo/src/search.rs | 6 +- rs/index/Cargo.toml | 2 + rs/index/src/collection/reader.rs | 2 +- rs/index/src/hnsw/builder.rs | 1 - rs/index/src/hnsw/index.rs | 1 - rs/index/src/ivf/builder.rs | 344 ++++++++++++----- rs/index/src/ivf/reader.rs | 12 +- rs/index/src/ivf/writer.rs | 18 +- rs/index/src/spann/builder.rs | 10 +- rs/index/src/spann/index.rs | 26 +- rs/index/src/spann/reader.rs | 2 +- rs/index/src/spann/writer.rs | 2 +- rs/index_writer/src/index_writer.rs | 4 +- rs/index_writer/src/input/hdf5.rs | 19 +- rs/utils/src/kmeans_builder/kmeans_builder.rs | 351 ++++++++++-------- rs/utils/src/lib.rs | 4 + 20 files changed, 551 insertions(+), 301 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0ee294..717859f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -173,6 +173,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "atty" version = "0.2.14" @@ -1362,6 +1368,7 @@ name = "index" version = "0.1.0" dependencies = [ "anyhow", + "atomic_refcell", "bit-vec", "byteorder", "dashmap", @@ -1373,6 +1380,7 @@ dependencies = [ "ordered-float", "quantization", "rand 0.8.5", + "rayon", "roaring", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 3ddfc56..6ad88d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,3 +58,4 @@ rayon = "1.10.0" sorted-vec = "0.8.5" dashmap = "6.1.0" reqwest = {version = "0.12.11", features = ["json"]} +atomic_refcell = "0.1.13" diff --git a/py/demo_search.py b/py/demo_search.py index 3287282..dd11fc5 100644 --- a/py/demo_search.py +++ b/py/demo_search.py @@ -5,26 +5,29 @@ if __name__ == "__main__": # Example usage for IndexServer muopdb_client = mp.IndexServerClient() - query = "personal career development" + query = "baby goes to school" query_vector = ollama.embeddings(model='nomic-embed-text', prompt=query)["embedding"] # Read back the raw data to print the responses with open("/mnt/muopdb/raw/1m_sentences.txt", "r") as f: sentences = [line.strip() for line in f] - start = time.time() - search_response = muopdb_client.search( - index_name="test-collection-1", - vector=query_vector, - top_k=5, - ef_construction=50, - record_metrics=False - ) - end = time.time() - print(f"Time taken for search: {end - start} seconds") + i = 0 + while i < 5: + start = time.time() + search_response = muopdb_client.search( + index_name="test-collection-1", + vector=query_vector, + top_k=5, + ef_construction=50, + record_metrics=True + ) + end = time.time() + print(f"Time taken for search: {end - start} seconds") - print(f"Number of results: {len(search_response.ids)}") - print("================") - for id in search_response.ids: - print(f"RESULT: {sentences[id - 1]}") - print("================") + print(f"Number of results: {len(search_response.ids)}") + print("================") + for id in search_response.ids: + print(f"RESULT: {sentences[id - 1]}") + print("================") + i += 1 diff --git a/rs/demo/src/main.rs b/rs/demo/src/main.rs index b526ff1..5d783d2 100644 --- a/rs/demo/src/main.rs +++ b/rs/demo/src/main.rs @@ -3,9 +3,9 @@ use std::time::Instant; use anyhow::{Context, Result}; use hdf5::File; use log::{LevelFilter, info}; +use ndarray::s; use proto::muopdb::index_server_client::IndexServerClient; use proto::muopdb::{FlushRequest, InsertRequest}; -use ndarray::s; #[tokio::main] async fn main() -> Result<()> { @@ -54,7 +54,7 @@ async fn main() -> Result<()> { }); client.insert(request).await?; - start_idx = end_idx; + start_idx = end_idx; } let mut duration = start.elapsed(); diff --git a/rs/demo/src/search.rs b/rs/demo/src/search.rs index f51ed9c..1e72642 100644 --- a/rs/demo/src/search.rs +++ b/rs/demo/src/search.rs @@ -42,7 +42,11 @@ async fn main() -> Result<()> { let response_body = response.text().await.expect("Failed to read response body"); let response_map: serde_json::Value = serde_json::from_str(&response_body).unwrap(); let query_vector_value = response_map["embedding"].as_array().unwrap(); - let query_vector: Vec = query_vector_value.iter().map(|x| x.as_f64().unwrap()).map(|x| x as f32).collect(); + let query_vector: Vec = query_vector_value + .iter() + .map(|x| x.as_f64().unwrap()) + .map(|x| x as f32) + .collect(); // Create search request let request = tonic::Request::new(SearchRequest { diff --git a/rs/index/Cargo.toml b/rs/index/Cargo.toml index 3af595e..acdcf90 100644 --- a/rs/index/Cargo.toml +++ b/rs/index/Cargo.toml @@ -22,3 +22,5 @@ tempdir.workspace = true utils.workspace = true serde.workspace = true serde_json.workspace = true +rayon.workspace = true +atomic_refcell.workspace = true diff --git a/rs/index/src/collection/reader.rs b/rs/index/src/collection/reader.rs index 16b7ace..591771b 100644 --- a/rs/index/src/collection/reader.rs +++ b/rs/index/src/collection/reader.rs @@ -71,7 +71,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters: 10, - num_data_points: 1000, + num_data_points_for_clustering: 1000, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: "./".to_string(), diff --git a/rs/index/src/hnsw/builder.rs b/rs/index/src/hnsw/builder.rs index bbed143..ff5adce 100644 --- a/rs/index/src/hnsw/builder.rs +++ b/rs/index/src/hnsw/builder.rs @@ -203,7 +203,6 @@ impl HnswBuilder { } while let Some(node) = queue.pop_front() { visited.set(node.try_into().unwrap(), true); - debug!("Visited node {}", node); let edges = graph.edges.get_mut(&node); if let Some(edges) = edges { diff --git a/rs/index/src/hnsw/index.rs b/rs/index/src/hnsw/index.rs index ecd19df..1afcc7b 100644 --- a/rs/index/src/hnsw/index.rs +++ b/rs/index/src/hnsw/index.rs @@ -394,7 +394,6 @@ impl Searchable for Hnsw { ef_construction: u32, context: &mut SearchContext, ) -> Option> { - // TODO(hicder): Add ef parameter Some(self.ann_search(query, k, ef_construction, context)) } } diff --git a/rs/index/src/ivf/builder.rs b/rs/index/src/ivf/builder.rs index a7479fe..e83f11a 100644 --- a/rs/index/src/ivf/builder.rs +++ b/rs/index/src/ivf/builder.rs @@ -1,14 +1,18 @@ -use std::cmp::{Ordering, Reverse}; +use std::cmp::{max, min, Ordering, Reverse}; use std::collections::{BinaryHeap, HashMap}; use std::fs::{create_dir, create_dir_all}; use std::io::ErrorKind; +use std::time::Instant; use anyhow::{anyhow, Result}; +use atomic_refcell::AtomicRefCell; +use log::debug; use rand::seq::SliceRandom; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sorted_vec::SortedVec; use utils::distance::l2::L2DistanceCalculator; use utils::kmeans_builder::kmeans_builder::{KMeansBuilder, KMeansVariant}; -use utils::{CalculateSquared, DistanceCalculator}; +use utils::{ceil_div, CalculateSquared, DistanceCalculator}; use crate::posting_list::file::FileBackedAppendablePostingListStorage; use crate::posting_list::PostingListStorage; @@ -19,7 +23,7 @@ pub struct IvfBuilderConfig { pub max_iteration: usize, pub batch_size: usize, pub num_clusters: usize, - pub num_data_points: usize, + pub num_data_points_for_clustering: usize, pub max_clusters_per_vector: usize, // Threshold to add a vector to more than one cluster pub distance_threshold: f32, @@ -37,8 +41,8 @@ pub struct IvfBuilderConfig { pub struct IvfBuilder { config: IvfBuilderConfig, - vectors: Box>, - centroids: Box>, + vectors: AtomicRefCell + Send + Sync>>, + centroids: AtomicRefCell + Send + Sync>>, posting_lists: Box PostingListStorage<'a>>, doc_id_mapping: Vec, } @@ -147,22 +151,24 @@ impl IvfBuilder { let vectors_path = format!("{}/builder_vector_storage", config.base_directory); create_dir(&vectors_path)?; - let vectors = Box::new(FileBackedAppendableVectorStorage::::new( - vectors_path, - config.memory_size, - config.file_size, - config.num_features, - )); + let vectors: AtomicRefCell + Send + Sync>> = + AtomicRefCell::new(Box::new(FileBackedAppendableVectorStorage::::new( + vectors_path, + config.memory_size, + config.file_size, + config.num_features, + ))); let centroids_path = format!("{}/builder_centroid_storage", config.base_directory); create_dir(¢roids_path)?; - let centroids = Box::new(FileBackedAppendableVectorStorage::::new( - centroids_path, - config.memory_size, - config.file_size, - config.num_features, - )); + let centroids: AtomicRefCell + Send + Sync>> = + AtomicRefCell::new(Box::new(FileBackedAppendableVectorStorage::::new( + centroids_path, + config.memory_size, + config.file_size, + config.num_features, + ))); let posting_lists_path = format!("{}/builder_posting_list_storage", config.base_directory); create_dir(&posting_lists_path)?; @@ -186,16 +192,16 @@ impl IvfBuilder { &self.config } - pub fn vectors(&self) -> &dyn VectorStorage { - &*self.vectors + pub fn vectors(&self) -> &AtomicRefCell + Send + Sync>> { + &self.vectors } pub fn doc_id_mapping(&self) -> &[u64] { &*self.doc_id_mapping } - pub fn centroids(&self) -> &dyn VectorStorage { - &*self.centroids + pub fn centroids(&self) -> &AtomicRefCell + Send + Sync>> { + &self.centroids } pub fn posting_lists(&self) -> &dyn for<'a> PostingListStorage<'a> { @@ -208,14 +214,14 @@ impl IvfBuilder { /// Add a new vector to the dataset for training pub fn add_vector(&mut self, doc_id: u64, data: &[f32]) -> Result<()> { - self.vectors.append(&data)?; + self.vectors.borrow_mut().append(&data)?; self.generate_id(doc_id)?; Ok(()) } /// Add a new centroid - pub fn add_centroid(&mut self, centroid: &[f32]) -> Result<()> { - self.centroids.append(centroid)?; + pub fn add_centroid(&self, centroid: &[f32]) -> Result<()> { + self.centroids.borrow_mut().append(centroid)?; Ok(()) } @@ -273,30 +279,51 @@ impl IvfBuilder { } pub fn build_posting_lists(&mut self) -> Result<()> { - let mut posting_lists: Vec> = vec![Vec::with_capacity(0); self.centroids.len()]; + debug!("Building posting lists"); + let mut posting_lists: Vec> = + vec![Vec::with_capacity(0); self.centroids.borrow().len()]; // Assign vectors to nearest centroids - for i in 0..self.vectors.len() { - let vector = self.vectors.get(i as u32)?; - let nearest_centroids = Self::find_nearest_centroids( - &vector, - self.centroids.as_ref(), - self.config.max_clusters_per_vector, - )?; - // Find the nearest distance, ensuring that NaN values are treated as greater than any - // other value - let nearest_distance = nearest_centroids - .iter() - .map(|pad| pad.distance) - .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Greater)) - .expect("nearest_distance should not be None"); - for point_and_distance in nearest_centroids.iter() { - if (point_and_distance.distance - nearest_distance).abs() - <= nearest_distance * self.config.distance_threshold - { - posting_lists[point_and_distance.point_id].push(i as u64); + // self.assign_docs_to_cluster(doc_ids, flattened_centroids) + + let doc_ids = (0..self.vectors.borrow().len()).collect::>(); + // let vector_clone = self.vectors.clone(); + let max_clusters_per_vector = self.config.max_clusters_per_vector; + let posting_list_per_doc = doc_ids + .par_iter() + .map(|doc_id| { + let nearest_centroids = Self::find_nearest_centroids( + self.vectors.borrow().get(*doc_id as u32).unwrap(), + self.centroids.borrow().as_ref(), + max_clusters_per_vector, + ) + .expect("Nearest centroids should not be None"); + // Find the nearest distance, ensuring that NaN values are treated as greater than any + // other value + let nearest_distance = nearest_centroids + .iter() + .map(|pad| pad.distance) + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Greater)) + .expect("nearest_distance should not be None"); + let mut accepted_centroid_ids = vec![]; + for centroid_and_distance in nearest_centroids.iter() { + if (centroid_and_distance.distance - nearest_distance).abs() + <= nearest_distance * self.config.distance_threshold + { + accepted_centroid_ids.push(centroid_and_distance.point_id as u64); + } } - } - } + accepted_centroid_ids + }) + .collect::>>(); + + posting_list_per_doc + .iter() + .enumerate() + .for_each(|(i, posting_list_for_doc)| { + posting_list_for_doc.iter().for_each(|posting_list_id| { + posting_lists[*posting_list_id as usize].push(i as u64); + }); + }); let posting_list_storage_location = format!( "{}/builder_posting_list_storage", @@ -340,47 +367,130 @@ impl IvfBuilder { }); } - for doc_id in doc_ids { - let vector = self.vectors.get(doc_id as u32)?; - let nearest_centroid = Self::find_nearest_centroid_inmemory( - &vector, - flattened_centroids, - self.config.num_features, - ); - posting_list_infos[nearest_centroid] + let num_features = self.config.num_features; + let vectors = self.vectors.borrow(); + let nearest_centroids = doc_ids + .par_iter() + .map(|doc_id| { + Self::find_nearest_centroid_inmemory( + vectors.get(*doc_id as u32).unwrap(), + &flattened_centroids, + num_features, + ) + }) + .collect::>(); + + for (doc_id, nearest_centroid) in doc_ids.iter().zip(nearest_centroids.iter()) { + posting_list_infos[*nearest_centroid] .posting_list - .push(doc_id); + .push(*doc_id); } Ok(posting_list_infos) } - fn get_flattened_dataset(&self, doc_ids: &[usize]) -> Result> { + fn get_sample_dataset_from_doc_ids( + &self, + doc_ids: &[usize], + sample_size: usize, + ) -> Result> { + let mut rng = rand::thread_rng(); let mut flattened_dataset: Vec = vec![]; - for i in 0..doc_ids.len() { - let vector = self.vectors.get(doc_ids[i] as u32)?; - flattened_dataset.extend_from_slice(vector); - } + doc_ids + .choose_multiple(&mut rng, sample_size) + .for_each(|doc_id| { + flattened_dataset + .extend_from_slice(self.vectors.borrow().get(*doc_id as u32).unwrap()); + }); Ok(flattened_dataset) } - fn cluster_docs(&self, doc_ids: Vec) -> Result> { + fn cluster_docs( + &self, + doc_ids: Vec, + max_posting_list_size: usize, + ) -> Result> { + let num_clusters = ceil_div(doc_ids.len(), max_posting_list_size); + + let num_points_for_clustering = max( + num_clusters * 10, + self.config.num_data_points_for_clustering, + ); let kmeans = KMeansBuilder::new( - self.config.num_clusters, + num_clusters, self.config.max_iteration, self.config.tolerance, self.config.num_features, KMeansVariant::Lloyd, ); - let flattened_dataset = self.get_flattened_dataset(doc_ids.as_ref())?; + let flattened_dataset = + self.get_sample_dataset_from_doc_ids(&doc_ids, num_points_for_clustering)?; + debug!( + "Local clustering with {} docs, sample size {}, num_clusters {}", + doc_ids.len(), + flattened_dataset.len() / self.config.num_features, + num_clusters + ); + let start = Instant::now(); let result = kmeans.fit(flattened_dataset)?; - self.assign_docs_to_cluster(doc_ids, result.centroids.as_ref()) + debug!("Time taken for kmeans: {}", start.elapsed().as_micros()); + let posting_list_infos = self.assign_docs_to_cluster(doc_ids, result.centroids.as_ref())?; + + { + let mut num_more_max = 0; + let mut num_less_max = 0; + let mut max_size = 0; + for pli in &posting_list_infos { + if pli.posting_list.len() > max_posting_list_size { + num_more_max += 1; + } else { + num_less_max += 1; + } + + if pli.posting_list.len() > max_size { + max_size = pli.posting_list.len(); + } + } + debug!( + "Number of posting lists with more than max_posting_list_size: {}", + num_more_max + ); + debug!( + "Number of posting lists with less than max_posting_list_size: {}", + num_less_max + ); + debug!("Max posting list size: {}", max_size); + } + Ok(posting_list_infos) + } + + fn compute_actual_num_clusters( + &self, + total_data_points: usize, + num_clusters: usize, + max_points_per_centroid: usize, + ) -> usize { + let num_centroids = num_clusters; + let num_points_per_centroid = total_data_points / num_centroids; + ceil_div( + total_data_points, + min(num_points_per_centroid, max_points_per_centroid), + ) } pub fn build_centroids(&mut self) -> Result<()> { // First pass to get the initial centroids - let kmeans = KMeansBuilder::new( + let num_clusters = self.compute_actual_num_clusters( + self.vectors.borrow().len(), self.config.num_clusters, + self.config.max_posting_list_size, + ); + debug!( + "First pass, will attemp to build {} centroids", + num_clusters + ); + let kmeans = KMeansBuilder::new( + num_clusters, self.config.max_iteration, self.config.tolerance, self.config.num_features, @@ -389,42 +499,76 @@ impl IvfBuilder { // Sample the dataset to build the first set of centroids let mut rng = rand::thread_rng(); - let num_input_vectors = self.vectors.len(); + let num_input_vectors = self.vectors.borrow().len(); // Create a vector from 0 to num_input_vectors and then shuffle it let mut flattened_dataset: Vec = vec![]; let indices: Vec = (0..num_input_vectors as usize).collect(); + let num_points_for_clustering = + max(num_clusters, self.config.num_data_points_for_clustering); + debug!( + "Partial clustering with {} points", + num_points_for_clustering + ); let selected = indices - .choose_multiple(&mut rng, self.config.num_data_points) + .choose_multiple(&mut rng, num_points_for_clustering) .cloned() .collect::>(); selected.iter().for_each(|index| { - flattened_dataset.extend_from_slice(self.vectors.get(*index as u32).unwrap()); + flattened_dataset.extend_from_slice(self.vectors.borrow().get(*index as u32).unwrap()); }); let result = kmeans.fit(flattened_dataset)?; let posting_list_infos = self.assign_docs_to_cluster(indices, result.centroids.as_ref())?; + { + let mut num_more_max = 0; + let mut num_less_max = 0; + for pli in &posting_list_infos { + if pli.posting_list.len() > self.config.max_posting_list_size { + num_more_max += 1; + } else { + num_less_max += 1; + } + } + debug!( + "Number of posting lists with more than max_posting_list_size: {}", + num_more_max + ); + debug!( + "Number of posting lists with less than max_posting_list_size: {}", + num_less_max + ); + } // Repeatedly run kmeans on the longest posting list until no posting list is longer // than max_posting_list_size let mut heap = BinaryHeap::::new(); for posting_list_info in posting_list_infos { heap.push(posting_list_info); } + let mut num_iter = 0; while heap.len() > 0 { match heap.peek() { None => break, Some(longest_posting_list) => { - if longest_posting_list.posting_list.len() < self.config.max_posting_list_size { + if longest_posting_list.posting_list.len() <= self.config.max_posting_list_size + { break; } } } let longest_posting_list = heap.pop().unwrap(); - let new_posting_list_infos = - self.cluster_docs(longest_posting_list.posting_list.clone())?; + debug!( + "Clustering longest posting list with length: {}", + longest_posting_list.posting_list.len() + ); + num_iter += 1; + let new_posting_list_infos = self.cluster_docs( + longest_posting_list.posting_list.clone(), + self.config.max_posting_list_size, + )?; // Add the new posting list infos to the heap for posting_list_info in new_posting_list_infos { @@ -432,6 +576,8 @@ impl IvfBuilder { } } + debug!("Number of iterations to cluster: {}", num_iter); + // Add the centroids to the centroid storage // We don't need to add the posting lists to the posting list storage, since later on // we will add them @@ -555,7 +701,7 @@ impl IvfBuilder { /// Assign new ids to the vectors fn get_reassigned_ids(&mut self) -> Result> { - let vector_length = self.vectors.len(); + let vector_length = self.vectors.borrow().len(); let mut assigned_ids = vec![-1; vector_length]; let mut cur_idx = self.assign_ids_until_last_stopping_point(&mut assigned_ids)?; @@ -622,18 +768,20 @@ impl IvfBuilder { ); create_dir_all(&new_vectors_path)?; - let mut new_vector_storage = Box::new(FileBackedAppendableVectorStorage::::new( - new_vectors_path, - self.config.memory_size, - self.config.file_size, - self.config.num_features, - )); + let new_vector_storage: AtomicRefCell + Send + Sync>> = + AtomicRefCell::new(Box::new(FileBackedAppendableVectorStorage::::new( + new_vectors_path, + self.config.memory_size, + self.config.file_size, + self.config.num_features, + ))); for i in 0..reverse_assigned_ids.len() { let mapped_id = reverse_assigned_ids[i]; - let vector = self.vectors.get(mapped_id as u32).unwrap(); + // let vector = self.vectors.borrow().get(mapped_id as u32).unwrap(); new_vector_storage - .append(vector) + .borrow_mut() + .append(self.vectors.borrow().get(mapped_id as u32).unwrap()) .unwrap_or_else(|_| panic!("append failed")); } @@ -708,7 +856,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -772,7 +920,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -810,7 +958,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -866,7 +1014,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -933,7 +1081,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -1006,7 +1154,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -1079,7 +1227,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -1174,7 +1322,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -1252,7 +1400,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: NUM_VECTORS, + num_data_points_for_clustering: NUM_VECTORS, max_clusters_per_vector: 2, distance_threshold: 0.1, base_directory, @@ -1289,6 +1437,7 @@ mod tests { assert_eq!( builder .vectors + .borrow() .get(i as u32) .expect(&format!("Failed to retrieve vector #{}", i))[0], expected_vectors[i] @@ -1326,7 +1475,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory, @@ -1347,8 +1496,8 @@ mod tests { let result = builder.build(); assert!(result.is_ok()); - assert_eq!(builder.vectors.len(), num_vectors); - assert_eq!(builder.centroids.len(), num_clusters); + assert_eq!(builder.vectors.borrow().len(), num_vectors); + assert_eq!(builder.centroids.borrow().len(), num_clusters); assert_eq!(builder.posting_lists.len(), num_clusters); // Total size of vectors is bigger than file size, check that they are flushed to disk @@ -1377,4 +1526,15 @@ mod tests { .div_ceil(builder.config.file_size) ); } + + #[test] + fn test_sample() { + let num: Vec = (0..100).collect(); + let mut rng = rand::thread_rng(); + let sample = num + .choose_multiple(&mut rng, 10) + .cloned() + .collect::>(); + println!("{:?}", sample); + } } diff --git a/rs/index/src/ivf/reader.rs b/rs/index/src/ivf/reader.rs index d9eb783..c8cc8b7 100644 --- a/rs/index/src/ivf/reader.rs +++ b/rs/index/src/ivf/reader.rs @@ -72,7 +72,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: base_directory.clone(), @@ -115,8 +115,10 @@ mod tests { for i in 0..num_vectors { let ref_vector = builder .vectors() + .borrow() .get(i as u32) - .expect("Failed to read vector from FileBackedAppendableVectorStorage"); + .expect("Failed to read vector from FileBackedAppendableVectorStorage") + .to_vec(); let read_vector = index .vector_storage .get(i, &mut context) @@ -156,8 +158,10 @@ mod tests { for i in 0..num_clusters { let ref_vector = builder .centroids() + .borrow() .get(i as u32) - .expect("Failed to read centroid from FileBackedAppendableVectorStorage"); + .expect("Failed to read centroid from FileBackedAppendableVectorStorage") + .to_vec(); let read_vector = index .index_storage .get_centroid(i) @@ -210,7 +214,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: base_directory.clone(), diff --git a/rs/index/src/ivf/writer.rs b/rs/index/src/ivf/writer.rs index d3d508f..783c516 100644 --- a/rs/index/src/ivf/writer.rs +++ b/rs/index/src/ivf/writer.rs @@ -35,8 +35,8 @@ impl IvfWriter { } let num_features = ivf_builder.config().num_features; - let num_clusters = ivf_builder.centroids().len(); - let num_vectors = ivf_builder.vectors().len(); + let num_clusters = ivf_builder.centroids().borrow().len(); + let num_vectors = ivf_builder.vectors().borrow().len(); // Write vectors let vectors_len = self @@ -125,9 +125,11 @@ impl IvfWriter { self.quantizer.quantized_dimension(), )); - for i in 0..full_vectors.len() { - let vector = full_vectors.get(i as u32)?; - let quantized_vector = Q::QuantizedT::process_vector(vector, &self.quantizer); + for i in 0..full_vectors.borrow().len() { + let quantized_vector = Q::QuantizedT::process_vector( + full_vectors.borrow().get(i as u32)?, + &self.quantizer, + ); quantized_vectors.append(&quantized_vector)?; } @@ -162,7 +164,7 @@ impl IvfWriter { let mut file = File::create(path)?; let mut writer = BufWriter::new(&mut file); - let bytes_written = ivf_builder.centroids().write(&mut writer)?; + let bytes_written = ivf_builder.centroids().borrow().write(&mut writer)?; Ok(bytes_written) } @@ -409,7 +411,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: base_directory.clone(), @@ -497,7 +499,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: base_directory.clone(), diff --git a/rs/index/src/spann/builder.rs b/rs/index/src/spann/builder.rs index 35369d9..00ed787 100644 --- a/rs/index/src/spann/builder.rs +++ b/rs/index/src/spann/builder.rs @@ -19,7 +19,7 @@ pub struct SpannBuilderConfig { pub max_iteration: usize, pub batch_size: usize, pub num_clusters: usize, - pub num_data_points: usize, + pub num_data_points_for_clustering: usize, pub max_clusters_per_vector: usize, // Threshold to add a vector to more than one cluster pub distance_threshold: f32, @@ -48,7 +48,7 @@ impl Default for SpannBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters: 10, - num_data_points: 1000, + num_data_points_for_clustering: 1000, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: "./".to_string(), @@ -73,7 +73,7 @@ impl SpannBuilder { max_iteration: config.max_iteration, batch_size: config.batch_size, num_clusters: config.num_clusters, - num_data_points: config.num_data_points, + num_data_points_for_clustering: config.num_data_points_for_clustering, max_clusters_per_vector: config.max_clusters_per_vector, distance_threshold: config.distance_threshold, base_directory: config.base_directory.clone(), @@ -118,11 +118,11 @@ impl SpannBuilder { self.ivf_builder.build()?; let centroid_storage = self.ivf_builder.centroids(); - let num_centroids = centroid_storage.len(); + let num_centroids = centroid_storage.borrow().len(); for i in 0..num_centroids { self.centroid_builder - .insert(i as u64, ¢roid_storage.get(i as u32).unwrap())?; + .insert(i as u64, ¢roid_storage.borrow().get(i as u32).unwrap())?; } Ok(()) diff --git a/rs/index/src/spann/index.rs b/rs/index/src/spann/index.rs index 587dc79..f8dd68a 100644 --- a/rs/index/src/spann/index.rs +++ b/rs/index/src/spann/index.rs @@ -1,3 +1,6 @@ +use std::cmp::Ordering; + +use log::debug; use quantization::noq::noq::NoQuantizer; use crate::hnsw::index::Hnsw; @@ -37,11 +40,30 @@ impl Searchable for Spann { // TODO(hicder): Fully implement SPANN, which includes adjusting number of centroids match self.centroids.search(query, k, ef_construction, context) { Some(nearest_centroids) => { - let nearest_centroid_ids = - nearest_centroids.iter().map(|x| x.id as usize).collect(); if nearest_centroids.is_empty() { return None; } + + // Get the nearest centroid, and only search those that are within 10% of the distance of the nearest centroid + let nearest_distance = nearest_centroids + .iter() + .map(|pad| pad.score) + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Greater)) + .expect("nearest_distance should not be None"); + + let nearest_centroid_ids: Vec = nearest_centroids + .iter() + .filter(|centroid_and_distance| { + centroid_and_distance.score - nearest_distance < nearest_distance * 0.1 + }) + .map(|x| x.id as usize) + .collect(); + + debug!( + "Number of nearest centroids: {}", + nearest_centroid_ids.len() + ); + let results = self.posting_lists.search_with_centroids_and_remap( query, nearest_centroid_ids, diff --git a/rs/index/src/spann/reader.rs b/rs/index/src/spann/reader.rs index 5d780ec..68f7ae0 100644 --- a/rs/index/src/spann/reader.rs +++ b/rs/index/src/spann/reader.rs @@ -55,7 +55,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: base_directory.clone(), diff --git a/rs/index/src/spann/writer.rs b/rs/index/src/spann/writer.rs index 81cc6b8..528583c 100644 --- a/rs/index/src/spann/writer.rs +++ b/rs/index/src/spann/writer.rs @@ -85,7 +85,7 @@ mod tests { max_iteration: 1000, batch_size: 4, num_clusters, - num_data_points: num_vectors, + num_data_points_for_clustering: num_vectors, max_clusters_per_vector: 1, distance_threshold: 0.1, base_directory: base_directory.clone(), diff --git a/rs/index_writer/src/index_writer.rs b/rs/index_writer/src/index_writer.rs index 95f245b..6cee8ec 100644 --- a/rs/index_writer/src/index_writer.rs +++ b/rs/index_writer/src/index_writer.rs @@ -202,7 +202,7 @@ impl IndexWriter { max_iteration: index_builder_config.ivf_config.max_iteration, batch_size: index_builder_config.ivf_config.batch_size, num_clusters: index_builder_config.ivf_config.num_clusters, - num_data_points: index_builder_config.ivf_config.num_data_points, + num_data_points_for_clustering: index_builder_config.ivf_config.num_data_points, max_clusters_per_vector: index_builder_config.ivf_config.max_clusters_per_vector, distance_threshold: index_builder_config.ivf_config.distance_threshold, base_directory: path.to_string(), @@ -356,7 +356,7 @@ impl IndexWriter { max_iteration: index_writer_config.ivf_config.max_iteration, batch_size: index_writer_config.ivf_config.batch_size, num_clusters: index_writer_config.ivf_config.num_clusters, - num_data_points: index_writer_config.ivf_config.num_data_points, + num_data_points_for_clustering: index_writer_config.ivf_config.num_data_points, max_clusters_per_vector: index_writer_config.ivf_config.max_clusters_per_vector, distance_threshold: index_writer_config.ivf_config.distance_threshold, base_directory: root_path.to_string(), diff --git a/rs/index_writer/src/input/hdf5.rs b/rs/index_writer/src/input/hdf5.rs index 697daa4..ecd300b 100644 --- a/rs/index_writer/src/input/hdf5.rs +++ b/rs/index_writer/src/input/hdf5.rs @@ -144,20 +144,23 @@ mod tests { cluster_sizes[*assignment] += 1; } - assert_eq!( - cluster_sizes, - vec![1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000] - ); - - // Check that the distance between the point to its centroid is less than 0.1 + // Check that each point is assigned to its closest centroid for i in 0..flattened_dataset.len() / 128 { let point = &flattened_dataset[i * 128..(i + 1) * 128]; let centroid_id = result.assignments[i]; let centroid = &result.centroids[centroid_id * 128..(centroid_id + 1) * 128]; let dist = L2DistanceCalculator::calculate(&point, ¢roid); - // We might need to adjust this threshold - assert!(dist < 70.0); + for j in 0..10 { + let dist_to_centroid = L2DistanceCalculator::calculate( + &point, + &result.centroids[j * 128..(j + 1) * 128], + ); + if dist_to_centroid < dist { + println!("Point {} is assigned to centroid {} with distance {}, but should be assigned to centroid {} with distance {}", i, centroid_id, dist, j, dist_to_centroid); + } + assert!(dist_to_centroid >= dist); + } } } } diff --git a/rs/utils/src/kmeans_builder/kmeans_builder.rs b/rs/utils/src/kmeans_builder/kmeans_builder.rs index 7a2aae1..6ee271b 100644 --- a/rs/utils/src/kmeans_builder/kmeans_builder.rs +++ b/rs/utils/src/kmeans_builder/kmeans_builder.rs @@ -1,6 +1,10 @@ +use std::cmp::min; +use std::simd::{LaneCount, Simd, SupportedLaneCount}; + use anyhow::{anyhow, Ok, Result}; use kmeans::KMeansConfig; use log::debug; +use rand::seq::SliceRandom; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rayon::slice::ParallelSlice; @@ -13,7 +17,7 @@ pub enum KMeansVariant { } pub struct KMeansBuilder { - pub num_cluters: usize, + pub num_clusters: usize, pub max_iter: usize, // Factor which determine how much penalty large cluster has over small cluster. @@ -32,6 +36,7 @@ pub struct KMeansResult { // Flattened centroids pub centroids: Vec, pub assignments: Vec, + pub error: f32, } // TODO(hicder): Add support for different variants of k-means. @@ -45,7 +50,7 @@ impl KMeansBuilder { variant: KMeansVariant, ) -> Self { Self { - num_cluters, + num_clusters: num_cluters, max_iter, tolerance, dimension, @@ -63,7 +68,7 @@ impl KMeansBuilder { cluster_init_values: Vec, ) -> Self { Self { - num_cluters, + num_clusters: num_cluters, max_iter, tolerance, dimension, @@ -89,14 +94,16 @@ impl KMeansBuilder { .build(); let kmean: kmeans::KMeans<_, 16> = kmeans::KMeans::new(data, sample_count, self.dimension); let result = kmean.kmeans_lloyd( - self.num_cluters, + self.num_clusters, self.max_iter, kmeans::KMeans::init_random_sample, &conf, ); let kmeans_result = KMeansResult { + // intial_centroids: vec![], centroids: result.centroids, assignments: result.assignments, + error: result.distsum, }; Ok(kmeans_result) } @@ -115,117 +122,143 @@ impl KMeansBuilder { KMeansVariant::Lloyd => { if self.dimension % 16 == 0 { return self - .run_lloyd::>(flattened_data); + .run_lloyd::, 16>(flattened_data); } else if self.dimension % 8 == 0 { - return self.run_lloyd::>(flattened_data); + return self + .run_lloyd::, 8>(flattened_data); } else if self.dimension % 4 == 0 { - return self.run_lloyd::>(flattened_data); + return self + .run_lloyd::, 4>(flattened_data); } else { - return self.run_lloyd::(flattened_data); + return self.run_lloyd::(flattened_data); } } } } - fn run_lloyd( + fn init_random_points(&self, points: &Vec<&[f32]>, num_clusters: usize) -> Result> { + match &self.cluster_init_values { + Some(cluster_init_values) if cluster_init_values.len() == num_clusters => { + return Ok(cluster_init_values + .iter() + .map(|point_id| points[*point_id]) + .flatten() + .cloned() + .collect()); + } + _ => { + let mut rng = rand::thread_rng(); + let mut centroids = vec![]; + points + .choose_multiple(&mut rng, num_clusters) + .for_each(|point| { + centroids.extend_from_slice(*point); + }); + return Ok(centroids); + } + } + } + + fn run_lloyd( &self, flattened_data_points: Vec, - ) -> Result { + ) -> Result + where + LaneCount: SupportedLaneCount, + { let data_points = flattened_data_points .par_chunks_exact(self.dimension) .map(|x| x) .collect::>(); let num_data_points = data_points.len(); - let mut cluster_labels = vec![0; num_data_points]; + let num_clusters = min(self.num_clusters, num_data_points); - // Random initialization of cluster labels - match &self.cluster_init_values { - Some(values) => { - for i in 0..num_data_points { - cluster_labels[i] = values[i % values.len()]; - } - } - None => { - for i in 0..num_data_points { - cluster_labels[i] = rand::random::() % self.num_cluters; - } - } - } - - let mut cluster_sizes = vec![0; self.num_cluters]; - for i in 0..num_data_points { - cluster_sizes[cluster_labels[i]] += 1; - } + // Choose random few points as initial centroids + let mut centroids = self.init_random_points(&data_points, num_clusters)?; + let mut cluster_sizes = vec![0; num_clusters]; - let mut centroids = vec![0.0; self.num_cluters * self.dimension]; - for i in 0..num_data_points { - let data_point = &data_points[i]; - let label = cluster_labels[i]; - for j in 0..self.dimension { - centroids[label * self.dimension + j] += data_point[j]; - } + // Add size penalty term + let mut penalties = vec![0.0; num_clusters]; + if self.tolerance > 0.0 { + penalties + .iter_mut() + .enumerate() + .for_each(|x| *x.1 = self.tolerance * cluster_sizes[x.0] as f32); } - centroids.iter_mut().enumerate().for_each(|x| { - let idx = x.0 / self.dimension; - if cluster_sizes[idx] > 0 { - *x.1 /= cluster_sizes[idx] as f32; - } - }); - - // Add size penalty term - let mut penalties = vec![0.0; self.num_cluters]; - penalties - .iter_mut() - .enumerate() - .for_each(|x| *x.1 = self.tolerance * cluster_sizes[x.0] as f32); + let mut cluster_labels = vec![0; data_points.len()]; - for _iteration in 0..self.max_iter { - let old_labels = cluster_labels.clone(); + let mut last_dist = f32::MAX; + let mut iteration = 0; + loop { + let last_labels = cluster_labels.clone(); // Reassign points using modified distance (Equation 8) - - cluster_labels = data_points + let mut cluster_labels_with_min_cost = data_points .par_iter() .map(|data_point| { let dp = *data_point; // Calculate distance to each centroid - let mut min_cost = f32::MAX; - let mut label = 0; - for centroid_id in 0..self.num_cluters { - let centroid = centroids - [centroid_id * self.dimension..(centroid_id + 1) * self.dimension] - .as_ref(); - let distance = T::calculate_squared(dp, centroid) + penalties[centroid_id]; - - if distance < min_cost { - min_cost = distance; - label = centroid_id; - } - } - label + let res = centroids + .chunks_exact(self.dimension) + .enumerate() + .map(|(centroid_id, centroid)| { + let distance = + T::calculate_squared(dp, centroid) + penalties[centroid_id]; + (centroid_id, distance) + }) + .fold((0, f32::MAX), |(min_label, min_cost), (label, distance)| { + if distance < min_cost { + (label, distance) + } else { + (min_label, min_cost) + } + }); + res }) - .collect::>(); - - // Reinitialize cluster sizes - cluster_sizes.iter_mut().for_each(|x| *x = 0); - for i in 0..num_data_points { - cluster_sizes[cluster_labels[i]] += 1; - } - - // Flattened centroids - centroids.iter_mut().for_each(|x| *x = 0.0); - for i in 0..num_data_points { - let data_point = &data_points[i]; - let label = cluster_labels[i]; - for j in 0..self.dimension { - centroids[label * self.dimension + j] += data_point[j]; - } - } + .collect::>(); + + let mut total_dist = 0.0; + rayon::scope(|s| { + s.spawn(|_| { + total_dist = cluster_labels_with_min_cost + .iter() + .map(|(_, distance)| (*distance).sqrt()) + .sum::(); + }); + s.spawn(|_| { + centroids.iter_mut().for_each(|x| *x = 0.0); + data_points + .iter() + .zip(cluster_labels_with_min_cost.iter()) + .for_each(|(data_point, (label, _))| { + let dp = *data_point; + let centroid_slice = &mut centroids + [*label * self.dimension..(label + 1) * self.dimension]; + centroid_slice + .chunks_exact_mut(SIMD_WIDTH) + .zip( + dp.chunks_exact(SIMD_WIDTH) + .map(|v| Simd::::from_slice(v)), + ) + .for_each(|(c, s)| { + let c_simd = Simd::::from_slice(c); + let result = c_simd + s; + c.copy_from_slice(result.as_array()); + }); + }); + }); + s.spawn(|_| { + cluster_sizes.iter_mut().for_each(|x| *x = 0); + for i in 0..num_data_points { + cluster_sizes[cluster_labels_with_min_cost[i].0] += 1; + } + }); + }); let mut contains_empty_cluster = false; - + // Reinitialize cluster sizes centroids.iter_mut().enumerate().for_each(|x| { let idx = x.0 / self.dimension; if cluster_sizes[idx] > 0 { @@ -235,47 +268,97 @@ impl KMeansBuilder { } }); + // Check if there is any empty cluster if contains_empty_cluster { - // Compute largest cluster - let largest_cluster = cluster_sizes.iter().max().unwrap(); - let largest_cluster_id = cluster_sizes - .iter() - .position(|x| x == largest_cluster) - .unwrap(); - let chosen_point = cluster_labels - .iter() - .position(|x| *x == largest_cluster_id) - .unwrap(); + // Find the point that is in a cluster with more than 1 points, that is farthest away from its cluster + for cluster_id in 0..cluster_sizes.len() { + if cluster_sizes[cluster_id] == 0 { + // debug!("Cluster {} with 0 points found", cluster_id); + let mut max_distance = 0.0; + let mut chosen_point_id = 0; + let mut chosen_cluster_id = 0; + for i in 0..cluster_labels_with_min_cost.len() { + let checking_cluster_id = cluster_labels_with_min_cost[i].0; + if cluster_sizes[checking_cluster_id] > 1 { + let point = data_points[i]; + let cluster = centroids + .chunks_exact(self.dimension) + .nth(cluster_id) + .unwrap(); + let distance = + L2DistanceCalculator::calculate_squared(point, cluster); + if distance > max_distance { + max_distance = distance; + chosen_point_id = i; + chosen_cluster_id = checking_cluster_id; + } + } + } + + let old_size = cluster_sizes[chosen_cluster_id] as f32; + cluster_sizes[chosen_cluster_id] -= 1; + + let chosen_point = data_points[chosen_point_id]; + for j in 0..self.dimension { + let x = centroids[chosen_cluster_id * self.dimension + j]; + centroids[chosen_cluster_id * self.dimension + j] = + (x * old_size - chosen_point[j]) / (old_size - 1.0); + } - // Handle empty clusters - for i in 0..self.num_cluters { - if cluster_sizes[i] == 0 { - // Set the centroid of this cluster to the point + // add chosen point to the new cluster + cluster_labels_with_min_cost[chosen_point_id].0 = cluster_id; + cluster_sizes[cluster_id] = 1; + // update centroid for this cluster for j in 0..self.dimension { - centroids[i * self.dimension + j] = data_points[chosen_point][j]; + centroids[cluster_id * self.dimension + j] = chosen_point[j]; } - cluster_sizes[i] = 1; } } } // Add size penalty term - let mut penalties = vec![0.0; self.num_cluters]; - penalties - .iter_mut() - .enumerate() - .for_each(|x| *x.1 = self.tolerance * cluster_sizes[x.0] as f32); + if self.tolerance > 0.0 { + penalties = vec![0.0; num_clusters]; + penalties + .iter_mut() + .enumerate() + .for_each(|x| *x.1 = self.tolerance * cluster_sizes[x.0] as f32); + + total_dist += penalties + .iter() + .zip(cluster_sizes.iter()) + .map(|(penalty, size)| *penalty * (*size as f32)) + .sum::(); + } - // Check convergence - if cluster_labels == old_labels { - debug!("Converged at iteration {}", _iteration); + debug!( + "Iteration: {}, Error {} -> {}, improvement: {}", + iteration, + last_dist, + total_dist, + last_dist - total_dist + ); + // TODO(hicder): Make 0.0005 a parameter + cluster_labels = cluster_labels_with_min_cost + .iter() + .map(|(label, _)| *label) + .collect(); + if cluster_labels == last_labels || iteration >= self.max_iter { + debug!( + "Converged at iteration {}, improvement: {}", + iteration, + total_dist - last_dist + ); break; } + last_dist = total_dist; + iteration += 1; } Ok(KMeansResult { centroids: centroids, assignments: cluster_labels, + error: last_dist, }) } } @@ -311,13 +394,13 @@ mod tests { 1e-4, 2, KMeansVariant::Lloyd, - vec![0, 0, 0, 1, 1, 1, 2, 2, 2], + vec![0, 1, 2], ); let result = kmeans .fit(flattened_data) .expect("KMeans run should succeed"); - assert_eq!(kmeans.num_cluters, 3); + assert_eq!(kmeans.num_clusters, 3); assert_eq!(kmeans.max_iter, 100); assert_eq!(kmeans.tolerance, 1e-4); assert_eq!(kmeans.dimension, 2); @@ -332,50 +415,6 @@ mod tests { assert_eq!(result.assignments[2], result.assignments[8]); } - #[test] - fn test_kmeans_lloyd_really_large_penalty() { - // This test tests the fact that, point (5.0, 5.0) is assigned to cluster 2 even though - // it is supposed to be assigned to cluster 1. The penalty for unbalancing a cluster is - // extremely large, which forces the point to be reassigned to a different cluster. - let data = vec![ - vec![0.0, 0.0], - vec![40.0, 40.0], - vec![90.0, 90.0], - vec![1.0, 1.0], - vec![41.0, 41.0], - vec![91.0, 91.0], - vec![2.0, 2.0], - vec![5.0, 5.0], - vec![92.0, 92.0], - ]; - - let flattened_data = data - .iter() - .map(|x| x.as_slice()) - .flatten() - .cloned() - .collect(); - let kmeans = KMeansBuilder::new_with_cluster_init_values( - 3, - 100, - 10000.0, - 2, - KMeansVariant::Lloyd, - vec![0, 0, 0, 0, 0, 0, 2, 2, 2], - ); - let result = kmeans - .fit(flattened_data) - .expect("KMeans run should succeed"); - - assert_eq!(result.centroids.len(), 3 * 2); - assert_eq!(result.assignments[0], result.assignments[3]); - assert_eq!(result.assignments[0], result.assignments[6]); - assert_eq!(result.assignments[1], result.assignments[4]); - assert_eq!(result.assignments[1], result.assignments[7]); - assert_eq!(result.assignments[2], result.assignments[5]); - assert_eq!(result.assignments[2], result.assignments[8]); - } - #[test] fn test_kmeans_no_distance_penalty() { let data = vec![ @@ -402,7 +441,7 @@ mod tests { 0.0, 2, KMeansVariant::Lloyd, - vec![0, 0, 0, 1, 1, 1, 2, 2, 2], + vec![0, 1, 2], ); let result = kmeans diff --git a/rs/utils/src/lib.rs b/rs/utils/src/lib.rs index 8088ff7..3185c7f 100644 --- a/rs/utils/src/lib.rs +++ b/rs/utils/src/lib.rs @@ -23,3 +23,7 @@ pub trait DistanceCalculator { pub trait CalculateSquared { fn calculate_squared(a: &[f32], b: &[f32]) -> f32; } + +pub fn ceil_div(a: usize, b: usize) -> usize { + (a + b - 1) / b +}