Skip to content

Commit

Permalink
Use a priority queue for TopKSearchResults. This speeds up handling o…
Browse files Browse the repository at this point in the history
…f large result lists substantially.

PiperOrigin-RevId: 635480598
  • Loading branch information
sdenton4 authored and copybara-github committed May 20, 2024
1 parent 9a22887 commit 3850baf
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 43 deletions.
2 changes: 1 addition & 1 deletion analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@
" np.ceil(len(combined_results) / samples_per_page))\n",
"\n",
"display.display_paged_results(\n",
" search.TopKSearchResults(combined_results, len(combined_results)),\n",
" search.TopKSearchResults(len(combined_results), combined_results),\n",
" page_state, samples_per_page,\n",
" project_state=project_state,\n",
" embedding_sample_rate=project_state.embedding_model.sample_rate,\n",
Expand Down
7 changes: 6 additions & 1 deletion chirp/inference/search/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def plot_audio_melspec(
):
"""Plot a melspectrogram from audio."""
melspec_layer = get_melspec_layer(sample_rate)
if audio.shape[0] < sample_rate / 100:
audio = np.concatenate(
audio, np.zeros([sample_rate // 100 + 1], dtype=audio.dtype), axis=0
)
melspec = melspec_layer.apply({}, audio[np.newaxis, :])[0]
plot_melspec(melspec, newfig=newfig, sample_rate=sample_rate, frame_rate=100)
plt.show()
Expand Down Expand Up @@ -191,7 +195,8 @@ def display_page(page_state):
print(f'Results Page: {page} / {num_pages}')
st, end = page * samples_per_page, (page + 1) * samples_per_page
results_page = search.TopKSearchResults(
all_results.search_results[st:end], top_k=samples_per_page
top_k=samples_per_page,
search_results=all_results.search_results[st:end],
)
display_search_results(
results=results_page, rank_offset=page * samples_per_page, **kwargs
Expand Down
71 changes: 39 additions & 32 deletions chirp/inference/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
import dataclasses
import functools
import heapq
from typing import Any, Callable, List, Sequence

from chirp.inference import tf_examples
Expand Down Expand Up @@ -50,55 +51,62 @@ def __hash__(self):
"""Return an identifier for this result."""
return hash((self.filename, self.timestamp_offset))

def __lt__(self, other):
return self.sort_score < other.sort_score

def __gt__(self, other):
return self.sort_score > other.sort_score

def __le__(self, other):
return self.sort_score <= other.sort_score

def __ge__(self, other):
return self.sort_score >= other.sort_score


@dataclasses.dataclass
class TopKSearchResults:
"""Top-K search results."""
"""Wrapper for sorting and handling TopK search results.
This class maintains a queue of SearchResult objects, sorted by their
sort_score. When updated with a new SearchResult, the result is either added
or ignored appropriately. For speed, the `will_filter` method allows checking
immediately whether a result with a given score will be discarded. The results
are kept in heap-order for efficeint updating.
Iterating over the search results will produce a copy of the results, with
in-order iteration over results from largest to smallest sort_score.
"""

search_results: List[SearchResult]
top_k: int
min_score: float = -1.0
_min_score_idx: int = -1
search_results: List[SearchResult] = dataclasses.field(default_factory=list)

def __post_init__(self):
self._update_deseridata()
heapq.heapify(self.search_results)

def __iter__(self):
for r in self.search_results:
yield r
iter_queue = sorted(self.search_results, reverse=True)
for result in iter_queue:
yield result

@property
def min_score(self):
return self.search_results[0].sort_score

def update(self, search_result: SearchResult) -> None:
"""Update Results with the new result."""
if len(self.search_results) < self.top_k:
# Add the result, regardless of score, until we have k results.
pass
elif search_result.sort_score < self.min_score:
# Early return to save compute.
if self.will_filter(search_result.sort_score):
return
elif len(self.search_results) >= self.top_k:
self.search_results.pop(self._min_score_idx)
self.search_results.append(search_result)
self._update_deseridata()
if len(self.search_results) >= self.top_k:
heapq.heappop(self.search_results)
heapq.heappush(self.search_results, search_result)

def will_filter(self, score: float) -> bool:
"""Check whether a score is relevant."""
if len(self.search_results) < self.top_k:
# Add the result, regardless of score, until we have k results.
return False
return score < self.min_score

def _update_deseridata(self):
if not self.search_results:
return
self._min_score_idx = np.argmin([r.sort_score for r in self.search_results])
self.min_score = self.search_results[self._min_score_idx].sort_score

def sort(self):
"""Sort the results."""
scores = np.array([r.sort_score for r in self.search_results])
idxs = np.argsort(-scores)
self.search_results = [self.search_results[idx] for idx in idxs]
self._update_deseridata()
return score < self.search_results[0].sort_score

def write_labeled_data(self, labeled_data_path: str, sample_rate: int):
"""Write labeled results to the labeled data collection."""
Expand Down Expand Up @@ -284,7 +292,7 @@ def search_embeddings_parallel(
embeddings_dataset = embeddings_dataset.filter(filter_fn)
embeddings_dataset = embeddings_dataset.prefetch(1024)

results = TopKSearchResults([], top_k=top_k)
results = TopKSearchResults(top_k=top_k)
all_distances = []
try:
for ex in tqdm.tqdm(embeddings_dataset.as_numpy_iterator()):
Expand All @@ -304,7 +312,6 @@ def search_embeddings_parallel(
except KeyboardInterrupt:
pass
all_distances = np.concatenate(all_distances)
results.sort()
return results, all_distances


Expand Down
2 changes: 1 addition & 1 deletion chirp/inference/tests/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_bootstrap_from_embeddings(self):

# Check that we can iterate over TopKSearchResults,
# and successfully attach audio.
search_results = search.TopKSearchResults([], top_k=3)
search_results = search.TopKSearchResults(top_k=3)
for i, ex in enumerate(ds.as_numpy_iterator()):
result = search.SearchResult(
embedding=ex['embedding'],
Expand Down
18 changes: 10 additions & 8 deletions chirp/inference/tests/search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,27 @@ def test_top_k_search_results(self):
)
fake_results.append(r)

results = search.TopKSearchResults([], top_k=10)
results = search.TopKSearchResults(top_k=10)
for i, r in enumerate(fake_results):
results.update(r)
self.assertLen(results.search_results, min([i + 1, 10]))
# Get the 10th largest value amongst the dists seen so far.
true_min_neg_dist = -np.max(sorted(dists[: i + 1])[:10])
arg_min_dist = np.argmin([r.sort_score for r in results])
arg_min_dist = np.argmin([r.sort_score for r in results.search_results])
self.assertEqual(results.min_score, true_min_neg_dist)
self.assertEqual(
results.search_results[arg_min_dist].sort_score, results.min_score
)

self.assertLen(results.search_results, results.top_k)
results.sort()
for i in range(1, 10):
self.assertGreater(
results.search_results[i - 1].sort_score,
results.search_results[i].sort_score,
)
last_score = None
for i, result in enumerate(results):
if i > 0:
self.assertGreater(
last_score,
result.sort_score,
)
last_score = result.sort_score

@parameterized.product(
metric_name=('euclidean', 'cosine', 'mip'),
Expand Down

0 comments on commit 3850baf

Please sign in to comment.