Skip to content

Commit

Permalink
Notebook updates:
Browse files Browse the repository at this point in the history
* Split the single big notebook into embed_audio, agile_modeling (search+classifier building), and analysis (validation and call density estimation).
* Add paged results in agile modeling notebook, which is helpful when dealing with very large result sets.
* Add validation and call density code.

PiperOrigin-RevId: 618826808
  • Loading branch information
sdenton4 authored and copybara-github committed Mar 25, 2024
1 parent ce5befa commit e9a8f97
Show file tree
Hide file tree
Showing 6 changed files with 823 additions and 238 deletions.
297 changes: 78 additions & 219 deletions agile_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
"source": [
"# Agile Modeling for Bioacoustics.\n",
"\n",
"This notebook provides a single-machine workflow for using pre-trained models to embed raw audio files, search, and create classifiers for target signals. This notebook is ideal for a single machine with a GPU for accelarated embedding."
"This notebook provides a workflow for creating custom classifiers for target signals, by first **searching** for training data, and then engaging in an **active learning** loop.\n",
"\n",
"We assume that embeddings have been pre-computed using `embed.ipynb`."
]
},
{
Expand Down Expand Up @@ -41,10 +43,12 @@
"colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n",
"\n",
"from chirp import audio_utils\n",
"from chirp.inference import interface\n",
"from chirp.inference import embed_lib\n",
"from chirp.inference import tf_examples\n",
"from chirp.inference import models\n",
"from chirp.models import metrics\n",
"from chirp.taxonomy import namespace\n",
"from chirp.inference.search import bootstrap\n",
"from chirp.inference.search import search\n",
"from chirp.inference.search import display\n",
Expand All @@ -70,6 +74,7 @@
"# Set the embedding and labeled data directories.\n",
"embeddings_path = epath.Path(working_dir) / 'embeddings'\n",
"labeled_data_path = epath.Path(working_dir) / 'labeled'\n",
"custom_classifier_path = epath.Path(working_dir) / 'custom_classifier'\n",
"embeddings_glob = embeddings_path / 'embeddings-*'\n",
"\n",
"# OPTIONAL: Set up separation model.\n",
Expand Down Expand Up @@ -117,171 +122,6 @@
" separator = None"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_XpWruWMArWo"
},
"source": [
"## Embed Audio"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "qx-SWjFYALok"
},
"outputs": [],
"source": [
"#@title Embedding Configuration. { vertical-output: true }\n",
"\n",
"config = config_dict.ConfigDict()\n",
"config.embed_fn_config = config_dict.ConfigDict()\n",
"config.embed_fn_config.model_config = config_dict.ConfigDict()\n",
"\n",
"# IMPORTANT: Select the targe audio files.\n",
"# source_file_patterns should contain a list of globs of audio files, like:\n",
"# ['/home/me/*.wav', '/home/me/other/*.flac']\n",
"config.source_file_patterns = [''] #@param\n",
"config.output_dir = embeddings_path.as_posix()\n",
"\n",
"# For Perch, set the perch_tfhub_model_version, and the model will load\n",
"# automagically from TFHub. Alternatively, set the model path for a local\n",
"# copy of the model.\n",
"# Note that only one of perch_model_path and perch_tfhub_version should be set.\n",
"perch_tfhub_version = 4 #@param\n",
"perch_model_path = '' #@param\n",
"\n",
"# For BirdNET, point to the specific tflite file.\n",
"birdnet_model_path = '' #@param\n",
"if model_choice == 'perch':\n",
" config.embed_fn_config.model_key = 'taxonomy_model_tf'\n",
" config.embed_fn_config.model_config.window_size_s = 5.0\n",
" config.embed_fn_config.model_config.hop_size_s = 5.0\n",
" config.embed_fn_config.model_config.sample_rate = 32000\n",
" config.embed_fn_config.model_config.tfhub_version = perch_tfhub_version\n",
" config.embed_fn_config.model_config.model_path = perch_model_path\n",
"elif model_choice == 'birdnet':\n",
" config.embed_fn_config.model_key = 'birdnet'\n",
" config.embed_fn_config.model_config.window_size_s = 3.0\n",
" config.embed_fn_config.model_config.hop_size_s = 3.0\n",
" config.embed_fn_config.model_config.sample_rate = 48000\n",
" config.embed_fn_config.model_config.model_path = birdnet_model_path\n",
" # Note: The v2_1 class list is appropriate for Birdnet 2.1, 2.2, and 2.3.\n",
" config.embed_fn_config.model_config.class_list_name = 'birdnet_v2_1'\n",
" config.embed_fn_config.model_config.num_tflite_threads = 4\n",
"\n",
"# Only write embeddings to reduce size.\n",
"config.embed_fn_config.write_embeddings = True\n",
"config.embed_fn_config.write_logits = False\n",
"config.embed_fn_config.write_separated_audio = False\n",
"config.embed_fn_config.write_raw_audio = False\n",
"\n",
"# Number of parent directories to include in the filename.\n",
"config.embed_fn_config.file_id_depth = 1"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "jb-pEadVDidv"
},
"outputs": [],
"source": [
"#@title Set up. { vertical-output: true }\n",
"\n",
"# Set up the embedding function, including loading models.\n",
"embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)\n",
"print('\\n\\nLoading model(s)...')\n",
"embed_fn.setup()\n",
"\n",
"# Create output directory and write the configuration.\n",
"output_dir = epath.Path(config.output_dir)\n",
"output_dir.mkdir(exist_ok=True, parents=True)\n",
"embed_lib.maybe_write_config(config, output_dir)\n",
"\n",
"# Create SourceInfos.\n",
"source_infos = embed_lib.create_source_infos(\n",
" config.source_file_patterns,\n",
" num_shards_per_file=config.get('num_shards_per_file', -1),\n",
" shard_len_s=config.get('shard_len_s', -1))\n",
"print(f'Found {len(source_infos)} source infos.')\n",
"\n",
"print('\\n\\nTest-run of model...')\n",
"window_size_s = config.embed_fn_config.model_config.window_size_s\n",
"sr = config.embed_fn_config.model_config.sample_rate\n",
"z = np.zeros([int(sr * window_size_s)])\n",
"embed_fn.embedding_model.embed(z)\n",
"print('Setup complete!')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "Dvnwf_LZDkBf"
},
"outputs": [],
"source": [
"#@title Run embedding. { vertical-output: true }\n",
"\n",
"# Uses multiple threads to load audio before embedding.\n",
"# This tends to be faster, but can fail if any audio files are corrupt.\n",
"\n",
"embed_fn.min_audio_s = 1.0\n",
"record_file = (output_dir / 'embeddings.tfrecord').as_posix()\n",
"succ, fail = 0, 0\n",
"\n",
"existing_embedding_ids = embed_lib.get_existing_source_ids(\n",
" output_dir, 'embeddings-*')\n",
"\n",
"new_source_infos = embed_lib.get_new_source_infos(\n",
" source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)\n",
"\n",
"print(f'Found {len(new_source_infos)} existing embedding ids.'\n",
" f'Processing {len(new_source_infos)} new source infos. ')\n",
"\n",
"audio_iterator = audio_utils.multi_load_audio_window(\n",
" filepaths=[s.filepath for s in new_source_infos],\n",
" offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],\n",
" sample_rate=config.embed_fn_config.model_config.sample_rate,\n",
" window_size_s=config.get('shard_len_s', -1.0),\n",
")\n",
"with tf_examples.EmbeddingsTFRecordMultiWriter(\n",
" output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:\n",
" for source_info, audio in tqdm.tqdm(\n",
" zip(new_source_infos, audio_iterator), total=len(new_source_infos)):\n",
" file_id = source_info.file_id(config.embed_fn_config.file_id_depth)\n",
" offset_s = source_info.shard_num * source_info.shard_len_s\n",
" example = embed_fn.audio_to_example(file_id, offset_s, audio)\n",
" if example is None:\n",
" fail += 1\n",
" continue\n",
" file_writer.write(example.SerializeToString())\n",
" succ += 1\n",
" file_writer.flush()\n",
"print(f'\\n\\nSuccessfully processed {succ} source_infos, failed {fail} times.')\n",
"\n",
"fns = [fn for fn in output_dir.glob('embeddings-*')]\n",
"ds = tf.data.TFRecordDataset(fns)\n",
"parser = tf_examples.get_example_parser()\n",
"ds = ds.map(parser)\n",
"for ex in ds.as_numpy_iterator():\n",
" print(ex['filename'])\n",
" print(ex['embedding'].shape, flush=True)\n",
" break\n",
"\n",
"# Load/refresh bootstrap_config for subsequent steps.\n",
"print('\\nRefreshing bootstrap_config.', flush=True)\n",
"bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
" embeddings_path=embeddings_path,\n",
" annotated_path=labeled_data_path)\n",
"\n",
"project_state = bootstrap.BootstrapState(bootstrap_config)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -369,7 +209,7 @@
"source": [
"#@title Select the query channel. { vertical-output: true }\n",
"\n",
"query_label = 'some_audio' #@param\n",
"query_label = 'my_label' #@param\n",
"query_channel = 0 #@param\n",
"\n",
"if query_channel < 0 or sep_outputs is None:\n",
Expand Down Expand Up @@ -451,10 +291,18 @@
"source": [
"#@title Display results. { vertical-output: true }\n",
"\n",
"display.display_search_results(\n",
" results, sample_rate, project_state.source_map,\n",
"samples_per_page = 10\n",
"page_state = display.PageState(\n",
" np.ceil(len(results.search_results) / samples_per_page))\n",
"\n",
"display.display_paged_results(\n",
" results, page_state, samples_per_page,\n",
" embedding_sample_rate=project_state.embedding_model.sample_rate,\n",
" source_map=project_state.source_map,\n",
" exclusive_labels=False,\n",
" checkbox_labels=[query_label, 'unknown'],\n",
" max_workers=5)"
" max_workers=5,\n",
")"
]
},
{
Expand Down Expand Up @@ -522,11 +370,11 @@
"\n",
"# Number of random training examples to choose form each class.\n",
"# Set exactly one of train_ratio and train_examples_per_class\n",
"train_ratio = None #@param\n",
"train_examples_per_class = 2 #@param\n",
"train_ratio = 0.9 #@param\n",
"train_examples_per_class = None #@param\n",
"\n",
"# Number of random re-trainings. Allows judging model stability.\n",
"num_seeds = 1 #@param\n",
"num_seeds = 8 #@param\n",
"\n",
"# Classifier training hyperparams.\n",
"# These should be good defaults.\n",
Expand Down Expand Up @@ -581,7 +429,7 @@
"#@title Run model on target unlabeled data. { vertical-output: true }\n",
"\n",
"# Choose the target class to work with.\n",
"target_class = 'some_audio' #@param\n",
"target_class = 'my_class' #@param\n",
"# Choose a target logit; will display results close to the target.\n",
"# Set to None to get the highest-logit examples.\n",
"target_logit = None #@param\n",
Expand Down Expand Up @@ -629,11 +477,17 @@
"if 'unknown' not in merged.labels:\n",
" display_labels += ('unknown',)\n",
"\n",
"display.display_search_results(\n",
" results, project_state.embedding_model.sample_rate,\n",
" project_state.source_map,\n",
"samples_per_page = 10\n",
"page_state = display.PageState(\n",
" np.ceil(len(results.search_results) / samples_per_page))\n",
"\n",
"display.display_paged_results(\n",
" results, page_state, samples_per_page,\n",
" embedding_sample_rate=project_state.embedding_model.sample_rate,\n",
" source_map=project_state.source_map,\n",
" exclusive_labels=False,\n",
" checkbox_labels=display_labels,\n",
" max_workers=5)"
")"
]
},
{
Expand All @@ -651,6 +505,27 @@
" project_state.embedding_model.sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "ZxasEcnhd7kP"
},
"outputs": [],
"source": [
"#@title Save the Custom Classifier. { vertical-output: true }\n",
"\n",
"wrapped_model = interface.LogitsOutputHead(\n",
" model_path=custom_classifier_path.as_posix(),\n",
" logits_key='logits',\n",
" logits_model=model,\n",
" class_list=namespace.ClassList('custom', merged.labels),\n",
")\n",
"wrapped_model.save_model(\n",
" custom_classifier_path,\n",
" embeddings_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -670,54 +545,38 @@
"source": [
"#@title Write classifier inference CSV. { vertical-output: true }\n",
"\n",
"threshold = 1.0 #@param\n",
"output_filepath = '/tmp/inference.csv' #@param\n",
"\n",
"# Set detection thresholds.\n",
"default_threshold = 0.0 #@param\n",
"if default_threshold is None:\n",
" # In this case, all logits are written. This can lead to very large CSV files.\n",
" class_thresholds = None\n",
"else:\n",
" class_thresholds = collections.defaultdict(lambda: default_threshold)\n",
" # Set per-class thresholds here.\n",
" class_thresholds['my_class'] = 1.0\n",
"\n",
"exclude_classes = ['unknown'] #@param\n",
"\n",
"# include_classes is ignored if empty.\n",
"# If non-empty, only scores for these classes will be written.\n",
"include_classes = [] #@param\n",
"\n",
"# Create the embeddings dataset.\n",
"embeddings_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"\n",
"def classify_batch(batch):\n",
" \"\"\"Classify a batch of embeddings.\"\"\"\n",
" emb = batch[tf_examples.EMBEDDING]\n",
" emb_shape = tf.shape(emb)\n",
" flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])\n",
" logits = model(flat_emb)\n",
" logits = tf.reshape(\n",
" logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])\n",
" # Take the maximum logit over channels.\n",
" logits = tf.reduce_max(logits, axis=-2)\n",
" batch['logits'] = logits\n",
" return batch\n",
"\n",
"inference_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"inference_ds = inference_ds.map(\n",
" classify_batch, num_parallel_calls=tf.data.AUTOTUNE\n",
")\n",
"\n",
"with open(output_filepath, 'w') as f:\n",
" # Write column headers.\n",
" headers = ['filename', 'timestamp_s', 'label', 'logit']\n",
" f.write(', '.join(headers) + '\\n')\n",
" for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):\n",
" for t in range(ex['logits'].shape[0]):\n",
" for i, label in enumerate(merged.labels):\n",
" if ex['logits'][t, i] > threshold:\n",
" offset = ex['timestamp_s'] + t * bootstrap_config.embedding_hop_size_s\n",
" logit = '{:.2f}'.format(ex['logits'][t, i])\n",
" row = [ex['filename'].decode('utf-8'),\n",
" '{:.2f}'.format(offset),\n",
" label, logit]\n",
" f.write(', '.join(row) + '\\n')\n"
"classify.write_inference_csv(\n",
" embeddings_ds=embeddings_ds,\n",
" model=model,\n",
" labels=merged.labels,\n",
" output_filepath=output_filepath,\n",
" threshold=class_thresholds,\n",
" embedding_hop_size_s=bootstrap_config.embedding_hop_size_s,\n",
" include_classes=include_classes,\n",
" exclude_classes=exclude_classes)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSqxSk74EIgs"
},
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit e9a8f97

Please sign in to comment.