Skip to content

Commit

Permalink
Move notebooks to top level.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560192193
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 25, 2023
1 parent 719eac9 commit 7e5087a
Show file tree
Hide file tree
Showing 6 changed files with 890 additions and 938 deletions.
343 changes: 343 additions & 0 deletions chirp/active_learning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "eU3pUl3cT5sq"
},
"source": [
"# Small-Model Training\n",
"\n",
"Train a small linear model over fixed embeddings, using a curated set of training data.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nrCu2iSQoIHs"
},
"source": [
"# Imports and Configuration."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "PhOcIEtYG6uL"
},
"outputs": [],
"source": [
"#@title Imports. { vertical-output: true }\n",
"import collections\n",
"import json\n",
"from ml_collections import config_dict\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from etils import epath\n",
"import matplotlib.pyplot as plt\n",
"import tqdm\n",
"\n",
"from chirp.inference import colab_utils\n",
"colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n",
"\n",
"# Chirp imports\n",
"from chirp.models import metrics\n",
"from chirp.inference import tf_examples\n",
"from chirp.projects.bootstrap import bootstrap\n",
"from chirp.projects.bootstrap import display\n",
"from chirp.projects.bootstrap import search\n",
"from chirp.projects.multicluster import classify\n",
"from chirp.projects.multicluster import data_lib\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "VCfIqyIyh5wU"
},
"outputs": [],
"source": [
"#@title Configure data and model locations. { vertical-output: true }\n",
"\n",
"# Path containing TFRecords of unlabeled embeddings.\n",
"# We will load the model which was used to compute the embeddings automatically.\n",
"embeddings_path = '/tmp/embeddings' #@param\n",
"\n",
"# Path to the labeled wav data.\n",
"# Should be in 'folder-of-folders' format - a folder with sub-folders for\n",
"# each class of interest.\n",
"# Audio in sub-folders should be wav files.\n",
"# Audio should ideally be 5s audio clips, but the system is quite forgiving.\n",
"labeled_data_path = '/tmp/labeled' #@param\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "A264FcxyG039"
},
"outputs": [],
"source": [
"#@title Load the model. { vertical-output: true }\n",
"\n",
"# Get relevant info from the embedding configuration.\n",
"embeddings_path = epath.Path(embeddings_path)\n",
"with (embeddings_path / 'config.json').open() as f:\n",
" embedding_config = config_dict.ConfigDict(json.loads(f.read()))\n",
"embeddings_glob = embeddings_path / 'embeddings-*'\n",
"\n",
"config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
" embeddings_path=embeddings_path,\n",
" annotated_path=labeled_data_path)\n",
"embedding_hop_size_s = config.embedding_hop_size_s\n",
"project_state = bootstrap.BootstrapState(config)\n",
"embedding_model = project_state.embedding_model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "waSrbkzHl2o3"
},
"source": [
"# Supervised Learning."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "ye4YRiARHHeR"
},
"outputs": [],
"source": [
"#@title Load+Embed the Labeled Dataset. { vertical-output: true }\n",
"\n",
"# Time-pooling strategy for examples longer than the model's window size.\n",
"time_pooling = 'mean' #@param\n",
"\n",
"merged = data_lib.MergedDataset(config.annotated_path,\n",
" embedding_model,\n",
" time_pooling=time_pooling)\n",
"\n",
"# Label distribution\n",
"lbl_counts = np.sum(merged.data['label_hot'], axis=0)\n",
"print('num classes :', (lbl_counts > 0).sum())\n",
"print('mean ex / class :', lbl_counts.sum() / (lbl_counts > 0).sum())\n",
"print('min ex / class :', (lbl_counts + (lbl_counts == 0) * 1e6).min())\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "oMzrwSzURe1i"
},
"outputs": [],
"source": [
"#@title Train linear model over embeddings. { vertical-output: true }\n",
"\n",
"# Number of random training examples to choose form each class.\n",
"example_per_class = 128 #@param\n",
"\n",
"# Number of random re-trainings. Allows judging model stability.\n",
"num_seeds = 1 #@param\n",
"\n",
"# Classifier training hyperparams.\n",
"# These should be good defaults.\n",
"batch_size = 32\n",
"num_epochs = 128\n",
"num_hiddens = -1\n",
"learning_rate = 1e-3\n",
"\n",
"metrics = collections.defaultdict(list)\n",
"for seed in range(num_seeds):\n",
" if num_hiddens > 0:\n",
" model = classify.get_two_layer_model(\n",
" num_hiddens, merged.embedding_dim, merged.num_classes)\n",
" else:\n",
" model = classify.get_linear_model(\n",
" merged.embedding_dim, merged.num_classes)\n",
" run_metrics = classify.train_embedding_model(\n",
" model, merged, example_per_class, num_epochs, seed, batch_size, learning_rate)\n",
" metrics['acc'].append(run_metrics.top1_accuracy)\n",
" metrics['auc_roc'].append(run_metrics.auc_roc)\n",
" metrics['cmap'].append(run_metrics.cmap_value)\n",
" metrics['maps'].append(run_metrics.class_maps)\n",
" metrics['recall'].append(run_metrics.recall)\n",
"mean_acc = np.mean(metrics['acc'])\n",
"mean_auc_roc = np.mean(metrics['auc_roc'])\n",
"mean_cmap = np.mean(metrics['cmap'])\n",
"mean_recall = np.mean(metrics['recall'])\n",
"print(f'{example_per_class:d}, acc:{mean_acc:5.2f}, '\n",
" f'auc_roc:{mean_auc_roc:5.2f}, cmap:{mean_cmap:5.2f}, '\n",
" f'recall:{mean_recall:5.2f}')\n",
"for lbl, auc in zip(merged.labels, run_metrics.class_maps):\n",
" if np.isnan(auc):\n",
" continue\n",
" print(f'{lbl:8s}, auc_roc:{auc:5.2f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dBjsS_8xh6IB"
},
"source": [
"# Evaluation on Unlabeled Data"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "4-Sxvh0JCCRW"
},
"outputs": [],
"source": [
"#@title Run model on target unlabeled data. { vertical-output: true }\n",
"\n",
"# Choose the target class to work with.\n",
"target_class = '' #@param\n",
"# Choose a target logit; will display results close to the target.\n",
"# Set to None to get the highest-logit examples.\n",
"target_logit = None #@param\n",
"# Number of results to display.\n",
"num_results = 25 #@param\n",
"\n",
"# Create the embeddings dataset.\n",
"embeddings_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"target_class_idx = merged.labels.index(target_class)\n",
"results, all_logits = search.classifer_search_embeddings_parallel(\n",
" embeddings_classifier=model,\n",
" target_index=target_class_idx,\n",
" embeddings_dataset=embeddings_ds,\n",
" hop_size_s=embedding_hop_size_s,\n",
" target_score=target_logit,\n",
" top_k=num_results\n",
")\n",
"\n",
"# Plot the histogram of logits.\n",
"_, ys, _ = plt.hist(all_logits, bins=128, density=True)\n",
"plt.xlabel(f'{target_class} logit')\n",
"plt.ylabel('density')\n",
"# plt.yscale('log')\n",
"plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "WCGcfWxiEL4m"
},
"outputs": [],
"source": [
"#@title Display results for the target label. { vertical-output: true }\n",
"\n",
"display_labels = merged.labels\n",
"\n",
"extra_labels = [] #@param\n",
"for label in extra_labels:\n",
" if label not in merged.labels:\n",
" display_labels += (label,)\n",
"if 'unknown' not in merged.labels:\n",
" display_labels += ('unknown',)\n",
"\n",
"display.display_search_results(\n",
" results, embedding_model.sample_rate,\n",
" project_state.source_map,\n",
" checkbox_labels=display_labels,\n",
" max_workers=5)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "JR4jUZKvatdK"
},
"outputs": [],
"source": [
"#@title Add selected results to the labeled data. { vertical-output: true }\n",
"\n",
"results.write_labeled_data(\n",
" config.annotated_path, embedding_model.sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "WkpWsvQ9DGYl"
},
"outputs": [],
"source": [
"#@title Write classifier inference CSV. { vertical-output: true }\n",
"\n",
"threshold = 1.0 #@param\n",
"output_filepath = '/tmp/inference.csv' #@param\n",
"\n",
"# Create the embeddings dataset.\n",
"embeddings_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"\n",
"def classify_batch(batch):\n",
" \"\"\"Classify a batch of embeddings.\"\"\"\n",
" emb = batch[tf_examples.EMBEDDING]\n",
" emb_shape = tf.shape(emb)\n",
" flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])\n",
" logits = model(flat_emb)\n",
" logits = tf.reshape(\n",
" logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])\n",
" # Take the maximum logit over channels.\n",
" logits = tf.reduce_max(logits, axis=-2)\n",
" batch['logits'] = logits\n",
" return batch\n",
"\n",
"inference_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"inference_ds = inference_ds.map(\n",
" classify_batch, num_parallel_calls=tf.data.AUTOTUNE\n",
")\n",
"\n",
"with open(output_filepath, 'w') as f:\n",
" # Write column headers.\n",
" headers = ['filename', 'timestamp_s', 'label', 'logit']\n",
" f.write(', '.join(headers) + '\\n')\n",
" for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):\n",
" for t in range(ex['logits'].shape[0]):\n",
" for i, label in enumerate(merged.class_names):\n",
" if ex['logits'][t, i] > threshold:\n",
" offset = ex['timestamp_s'] + t * config.embedding_hop_size_s\n",
" logit = '{:.2f}'.format(ex['logits'][t, i])\n",
" row = [ex['filename'].decode('utf-8'),\n",
" '{:.2f}'.format(offset),\n",
" label, logit]\n",
" f.write(', '.join(row) + '\\n')\n"
]
}
],
"metadata": {
"colab": {
"name": "active_learning.ipynb",
"private_outputs": true,
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit 7e5087a

Please sign in to comment.