Skip to content

Commit

Permalink
Threaded greedy search.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 664989722
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 19, 2024
1 parent 5fb0209 commit 48a359f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 39 deletions.
119 changes: 81 additions & 38 deletions chirp/projects/hoplite/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import collections
import concurrent
import dataclasses
from typing import Callable
import threading
from typing import Any, Callable

from chirp.projects.hoplite import brutalism
from chirp.projects.hoplite import graph_utils
Expand Down Expand Up @@ -120,6 +121,7 @@ def greedy_search(
search_list_size: int = 100,
deterministic: bool = False,
max_visits: int | None = None,
max_workers: int = 10,
) -> tuple[search_results.TopKSearchResults, np.ndarray]:
"""Apply the Vamana greedy search.
Expand All @@ -129,56 +131,64 @@ def greedy_search(
search_list_size: Top-k value for search.
deterministic: Ensure that the search path is fully reproducible.
max_visits: Visit no more than this many nodes.
max_workers: Max number of worker threads.
Returns:
The TopKSearchResults and the sequence of all 'visited' nodes.
"""
visited = {}
visited = np.array([], dtype=np.int64)
results = search_results.TopKSearchResults(search_list_size)
state = {}
state['db'] = self.db
state['score_fn'] = self.score_fn

# Insert start node into the TopKResults.
start_node_embedding = self.db.get_embedding(start_node)
start_score = self.score_fn(start_node_embedding, query_embedding)
result = search_results.SearchResult(start_node, start_score)
results.update(result)

while max_visits is None or len(visited) < max_visits:
# Get the best result we have not yet visited.
for r in results:
if r.embedding_id not in visited:
visit_idx = r.embedding_id
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers,
initializer=brutalism.worker_initializer,
initargs=(state,),
) as executor:
while max_visits is None or len(visited) < max_visits:
jobs = []
if len(results.search_results) >= search_list_size:
min_score = results.min_score
else:
min_score = -np.inf
result_ids = np.array([r.embedding_id for r in results.search_results])
unvisited = np.setdiff1d(result_ids, visited)
if unvisited.shape[0] == 0:
# All candidates visited; we're done.
break
else:
break

# Add the selected node to 'visited'.
visited[visit_idx] = None

# We will examine neighbors of the visited node.
nbrs = self.db.get_edges(visit_idx)
# Filter visited neighbors.
nbrs = nbrs[np.array(tuple(n not in visited for n in nbrs), dtype=bool)]

nbrs, nbr_embeddings = self.db.get_embeddings(nbrs)
if deterministic:
order = np.argsort(nbrs)
nbr_embeddings = nbr_embeddings[order]
nbrs = nbrs[order]
nbr_scores = self.score_fn(nbr_embeddings, query_embedding)

if len(results.search_results) >= search_list_size:
# Drop any elements bigger than the current result set's min_score.
keep_args = np.where(nbr_scores >= results.min_score)
nbrs = nbrs[keep_args]
nbr_scores = nbr_scores[keep_args]

for nbr_idx, nbr_score in zip(nbrs, nbr_scores):
if results.will_filter(nbr_idx, nbr_score):
continue
results.update(
search_results.SearchResult(nbr_idx, nbr_score), force_insert=True
)
return results, np.array(tuple(visited.keys()))
for r in unvisited:
jobs.append(
executor.submit(
greedy_search_worker,
query_embedding,
r,
min_score,
visited,
state,
deterministic,
)
)

# Process the results.
for job in jobs:
nbrs, nbr_scores = job.result()
for nbr_idx, nbr_score in zip(nbrs, nbr_scores):
if results.will_filter(nbr_idx, nbr_score):
continue
results.update(
search_results.SearchResult(nbr_idx, nbr_score),
force_insert=True,
)
visited = np.concatenate([visited, np.array(unvisited)], axis=0)
return results, visited

def index(
self,
Expand Down Expand Up @@ -484,3 +494,36 @@ def multi_test_recall(
)
recalls.append(recall)
return float(np.mean(recalls))


def greedy_search_worker(
query_embedding: np.ndarray,
target_idx: int,
min_score: float,
visited: np.ndarray,
state: dict[str, Any],
deterministic: bool = False,
):
"""Worker task for threaded greedy search."""
name = threading.current_thread().name
db = state[name + 'db']

# We will examine neighbors of the visited node.
nbrs = db.get_edges(target_idx)
# Filter visited neighbors.
nbrs = np.setdiff1d(nbrs, visited)
if not nbrs.shape[0]:
return nbrs, np.array([], np.float16)

nbrs, nbr_embeddings = db.get_embeddings(nbrs)
if deterministic:
order = np.argsort(nbrs)
nbr_embeddings = nbr_embeddings[order]
nbrs = nbrs[order]
nbr_scores = state['score_fn'](nbr_embeddings, query_embedding)

# Drop any elements bigger than the current result set's min_score.
keep_args = np.where(nbr_scores >= min_score)
nbrs = nbrs[keep_args]
nbr_scores = nbr_scores[keep_args]
return nbrs, nbr_scores
3 changes: 2 additions & 1 deletion chirp/projects/hoplite/tests/hoplite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def test_greedy_search_impl_agreement(self):
nbrs = in_mem_db.get_edges(x)
for y in nbrs:
sqlite_db.insert_edge(id_mapping[x], id_mapping[y])
sqlite_db.commit()

rng = np.random.default_rng(42)
query = rng.normal(size=(EMBEDDING_SIZE,), loc=0, scale=1.0)
Expand All @@ -239,7 +240,7 @@ def test_greedy_search_impl_agreement(self):
query, search_list_size=32, start_node=0, deterministic=True
)
results_s, path_s = v_s.greedy_search(
query, search_list_size=32, start_node=1, deterministic=True
query, search_list_size=32, start_node=id_mapping[0], deterministic=True
)
self.assertSameElements((id_mapping[x] for x in path_m), path_s)

Expand Down

0 comments on commit 48a359f

Please sign in to comment.