Skip to content

Commit

Permalink
Use a priority queue for TopKSearchResults. This speeds up handling o…
Browse files Browse the repository at this point in the history
…f large result lists substantially.

PiperOrigin-RevId: 635296548
  • Loading branch information
sdenton4 authored and copybara-github committed May 20, 2024
1 parent 9a22887 commit 240aeb7
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 32 deletions.
3 changes: 2 additions & 1 deletion chirp/inference/search/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def display_page(page_state):
print(f'Results Page: {page} / {num_pages}')
st, end = page * samples_per_page, (page + 1) * samples_per_page
results_page = search.TopKSearchResults(
all_results.search_results[st:end], top_k=samples_per_page
top_k=samples_per_page,
search_results=all_results.search_results[st:end],
)
display_search_results(
results=results_page, rank_offset=page * samples_per_page, **kwargs
Expand Down
59 changes: 30 additions & 29 deletions chirp/inference/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
import dataclasses
import functools
import heapq
from typing import Any, Callable, List, Sequence

from chirp.inference import tf_examples
Expand Down Expand Up @@ -50,55 +51,55 @@ def __hash__(self):
"""Return an identifier for this result."""
return hash((self.filename, self.timestamp_offset))

def __lt__(self, other):
return self.sort_score < other.sort_score

def __gt__(self, other):
return self.sort_score > other.sort_score

def __le__(self, other):
return self.sort_score <= other.sort_score

def __ge__(self, other):
return self.sort_score >= other.sort_score


@dataclasses.dataclass
class TopKSearchResults:
"""Top-K search results."""

search_results: List[SearchResult]
top_k: int
min_score: float = -1.0
_min_score_idx: int = -1
search_results: List[SearchResult] = dataclasses.field(default_factory=list)

def __post_init__(self):
self._update_deseridata()
heapq.heapify(self.search_results)

def __iter__(self):
for r in self.search_results:
yield r
iter_queue = self.search_results[:]
while iter_queue:
yield heapq.heappop(iter_queue)

@property
def min_score(self):
return self.search_results[0].sort_score

def update(self, search_result: SearchResult) -> None:
"""Update Results with the new result."""
if len(self.search_results) < self.top_k:
# Add the result, regardless of score, until we have k results.
pass
elif search_result.sort_score < self.min_score:
# Early return to save compute.
if self.will_filter(search_result.sort_score):
return
elif len(self.search_results) >= self.top_k:
self.search_results.pop(self._min_score_idx)
self.search_results.append(search_result)
self._update_deseridata()
if len(self.search_results) >= self.top_k:
heapq.heappop(self.search_results)
heapq.heappush(self.search_results, search_result)

def will_filter(self, score: float) -> bool:
"""Check whether a score is relevant."""
if len(self.search_results) < self.top_k:
# Add the result, regardless of score, until we have k results.
return False
return score < self.min_score

def _update_deseridata(self):
if not self.search_results:
return
self._min_score_idx = np.argmin([r.sort_score for r in self.search_results])
self.min_score = self.search_results[self._min_score_idx].sort_score
return score < self.search_results[0].sort_score

def sort(self):
"""Sort the results."""
scores = np.array([r.sort_score for r in self.search_results])
idxs = np.argsort(-scores)
self.search_results = [self.search_results[idx] for idx in idxs]
self._update_deseridata()
def sort(self) -> None:
self.search_results = sorted(self.search_results, reverse=True)

def write_labeled_data(self, labeled_data_path: str, sample_rate: int):
"""Write labeled results to the labeled data collection."""
Expand Down Expand Up @@ -284,7 +285,7 @@ def search_embeddings_parallel(
embeddings_dataset = embeddings_dataset.filter(filter_fn)
embeddings_dataset = embeddings_dataset.prefetch(1024)

results = TopKSearchResults([], top_k=top_k)
results = TopKSearchResults(top_k=top_k)
all_distances = []
try:
for ex in tqdm.tqdm(embeddings_dataset.as_numpy_iterator()):
Expand Down
2 changes: 1 addition & 1 deletion chirp/inference/tests/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_bootstrap_from_embeddings(self):

# Check that we can iterate over TopKSearchResults,
# and successfully attach audio.
search_results = search.TopKSearchResults([], top_k=3)
search_results = search.TopKSearchResults(top_k=3)
for i, ex in enumerate(ds.as_numpy_iterator()):
result = search.SearchResult(
embedding=ex['embedding'],
Expand Down
2 changes: 1 addition & 1 deletion chirp/inference/tests/search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_top_k_search_results(self):
)
fake_results.append(r)

results = search.TopKSearchResults([], top_k=10)
results = search.TopKSearchResults(top_k=10)
for i, r in enumerate(fake_results):
results.update(r)
self.assertLen(results.search_results, min([i + 1, 10]))
Expand Down

0 comments on commit 240aeb7

Please sign in to comment.