Skip to content

Commit

Permalink
Optimizations to KMeans implementation (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
hicder authored Jan 1, 2025
1 parent f9a4678 commit 39b827d
Show file tree
Hide file tree
Showing 20 changed files with 551 additions and 301 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ rayon = "1.10.0"
sorted-vec = "0.8.5"
dashmap = "6.1.0"
reqwest = {version = "0.12.11", features = ["json"]}
atomic_refcell = "0.1.13"
35 changes: 19 additions & 16 deletions py/demo_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,29 @@
if __name__ == "__main__":
# Example usage for IndexServer
muopdb_client = mp.IndexServerClient()
query = "personal career development"
query = "baby goes to school"
query_vector = ollama.embeddings(model='nomic-embed-text', prompt=query)["embedding"]

# Read back the raw data to print the responses
with open("/mnt/muopdb/raw/1m_sentences.txt", "r") as f:
sentences = [line.strip() for line in f]

start = time.time()
search_response = muopdb_client.search(
index_name="test-collection-1",
vector=query_vector,
top_k=5,
ef_construction=50,
record_metrics=False
)
end = time.time()
print(f"Time taken for search: {end - start} seconds")
i = 0
while i < 5:
start = time.time()
search_response = muopdb_client.search(
index_name="test-collection-1",
vector=query_vector,
top_k=5,
ef_construction=50,
record_metrics=True
)
end = time.time()
print(f"Time taken for search: {end - start} seconds")

print(f"Number of results: {len(search_response.ids)}")
print("================")
for id in search_response.ids:
print(f"RESULT: {sentences[id - 1]}")
print("================")
print(f"Number of results: {len(search_response.ids)}")
print("================")
for id in search_response.ids:
print(f"RESULT: {sentences[id - 1]}")
print("================")
i += 1
4 changes: 2 additions & 2 deletions rs/demo/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::time::Instant;
use anyhow::{Context, Result};
use hdf5::File;
use log::{LevelFilter, info};
use ndarray::s;
use proto::muopdb::index_server_client::IndexServerClient;
use proto::muopdb::{FlushRequest, InsertRequest};
use ndarray::s;

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -54,7 +54,7 @@ async fn main() -> Result<()> {
});

client.insert(request).await?;
start_idx = end_idx;
start_idx = end_idx;
}

let mut duration = start.elapsed();
Expand Down
6 changes: 5 additions & 1 deletion rs/demo/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ async fn main() -> Result<()> {
let response_body = response.text().await.expect("Failed to read response body");
let response_map: serde_json::Value = serde_json::from_str(&response_body).unwrap();
let query_vector_value = response_map["embedding"].as_array().unwrap();
let query_vector: Vec<f32> = query_vector_value.iter().map(|x| x.as_f64().unwrap()).map(|x| x as f32).collect();
let query_vector: Vec<f32> = query_vector_value
.iter()
.map(|x| x.as_f64().unwrap())
.map(|x| x as f32)
.collect();

// Create search request
let request = tonic::Request::new(SearchRequest {
Expand Down
2 changes: 2 additions & 0 deletions rs/index/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ tempdir.workspace = true
utils.workspace = true
serde.workspace = true
serde_json.workspace = true
rayon.workspace = true
atomic_refcell.workspace = true
2 changes: 1 addition & 1 deletion rs/index/src/collection/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ mod tests {
max_iteration: 1000,
batch_size: 4,
num_clusters: 10,
num_data_points: 1000,
num_data_points_for_clustering: 1000,
max_clusters_per_vector: 1,
distance_threshold: 0.1,
base_directory: "./".to_string(),
Expand Down
1 change: 0 additions & 1 deletion rs/index/src/hnsw/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ impl<Q: Quantizer> HnswBuilder<Q> {
}
while let Some(node) = queue.pop_front() {
visited.set(node.try_into().unwrap(), true);
debug!("Visited node {}", node);

let edges = graph.edges.get_mut(&node);
if let Some(edges) = edges {
Expand Down
1 change: 0 additions & 1 deletion rs/index/src/hnsw/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ impl<Q: Quantizer> Searchable for Hnsw<Q> {
ef_construction: u32,
context: &mut SearchContext,
) -> Option<Vec<IdWithScore>> {
// TODO(hicder): Add ef parameter
Some(self.ann_search(query, k, ef_construction, context))
}
}
Expand Down
Loading

0 comments on commit 39b827d

Please sign in to comment.