From 051debfc427cd3276cc8055ca4d46abd5bd9ac01 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 21 Mar 2024 12:01:08 -0700 Subject: [PATCH] Notebook updates: * Split the single big notebook into embed_audio, agile_modeling (search+classifier building), and analysis (validation and call density estimation). * Add paged results in agile modeling notebook, which is helpful when dealing with very large result sets. * Add validation and call density code. PiperOrigin-RevId: 617919728 --- agile_modeling.ipynb | 297 +++++++---------------- analysis.ipynb | 338 +++++++++++++++++++++++++++ chirp/inference/classify/classify.py | 71 ++++++ chirp/inference/search/display.py | 112 +++++++-- chirp/inference/search/search.py | 5 + embed_audio.ipynb | 238 +++++++++++++++++++ 6 files changed, 823 insertions(+), 238 deletions(-) create mode 100644 analysis.ipynb create mode 100644 embed_audio.ipynb diff --git a/agile_modeling.ipynb b/agile_modeling.ipynb index 30fe6e6c..da906280 100644 --- a/agile_modeling.ipynb +++ b/agile_modeling.ipynb @@ -8,7 +8,9 @@ "source": [ "# Agile Modeling for Bioacoustics.\n", "\n", - "This notebook provides a single-machine workflow for using pre-trained models to embed raw audio files, search, and create classifiers for target signals. This notebook is ideal for a single machine with a GPU for accelarated embedding." + "This notebook provides a workflow for creating custom classifiers for target signals, by first **searching** for training data, and then engaging in an **active learning** loop.\n", + "\n", + "We assume that embeddings have been pre-computed using `embed.ipynb`." ] }, { @@ -41,10 +43,12 @@ "colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", "\n", "from chirp import audio_utils\n", + "from chirp.inference import interface\n", "from chirp.inference import embed_lib\n", "from chirp.inference import tf_examples\n", "from chirp.inference import models\n", "from chirp.models import metrics\n", + "from chirp.taxonomy import namespace\n", "from chirp.inference.search import bootstrap\n", "from chirp.inference.search import search\n", "from chirp.inference.search import display\n", @@ -70,6 +74,7 @@ "# Set the embedding and labeled data directories.\n", "embeddings_path = epath.Path(working_dir) / 'embeddings'\n", "labeled_data_path = epath.Path(working_dir) / 'labeled'\n", + "custom_classifier_path = epath.Path(working_dir) / 'custom_classifier'\n", "embeddings_glob = embeddings_path / 'embeddings-*'\n", "\n", "# OPTIONAL: Set up separation model.\n", @@ -117,171 +122,6 @@ " separator = None" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "_XpWruWMArWo" - }, - "source": [ - "## Embed Audio" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "id": "qx-SWjFYALok" - }, - "outputs": [], - "source": [ - "#@title Embedding Configuration. { vertical-output: true }\n", - "\n", - "config = config_dict.ConfigDict()\n", - "config.embed_fn_config = config_dict.ConfigDict()\n", - "config.embed_fn_config.model_config = config_dict.ConfigDict()\n", - "\n", - "# IMPORTANT: Select the targe audio files.\n", - "# source_file_patterns should contain a list of globs of audio files, like:\n", - "# ['/home/me/*.wav', '/home/me/other/*.flac']\n", - "config.source_file_patterns = [''] #@param\n", - "config.output_dir = embeddings_path.as_posix()\n", - "\n", - "# For Perch, set the perch_tfhub_model_version, and the model will load\n", - "# automagically from TFHub. Alternatively, set the model path for a local\n", - "# copy of the model.\n", - "# Note that only one of perch_model_path and perch_tfhub_version should be set.\n", - "perch_tfhub_version = 4 #@param\n", - "perch_model_path = '' #@param\n", - "\n", - "# For BirdNET, point to the specific tflite file.\n", - "birdnet_model_path = '' #@param\n", - "if model_choice == 'perch':\n", - " config.embed_fn_config.model_key = 'taxonomy_model_tf'\n", - " config.embed_fn_config.model_config.window_size_s = 5.0\n", - " config.embed_fn_config.model_config.hop_size_s = 5.0\n", - " config.embed_fn_config.model_config.sample_rate = 32000\n", - " config.embed_fn_config.model_config.tfhub_version = perch_tfhub_version\n", - " config.embed_fn_config.model_config.model_path = perch_model_path\n", - "elif model_choice == 'birdnet':\n", - " config.embed_fn_config.model_key = 'birdnet'\n", - " config.embed_fn_config.model_config.window_size_s = 3.0\n", - " config.embed_fn_config.model_config.hop_size_s = 3.0\n", - " config.embed_fn_config.model_config.sample_rate = 48000\n", - " config.embed_fn_config.model_config.model_path = birdnet_model_path\n", - " # Note: The v2_1 class list is appropriate for Birdnet 2.1, 2.2, and 2.3.\n", - " config.embed_fn_config.model_config.class_list_name = 'birdnet_v2_1'\n", - " config.embed_fn_config.model_config.num_tflite_threads = 4\n", - "\n", - "# Only write embeddings to reduce size.\n", - "config.embed_fn_config.write_embeddings = True\n", - "config.embed_fn_config.write_logits = False\n", - "config.embed_fn_config.write_separated_audio = False\n", - "config.embed_fn_config.write_raw_audio = False\n", - "\n", - "# Number of parent directories to include in the filename.\n", - "config.embed_fn_config.file_id_depth = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "id": "jb-pEadVDidv" - }, - "outputs": [], - "source": [ - "#@title Set up. { vertical-output: true }\n", - "\n", - "# Set up the embedding function, including loading models.\n", - "embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)\n", - "print('\\n\\nLoading model(s)...')\n", - "embed_fn.setup()\n", - "\n", - "# Create output directory and write the configuration.\n", - "output_dir = epath.Path(config.output_dir)\n", - "output_dir.mkdir(exist_ok=True, parents=True)\n", - "embed_lib.maybe_write_config(config, output_dir)\n", - "\n", - "# Create SourceInfos.\n", - "source_infos = embed_lib.create_source_infos(\n", - " config.source_file_patterns,\n", - " num_shards_per_file=config.get('num_shards_per_file', -1),\n", - " shard_len_s=config.get('shard_len_s', -1))\n", - "print(f'Found {len(source_infos)} source infos.')\n", - "\n", - "print('\\n\\nTest-run of model...')\n", - "window_size_s = config.embed_fn_config.model_config.window_size_s\n", - "sr = config.embed_fn_config.model_config.sample_rate\n", - "z = np.zeros([int(sr * window_size_s)])\n", - "embed_fn.embedding_model.embed(z)\n", - "print('Setup complete!')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "id": "Dvnwf_LZDkBf" - }, - "outputs": [], - "source": [ - "#@title Run embedding. { vertical-output: true }\n", - "\n", - "# Uses multiple threads to load audio before embedding.\n", - "# This tends to be faster, but can fail if any audio files are corrupt.\n", - "\n", - "embed_fn.min_audio_s = 1.0\n", - "record_file = (output_dir / 'embeddings.tfrecord').as_posix()\n", - "succ, fail = 0, 0\n", - "\n", - "existing_embedding_ids = embed_lib.get_existing_source_ids(\n", - " output_dir, 'embeddings-*')\n", - "\n", - "new_source_infos = embed_lib.get_new_source_infos(\n", - " source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)\n", - "\n", - "print(f'Found {len(new_source_infos)} existing embedding ids.'\n", - " f'Processing {len(new_source_infos)} new source infos. ')\n", - "\n", - "audio_iterator = audio_utils.multi_load_audio_window(\n", - " filepaths=[s.filepath for s in new_source_infos],\n", - " offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],\n", - " sample_rate=config.embed_fn_config.model_config.sample_rate,\n", - " window_size_s=config.get('shard_len_s', -1.0),\n", - ")\n", - "with tf_examples.EmbeddingsTFRecordMultiWriter(\n", - " output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:\n", - " for source_info, audio in tqdm.tqdm(\n", - " zip(new_source_infos, audio_iterator), total=len(new_source_infos)):\n", - " file_id = source_info.file_id(config.embed_fn_config.file_id_depth)\n", - " offset_s = source_info.shard_num * source_info.shard_len_s\n", - " example = embed_fn.audio_to_example(file_id, offset_s, audio)\n", - " if example is None:\n", - " fail += 1\n", - " continue\n", - " file_writer.write(example.SerializeToString())\n", - " succ += 1\n", - " file_writer.flush()\n", - "print(f'\\n\\nSuccessfully processed {succ} source_infos, failed {fail} times.')\n", - "\n", - "fns = [fn for fn in output_dir.glob('embeddings-*')]\n", - "ds = tf.data.TFRecordDataset(fns)\n", - "parser = tf_examples.get_example_parser()\n", - "ds = ds.map(parser)\n", - "for ex in ds.as_numpy_iterator():\n", - " print(ex['filename'])\n", - " print(ex['embedding'].shape, flush=True)\n", - " break\n", - "\n", - "# Load/refresh bootstrap_config for subsequent steps.\n", - "print('\\nRefreshing bootstrap_config.', flush=True)\n", - "bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(\n", - " embeddings_path=embeddings_path,\n", - " annotated_path=labeled_data_path)\n", - "\n", - "project_state = bootstrap.BootstrapState(bootstrap_config)\n" - ] - }, { "cell_type": "markdown", "metadata": { @@ -369,7 +209,7 @@ "source": [ "#@title Select the query channel. { vertical-output: true }\n", "\n", - "query_label = 'some_audio' #@param\n", + "query_label = 'my_label' #@param\n", "query_channel = 0 #@param\n", "\n", "if query_channel < 0 or sep_outputs is None:\n", @@ -451,10 +291,18 @@ "source": [ "#@title Display results. { vertical-output: true }\n", "\n", - "display.display_search_results(\n", - " results, sample_rate, project_state.source_map,\n", + "samples_per_page = 10\n", + "page_state = display.PageState(\n", + " np.ceil(len(results.search_results) / samples_per_page))\n", + "\n", + "display.display_paged_results(\n", + " results, page_state, samples_per_page,\n", + " embedding_sample_rate=project_state.embedding_model.sample_rate,\n", + " source_map=project_state.source_map,\n", + " exclusive_labels=False,\n", " checkbox_labels=[query_label, 'unknown'],\n", - " max_workers=5)" + " max_workers=5,\n", + ")" ] }, { @@ -522,11 +370,11 @@ "\n", "# Number of random training examples to choose form each class.\n", "# Set exactly one of train_ratio and train_examples_per_class\n", - "train_ratio = None #@param\n", - "train_examples_per_class = 2 #@param\n", + "train_ratio = 0.9 #@param\n", + "train_examples_per_class = None #@param\n", "\n", "# Number of random re-trainings. Allows judging model stability.\n", - "num_seeds = 1 #@param\n", + "num_seeds = 8 #@param\n", "\n", "# Classifier training hyperparams.\n", "# These should be good defaults.\n", @@ -581,7 +429,7 @@ "#@title Run model on target unlabeled data. { vertical-output: true }\n", "\n", "# Choose the target class to work with.\n", - "target_class = 'some_audio' #@param\n", + "target_class = 'my_class' #@param\n", "# Choose a target logit; will display results close to the target.\n", "# Set to None to get the highest-logit examples.\n", "target_logit = None #@param\n", @@ -629,11 +477,17 @@ "if 'unknown' not in merged.labels:\n", " display_labels += ('unknown',)\n", "\n", - "display.display_search_results(\n", - " results, project_state.embedding_model.sample_rate,\n", - " project_state.source_map,\n", + "samples_per_page = 10\n", + "page_state = display.PageState(\n", + " np.ceil(len(results.search_results) / samples_per_page))\n", + "\n", + "display.display_paged_results(\n", + " results, page_state, samples_per_page,\n", + " embedding_sample_rate=project_state.embedding_model.sample_rate,\n", + " source_map=project_state.source_map,\n", + " exclusive_labels=False,\n", " checkbox_labels=display_labels,\n", - " max_workers=5)" + ")" ] }, { @@ -651,6 +505,27 @@ " project_state.embedding_model.sample_rate)" ] }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "ZxasEcnhd7kP" + }, + "outputs": [], + "source": [ + "#@title Save the Custom Classifier. { vertical-output: true }\n", + "\n", + "wrapped_model = interface.LogitsOutputHead(\n", + " model_path=custom_classifier_path.as_posix(),\n", + " logits_key='logits',\n", + " logits_model=model,\n", + " class_list=namespace.ClassList('custom', merged.labels),\n", + ")\n", + "wrapped_model.save_model(\n", + " custom_classifier_path,\n", + " embeddings_path)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -670,54 +545,38 @@ "source": [ "#@title Write classifier inference CSV. { vertical-output: true }\n", "\n", - "threshold = 1.0 #@param\n", "output_filepath = '/tmp/inference.csv' #@param\n", "\n", + "# Set detection thresholds.\n", + "default_threshold = 0.0 #@param\n", + "if default_threshold is None:\n", + " # In this case, all logits are written. This can lead to very large CSV files.\n", + " class_thresholds = None\n", + "else:\n", + " class_thresholds = collections.defaultdict(lambda: default_threshold)\n", + " # Set per-class thresholds here.\n", + " class_thresholds['my_class'] = 1.0\n", + "\n", + "exclude_classes = ['unknown'] #@param\n", + "\n", + "# include_classes is ignored if empty.\n", + "# If non-empty, only scores for these classes will be written.\n", + "include_classes = [] #@param\n", + "\n", "# Create the embeddings dataset.\n", "embeddings_ds = tf_examples.create_embeddings_dataset(\n", " embeddings_path, file_glob='embeddings-*')\n", "\n", - "def classify_batch(batch):\n", - " \"\"\"Classify a batch of embeddings.\"\"\"\n", - " emb = batch[tf_examples.EMBEDDING]\n", - " emb_shape = tf.shape(emb)\n", - " flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])\n", - " logits = model(flat_emb)\n", - " logits = tf.reshape(\n", - " logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])\n", - " # Take the maximum logit over channels.\n", - " logits = tf.reduce_max(logits, axis=-2)\n", - " batch['logits'] = logits\n", - " return batch\n", - "\n", - "inference_ds = tf_examples.create_embeddings_dataset(\n", - " embeddings_path, file_glob='embeddings-*')\n", - "inference_ds = inference_ds.map(\n", - " classify_batch, num_parallel_calls=tf.data.AUTOTUNE\n", - ")\n", - "\n", - "with open(output_filepath, 'w') as f:\n", - " # Write column headers.\n", - " headers = ['filename', 'timestamp_s', 'label', 'logit']\n", - " f.write(', '.join(headers) + '\\n')\n", - " for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):\n", - " for t in range(ex['logits'].shape[0]):\n", - " for i, label in enumerate(merged.labels):\n", - " if ex['logits'][t, i] > threshold:\n", - " offset = ex['timestamp_s'] + t * bootstrap_config.embedding_hop_size_s\n", - " logit = '{:.2f}'.format(ex['logits'][t, i])\n", - " row = [ex['filename'].decode('utf-8'),\n", - " '{:.2f}'.format(offset),\n", - " label, logit]\n", - " f.write(', '.join(row) + '\\n')\n" + "classify.write_inference_csv(\n", + " embeddings_ds=embeddings_ds,\n", + " model=model,\n", + " labels=merged.labels,\n", + " output_filepath=output_filepath,\n", + " threshold=class_thresholds,\n", + " embedding_hop_size_s=bootstrap_config.embedding_hop_size_s,\n", + " include_classes=include_classes,\n", + " exclude_classes=exclude_classes)\n" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HSqxSk74EIgs" - }, - "source": [] } ], "metadata": { diff --git a/analysis.ipynb b/analysis.ipynb new file mode 100644 index 00000000..3404ec27 --- /dev/null +++ b/analysis.ipynb @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ndV0dmyzhpHE" + }, + "source": [ + "# Analysis of Bioacoustic Data\n", + "\n", + "This notebook provides tools for analyzing data using a custom classifier (developed with `agile_modeling.ipynb`)." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "i984ftjPcxDu" + }, + "outputs": [], + "source": [ + "#@title Imports. { vertical-output: true }\n", + "\n", + "import collections\n", + "from etils import epath\n", + "from ml_collections import config_dict\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from chirp.inference import colab_utils\n", + "colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", + "\n", + "from chirp.inference import interface\n", + "from chirp.inference import tf_examples\n", + "from chirp.inference.search import bootstrap\n", + "from chirp.inference.search import search\n", + "from chirp.inference.search import display\n", + "from chirp.inference.classify import classify\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "TRETHuu1h7uZ" + }, + "outputs": [], + "source": [ + "#@title Basic Configuration. { vertical-output: true }\n", + "\n", + "# Define the model: Usually perch or birdnet.\n", + "model_choice = 'perch' #@param\n", + "# Set the base directory for the project.\n", + "working_dir = '/tmp/agile' #@param\n", + "\n", + "# Set the embedding and labeled data directories.\n", + "embeddings_path = epath.Path(working_dir) / 'embeddings'\n", + "labeled_data_path = epath.Path(working_dir) / 'labeled'\n", + "custom_classifier_path = epath.Path(working_dir) / 'custom_classifier'\n", + "embeddings_glob = embeddings_path / 'embeddings-*'\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "ake6Xk_Hh-nN" + }, + "outputs": [], + "source": [ + "#@title Load Existing Project State and Models. { vertical-output: true }\n", + "\n", + "# If you have already computed embeddings, run this cell to load models\n", + "# and find existing data.\n", + "\n", + "if (embeddings_path / 'config.json').exists():\n", + " # Get relevant info from the embedding configuration.\n", + " bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(\n", + " embeddings_path=embeddings_path,\n", + " annotated_path=labeled_data_path)\n", + " project_state = bootstrap.BootstrapState(bootstrap_config)\n", + "\n", + "cfg = config_dict.ConfigDict({\n", + " 'model_path': custom_classifier_path,\n", + " 'logits_key': 'custom',\n", + "})\n", + "loaded_model = interface.LogitsOutputHead.from_config(cfg)\n", + "model = loaded_model.logits_model\n", + "class_list = loaded_model.class_list\n", + "print('Loaded custom model with classes: ')\n", + "print('\\t' + '\\n\\t'.join(class_list.classes))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "Ur03VoLyuBHR" + }, + "outputs": [], + "source": [ + "#@title Write classifier inference CSV. { vertical-output: true }\n", + "\n", + "output_filepath = '/tmp/inference.csv' #@param\n", + "\n", + "# Set detection thresholds.\n", + "default_threshold = 0.0 #@param\n", + "if default_threshold is None:\n", + " # In this case, all logits are written. This can lead to very large CSV files.\n", + " class_thresholds = None\n", + "else:\n", + " class_thresholds = collections.defaultdict(lambda: default_threshold)\n", + " # Set per-class thresholds here.\n", + " class_thresholds['my_class'] = 1.0\n", + "\n", + "# Classes for which we do not want to write detections.\n", + "exclude_classes = ['unknown'] #@param\n", + "\n", + "# include_classes is ignored if empty.\n", + "# If non-empty, only scores for these classes will be written.\n", + "include_classes = [] #@param\n", + "\n", + "# Create the embeddings dataset.\n", + "embeddings_ds = tf_examples.create_embeddings_dataset(\n", + " embeddings_path, file_glob='embeddings-*')\n", + "\n", + "classify.write_inference_csv(\n", + " embeddings_ds=embeddings_ds,\n", + " model=model,\n", + " labels=class_list.classes,\n", + " output_filepath=output_filepath,\n", + " threshold=class_thresholds,\n", + " embedding_hop_size_s=bootstrap_config.embedding_hop_size_s,\n", + " include_classes=include_classes,\n", + " exclude_classes=exclude_classes)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GdJmpn0XzMj6" + }, + "source": [ + "## Call Density Estimation\n", + "\n", + "See 'All Thresholds Barred': https://arxiv.org/abs/2402.15360" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "-lhqypjsu2L9" + }, + "outputs": [], + "source": [ + "#@title Validation and Call Density. { vertical-output: true }\n", + "# For validation, we select random samples from logarithmic-quantile bins.\n", + "\n", + "target_class = 'my_class' #@param\n", + "\n", + "num_bins = 4 #@param\n", + "samples_per_bin = 50 #@param\n", + "# The highest bin contains 2**-num_bins of the data.\n", + "top_k = samples_per_bin * 2**(num_bins + 1)\n", + "\n", + "embeddings_ds = tf_examples.create_embeddings_dataset(\n", + " embeddings_path, file_glob='embeddings-*')\n", + "results, all_logits = search.classifer_search_embeddings_parallel(\n", + " embeddings_classifier=model,\n", + " target_index=class_list.classes.index(target_class),\n", + " random_sample=True,\n", + " top_k=top_k,\n", + " hop_size_s=bootstrap_config.embedding_hop_size_s,\n", + " embeddings_dataset=embeddings_ds,\n", + ")\n", + "\n", + "# Pick samples_per_bin examples from each quantile.\n", + "def get_quantile_bounds(n_bins):\n", + " lowers = [1.0 - 1.0 / 2**(k + 1) for k in range(n_bins - 1)]\n", + " return np.array([0.0] + lowers + [1.0])\n", + "\n", + "bounds = get_quantile_bounds(num_bins)\n", + "q_bounds = np.quantile(all_logits, bounds)\n", + "binned = [[] for _ in range(num_bins)]\n", + "for r in results.search_results:\n", + " bin = np.argmax(r.score < q_bounds) - 1\n", + " binned[bin].append(r)\n", + "binned = [np.random.choice(b, samples_per_bin) for b in binned]\n", + "\n", + "combined = []\n", + "for b in binned:\n", + " combined.extend(b)\n", + "np.random.shuffle(combined)\n", + "\n", + "samples_per_page = 10\n", + "page_state = display.PageState(np.ceil(len(combined) / samples_per_page))\n", + "\n", + "display.display_paged_results(\n", + " search.TopKSearchResults(combined, len(combined)),\n", + " page_state, samples_per_page,\n", + " embedding_sample_rate=project_state.embedding_model.sample_rate,\n", + " source_map=project_state.source_map,\n", + " exclusive_labels=True,\n", + " checkbox_labels=[target_class, f'not {target_class}', 'unsure'],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "A30usazfu8h8" + }, + "outputs": [], + "source": [ + "#@title Collate results and write validation log. { vertical-output: true }\n", + "\n", + "validation_log_filepath = epath.Path(working_dir) / 'validation.csv'\n", + "\n", + "filenames = []\n", + "timestamp_offsets = []\n", + "scores = []\n", + "is_pos = []\n", + "\n", + "for r in combined:\n", + " if not r.label_widgets: continue\n", + " value = r.label_widgets[0].value\n", + " if value is None:\n", + " continue\n", + " filenames.append(r.filename)\n", + " scores.append(r.score)\n", + " timestamp_offsets.append(r.timestamp_offset)\n", + " if value == target_class:\n", + " is_pos.append(1)\n", + " elif value == f'not {target_class}':\n", + " is_pos.append(-1)\n", + " elif value == 'unsure':\n", + " is_pos.append(0)\n", + "\n", + "label = [target_class for _ in range(len(filenames))]\n", + "log = pd.DataFrame({\n", + " 'filenames': filenames,\n", + " 'timestamp_offsets': timestamp_offsets,\n", + " 'scores': scores,\n", + " 'is_pos': is_pos})\n", + "log.to_csv(output_filepath, mode='a')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "uZHVzPttwGZ2" + }, + "outputs": [], + "source": [ + "#@title Estimate Model Quality and Call Density. { vertical-output: true }\n", + "\n", + "import scipy\n", + "\n", + "# Collect validated labels by bin.\n", + "bin_pos = [0 for i in range(num_bins)]\n", + "bin_neg = [0 for i in range(num_bins)]\n", + "for score, pos in zip(scores, is_pos):\n", + " bin = np.argmax(score < q_bounds) - 1\n", + " if pos == 1:\n", + " bin_pos[bin] += 1\n", + " elif pos == -1:\n", + " bin_neg[bin] += 1\n", + "\n", + "# Create beta distributions.\n", + "prior = 0.1\n", + "betas = [scipy.stats.beta(p + prior, n + prior)\n", + " for p, n in zip(bin_pos, bin_neg)]\n", + "# MLE positive rate in each bin.\n", + "mle_b = np.array([bin_pos[b] / (bin_pos[b] + bin_neg[b] + 1e-6)\n", + " for b in range(num_bins)])\n", + "# Probability of each bin, P(b).\n", + "p_b = np.array([2**-k for k in range(1, num_bins)] + [2**(-num_bins + 1)])\n", + "\n", + "# MLE total call density.\n", + "q_mle = np.dot(mle_b, p_b)\n", + "\n", + "num_beta_samples = 10_000\n", + "q_betas = []\n", + "for _ in range(num_beta_samples):\n", + " qs_pos = np.array([b.rvs(size=1)[0] for b in betas]) # P(+|b)\n", + " q_beta = np.dot(qs_pos, p_b)\n", + " q_betas.append(q_beta)\n", + "\n", + "# Plot call density estimate.\n", + "plt.figure(figsize=(10, 5))\n", + "xs, ys, _ = plt.hist(q_betas, density=True, bins=25, alpha=0.25)\n", + "plt.plot([q_mle, q_mle], [0.0, np.max(xs)], 'k:', alpha=0.75,\n", + " label='q_mle')\n", + "\n", + "low, high = np.quantile(q_betas, [0.05, 0.95])\n", + "plt.plot([low, low], [0.0, np.max(xs)], 'g', alpha=0.75, label='low conf')\n", + "plt.plot([high, high], [0.0, np.max(xs)], 'g', alpha=0.75, label='high conf')\n", + "\n", + "plt.xlim(0.0, 1.0)\n", + "plt.xlabel('Call Rate (q)')\n", + "plt.ylabel('P(q)')\n", + "plt.title(f'Call Density Estimation ({target_class})')\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "print(f'MLE Call Density: {q_mle:.4f}')\n", + "print(f'(Low/MLE/High) Call Density Estimate: ({low:5.4f} / {q_mle:5.4f} / {high:5.4f})')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "6PPrCBc-15k_" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "private_outputs": true, + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index b1952323..89785daf 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -18,10 +18,12 @@ import dataclasses from typing import Sequence +from chirp.inference import tf_examples from chirp.inference.classify import data_lib from chirp.models import metrics import numpy as np import tensorflow as tf +import tqdm @dataclasses.dataclass @@ -144,3 +146,72 @@ def train_embedding_model( learning_rate=learning_rate, ) return test_metrics + + +def get_inference_dataset( + embeddings_ds: tf.data.Dataset, + model: tf.keras.Model, +): + """Create a dataset which includes the model's predictions.""" + + def classify_batch(batch): + """Classify a batch of embeddings.""" + emb = batch[tf_examples.EMBEDDING] + emb_shape = tf.shape(emb) + flat_emb = tf.reshape(emb, [-1, emb_shape[-1]]) + logits = model(flat_emb) + logits = tf.reshape( + logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]] + ) + # Take the maximum logit over channels. + logits = tf.reduce_max(logits, axis=-2) + batch['logits'] = logits + return batch + + inference_ds = embeddings_ds.map( + classify_batch, num_parallel_calls=tf.data.AUTOTUNE + ) + return inference_ds + + +def write_inference_csv( + embeddings_ds: tf.data.Dataset, + model: tf.keras.Model, + labels: Sequence[str], + output_filepath: str, + embedding_hop_size_s: float, + threshold: dict[str, float] | None = None, + exclude_classes: Sequence[str] = ('unknown',), + include_classes: Sequence[str] = (), +): + """Write a CSV file of inference results.""" + inference_ds = get_inference_dataset(embeddings_ds, model) + + detection_count = 0 + nondetection_count = 0 + with open(output_filepath, 'w') as f: + # Write column headers. + headers = ['filename', 'timestamp_s', 'label', 'logit'] + f.write(', '.join(headers) + '\n') + for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): + for t in range(ex['logits'].shape[0]): + for i, label in enumerate(labels): + if label in exclude_classes: + continue + if include_classes and label not in include_classes: + continue + if threshold is None or ex['logits'][t, i] > threshold[label]: + offset = ex['timestamp_s'] + t * embedding_hop_size_s + logit = '{:.2f}'.format(ex['logits'][t, i]) + row = [ + ex['filename'].decode('utf-8'), + '{:.2f}'.format(offset), + label, + logit, + ] + f.write(', '.join(row) + '\n') + detection_count += 1 + else: + nondetection_count += 1 + print('\n\n\n Detection count: ', detection_count) + print('NonDetection count: ', nondetection_count) diff --git a/chirp/inference/search/display.py b/chirp/inference/search/display.py index aade74c8..10658a2a 100644 --- a/chirp/inference/search/display.py +++ b/chirp/inference/search/display.py @@ -15,6 +15,7 @@ """Utility functions for displaying audio and results in Colab/Jupyter.""" +import dataclasses import functools from typing import Sequence @@ -22,6 +23,7 @@ from chirp.inference.search import search from chirp.models import frontend import IPython +from IPython.display import clear_output from IPython.display import display as ipy_display import ipywidgets from librosa import display as librosa_display @@ -81,6 +83,38 @@ def plot_audio_melspec( ipy_display(IPython.display.Audio(audio, rate=sample_rate)) +def _make_result_buttons(button_labels: Sequence[str]): + """Creates buttons for selected labels.""" + + def button_callback(x): + x.value = not x.value + if x.value: + x.button_style = 'success' + else: + x.button_style = '' + + buttons = [] + for lbl in button_labels: + check = ipywidgets.Button( + description=lbl, + disabled=False, + button_style='', + ) + check.value = False + check.on_click(button_callback) + + buttons.append(check) + return buttons + + +def _make_result_radio_buttons(button_labels: Sequence[str]): + """Make radio buttons with the indicated labels.""" + b = ipywidgets.RadioButtons(options=button_labels) + # Explicitly set value to None to avoid pre-selecting the first option. + b.value = None + return [b] + + def display_search_results( results: search.TopKSearchResults, embedding_sample_rate: int, @@ -88,6 +122,8 @@ def display_search_results( window_s: float = 5.0, checkbox_labels: Sequence[str] = (), max_workers=5, + exclusive_labels=False, + rank_offset: int = 0, ): """Display search results, and add audio and annotation info to results.""" @@ -104,33 +140,71 @@ def display_search_results( ): plot_audio_melspec(result_audio_window, embedding_sample_rate) plt.show() - print(f'rank : {rank}') + print(f'rank : {rank + rank_offset}') print(f'source file : {r.filename}') offset_s = r.timestamp_offset print(f'offset_s : {offset_s:.2f}') print(f'score : {(r.score):.2f}') - label_widgets = [] - def button_callback(x): - x.value = not x.value - if x.value: - x.button_style = 'success' + if not r.label_widgets: + if exclusive_labels: + r.label_widgets = _make_result_radio_buttons(checkbox_labels) else: - x.button_style = '' + r.label_widgets = _make_result_buttons(checkbox_labels) - for lbl in checkbox_labels: - check = ipywidgets.Button( - description=lbl, - disabled=False, - button_style='', - ) - check.value = False - check.on_click(button_callback) + for b in r.label_widgets: + ipy_display(b) - label_widgets.append(check) - ipy_display(check) # Attach audio and widgets to the SearchResult. r.audio = result_audio_window - r.label_widgets = label_widgets - print('-' * 80) + + +@dataclasses.dataclass +class PageState: + max_page: int + curr_page: int = 0 + + def increment(self, inc): + self.curr_page += inc + self.curr_page = min(self.max_page, self.curr_page) + self.curr_page = max(0, self.curr_page) + + +def display_paged_results( + all_results: search.TopKSearchResults, + page_state: PageState, + samples_per_page: int = 10, + **kwargs, +): + """Display search results in pages.""" + + def increment_page_callback(x, inc, page_state): + page_state.increment(inc) + display_page(page_state) + + next_page_button = ipywidgets.Button(description='Next Page', disabled=False) + next_page_button.on_click(lambda x: increment_page_callback(x, 1, page_state)) + prev_page_button = ipywidgets.Button(description='Prev Page', disabled=False) + prev_page_button.on_click( + lambda x: increment_page_callback(x, -1, page_state) + ) + + def display_page(page_state): + clear_output() + num_pages = len(all_results.search_results) // samples_per_page + page = page_state.curr_page + print(f'Results Page: {page} / {num_pages}') + st, end = page * samples_per_page, (page + 1) * samples_per_page + results_page = search.TopKSearchResults( + all_results.search_results[st:end], top_k=samples_per_page + ) + display_search_results( + results_page, rank_offset=page * samples_per_page, **kwargs + ) + print(f'Results Page: {page} / {num_pages}') + ipy_display(prev_page_button) + ipy_display(next_page_button) + + # Display the first page. + display_page(page_state) diff --git a/chirp/inference/search/search.py b/chirp/inference/search/search.py index 49071589..98bf22b9 100644 --- a/chirp/inference/search/search.py +++ b/chirp/inference/search/search.py @@ -60,6 +60,9 @@ class TopKSearchResults: min_score: float = -1.0 _min_score_idx: int = -1 + def __post_init__(self): + self._update_deseridata() + def __iter__(self): for r in self.search_results: yield r @@ -85,6 +88,8 @@ def will_filter(self, score: float) -> bool: return score < self.min_score def _update_deseridata(self): + if not self.search_results: + return self._min_score_idx = np.argmin([r.sort_score for r in self.search_results]) self.min_score = self.search_results[self._min_score_idx].sort_score diff --git a/embed_audio.ipynb b/embed_audio.ipynb new file mode 100644 index 00000000..416e42fd --- /dev/null +++ b/embed_audio.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AjNsE-YjbCew" + }, + "source": [ + "# Mass Embedding of Bioacoustic Audio\n", + "\n", + "This notebook facilitates pre-computing embeddings of audio data for subsequent\n", + "use with search, classification, and analysis." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "avlqyEzpa_rN" + }, + "source": [ + "## Configuration and Imports." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "3kuA7b5Wap7o" + }, + "outputs": [], + "source": [ + "#@title Imports. { vertical-output: true }\n", + "\n", + "from etils import epath\n", + "from ml_collections import config_dict\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tqdm\n", + "from chirp.inference import colab_utils\n", + "colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", + "\n", + "from chirp import audio_utils\n", + "from chirp.inference import embed_lib\n", + "from chirp.inference import tf_examples\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "-l4NL0CAavKl" + }, + "outputs": [], + "source": [ + "#@title Basic Configuration. { vertical-output: true }\n", + "\n", + "# Define the model: Usually perch or birdnet.\n", + "model_choice = 'perch' #@param\n", + "# Set the base directory for the project.\n", + "working_dir = '/tmp/agile' #@param\n", + "\n", + "# Set the embedding and labeled data directories.\n", + "embeddings_path = epath.Path(working_dir) / 'embeddings'\n", + "labeled_data_path = epath.Path(working_dir) / 'labeled'\n", + "embeddings_glob = embeddings_path / 'embeddings-*'\n", + "\n", + "# OPTIONAL: Set up separation model.\n", + "separation_model_key = 'separator_model_tf' #@param\n", + "separation_model_path = '' #@param\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G7W8Rl0ma8Mm" + }, + "source": [ + "## Embed Audio" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "vobyRomeazNr" + }, + "outputs": [], + "source": [ + "#@title Embedding Configuration. { vertical-output: true }\n", + "\n", + "config = config_dict.ConfigDict()\n", + "config.embed_fn_config = config_dict.ConfigDict()\n", + "config.embed_fn_config.model_config = config_dict.ConfigDict()\n", + "\n", + "# IMPORTANT: Select the targe audio files.\n", + "# source_file_patterns should contain a list of globs of audio files, like:\n", + "# ['/home/me/*.wav', '/home/me/other/*.flac']\n", + "config.source_file_patterns = [''] #@param\n", + "config.output_dir = embeddings_path.as_posix()\n", + "\n", + "# For Perch, set the perch_tfhub_model_version, and the model will load\n", + "# automagically from TFHub. Alternatively, set the model path for a local\n", + "# copy of the model.\n", + "# Note that only one of perch_model_path and perch_tfhub_version should be set.\n", + "perch_tfhub_version = 4 #@param\n", + "perch_model_path = '' #@param\n", + "\n", + "# For BirdNET, point to the specific tflite file.\n", + "birdnet_model_path = '' #@param\n", + "if model_choice == 'perch':\n", + " config.embed_fn_config.model_key = 'taxonomy_model_tf'\n", + " config.embed_fn_config.model_config.window_size_s = 5.0\n", + " config.embed_fn_config.model_config.hop_size_s = 5.0\n", + " config.embed_fn_config.model_config.sample_rate = 32000\n", + " config.embed_fn_config.model_config.tfhub_version = perch_tfhub_version\n", + " config.embed_fn_config.model_config.model_path = perch_model_path\n", + "elif model_choice == 'birdnet':\n", + " config.embed_fn_config.model_key = 'birdnet'\n", + " config.embed_fn_config.model_config.window_size_s = 3.0\n", + " config.embed_fn_config.model_config.hop_size_s = 3.0\n", + " config.embed_fn_config.model_config.sample_rate = 48000\n", + " config.embed_fn_config.model_config.model_path = birdnet_model_path\n", + " # Note: The v2_1 class list is appropriate for Birdnet 2.1, 2.2, and 2.3.\n", + " config.embed_fn_config.model_config.class_list_name = 'birdnet_v2_1'\n", + " config.embed_fn_config.model_config.num_tflite_threads = 4\n", + "\n", + "# Only write embeddings to reduce size.\n", + "config.embed_fn_config.write_embeddings = True\n", + "config.embed_fn_config.write_logits = False\n", + "config.embed_fn_config.write_separated_audio = False\n", + "config.embed_fn_config.write_raw_audio = False\n", + "\n", + "# Number of parent directories to include in the filename.\n", + "config.embed_fn_config.file_id_depth = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "FiWVT22ja1Y0" + }, + "outputs": [], + "source": [ + "#@title Set up. { vertical-output: true }\n", + "\n", + "# Set up the embedding function, including loading models.\n", + "embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)\n", + "print('\\n\\nLoading model(s)...')\n", + "embed_fn.setup()\n", + "\n", + "# Create output directory and write the configuration.\n", + "output_dir = epath.Path(config.output_dir)\n", + "output_dir.mkdir(exist_ok=True, parents=True)\n", + "embed_lib.maybe_write_config(config, output_dir)\n", + "\n", + "# Create SourceInfos.\n", + "source_infos = embed_lib.create_source_infos(\n", + " config.source_file_patterns,\n", + " num_shards_per_file=config.get('num_shards_per_file', -1),\n", + " shard_len_s=config.get('shard_len_s', -1))\n", + "print(f'Found {len(source_infos)} source infos.')\n", + "\n", + "print('\\n\\nTest-run of model...')\n", + "window_size_s = config.embed_fn_config.model_config.window_size_s\n", + "sr = config.embed_fn_config.model_config.sample_rate\n", + "z = np.zeros([int(sr * window_size_s)])\n", + "embed_fn.embedding_model.embed(z)\n", + "print('Setup complete!')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "jf8RVwRwa350" + }, + "outputs": [], + "source": [ + "#@title Run embedding. { vertical-output: true }\n", + "\n", + "# Uses multiple threads to load audio before embedding.\n", + "# This tends to be faster, but can fail if any audio files are corrupt.\n", + "\n", + "embed_fn.min_audio_s = 1.0\n", + "record_file = (output_dir / 'embeddings.tfrecord').as_posix()\n", + "succ, fail = 0, 0\n", + "\n", + "existing_embedding_ids = embed_lib.get_existing_source_ids(\n", + " output_dir, 'embeddings-*')\n", + "\n", + "new_source_infos = embed_lib.get_new_source_infos(\n", + " source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)\n", + "\n", + "print(f'Found {len(new_source_infos)} existing embedding ids.'\n", + " f'Processing {len(new_source_infos)} new source infos. ')\n", + "\n", + "audio_iterator = audio_utils.multi_load_audio_window(\n", + " filepaths=[s.filepath for s in new_source_infos],\n", + " offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],\n", + " sample_rate=config.embed_fn_config.model_config.sample_rate,\n", + " window_size_s=config.get('shard_len_s', -1.0),\n", + ")\n", + "with tf_examples.EmbeddingsTFRecordMultiWriter(\n", + " output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:\n", + " for source_info, audio in tqdm.tqdm(\n", + " zip(new_source_infos, audio_iterator), total=len(new_source_infos)):\n", + " file_id = source_info.file_id(config.embed_fn_config.file_id_depth)\n", + " offset_s = source_info.shard_num * source_info.shard_len_s\n", + " example = embed_fn.audio_to_example(file_id, offset_s, audio)\n", + " if example is None:\n", + " fail += 1\n", + " continue\n", + " file_writer.write(example.SerializeToString())\n", + " succ += 1\n", + " file_writer.flush()\n", + "print(f'\\n\\nSuccessfully processed {succ} source_infos, failed {fail} times.')\n", + "\n", + "fns = [fn for fn in output_dir.glob('embeddings-*')]\n", + "ds = tf.data.TFRecordDataset(fns)\n", + "parser = tf_examples.get_example_parser()\n", + "ds = ds.map(parser)\n", + "for ex in ds.as_numpy_iterator():\n", + " print(ex['filename'])\n", + " print(ex['embedding'].shape, flush=True)\n", + " break\n" + ] + } + ], + "metadata": { + "colab": { + "private_outputs": true, + "toc_visible": true + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}