Skip to content

Commit

Permalink
Add a cell for displaying loaded validation results.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686939867
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 17, 2024
1 parent 1c040df commit d7d316e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 1 deletion.
29 changes: 29 additions & 0 deletions analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,35 @@
"roc_auc_estimate = call_density.estimate_roc_auc(validation_examples)\n",
"print(f'Estimated ROC-AUC : {roc_auc_estimate:5.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "0ii9H72iQknv"
},
"outputs": [],
"source": [
"#@title Display Logged Validation Examples. { vertical-output: true }\n",
"\n",
"validation_results = search.TopKSearchResults(top_k=len(validation_examples))\n",
"for v in validation_examples:\n",
" validation_results.update(v.to_search_result(\n",
" target_class, project_state.embedding_model.sample_rate))\n",
"\n",
"samples_per_page = 40 #@param\n",
"page_state = display.PageState(\n",
" np.ceil(combined_results.top_k / samples_per_page))\n",
"\n",
"display.display_paged_results(\n",
" validation_results,\n",
" page_state, samples_per_page,\n",
" project_state=project_state,\n",
" embedding_sample_rate=project_state.embedding_model.sample_rate,\n",
" exclusive_labels=True,\n",
" checkbox_labels=[target_class, f'not {target_class}', 'unsure'],\n",
")"
]
}
],
"metadata": {
Expand Down
24 changes: 24 additions & 0 deletions chirp/inference/call_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from chirp.inference.search import search
from etils import epath
import ipywidgets
import numpy as np
import pandas as pd
import scipy
Expand Down Expand Up @@ -59,6 +60,29 @@ def to_row(self):
self.bin_weight,
]

def to_search_result(self, target_class: str):
"""Convert to a search result for display only."""
result = search.SearchResult(
filename=self.filename,
timestamp_offset=self.timestamp_offset,
score=self.score,
sort_score=np.random.uniform(),
embedding=np.zeros(shape=(0,), dtype=np.float32),
)
b = ipywidgets.RadioButtons(
options=[target_class, f'not {target_class}', 'unsure']
)
if self.is_pos == 1:
b.value = target_class
elif self.is_pos == -1:
b.value = f'not {target_class}'
elif self.is_pos == 0:
b.value = 'unsure'
else:
raise ValueError(f'unexpected value ({self.is_pos})')
result.label_widgets = [b]
return result

@classmethod
def from_search_result(
cls,
Expand Down
2 changes: 1 addition & 1 deletion chirp/inference/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SearchResult:
# Source file contianing corresponding audio.
filename: str
# Time offset for audio.
timestamp_offset: int
timestamp_offset: float

# The following are populated as needed.
audio: np.ndarray | None = None
Expand Down
28 changes: 28 additions & 0 deletions chirp/inference/tests/call_density_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import shutil
import string
import tempfile
from unittest import mock

from chirp.inference import call_density
from etils import epath
import IPython
import ipywidgets
import numpy as np
from sklearn import metrics

Expand All @@ -31,6 +34,18 @@
class CallDensityTest(absltest.TestCase):

def setUp(self):
# Without this, unit tests using Ipywidgets will fail with 'Comms cannot be
# opened without a kernel and a comm_manager attached to that kernel'. This
# mocks out the comms. This is a little fragile because it sets a private
# attribute and may break for future Ipywidget library upgrades.
setattr(
ipywidgets.Widget,
'_comm_default',
lambda self: mock.MagicMock(spec=IPython.kernel.comm.Comm),
)

super().setUp()

super().setUp()
self.tempdir = tempfile.mkdtemp()

Expand Down Expand Up @@ -156,6 +171,19 @@ def test_write_read_log(self):
got_examples = call_density.load_validation_log(log_filepath)
self.assertLen(got_examples, len(examples))

with self.subTest('to_result'):
r = got_examples[0].to_search_result('someclass', 10)
self.assertEqual(r.filename, got_examples[0].filename)
self.assertEqual(r.timestamp_offset, got_examples[0].timestamp_offset)
self.assertEqual(r.score, got_examples[0].score)
if examples[0].is_pos == 1:
self.assertEqual(r.label_widgets[0].value, 'someclass')
elif examples[0].is_pos == -1:
self.assertEqual(r.label_widgets[0].value, 'not someclass')
elif examples[0].is_pos == 0:
self.assertEqual(r.label_widgets[0].value, 'unsure')
else:
raise ValueError(f'unexpected value ({examples[0].is_pos})')

if __name__ == '__main__':
absltest.main()

0 comments on commit d7d316e

Please sign in to comment.