Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LogitsOutputHead callable. Fixes yet another Keras SNAFU. #660

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@
" 'model_path': custom_classifier_path,\n",
" 'logits_key': 'custom',\n",
"})\n",
"loaded_model = interface.LogitsOutputHead.from_config(cfg)\n",
"model = loaded_model.logits_model\n",
"class_list = loaded_model.class_list\n",
"logits_head = interface.LogitsOutputHead.from_config(cfg)\n",
"model = logits_head.logits_model\n",
"class_list = logits_head.class_list\n",
"print('Loaded custom model with classes: ')\n",
"print('\\t' + '\\n\\t'.join(class_list.classes))"
]
Expand Down Expand Up @@ -167,7 +167,7 @@
"\n",
"classify.write_inference_csv(\n",
" embeddings_ds=embeddings_ds,\n",
" model=model,\n",
" model=logits_head,\n",
" labels=class_list.classes,\n",
" output_filepath=output_filepath,\n",
" threshold=class_thresholds,\n",
Expand Down Expand Up @@ -216,7 +216,7 @@
"\n",
"embeddings_ds = project_state.create_embeddings_dataset(shuffle_files=True)\n",
"results, all_logits = search.classifer_search_embeddings_parallel(\n",
" embeddings_classifier=model,\n",
" embeddings_classifier=logits_head,\n",
" target_index=class_list.classes.index(target_class),\n",
" random_sample=True,\n",
" top_k=top_k,\n",
Expand Down
5 changes: 3 additions & 2 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
from typing import Sequence

from chirp.inference import interface
from chirp.inference import tf_examples
from chirp.inference.classify import data_lib
from chirp.models import metrics
Expand Down Expand Up @@ -156,7 +157,7 @@ def train_embedding_model(

def get_inference_dataset(
embeddings_ds: tf.data.Dataset,
model: tf.keras.Model,
model: interface.LogitsOutputHead,
):
"""Create a dataset which includes the model's predictions."""

Expand All @@ -182,7 +183,7 @@ def classify_batch(batch):

def write_inference_csv(
embeddings_ds: tf.data.Dataset,
model: tf.keras.Model,
model: interface.LogitsOutputHead,
labels: Sequence[str],
output_filepath: str,
embedding_hop_size_s: float,
Expand Down
25 changes: 15 additions & 10 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def from_config(cls, config: config_dict.ConfigDict):
**config,
)

def __call__(self, embeddings: np.ndarray) -> InferenceOutputs:
"""Apply the wrapped logits_model to embeddings with shape [B, D]."""
if callable(self.logits_model):
logits = self.logits_model(embeddings)
elif hasattr(self.logits_model, 'signatures'):
# TODO(tomdenton): Figure out why the Keras saved model isn't callable.
flat_logits = self.logits_model.signatures['serving_default'](
inputs=embeddings
)
logits = flat_logits['output_0'].numpy()
else:
raise ValueError('could not figure out how to call wrapped model.')
return logits

def save_model(self, output_path: str, embeddings_path: str):
"""Write a SavedModel and metadata to disk."""
# Write the model.
Expand Down Expand Up @@ -230,16 +244,7 @@ def add_logits(self, model_outputs: InferenceOutputs, keep_original: bool):
logging.warning('No embeddings found in model outputs.')
return model_outputs
flat_embeddings = np.reshape(embeddings, [-1, embeddings.shape[-1]])
# TODO(tomdenton): Figure out why the keras saved model isn't callable.
if callable(self.logits_model):
flat_logits = self.logits_model(flat_embeddings)
elif hasattr(self.logits_model, 'signatures'):
flat_logits = self.logits_model.signatures['serving_default'](
inputs=flat_embeddings
)
flat_logits = flat_logits['output_0'].numpy()
else:
raise ValueError('could not figure out how to call wrapped model.')
flat_logits = self(flat_embeddings)
logits_shape = np.concatenate(
[np.shape(embeddings)[:-1], np.shape(flat_logits)[-1:]], axis=0
)
Expand Down
3 changes: 2 additions & 1 deletion chirp/inference/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import heapq
from typing import Any, Callable, List, Sequence

from chirp.inference import interface
from chirp.inference import tf_examples
from etils import epath
import numpy as np
Expand Down Expand Up @@ -316,7 +317,7 @@ def search_embeddings_parallel(


def classifer_search_embeddings_parallel(
embeddings_classifier: tf.keras.Model,
embeddings_classifier: interface.LogitsOutputHead,
target_index: int,
**kwargs,
):
Expand Down
28 changes: 25 additions & 3 deletions chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

"""Test small-model classification."""

from collections.abc import Sequence
import tempfile

from absl import app
from chirp.inference import interface
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from chirp.taxonomy import namespace
import numpy as np

from absl.testing import absltest
Expand Down Expand Up @@ -60,9 +61,10 @@ def test_train_linear_model(self):
num_classes = 4
num_points = 100
model = classify.get_linear_model(embedding_dim, num_classes)
rng = np.random.default_rng(42)
merged = self.make_merged_dataset(
num_points=num_points,
rng=np.random.default_rng(42),
rng=rng,
num_classes=num_classes,
embedding_dim=embedding_dim,
)
Expand All @@ -76,6 +78,26 @@ def test_train_linear_model(self):
batch_size=16,
learning_rate=0.01,
)
query = rng.normal(size=(num_points, embedding_dim)).astype(np.float32)

logits = model(query)

# Save and restore the model.
class_names = ['a', 'b', 'c', 'd']
with tempfile.TemporaryDirectory() as logits_model_dir:
logits_model = interface.LogitsOutputHead(
model_path=logits_model_dir,
logits_key='some_model',
logits_model=model,
class_list=namespace.ClassList('classes', class_names),
)
logits_model.save_model(logits_model_dir, '')
restored_model = interface.LogitsOutputHead.from_config_file(
logits_model_dir
)
restored_logits = restored_model(query)
error = np.abs(restored_logits - logits).sum()
self.assertEqual(error, 0)


if __name__ == '__main__':
Expand Down
Loading