Skip to content

Commit

Permalink
Create embedding models from configurations with a from_config clas…
Browse files Browse the repository at this point in the history
…smethod instead of abusing the dataclass `post_init` function.

PiperOrigin-RevId: 557267692
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 16, 2023
1 parent e4cd2c8 commit f152d4c
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 124 deletions.
2 changes: 2 additions & 0 deletions chirp/inference/configs/separate_soundscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def get_config() -> config_dict.ConfigDict:
},
'separator_model_tf_config': {
'model_path': sep_model_checkpoint_path,
'window_size_s': 5.0,
'sample_rate': 32000,
'frame_size': 32000,
'target_class_list': None,
},
},
'speech_filter_threshold': 0.95,
Expand Down
3 changes: 3 additions & 0 deletions chirp/inference/configs/separated_seabirds.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_config() -> config_dict.ConfigDict:
'model_key': 'separate_embed_model',
'model_config': {
'sample_rate': 32000,
'embed_raw': True,
'taxonomy_model_tf_config': {
'model_path': emb_model_checkpoint_path,
'window_size_s': 5.0,
Expand All @@ -54,8 +55,10 @@ def get_config() -> config_dict.ConfigDict:
},
'separator_model_tf_config': {
'model_path': sep_model_checkpoint_path,
'window_size_s': 5.0,
'sample_rate': 32000,
'frame_size': 32000,
'target_class_list': None,
},
},
'speech_filter_threshold': 0.0,
Expand Down
5 changes: 2 additions & 3 deletions chirp/inference/embed_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def __init__(

def setup(self):
if self.embedding_model is None:
self.embedding_model = models.model_class_map()[self.model_key](
**self.model_config
)
model_class = models.model_class_map()[self.model_key]
self.embedding_model = model_class.from_config(self.model_config)
if hasattr(self, 'model_key'):
del self.model_key
if hasattr(self, 'model_config'):
Expand Down
8 changes: 8 additions & 0 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from chirp.taxonomy import namespace
import librosa
from ml_collections import config_dict
import numpy as np

LogitType = Dict[str, np.ndarray]
Expand Down Expand Up @@ -111,6 +112,13 @@ class EmbeddingModel:

sample_rate: int

@classmethod
def from_config(
cls, model_config: config_dict.ConfigDict
) -> 'EmbeddingModel':
"""Load the model from a configuration dict."""
raise NotImplementedError

def embed(self, audio_array: np.ndarray) -> InferenceOutputs:
"""Create InferenceOutputs from an audio array.
Expand Down
Loading

0 comments on commit f152d4c

Please sign in to comment.