From 0527dc557bf6681d38d808773781ff8ca8e602d7 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Wed, 22 May 2024 12:31:53 -0700 Subject: [PATCH] Make LogitsOutputHead callable. Fixes yet another Keras SNAFU. PiperOrigin-RevId: 636256769 --- analysis.ipynb | 10 +++++----- chirp/inference/classify/classify.py | 5 +++-- chirp/inference/interface.py | 25 +++++++++++++++---------- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/analysis.ipynb b/analysis.ipynb index b69c751e..f4934dff 100644 --- a/analysis.ipynb +++ b/analysis.ipynb @@ -120,9 +120,9 @@ " 'model_path': custom_classifier_path,\n", " 'logits_key': 'custom',\n", "})\n", - "loaded_model = interface.LogitsOutputHead.from_config(cfg)\n", - "model = loaded_model.logits_model\n", - "class_list = loaded_model.class_list\n", + "logits_head = interface.LogitsOutputHead.from_config(cfg)\n", + "model = logits_head.logits_model\n", + "class_list = logits_head.class_list\n", "print('Loaded custom model with classes: ')\n", "print('\\t' + '\\n\\t'.join(class_list.classes))" ] @@ -167,7 +167,7 @@ "\n", "classify.write_inference_csv(\n", " embeddings_ds=embeddings_ds,\n", - " model=model,\n", + " model=logits_head,\n", " labels=class_list.classes,\n", " output_filepath=output_filepath,\n", " threshold=class_thresholds,\n", @@ -216,7 +216,7 @@ "\n", "embeddings_ds = project_state.create_embeddings_dataset(shuffle_files=True)\n", "results, all_logits = search.classifer_search_embeddings_parallel(\n", - " embeddings_classifier=model,\n", + " embeddings_classifier=logits_head,\n", " target_index=class_list.classes.index(target_class),\n", " random_sample=True,\n", " top_k=top_k,\n", diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index 176c8cf1..c20868df 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -18,6 +18,7 @@ import dataclasses from typing import Sequence +from chirp.inference import interface from chirp.inference import tf_examples from chirp.inference.classify import data_lib from chirp.models import metrics @@ -156,7 +157,7 @@ def train_embedding_model( def get_inference_dataset( embeddings_ds: tf.data.Dataset, - model: tf.keras.Model, + model: interface.LogitsOutputHead, ): """Create a dataset which includes the model's predictions.""" @@ -182,7 +183,7 @@ def classify_batch(batch): def write_inference_csv( embeddings_ds: tf.data.Dataset, - model: tf.keras.Model, + model: interface.LogitsOutputHead, labels: Sequence[str], output_filepath: str, embedding_hop_size_s: float, diff --git a/chirp/inference/interface.py b/chirp/inference/interface.py index e996497f..a4d415d5 100644 --- a/chirp/inference/interface.py +++ b/chirp/inference/interface.py @@ -202,6 +202,20 @@ def from_config(cls, config: config_dict.ConfigDict): **config, ) + def __call__(self, embeddings: np.ndarray) -> InferenceOutputs: + """Apply the wrapped logits_model to embeddings with shape [B, D].""" + if callable(self.logits_model): + logits = self.logits_model(embeddings) + elif hasattr(self.logits_model, 'signatures'): + # TODO(tomdenton): Figure out why the Keras saved model isn't callable. + flat_logits = self.logits_model.signatures['serving_default']( + inputs=embeddings + ) + logits = flat_logits['output_0'].numpy() + else: + raise ValueError('could not figure out how to call wrapped model.') + return logits + def save_model(self, output_path: str, embeddings_path: str): """Write a SavedModel and metadata to disk.""" # Write the model. @@ -230,16 +244,7 @@ def add_logits(self, model_outputs: InferenceOutputs, keep_original: bool): logging.warning('No embeddings found in model outputs.') return model_outputs flat_embeddings = np.reshape(embeddings, [-1, embeddings.shape[-1]]) - # TODO(tomdenton): Figure out why the keras saved model isn't callable. - if callable(self.logits_model): - flat_logits = self.logits_model(flat_embeddings) - elif hasattr(self.logits_model, 'signatures'): - flat_logits = self.logits_model.signatures['serving_default']( - inputs=flat_embeddings - ) - flat_logits = flat_logits['output_0'].numpy() - else: - raise ValueError('could not figure out how to call wrapped model.') + flat_logits = self(flat_embeddings) logits_shape = np.concatenate( [np.shape(embeddings)[:-1], np.shape(flat_logits)[-1:]], axis=0 )