Skip to content

Commit

Permalink
Adding a vector to multiple posting lists (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
tyb0807 authored Dec 5, 2024
1 parent d548578 commit 2f62fe0
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 13 deletions.
114 changes: 103 additions & 11 deletions rs/index/src/ivf/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::fs::{create_dir, create_dir_all};

Expand All @@ -18,8 +19,10 @@ pub struct IvfBuilderConfig {
pub num_clusters: usize,
pub num_data_points: usize,
pub max_clusters_per_vector: usize,
// Threshold to add a vector to more than one cluster
pub distance_threshold: f32,

// Parameters for storages.
// Parameters for storages
pub base_directory: String,
pub memory_size: usize,
pub file_size: usize,
Expand All @@ -38,6 +41,12 @@ pub struct IvfBuilder {
doc_id_mapping: Vec<u64>,
}

// TODO(tyb): maybe merge with HNSW's one
pub struct PointAndDistance {
pub point_id: usize,
pub distance: f32,
}

#[derive(Debug)]
struct PostingListInfo {
centroid: Vec<f32>,
Expand Down Expand Up @@ -176,35 +185,48 @@ impl IvfBuilder {
vector: &[f32],
centroids: &dyn VectorStorage<f32>,
num_probes: usize,
) -> Result<Vec<usize>> {
let mut distances: Vec<(usize, f32)> = Vec::new();
) -> Result<Vec<PointAndDistance>> {
let mut distances: Vec<PointAndDistance> = Vec::new();
let num_centroids = centroids.len();
for i in 0..num_centroids {
let centroid = centroids.get(i as u32)?;
let dist = L2DistanceCalculator::calculate_squared(&vector, &centroid);
if dist.is_nan() {
println!("NAN found");
}
distances.push((i, dist));
distances.push(PointAndDistance {
point_id: i,
distance: dist,
});
}
distances.select_nth_unstable_by(num_probes - 1, |a, b| a.1.total_cmp(&b.1));
distances.select_nth_unstable_by(num_probes - 1, |a, b| a.distance.total_cmp(&b.distance));
distances.truncate(num_probes);
Ok(distances.into_iter().map(|(idx, _)| idx).collect())
Ok(distances)
}

pub fn build_posting_lists(&mut self) -> Result<()> {
let mut posting_lists: Vec<Vec<u64>> = vec![Vec::with_capacity(0); self.centroids.len()];
// Assign vectors to nearest centroids
for i in 0..self.vectors.len() {
let vector = self.vectors.get(i as u32)?;
let nearest_centroid = Self::find_nearest_centroids(
let nearest_centroids = Self::find_nearest_centroids(
&vector,
self.centroids.as_ref(),
self.config.max_clusters_per_vector,
)?;

for centroid_id in nearest_centroid {
posting_lists[centroid_id].push(i as u64);
// Find the nearest distance, ensuring that NaN values are treated as greater than any
// other value
let nearest_distance = nearest_centroids
.iter()
.map(|pad| pad.distance)
.min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Greater))
.expect("nearest_distance should not be None");
for point_and_distance in nearest_centroids.iter() {
if (point_and_distance.distance - nearest_distance).abs()
<= nearest_distance * self.config.distance_threshold
{
posting_lists[point_and_distance.point_id].push(i as u64);
}
}
}

Expand Down Expand Up @@ -404,9 +426,78 @@ mod tests {
}

#[test]
fn test_ivf_builder() {
fn test_build_posting_lists() {
env_logger::init();

let temp_dir = tempdir::TempDir::new("buil_posting_lists_test")
.expect("Failed to create temporary directory");
let base_directory = temp_dir
.path()
.to_str()
.expect("Failed to convert temporary directory path to string")
.to_string();
let num_clusters = 2;
let num_vectors = 6;
let num_features = 1;
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
num_data_points: num_vectors,
max_clusters_per_vector: 2,
distance_threshold: 0.1,
base_directory,
memory_size: 1024,
file_size,
num_features,
tolerance: balance_factor,
max_posting_list_size,
})
.expect("Failed to create builder");
// Generate 1000 vectors of f32, dimension 4
for i in 0..num_vectors {
builder
.add_vector(i as u64, vec![(i + 1) as f32])
.expect("Vector should be added");
}

let _ = builder.add_centroid(&[2.5]);
let _ = builder.add_centroid(&[5.5]);

let result = builder.build_posting_lists();
assert!(result.is_ok());

assert_eq!(
builder
.posting_lists
.get(0)
.expect("Failed to get posting list")
.slices
.iter()
.flat_map(|slice| slice.chunks_exact(8))
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<_>>(),
vec![0, 1, 2, 3]
);
assert_eq!(
builder
.posting_lists
.get(1)
.expect("Failed to get posting list")
.slices
.iter()
.flat_map(|slice| slice.chunks_exact(8))
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<_>>(),
vec![3, 4, 5]
);
}

#[test]
fn test_ivf_builder() {
let temp_dir = tempdir::TempDir::new("ivf_builder_test")
.expect("Failed to create temporary directory");
let base_directory = temp_dir
Expand All @@ -426,6 +517,7 @@ mod tests {
num_clusters,
num_data_points: num_vectors,
max_clusters_per_vector: 1,
distance_threshold: 0.1,
base_directory,
memory_size: 1024,
file_size,
Expand Down
2 changes: 2 additions & 0 deletions rs/index/src/ivf/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ mod tests {
num_clusters,
num_data_points: num_vectors,
max_clusters_per_vector: 1,
distance_threshold: 0.1,
base_directory: base_directory.clone(),
memory_size: 1024,
file_size,
Expand Down Expand Up @@ -182,6 +183,7 @@ mod tests {
num_clusters,
num_data_points: num_vectors,
max_clusters_per_vector: 1,
distance_threshold: 0.1,
base_directory: base_directory.clone(),
memory_size: 1024,
file_size,
Expand Down
1 change: 1 addition & 0 deletions rs/index/src/ivf/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ mod tests {
num_clusters,
num_data_points: num_vectors,
max_clusters_per_vector: 1,
distance_threshold: 0.1,
base_directory: base_directory.clone(),
memory_size: 1024,
file_size,
Expand Down
2 changes: 1 addition & 1 deletion rs/index/src/posting_list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct PostingListStorageConfig {
}

pub struct PostingList<'a> {
slices: Vec<&'a [u8]>,
pub slices: Vec<&'a [u8]>,
}

pub struct PostingListIterator<'a> {
Expand Down
1 change: 1 addition & 0 deletions rs/index_writer/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub struct IvfConfig {
pub num_clusters: usize,
pub num_data_points: usize,
pub max_clusters_per_vector: usize,
pub distance_threshold: f32,

// KMeans training parameters
pub max_iteration: usize,
Expand Down
6 changes: 5 additions & 1 deletion rs/index_writer/src/index_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use quantization::pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderC
use rand::seq::SliceRandom;

use crate::config::{
HnswConfigWithBase, SpannConfigWithBase, IndexWriterConfig, IvfConfigWithBase, QuantizerType,
HnswConfigWithBase, IndexWriterConfig, IvfConfigWithBase, QuantizerType, SpannConfigWithBase,
};
use crate::input::Input;

Expand Down Expand Up @@ -131,6 +131,7 @@ impl IndexWriter {
num_clusters: index_builder_config.ivf_config.num_clusters,
num_data_points: index_builder_config.ivf_config.num_data_points,
max_clusters_per_vector: index_builder_config.ivf_config.max_clusters_per_vector,
distance_threshold: index_builder_config.ivf_config.distance_threshold,
base_directory: path.to_string(),
memory_size: index_builder_config.base_config.max_memory_size,
file_size: index_builder_config.base_config.file_size,
Expand Down Expand Up @@ -191,6 +192,7 @@ impl IndexWriter {
num_clusters: ivf_config.num_clusters,
num_data_points: ivf_config.num_data_points,
max_clusters_per_vector: ivf_config.max_clusters_per_vector,
distance_threshold: ivf_config.distance_threshold,
base_directory: index_writer_config.base_config.output_path.clone(),
memory_size: index_writer_config.base_config.max_memory_size,
file_size: index_writer_config.base_config.file_size,
Expand Down Expand Up @@ -453,6 +455,7 @@ mod tests {
num_clusters: 2,
num_data_points: 100,
max_clusters_per_vector: 1,
distance_threshold: 0.1,

max_iteration: 10,
batch_size: 10,
Expand Down Expand Up @@ -527,6 +530,7 @@ mod tests {
num_clusters: 2,
num_data_points: 100,
max_clusters_per_vector: 1,
distance_threshold: 0.1,

max_iteration: 10,
batch_size: 10,
Expand Down

0 comments on commit 2f62fe0

Please sign in to comment.