Skip to content

Commit

Permalink
Switch to Keras v3 format for nn_ensemble (with legacy h5 support)
Browse files Browse the repository at this point in the history
  • Loading branch information
osma committed Aug 16, 2023
1 parent e037b78 commit 4886e08
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
16 changes: 12 additions & 4 deletions annif/backend/nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBacke

name = "nn_ensemble"

MODEL_FILE = "nn-model.h5"
MODEL_FILE = "nn-model.keras"
MODEL_FILE_FALLBACK = "nn-model.h5"
LMDB_FILE = "nn-train.mdb"

DEFAULT_PARAMETERS = {
Expand All @@ -122,11 +123,18 @@ def initialize(self, parallel: bool = False) -> None:
return
model_filename = os.path.join(self.datadir, self.MODEL_FILE)
if not os.path.exists(model_filename):
raise NotInitializedException(
"model file {} not found".format(model_filename),
backend_id=self.backend_id,
model_filename_fallback = os.path.join(
self.datadir, self.MODEL_FILE_FALLBACK
)
if os.path.exists(model_filename_fallback):
model_filename = model_filename_fallback
else:
raise NotInitializedException(
"model file {} not found".format(model_filename),
backend_id=self.backend_id,
)
self.debug("loading Keras model from {}".format(model_filename))
print("loading Keras model from {}".format(model_filename))
self._model = load_model(
model_filename, custom_objects={"MeanLayer": MeanLayer}
)
Expand Down
53 changes: 47 additions & 6 deletions tests/test_backend_nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def test_nn_ensemble_train_and_learn(registry, tmpdir):
assert nn_ensemble._model.optimizer.learning_rate.value() == 0.001

datadir = py.path.local(project.datadir)
assert datadir.join("nn-model.h5").exists()
assert datadir.join("nn-model.h5").size() > 0
assert datadir.join("nn-model.keras").exists()
assert datadir.join("nn-model.keras").size() > 0

# test online learning
modelfile = datadir.join("nn-model.h5")
modelfile = datadir.join("nn-model.keras")

old_size = modelfile.size()
old_mtime = modelfile.mtime()
Expand All @@ -129,7 +129,7 @@ def test_nn_ensemble_train_cached(registry):
datadir = py.path.local(project.datadir)
assert datadir.join("nn-train.mdb").exists()

datadir.join("nn-model.h5").remove()
datadir.join("nn-model.keras").remove()

nn_ensemble_type = annif.backend.get_backend("nn_ensemble")
nn_ensemble = nn_ensemble_type(
Expand All @@ -140,8 +140,8 @@ def test_nn_ensemble_train_cached(registry):

nn_ensemble.train("cached")

assert datadir.join("nn-model.h5").exists()
assert datadir.join("nn-model.h5").size() > 0
assert datadir.join("nn-model.keras").exists()
assert datadir.join("nn-model.keras").size() > 0


def test_nn_ensemble_train_and_learn_params(registry, tmpdir, capfd):
Expand Down Expand Up @@ -254,3 +254,44 @@ def test_nn_ensemble_suggest(app_project):

assert nn_ensemble._model is not None
assert len(results) > 0


def test_nn_ensemble_h5_fallback(registry, app_project, monkeypatch):
nn_ensemble_type = annif.backend.get_backend("nn_ensemble")
nn_ensemble = nn_ensemble_type(
backend_id="nn_ensemble",
config_params={"sources": "dummy-en"},
project=app_project,
)

datadir = py.path.local(app_project.datadir)

# train model from cached data and save it in legacy HDF5 format
with monkeypatch.context() as m:
m.setattr(nn_ensemble, "MODEL_FILE", "nn-model.h5")
nn_ensemble.train("cached")

# remove any existing .keras model
datadir.join("nn-model.keras").remove()

# check that the trained model was saved in a HDF5 file
assert datadir.join("nn-model.h5").exists()
assert datadir.join("nn-model.h5").size() > 0

# delete the model, so it has to be loaded again
nn_ensemble._model = None

# now try using suggest, which forces loading it from .h5 file
results = nn_ensemble.suggest(
[
"""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen
tiede tai oikeammin joukko tieteitä, jotka tutkivat ihmisen
menneisyyttä. Tutkimusta tehdään analysoimalla muinaisjäännöksiä
eli niitä jälkiä, joita ihmisten toiminta on jättänyt maaperään
tai vesistöjen pohjaan."""
]
)[0]

assert nn_ensemble._model is not None
assert len(results) > 0

0 comments on commit 4886e08

Please sign in to comment.