Skip to content

Commit

Permalink
Update beam and pandas dependencies, and add instructions to download…
Browse files Browse the repository at this point in the history
… and prepare the BIRB evaluation data

PiperOrigin-RevId: 542269331
  • Loading branch information
vdumoulin authored and copybara-github committed Aug 14, 2023
1 parent 52a7ea1 commit 91624a0
Show file tree
Hide file tree
Showing 18 changed files with 1,760 additions and 556 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ You might need the following dependencies.
# Install Poetry for package management
curl -sSL https://install.python-poetry.org | python3 -

# Install dependencies for librosa (required for testing only)
sudo apt-get install libsndfile1
# Install dependencies for librosa
sudo apt-get install libsndfile1 ffmpeg

# Install all dependencies specified in the poetry configs.
poetry install
Expand All @@ -26,4 +26,18 @@ dependencies, in which you can run the Chirp codebase. To run the tests, try
poetry run python -m unittest discover -s chirp/tests -p "*test.py"
```

## BIRB data preparation

### Evaluation data

After [installing](#installation) the `chirp` package, run the following command from the repository's root directory:

```bash
poetry run tfds build -i chirp.data.bird_taxonomy,chirp.data.soundscapes \
soundscapes/{ssw,hawaii,coffee_farms,sierras_kahl,high_sierras,peru}_full_length \
bird_taxonomy/{downstream_full_length,class_representatives_slice_peaked}
```

The process should take 36 to 48 hours to complete and use around 256 GiB of disk space.

*This is not an officially supported Google product.*
53 changes: 44 additions & 9 deletions chirp/data/bird_taxonomy/bird_taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
be retrieved from the 'filename' feature: 'XC{xeno_canto_id}.mp3'.
"""

# The maximum audio sequence length to consider if a localization function is
# provided. This is 5 * 60 seconds = 5 minutes.
_MAX_LOCALIZATION_LENGTH_S = 5 * 60

LocalizationFn = Callable[[Any, int, float], jnp.ndarray]


Expand Down Expand Up @@ -335,6 +339,12 @@ def _info(self) -> tfds.core.DatasetInfo:
)

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
Increase the file handle resource soft limit to the hard limit. The
dataset is large enough that it causes TFDS to hit the soft limit.
import resource
_low, _high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (_high, high))

# No checksum is found for the new taxonomy_info. dl_manager may raise
# an error when removing the line below.
dl_manager._force_checksums_validation = (
Expand Down Expand Up @@ -435,6 +445,9 @@ def _process_example(row):
# Resampling can introduce artifacts that push the signal outside the
# [-1, 1) interval.
audio = np.clip(audio, -1.0, 1.0 - (1.0 / float(1 << 15)))
# Skip empty audio files.
if audio.shape[0] == 0 or np.max(np.abs(audio)) == 0.0:
return None
# The scrubbed foreground annotations are replaced by ''. When this is the
# case, we translate this annotation into [] rather than [''].
foreground_label = (
Expand Down Expand Up @@ -463,19 +476,24 @@ def _process_example(row):
'sound_type': source['sound_type'],
}

pipeline = beam.Create(source_info.iterrows()) | beam.Map(_process_example)

if self.builder_config.localization_fn:

def _localize_intervals(args):
def localize_intervals_fn(args):
key, example = args
sample_rate_hz = self.builder_config.sample_rate_hz
interval_length_s = self.builder_config.interval_length_s
target_length = int(sample_rate_hz * interval_length_s)

audio = audio_utils.pad_to_length_if_shorter(
example['audio'], target_length
)
audio = example['audio']

# We limit audio sequence length to _MAX_LOCALIZATION_LENGTH_S when
# localizing intervals because the localization function can result in
# very large memory consumption for long audio sequences.
max_length = sample_rate_hz * _MAX_LOCALIZATION_LENGTH_S
if audio.shape[0] > max_length:
audio = audio[:max_length]

audio = audio_utils.pad_to_length_if_shorter(audio, target_length)
# Pass padded audio to avoid localization_fn having to pad again
audio_intervals = self.builder_config.localization_fn(
audio, sample_rate_hz, interval_length_s
Expand All @@ -499,6 +517,23 @@ def _localize_intervals(args):
))
return interval_examples

pipeline = pipeline | beam.FlatMap(_localize_intervals)

return pipeline
else:
localize_intervals_fn = None

for i, key_and_example in enumerate(
map(_process_example, source_info.iterrows())
):
# Since the audio files have variable length, the JAX compilation cache
# can use up a large amount of memory after a while.
if i % 100 == 0:
jax.clear_caches()

# Skip empty audio files.
if key_and_example is None:
continue

if localize_intervals_fn:
for key_and_example in localize_intervals_fn(key_and_example):
yield key_and_example
else:
yield key_and_example
2 changes: 1 addition & 1 deletion chirp/data/filter_scrub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def is_not_in(
def append(df: pd.DataFrame, row: dict[str, Any]):
if set(row.keys()) != set(df.columns):
raise ValueError
new_df = df.append(row, ignore_index=True)
new_df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
return new_df


Expand Down
7 changes: 3 additions & 4 deletions chirp/data/soundscapes/soundscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,6 @@ def _process_group(
beam.metrics.Metrics.counter('soundscapes', 'examples').inc()
return valid_segments

pipeline = beam.Create(
enumerate(segments.groupby('filename'))
) | beam.FlatMap(_process_group)
return pipeline
for group in enumerate(segments.groupby('filename')):
for key, example in _process_group(group):
yield key, example
4 changes: 2 additions & 2 deletions chirp/models/class_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class ClassAverage(metrics.Metric):
form a multi-hot encoding.
"""

total: jnp.array
count: jnp.array
total: jnp.ndarray
count: jnp.ndarray

@classmethod
def empty(cls):
Expand Down
5 changes: 3 additions & 2 deletions chirp/projects/bootstrap/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ def write_labeled_data(self, labeled_data_path: str, sample_rate: int):
for label in labels:
output_path = labeled_data_path / label
output_path.mkdir(parents=True, exist_ok=True)
output_filepath = output_path / output_filename
wavfile.write(output_filepath, sample_rate, r.audio)
output_filepath = epath.Path(output_path / output_filename)
with output_filepath.open('wb') as f:
wavfile.write(f, sample_rate, r.audio)
counts[label] += 1
for label, count in counts.items():
print(f'Wrote {count} examples for label {label}')
Expand Down
33 changes: 29 additions & 4 deletions chirp/projects/multicluster/data_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,15 @@ def embed_dataset(
exclude_classes: Sequence[str] = (),
load_audio: bool = True,
target_sample_rate: int = -1,
audio_file_pattern: str = '*.wav',
) -> Tuple[Sequence[str], Dict[str, np.ndarray]]:
"""Add embeddings to an eval dataset.
Embed a dataset, creating an in-memory copy of all data with embeddings added.
The base_dir should contain folders corresponding to classes, and each
sub-folder should contina wav files for the respective class.
sub-folder should contina audio files for the respective class.
Note that any wav files in the base_dir directly will be ignored.
Note that any audio files in the base_dir directly will be ignored.
Args:
base_dir: Directory contianing audio data.
Expand All @@ -196,6 +197,8 @@ def embed_dataset(
load_audio: Whether to load audio into memory.
target_sample_rate: Resample loaded audio to this sample rate. If -1, loads
raw audio with no resampling. If -2, uses the embedding_model sample rate.
audio_file_pattern: The glob pattern to use for finding audio files within
the sub-folders.
Returns:
Ordered labels and a Dict contianing the entire embedded dataset.
Expand Down Expand Up @@ -224,7 +227,18 @@ def embed_dataset(
for label_idx, label in enumerate(labels):
label_hot = np.zeros([len(labels)], np.int32)
label_hot[label_idx] = 1
filepaths = [fp.as_posix() for fp in (base_dir / label).glob('*.wav')]

filepaths = [
fp.as_posix() for fp in (base_dir / label).glob(audio_file_pattern)
]

if not filepaths:
raise ValueError(
'No files matching {} were found in directory {}'.format(
audio_file_pattern, base_dir / label
)
)

audio_iterator = audio_utils.multi_load_audio_window(
filepaths, None, target_sample_rate, -1
)
Expand All @@ -235,7 +249,18 @@ def embed_dataset(
audio = _pad_audio(audio, window_size)
audio = audio.astype(np.float32)
outputs = embedding_model.embed(audio)
embeds = outputs.pooled_embeddings(time_pooling, 'squeeze')

if not outputs.embeddings:
raise ValueError('Embedding model did not produce any embeddings!')

# If the audio was separated then the raw audio is in the first channel.
# Embedding shape is either [B, F, C, D] or [F, C, D] so channel is
# always -2.
channel_pooling = (
'squeeze' if outputs.embeddings.shape[-2] == 1 else 'first'
)

embeds = outputs.pooled_embeddings(time_pooling, channel_pooling)
merged['embeddings'].append(embeds)

filename = epath.Path(fp).name
Expand Down
2 changes: 1 addition & 1 deletion chirp/projects/sfda/mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def empty(cls) -> "MCA":

@classmethod
def from_model_output(
cls, scores: jnp.array, label: jnp.array, **_
cls, scores: jnp.ndarray, label: jnp.ndarray, **_
) -> clu_metrics.Metric:
num_classes = label.shape[-1]
if scores.shape[-1] != num_classes:
Expand Down
2 changes: 1 addition & 1 deletion chirp/projects/sfda/method_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


@jax.jit
def jax_cdist(features_a: jnp.array, features_b: jnp.array) -> jnp.array:
def jax_cdist(features_a: jnp.ndarray, features_b: jnp.ndarray) -> jnp.ndarray:
"""A jax equivalent of scipy.spatial.distance.cdist.
Computes the pairwise squared euclidean distance between each pair of features
Expand Down
14 changes: 7 additions & 7 deletions chirp/projects/sfda/methods/nrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
class NRCLoss(clu_metrics.Metric):
"""Computes NRC's loss for the standard single-label case."""

probabilities_sum: jnp.array
nn_loss_sum: jnp.array
extended_nn_loss_sum: jnp.array
probabilities_sum: jnp.ndarray
nn_loss_sum: jnp.ndarray
extended_nn_loss_sum: jnp.ndarray
label_mask: jnp.ndarray | None
n_samples: int

Expand Down Expand Up @@ -138,10 +138,10 @@ def compute(self):
class NRCMultiLoss(clu_metrics.Metric):
"""Computes NRC's loss for the multi-label case."""

probabilities_sum: jnp.array
nn_loss_sum: jnp.array
extended_nn_loss_sum: jnp.array
label_mask: jnp.array
probabilities_sum: jnp.ndarray
nn_loss_sum: jnp.ndarray
extended_nn_loss_sum: jnp.ndarray
label_mask: jnp.ndarray
n_samples: int

@classmethod
Expand Down
16 changes: 8 additions & 8 deletions chirp/projects/sfda/methods/shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
class SHOTMultiLabelLoss(clu_metrics.Metric):
"""Computes the loss used in SHOT-full for the multi-label case."""

probabilities_sum: jnp.array
entropy_sum: jnp.array
pl_xent_sum: jnp.array
label_mask: jnp.array
probabilities_sum: jnp.ndarray
entropy_sum: jnp.ndarray
pl_xent_sum: jnp.ndarray
label_mask: jnp.ndarray
n_samples: int
beta: float

Expand Down Expand Up @@ -109,9 +109,9 @@ def compute(self):
class SHOTLoss(clu_metrics.Metric):
"""Computes the loss used in SHOT-full for the single-label case."""

probabilities_sum: jnp.array
entropy_sum: jnp.array
pl_xent_sum: jnp.array
probabilities_sum: jnp.ndarray
entropy_sum: jnp.ndarray
pl_xent_sum: jnp.ndarray
label_mask: jnp.ndarray | None
n_samples: int
beta: float
Expand All @@ -121,7 +121,7 @@ def from_model_output(
cls,
probabilities: jnp.ndarray,
pseudo_label: jnp.ndarray,
label_mask: jnp.array,
label_mask: jnp.ndarray,
beta: float,
**_
) -> "SHOTLoss":
Expand Down
10 changes: 5 additions & 5 deletions chirp/projects/sfda/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MarginalEntropy(clu_metrics.Metric):
non-`averageable` metric, which is why we dedicate a separate metric for it.
"""

probability_sum: jnp.array
probability_sum: jnp.ndarray
n_samples: int
multi_label: bool
label_mask: jnp.ndarray | None
Expand Down Expand Up @@ -87,7 +87,7 @@ def compute(self):

@classmethod
def empty(cls) -> "MarginalEntropy":
return cls(
return cls( # pytype: disable=wrong-arg-types # jnp-array
probability_sum=0.0, n_samples=0, multi_label=False, label_mask=None
)

Expand All @@ -100,9 +100,9 @@ class MarginalBinaryEntropy(clu_metrics.Metric):
multi_label.
"""

probability_sum: jnp.array
probability_sum: jnp.ndarray
n_samples: int
label_mask: jnp.array
label_mask: jnp.ndarray
multi_label: bool

@classmethod
Expand Down Expand Up @@ -140,6 +140,6 @@ def compute(self):

@classmethod
def empty(cls) -> "MarginalBinaryEntropy":
return cls(
return cls( # pytype: disable=wrong-arg-types # jnp-array
probability_sum=0.0, n_samples=0, label_mask=0.0, multi_label=False
)
12 changes: 9 additions & 3 deletions chirp/tests/bird_taxonomy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,15 @@ def setUpClass(cls):
subdir = epath.Path(cls.tempdir) / 'audio-data' / 'comter'
subdir.mkdir(parents=True)
for i in range(4):
tfds.core.lazy_imports.pydub.AudioSegment.silent(duration=10000).export(
subdir / f'XC{i:05d}.mp3', format='mp3'
)
tfds.core.lazy_imports.pydub.AudioSegment(
b'\0\1' * int(10_000 * 10),
metadata={
'channels': 1,
'sample_width': 2,
'frame_rate': 10_000,
'frame_width': 2,
},
).export(subdir / f'XC{i:05d}.mp3', format='mp3')

@classmethod
def tearDownClass(cls):
Expand Down
9 changes: 7 additions & 2 deletions chirp/tests/filter_scrub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ def test_append_query(self):
new_df = fsu.apply_query(self.toy_df, append_query)
self.assertEqual(
new_df.to_dict(),
self.toy_df.append(new_row, ignore_index=True).to_dict(),
pd.concat(
[self.toy_df, pd.DataFrame([new_row])], ignore_index=True
).to_dict(),
)

# Append query with keys not matching the dataframe
Expand Down Expand Up @@ -456,7 +458,10 @@ def test_merge_concat_no_duplicates(self):
# of .to_dict().
self.assertTrue(
fsu.apply_parallel(self.toy_df, query_parallel).equals(
self.toy_df.loc[[0]].append([scrubbed_r0, self.toy_df.loc[1]])
pd.concat([
self.toy_df.loc[[0]],
pd.DataFrame([scrubbed_r0, self.toy_df.loc[1]]),
])
)
)

Expand Down
Loading

0 comments on commit 91624a0

Please sign in to comment.