From 0c096c225424a0aaa48966eb535eab23c2faae05 Mon Sep 17 00:00:00 2001 From: iamthinh Date: Fri, 3 Jan 2025 23:02:47 +0700 Subject: [PATCH] Dot-product implementation and distance template (#237) * add distance config for index and quantizer * distance template for IVF and kmeans index * move lane_conforming to another file and update code according to suggested comments --- .gitignore | 2 +- rs/index/src/ivf/builder.rs | 33 +-- rs/index/src/ivf/index.rs | 39 +++- rs/index/src/ivf/reader.rs | 16 +- rs/index/src/ivf/writer.rs | 46 ++-- rs/index/src/spann/builder.rs | 5 +- rs/index/src/spann/index.rs | 10 +- rs/index/src/spann/reader.rs | 4 +- rs/index/src/spann/writer.rs | 3 +- rs/index_writer/src/config.rs | 9 + rs/index_writer/src/index_writer.rs | 63 +++++- rs/index_writer/src/input/hdf5.rs | 3 +- rs/proto/src/muopdb.rs | 214 +++++++----------- rs/utils/benches/dot_product.rs | 6 +- rs/utils/benches/kmeans.rs | 3 +- rs/utils/src/distance/dot_product.rs | 37 ++- rs/utils/src/distance/l2.rs | 19 +- rs/utils/src/distance/lane_conforming.rs | 51 +++++ rs/utils/src/distance/mod.rs | 3 +- rs/utils/src/kmeans_builder/kmeans_builder.rs | 30 ++- rs/utils/src/lib.rs | 6 + rs/utils/src/scripts/run_kmeans.rs | 4 +- 22 files changed, 356 insertions(+), 250 deletions(-) create mode 100644 rs/utils/src/distance/lane_conforming.rs diff --git a/.gitignore b/.gitignore index 9c9eb81..fd24bc3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,6 @@ flamegraph.svg perf.data* .idea -.DS_store +.DS_Store .venv */__pycache__/ diff --git a/rs/index/src/ivf/builder.rs b/rs/index/src/ivf/builder.rs index 8f2b754..2c412fd 100644 --- a/rs/index/src/ivf/builder.rs +++ b/rs/index/src/ivf/builder.rs @@ -2,6 +2,7 @@ use std::cmp::{max, min, Ordering, Reverse}; use std::collections::{BinaryHeap, HashMap}; use std::fs::{create_dir, create_dir_all}; use std::io::ErrorKind; +use std::marker::PhantomData; use anyhow::{anyhow, Result}; use atomic_refcell::AtomicRefCell; @@ -38,12 +39,13 @@ pub struct IvfBuilderConfig { pub max_posting_list_size: usize, } -pub struct IvfBuilder { +pub struct IvfBuilder { config: IvfBuilderConfig, vectors: AtomicRefCell + Send + Sync>>, centroids: AtomicRefCell + Send + Sync>>, posting_lists: Box PostingListStorage<'a>>, doc_id_mapping: Vec, + _marker: PhantomData, } // TODO(tyb): maybe merge with HNSW's one @@ -141,7 +143,7 @@ impl PartialEq for PostingListWithStoppingPoints { impl Eq for PostingListWithStoppingPoints {} -impl IvfBuilder { +impl IvfBuilder { /// Create a new IvfBuilder pub fn new(config: IvfBuilderConfig) -> Result { // Create the base directory and all parent directories if they don't exist @@ -184,6 +186,7 @@ impl IvfBuilder { centroids, posting_lists, doc_id_mapping: Vec::new(), + _marker: PhantomData, }) } @@ -245,7 +248,7 @@ impl IvfBuilder { let mut centroid_index = 0; for i in 0..flattened_centroids.len() / dimension { let centroid = &flattened_centroids[i * dimension..(i + 1) * dimension]; - let dist = L2DistanceCalculator::calculate(&vector, ¢roid); + let dist = D::calculate(&vector, ¢roid); if dist > max_distance { max_distance = dist; centroid_index = i; @@ -415,7 +418,7 @@ impl IvfBuilder { num_clusters * 10, self.config.num_data_points_for_clustering, ); - let kmeans = KMeansBuilder::new( + let kmeans = KMeansBuilder::::new( num_clusters, self.config.max_iteration, self.config.tolerance, @@ -453,7 +456,7 @@ impl IvfBuilder { self.config.num_clusters, self.config.max_posting_list_size, ); - let kmeans = KMeansBuilder::new( + let kmeans = KMeansBuilder::::new( num_clusters, self.config.max_iteration, self.config.tolerance, @@ -789,7 +792,7 @@ mod tests { let file_size = 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -853,7 +856,7 @@ mod tests { let file_size = 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -891,7 +894,7 @@ mod tests { let file_size = 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -947,7 +950,7 @@ mod tests { let file_size = 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -1014,7 +1017,7 @@ mod tests { let file_size = 4096 * 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -1087,7 +1090,7 @@ mod tests { let file_size = 4096 * 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -1160,7 +1163,7 @@ mod tests { let file_size = 4096 * 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -1255,7 +1258,7 @@ mod tests { let file_size = 4096 * 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -1333,7 +1336,7 @@ mod tests { let balance_factor = 0.0; let max_posting_list_size = usize::MAX; const NUM_VECTORS: usize = 22; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -1408,7 +1411,7 @@ mod tests { let file_size = 4096; let balance_factor = 0.0; let max_posting_list_size = usize::MAX; - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, diff --git a/rs/index/src/ivf/index.rs b/rs/index/src/ivf/index.rs index 6d2e8f8..1612e69 100644 --- a/rs/index/src/ivf/index.rs +++ b/rs/index/src/ivf/index.rs @@ -1,9 +1,9 @@ use std::collections::BinaryHeap; +use std::marker::PhantomData; use anyhow::{Context, Result}; use quantization::quantization::Quantizer; use quantization::typing::VectorOps; -use utils::distance::l2::L2DistanceCalculator; use utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD; use utils::mem::transmute_u8_to_slice; use utils::DistanceCalculator; @@ -13,7 +13,7 @@ use crate::posting_list::combined_file::FixedIndexFile; use crate::utils::{IdWithScore, SearchContext}; use crate::vector::fixed_file::FixedFileVectorStorage; -pub struct Ivf { +pub struct Ivf { // The dataset. pub vector_storage: FixedFileVectorStorage, @@ -29,9 +29,11 @@ pub struct Ivf { pub num_clusters: usize, pub quantizer: Q, + + _marker: PhantomData, } -impl Ivf { +impl Ivf { pub fn new( vector_storage: FixedFileVectorStorage, index_storage: FixedIndexFile, @@ -43,6 +45,7 @@ impl Ivf { index_storage, num_clusters, quantizer, + _marker: PhantomData, } } @@ -56,7 +59,7 @@ impl Ivf { let centroid = index_storage .get_centroid(i as usize) .with_context(|| format!("Failed to get centroid at index {}", i))?; - let dist = L2DistanceCalculator::calculate(&vector, ¢roid); + let dist = D::calculate(&vector, ¢roid); distances.push((i as usize, dist)); } distances.select_nth_unstable_by(num_probes - 1, |a, b| a.1.total_cmp(&b.1)); @@ -146,7 +149,7 @@ impl Ivf { } } -impl Searchable for Ivf { +impl Searchable for Ivf { fn search( &self, query: &[f32], @@ -180,6 +183,7 @@ mod tests { use num_traits::ops::bytes::ToBytes; use quantization::noq::noq::NoQuantizer; use quantization::pq::pq::ProductQuantizer; + use utils::distance::l2::L2DistanceCalculator; use utils::mem::transmute_slice_to_u8; use super::*; @@ -337,7 +341,12 @@ mod tests { let num_clusters = 2; let quantizer = NoQuantizer::new(3); - let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer); + let ivf = Ivf::::new( + storage, + index_storage, + num_clusters, + quantizer, + ); assert_eq!(ivf.num_clusters, num_clusters); let cluster_0 = transmute_u8_to_slice::( @@ -383,9 +392,12 @@ mod tests { FixedIndexFile::new(file_path).expect("FixedIndexFile should be created"); let num_probes = 2; - let nearest = - Ivf::::find_nearest_centroids(&vector, &index_storage, num_probes) - .expect("Nearest centroids should be found"); + let nearest = Ivf::::find_nearest_centroids( + &vector, + &index_storage, + num_probes, + ) + .expect("Nearest centroids should be found"); assert_eq!(nearest[0], 1); assert_eq!(nearest[1], 0); @@ -431,7 +443,8 @@ mod tests { let num_probes = 2; let quantizer = NoQuantizer::new(num_features); - let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer); + let ivf: Ivf = + Ivf::new(storage, index_storage, num_clusters, quantizer); let query = vec![2.0, 3.0, 4.0]; let k = 2; @@ -505,7 +518,8 @@ mod tests { let num_clusters = 2; let num_probes = 2; - let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer); + let ivf: Ivf = + Ivf::new(storage, index_storage, num_clusters, quantizer); let query = vec![2.0, 3.0, 4.0]; let k = 2; @@ -557,7 +571,8 @@ mod tests { let num_probes = 1; let quantizer = NoQuantizer::new(num_features); - let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer); + let ivf: Ivf = + Ivf::new(storage, index_storage, num_clusters, quantizer); let query = vec![1.0, 2.0, 3.0]; let k = 5; // More than available results diff --git a/rs/index/src/ivf/reader.rs b/rs/index/src/ivf/reader.rs index 6bf9075..ad4a201 100644 --- a/rs/index/src/ivf/reader.rs +++ b/rs/index/src/ivf/reader.rs @@ -1,5 +1,6 @@ use anyhow::Result; use quantization::quantization::Quantizer; +use utils::DistanceCalculator; use crate::ivf::index::Ivf; use crate::posting_list::combined_file::FixedIndexFile; @@ -14,7 +15,7 @@ impl IvfReader { Self { base_directory } } - pub fn read(&self) -> Result> { + pub fn read(&self) -> Result> { let index_storage = FixedIndexFile::new(format!("{}/index", self.base_directory))?; let vector_storage_path = format!("{}/vectors", self.base_directory); @@ -45,6 +46,7 @@ mod tests { use compression::noc::noc::PlainEncoder; use quantization::noq::noq::{NoQuantizer, NoQuantizerWriter}; use tempdir::TempDir; + use utils::distance::l2::L2DistanceCalculator; use utils::mem::transmute_u8_to_slice; use utils::test_utils::generate_random_vector; @@ -68,9 +70,9 @@ mod tests { let num_features = 4; let file_size = 4096; let quantizer = NoQuantizer::new(num_features); - let writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer); + let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer); - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -105,7 +107,7 @@ mod tests { let reader = IvfReader::new(base_directory.clone()); let index = reader - .read::() + .read::() .expect("Failed to read index file"); // Check if files were created @@ -212,9 +214,9 @@ mod tests { let noq_writer = NoQuantizerWriter::new(quantizer_directory); assert!(noq_writer.write(&quantizer).is_ok()); - let writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer); + let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer); - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -241,7 +243,7 @@ mod tests { let reader = IvfReader::new(base_directory.clone()); let index = reader - .read::() + .read::() .expect("Failed to read index file"); let num_centroids = index.num_clusters; diff --git a/rs/index/src/ivf/writer.rs b/rs/index/src/ivf/writer.rs index 9885462..f8441c3 100644 --- a/rs/index/src/ivf/writer.rs +++ b/rs/index/src/ivf/writer.rs @@ -10,26 +10,39 @@ use num_traits::ToBytes; use quantization::quantization::Quantizer; use quantization::typing::VectorOps; use utils::io::{append_file_to_writer, wrap_write}; +use utils::{CalculateSquared, DistanceCalculator}; use crate::ivf::builder::IvfBuilder; use crate::posting_list::combined_file::{Header, Version}; -pub struct IvfWriter { +pub struct IvfWriter +where + Q: Quantizer, + C: IntSeqEncoder, + D: DistanceCalculator + CalculateSquared + Send + Sync +{ base_directory: String, quantizer: Q, - _marker: PhantomData, + _marker_1: PhantomData, + _marker_2: PhantomData, } -impl IvfWriter { +impl IvfWriter +where + Q: Quantizer, + C: IntSeqEncoder + 'static, + D: DistanceCalculator + CalculateSquared + Send + Sync, +{ pub fn new(base_directory: String, quantizer: Q) -> Self { Self { base_directory, quantizer, - _marker: PhantomData, + _marker_1: PhantomData, + _marker_2: PhantomData, } } - pub fn write(&self, ivf_builder: &mut IvfBuilder, reindex: bool) -> Result<()> { + pub fn write(&self, ivf_builder: &mut IvfBuilder, reindex: bool) -> Result<()> { if reindex { // Reindex the vectors for efficient lookup ivf_builder @@ -110,7 +123,7 @@ impl IvfWriter { Ok(()) } - fn quantize_and_write_vectors(&self, ivf_builder: &IvfBuilder) -> Result { + fn quantize_and_write_vectors(&self, ivf_builder: &IvfBuilder) -> Result { // Quantize vectors let full_vectors = &ivf_builder.vectors(); let quantized_vectors_path = format!("{}/quantized", self.base_directory); @@ -142,7 +155,7 @@ impl IvfWriter { Ok(bytes_written) } - fn write_doc_id_mapping(&self, ivf_builder: &IvfBuilder) -> Result { + fn write_doc_id_mapping(&self, ivf_builder: &IvfBuilder) -> Result { let path = format!("{}/doc_id_mapping", self.base_directory); let mut file = File::create(path)?; let mut writer = BufWriter::new(&mut file); @@ -157,7 +170,7 @@ impl IvfWriter { Ok(bytes_written) } - fn write_centroids(&self, ivf_builder: &IvfBuilder) -> Result { + fn write_centroids(&self, ivf_builder: &IvfBuilder) -> Result { let path = format!("{}/centroids", self.base_directory); let mut file = File::create(path)?; let mut writer = BufWriter::new(&mut file); @@ -166,7 +179,7 @@ impl IvfWriter { Ok(bytes_written) } - fn write_posting_lists_and_metadata(&self, ivf_builder: &mut IvfBuilder) -> Result { + fn write_posting_lists_and_metadata(&self, ivf_builder: &mut IvfBuilder) -> Result { let metadata_path = format!("{}/posting_list_metadata", self.base_directory); let mut metadata_file = File::create(metadata_path)?; let mut metadata_writer = BufWriter::new(&mut metadata_file); @@ -300,6 +313,7 @@ mod tests { use quantization::noq::noq::NoQuantizer; use quantization::pq::pq::ProductQuantizer; use tempdir::TempDir; + use utils::distance::l2::L2DistanceCalculator; use utils::test_utils::generate_random_vector; use super::*; @@ -326,7 +340,7 @@ mod tests { // Create an IvfWriter instance let num_features = 10; let quantizer = NoQuantizer::new(num_features); - let ivf_writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer); + let ivf_writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer); // Create test files create_test_file(&base_directory, "centroids", &[5, 6, 7, 8])?; @@ -427,7 +441,7 @@ mod tests { // Pad to 8-byte alignment let padding_written = - IvfWriter::::write_pad(initial_size, &mut writer, 8) + IvfWriter::::write_pad(initial_size, &mut writer, 8) .unwrap(); assert_eq!(padding_written, 5); // 3 bytes written, so 5 bytes of padding needed @@ -459,9 +473,9 @@ mod tests { let quantizer = ProductQuantizer::new(3, 1, subvector_dimension, codebook, base_directory.clone()) .expect("Can't create product quantizer"); - let ivf_writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer); + let ivf_writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer); - let mut ivf_builder = IvfBuilder::new(IvfBuilderConfig { + let mut ivf_builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, @@ -547,7 +561,7 @@ mod tests { let file_size = 4096; let quantizer = NoQuantizer::new(num_features); - let ivf_writer = IvfWriter::<_, EliasFano>::new(base_directory.clone(), quantizer); + let ivf_writer = IvfWriter::<_, EliasFano, L2DistanceCalculator>::new(base_directory.clone(), quantizer); let mut ivf_builder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, @@ -633,9 +647,9 @@ mod tests { let num_features = 4; let file_size = 4096; let quantizer = NoQuantizer::new(num_features); - let writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer); + let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer); - let mut builder = IvfBuilder::new(IvfBuilderConfig { + let mut builder: IvfBuilder = IvfBuilder::new(IvfBuilderConfig { max_iteration: 1000, batch_size: 4, num_clusters, diff --git a/rs/index/src/spann/builder.rs b/rs/index/src/spann/builder.rs index 0ed12c5..17fb8d5 100644 --- a/rs/index/src/spann/builder.rs +++ b/rs/index/src/spann/builder.rs @@ -2,6 +2,7 @@ use anyhow::{Ok, Result}; use log::debug; use quantization::noq::noq::NoQuantizer; use serde::{Deserialize, Serialize}; +use utils::distance::l2::L2DistanceCalculator; use crate::hnsw::builder::HnswBuilder; use crate::ivf::builder::{IvfBuilder, IvfBuilderConfig}; @@ -64,13 +65,13 @@ impl Default for SpannBuilderConfig { pub struct SpannBuilder { pub config: SpannBuilderConfig, - pub ivf_builder: IvfBuilder, + pub ivf_builder: IvfBuilder, pub centroid_builder: HnswBuilder, } impl SpannBuilder { pub fn new(config: SpannBuilderConfig) -> Result { - let ivf_builder = IvfBuilder::new(IvfBuilderConfig { + let ivf_builder = IvfBuilder::::new(IvfBuilderConfig { max_iteration: config.max_iteration, batch_size: config.batch_size, num_clusters: config.num_clusters, diff --git a/rs/index/src/spann/index.rs b/rs/index/src/spann/index.rs index f8dd68a..41f4c63 100644 --- a/rs/index/src/spann/index.rs +++ b/rs/index/src/spann/index.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use log::debug; use quantization::noq::noq::NoQuantizer; +use utils::distance::l2::L2DistanceCalculator; use crate::hnsw::index::Hnsw; use crate::index::Searchable; @@ -9,11 +10,14 @@ use crate::ivf::index::Ivf; pub struct Spann { centroids: Hnsw, - posting_lists: Ivf, + posting_lists: Ivf, } impl Spann { - pub fn new(centroids: Hnsw, posting_lists: Ivf) -> Self { + pub fn new( + centroids: Hnsw, + posting_lists: Ivf, + ) -> Self { Self { centroids, posting_lists, @@ -24,7 +28,7 @@ impl Spann { &self.centroids } - pub fn get_posting_lists(&self) -> &Ivf { + pub fn get_posting_lists(&self) -> &Ivf { &self.posting_lists } } diff --git a/rs/index/src/spann/reader.rs b/rs/index/src/spann/reader.rs index 68f7ae0..3486659 100644 --- a/rs/index/src/spann/reader.rs +++ b/rs/index/src/spann/reader.rs @@ -1,5 +1,6 @@ use anyhow::Result; use quantization::noq::noq::NoQuantizer; +use utils::distance::l2::L2DistanceCalculator; use super::index::Spann; use crate::hnsw::reader::HnswReader; @@ -19,7 +20,8 @@ impl SpannReader { let centroid_path = format!("{}/centroids", self.base_directory); let centroids = HnswReader::new(centroid_path).read::()?; - let posting_lists = IvfReader::new(posting_list_path).read::()?; + let posting_lists = + IvfReader::new(posting_list_path).read::()?; Ok(Spann::new(centroids, posting_lists)) } diff --git a/rs/index/src/spann/writer.rs b/rs/index/src/spann/writer.rs index 1cd1187..4428617 100644 --- a/rs/index/src/spann/writer.rs +++ b/rs/index/src/spann/writer.rs @@ -2,6 +2,7 @@ use anyhow::Result; use compression::noc::noc::PlainEncoder; use log::debug; use quantization::noq::noq::{NoQuantizer, NoQuantizerWriter}; +use utils::distance::l2::L2DistanceCalculator; use super::builder::SpannBuilder; use crate::hnsw::writer::HnswWriter; @@ -51,7 +52,7 @@ impl SpannWriter { ivf_quantizer_writer.write(&ivf_quantizer)?; debug!("Writing IVF index"); - let ivf_writer = IvfWriter::<_, PlainEncoder>::new(ivf_directory, ivf_quantizer); + let ivf_writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(ivf_directory, ivf_quantizer); ivf_writer.write(&mut spann_builder.ivf_builder, index_writer_config.reindex)?; spann_builder.ivf_builder.cleanup()?; debug!("Finish writing IVF index"); diff --git a/rs/index_writer/src/config.rs b/rs/index_writer/src/config.rs index 5e59669..6530f67 100644 --- a/rs/index_writer/src/config.rs +++ b/rs/index_writer/src/config.rs @@ -8,6 +8,13 @@ pub enum QuantizerType { NoQuantizer, } +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] +pub enum DistanceType { + DotProduct, + #[default] + L2, +} + #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] pub enum IndexType { Hnsw, @@ -28,6 +35,7 @@ pub struct BaseConfig { pub file_size: usize, pub index_type: IndexType, + pub index_distance_type: DistanceType, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] @@ -37,6 +45,7 @@ pub struct QuantizerConfig { pub subvector_dimension: usize, pub num_bits: u8, pub num_training_rows: usize, + pub quantizer_distance_type: DistanceType, // Quantizer builder parameters pub max_iteration: usize, diff --git a/rs/index_writer/src/index_writer.rs b/rs/index_writer/src/index_writer.rs index 5c71ce7..66f7dd0 100644 --- a/rs/index_writer/src/index_writer.rs +++ b/rs/index_writer/src/index_writer.rs @@ -13,9 +13,13 @@ use quantization::pq::pq::{ProductQuantizer, ProductQuantizerConfig, ProductQuan use quantization::pq::pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderConfig}; use quantization::quantization::Quantizer; use rand::seq::SliceRandom; +use utils::distance::dot_product::DotProductDistanceCalculator; +use utils::distance::l2::L2DistanceCalculator; +use utils::{CalculateSquared, DistanceCalculator}; use crate::config::{ - HnswConfigWithBase, IndexWriterConfig, IvfConfigWithBase, QuantizerType, SpannConfigWithBase, + DistanceType, HnswConfigWithBase, IndexWriterConfig, IvfConfigWithBase, QuantizerType, + SpannConfigWithBase, }; use crate::input::Input; @@ -183,13 +187,18 @@ impl IndexWriter { Ok(()) } - fn write_quantizer_and_build_ivf_index Result<()>>( + fn write_quantizer_and_build_ivf_index( &mut self, input: &mut impl Input, index_builder_config: &IvfConfigWithBase, quantizer: T, writer_fn: F, - ) -> Result<()> { + ) -> Result<()> + where + T: Quantizer, + D: DistanceCalculator + CalculateSquared + Send + Sync, + F: Fn(&String, &T) -> Result<()>, + { info!("Start writing product quantizer"); let path = &self.output_root; @@ -199,7 +208,7 @@ impl IndexWriter { // Use the provided writer function to write the quantizer writer_fn(&quantizer_directory, &quantizer)?; - let mut ivf_builder = IvfBuilder::new(IvfBuilderConfig { + let mut ivf_builder = IvfBuilder::::new(IvfBuilderConfig { max_iteration: index_builder_config.ivf_config.max_iteration, batch_size: index_builder_config.ivf_config.batch_size, num_clusters: index_builder_config.ivf_config.num_clusters, @@ -229,7 +238,7 @@ impl IndexWriter { std::fs::create_dir_all(&path)?; info!("Start writing index"); - let ivf_writer = IvfWriter::<_, PlainEncoder>::new(path.to_string(), quantizer); + let ivf_writer = IvfWriter::<_, PlainEncoder, D>::new(path.to_string(), quantizer); ivf_writer.write(&mut ivf_builder, index_builder_config.base_config.reindex)?; // Cleanup tmp directory. It's ok to fail @@ -276,7 +285,22 @@ impl IndexWriter { pq_writer.write(pq) }; - self.write_quantizer_and_build_ivf_index(input, index_builder_config, pq, pq_writer_fn) + match index_builder_config.base_config.index_distance_type { + DistanceType::DotProduct => self + .write_quantizer_and_build_ivf_index::<_, DotProductDistanceCalculator, _>( + input, + index_builder_config, + pq, + pq_writer_fn, + ), + DistanceType::L2 => self + .write_quantizer_and_build_ivf_index::<_, L2DistanceCalculator, _>( + input, + index_builder_config, + pq, + pq_writer_fn, + ), + } } fn build_ivf_noq( @@ -299,7 +323,22 @@ impl IndexWriter { noq_writer.write(noq) }; - self.write_quantizer_and_build_ivf_index(input, index_builder_config, noq, noq_writer_fn) + match index_builder_config.base_config.index_distance_type { + DistanceType::DotProduct => self + .write_quantizer_and_build_ivf_index::<_, DotProductDistanceCalculator, _>( + input, + index_builder_config, + noq, + noq_writer_fn, + ), + DistanceType::L2 => self + .write_quantizer_and_build_ivf_index::<_, L2DistanceCalculator, _>( + input, + index_builder_config, + noq, + noq_writer_fn, + ), + } } fn do_build_ivf_index( @@ -430,7 +469,9 @@ mod tests { use tempdir::TempDir; use super::*; - use crate::config::{BaseConfig, HnswConfig, IndexType, IvfConfig, QuantizerConfig}; + use crate::config::{ + BaseConfig, DistanceType, HnswConfig, IndexType, IvfConfig, QuantizerConfig, + }; use crate::input::Row; // Mock Input implementation for testing struct MockInput { @@ -514,9 +555,11 @@ mod tests { max_memory_size: 1024 * 1024 * 1024, // 1 GB file_size: 1024 * 1024 * 1024, // 1 GB index_type: IndexType::Hnsw, + index_distance_type: DistanceType::L2, }; let quantizer_config = QuantizerConfig { quantizer_type: QuantizerType::ProductQuantizer, + quantizer_distance_type: DistanceType::L2, subvector_dimension: 2, num_bits: 2, num_training_rows: 50, @@ -585,9 +628,11 @@ mod tests { max_memory_size: 1024 * 1024 * 1024, // 1 GB file_size: 1024 * 1024 * 1024, // 1 GB index_type: IndexType::Ivf, + index_distance_type: DistanceType::DotProduct, }; let quantizer_config = QuantizerConfig { quantizer_type: QuantizerType::ProductQuantizer, + quantizer_distance_type: DistanceType::L2, subvector_dimension: 2, num_bits: 2, num_training_rows: 50, @@ -658,9 +703,11 @@ mod tests { max_memory_size: 1024 * 1024 * 1024, // 1 GB file_size: 1024 * 1024 * 1024, // 1 GB index_type: IndexType::Spann, + index_distance_type: DistanceType::L2, }; let quantizer_config = QuantizerConfig { quantizer_type: QuantizerType::ProductQuantizer, + quantizer_distance_type: DistanceType::L2, subvector_dimension: 2, num_bits: 2, num_training_rows: 50, diff --git a/rs/index_writer/src/input/hdf5.rs b/rs/index_writer/src/input/hdf5.rs index ecd300b..c30d0ff 100644 --- a/rs/index_writer/src/input/hdf5.rs +++ b/rs/index_writer/src/input/hdf5.rs @@ -131,7 +131,8 @@ mod tests { flattened_dataset.extend_from_slice(row.data); } - let kmeans = KMeansBuilder::new(10, 10000, 0.0, 128, KMeansVariant::Lloyd); + let kmeans = + KMeansBuilder::::new(10, 10000, 0.0, 128, KMeansVariant::Lloyd); let result = kmeans .fit(flattened_dataset.clone()) .expect("Failed to run KMeans model"); diff --git a/rs/proto/src/muopdb.rs b/rs/proto/src/muopdb.rs index d0471dc..4b78a3a 100644 --- a/rs/proto/src/muopdb.rs +++ b/rs/proto/src/muopdb.rs @@ -93,8 +93,8 @@ pub struct InsertBinaryResponse {} /// Generated client implementations. pub mod aggregator_client { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; use tonic::codegen::http::Uri; + use tonic::codegen::*; #[derive(Debug, Clone)] pub struct AggregatorClient { inner: tonic::client::Grpc, @@ -138,9 +138,8 @@ pub mod aggregator_client { >::ResponseBody, >, >, - , - >>::Error: Into + Send + Sync, + >>::Error: + Into + Send + Sync, { AggregatorClient::new(InterceptedService::new(inner, interceptor)) } @@ -163,15 +162,12 @@ pub mod aggregator_client { &mut self, request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static("/muopdb.Aggregator/Get"); self.inner.unary(request.into_request(), path, codec).await @@ -181,8 +177,8 @@ pub mod aggregator_client { /// Generated client implementations. pub mod index_server_client { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; use tonic::codegen::http::Uri; + use tonic::codegen::*; #[derive(Debug, Clone)] pub struct IndexServerClient { inner: tonic::client::Grpc, @@ -226,9 +222,8 @@ pub mod index_server_client { >::ResponseBody, >, >, - , - >>::Error: Into + Send + Sync, + >>::Error: + Into + Send + Sync, { IndexServerClient::new(InterceptedService::new(inner, interceptor)) } @@ -251,38 +246,28 @@ pub mod index_server_client { &mut self, request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/muopdb.IndexServer/Search", - ); + let path = http::uri::PathAndQuery::from_static("/muopdb.IndexServer/Search"); self.inner.unary(request.into_request(), path, codec).await } pub async fn insert( &mut self, request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/muopdb.IndexServer/Insert", - ); + let path = http::uri::PathAndQuery::from_static("/muopdb.IndexServer/Insert"); self.inner.unary(request.into_request(), path, codec).await } pub async fn insert_binary( @@ -308,15 +293,12 @@ pub mod index_server_client { &mut self, request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static("/muopdb.IndexServer/Flush"); self.inner.unary(request.into_request(), path, codec).await @@ -354,10 +336,7 @@ pub mod aggregator_server { send_compression_encodings: Default::default(), } } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> InterceptedService + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService where F: tonic::service::Interceptor, { @@ -385,10 +364,7 @@ pub mod aggregator_server { type Response = http::Response; type Error = std::convert::Infallible; type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { @@ -397,13 +373,9 @@ pub mod aggregator_server { "/muopdb.Aggregator/Get" => { #[allow(non_camel_case_types)] struct GetSvc(pub Arc); - impl tonic::server::UnaryService - for GetSvc { + impl tonic::server::UnaryService for GetSvc { type Response = super::GetResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -420,28 +392,23 @@ pub mod aggregator_server { let inner = inner.0; let method = GetSvc(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ); + let mut grpc = tonic::server::Grpc::new(codec).apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ); let res = grpc.unary(method, req).await; Ok(res) }; Box::pin(fut) } - _ => { - Box::pin(async move { - Ok( - http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap(), - ) - }) - } + _ => Box::pin(async move { + Ok(http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap()) + }), } } } @@ -512,10 +479,7 @@ pub mod index_server_server { send_compression_encodings: Default::default(), } } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> InterceptedService + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService where F: tonic::service::Interceptor, { @@ -543,10 +507,7 @@ pub mod index_server_server { type Response = http::Response; type Error = std::convert::Infallible; type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { @@ -555,15 +516,9 @@ pub mod index_server_server { "/muopdb.IndexServer/Search" => { #[allow(non_camel_case_types)] struct SearchSvc(pub Arc); - impl< - T: IndexServer, - > tonic::server::UnaryService - for SearchSvc { + impl tonic::server::UnaryService for SearchSvc { type Response = super::SearchResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -580,11 +535,10 @@ pub mod index_server_server { let inner = inner.0; let method = SearchSvc(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ); + let mut grpc = tonic::server::Grpc::new(codec).apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ); let res = grpc.unary(method, req).await; Ok(res) }; @@ -593,15 +547,9 @@ pub mod index_server_server { "/muopdb.IndexServer/Insert" => { #[allow(non_camel_case_types)] struct InsertSvc(pub Arc); - impl< - T: IndexServer, - > tonic::server::UnaryService - for InsertSvc { + impl tonic::server::UnaryService for InsertSvc { type Response = super::InsertResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -618,11 +566,10 @@ pub mod index_server_server { let inner = inner.0; let method = InsertSvc(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ); + let mut grpc = tonic::server::Grpc::new(codec).apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ); let res = grpc.unary(method, req).await; Ok(res) }; @@ -671,13 +618,9 @@ pub mod index_server_server { "/muopdb.IndexServer/Flush" => { #[allow(non_camel_case_types)] struct FlushSvc(pub Arc); - impl tonic::server::UnaryService - for FlushSvc { + impl tonic::server::UnaryService for FlushSvc { type Response = super::FlushResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -694,28 +637,23 @@ pub mod index_server_server { let inner = inner.0; let method = FlushSvc(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ); + let mut grpc = tonic::server::Grpc::new(codec).apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ); let res = grpc.unary(method, req).await; Ok(res) }; Box::pin(fut) } - _ => { - Box::pin(async move { - Ok( - http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap(), - ) - }) - } + _ => Box::pin(async move { + Ok(http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap()) + }), } } } diff --git a/rs/utils/benches/dot_product.rs b/rs/utils/benches/dot_product.rs index 0c38634..e10fd70 100644 --- a/rs/utils/benches/dot_product.rs +++ b/rs/utils/benches/dot_product.rs @@ -1,7 +1,6 @@ - use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use utils::test_utils::generate_random_vector; use utils::distance::dot_product::DotProductDistanceCalculator; +use utils::test_utils::generate_random_vector; use utils::DistanceCalculator; fn benches_dot_product(c: &mut Criterion) { @@ -14,7 +13,8 @@ fn benches_dot_product(c: &mut Criterion) { 1536, // VECTOR_DIM_OPENAI_SMALL 3072, // VECTOR_DIM_OPENAI_LARGE ] - .iter() { + .iter() + { let a = generate_random_vector(*size); let b = generate_random_vector(*size); diff --git a/rs/utils/benches/kmeans.rs b/rs/utils/benches/kmeans.rs index 365645e..bc03592 100644 --- a/rs/utils/benches/kmeans.rs +++ b/rs/utils/benches/kmeans.rs @@ -1,4 +1,5 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use utils::distance::l2::L2DistanceCalculator; use utils::kmeans_builder::kmeans_builder; fn bench_kmeans(c: &mut Criterion) { @@ -12,7 +13,7 @@ fn bench_kmeans(c: &mut Criterion) { } } - let kmeans = kmeans_builder::KMeansBuilder::new( + let kmeans = kmeans_builder::KMeansBuilder::::new( 100, 1000, 0.0, diff --git a/rs/utils/src/distance/dot_product.rs b/rs/utils/src/distance/dot_product.rs index 3c3a2bd..59aa0ad 100644 --- a/rs/utils/src/distance/dot_product.rs +++ b/rs/utils/src/distance/dot_product.rs @@ -1,5 +1,8 @@ -use crate::DistanceCalculator; -use std::{ops::AddAssign, simd::{num::SimdFloat, LaneCount, Simd, SupportedLaneCount}}; +use std::ops::AddAssign; +use std::simd::num::SimdFloat; +use std::simd::{LaneCount, Simd, SupportedLaneCount}; + +use crate::{CalculateSquared, DistanceCalculator}; pub struct DotProductDistanceCalculator {} @@ -22,7 +25,13 @@ impl DotProductDistanceCalculator { pub fn neg_score(x: f32) -> f32 { -x } -} +} + +impl CalculateSquared for DotProductDistanceCalculator { + fn calculate_squared(a: &[f32], b: &[f32]) -> f32 { + DotProductDistanceCalculator::calculate(a, b) + } +} impl DistanceCalculator for DotProductDistanceCalculator { #[inline(always)] @@ -30,17 +39,17 @@ impl DistanceCalculator for DotProductDistanceCalculator { let mut res = 0.0; let mut a_vec = a; let mut b_vec = b; - - if a_vec.len() > 16 { - let mut accumulator= Simd::::splat(0.0); + + if a_vec.len() > 16 { + let mut accumulator = Simd::::splat(0.0); Self::accumulate_lanes::<16>(a_vec, b_vec, &mut accumulator); res += accumulator.reduce_sum(); a_vec = a_vec.chunks_exact(16).remainder(); b_vec = b_vec.chunks_exact(16).remainder(); } - if a_vec.len() > 8 { - let mut accumulator= Simd::::splat(0.0); + if a_vec.len() > 8 { + let mut accumulator = Simd::::splat(0.0); Self::accumulate_lanes::<8>(a_vec, b_vec, &mut accumulator); res += accumulator.reduce_sum(); a_vec = a_vec.chunks_exact(8).remainder(); @@ -48,7 +57,7 @@ impl DistanceCalculator for DotProductDistanceCalculator { } if a_vec.len() > 4 { - let mut accumulator= Simd::::splat(0.0); + let mut accumulator = Simd::::splat(0.0); Self::accumulate_lanes::<4>(a_vec, b_vec, &mut accumulator); res += accumulator.reduce_sum(); a_vec = a_vec.chunks_exact(4).remainder(); @@ -66,8 +75,8 @@ impl DistanceCalculator for DotProductDistanceCalculator { a: &[f32], b: &[f32], accumulator: &mut Simd, - ) where - LaneCount: SupportedLaneCount, + ) where + LaneCount: SupportedLaneCount, { a.chunks_exact(LANES) .zip(b.chunks_exact(LANES)) @@ -77,13 +86,17 @@ impl DistanceCalculator for DotProductDistanceCalculator { accumulator.add_assign(a_simd * b_simd); }); } + + #[inline(always)] + fn outermost_op(x: f32) -> f32 { + Self::neg_score(x) + } } #[cfg(test)] mod tests { use super::*; use crate::test_utils::generate_random_vector; - #[test] fn test_dot_product_distance_calculator() { let a = generate_random_vector(128); diff --git a/rs/utils/src/distance/l2.rs b/rs/utils/src/distance/l2.rs index 77108c5..61e8f94 100644 --- a/rs/utils/src/distance/l2.rs +++ b/rs/utils/src/distance/l2.rs @@ -86,23 +86,10 @@ impl DistanceCalculator for L2DistanceCalculator { *acc += diff.mul(diff); }); } -} - -/// Calculator where we know in advance that the dimension of vectors is a multiple of LANES. -/// This skips a bunch of checks and allows for a more efficient implementation. -pub struct LaneConformingL2DistanceCalculator -where - LaneCount: SupportedLaneCount, {} - -impl CalculateSquared for LaneConformingL2DistanceCalculator -where - LaneCount: SupportedLaneCount, -{ + #[inline(always)] - fn calculate_squared(a: &[f32], b: &[f32]) -> f32 { - let mut simd = Simd::::splat(0.0); - L2DistanceCalculator::accumulate_lanes(a, b, &mut simd); - simd.reduce_sum() + fn outermost_op(x: f32) -> f32 { + x } } diff --git a/rs/utils/src/distance/lane_conforming.rs b/rs/utils/src/distance/lane_conforming.rs new file mode 100644 index 0000000..98565fb --- /dev/null +++ b/rs/utils/src/distance/lane_conforming.rs @@ -0,0 +1,51 @@ +use std::simd::num::SimdFloat; +use std::{marker::PhantomData, simd::{LaneCount, Simd, SupportedLaneCount}}; + +use crate::{CalculateSquared, DistanceCalculator}; + +/// Calculator where we know in advance that the dimension of vectors is a multiple of LANES. +/// This skips a bunch of checks and allows for a more efficient implementation. +pub struct LaneConformingDistanceCalculator +where + LaneCount: SupportedLaneCount, +{ + _marker: PhantomData, +} + +impl CalculateSquared + for LaneConformingDistanceCalculator +where + LaneCount: SupportedLaneCount, +{ + #[inline(always)] + fn calculate_squared(a: &[f32], b: &[f32]) -> f32 { + let mut simd = Simd::::splat(0.0); + D::accumulate_lanes(a, b, &mut simd); + D::outermost_op(simd.reduce_sum()) + } +} +#[cfg(test)] +mod tests { + use crate::{distance::{dot_product::DotProductDistanceCalculator, l2::L2DistanceCalculator}, test_utils::generate_random_vector}; + use super::*; + + #[test] + fn test_calculate_l2_distance() { + let a = generate_random_vector(16); + let b = generate_random_vector(16); + let eps = 1e-5; + let conforming_result = LaneConformingDistanceCalculator::<4, L2DistanceCalculator>::calculate_squared(&a, &b); + let l2_result = L2DistanceCalculator::calculate_squared(&a, &b); + assert!((conforming_result - l2_result).abs() < eps) + } + + #[test] + fn test_calculate_dot_product_distance() { + let a = generate_random_vector(16); + let b = generate_random_vector(16); + let eps = 1e-5; + let conforming_result = LaneConformingDistanceCalculator::<4, DotProductDistanceCalculator>::calculate_squared(&a, &b); + let dot_product_result = DotProductDistanceCalculator::calculate_squared(&a, &b); + assert!((conforming_result - dot_product_result).abs() < eps) + } +} diff --git a/rs/utils/src/distance/mod.rs b/rs/utils/src/distance/mod.rs index bc60355..2c8b563 100644 --- a/rs/utils/src/distance/mod.rs +++ b/rs/utils/src/distance/mod.rs @@ -1,2 +1,3 @@ -pub mod l2; pub mod dot_product; +pub mod l2; +pub mod lane_conforming; diff --git a/rs/utils/src/kmeans_builder/kmeans_builder.rs b/rs/utils/src/kmeans_builder/kmeans_builder.rs index 6ee271b..6d0a80e 100644 --- a/rs/utils/src/kmeans_builder/kmeans_builder.rs +++ b/rs/utils/src/kmeans_builder/kmeans_builder.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use std::cmp::min; use std::simd::{LaneCount, Simd, SupportedLaneCount}; @@ -8,15 +9,15 @@ use rand::seq::SliceRandom; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rayon::slice::ParallelSlice; -use crate::distance::l2::{L2DistanceCalculator, LaneConformingL2DistanceCalculator}; -use crate::CalculateSquared; +use crate::distance::lane_conforming::LaneConformingDistanceCalculator; +use crate::{CalculateSquared, DistanceCalculator}; #[derive(PartialEq, Debug)] pub enum KMeansVariant { Lloyd, } -pub struct KMeansBuilder { +pub struct KMeansBuilder { pub num_clusters: usize, pub max_iter: usize, @@ -30,6 +31,8 @@ pub struct KMeansBuilder { pub variant: KMeansVariant, pub cluster_init_values: Option>, + + _marker: PhantomData, } pub struct KMeansResult { @@ -41,7 +44,7 @@ pub struct KMeansResult { // TODO(hicder): Add support for different variants of k-means. // TODO(hicder): Add support for different distance metrics. -impl KMeansBuilder { +impl KMeansBuilder { pub fn new( num_cluters: usize, max_iter: usize, @@ -56,6 +59,7 @@ impl KMeansBuilder { dimension, variant, cluster_init_values: None, + _marker: PhantomData, } } @@ -74,6 +78,7 @@ impl KMeansBuilder { dimension, variant, cluster_init_values: Some(cluster_init_values), + _marker: PhantomData, } } @@ -122,15 +127,15 @@ impl KMeansBuilder { KMeansVariant::Lloyd => { if self.dimension % 16 == 0 { return self - .run_lloyd::, 16>(flattened_data); + .run_lloyd::, 16>(flattened_data); } else if self.dimension % 8 == 0 { return self - .run_lloyd::, 8>(flattened_data); + .run_lloyd::, 8>(flattened_data); } else if self.dimension % 4 == 0 { return self - .run_lloyd::, 4>(flattened_data); + .run_lloyd::, 4>(flattened_data); } else { - return self.run_lloyd::(flattened_data); + return self.run_lloyd::(flattened_data); } } } @@ -286,7 +291,7 @@ impl KMeansBuilder { .nth(cluster_id) .unwrap(); let distance = - L2DistanceCalculator::calculate_squared(point, cluster); + T::calculate_squared(point, cluster); if distance > max_distance { max_distance = distance; chosen_point_id = i; @@ -365,6 +370,9 @@ impl KMeansBuilder { #[cfg(test)] mod tests { + + use crate::distance::l2::L2DistanceCalculator; + use super::*; #[test] @@ -388,7 +396,7 @@ mod tests { .cloned() .collect(); - let kmeans = KMeansBuilder::new_with_cluster_init_values( + let kmeans = KMeansBuilder::::new_with_cluster_init_values( 3, 100, 1e-4, @@ -435,7 +443,7 @@ mod tests { .flatten() .cloned() .collect(); - let kmeans = KMeansBuilder::new_with_cluster_init_values( + let kmeans = KMeansBuilder::::new_with_cluster_init_values( 3, 100, 0.0, diff --git a/rs/utils/src/lib.rs b/rs/utils/src/lib.rs index 3185c7f..fc89020 100644 --- a/rs/utils/src/lib.rs +++ b/rs/utils/src/lib.rs @@ -18,6 +18,12 @@ pub trait DistanceCalculator { accumulator: &mut Simd, ) where LaneCount: SupportedLaneCount; + + /* + * The outermost operator of the distance function, + * to be used with accumulate_lanes for lane conforming code. + */ + fn outermost_op(x: f32) -> f32; } pub trait CalculateSquared { diff --git a/rs/utils/src/scripts/run_kmeans.rs b/rs/utils/src/scripts/run_kmeans.rs index 401c95e..981a8e1 100644 --- a/rs/utils/src/scripts/run_kmeans.rs +++ b/rs/utils/src/scripts/run_kmeans.rs @@ -1,3 +1,4 @@ +use utils::distance::l2::L2DistanceCalculator; use utils::kmeans_builder::kmeans_builder::{KMeansBuilder, KMeansVariant}; fn main() { @@ -12,7 +13,8 @@ fn main() { } } - let kmeans = KMeansBuilder::new(10000, 5, 0.0, dimension, KMeansVariant::Lloyd); + let kmeans = + KMeansBuilder::::new(10000, 5, 0.0, dimension, KMeansVariant::Lloyd); let _result = kmeans .fit(flattened_dataset) .expect("Failed to run KMeans model");