diff --git a/analysis.ipynb b/analysis.ipynb index b69c751e..f4934dff 100644 --- a/analysis.ipynb +++ b/analysis.ipynb @@ -120,9 +120,9 @@ " 'model_path': custom_classifier_path,\n", " 'logits_key': 'custom',\n", "})\n", - "loaded_model = interface.LogitsOutputHead.from_config(cfg)\n", - "model = loaded_model.logits_model\n", - "class_list = loaded_model.class_list\n", + "logits_head = interface.LogitsOutputHead.from_config(cfg)\n", + "model = logits_head.logits_model\n", + "class_list = logits_head.class_list\n", "print('Loaded custom model with classes: ')\n", "print('\\t' + '\\n\\t'.join(class_list.classes))" ] @@ -167,7 +167,7 @@ "\n", "classify.write_inference_csv(\n", " embeddings_ds=embeddings_ds,\n", - " model=model,\n", + " model=logits_head,\n", " labels=class_list.classes,\n", " output_filepath=output_filepath,\n", " threshold=class_thresholds,\n", @@ -216,7 +216,7 @@ "\n", "embeddings_ds = project_state.create_embeddings_dataset(shuffle_files=True)\n", "results, all_logits = search.classifer_search_embeddings_parallel(\n", - " embeddings_classifier=model,\n", + " embeddings_classifier=logits_head,\n", " target_index=class_list.classes.index(target_class),\n", " random_sample=True,\n", " top_k=top_k,\n", diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index 176c8cf1..c20868df 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -18,6 +18,7 @@ import dataclasses from typing import Sequence +from chirp.inference import interface from chirp.inference import tf_examples from chirp.inference.classify import data_lib from chirp.models import metrics @@ -156,7 +157,7 @@ def train_embedding_model( def get_inference_dataset( embeddings_ds: tf.data.Dataset, - model: tf.keras.Model, + model: interface.LogitsOutputHead, ): """Create a dataset which includes the model's predictions.""" @@ -182,7 +183,7 @@ def classify_batch(batch): def write_inference_csv( embeddings_ds: tf.data.Dataset, - model: tf.keras.Model, + model: interface.LogitsOutputHead, labels: Sequence[str], output_filepath: str, embedding_hop_size_s: float, diff --git a/chirp/inference/interface.py b/chirp/inference/interface.py index e996497f..a4d415d5 100644 --- a/chirp/inference/interface.py +++ b/chirp/inference/interface.py @@ -202,6 +202,20 @@ def from_config(cls, config: config_dict.ConfigDict): **config, ) + def __call__(self, embeddings: np.ndarray) -> InferenceOutputs: + """Apply the wrapped logits_model to embeddings with shape [B, D].""" + if callable(self.logits_model): + logits = self.logits_model(embeddings) + elif hasattr(self.logits_model, 'signatures'): + # TODO(tomdenton): Figure out why the Keras saved model isn't callable. + flat_logits = self.logits_model.signatures['serving_default']( + inputs=embeddings + ) + logits = flat_logits['output_0'].numpy() + else: + raise ValueError('could not figure out how to call wrapped model.') + return logits + def save_model(self, output_path: str, embeddings_path: str): """Write a SavedModel and metadata to disk.""" # Write the model. @@ -230,16 +244,7 @@ def add_logits(self, model_outputs: InferenceOutputs, keep_original: bool): logging.warning('No embeddings found in model outputs.') return model_outputs flat_embeddings = np.reshape(embeddings, [-1, embeddings.shape[-1]]) - # TODO(tomdenton): Figure out why the keras saved model isn't callable. - if callable(self.logits_model): - flat_logits = self.logits_model(flat_embeddings) - elif hasattr(self.logits_model, 'signatures'): - flat_logits = self.logits_model.signatures['serving_default']( - inputs=flat_embeddings - ) - flat_logits = flat_logits['output_0'].numpy() - else: - raise ValueError('could not figure out how to call wrapped model.') + flat_logits = self(flat_embeddings) logits_shape = np.concatenate( [np.shape(embeddings)[:-1], np.shape(flat_logits)[-1:]], axis=0 ) diff --git a/chirp/inference/search/search.py b/chirp/inference/search/search.py index 7e70990c..f3d947d1 100644 --- a/chirp/inference/search/search.py +++ b/chirp/inference/search/search.py @@ -21,6 +21,7 @@ import heapq from typing import Any, Callable, List, Sequence +from chirp.inference import interface from chirp.inference import tf_examples from etils import epath import numpy as np @@ -316,7 +317,7 @@ def search_embeddings_parallel( def classifer_search_embeddings_parallel( - embeddings_classifier: tf.keras.Model, + embeddings_classifier: interface.LogitsOutputHead, target_index: int, **kwargs, ): diff --git a/chirp/inference/tests/classify_test.py b/chirp/inference/tests/classify_test.py index e99afd4e..3475442c 100644 --- a/chirp/inference/tests/classify_test.py +++ b/chirp/inference/tests/classify_test.py @@ -15,11 +15,12 @@ """Test small-model classification.""" -from collections.abc import Sequence +import tempfile -from absl import app +from chirp.inference import interface from chirp.inference.classify import classify from chirp.inference.classify import data_lib +from chirp.taxonomy import namespace import numpy as np from absl.testing import absltest @@ -60,9 +61,10 @@ def test_train_linear_model(self): num_classes = 4 num_points = 100 model = classify.get_linear_model(embedding_dim, num_classes) + rng = np.random.default_rng(42) merged = self.make_merged_dataset( num_points=num_points, - rng=np.random.default_rng(42), + rng=rng, num_classes=num_classes, embedding_dim=embedding_dim, ) @@ -76,6 +78,26 @@ def test_train_linear_model(self): batch_size=16, learning_rate=0.01, ) + query = rng.normal(size=(num_points, embedding_dim)).astype(np.float32) + + logits = model(query) + + # Save and restore the model. + class_names = ['a', 'b', 'c', 'd'] + with tempfile.TemporaryDirectory() as logits_model_dir: + logits_model = interface.LogitsOutputHead( + model_path=logits_model_dir, + logits_key='some_model', + logits_model=model, + class_list=namespace.ClassList('classes', class_names), + ) + logits_model.save_model(logits_model_dir, '') + restored_model = interface.LogitsOutputHead.from_config_file( + logits_model_dir + ) + restored_logits = restored_model(query) + error = np.abs(restored_logits - logits).sum() + self.assertEqual(error, 0) if __name__ == '__main__':