Skip to content

Commit

Permalink
QoL improvements for validation.
Browse files Browse the repository at this point in the history
* Write the class name in the validation log filename.
* Sample without replacement when creating validation set, avoiding duplicated validation examples.
* Shuffle the validation results deterministically, to display the results in a fixed random order.

PiperOrigin-RevId: 631928509
  • Loading branch information
sdenton4 authored and copybara-github committed May 8, 2024
1 parent fb72a85 commit 07b56b9
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@
"source": [
"#@title Basic Configuration. { vertical-output: true }\n",
"\n",
"data_source = 'filesystem' #@param['filesystem', 'a2o']\n",
"data_source = 'filesystem' #@param['filesystem', 'a2o'] {type:'string'}\n",
"a2o_auth_token = '' #@param\n",
"\n",
"#@markdown Define the model: Usually perch or birdnet.\n",
"model_choice = 'perch' #@param\n",
"model_choice = 'perch' #@param {type:'string'}\n",
"#@markdown Set the base directory for the project.\n",
"working_dir = '/tmp/agile' #@param\n",
"working_dir = '/tmp/hawaii' #@param {type:'string'}\n",
"\n",
"# Set the embedding and labeled data directories.\n",
"labeled_data_path = epath.Path(working_dir) / 'labeled'\n",
Expand All @@ -99,7 +99,7 @@
" bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
" embedding_config=embedding_config,\n",
" annotated_path=labeled_data_path,\n",
" embeddings_glob = '*/embeddings-*')\n",
" embeddings_glob='*/embeddings-*')\n",
" embeddings_path = embedding_config.output_dir\n",
"elif (embeddings_path\n",
" or (epath.Path(working_dir) / 'embeddings/config.json').exists()):\n",
Expand Down Expand Up @@ -197,7 +197,7 @@
"source": [
"#@title Validation and Call Density. { vertical-output: true }\n",
"\n",
"target_class = 'my_class' #@param\n",
"target_class = 'my_class' #@param {type:'string'}\n",
"\n",
"#@markdown Bin bounds for validation. Should be an ordered list, beginning with\n",
"#@markdown 0.0 and ending with 1.0.\n",
Expand Down Expand Up @@ -227,9 +227,15 @@
"q_bounds = np.quantile(all_logits, bounds)\n",
"binned = [[] for _ in range(num_bins)]\n",
"for r in results.search_results:\n",
" bin = np.argmax(r.score < q_bounds) - 1\n",
" binned[bin].append(r)\n",
"binned = [np.random.choice(b, samples_per_bin) for b in binned]\n",
" result_bin = np.argmax(r.score < q_bounds) - 1\n",
" binned[result_bin].append(r)\n",
"binned = [np.random.choice(b, samples_per_bin, replace=False) for b in binned]\n",
"\n",
"combined_results = []\n",
"for b in binned:\n",
" combined_results.extend(b)\n",
"rng = np.random.default_rng(42)\n",
"rng.shuffle(combined_results)\n",
"\n",
"ys, _, _, = plt.hist(all_logits, bins=100, density=True)\n",
"for q in q_bounds:\n",
Expand All @@ -247,16 +253,12 @@
"source": [
"#@title Display Results. { vertical-output: true }\n",
"\n",
"combined = []\n",
"for b in binned:\n",
" combined.extend(b)\n",
"np.random.shuffle(combined)\n",
"\n",
"samples_per_page = 40 #@param\n",
"page_state = display.PageState(np.ceil(len(combined) / samples_per_page))\n",
"samples_per_page = 40 #@param\n",
"page_state = display.PageState(\n",
" np.ceil(len(combined_results) / samples_per_page))\n",
"\n",
"display.display_paged_results(\n",
" search.TopKSearchResults(combined, len(combined)),\n",
" search.TopKSearchResults(combined_results, len(combined_results)),\n",
" page_state, samples_per_page,\n",
" project_state=project_state,\n",
" embedding_sample_rate=project_state.embedding_model.sample_rate,\n",
Expand All @@ -275,7 +277,8 @@
"source": [
"#@title Collate results and write validation log. { vertical-output: true }\n",
"\n",
"validation_log_filepath = epath.Path(working_dir) / 'validation.csv'\n",
"validation_log_filepath = (\n",
" epath.Path(working_dir) / f'validation_{target_class}.csv')\n",
"\n",
"filenames = []\n",
"timestamp_offsets = []\n",
Expand All @@ -288,7 +291,7 @@
"bin_wts = [1.0 / 2**(k + 1) for k in range(num_bins - 1)]\n",
"bin_wts.append(bin_wts[-1])\n",
"\n",
"for r in combined:\n",
"for r in combined_results:\n",
" if not r.label_widgets: continue\n",
" value = r.label_widgets[0].value\n",
" if value is None:\n",
Expand All @@ -298,9 +301,9 @@
" timestamp_offsets.append(r.timestamp_offset)\n",
"\n",
" # Get the bin number and sampling weight for the search result.\n",
" bin = np.argmax(r.score < q_bounds) - 1\n",
" bins.append(bin)\n",
" weights.append(bin_wts[bin])\n",
" result_bin = np.argmax(r.score < q_bounds) - 1\n",
" bins.append(result_bin)\n",
" weights.append(bin_wts[result_bin])\n",
"\n",
" if value == target_class:\n",
" is_pos.append(1)\n",
Expand Down Expand Up @@ -337,11 +340,11 @@
"bin_pos = [0 for i in range(num_bins)]\n",
"bin_neg = [0 for i in range(num_bins)]\n",
"for score, pos in zip(scores, is_pos):\n",
" bin = np.argmax(score < q_bounds) - 1\n",
" result_bin = np.argmax(score < q_bounds) - 1\n",
" if pos == 1:\n",
" bin_pos[bin] += 1\n",
" bin_pos[result_bin] += 1\n",
" elif pos == -1:\n",
" bin_neg[bin] += 1\n",
" bin_neg[result_bin] += 1\n",
"\n",
"# Create beta distributions.\n",
"prior = 0.1\n",
Expand Down

0 comments on commit 07b56b9

Please sign in to comment.