From cfefbcc63c4812ec7da119341e17b68386c8c05d Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Tue, 15 Aug 2023 15:20:03 -0700 Subject: [PATCH] Create embedding models from configurations with a `from_config` classmethod instead of abusing the dataclass `post_init` function. PiperOrigin-RevId: 557267692 --- .../inference/configs/separate_soundscapes.py | 2 + chirp/inference/configs/separated_seabirds.py | 3 + chirp/inference/embed_lib.py | 5 +- chirp/inference/interface.py | 8 + chirp/inference/models.py | 258 +++++++++++------- chirp/tests/inference_test.py | 27 +- 6 files changed, 179 insertions(+), 124 deletions(-) diff --git a/chirp/inference/configs/separate_soundscapes.py b/chirp/inference/configs/separate_soundscapes.py index 660fdb3e..b4b022ef 100644 --- a/chirp/inference/configs/separate_soundscapes.py +++ b/chirp/inference/configs/separate_soundscapes.py @@ -52,8 +52,10 @@ def get_config() -> config_dict.ConfigDict: }, 'separator_model_tf_config': { 'model_path': sep_model_checkpoint_path, + 'window_size_s': 5.0, 'sample_rate': 32000, 'frame_size': 32000, + 'target_class_list': None, }, }, 'speech_filter_threshold': 0.95, diff --git a/chirp/inference/configs/separated_seabirds.py b/chirp/inference/configs/separated_seabirds.py index dcd3d84e..0097a4cc 100644 --- a/chirp/inference/configs/separated_seabirds.py +++ b/chirp/inference/configs/separated_seabirds.py @@ -46,6 +46,7 @@ def get_config() -> config_dict.ConfigDict: 'model_key': 'separate_embed_model', 'model_config': { 'sample_rate': 32000, + 'embed_raw': True, 'taxonomy_model_tf_config': { 'model_path': emb_model_checkpoint_path, 'window_size_s': 5.0, @@ -54,8 +55,10 @@ def get_config() -> config_dict.ConfigDict: }, 'separator_model_tf_config': { 'model_path': sep_model_checkpoint_path, + 'window_size_s': 5.0, 'sample_rate': 32000, 'frame_size': 32000, + 'target_class_list': None, }, }, 'speech_filter_threshold': 0.0, diff --git a/chirp/inference/embed_lib.py b/chirp/inference/embed_lib.py index 16aa516b..2e64e6a2 100644 --- a/chirp/inference/embed_lib.py +++ b/chirp/inference/embed_lib.py @@ -139,9 +139,8 @@ def __init__( def setup(self): if self.embedding_model is None: - self.embedding_model = models.model_class_map()[self.model_key]( - **self.model_config - ) + model_class = models.model_class_map()[self.model_key] + self.embedding_model = model_class.from_config(self.model_config) if hasattr(self, 'model_key'): del self.model_key if hasattr(self, 'model_config'): diff --git a/chirp/inference/interface.py b/chirp/inference/interface.py index 03396bbc..da455c22 100644 --- a/chirp/inference/interface.py +++ b/chirp/inference/interface.py @@ -20,6 +20,7 @@ from chirp.taxonomy import namespace import librosa +from ml_collections import config_dict import numpy as np LogitType = Dict[str, np.ndarray] @@ -111,6 +112,13 @@ class EmbeddingModel: sample_rate: int + @classmethod + def from_config( + cls, model_config: config_dict.ConfigDict + ) -> 'EmbeddingModel': + """Load the model from a configuration dict.""" + raise NotImplementedError + def embed(self, audio_array: np.ndarray) -> InferenceOutputs: """Create InferenceOutputs from an audio array. diff --git a/chirp/inference/models.py b/chirp/inference/models.py index 60bd8d62..6c35efc6 100644 --- a/chirp/inference/models.py +++ b/chirp/inference/models.py @@ -54,27 +54,33 @@ class SeparateEmbedModel(interface.EmbeddingModel): rate is used to resample prior to computing the embedding. Attributes: - taxonomy_model_tf_config: Configuration for a TaxonomyModelTF. - separator_model_tf_config: Configuration for a SeparationModelTF. + separation_model: SeparationModelTF. + embedding_model: TaxonomyModelTF. embed_raw: If True, the outputs will include embeddings of the original audio in addition to embeddings for the separated channels. The embeddings will have shape [T, C+1, D], with the raw audio embedding on channel 0. - separation_model: SeparationModelTF, automatically populated during init. - embedding_model: TaxonomyModelTF, automatically populated during init. """ - taxonomy_model_tf_config: config_dict.ConfigDict - separator_model_tf_config: config_dict.ConfigDict + separation_model: 'SeparatorModelTF' + embedding_model: 'TaxonomyModelTF' embed_raw: bool = True - # Populated during init. - separation_model: Any = None - embedding_model: Any = None + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'SeparateEmbedModel': + separation_model = SeparatorModelTF.from_config( + config.separator_model_tf_config + ) + embedding_model = TaxonomyModelTF.from_config( + config.taxonomy_model_tf_config + ) + return cls( + sample_rate=config.sample_rate, + separation_model=separation_model, + embedding_model=embedding_model, + embed_raw=config.embed_raw, + ) def __post_init__(self): - if self.separation_model is None: - self.separation_model = SeparatorModelTF(**self.separator_model_tf_config) - self.embedding_model = TaxonomyModelTF(**self.taxonomy_model_tf_config) if self.separation_model.sample_rate != self.embedding_model.sample_rate: raise ValueError( 'Separation and embedding models must have matching rates.' @@ -93,6 +99,8 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: separation_outputs = self.separation_model.batch_embed(framed_audio) # separated_audio has shape [F, C, T] separated_audio = separation_outputs.separated_audio + if separated_audio is None: + raise RuntimeError('Separation model returned None for separated audio.') if self.embed_raw: separated_audio = np.concatenate( @@ -109,12 +117,13 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: embedding_outputs = self.embedding_model.batch_embed(separated_audio) - # Batch embeddings have shape [Batch, Time, Channels, Features] - # Time is 1 because we have framed using the embedding model's window_size. - # The batch size is num_frames * num_channels. - embeddings = np.reshape( - embedding_outputs.embeddings, [num_frames, num_channels, -1] - ) + if embedding_outputs.embeddings is not None: + # Batch embeddings have shape [Batch, Time, Channels, Features] + # Time is 1 because we have framed using the embedding model's + # window_size. The batch size is num_frames * num_channels. + embedding_outputs.embeddings = np.reshape( + embedding_outputs.embeddings, [num_frames, num_channels, -1] + ) # Take the maximum logits over the channels dimension. if embedding_outputs.logits is not None: @@ -126,7 +135,7 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: max_logits = None return interface.InferenceOutputs( - embeddings=embeddings, + embeddings=embedding_outputs.embeddings, logits=max_logits, # Because the separated audio is framed, it does not match the # outputs interface, so we do not return it. @@ -141,28 +150,34 @@ class BirbSepModelTF1(interface.EmbeddingModel): model_path: str window_size_s: float keep_raw_channel: bool + session: Any + input_placeholder_ns: Any + output_tensor_ns: Any - # The following are populated at init time. - session: Any | None = None - input_placeholder_ns: Any | None = None - output_tensor_ns: Any | None = None - - def __post_init__(self): + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'BirbSepModelTF1': """Load model files and create TF1 session graph.""" - metagraph_path_ns = epath.Path(self.model_path) / 'inference.meta' - checkpoint_path = tf.train.latest_checkpoint(self.model_path) + metagraph_path_ns = epath.Path(config.model_path) / 'inference.meta' + checkpoint_path = tf.train.latest_checkpoint(config.model_path) graph_ns = tf.Graph() sess_ns = tf1.Session(graph=graph_ns) with graph_ns.as_default(): new_saver = tf1.train.import_meta_graph(metagraph_path_ns) new_saver.restore(sess_ns, checkpoint_path) - self.input_placeholder_ns = graph_ns.get_tensor_by_name( + input_placeholder_ns = graph_ns.get_tensor_by_name( 'input_audio/receiver_audio:0' ) - self.output_tensor_ns = graph_ns.get_tensor_by_name( - 'denoised_waveforms:0' - ) - self.session = sess_ns + output_tensor_ns = graph_ns.get_tensor_by_name('denoised_waveforms:0') + session = sess_ns + return cls( + model_path=config.model_path, + sample_rate=config.sample_rate, + window_size_s=config.window_size_s, + keep_raw_channel=config.keep_raw_channel, + session=session, + input_placeholder_ns=input_placeholder_ns, + output_tensor_ns=output_tensor_ns, + ) def embed(self, audio_array: Any) -> interface.InferenceOutputs: start_sample = 0 @@ -204,45 +219,48 @@ class TaxonomyModelTF(interface.EmbeddingModel): window_size_s: Window size for framing audio in seconds. TODO(tomdenton): Ideally this should come from a model metadata file. hop_size_s: Hop size for inference. - target_class_list: If provided, restricts logits to this ClassList. model: Loaded TF SavedModel. class_list: Loaded class_list for the model's output logits. batchable: Whether the model supports batched input. + target_class_list: If provided, restricts logits to this ClassList. + target_peak: Peak normalization value. """ model_path: str window_size_s: float hop_size_s: float + model: Any # TF SavedModel + class_list: namespace.ClassList + batchable: bool target_class_list: namespace.ClassList | None = None target_peak: float | None = 0.25 - # The following are populated during init. - model: Any | None = None # TF SavedModel - class_list: namespace.ClassList | None = None - batchable: bool = False - - def __post_init__(self): + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': logging.info('Loading taxonomy model...') - base_path = epath.Path(self.model_path) + base_path = epath.Path(config.model_path) if (base_path / 'saved_model.pb').exists() and ( base_path / 'assets' ).exists(): # This looks like a TFHub downloaded model. model_path = base_path - label_csv_path = epath.Path(self.model_path) / 'assets' / 'label.csv' + label_csv_path = epath.Path(config.model_path) / 'assets' / 'label.csv' else: # Probably a savedmodel distributed directly. model_path = base_path / 'savedmodel' label_csv_path = base_path / 'label.csv' - self.model = tf.saved_model.load(model_path) + model = tf.saved_model.load(model_path) with label_csv_path.open('r') as f: - self.class_list = namespace.ClassList.from_csv(f) + class_list = namespace.ClassList.from_csv(f) # Check whether the model support polymorphic batch shape. - sig = self.model.signatures['serving_default'] - self.batchable = sig.inputs[0].shape[0] is None + sig = model.signatures['serving_default'] + batchable = sig.inputs[0].shape[0] is None + return cls( + model=model, class_list=class_list, batchable=batchable, **config + ) def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: if self.batchable: @@ -306,26 +324,26 @@ class SeparatorModelTF(interface.EmbeddingModel): windows_size_s: Window size for framing audio in samples. The audio will be chunked into frames of size window_size_s, which may help avoid memory blowouts. - target_class_list: If provided, restricts logits to this ClassList. model: Loaded TF SavedModel. class_list: Loaded class_list for the model's output logits. + target_class_list: If provided, restricts logits to this ClassList. """ model_path: str frame_size: int - window_size_s: float | None = None + window_size_s: float + model: Any + class_list: namespace.ClassList target_class_list: namespace.ClassList | None = None - # The following are populated during init. - model: Any | None = None # TF SavedModel - class_list: namespace.ClassList | None = None - - def __post_init__(self): - logging.info('Loading taxonomy model...') - self.model = tf.saved_model.load(epath.Path(self.model_path) / 'savedmodel') - label_csv_path = epath.Path(self.model_path) / 'label.csv' + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'SeparatorModelTF': + logging.info('Loading taxonomy separator model...') + model = tf.saved_model.load(epath.Path(config.model_path) / 'savedmodel') + label_csv_path = epath.Path(config.model_path) / 'label.csv' with label_csv_path.open('r') as f: - self.class_list = namespace.ClassList.from_csv(f) + class_list = namespace.ClassList.from_csv(f) + return cls(model=model, class_list=class_list, **config) def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: # Drop samples to allow reshaping to frame_size @@ -375,41 +393,55 @@ class BirdNet(interface.EmbeddingModel): Attributes: model_path: Path to the saved model checkpoint or TFLite file. - class_list_name: Name of the BirdNet class list. + model: The TF SavedModel or TFLite interpreter. + tflite: Whether the model is a TFLite model. + class_list: The loaded class list. window_size_s: Window size for framing audio in samples. hop_size_s: Hop size for inference. num_tflite_threads: Number of threads to use with TFLite model. + class_list_name: Name of the BirdNet class list. target_class_list: If provided, restricts logits to this ClassList. - model: The TF SavedModel or TFLite interpreter. - tflite: Whether the model is a TFLite model. - class_list: The loaded class list. """ model_path: str - class_list_name: str = 'birdnet_v2_1' + model: Any + tflite: bool + class_list: namespace.ClassList window_size_s: float = 3.0 hop_size_s: float = 3.0 num_tflite_threads: int = 16 + class_list_name: str = 'birdnet_v2_1' target_class_list: namespace.ClassList | None = None - # The following are populated during init. - model: Any | None = None - tflite: bool = False - class_list: namespace.ClassList | None = None - def __post_init__(self): + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'BirdNet': logging.info('Loading BirdNet model...') - if self.model_path.endswith('.tflite'): - self.tflite = True + if config.model_path.endswith('.tflite'): + tflite = True with tempfile.NamedTemporaryFile() as tmpf: - model_file = epath.Path(self.model_path) + model_file = epath.Path(config.model_path) model_file.copy(tmpf.name, overwrite=True) - self.model = tf.lite.Interpreter( - tmpf.name, num_threads=self.num_tflite_threads + model = tf.lite.Interpreter( + tmpf.name, num_threads=config.num_tflite_threads ) - self.model.allocate_tensors() + model.allocate_tensors() else: - self.tflite = False - self.model = tf.saved_model.load(self.model_path) + tflite = False + model = tf.saved_model.load(config.model_path) + db = namespace_db.load_db() + class_list = db.class_lists[config.class_list_name] + return cls( + sample_rate=config.sample_rate, + model_path=config.model_path, + class_list_name=config.class_list_name, + window_size_s=config.window_size_s, + hop_size_s=config.hop_size_s, + num_tflite_threads=config.num_tflite_threads, + target_class_list=config.target_class_list, + model=model, + tflite=tflite, + class_list=class_list, + ) def embed_saved_model( self, audio_array: np.ndarray @@ -470,8 +502,24 @@ class HandcraftedFeaturesModel(interface.EmbeddingModel): window_size_s: float hop_size_s: float - melspec_config: config_dict.ConfigDict - features_config: config_dict.ConfigDict + melspec_layer: frontend.Frontend + features_layer: handcrafted_features.HandcraftedFeatures + + @classmethod + def from_config( + cls, config: config_dict.ConfigDict + ) -> 'HandcraftedFeaturesModel': + melspec_layer = frontend.MelSpectrogram(**config.melspec_config) + features_layer = handcrafted_features.HandcraftedFeatures( + **config.features_config + ) + return cls( + sample_rate=config.sample_rate, + window_size_s=config.window_size_s, + hop_size_s=config.hop_size_s, + melspec_layer=melspec_layer, + features_layer=features_layer, + ) @classmethod def beans_baseline(cls, sample_rate=32000, frame_rate=100): @@ -488,20 +536,15 @@ def beans_baseline(cls, sample_rate=32000, frame_rate=100): 'compute_mfccs': True, 'aggregation': 'beans', }) + config = config_dict.ConfigDict({ + 'sample_rate': sample_rate, + 'melspec_config': mel_config, + 'features_config': features_config, + 'window_size_s': 1.0, + 'hop_size_s': 1.0, + }) # pylint: disable=unexpected-keyword-arg - return HandcraftedFeaturesModel( - sample_rate=sample_rate, - window_size_s=1.0, - hop_size_s=1.0, - melspec_config=mel_config, - features_config=features_config, - ) - - def __post_init__(self): - self.melspec_layer = frontend.MelSpectrogram(**self.melspec_config) - self.features_layer = handcrafted_features.HandcraftedFeatures( - **self.features_config - ) + return HandcraftedFeaturesModel.from_config(config) def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: framed_audio = self.frame_audio( @@ -523,23 +566,42 @@ def batch_embed(self, audio_batch: np.ndarray) -> interface.InferenceOutputs: class TFHubModel(interface.EmbeddingModel): """Generic wrapper for TFHub models which produce embeddings.""" + model: Any # TFHub loaded model. model_url: str embedding_index: int logits_index: int = -1 + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'TFHubModel': + model = hub.load(config.model_url) + return cls( + sample_rate=config.sample_rate, + model=model, + model_url=config.model_url, + embedding_index=config.embedding_index, + logits_index=config.logits_index, + ) + @classmethod def yamnet(cls): # Parent class takes a sample_rate arg which pylint doesn't find. - # pylint: disable=too-many-function-args - return TFHubModel(16000, 'https://tfhub.dev/google/yamnet/1', 1, 0) + config = config_dict.ConfigDict({ + 'sample_rate': 16000, + 'model_url': 'https://tfhub.dev/google/yamnet/1', + 'embedding_index': 1, + 'logits_index': 0, + }) + return TFHubModel.from_config(config) @classmethod def vggish(cls): - # pylint: disable=too-many-function-args - return TFHubModel(16000, 'https://tfhub.dev/google/vggish/1', -1) - - def __post_init__(self): - self.model = hub.load(self.model_url) + config = config_dict.ConfigDict({ + 'sample_rate': 16000, + 'model_url': 'https://tfhub.dev/google/vggish/1', + 'embedding_index': -1, + 'logits_index': -1, + }) + return TFHubModel.from_config(config) def embed( self, audio_array: np.ndarray[Any, np.dtype[Any]] @@ -577,6 +639,10 @@ class PlaceholderModel(interface.EmbeddingModel): window_size_s: float = 1.0 hop_size_s: float = 1.0 + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'PlaceholderModel': + return cls(**config) + def __post_init__(self): db = namespace_db.load_db() self.class_list = db.class_lists['caples'] diff --git a/chirp/tests/inference_test.py b/chirp/tests/inference_test.py index 7eddca5e..58e7a390 100644 --- a/chirp/tests/inference_test.py +++ b/chirp/tests/inference_test.py @@ -212,29 +212,9 @@ def test_load_configs(self, config_name): self.assertIsNotNone(config) def test_handcrafted_features(self): - sample_rate = 32000 - frame_rate = 100 - mel_config = { - 'sample_rate': sample_rate, - 'features': 160, - 'stride': sample_rate // frame_rate, - 'kernel_size': int(0.08 * sample_rate), - 'freq_range': (60.0, sample_rate / 2.0), - 'scaling_config': frontend.LogScalingConfig(), - } - features_config = { - 'compute_mfccs': True, - 'aggregation': 'beans', - } - model = models.HandcraftedFeaturesModel( - sample_rate=sample_rate, - window_size_s=1.0, - hop_size_s=1.0, - melspec_config=mel_config, - features_config=features_config, - ) + model = models.HandcraftedFeaturesModel.beans_baseline() - audio = np.zeros([5 * sample_rate], dtype=np.float32) + audio = np.zeros([5 * 32000], dtype=np.float32) outputs = model.embed(audio) # Five frames because we have 5s of audio with window 1.0 and hope 1.0. # Beans aggrregation with mfccs creates 20 MFCC channels, and then computes @@ -259,11 +239,8 @@ def test_sep_embed_wrapper(self): make_separated_audio=False, target_class_list=target_class_list, ) - fake_config = config_dict.ConfigDict() sep_embed = models.SeparateEmbedModel( sample_rate=22050, - taxonomy_model_tf_config=fake_config, - separator_model_tf_config=fake_config, separation_model=separator, embedding_model=embeddor, )