Skip to content

Commit

Permalink
Dot-product implementation and distance template (#237)
Browse files Browse the repository at this point in the history
* add distance config for index and quantizer

* distance template for IVF and kmeans index

* move lane_conforming to another file and update code according to suggested comments
  • Loading branch information
thinh2 authored Jan 3, 2025
1 parent 87db916 commit 0c096c2
Show file tree
Hide file tree
Showing 22 changed files with 356 additions and 250 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
flamegraph.svg
perf.data*
.idea
.DS_store
.DS_Store
.venv
*/__pycache__/
33 changes: 18 additions & 15 deletions rs/index/src/ivf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::cmp::{max, min, Ordering, Reverse};
use std::collections::{BinaryHeap, HashMap};
use std::fs::{create_dir, create_dir_all};
use std::io::ErrorKind;
use std::marker::PhantomData;

use anyhow::{anyhow, Result};
use atomic_refcell::AtomicRefCell;
Expand Down Expand Up @@ -38,12 +39,13 @@ pub struct IvfBuilderConfig {
pub max_posting_list_size: usize,
}

pub struct IvfBuilder {
pub struct IvfBuilder<D: DistanceCalculator + CalculateSquared + Send + Sync> {
config: IvfBuilderConfig,
vectors: AtomicRefCell<Box<dyn VectorStorage<f32> + Send + Sync>>,
centroids: AtomicRefCell<Box<dyn VectorStorage<f32> + Send + Sync>>,
posting_lists: Box<dyn for<'a> PostingListStorage<'a>>,
doc_id_mapping: Vec<u64>,
_marker: PhantomData<D>,
}

// TODO(tyb): maybe merge with HNSW's one
Expand Down Expand Up @@ -141,7 +143,7 @@ impl PartialEq for PostingListWithStoppingPoints {

impl Eq for PostingListWithStoppingPoints {}

impl IvfBuilder {
impl<D: DistanceCalculator + CalculateSquared + Send + Sync> IvfBuilder<D> {
/// Create a new IvfBuilder
pub fn new(config: IvfBuilderConfig) -> Result<Self> {
// Create the base directory and all parent directories if they don't exist
Expand Down Expand Up @@ -184,6 +186,7 @@ impl IvfBuilder {
centroids,
posting_lists,
doc_id_mapping: Vec::new(),
_marker: PhantomData,
})
}

Expand Down Expand Up @@ -245,7 +248,7 @@ impl IvfBuilder {
let mut centroid_index = 0;
for i in 0..flattened_centroids.len() / dimension {
let centroid = &flattened_centroids[i * dimension..(i + 1) * dimension];
let dist = L2DistanceCalculator::calculate(&vector, &centroid);
let dist = D::calculate(&vector, &centroid);
if dist > max_distance {
max_distance = dist;
centroid_index = i;
Expand Down Expand Up @@ -415,7 +418,7 @@ impl IvfBuilder {
num_clusters * 10,
self.config.num_data_points_for_clustering,
);
let kmeans = KMeansBuilder::new(
let kmeans = KMeansBuilder::<D>::new(
num_clusters,
self.config.max_iteration,
self.config.tolerance,
Expand Down Expand Up @@ -453,7 +456,7 @@ impl IvfBuilder {
self.config.num_clusters,
self.config.max_posting_list_size,
);
let kmeans = KMeansBuilder::new(
let kmeans = KMeansBuilder::<D>::new(
num_clusters,
self.config.max_iteration,
self.config.tolerance,
Expand Down Expand Up @@ -789,7 +792,7 @@ mod tests {
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -853,7 +856,7 @@ mod tests {
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -891,7 +894,7 @@ mod tests {
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -947,7 +950,7 @@ mod tests {
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -1014,7 +1017,7 @@ mod tests {
let file_size = 4096 * 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -1087,7 +1090,7 @@ mod tests {
let file_size = 4096 * 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -1160,7 +1163,7 @@ mod tests {
let file_size = 4096 * 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -1255,7 +1258,7 @@ mod tests {
let file_size = 4096 * 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -1333,7 +1336,7 @@ mod tests {
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
const NUM_VECTORS: usize = 22;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -1408,7 +1411,7 @@ mod tests {
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down
39 changes: 27 additions & 12 deletions rs/index/src/ivf/index.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::BinaryHeap;
use std::marker::PhantomData;

use anyhow::{Context, Result};
use quantization::quantization::Quantizer;
use quantization::typing::VectorOps;
use utils::distance::l2::L2DistanceCalculator;
use utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD;
use utils::mem::transmute_u8_to_slice;
use utils::DistanceCalculator;
Expand All @@ -13,7 +13,7 @@ use crate::posting_list::combined_file::FixedIndexFile;
use crate::utils::{IdWithScore, SearchContext};
use crate::vector::fixed_file::FixedFileVectorStorage;

pub struct Ivf<Q: Quantizer> {
pub struct Ivf<Q: Quantizer, D: DistanceCalculator> {
// The dataset.
pub vector_storage: FixedFileVectorStorage<Q::QuantizedT>,

Expand All @@ -29,9 +29,11 @@ pub struct Ivf<Q: Quantizer> {
pub num_clusters: usize,

pub quantizer: Q,

_marker: PhantomData<D>,
}

impl<Q: Quantizer> Ivf<Q> {
impl<Q: Quantizer, D: DistanceCalculator> Ivf<Q, D> {
pub fn new(
vector_storage: FixedFileVectorStorage<Q::QuantizedT>,
index_storage: FixedIndexFile,
Expand All @@ -43,6 +45,7 @@ impl<Q: Quantizer> Ivf<Q> {
index_storage,
num_clusters,
quantizer,
_marker: PhantomData,
}
}

Expand All @@ -56,7 +59,7 @@ impl<Q: Quantizer> Ivf<Q> {
let centroid = index_storage
.get_centroid(i as usize)
.with_context(|| format!("Failed to get centroid at index {}", i))?;
let dist = L2DistanceCalculator::calculate(&vector, &centroid);
let dist = D::calculate(&vector, &centroid);
distances.push((i as usize, dist));
}
distances.select_nth_unstable_by(num_probes - 1, |a, b| a.1.total_cmp(&b.1));
Expand Down Expand Up @@ -146,7 +149,7 @@ impl<Q: Quantizer> Ivf<Q> {
}
}

impl<Q: Quantizer> Searchable for Ivf<Q> {
impl<Q: Quantizer, D: DistanceCalculator> Searchable for Ivf<Q, D> {
fn search(
&self,
query: &[f32],
Expand Down Expand Up @@ -180,6 +183,7 @@ mod tests {
use num_traits::ops::bytes::ToBytes;
use quantization::noq::noq::NoQuantizer;
use quantization::pq::pq::ProductQuantizer;
use utils::distance::l2::L2DistanceCalculator;
use utils::mem::transmute_slice_to_u8;

use super::*;
Expand Down Expand Up @@ -337,7 +341,12 @@ mod tests {
let num_clusters = 2;

let quantizer = NoQuantizer::new(3);
let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer);
let ivf = Ivf::<NoQuantizer, L2DistanceCalculator>::new(
storage,
index_storage,
num_clusters,
quantizer,
);

assert_eq!(ivf.num_clusters, num_clusters);
let cluster_0 = transmute_u8_to_slice::<u64>(
Expand Down Expand Up @@ -383,9 +392,12 @@ mod tests {
FixedIndexFile::new(file_path).expect("FixedIndexFile should be created");
let num_probes = 2;

let nearest =
Ivf::<NoQuantizer>::find_nearest_centroids(&vector, &index_storage, num_probes)
.expect("Nearest centroids should be found");
let nearest = Ivf::<NoQuantizer, L2DistanceCalculator>::find_nearest_centroids(
&vector,
&index_storage,
num_probes,
)
.expect("Nearest centroids should be found");

assert_eq!(nearest[0], 1);
assert_eq!(nearest[1], 0);
Expand Down Expand Up @@ -431,7 +443,8 @@ mod tests {
let num_probes = 2;

let quantizer = NoQuantizer::new(num_features);
let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer);
let ivf: Ivf<NoQuantizer, L2DistanceCalculator> =
Ivf::new(storage, index_storage, num_clusters, quantizer);

let query = vec![2.0, 3.0, 4.0];
let k = 2;
Expand Down Expand Up @@ -505,7 +518,8 @@ mod tests {
let num_clusters = 2;
let num_probes = 2;

let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer);
let ivf: Ivf<ProductQuantizer, L2DistanceCalculator> =
Ivf::new(storage, index_storage, num_clusters, quantizer);

let query = vec![2.0, 3.0, 4.0];
let k = 2;
Expand Down Expand Up @@ -557,7 +571,8 @@ mod tests {
let num_probes = 1;

let quantizer = NoQuantizer::new(num_features);
let ivf = Ivf::new(storage, index_storage, num_clusters, quantizer);
let ivf: Ivf<NoQuantizer, L2DistanceCalculator> =
Ivf::new(storage, index_storage, num_clusters, quantizer);

let query = vec![1.0, 2.0, 3.0];
let k = 5; // More than available results
Expand Down
16 changes: 9 additions & 7 deletions rs/index/src/ivf/reader.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use quantization::quantization::Quantizer;
use utils::DistanceCalculator;

use crate::ivf::index::Ivf;
use crate::posting_list::combined_file::FixedIndexFile;
Expand All @@ -14,7 +15,7 @@ impl IvfReader {
Self { base_directory }
}

pub fn read<Q: Quantizer>(&self) -> Result<Ivf<Q>> {
pub fn read<Q: Quantizer, D: DistanceCalculator>(&self) -> Result<Ivf<Q, D>> {
let index_storage = FixedIndexFile::new(format!("{}/index", self.base_directory))?;

let vector_storage_path = format!("{}/vectors", self.base_directory);
Expand Down Expand Up @@ -45,6 +46,7 @@ mod tests {
use compression::noc::noc::PlainEncoder;
use quantization::noq::noq::{NoQuantizer, NoQuantizerWriter};
use tempdir::TempDir;
use utils::distance::l2::L2DistanceCalculator;
use utils::mem::transmute_u8_to_slice;
use utils::test_utils::generate_random_vector;

Expand All @@ -68,9 +70,9 @@ mod tests {
let num_features = 4;
let file_size = 4096;
let quantizer = NoQuantizer::new(num_features);
let writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer);
let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer);

let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand Down Expand Up @@ -105,7 +107,7 @@ mod tests {

let reader = IvfReader::new(base_directory.clone());
let index = reader
.read::<NoQuantizer>()
.read::<NoQuantizer, L2DistanceCalculator>()
.expect("Failed to read index file");

// Check if files were created
Expand Down Expand Up @@ -212,9 +214,9 @@ mod tests {
let noq_writer = NoQuantizerWriter::new(quantizer_directory);
assert!(noq_writer.write(&quantizer).is_ok());

let writer = IvfWriter::<_, PlainEncoder>::new(base_directory.clone(), quantizer);
let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer);

let mut builder = IvfBuilder::new(IvfBuilderConfig {
let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
Expand All @@ -241,7 +243,7 @@ mod tests {

let reader = IvfReader::new(base_directory.clone());
let index = reader
.read::<NoQuantizer>()
.read::<NoQuantizer, L2DistanceCalculator>()
.expect("Failed to read index file");

let num_centroids = index.num_clusters;
Expand Down
Loading

0 comments on commit 0c096c2

Please sign in to comment.