diff --git a/.github/workflows/ci_no_jaxtrain.yml b/.github/workflows/ci_no_jaxtrain.yml index c47649c3..7c93bca2 100644 --- a/.github/workflows/ci_no_jaxtrain.yml +++ b/.github/workflows/ci_no_jaxtrain.yml @@ -32,4 +32,4 @@ jobs: poetry install --without jaxtrain - name: Test with unittest # TODO: Group together jaxtrain tests so they can be easily excluded. - run: poetry run python -m unittest discover -s chirp/tests -p "*inference_test.py" + run: poetry run python -m unittest discover -s chirp/inference/tests -p "*test.py" diff --git a/chirp/projects/multicluster/classify.py b/chirp/inference/classify/classify.py similarity index 97% rename from chirp/projects/multicluster/classify.py rename to chirp/inference/classify/classify.py index b3d13e4a..b1952323 100644 --- a/chirp/projects/multicluster/classify.py +++ b/chirp/inference/classify/classify.py @@ -18,8 +18,8 @@ import dataclasses from typing import Sequence +from chirp.inference.classify import data_lib from chirp.models import metrics -from chirp.projects.multicluster import data_lib import numpy as np import tensorflow as tf @@ -131,9 +131,7 @@ def train_embedding_model( ) -> ClassifierMetrics: """Trains a classification model over embeddings and labels.""" train_locs, test_locs, _ = merged.create_random_train_test_split( - train_ratio, - train_examples_per_class, - random_seed, + train_ratio, train_examples_per_class, random_seed, exclude_eval_classes=exclude_eval_classes, ) test_metrics = train_from_locs( diff --git a/chirp/projects/multicluster/data_lib.py b/chirp/inference/classify/data_lib.py similarity index 100% rename from chirp/projects/multicluster/data_lib.py rename to chirp/inference/classify/data_lib.py diff --git a/chirp/projects/bootstrap/bootstrap.py b/chirp/inference/search/bootstrap.py similarity index 95% rename from chirp/projects/bootstrap/bootstrap.py rename to chirp/inference/search/bootstrap.py index 96c3e8a0..c6b76807 100644 --- a/chirp/projects/bootstrap/bootstrap.py +++ b/chirp/inference/search/bootstrap.py @@ -91,12 +91,12 @@ class BootstrapConfig: tensor_dtype: str # The following are populated automatically from the embedding config. - embedding_hop_size_s: float | None = None - file_id_depth: int | None = None - audio_globs: Sequence[str] | None = None - model_key: str | None = None - model_config: config_dict.ConfigDict | None = None - tf_record_shards: int | None = None + embedding_hop_size_s: float + file_id_depth: int + audio_globs: Sequence[str] + model_key: str + model_config: config_dict.ConfigDict + tf_record_shards: int @classmethod def load_from_embedding_config( diff --git a/chirp/projects/bootstrap/display.py b/chirp/inference/search/display.py similarity index 98% rename from chirp/projects/bootstrap/display.py rename to chirp/inference/search/display.py index 8a292d15..aade74c8 100644 --- a/chirp/projects/bootstrap/display.py +++ b/chirp/inference/search/display.py @@ -19,8 +19,8 @@ from typing import Sequence from chirp import audio_utils +from chirp.inference.search import search from chirp.models import frontend -from chirp.projects.bootstrap import search import IPython from IPython.display import display as ipy_display import ipywidgets diff --git a/chirp/projects/bootstrap/search.py b/chirp/inference/search/search.py similarity index 100% rename from chirp/projects/bootstrap/search.py rename to chirp/inference/search/search.py diff --git a/chirp/tests/inference_test.py b/chirp/inference/tests/embed_test.py similarity index 95% rename from chirp/tests/inference_test.py rename to chirp/inference/tests/embed_test.py index 4e06e89f..fac351cb 100644 --- a/chirp/tests/inference_test.py +++ b/chirp/inference/tests/embed_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for inference library.""" +"""Tests for mass-embedding functionality.""" import os import tempfile @@ -28,12 +28,12 @@ from chirp.inference import interface from chirp.inference import models from chirp.inference import tf_examples +from chirp.inference.classify import classify +from chirp.inference.classify import data_lib +from chirp.inference.search import bootstrap +from chirp.inference.search import display +from chirp.inference.search import search from chirp.models import metrics -from chirp.projects.bootstrap import bootstrap -from chirp.projects.bootstrap import display -from chirp.projects.bootstrap import search -from chirp.projects.multicluster import classify -from chirp.projects.multicluster import data_lib from chirp.taxonomy import namespace from etils import epath from ml_collections import config_dict @@ -56,7 +56,7 @@ def _make_output_head_model(model_path: str, embedding_dim: int = 1280): ) -class InferenceTest(parameterized.TestCase): +class EmbedTest(parameterized.TestCase): def test_imports(self): # Test that imports work in external github environment. @@ -139,9 +139,7 @@ def test_embed_fn( self.assertIsNotNone(embed_fn.embedding_model) test_wav_path = os.fspath( - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) + path_utils.get_absolute_path('inference/tests/testdata/clap.wav') ) source_info = embed_lib.SourceInfo(test_wav_path, 0, 10) @@ -205,7 +203,9 @@ def test_embed_fn( def test_embed_fn_from_config(self, config_filename): # Test that we can load a model from a golden config and compute embeddings. test_config_path = os.fspath( - path_utils.get_absolute_path(f'tests/testdata/{config_filename}.json') + path_utils.get_absolute_path( + f'inference/tests/testdata/{config_filename}.json' + ) ) embed_config = embed_lib.load_embedding_config(test_config_path, '') embed_fn = embed_lib.EmbedFn(**embed_config) @@ -213,9 +213,7 @@ def test_embed_fn_from_config(self, config_filename): self.assertIsNotNone(embed_fn.embedding_model) test_wav_path = os.fspath( - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) + path_utils.get_absolute_path('inference/tests/testdata/clap.wav') ) source_info = embed_lib.SourceInfo(test_wav_path, 0, 10) example = embed_fn.process(source_info, crop_s=10.0)[0] @@ -253,9 +251,7 @@ def test_embed_fn_source_variations(self): parser = tf_examples.get_example_parser() test_wav_path = os.fspath( - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) + path_utils.get_absolute_path('inference/tests/testdata/clap.wav') ) # Check that a SourceInfo with window_size_s <= 0 embeds the entire file. @@ -315,9 +311,7 @@ def test_keyed_write_logits(self): self.assertIsNotNone(embed_fn.embedding_model) test_wav_path = os.fspath( - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) + path_utils.get_absolute_path('inference/tests/testdata/clap.wav') ) source_info = embed_lib.SourceInfo(test_wav_path, 0, 10) @@ -396,9 +390,7 @@ def test_embed_short_audio(self): self.assertIsNotNone(embed_fn.embedding_model) test_wav_path = os.fspath( - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) + path_utils.get_absolute_path('inference/tests/testdata/clap.wav') ) source_info = embed_lib.SourceInfo(test_wav_path, 0, 10) # Crop to 3.0s to ensure we can handle short audio examples. @@ -443,11 +435,7 @@ def test_frame_audio(self): def test_create_source_infos(self): # Just one file, but it's all good. - globs = [ - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) - ] + globs = [path_utils.get_absolute_path('inference/tests/testdata/clap.wav')] # Disable sharding by setting shard_len_s <= 0. got_infos = embed_lib.create_source_infos( globs, shard_len_s=-1, num_shards_per_file=100 @@ -672,9 +660,7 @@ def test_pooled_embeddings(self): def test_beam_pipeline(self): """Check that we can write embeddings to TFRecord file.""" test_wav_path = os.fspath( - path_utils.get_absolute_path( - 'tests/testdata/tfds_builder_wav_directory_test/clap.wav' - ) + path_utils.get_absolute_path('inference/tests/testdata/clap.wav') ) source_infos = [embed_lib.SourceInfo(test_wav_path, 0, 10)] base_pipeline = test_pipeline.TestPipeline() diff --git a/chirp/projects/bootstrap/tests/search_test.py b/chirp/inference/tests/search_test.py similarity index 99% rename from chirp/projects/bootstrap/tests/search_test.py rename to chirp/inference/tests/search_test.py index 81a6e0f3..f8bc044b 100644 --- a/chirp/projects/bootstrap/tests/search_test.py +++ b/chirp/inference/tests/search_test.py @@ -15,7 +15,7 @@ """Tests for the bootstrap search component.""" -from chirp.projects.bootstrap import search +from chirp.inference.search import search import numpy as np from absl.testing import absltest diff --git a/chirp/inference/tests/testdata/clap.wav b/chirp/inference/tests/testdata/clap.wav new file mode 100644 index 00000000..aa17258d Binary files /dev/null and b/chirp/inference/tests/testdata/clap.wav differ diff --git a/chirp/tests/testdata/embedding_config_v0.json b/chirp/inference/tests/testdata/embedding_config_v0.json similarity index 100% rename from chirp/tests/testdata/embedding_config_v0.json rename to chirp/inference/tests/testdata/embedding_config_v0.json diff --git a/chirp/tests/testdata/embedding_config_v1.json b/chirp/inference/tests/testdata/embedding_config_v1.json similarity index 100% rename from chirp/tests/testdata/embedding_config_v1.json rename to chirp/inference/tests/testdata/embedding_config_v1.json