Skip to content

Commit

Permalink
Clean up embedding configuration for better observability by expandin…
Browse files Browse the repository at this point in the history
…g 'AudioGlob' to include all characteristics needed for embedding a collection of files. This allows better serializability of embedding configuration and checks for compatibility across different runs. In particular, we can choose different 'high-level' settings per-dataset. We also allow updating the base_path for some audio (eg, allowing pointing at a local cache of the audio files instead of a remote source).

Also, change default embedding behavior to convert to the model's sample rate. (Using the file's 'native' sample rate is for advanced use only...)

PiperOrigin-RevId: 683643120
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 8, 2024
1 parent 952df72 commit 7d1be69
Show file tree
Hide file tree
Showing 9 changed files with 397 additions and 142 deletions.
36 changes: 27 additions & 9 deletions chirp/projects/agile2/1_embed_audio_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
"import os\n",
"from IPython.display import display\n",
"import ipywidgets as widgets\n",
"\n",
"from chirp.projects.agile2 import colab_utils\n",
"from chirp.projects.agile2 import embed\n",
"from chirp.projects.agile2 import source_info\n",
"from chirp.projects.hoplite import interface\n"
]
},
Expand Down Expand Up @@ -60,13 +62,24 @@
"#@markdown `multispecies_whale` for marine mammals.\n",
"model_choice = 'perch_8' #@param['perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']\n",
"\n",
"globs_to_process = {dataset_name: (dataset_base_path, dataset_fileglob,),}\n",
"use_file_sharding = True #@param {type:'boolean'}\n",
"\n",
"audio_glob = source_info.AudioSourceConfig(\n",
" dataset_name=dataset_name,\n",
" base_path=dataset_base_path,\n",
" file_glob=dataset_fileglob,\n",
" min_audio_len_s=1.0,\n",
" target_sample_rate_hz=-2,\n",
" shard_len_s=60.0 if use_file_sharding else None,\n",
")\n",
"\n",
"# You do not need to change this unless you want to maintain multiple distinct\n",
"# embedding databases.\n",
"db_path = None\n",
"configs = colab_utils.load_configs(\n",
" globs_to_process, db_path, model_config_key=model_choice)\n",
" source_info.AudioSources((audio_glob,)),\n",
" db_path,\n",
" model_config_key=model_choice)\n",
"configs"
]
},
Expand Down Expand Up @@ -117,18 +130,14 @@
"source": [
"#@title Run the embedding { vertical-output: true }\n",
"\n",
"# If the DB already exists, we need to make sure that the the current\n",
"# model_config is compatible with the model_config that was used previously.\n",
"colab_utils.validate_and_save_configs(configs, db)\n",
"\n",
"print(f'Embedding dataset: {[key for key in globs_to_process]}')\n",
"print(f'Embedding dataset: {audio_glob.dataset_name}')\n",
"\n",
"worker = embed.EmbedWorker(\n",
" embed_config=configs.audio_sources_config,\n",
" audio_sources=configs.audio_sources_config,\n",
" db=db,\n",
" model_config=configs.model_config)\n",
"\n",
"worker.process_all()\n",
"worker.process_all(target_dataset_name=audio_glob.dataset_name)\n",
"\n",
"print('\\n\\nEmbedding complete, total embeddings: ', db.count_embeddings())"
]
Expand All @@ -147,6 +156,15 @@
" print(f'\\nDataset \\'{dataset}\\':')\n",
" print('\\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hr_AUAfI7UG_"
},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
43 changes: 8 additions & 35 deletions chirp/projects/agile2/colab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses

from chirp.projects.agile2 import embed
from chirp.projects.agile2 import source_info
from chirp.projects.hoplite import db_loader
from chirp.projects.hoplite import interface
from chirp.projects.zoo import models
Expand All @@ -30,7 +31,7 @@ class AgileConfigs:
"""Container for the various configs used in the Agile notebooks."""

# Config for the raw audio sources.
audio_sources_config: embed.EmbedConfig
audio_sources_config: source_info.AudioSources
# Database config for the embeddings database.
db_config: db_loader.DBConfig
# Config for the embedding model.
Expand All @@ -45,33 +46,9 @@ def as_config_dict(self) -> config_dict.ConfigDict:
})


def validate_and_save_configs(
configs: AgileConfigs,
db: interface.GraphSearchDBInterface,
):
"""Validates that the model config is compatible with the DB."""

model_config = configs.model_config
db_metadata = db.get_metadata(None)
if 'model_config' in db_metadata:
if db_metadata['model_config'].model_key != model_config.model_key:
raise AssertionError(
'The configured embedding model does not match the embedding model'
' that is already in the DB. You either need to drop the database or'
" use the '%s' model confg."
% db_metadata['model_config'].model_key
)

db.insert_metadata('model_config', model_config.to_config_dict())
db.insert_metadata(
'embed_config', configs.audio_sources_config.to_config_dict()
)
db.commit()


def load_configs(
audio_globs: dict[str, tuple[str, str]],
db_path: str,
audio_sources: source_info.AudioSources,
db_path: str | None = None,
model_config_key: str = 'perch_8',
) -> AgileConfigs:
"""Load default configs for the notebook and return them as an AgileConfigs.
Expand All @@ -87,13 +64,14 @@ def load_configs(
AgileConfigs object with the loaded configs.
"""
if db_path is None:
if len(audio_globs) > 1:
if len(audio_sources.audio_globs) > 1:
raise ValueError(
'db_path must be specified when embedding multiple datasets.'
)
# Put the DB in the same directory as the audio.
db_path = (
epath.Path(next(iter(audio_globs.values()))[0]) / 'hoplite_db.sqlite'
epath.Path(next(iter(audio_sources.audio_globs)).base_path)
/ 'hoplite_db.sqlite'
)

model_key, embedding_dim, model_config = models.get_preset_model_config(
Expand All @@ -109,13 +87,8 @@ def load_configs(
'embedding_dim': embedding_dim,
})

audio_srcs_config = embed.EmbedConfig(
audio_globs=audio_globs,
min_audio_len_s=1.0,
)

return AgileConfigs(
audio_sources_config=audio_srcs_config,
audio_sources_config=audio_sources,
db_config=db_loader.DBConfig('sqlite', db_config),
model_config=db_model_config,
)
31 changes: 20 additions & 11 deletions chirp/projects/agile2/convert_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from chirp.projects.agile2 import embed
from chirp.projects.agile2 import source_info
from chirp.projects.hoplite import in_mem_impl
from chirp.projects.hoplite import interface
from chirp.projects.hoplite import sqlite_impl
Expand Down Expand Up @@ -70,19 +71,27 @@ def convert_tfrecords(
)
file_id_depth = legacy_config.embed_fn_config['file_id_depth']
audio_globs = []
for glob in legacy_config.source_file_patterns:
new_glob = glob.split('/')[-file_id_depth - 1 :]
audio_globs.append(new_glob)
for i, glob in enumerate(legacy_config.source_file_patterns):
base_path, file_glob = glob.split('/')[-file_id_depth - 1 :]
if i > 0:
partial_dataset_name = f'{dataset_name}_{i}'
else:
partial_dataset_name = dataset_name
audio_globs.append(
source_info.AudioSourceConfig(
dataset_name=partial_dataset_name,
base_path=base_path,
file_glob=file_glob,
min_audio_len_s=legacy_config.embed_fn_config.min_audio_s,
target_sample_rate_hz=legacy_config.embed_fn_config.get(
'target_sample_rate_hz', -2
),
)
)

embed_config = embed.EmbedConfig(
audio_globs={dataset_name: tuple(audio_globs)},
min_audio_len_s=legacy_config.embed_fn_config.min_audio_s,
target_sample_rate_hz=legacy_config.embed_fn_config.get(
'target_sample_rate_hz', -1
),
)
audio_sources = source_info.AudioSources(audio_globs=tuple(audio_globs))
db.insert_metadata('legacy_config', legacy_config)
db.insert_metadata('embed_config', embed_config.to_config_dict())
db.insert_metadata('audio_sources', audio_sources.to_config_dict())
db.insert_metadata('model_config', model_config.to_config_dict())
hop_size_s = model_config.model_config.hop_size_s

Expand Down
102 changes: 72 additions & 30 deletions chirp/projects/agile2/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,42 +45,27 @@ class ModelConfig(hoplite_interface.EmbeddingMetadata):
model_config: config_dict.ConfigDict


@dataclasses.dataclass
class EmbedConfig(hoplite_interface.EmbeddingMetadata):
"""Configuration for embedding processing.
Attributes:
audio_globs: Mapping from dataset name to pairs of `(root directory, file
glob)`.
min_audio_len_s: Minimum audio length to process.
target_sample_rate_hz: Target sample rate for audio. If -2, use the
embedding model's declared sample rate. If -1, use the file's native
sample rate. If > 0, resample to the specified rate.
"""

audio_globs: dict[str, tuple[str, str]]
min_audio_len_s: float
target_sample_rate_hz: int = -1


class EmbedWorker:
"""Worker for embedding audio examples."""

def __init__(
self,
embed_config: EmbedConfig,
audio_sources: source_info.AudioSources,
model_config: ModelConfig,
db: hoplite_interface.GraphSearchDBInterface,
embedding_model: zoo_interface.EmbeddingModel | None = None,
):
self.db = db
self.model_config = model_config
self.embed_config = embed_config
self.audio_sources = audio_sources
if embedding_model is None:
model_class = models.model_class_map()[model_config.model_key]
self.embedding_model = model_class.from_config(model_config.model_config)
else:
self.embedding_model = embedding_model
self.audio_globs = {
g.dataset_name: g for g in self.audio_sources.audio_globs
}

def _log_error(self, source_id, exception, counter_name):
logging.warning(
Expand All @@ -92,15 +77,68 @@ def _log_error(self, source_id, exception, counter_name):
exception,
)

def get_sample_rate_hz(self) -> int:
def _update_audio_sources(self):
"""Validates the embed config and/or saves it to the DB."""
db_metadata = self.db.get_metadata(None)
if 'audio_sources' not in db_metadata:
self.db.insert_metadata(
'audio_sources', self.audio_sources.to_config_dict()
)
return

db_audio_sources = source_info.AudioSources.from_config_dict(
db_metadata['audio_sources']
)
merged = self.audio_sources.merge_update(db_audio_sources)
self.db.insert_metadata('audio_sources', merged.to_config_dict())
self.audio_sources = merged

def _update_model_config(self):
"""Validates the model config and/or saves it to the DB."""
db_metadata = self.db.get_metadata(None)
if 'model_config' not in db_metadata:
self.db.insert_metadata(
'model_config', self.model_config.to_config_dict()
)
return

db_model_config = ModelConfig(**db_metadata['model_config'])
if self.model_config == db_model_config:
return

# Validate the config against the DB.
# TODO(tomdenton): Implement compatibility checks for model configs.
if self.model_config.model_key != db_model_config.model_key:
raise AssertionError(
'The configured model key does not match the model key that is '
'already in the DB.'
)
if self.model_config.embedding_dim != db_model_config.embedding_dim:
raise AssertionError(
'The configured embedding dimension does not match the embedding '
'dimension that is already in the DB.'
)
self.db.insert_metadata('model_config', self.model_config.to_config_dict())

def update_configs(self):
"""Validates the configs and saves them to the DB."""
self._update_model_config()
self._update_audio_sources()
self.db.commit()

def get_sample_rate_hz(self, source_id: source_info.SourceId) -> int:
"""Get the sample rate of the embedding model."""
if self.embed_config.target_sample_rate_hz == -2:
dataset_name = source_id.dataset_name
if dataset_name not in self.audio_globs:
raise ValueError(f'Dataset name {dataset_name} not found in audio globs.')
audio_glob = self.audio_globs[dataset_name]
if audio_glob.target_sample_rate_hz == -2:
return self.embedding_model.sample_rate
elif self.embed_config.target_sample_rate_hz == -1:
elif audio_glob.target_sample_rate_hz == -1:
# Uses the file's native sample rate.
return -1
elif self.embed_config.target_sample_rate_hz > 0:
return self.embed_config.target_sample_rate_hz
elif audio_glob.target_sample_rate_hz > 0:
return audio_glob.target_sample_rate_hz
else:
raise ValueError('Invalid target_sample_rate.')

Expand All @@ -110,7 +148,7 @@ def load_audio(self, source_id: source_info.SourceId) -> np.ndarray | None:
audio_array = audio_utils.load_audio_window(
source_id.filepath,
source_id.offset_s,
self.embed_config.target_sample_rate_hz,
self.get_sample_rate_hz(source_id),
source_id.shard_len_s,
)
return np.array(audio_array)
Expand Down Expand Up @@ -141,12 +179,13 @@ def process_source_id(
self, source_id: source_info.SourceId
) -> Iterator[tuple[hoplite_interface.EmbeddingSource, np.ndarray]]:
"""Process a single audio source."""
glob = self.audio_globs[source_id.dataset_name]
audio_array = self.load_audio(source_id)
if audio_array is None:
return
if (
audio_array.shape[0]
< self.embed_config.min_audio_len_s * self.embedding_model.sample_rate
< glob.min_audio_len_s * self.embedding_model.sample_rate
):
self._log_error(source_id, 'no_exception', 'audio_too_short')
return
Expand All @@ -170,10 +209,13 @@ def process_source_id(
for channel_embedding in embedding:
yield (emb_source_id, channel_embedding)

def process_all(self):
def process_all(self, target_dataset_name: str | None = None):
"""Process all audio examples."""
audio_sources = source_info.AudioSources(self.embed_config.audio_globs)
for source_id in audio_sources.iterate_all_sources():
self.update_configs()
# TODO(tomdenton): Prefetch audio in parallel for faster execution.
for source_id in self.audio_sources.iterate_all_sources(
target_dataset_name
):
for emb_source_id, embedding in self.process_source_id(source_id):
self.db.insert_embedding(embedding, emb_source_id)
self.db.commit()
Loading

0 comments on commit 7d1be69

Please sign in to comment.