Skip to content

Commit

Permalink
Support different scoring functions (MIP, cosine, random, classifier …
Browse files Browse the repository at this point in the history
…logit) in brute-force search.

PiperOrigin-RevId: 556568615
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 15, 2023
1 parent b427160 commit e5dac7a
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 146 deletions.
27 changes: 21 additions & 6 deletions chirp/inference/active_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"import tensorflow as tf\n",
"from etils import epath\n",
"import matplotlib.pyplot as plt\n",
"import tqdm\n",
"\n",
"from chirp.inference import colab_utils\n",
"colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n",
Expand Down Expand Up @@ -72,7 +73,7 @@
"# each class of interest.\n",
"# Audio in sub-folders should be wav files.\n",
"# Audio should ideally be 5s audio clips, but the system is quite forgiving.\n",
"labeled_data_path = '' #@param\n"
"labeled_data_path = '/tmp/labeled' #@param\n"
]
},
{
Expand All @@ -90,11 +91,11 @@
"with (embeddings_path / 'config.json').open() as f:\n",
" embedding_config = config_dict.ConfigDict(json.loads(f.read()))\n",
"embeddings_glob = embeddings_path / 'embeddings-*'\n",
"embedding_hop_size_s = embedding_config.embed_fn_config.model_config.hop_size_s\n",
"\n",
"config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
" embeddings_path=embeddings_path,\n",
" annotated_path=labeled_data_path)\n",
"embedding_hop_size_s = config.embedding_hop_size_s\n",
"project_state = bootstrap.BootstrapState(config)\n",
"embedding_model = project_state.embedding_model"
]
Expand Down Expand Up @@ -205,7 +206,8 @@
"# Choose the target class to work with.\n",
"target_class = '' #@param\n",
"# Choose a target logit; will display results close to the target.\n",
"target_logit = 2.0 #@param\n",
"# Set to None to get the highest-logit examples.\n",
"target_logit = None #@param\n",
"# Number of results to display.\n",
"num_results = 25 #@param\n",
"\n",
Expand All @@ -214,8 +216,12 @@
" embeddings_path, file_glob='embeddings-*')\n",
"target_class_idx = merged.labels.index(target_class)\n",
"results, all_logits = search.classifer_search_embeddings_parallel(\n",
" embeddings_ds, model, target_class_idx, hop_size_s=embedding_hop_size_s,\n",
" target_logit=target_logit, top_k=num_results\n",
" embeddings_classifier=model,\n",
" target_index=target_class_idx,\n",
" embeddings_dataset=embeddings_ds,\n",
" hop_size_s=embedding_hop_size_s,\n",
" target_score=target_logit,\n",
" top_k=num_results\n",
")\n",
"\n",
"# Plot the histogram of logits.\n",
Expand Down Expand Up @@ -307,7 +313,7 @@
" # Write column headers.\n",
" headers = ['filename', 'timestamp_s', 'label', 'logit']\n",
" f.write(', '.join(headers) + '\\n')\n",
" for ex in inference_ds.as_numpy_iterator():\n",
" for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):\n",
" for t in range(ex['logits'].shape[0]):\n",
" for i, label in enumerate(merged.class_names):\n",
" if ex['logits'][t, i] \u003e threshold:\n",
Expand All @@ -318,6 +324,15 @@
" label, logit]\n",
" f.write(', '.join(row) + '\\n')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1Rb3LbPvMKde"
},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
76 changes: 45 additions & 31 deletions chirp/inference/search_embeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
"source": [
"#@title Imports. { vertical-output: true }\n",
"\n",
"# Disable annoying warnings.\n",
"\n",
"# Global imports\n",
"import json\n",
"from ml_collections import config_dict\n",
Expand Down Expand Up @@ -40,10 +38,13 @@
"#@title Configuration and Setup. { vertical-output: true }\n",
"\n",
"# Path to embeddings of unlabeled data.\n",
"embeddings_path = '/tmp/embeddings' #@param\n",
"embeddings_path = '' #@param\n",
"\n",
"# Path for storing annotated examples.\n",
"labeled_data_path = '/tmp/labeled_data' #@param\n"
"labeled_data_path = '' #@param\n",
"\n",
"separation_model_key = 'separator_model_tf' #@param\n",
"separation_model_path = '' #@param\n"
]
},
{
Expand Down Expand Up @@ -80,8 +81,6 @@
"outputs": [],
"source": [
"#@title Load Separation Model (Optional) { vertical-output: true }\n",
"separation_model_key = '' #@param\n",
"separation_model_path = '' #@param\n",
"\n",
"if config.model_key == 'separate_embed_model' and not separation_model_key.strip():\n",
" separation_model_key = 'separator_model_tf'\n",
Expand All @@ -94,7 +93,7 @@
" model_path=separation_model_path,\n",
" frame_size=32000,\n",
" )\n",
" print(\"Loaded separator model at {}\".format(separation_model_path))\n",
" print('Loaded separator model at {}'.format(separation_model_path))\n",
"else:\n",
" print('No separation model loaded.')\n",
" separator = None"
Expand Down Expand Up @@ -138,7 +137,7 @@
"end = int(st + window_s * sample_rate)\n",
"if end \u003e audio.shape[0]:\n",
" end = audio.shape[0]\n",
" st = max([0, end - window_s * sample_rate])\n",
" st = max([0, int(end - window_s * sample_rate)])\n",
"audio_window = audio[st:end]\n",
"display.plot_audio_melspec(audio_window, sample_rate)\n",
"\n",
Expand Down Expand Up @@ -178,7 +177,7 @@
"source": [
"#@title Select the query channel. { vertical-output: true }\n",
"\n",
"query_label = 'my_label' #@param\n",
"query_label = 'some_audio' #@param\n",
"query_channel = -1 #@param\n",
"\n",
"if query_channel \u003c 0 or sep_outputs is None:\n",
Expand All @@ -202,38 +201,43 @@
"source": [
"#@title Run Top-K Search. { vertical-output: true }\n",
"\n",
"# Number of search results to capture.\n",
"top_k = 25 #@param\n",
"\n",
"# Target distance for search results.\n",
"# This lets us try to hone in on a 'classifier boundary' instead of just\n",
"# looking at the closest matches.\n",
"target_dist = 0 #@param\n",
"# Set to 'None' for raw 'best results' search.\n",
"target_score = None #@param\n",
"\n",
"# Number of search results to capture.\n",
"top_k = 10 #@param\n",
"metric = 'euclidean' #@param['euclidean', 'mip', 'cosine']\n",
"\n",
"random_sample = False #@param\n",
"\n",
"ds = project_state.create_embeddings_dataset()\n",
"results, all_distances = search.search_embeddings_parallel(\n",
" ds, query[np.newaxis, np.newaxis, :], hop_size_s=model_config.hop_size_s,\n",
" top_k=top_k, target_dist=target_dist)\n",
"results, all_scores = search.search_embeddings_parallel(\n",
" ds, query[np.newaxis, np.newaxis, :],\n",
" hop_size_s=config.embedding_hop_size_s,\n",
" top_k=top_k, target_score=target_score, score_fn=metric,\n",
" random_sample=random_sample)\n",
"\n",
"# Plot histogram of distances\n",
"_, ys, _ = plt.hist(all_distances, bins=128, density=True)\n",
"hit_distances = [r.distance for r in results.search_results]\n",
"plt.scatter(hit_distances, np.zeros_like(hit_distances), marker='|',\n",
"ys, _, _ = plt.hist(all_scores, bins=128, density=True)\n",
"hit_scores = [r.score for r in results.search_results]\n",
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n",
"\n",
"plt.xlabel('distance')\n",
"plt.xlabel(metric)\n",
"plt.ylabel('density')\n",
"if target_dist \u003e 0:\n",
" plt.plot([target_dist, target_dist], [0.0, np.max(ys)], 'r:')\n",
"min_dist = np.min(all_distances)\n",
"plt.plot([min_dist, min_dist], [0.0, np.max(ys)], 'g:')\n",
"\n",
"plt.show()\n",
"\n",
"# Compute the proportion of files with min_dist \u003c target_dist\n",
"hit_percentage = (np.sum(\n",
" [d \u003c target_dist for d in all_distances]) / all_distances.shape[0])\n",
"print(f'file min_dist\u003ctarget percentage : {hit_percentage:5.3f}')"
"if target_score is not None:\n",
" plt.plot([target_score, target_score], [0.0, np.max(ys)], 'r:')\n",
" # Compute the proportion of scores \u003c target_score\n",
" hit_percentage = (all_scores \u003c target_score).mean()\n",
" print(f'score \u003c target_score percentage : {hit_percentage:5.3f}')\n",
"min_score = np.min(all_scores)\n",
"plt.plot([min_score, min_score], [0.0, np.max(ys)], 'g:')\n",
"\n",
"plt.show()\n"
]
},
{
Expand Down Expand Up @@ -265,6 +269,15 @@
"results.write_labeled_data(config.annotated_path,\n",
" project_state.embedding_model.sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "C8cIzgMxSMWT"
},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -279,7 +292,8 @@
"file_id": "1HQNRQL-pQu-9kuZzKI9Fkiy6R7jYGKir",
"timestamp": 1689436763547
}
]
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
Expand Down
19 changes: 11 additions & 8 deletions chirp/projects/bootstrap/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,21 @@ def display_search_results(
# Parallel load the audio windows.
filepaths = [source_map[r.filename] for r in results]
offsets = [r.timestamp_offset for r in results]
for r, result_audio_window in zip(
results,
audio_utils.multi_load_audio_window(
filepaths, offsets, embedding_sample_rate, window_s, max_workers
),
for rank, (r, result_audio_window) in enumerate(
zip(
results,
audio_utils.multi_load_audio_window(
filepaths, offsets, embedding_sample_rate, window_s, max_workers
),
)
):
plot_audio_melspec(result_audio_window, embedding_sample_rate)
plt.show()
print(f'source file: {r.filename}')
print(f'rank : {rank}')
print(f'source file : {r.filename}')
offset_s = r.timestamp_offset
print(f'offset: {offset_s:6.2f}')
print(f'distance: {(r.distance + results.distance_offset):6.2f}')
print(f'offset_s : {offset_s:.2f}')
print(f'score : {(r.score):.2f}')
label_widgets = []

def button_callback(x):
Expand Down
Loading

0 comments on commit e5dac7a

Please sign in to comment.