Skip to content

Commit

Permalink
Make kmeans test less flaky (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
hicder authored Dec 4, 2024
1 parent 4179cea commit 7b1f6b1
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 10 deletions.
2 changes: 1 addition & 1 deletion rs/index/src/hnsw/reader.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::fs::File;

use anyhow::Result;
use byteorder::{ByteOrder, LittleEndian};
use memmap2::Mmap;
use quantization::quantization::Quantizer;
use anyhow::Result;

use crate::hnsw::index::Hnsw;
use crate::hnsw::writer::{Header, Version};
Expand Down
2 changes: 1 addition & 1 deletion rs/index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod hnsw;
pub mod index;
pub mod ivf;
pub mod posting_list;
pub mod spann;
pub mod traverse_state;
pub mod utils;
pub mod vector;
pub mod spann;
2 changes: 1 addition & 1 deletion rs/index/src/spann/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod reader;
pub mod index;
pub mod reader;
78 changes: 78 additions & 0 deletions rs/index_writer/src/index_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,4 +465,82 @@ mod tests {
assert!(ivf_vector_storage.exists());
assert!(ivf_index.exists());
}

#[test]
fn test_index_writer_process_ivf_hnsw() {
// Setup test data
let mut rng = rand::thread_rng();
let dimension = 10;
let num_rows = 100;
let data: Vec<Vec<f32>> = (0..num_rows)
.map(|_| (0..dimension).map(|_| rng.gen::<f32>()).collect())
.collect();

let mut mock_input = MockInput::new(data);

// Create a temporary directory for output
let temp_dir = TempDir::new("test_index_writer_process_ivf_hnsw")
.expect("Failed to create temporary directory");
let base_directory = temp_dir
.path()
.to_str()
.expect("Failed to convert temporary directory path to string")
.to_string();

// Configure IndexWriter
let base_config = BaseConfig {
output_path: base_directory.clone(),
dimension,
max_memory_size: 1024 * 1024 * 1024, // 1 GB
file_size: 1024 * 1024 * 1024, // 1 GB
};
let hnsw_config = HnswConfig {
num_layers: 2,
max_num_neighbors: 10,
ef_construction: 100,
reindex: false,
quantizer_type: QuantizerType::ProductQuantizer,
subvector_dimension: 2,
num_bits: 2,
num_training_rows: 50,

max_iteration: 10,
batch_size: 10,
};
let ivf_config = IvfConfig {
num_clusters: 2,
num_data_points: 100,
max_clusters_per_vector: 1,

max_iteration: 10,
batch_size: 10,
tolerance: 0.0,
max_posting_list_size: usize::MAX,
};
let config = IndexWriterConfig::HnswIvf(HnswIvfConfig {
base_config,
hnsw_config,
ivf_config,
});

let mut index_writer = IndexWriter::new(config);

// Process the input
assert!(index_writer.process(&mut mock_input).is_ok());

// Check if output directories and files exist
let quantizer_directory_path = format!("{}/centroid_quantizer", base_directory);
let pq_directory = Path::new(&quantizer_directory_path);
let centroids_directory_path = format!("{}/centroids", base_directory);
let centroids_directory = Path::new(&centroids_directory_path);
let hnsw_vector_storage_path =
format!("{}/vector_storage", centroids_directory.to_str().unwrap());
let hnsw_vector_storage = Path::new(&hnsw_vector_storage_path);
let hnsw_index_path = format!("{}/index", centroids_directory.to_str().unwrap());
let hnsw_index = Path::new(&hnsw_index_path);
assert!(pq_directory.exists());
assert!(centroids_directory.exists());
assert!(hnsw_vector_storage.exists());
assert!(hnsw_index.exists());
}
}
65 changes: 58 additions & 7 deletions rs/utils/src/kmeans_builder/kmeans_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub struct KMeansBuilder {

// Variant for this algorithm. Currently only Lloyd is supported.
pub variant: KMeansVariant,

pub cluster_init_values: Option<Vec<usize>>,
}

pub struct KMeansResult {
Expand All @@ -49,6 +51,25 @@ impl KMeansBuilder {
tolerance,
dimension,
variant,
cluster_init_values: None,
}
}

pub fn new_with_cluster_init_values(
num_cluters: usize,
max_iter: usize,
tolerance: f32,
dimension: usize,
variant: KMeansVariant,
cluster_init_values: Vec<usize>,
) -> Self {
Self {
num_cluters,
max_iter,
tolerance,
dimension,
variant,
cluster_init_values: Some(cluster_init_values),
}
}

Expand Down Expand Up @@ -120,8 +141,17 @@ impl KMeansBuilder {
let mut cluster_labels = vec![0; num_data_points];

// Random initialization of cluster labels
for i in 0..num_data_points {
cluster_labels[i] = rand::random::<usize>() % self.num_cluters;
match &self.cluster_init_values {
Some(values) => {
for i in 0..num_data_points {
cluster_labels[i] = values[i % values.len()];
}
}
None => {
for i in 0..num_data_points {
cluster_labels[i] = rand::random::<usize>() % self.num_cluters;
}
}
}

let mut final_centroids = vec![0.0; self.num_cluters * self.dimension];
Expand Down Expand Up @@ -174,8 +204,7 @@ impl KMeansBuilder {
let centroid = centroids
[centroid_id * self.dimension..(centroid_id + 1) * self.dimension]
.as_ref();
let distance =
T::calculate_squared(dp, centroid) + penalties[centroid_id];
let distance = T::calculate_squared(dp, centroid) + penalties[centroid_id];

if distance < min_cost {
min_cost = distance;
Expand Down Expand Up @@ -224,7 +253,14 @@ mod tests {
.cloned()
.collect();

let kmeans = KMeansBuilder::new(3, 100, 1e-4, 2, KMeansVariant::Lloyd);
let kmeans = KMeansBuilder::new_with_cluster_init_values(
3,
100,
1e-4,
2,
KMeansVariant::Lloyd,
vec![0, 0, 0, 1, 1, 1, 2, 2, 2],
);
let result = kmeans
.fit(flattened_data)
.expect("KMeans run should succeed");
Expand Down Expand Up @@ -267,7 +303,14 @@ mod tests {
.flatten()
.cloned()
.collect();
let kmeans = KMeansBuilder::new(3, 100, 10000.0, 2, KMeansVariant::Lloyd);
let kmeans = KMeansBuilder::new_with_cluster_init_values(
3,
100,
10000.0,
2,
KMeansVariant::Lloyd,
vec![0, 0, 0, 1, 1, 1, 2, 2, 2],
);
let result = kmeans
.fit(flattened_data)
.expect("KMeans run should succeed");
Expand Down Expand Up @@ -301,7 +344,15 @@ mod tests {
.flatten()
.cloned()
.collect();
let kmeans = KMeansBuilder::new(3, 100, 0.0, 2, KMeansVariant::Lloyd);
let kmeans = KMeansBuilder::new_with_cluster_init_values(
3,
100,
0.0,
2,
KMeansVariant::Lloyd,
vec![0, 0, 0, 1, 1, 1, 2, 2, 2],
);

let result = kmeans
.fit(flattened_data)
.expect("KMeans run should succeed");
Expand Down

0 comments on commit 7b1f6b1

Please sign in to comment.