From f8f5f0f826c887a0fd2d223f1e63efc496872343 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 17 Oct 2024 09:21:00 -0700 Subject: [PATCH] Add a cell for displaying loaded validation results. PiperOrigin-RevId: 686939867 --- analysis.ipynb | 29 +++++++++++++++++++++++++++++ chirp/inference/call_density.py | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/analysis.ipynb b/analysis.ipynb index 316f4a98..cddc15e1 100644 --- a/analysis.ipynb +++ b/analysis.ipynb @@ -312,6 +312,35 @@ "roc_auc_estimate = call_density.estimate_roc_auc(validation_examples)\n", "print(f'Estimated ROC-AUC : {roc_auc_estimate:5.4f}')" ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "0ii9H72iQknv" + }, + "outputs": [], + "source": [ + "#@title Display Logged Validation Examples. { vertical-output: true }\n", + "\n", + "validation_results = search.TopKSearchResults(top_k=len(validation_examples))\n", + "for v in validation_examples:\n", + " validation_results.update(v.to_search_result(\n", + " target_class, project_state.embedding_model.sample_rate))\n", + "\n", + "samples_per_page = 40 #@param\n", + "page_state = display.PageState(\n", + " np.ceil(combined_results.top_k / samples_per_page))\n", + "\n", + "display.display_paged_results(\n", + " validation_results,\n", + " page_state, samples_per_page,\n", + " project_state=project_state,\n", + " embedding_sample_rate=project_state.embedding_model.sample_rate,\n", + " exclusive_labels=True,\n", + " checkbox_labels=[target_class, f'not {target_class}', 'unsure'],\n", + ")" + ] } ], "metadata": { diff --git a/chirp/inference/call_density.py b/chirp/inference/call_density.py index dac25a38..b86e5322 100644 --- a/chirp/inference/call_density.py +++ b/chirp/inference/call_density.py @@ -21,6 +21,7 @@ from chirp.inference.search import search from etils import epath +import ipywidgets import numpy as np import pandas as pd import scipy @@ -59,6 +60,29 @@ def to_row(self): self.bin_weight, ] + def to_search_result(self, target_class: str, embedding_sample_rate: int): + """Convert to a search result for display only.""" + result = search.SearchResult( + filename=self.filename, + timestamp_offset=int(embedding_sample_rate * self.timestamp_offset), + score=self.score, + sort_score=np.random.uniform(), + embedding=np.zeros(shape=(0,), dtype=np.float32), + ) + b = ipywidgets.RadioButtons( + options=[target_class, f'not {target_class}', 'unsure'] + ) + if self.is_pos == 1: + b.value = target_class + elif self.is_pos == -1: + b.value = f'not {target_class}' + elif self.is_pos == 0: + b.value = 'unsure' + else: + raise ValueError(f'unexpected value ({self.is_pos})') + result.label_widgets = [b] + return result + @classmethod def from_search_result( cls,