-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
890 additions
and
938 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "eU3pUl3cT5sq" | ||
}, | ||
"source": [ | ||
"# Small-Model Training\n", | ||
"\n", | ||
"Train a small linear model over fixed embeddings, using a curated set of training data.\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "nrCu2iSQoIHs" | ||
}, | ||
"source": [ | ||
"# Imports and Configuration." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "PhOcIEtYG6uL" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Imports. { vertical-output: true }\n", | ||
"import collections\n", | ||
"import json\n", | ||
"from ml_collections import config_dict\n", | ||
"import numpy as np\n", | ||
"import tensorflow as tf\n", | ||
"from etils import epath\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import tqdm\n", | ||
"\n", | ||
"from chirp.inference import colab_utils\n", | ||
"colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", | ||
"\n", | ||
"# Chirp imports\n", | ||
"from chirp.models import metrics\n", | ||
"from chirp.inference import tf_examples\n", | ||
"from chirp.projects.bootstrap import bootstrap\n", | ||
"from chirp.projects.bootstrap import display\n", | ||
"from chirp.projects.bootstrap import search\n", | ||
"from chirp.projects.multicluster import classify\n", | ||
"from chirp.projects.multicluster import data_lib\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "VCfIqyIyh5wU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Configure data and model locations. { vertical-output: true }\n", | ||
"\n", | ||
"# Path containing TFRecords of unlabeled embeddings.\n", | ||
"# We will load the model which was used to compute the embeddings automatically.\n", | ||
"embeddings_path = '/tmp/embeddings' #@param\n", | ||
"\n", | ||
"# Path to the labeled wav data.\n", | ||
"# Should be in 'folder-of-folders' format - a folder with sub-folders for\n", | ||
"# each class of interest.\n", | ||
"# Audio in sub-folders should be wav files.\n", | ||
"# Audio should ideally be 5s audio clips, but the system is quite forgiving.\n", | ||
"labeled_data_path = '/tmp/labeled' #@param\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "A264FcxyG039" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Load the model. { vertical-output: true }\n", | ||
"\n", | ||
"# Get relevant info from the embedding configuration.\n", | ||
"embeddings_path = epath.Path(embeddings_path)\n", | ||
"with (embeddings_path / 'config.json').open() as f:\n", | ||
" embedding_config = config_dict.ConfigDict(json.loads(f.read()))\n", | ||
"embeddings_glob = embeddings_path / 'embeddings-*'\n", | ||
"\n", | ||
"config = bootstrap.BootstrapConfig.load_from_embedding_config(\n", | ||
" embeddings_path=embeddings_path,\n", | ||
" annotated_path=labeled_data_path)\n", | ||
"embedding_hop_size_s = config.embedding_hop_size_s\n", | ||
"project_state = bootstrap.BootstrapState(config)\n", | ||
"embedding_model = project_state.embedding_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "waSrbkzHl2o3" | ||
}, | ||
"source": [ | ||
"# Supervised Learning." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "ye4YRiARHHeR" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Load+Embed the Labeled Dataset. { vertical-output: true }\n", | ||
"\n", | ||
"# Time-pooling strategy for examples longer than the model's window size.\n", | ||
"time_pooling = 'mean' #@param\n", | ||
"\n", | ||
"merged = data_lib.MergedDataset(config.annotated_path,\n", | ||
" embedding_model,\n", | ||
" time_pooling=time_pooling)\n", | ||
"\n", | ||
"# Label distribution\n", | ||
"lbl_counts = np.sum(merged.data['label_hot'], axis=0)\n", | ||
"print('num classes :', (lbl_counts > 0).sum())\n", | ||
"print('mean ex / class :', lbl_counts.sum() / (lbl_counts > 0).sum())\n", | ||
"print('min ex / class :', (lbl_counts + (lbl_counts == 0) * 1e6).min())\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "oMzrwSzURe1i" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Train linear model over embeddings. { vertical-output: true }\n", | ||
"\n", | ||
"# Number of random training examples to choose form each class.\n", | ||
"example_per_class = 128 #@param\n", | ||
"\n", | ||
"# Number of random re-trainings. Allows judging model stability.\n", | ||
"num_seeds = 1 #@param\n", | ||
"\n", | ||
"# Classifier training hyperparams.\n", | ||
"# These should be good defaults.\n", | ||
"batch_size = 32\n", | ||
"num_epochs = 128\n", | ||
"num_hiddens = -1\n", | ||
"learning_rate = 1e-3\n", | ||
"\n", | ||
"metrics = collections.defaultdict(list)\n", | ||
"for seed in range(num_seeds):\n", | ||
" if num_hiddens > 0:\n", | ||
" model = classify.get_two_layer_model(\n", | ||
" num_hiddens, merged.embedding_dim, merged.num_classes)\n", | ||
" else:\n", | ||
" model = classify.get_linear_model(\n", | ||
" merged.embedding_dim, merged.num_classes)\n", | ||
" run_metrics = classify.train_embedding_model(\n", | ||
" model, merged, example_per_class, num_epochs, seed, batch_size, learning_rate)\n", | ||
" metrics['acc'].append(run_metrics.top1_accuracy)\n", | ||
" metrics['auc_roc'].append(run_metrics.auc_roc)\n", | ||
" metrics['cmap'].append(run_metrics.cmap_value)\n", | ||
" metrics['maps'].append(run_metrics.class_maps)\n", | ||
" metrics['recall'].append(run_metrics.recall)\n", | ||
"mean_acc = np.mean(metrics['acc'])\n", | ||
"mean_auc_roc = np.mean(metrics['auc_roc'])\n", | ||
"mean_cmap = np.mean(metrics['cmap'])\n", | ||
"mean_recall = np.mean(metrics['recall'])\n", | ||
"print(f'{example_per_class:d}, acc:{mean_acc:5.2f}, '\n", | ||
" f'auc_roc:{mean_auc_roc:5.2f}, cmap:{mean_cmap:5.2f}, '\n", | ||
" f'recall:{mean_recall:5.2f}')\n", | ||
"for lbl, auc in zip(merged.labels, run_metrics.class_maps):\n", | ||
" if np.isnan(auc):\n", | ||
" continue\n", | ||
" print(f'{lbl:8s}, auc_roc:{auc:5.2f}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "dBjsS_8xh6IB" | ||
}, | ||
"source": [ | ||
"# Evaluation on Unlabeled Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "4-Sxvh0JCCRW" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Run model on target unlabeled data. { vertical-output: true }\n", | ||
"\n", | ||
"# Choose the target class to work with.\n", | ||
"target_class = '' #@param\n", | ||
"# Choose a target logit; will display results close to the target.\n", | ||
"# Set to None to get the highest-logit examples.\n", | ||
"target_logit = None #@param\n", | ||
"# Number of results to display.\n", | ||
"num_results = 25 #@param\n", | ||
"\n", | ||
"# Create the embeddings dataset.\n", | ||
"embeddings_ds = tf_examples.create_embeddings_dataset(\n", | ||
" embeddings_path, file_glob='embeddings-*')\n", | ||
"target_class_idx = merged.labels.index(target_class)\n", | ||
"results, all_logits = search.classifer_search_embeddings_parallel(\n", | ||
" embeddings_classifier=model,\n", | ||
" target_index=target_class_idx,\n", | ||
" embeddings_dataset=embeddings_ds,\n", | ||
" hop_size_s=embedding_hop_size_s,\n", | ||
" target_score=target_logit,\n", | ||
" top_k=num_results\n", | ||
")\n", | ||
"\n", | ||
"# Plot the histogram of logits.\n", | ||
"_, ys, _ = plt.hist(all_logits, bins=128, density=True)\n", | ||
"plt.xlabel(f'{target_class} logit')\n", | ||
"plt.ylabel('density')\n", | ||
"# plt.yscale('log')\n", | ||
"plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')\n", | ||
"plt.show()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "WCGcfWxiEL4m" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Display results for the target label. { vertical-output: true }\n", | ||
"\n", | ||
"display_labels = merged.labels\n", | ||
"\n", | ||
"extra_labels = [] #@param\n", | ||
"for label in extra_labels:\n", | ||
" if label not in merged.labels:\n", | ||
" display_labels += (label,)\n", | ||
"if 'unknown' not in merged.labels:\n", | ||
" display_labels += ('unknown',)\n", | ||
"\n", | ||
"display.display_search_results(\n", | ||
" results, embedding_model.sample_rate,\n", | ||
" project_state.source_map,\n", | ||
" checkbox_labels=display_labels,\n", | ||
" max_workers=5)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "JR4jUZKvatdK" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Add selected results to the labeled data. { vertical-output: true }\n", | ||
"\n", | ||
"results.write_labeled_data(\n", | ||
" config.annotated_path, embedding_model.sample_rate)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"id": "WkpWsvQ9DGYl" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Write classifier inference CSV. { vertical-output: true }\n", | ||
"\n", | ||
"threshold = 1.0 #@param\n", | ||
"output_filepath = '/tmp/inference.csv' #@param\n", | ||
"\n", | ||
"# Create the embeddings dataset.\n", | ||
"embeddings_ds = tf_examples.create_embeddings_dataset(\n", | ||
" embeddings_path, file_glob='embeddings-*')\n", | ||
"\n", | ||
"def classify_batch(batch):\n", | ||
" \"\"\"Classify a batch of embeddings.\"\"\"\n", | ||
" emb = batch[tf_examples.EMBEDDING]\n", | ||
" emb_shape = tf.shape(emb)\n", | ||
" flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])\n", | ||
" logits = model(flat_emb)\n", | ||
" logits = tf.reshape(\n", | ||
" logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])\n", | ||
" # Take the maximum logit over channels.\n", | ||
" logits = tf.reduce_max(logits, axis=-2)\n", | ||
" batch['logits'] = logits\n", | ||
" return batch\n", | ||
"\n", | ||
"inference_ds = tf_examples.create_embeddings_dataset(\n", | ||
" embeddings_path, file_glob='embeddings-*')\n", | ||
"inference_ds = inference_ds.map(\n", | ||
" classify_batch, num_parallel_calls=tf.data.AUTOTUNE\n", | ||
")\n", | ||
"\n", | ||
"with open(output_filepath, 'w') as f:\n", | ||
" # Write column headers.\n", | ||
" headers = ['filename', 'timestamp_s', 'label', 'logit']\n", | ||
" f.write(', '.join(headers) + '\\n')\n", | ||
" for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):\n", | ||
" for t in range(ex['logits'].shape[0]):\n", | ||
" for i, label in enumerate(merged.class_names):\n", | ||
" if ex['logits'][t, i] > threshold:\n", | ||
" offset = ex['timestamp_s'] + t * config.embedding_hop_size_s\n", | ||
" logit = '{:.2f}'.format(ex['logits'][t, i])\n", | ||
" row = [ex['filename'].decode('utf-8'),\n", | ||
" '{:.2f}'.format(offset),\n", | ||
" label, logit]\n", | ||
" f.write(', '.join(row) + '\\n')\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"name": "active_learning.ipynb", | ||
"private_outputs": true, | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.