Skip to content

Commit

Permalink
Make LogitsOutputHead callable. Fixes yet another Keras SNAFU.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636256769
  • Loading branch information
sdenton4 authored and copybara-github committed May 22, 2024
1 parent 81f7604 commit 0527dc5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
10 changes: 5 additions & 5 deletions analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@
" 'model_path': custom_classifier_path,\n",
" 'logits_key': 'custom',\n",
"})\n",
"loaded_model = interface.LogitsOutputHead.from_config(cfg)\n",
"model = loaded_model.logits_model\n",
"class_list = loaded_model.class_list\n",
"logits_head = interface.LogitsOutputHead.from_config(cfg)\n",
"model = logits_head.logits_model\n",
"class_list = logits_head.class_list\n",
"print('Loaded custom model with classes: ')\n",
"print('\\t' + '\\n\\t'.join(class_list.classes))"
]
Expand Down Expand Up @@ -167,7 +167,7 @@
"\n",
"classify.write_inference_csv(\n",
" embeddings_ds=embeddings_ds,\n",
" model=model,\n",
" model=logits_head,\n",
" labels=class_list.classes,\n",
" output_filepath=output_filepath,\n",
" threshold=class_thresholds,\n",
Expand Down Expand Up @@ -216,7 +216,7 @@
"\n",
"embeddings_ds = project_state.create_embeddings_dataset(shuffle_files=True)\n",
"results, all_logits = search.classifer_search_embeddings_parallel(\n",
" embeddings_classifier=model,\n",
" embeddings_classifier=logits_head,\n",
" target_index=class_list.classes.index(target_class),\n",
" random_sample=True,\n",
" top_k=top_k,\n",
Expand Down
5 changes: 3 additions & 2 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
from typing import Sequence

from chirp.inference import interface
from chirp.inference import tf_examples
from chirp.inference.classify import data_lib
from chirp.models import metrics
Expand Down Expand Up @@ -156,7 +157,7 @@ def train_embedding_model(

def get_inference_dataset(
embeddings_ds: tf.data.Dataset,
model: tf.keras.Model,
model: interface.LogitsOutputHead,
):
"""Create a dataset which includes the model's predictions."""

Expand All @@ -182,7 +183,7 @@ def classify_batch(batch):

def write_inference_csv(
embeddings_ds: tf.data.Dataset,
model: tf.keras.Model,
model: interface.LogitsOutputHead,
labels: Sequence[str],
output_filepath: str,
embedding_hop_size_s: float,
Expand Down
25 changes: 15 additions & 10 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def from_config(cls, config: config_dict.ConfigDict):
**config,
)

def __call__(self, embeddings: np.ndarray) -> InferenceOutputs:
"""Apply the wrapped logits_model to embeddings with shape [B, D]."""
if callable(self.logits_model):
logits = self.logits_model(embeddings)
elif hasattr(self.logits_model, 'signatures'):
# TODO(tomdenton): Figure out why the Keras saved model isn't callable.
flat_logits = self.logits_model.signatures['serving_default'](
inputs=embeddings
)
logits = flat_logits['output_0'].numpy()
else:
raise ValueError('could not figure out how to call wrapped model.')
return logits

def save_model(self, output_path: str, embeddings_path: str):
"""Write a SavedModel and metadata to disk."""
# Write the model.
Expand Down Expand Up @@ -230,16 +244,7 @@ def add_logits(self, model_outputs: InferenceOutputs, keep_original: bool):
logging.warning('No embeddings found in model outputs.')
return model_outputs
flat_embeddings = np.reshape(embeddings, [-1, embeddings.shape[-1]])
# TODO(tomdenton): Figure out why the keras saved model isn't callable.
if callable(self.logits_model):
flat_logits = self.logits_model(flat_embeddings)
elif hasattr(self.logits_model, 'signatures'):
flat_logits = self.logits_model.signatures['serving_default'](
inputs=flat_embeddings
)
flat_logits = flat_logits['output_0'].numpy()
else:
raise ValueError('could not figure out how to call wrapped model.')
flat_logits = self(flat_embeddings)
logits_shape = np.concatenate(
[np.shape(embeddings)[:-1], np.shape(flat_logits)[-1:]], axis=0
)
Expand Down

0 comments on commit 0527dc5

Please sign in to comment.