Skip to content

Commit

Permalink
Consolidate search and small-classifier code under inference.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615067551
  • Loading branch information
sdenton4 authored and copybara-github committed Mar 14, 2024
1 parent e7c2dfd commit 79fd47d
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_no_jaxtrain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
poetry install --without jaxtrain
- name: Test with unittest
# TODO: Group together jaxtrain tests so they can be easily excluded.
run: poetry run python -m unittest discover -s chirp/tests -p "*inference_test.py"
run: poetry run python -m unittest discover -s chirp/inference/tests -p "*test.py"
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import dataclasses
from typing import Sequence

from chirp.inference.classify import data_lib
from chirp.models import metrics
from chirp.projects.multicluster import data_lib
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -131,9 +131,7 @@ def train_embedding_model(
) -> ClassifierMetrics:
"""Trains a classification model over embeddings and labels."""
train_locs, test_locs, _ = merged.create_random_train_test_split(
train_ratio,
train_examples_per_class,
random_seed,
train_ratio, train_examples_per_class, random_seed,
exclude_eval_classes=exclude_eval_classes,
)
test_metrics = train_from_locs(
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ class BootstrapConfig:
tensor_dtype: str

# The following are populated automatically from the embedding config.
embedding_hop_size_s: float | None = None
file_id_depth: int | None = None
audio_globs: Sequence[str] | None = None
model_key: str | None = None
model_config: config_dict.ConfigDict | None = None
tf_record_shards: int | None = None
embedding_hop_size_s: float
file_id_depth: int
audio_globs: Sequence[str]
model_key: str
model_config: config_dict.ConfigDict
tf_record_shards: int

@classmethod
def load_from_embedding_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from typing import Sequence

from chirp import audio_utils
from chirp.inference.search import search
from chirp.models import frontend
from chirp.projects.bootstrap import search
import IPython
from IPython.display import display as ipy_display
import ipywidgets
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for inference library."""
"""Tests for mass-embedding functionality."""

import os
import tempfile
Expand All @@ -28,12 +28,12 @@
from chirp.inference import interface
from chirp.inference import models
from chirp.inference import tf_examples
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from chirp.inference.search import bootstrap
from chirp.inference.search import display
from chirp.inference.search import search
from chirp.models import metrics
from chirp.projects.bootstrap import bootstrap
from chirp.projects.bootstrap import display
from chirp.projects.bootstrap import search
from chirp.projects.multicluster import classify
from chirp.projects.multicluster import data_lib
from chirp.taxonomy import namespace
from etils import epath
from ml_collections import config_dict
Expand All @@ -56,7 +56,7 @@ def _make_output_head_model(model_path: str, embedding_dim: int = 1280):
)


class InferenceTest(parameterized.TestCase):
class EmbedTest(parameterized.TestCase):

def test_imports(self):
# Test that imports work in external github environment.
Expand Down Expand Up @@ -139,9 +139,7 @@ def test_embed_fn(
self.assertIsNotNone(embed_fn.embedding_model)

test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
path_utils.get_absolute_path('inference/tests/testdata/clap.wav')
)

source_info = embed_lib.SourceInfo(test_wav_path, 0, 10)
Expand Down Expand Up @@ -205,17 +203,17 @@ def test_embed_fn(
def test_embed_fn_from_config(self, config_filename):
# Test that we can load a model from a golden config and compute embeddings.
test_config_path = os.fspath(
path_utils.get_absolute_path(f'tests/testdata/{config_filename}.json')
path_utils.get_absolute_path(
f'inference/tests/testdata/{config_filename}.json'
)
)
embed_config = embed_lib.load_embedding_config(test_config_path, '')
embed_fn = embed_lib.EmbedFn(**embed_config)
embed_fn.setup()
self.assertIsNotNone(embed_fn.embedding_model)

test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
path_utils.get_absolute_path('inference/tests/testdata/clap.wav')
)
source_info = embed_lib.SourceInfo(test_wav_path, 0, 10)
example = embed_fn.process(source_info, crop_s=10.0)[0]
Expand Down Expand Up @@ -253,9 +251,7 @@ def test_embed_fn_source_variations(self):
parser = tf_examples.get_example_parser()

test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
path_utils.get_absolute_path('inference/tests/testdata/clap.wav')
)

# Check that a SourceInfo with window_size_s <= 0 embeds the entire file.
Expand Down Expand Up @@ -315,9 +311,7 @@ def test_keyed_write_logits(self):
self.assertIsNotNone(embed_fn.embedding_model)

test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
path_utils.get_absolute_path('inference/tests/testdata/clap.wav')
)

source_info = embed_lib.SourceInfo(test_wav_path, 0, 10)
Expand Down Expand Up @@ -396,9 +390,7 @@ def test_embed_short_audio(self):
self.assertIsNotNone(embed_fn.embedding_model)

test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
path_utils.get_absolute_path('inference/tests/testdata/clap.wav')
)
source_info = embed_lib.SourceInfo(test_wav_path, 0, 10)
# Crop to 3.0s to ensure we can handle short audio examples.
Expand Down Expand Up @@ -443,11 +435,7 @@ def test_frame_audio(self):

def test_create_source_infos(self):
# Just one file, but it's all good.
globs = [
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
]
globs = [path_utils.get_absolute_path('inference/tests/testdata/clap.wav')]
# Disable sharding by setting shard_len_s <= 0.
got_infos = embed_lib.create_source_infos(
globs, shard_len_s=-1, num_shards_per_file=100
Expand Down Expand Up @@ -672,9 +660,7 @@ def test_pooled_embeddings(self):
def test_beam_pipeline(self):
"""Check that we can write embeddings to TFRecord file."""
test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
path_utils.get_absolute_path('inference/tests/testdata/clap.wav')
)
source_infos = [embed_lib.SourceInfo(test_wav_path, 0, 10)]
base_pipeline = test_pipeline.TestPipeline()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""Tests for the bootstrap search component."""

from chirp.projects.bootstrap import search
from chirp.inference.search import search
import numpy as np

from absl.testing import absltest
Expand Down
Binary file added chirp/inference/tests/testdata/clap.wav
Binary file not shown.

0 comments on commit 79fd47d

Please sign in to comment.