Skip to content

Commit

Permalink
Create embedding models from configurations with a from_config clas…
Browse files Browse the repository at this point in the history
…smethod instead of abusing the dataclass `post_init` function.

PiperOrigin-RevId: 557267692
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 15, 2023
1 parent e4cd2c8 commit b460a89
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 99 deletions.
5 changes: 2 additions & 3 deletions chirp/inference/embed_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def __init__(

def setup(self):
if self.embedding_model is None:
self.embedding_model = models.model_class_map()[self.model_key](
**self.model_config
)
model_class = models.model_class_map()[self.model_key]
self.embedding_model = model_class.from_config(**self.model_config)
if hasattr(self, 'model_key'):
del self.model_key
if hasattr(self, 'model_config'):
Expand Down
8 changes: 8 additions & 0 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from chirp.taxonomy import namespace
import librosa
from ml_collections import config_dict
import numpy as np

LogitType = Dict[str, np.ndarray]
Expand Down Expand Up @@ -111,6 +112,13 @@ class EmbeddingModel:

sample_rate: int

@classmethod
def from_config(
cls, model_config: config_dict.ConfigDict
) -> 'EmbeddingModel':
"""Load the model from a configuration dict."""
raise NotImplementedError

def embed(self, audio_array: np.ndarray) -> InferenceOutputs:
"""Create InferenceOutputs from an audio array.
Expand Down
Loading

0 comments on commit b460a89

Please sign in to comment.