Skip to content

Commit

Permalink
Fix memory explosion in multi_load_audio_window when audio is loaded …
Browse files Browse the repository at this point in the history
…faster than it can be processed.

PiperOrigin-RevId: 621643020
  • Loading branch information
sdenton4 authored and copybara-github committed Apr 4, 2024
1 parent 9a45259 commit e609b91
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 33 deletions.
44 changes: 31 additions & 13 deletions chirp/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
General utilities for processing audio and spectrograms.
"""
import concurrent
import dataclasses
import functools
import logging
import os
import queue
import tempfile
from typing import Generator, Sequence
from typing import Any, Callable, Generator, Iterator, Sequence
import warnings

from chirp import path_utils
Expand Down Expand Up @@ -148,38 +150,54 @@ def multi_load_audio_window(
sample_rate: int,
window_size_s: float,
max_workers: int = 5,
buffer_size: int = -1,
) -> Generator[np.ndarray, None, None]:
"""Generator for loading audio windows in parallel.
Note that audio is returned in the same order as the filepaths.
Also, this ultimately relies on soundfile, which can be buggy in some cases.
Caution: Because this generator uses an Executor, it can continue holding
resources while not being used. If you are using this in a notebook, you
should use this in a 'nameless' context, like:
```
for audio in multi_load_audio_window(...):
...
```
Otherwise, the generator will continue to hold resources until the notebook
is closed.
Args:
filepaths: Paths to audio to load.
offsets: Read offset in seconds for each file, or None if no offsets are
needed.
sample_rate: Sample rate for returned audio.
window_size_s: Window length to read from each file. Set <0 to read all.
max_workers: Number of threads to allocate.
buffer_size: Max number of audio windows to queue up. Defaults to 2x the
number of workers.
Yields:
Loaded audio windows.
"""
if buffer_size == -1:
buffer_size = 2 * max_workers
q = queue.Queue(maxsize=buffer_size)
loader = functools.partial(
load_audio_window, sample_rate=sample_rate, window_size_s=window_size_s
)
if offsets is None:
offsets = [0.0 for _ in filepaths]
# ThreadPoolExecutor works well despite the
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
futures = []
for fp, offset in zip(filepaths, offsets):
future = executor.submit(loader, offset_s=offset, filepath=fp)
futures.append(future)
while futures:
yield futures.pop(0).result()
mapping = lambda fp, offset: q.put(loader(fp, offset))
task_iterator = zip(filepaths, offsets)

executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
try:
for fp, offset in task_iterator:
executor.submit(mapping, fp, offset)
for _ in range(len(filepaths)):
yield q.get()
finally:
# This is run when the generator is closed, or when an exception is raised.
executor.shutdown(wait=False, cancel_futures=True)


def load_xc_audio(xc_id: str, sample_rate: int) -> jnp.ndarray:
Expand Down
59 changes: 39 additions & 20 deletions embed_audio.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"config.embed_fn_config = config_dict.ConfigDict()\n",
"config.embed_fn_config.model_config = config_dict.ConfigDict()\n",
"\n",
"# IMPORTANT: Select the targe audio files.\n",
"# IMPORTANT: Select the target audio files.\n",
"# source_file_patterns should contain a list of globs of audio files, like:\n",
"# ['/home/me/*.wav', '/home/me/other/*.flac']\n",
"config.source_file_patterns = [''] #@param\n",
Expand Down Expand Up @@ -195,25 +195,28 @@
"print(f'Found {len(new_source_infos)} existing embedding ids.'\n",
" f'Processing {len(new_source_infos)} new source infos. ')\n",
"\n",
"audio_iterator = audio_utils.multi_load_audio_window(\n",
" filepaths=[s.filepath for s in new_source_infos],\n",
" offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],\n",
" sample_rate=config.embed_fn_config.model_config.sample_rate,\n",
" window_size_s=config.get('shard_len_s', -1.0),\n",
")\n",
"with tf_examples.EmbeddingsTFRecordMultiWriter(\n",
" output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:\n",
" for source_info, audio in tqdm.tqdm(\n",
" zip(new_source_infos, audio_iterator), total=len(new_source_infos)):\n",
" file_id = source_info.file_id(config.embed_fn_config.file_id_depth)\n",
" offset_s = source_info.shard_num * source_info.shard_len_s\n",
" example = embed_fn.audio_to_example(file_id, offset_s, audio)\n",
" if example is None:\n",
" fail += 1\n",
" continue\n",
" file_writer.write(example.SerializeToString())\n",
" succ += 1\n",
" file_writer.flush()\n",
"try:\n",
" audio_iterator = audio_utils.multi_load_audio_window(\n",
" filepaths=[s.filepath for s in new_source_infos],\n",
" offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],\n",
" sample_rate=config.embed_fn_config.model_config.sample_rate,\n",
" window_size_s=config.get('shard_len_s', -1.0),\n",
" )\n",
" with tf_examples.EmbeddingsTFRecordMultiWriter(\n",
" output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:\n",
" for source_info, audio in tqdm.tqdm(\n",
" zip(new_source_infos, audio_iterator), total=len(new_source_infos)):\n",
" file_id = source_info.file_id(config.embed_fn_config.file_id_depth)\n",
" offset_s = source_info.shard_num * source_info.shard_len_s\n",
" example = embed_fn.audio_to_example(file_id, offset_s, audio)\n",
" if example is None:\n",
" fail += 1\n",
" continue\n",
" file_writer.write(example.SerializeToString())\n",
" succ += 1\n",
" file_writer.flush()\n",
"finally:\n",
" del(audio_iterator)\n",
"print(f'\\n\\nSuccessfully processed {succ} source_infos, failed {fail} times.')\n",
"\n",
"fns = [fn for fn in output_dir.glob('embeddings-*')]\n",
Expand All @@ -225,12 +228,28 @@
" print(ex['embedding'].shape, flush=True)\n",
" break\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "pS8vC2JWliEG"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"private_outputs": true,
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
Expand Down

0 comments on commit e609b91

Please sign in to comment.